#!/home/CWA_MDPS/.conda/envs/rd/bin/python
from pytools.caltools import interp_1d
import pytools.nctools as nct
import pytools.timetools as tt
from pytools.fillNans2d import fillNans2d
from tqdm import tqdm
import numpy as np
import os
import netCDF4
import shared as st
# !!! run this file on betago !!!

#       ssu     ssv     ssh     MLD     D20     D15
# 2001
# 2002
# 2003
# 2004
# 2005
# 2006
# 2007
# 2008
# 2009
# 2010
# 2011                                            
# 2012                                            
# 2013  O       O       O       O       O       O                                         
# 2014  O       O       O       O       O       O                                         
# 2015  O       O       O       O       O       O                                         
# 2016  O       O       O       O       O       O                                         
# 2017  O       O       O       O       O       O                                         
# 2018
# 2019
# 2020

def main():
    for year in range(2011, 2017+1):
        for varName in ['ssu']:
            run(varName, year=year, debug=False)


def run(varName, year, debug):
    print(f'{debug=}')
    if debug:
        validTimeStart = tt.ymd2float(year, 1, 1)
        validTimeEnd = tt.ymd2float(year, 1, 2, 18)
    else:
        validTimeStart = tt.ymd2float(year, 1, 1)
        validTimeEnd = tt.ymd2float(year, 12, 31, 18)
    validTimes = getValidTimeRange(validTimeStart, validTimeEnd)

    stat = checkSrcPaths(varName, validTimes)
    if not stat:
        return

    desPath = getDesPath(varName, year, debug, suffix='')
    sumDesPath = getDesPath(varName, year, debug, suffix='sum')
    print(f' {desPath=}')

    if not os.path.exists(desPath):
        createDesFile(varName, desPath, validTimes)

        with netCDF4.Dataset(desPath, 'a') as h:
            variableHandle = h[varName]
            for itime, validTime in enumerate(tqdm(validTimes)):
                readAndWriteRecord(varName, variableHandle, itime, validTime)
    else:
        print(f'desPath already exists: {desPath}')

    if not os.path.exists(sumDesPath):
        os.system(f'cdo -b U16 timsum {desPath} {sumDesPath}')
    else:
        print(f'sumDesPath already exists: {sumDesPath}')


def getValidTimeRange(start, end):
    delta = 6/24
    return np.r_[start:(end+delta):delta]


def checkSrcPaths(varName, validTimes):
    stat = True
    missings = []
    
    for validTime in validTimes:
        path = getSrcPath(varName, *valid2initLead(validTime))
        if not os.path.exists(path):
            stat = False
            missings.append(path)

    if not stat:
        print('path(s) missing:')
        for path in missings:
            print(f' {path}')
        return

    print(f' All paths are found for {
        ' to '.join([
            tt.float2format(t, '%Y-%m-%d %Hz') for t in [
                validTimes[i] for i in [0, -1]
            ]
        ])
    }, NT={len(validTimes)}.')
    return stat


def getDesPath(varName, year, debug, suffix=''):
    rootDir = '/data/public/HYCOM/leecw/data/nnans'
    if debug:
        if not suffix:
            return f'{rootDir}/debug_{varName}_{year}.nc'
        else:
            return f'{rootDir}/debug_{varName}_{year}_{suffix}.nc'
    else:
        if not suffix:
            return f'{rootDir}/{varName}_{year}.nc'
        else:
            return f'{rootDir}/{varName}_{year}_{suffix}.nc'


def createDesFile(varName, desPath, validTimes):
    nxOutput = len(LON_OUT)
    nyOutput = len(LAT_OUT)
    ntOutput = len(validTimes)
    outputShape = (ntOutput, nyOutput, nxOutput)

    nct.create(
        desPath, varName, outputShape, 
        [TIMENAME_OUT, LATNAME_OUT, LONNAME_OUT], use_my_attrs=True,
        significant_digits=None, dtype='u1', complevel=0, shuffle=False,
    )
    
    nct.write(desPath, TIMENAME_OUT, validTimes)
    nct.write(desPath, LATNAME_OUT, LAT_OUT)
    nct.write(desPath, LONNAME_OUT, LON_OUT)

    return

def getVarNameIn(varName):
    map_varNames = {
        'ssh': 'surf_el',
        'ssu': 'water_u',
        'ssv': 'water_v',
        'MLD': 'MLD',
        'D15': 'D15',
        'D20': 'D20',
    }
    if varName not in map_varNames:
        raise ValueError(f'unrecognized {varName=}')
    else:
        return map_varNames[varName]


def readAndWriteRecord(varName, variableHandle, itime, validTime):
    # read
    src = getSrcPath(varName, *valid2initLead(validTime))
    lon_in = nct.read(src, LONNAME_IN)
    lat_in = nct.read(src, LATNAME_IN)
    data = nct.read(src, getVarNameIn(varName))
    data = np.array(data, dtype=np.float32)
    if data.ndim == 4:
        data = np.squeeze(data, axis=-3)

    # masking
    data = st.maskInvalidValues(data, varName)

    # lon = -180 ~ 180 -> 0 ~ 360
    ind = next((i for i, l in enumerate(lon_in) if l >= 0), None)
    lon_in = np.concatenate((lon_in[ind:], lon_in[:ind]+360), axis=-1)
    data = np.concatenate((data[:, :, ind:], data[:, :, :ind]), axis=-1)

    # flip latitude to be increaing for the function
    data = interp_1d(lat_in, data, LAT_OUT[::-1], axis=-2, extrapolate=True) 
    data = interp_1d(lon_in, data, LON_OUT, axis=-1)
    data = data[:, ::-1, :] # flip latitude

    # write
    variableHandle[itime, :] = np.isnan(data)


def getSrcPath(varName, initTime, lead):
    def getPath(subDir):
        rootDir = f'/data/public/HYCOM/leecw/data/{subDir}'
        if varName == 'ssh':
            pathFmt = f'{rootDir}/%Y/{varName}/%m/hycom_{varName}_%Y%m%d%H_t{lead:03d}.nc'
        elif varName in ['ssu', 'ssv']:
            pathFmt = f'{rootDir}/%Y/ssuv/%m/hycom_{varName}_%Y%m%d%H_t{lead:03d}.nc'
        elif varName in ['MLD', 'D20', 'D15']:
            varName2 = f'M{varName[1:]}'
            pathFmt = f'{rootDir}/%Y/others/%m/hycom_{varName2}_%Y%m%d%H_t{lead:03d}.nc'
        else:
            raise ValueError(f'unrecognized {varName=}')
        return tt.float2format(initTime, pathFmt)

    path = getPath('patched')
    if os.path.exists(path):
        return path

    path = getPath('source')
    return path

def valid2initLead(validTime):
    validHour = tt.hour(validTime)
    validDate = validTime - validHour/24

    # find the previous 12z
    if validHour >= 12:
        initTime = validDate + 0.5  # today 12z
    else:
        initTime = (validDate-1) + 0.5  # yesterday 12z

    lead = (validTime-initTime) * 24

    return initTime, int(lead)


if __name__ == '__main__':
    LONNAME_IN = 'lon'
    LATNAME_IN = 'lat'

    LONNAME_OUT = 'lon'
    LATNAME_OUT = 'lat'
    TIMENAME_OUT = 'time'
    LON_OUT = np.r_[0:360:0.25]
    LAT_OUT = np.r_[90:-90.25:-0.25]

    main()
