#!/home/CWA_MDPS/.conda/envs/rd/bin/python
from pytools.caltools import interp_1d
import pytools.nctools as nct
import pytools.timetools as tt
import pytools.filetools as ft
from pytools.caltools import fillNans2d
from tqdm import tqdm
import numpy as np
import os
import netCDF4
import shared as st
# !!! run this file on betago !!!

#       ssu     ssv     ssh     MLD     D20     D15
# 2001
# 2002
# 2003
# 2004
# 2005
# 2006
# 2007
# 2008
# 2009
# 2010
# 2011                                            
# 2012                                            
# 2013                                            
# 2014  O       O
# 2015                                            
# 2016                                            
# 2017                                            
# 2018
# 2019
# 2020

def main():
    for year in [2005, 2006]:
        for varName in ['MLD', 'D20', 'D15', 'ssh', 'ssu', 'ssv']:
            run(varName=varName, year=year, debug=False, force=False)


class Global:
    def __init__(self):
        self.numStats = 4

        self.isLand = self.read_land_sea_mask()
        self.isSea = 1 - self.isLand
        self.nIsLand = np.sum(self.isLand)
        self.nIsSea = np.sum(self.isSea)

    def read_land_sea_mask(self):
        return nct.read('../data/land_sea_mask_hycom.nc', 'isLand')


def run(varName, year, debug=False, it_list=None, force=False):
    print(f'{debug=}')
    if debug:
        validTimeStart = tt.ymd2float(year, 1, 1)
        validTimeEnd = tt.ymd2float(year, 1, 2, 18)
    else:
        validTimeStart = tt.ymd2float(year, 1, 1)
        validTimeEnd = tt.ymd2float(year, 12, 31, 18)
    validTimes = getValidTimeRange(validTimeStart, validTimeEnd)

    stat = checkSrcPaths(varName, validTimes)
    if not stat:
        return

    desPath = getDesPath(varName, year, debug)
    if os.path.exists(desPath) and not force:
        print(f'file already exists: {desPath}')
        return

    print(f' {desPath=}')

    ft.check_des_path(desPath)
    createDesFile(varName, desPath, validTimes)

    if it_list is None:
        it_list = list(range(len(validTimes)))

    with netCDF4.Dataset(desPath, 'a') as h:
        variableHandle = h[varName]
        for itime in tqdm(it_list):
            readAndWriteRecord(varName, variableHandle, itime, validTimes[itime])


def getValidTimeRange(start, end):
    delta = 6/24
    return np.r_[start:(end+delta):delta]


def checkSrcPaths(varName, validTimes):
    stat = True
    missings = []
    
    for validTime in validTimes:
        path = getSrcPath(varName, *valid2initLead(validTime))
        if not os.path.exists(path):
            stat = False
            missings.append(path)

    if not stat:
        print('path(s) missing:')
        for path in missings:
            print(f' {path}')
        return

    print(f' All paths are found for {
        ' to '.join([
            tt.float2format(t, '%Y-%m-%d %Hz') for t in [
                validTimes[i] for i in [0, -1]
            ]
        ])
    }, NT={len(validTimes)}.')
    return stat


def getDesPath(varName, year, debug):
    rootDir = '/data/public/HYCOM/leecw/data/merged_interp'
    if debug:
        return f'{rootDir}/debug_{varName}_{year}.nc'
    else:
        return f'{rootDir}/{varName}_{year}.nc'


def createDesFile(varName, desPath, validTimes):
    nxOutput = len(LON_OUT)
    nyOutput = len(LAT_OUT)
    ntOutput = len(validTimes)
    outputShape = (ntOutput, nyOutput, nxOutput)

    print(f' creating output shape = {outputShape}...')
    nct.create(
        desPath, varName, outputShape, 
        [TIMENAME_OUT, LATNAME_OUT, LONNAME_OUT], use_my_attrs=True,
        significant_digits=None, dtype='f8', complevel=0, shuffle=False,
    )
    
    nct.write(desPath, TIMENAME_OUT, validTimes)
    nct.write(desPath, LATNAME_OUT, LAT_OUT)
    nct.write(desPath, LONNAME_OUT, LON_OUT)
    return

def getVarNameIn(varName):
    map_varNames = {
        'ssh': 'surf_el',
        'ssu': 'water_u',
        'ssv': 'water_v',
        'MLD': 'MLD',
        'D15': 'D15',
        'D20': 'D20',
    }
    if varName not in map_varNames:
        raise ValueError(f'unrecognized {varName=}')
    else:
        return map_varNames[varName]


def readAndWriteRecord(varName, variableHandle, itime, validTime):
    # read
    src = getSrcPath(varName, *valid2initLead(validTime))
    lon_in = nct.read(src, LONNAME_IN)
    lat_in = nct.read(src, LATNAME_IN)
    data = nct.read(src, getVarNameIn(varName))
    data = np.array(data, dtype=np.float32)
    if data.ndim == 4:
        data = np.squeeze(data, axis=-3)
    elif data.ndim == 2:
        data = data[None, :, :]

    # masking
    data = st.maskInvalidValues(data, varName, glb.isLand)

    # # [DEPRECATED - using st.maskInvalidVlaues()] missing values
    # if varName in ['ssh', 'MLD', 'D20', 'D15']:
    #     data[(np.abs(data) > 1e4)] = np.nan
    # elif varName in ['ssu', 'ssv']:
    #     data[(np.abs(data) > 1e4)] = np.nan
    #     data[(data==0)] = np.nan
    #     data[(data>15)] = np.nan
    # else:
    #     raise ValueError(f'Unrecognized {varName=}')

    # lon = -180 ~ 180 -> 0 ~ 360
    ind = next((i for i, l in enumerate(lon_in) if l >= 0), None)
    lon_in = np.concatenate((lon_in[ind:], lon_in[:ind]+360), axis=-1)
    data = np.concatenate((data[:, :, ind:], data[:, :, :ind]), axis=-1)

    # flip latitude to be increaing for the function
    data = interp_1d(lat_in, data, LAT_OUT[::-1], axis=-2, extrapolate=True) 
    data = interp_1d(lon_in, data, LON_OUT, axis=-1)
    data = data[:, ::-1, :] # flip latitude

    # fill nans
    if np.sum(np.isnan(data)) > 0:
        numSmoothsNan = 20 * 4 # 20 deg * 4 grid / deg
        data[:, :, :] = fillNans2d(data[0, :, :], numSmoothsNan)

    # write
    variableHandle[itime, :] = data


def getSrcPath(varName, initTime, lead):
    def getPath(subDir):
        rootDir = f'/data/public/HYCOM/leecw/data/{subDir}'
        if varName == 'ssh':
            pathFmt = f'{rootDir}/%Y/{varName}/%m/hycom_{varName}_%Y%m%d%H_t{lead:03d}.nc'
        elif varName in ['ssu', 'ssv']:
            pathFmt = f'{rootDir}/%Y/ssuv/%m/hycom_{varName}_%Y%m%d%H_t{lead:03d}.nc'
        elif varName in ['MLD', 'D20', 'D15']:
            varName2 = f'M{varName[1:]}'
            pathFmt = f'{rootDir}/%Y/others/%m/hycom_{varName2}_%Y%m%d%H_t{lead:03d}.nc'
        else:
            raise ValueError(f'unrecognized {varName=}')
        return tt.float2format(initTime, pathFmt)

    path = getPath('patched')
    if os.path.exists(path):
        return path

    path = getPath('source')
    return path

def valid2initLead(validTime):
    validHour = tt.hour(validTime)
    validDate = validTime - validHour/24

    # find the previous 12z
    if validHour >= 12:
        initTime = validDate + 0.5  # today 12z
    else:
        initTime = (validDate-1) + 0.5  # yesterday 12z

    lead = (validTime-initTime) * 24

    return initTime, int(lead)


if __name__ == '__main__':
    LONNAME_IN = 'lon'
    LATNAME_IN = 'lat'

    LONNAME_OUT = 'lon'
    LATNAME_OUT = 'lat'
    TIMENAME_OUT = 'time'
    LON_OUT = np.r_[0:360:0.25]
    LAT_OUT = np.r_[90:-90.25:-0.25]

    glb = Global()
    main()
