import geopandas as gpd
from typing import List, Union
import pandas as pd
import numpy as np
from shapely.geometry import LineString, Point
from tqdm import tqdm

def raw_rp_dtw(rp, y, dist_func=None, return_path_cost_matrix=False, return_moves_matrices=False, verbose=False, band_width=np.inf, export_path_cost_matrix=None, tqdm_text="Warping..."):
    The raw reference path dtw implementation from Vaclav

    if dist_func is None:
        dist_func = lambda a,b: np.sqrt( np.sum((a-b)**2) )  # Euclidean distance

    if rp.ndim == 1:
        rp = rp.reshape(-1,1)
        y = y.reshape(-1,1)

    N = rp.shape[0]
    M = y.shape[0]

    path_cost_matrix = np.full((N,M), np.inf)
    best_moves_x = np.full((N,M), np.nan, np.int8)
    best_moves_y = np.full((N,M), np.nan, np.int8)

    path_cost_matrix[0,0] = dist_func(rp[0,:],y[0,:])

    # Calculate the distance
    #if verbose: print('Warping...')
    for ii in tqdm(range(N), desc=tqdm_text, disable=(not verbose)):
        #if verbose & (np.mod(ii,1000)==0): print('{}/{}'.format(ii,N-1))
        #for jj in range(M):
        #for jj in np.arange(start=np.min([np.max([0,ii-band_width]),M-band_width]), stop=np.min([ii+1,M]), step=1):
        jj_min = np.max([0, ii*(M-1)/(N-1)-np.floor(band_width/2)])
        jj_max = np.min([M, ii*(M-1)/(N-1)+np.ceil(band_width/2)])
        for jj in np.arange(start=jj_min, stop=jj_max, step=1, dtype=int):
            if (ii == 0) and (jj == 0): continue
            if jj > ii: continue

            # Calculate cost of each allowed action
            cost_move_right = np.inf
            cost_move_up    = np.inf
            cost_move_diag  = np.inf
            if (ii > 0):
                cost_move_right = path_cost_matrix[ii-1, jj  ]  # Take next sample from the first signal, keep the second fixed

            # This move is never allowed (we always take next sample from the first signal => the first signal stays unchanged)
            #if (jj > 0):
            #    cost_move_up    = path_cost_matrix[ii  , jj-1]  # Take next sample from the second signal, keep the first fixed

            if (ii > 0) and (jj > 0):
                cost_move_diag  = path_cost_matrix[ii-1, jj-1]  # Take next sample from both signals

            cost_moves = np.array([cost_move_right, cost_move_up, cost_move_diag])

            # Select the best cost (i.e. the best previous position)
            action_idx = cost_moves.argmin()
            best_move_cost = cost_moves[action_idx]
            if action_idx == 0:
                    best_moves_x[ii,jj] = -1  # We came here from left
                    best_moves_y[ii,jj] = 0
            elif action_idx == 1:             # Never happens.
                    best_moves_x[ii,jj] = 0   # We came here from below
                    best_moves_y[ii,jj] = -1
            elif action_idx == 2:
                    best_moves_x[ii,jj] = -1  # We came here diagonally
                    best_moves_y[ii,jj] = -1

            path_cost_matrix[ii,jj] = dist_func(rp[ii,:], y[jj,:]) + best_move_cost

    # Calculate the indexes: Go back through the distance matrix, following the best (cheapest) moves
    #x_idx = np.full(M+N, np.nan)  # Preallocation
    y_idx = np.full(M+N, np.nan)

    xidx_temp = N-1
    yidx_temp = M-1
    ii = 0

    while (xidx_temp > 0) or (yidx_temp > 0):
        if ii > 0:
            # Get the previous indexes on the path by a lookup from the "moves" matrixes
            xidx_prev = xidx_temp + best_moves_x[xidx_temp, yidx_temp]
            yidx_prev = yidx_temp + best_moves_y[xidx_temp, yidx_temp]
            xidx_temp = xidx_prev
            yidx_temp = yidx_prev

        # Store the current x and y indexes
        #x_idx[ii] = xidx_temp
        y_idx[ii] = yidx_temp

        ii += 1

    #ix = np.flipud( x_idx[~np.isnan(x_idx)] )
    iy = np.flipud( y_idx[~np.isnan(y_idx)] ).astype(np.int)

    if export_path_cost_matrix is not None:
        from scipy.io import savemat
        savemat(export_path_cost_matrix, {'M':path_cost_matrix, 'idx':iy})

    if return_path_cost_matrix and return_moves_matrices:
        return iy, path_cost_matrix, best_moves_x, best_moves_y
    if return_path_cost_matrix and (not return_moves_matrices):
        return iy, path_cost_matrix
    if (not return_path_cost_matrix) and return_moves_matrices:
        return iy, best_moves_x, best_moves_y
    if (not return_path_cost_matrix) and (not return_moves_matrices):
        return iy

def warp_geodataframe(
        reference_path : gpd.GeoSeries,
        measurement_series : List[gpd.GeoDataFrame],
        resampling_distance=None, used_crs="EPSG:31287",
        band_width=np.inf, verbose=True
    ) -> List[gpd.GeoDataFrame]:
    Reindex a list of measurements to a reference path by applying dynamic timewarping

    :param reference_path: geopandas series with EPSG:4326 crs either consisting of multiple points or a single linestring
    :param measurement_series: geopandas dataframe with geometry specified by a sequence of points in EPSG:4326
    :param resampling_distance: distance in meters used for resampling of the reference_path. If None then no resampling will be conducted
    :param used_crs: the coordinate system used for the distance calculation
    :param band_width: band_width parameter for the the dynamic time_warping algorithm
    :param verbose: print a progress animation during warping

        A list of geodadaframes reindexed at the reference path positions
        the previous measurement sample index is stored under original_index

    #Implement the inclusion of a single linestring resampling needs to be set
    if any(series.crs != "EPSG:4326" for series in [reference_path] + measurement_series):
        raise ValueError("Pass GeoSeries in EPSG:4326!")

    #Check if reference path is linestring
    linestring_input = False
    if any(type(el) == LineString for el in reference_path):
        linestring_input = True
        if not len(reference_path) == 1:
            raise ValueError("Only a single LineString can be used for the reference path")
        elif resampling_distance is None:
            raise ValueError("Resampling distance cannot be None for LineString Reference Path")

    #Convert to correct CRS
    reference_path = reference_path.to_crs(used_crs)
    measurement_series = [series.to_crs(used_crs) for series in measurement_series]

    #Resample the reference path if a resampling distance is passed
    if resampling_distance is not None:
        if linestring_input:
            reference_path_ls = reference_path.iloc[0]
            reference_path_ls = LineString(reference_path)

        reference_path = gpd.GeoSeries(
            reference_path_ls.interpolate(dist) for dist in np.arange(0, reference_path_ls.length, resampling_distance)

    #Basically we should simply summarize the indices here. then append them again.
        index_array = [
                np.column_stack((reference_path.x, reference_path.y)),
                np.column_stack((series.geometry.x, series.geometry.y)),
                tqdm_text=f"Warping {i+1}/{len(measurement_series)}",
            ) for i, series in enumerate(measurement_series)
    except IndexError:
        raise IndexError("Encountered Index Error - Choose smaller resampling distance or remove low speed measurement sequences")

    warped_measurement_series = []
    for i, indices in enumerate(index_array):
        #Reindex the warped series

        #Save the previous index under original index
        old_index_name = measurement_series[i].index.name
        warped_series = measurement_series[i].loc[indices].reset_index()
        warped_series = warped_series.rename(columns={old_index_name: 'original_index'})

        #Assign the reference path to the warped series
        warped_series.geometry = reference_path
        warped_series = warped_series.to_crs("EPSG:4326")


    return warped_measurement_series

