from netCDF4 import Dataset
import numpy as np
import datetime
from more_itertools import locate
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from cartopy.io import shapereader
from matplotlib.cm import get_cmap
import cartopy as cart
import calendar 
from dateutil.relativedelta import relativedelta
from cartopy.io.shapereader import Reader
from cartopy.feature import ShapelyFeature
import shapefile as shp



class GraceMaps():
    """docstring for GraceMaps"""
    def __init__(self, file):
        super(GraceMaps, self).__init__()
        self.file = file

    def read_tws(self):
        ff = Dataset(self.file)
        self.tws = ff['lwe_thickness'] 

    def read_tws_time(self):
        ff = Dataset(self.file)
        base = datetime.datetime(2002,1,1,0)
        self.ndays = ff['time']
        self.time = [base + datetime.timedelta(days=float(x)) for x in self.ndays[:]]

    def read_tws_coordinates(self):
        ff = Dataset(self.file)
        self.lat = ff['lat']
        self.lon = ff['lon']
                
    def longterm_mean(self):
        longavg_tws = np.mean(self.tws[:], axis=0)
        return longavg_tws

    def month_climatology(self):
        monclim_tws = np.empty((12, self.tws.shape[1], self.tws.shape[2]))
        months = [x.month for x in self.time[:]]
        for mm in range(1, 13):
            indices = list(locate(months, lambda x: x == mm))
            monclim_tws[mm-1, :, :] = np.mean(self.tws[indices[:], :,:], axis=0)
        return monclim_tws

    def years_average(self, step_y):
        start = 2002
        end = 2020
        grace_years = [yy.year for yy in self.time]
        datelist = [start]
        dd = datelist[0]
        date_idx = []
        while dd + step_y <= end:
            dd = dd + step_y
            datelist.append(dd)
        
        for x in datelist:
            print(x)
            idx = grace_years.index(x)
            # print('idx', idx)
            date_idx.append(idx)
        year_avg = np.empty((len(date_idx), self.tws.shape[1], self.tws.shape[2]))
        print(year_avg.shape)
        # print(len(date_idx))
        for t in range(len(date_idx)):
            limit = len(date_idx)
            if t+1 != limit:
                year_avg[t] = np.mean(self.tws[date_idx[t]:date_idx[t+1]], axis=0)
            else:
                # print('ultima media',t, date_idx[t], datelist[t-1])
                year_avg[t] = np.mean(self.tws[date_idx[t]:], axis=0)
        datelist.append(end)
        print(datelist)
        print(date_idx)
        return datelist, year_avg
    
    def basin_mask(self):
        pa_file = Dataset('/data/dmdpesq/bianca.maske/scripts/mascara_bacias/new/PA.nc')
        self.pamsk = pa_file['pamsk'][:,:]*-1
        self.pa_lat = pa_file['lat'][:]
        self.pa_lon = pa_file['lon'][:]        
        ne_file = Dataset('/data/dmdpesq/bianca.maske/scripts/mascara_bacias/new/NE.nc')
        self.nemsk = ne_file['nemsk'][:,:]*-1
        self.ne_lat = ne_file['lat'][:]
        self.ne_lon = ne_file['lon'][:]        
        no_file = Dataset('/data/dmdpesq/bianca.maske/scripts/mascara_bacias/new/NO.nc')
        self.nomsk = no_file['nomsk'][:,:]*-1
        self.no_lat = no_file['lat'][:]
        self.no_lon = no_file['lon'][:]        
        am_file = Dataset('/data/dmdpesq/bianca.maske/scripts/mascara_bacias/new/AM.nc')
        self.ammsk = am_file['ammsk'][:,:]*-1
        self.am_lat = am_file['lat'][:]
        self.am_lon = am_file['lon'][:]        
        su_file = Dataset('/data/dmdpesq/bianca.maske/scripts/mascara_bacias/new/SU.nc')
        self.sumsk = su_file['sumsk'][:,:]*-1
        self.su_lat = su_file['lat'][:]
        self.su_lon = su_file['lon'][:] 

    def basin_mask_grace(self):
        pa_file = Dataset('/data/dmdpesq/bianca.maske/dados/grace/csr/PA_grace.nc')
        self.pamsk_grace = pa_file['pamsk'][:,:]*-1
        self.pa_lat_grace = pa_file['lat'][:]
        self.pa_lon_grace = pa_file['lon'][:]        
        ne_file = Dataset('/data/dmdpesq/bianca.maske/dados/grace/csr/NE_grace.nc')
        self.nemsk_grace = ne_file['nemsk'][:,:]*-1
        self.ne_lat_grace = ne_file['lat'][:]
        self.ne_lon_grace = ne_file['lon'][:]        
        no_file = Dataset('/data/dmdpesq/bianca.maske/dados/grace/csr/NO_grace.nc')
        self.nomsk_grace = no_file['nomsk'][:,:]*-1
        self.no_lat_grace = no_file['lat'][:]
        self.no_lon_grace = no_file['lon'][:]        
        am_file = Dataset('/data/dmdpesq/bianca.maske/dados/grace/csr/AM_grace.nc')
        self.ammsk_grace = am_file['ammsk'][:,:]*-1
        self.am_lat_grace = am_file['lat'][:]
        self.am_lon_grace = am_file['lon'][:]        
        su_file = Dataset('/data/dmdpesq/bianca.maske/dados/grace/csr/SU_grace.nc')
        self.sumsk_grace = su_file['sumsk'][:,:]*-1
        self.su_lat_grace = su_file['lat'][:]
        self.su_lon_grace = su_file['lon'][:] 

    def plot_timeserie(self):
        msk = [self.pamsk_grace, self.nemsk_grace, self.nomsk_grace, self.ammsk_grace, self.sumsk_grace]
        cor = ['darkred', 'darkgreen', 'midnightblue', 'darkorange', 'purple']
        basin = ['PARANA', 'NORTHEAST', 'NORTHERN SOUTH AMERICA', 'AMAZON', 'SOUTHERN SOUTH AMERICA']
        # fig, axs = plt.subplots(3, 2)
        figs = [ (3,2,1), (3,2,2), (3,2,3), (3,2,4), (3,2,5) ]
        for x in range(len(msk)):
            msk2 = np.empty((self.tws.shape[0], msk[x].shape[0], msk[x].shape[1]))
            for tt in range(self.tws.shape[0]):
                msk2[tt] = msk[x][:,:]
            tws2 = np.ma.mean(np.ma.masked_where(self.tws[:]*msk2 == 0, self.tws), axis=(1,2))
            nrows = figs[x][0]
            ncols = figs[x][1]
            plot_number = figs[x][2]
            print(nrows, ncols, plot_number)
            plt.subplot(nrows, ncols, plot_number)
            plt.plot(self.time, tws2, color='k')
            plt.title(basin[x], fontsize=8)
            plt.xticks(rotation=90)
        plt.tight_layout()
        plt.savefig('time_serie.png')

    def plot_data(self, data, title, namefig, levels, colorPallet='RdYlBu',
                  extend_bar="both", mask="None"):
        plt.figure(figsize=(6, 4)) 
        ax = plt.axes(projection=ccrs.PlateCarree()) 
        ax.add_feature(cart.feature.COASTLINE, edgecolor='k')
        if mask.lower() == "ocean":
            ax.add_feature(cart.feature.OCEAN, zorder=100, edgecolor='k',
                           facecolor='white')
        elif mask.lower() == "land":
            ax.add_feature(cart.feature.LAND, zorder=100, edgecolor='k',
                           facecolor='white')
        # 
        
        x,y = np.meshgrid(self.lon, self.lat)
        plt.contourf(x, y, data, levels=levels,
                     cmap=get_cmap(colorPallet), extend=extend_bar)
        cb = plt.colorbar(ax=ax)
        cb.ax.tick_params(labelsize=8)
        x_pa, y_pa = np.meshgrid(self.pa_lon, self.pa_lat) 
        plt.contour(x_pa,y_pa, self.pamsk, colors='k', linewidths=0.5, linestyles='solid')
        x_am, y_am = np.meshgrid(self.am_lon, self.am_lat) 
        plt.contour(x_am,y_am, self.ammsk, colors='k', linewidths=0.5, linestyles='solid')
        x_ne, y_ne = np.meshgrid(self.ne_lon, self.ne_lat) 
        plt.contour(x_ne, y_ne, self.nemsk, colors='k', linewidths=0.5, linestyles='solid')
        x_no, y_no = np.meshgrid(self.no_lon, self.no_lat) 
        plt.contour(x_no, y_no, self.nomsk, colors='k', linewidths=0.5, linestyles='solid')
        x_su, y_su = np.meshgrid(self.su_lon, self.su_lat) 
        plt.contour(x_su, y_su, self.sumsk, colors='k', linewidths=0.5, linestyles='solid')
        ax.coastlines('50m')
        ax.set_extent([-80.974976, -34.825012, -55.925, 11.975006],  ccrs.PlateCarree())

        plt.title(title,fontsize=10)
        plt.savefig(namefig, bbox_inches='tight', dpi=300)
        plt.clf()
        plt.close('all')







if __name__== "__main__":
    dados_grace = '/data/dmdpesq/bianca.maske/dados/grace/csr/grace_csr_rl06_preenchido_AS.nc'
    grace = GraceMaps(dados_grace)
    grace.read_tws()
    grace.read_tws_time()
    grace.read_tws_coordinates()
    grace.basin_mask()
    grace.basin_mask_grace()
    # longavg = grace.longterm_mean()
    # print(longavg.min(), longavg.max())
    # month_longavg = grace.month_climatology()
    # # print(month_longavg.min(), month_longavg.max())
    # grace.plot_data(longavg, 'Longterm TWS average', 'longavg.png', np.arange(-20, 20, 1))
    # for tt in range(1,13):
    #     mm = calendar.month_name[tt]
    #     grace.plot_data(month_longavg[tt-1], f'TWS average {mm}', f'{tt}_monthavg_{mm}.png', np.arange(-20, 20, 1))
    # time_bounds, years_avg = grace.years_average(1)
    # x = 0
    # for yy in time_bounds[:-1]:
    #     print(yy)
    #     print(years_avg[x].min(), years_avg[x].max())
    #     grace.plot_data(years_avg[x], f'TWS average {yy}',
    #                     f'{yy}_avg.png', np.arange(-20, 20, 1))
    #     x += 1
    grace.plot_timeserie()


