#!/home/CWA_MDPS/.conda/envs/rd/bin/python
from pytools import nctools as nct
from pytools import timetools as tt
import numpy as np
import netCDF4 as nc
import os


def main():
    # auto()
    # manual('interpolated')
    # findPath()
    fix()

def fix():
    ncview = []
    for validTime in [ tt.format2float(d, '%Y-%m-%d %Hz') for d in [
        '2012-11-27 12Z',
        '2012-11-28 00Z',
        '2012-11-29 06Z',
        '2012-11-29 12Z',
        '2012-12-01 00Z',
        '2012-12-04 00Z',
        '2012-12-06 06Z',
        '2012-12-09 18Z',
        '2012-12-10 18Z',
        '2012-12-13 06Z',
        '2012-12-14 18Z',
        '2012-12-15 12Z',
        '2012-12-15 18Z',
        '2012-12-16 06Z',
        '2012-12-18 00Z',
        '2012-12-18 12Z',
        '2012-12-19 00Z',
        '2012-12-19 18Z',
        '2012-12-21 06Z',
        '2012-12-21 12Z',
        '2012-12-22 00Z',
        '2012-12-23 06Z',
        '2012-12-23 12Z',
        '2012-12-26 00Z',
        '2012-12-27 18Z',
        '2012-12-29 00Z',
        '2012-12-31 00Z',
        '2012-12-31 12Z',
    ]]:
        initTime, leadHour = validTime2InitLead(validTime)

        initHour = 12
        varName = 'ssh'
        year = 2014
        mode = 'fix'

        if leadHour in [6, 12]:
            initTime1, leadHour1 = initTime, leadHour - 6
            initTime2, leadHour2 = initTime, leadHour + 6
        elif leadHour == 0:
            initTime1, leadHour1 = initTime - 1, leadHour + 18
            initTime2, leadHour2 = initTime, leadHour + 6
        elif leadHour == 18:
            initTime1, leadHour1 = initTime, leadHour - 6
            initTime2, leadHour2 = initTime + 1, leadHour - 18

        srcs = [
            Forecast(varName, initTime1, leadHour1),
            Forecast(varName, initTime2, leadHour2),
        ]
        des = Forecast(varName, initTime, leadHour)
        run(varName, des, srcs, maxDeltaHour=6, mode=mode)


def findPath():
    ncview = []
    for validTime in [ tt.format2float(d, '%Y-%m-%d %Hz') for d in [
        '2012-11-27 12Z',
        '2012-11-28 00Z',
        '2012-11-29 06Z',
        '2012-11-29 12Z',
        '2012-12-01 00Z',
        '2012-12-04 00Z',
        '2012-12-06 06Z',
        '2012-12-09 18Z',
        '2012-12-10 18Z',
        '2012-12-13 06Z',
        '2012-12-14 18Z',
        '2012-12-15 12Z',
        '2012-12-15 18Z',
        '2012-12-16 06Z',
        '2012-12-18 00Z',
        '2012-12-18 12Z',
        '2012-12-19 00Z',
        '2012-12-19 18Z',
        '2012-12-21 06Z',
        '2012-12-21 12Z',
        '2012-12-22 00Z',
        '2012-12-23 06Z',
        '2012-12-23 12Z',
        '2012-12-26 00Z',
        '2012-12-27 18Z',
        '2012-12-29 00Z',
        '2012-12-31 00Z',
        '2012-12-31 12Z',
    ]]:
        initTime, leadHour = validTime2InitLead(validTime)

        print(f'init = {formatTime(initTime)} ', end='', flush=True)
        print(f'valid = {formatTime(validTime)} ', end='', flush=True)

        def get(mode):
            return getPath('ssh', initTime, leadHour, mode)

        if os.path.exists(get('source')):
            path = get('source')
        elif os.path.exists(get('fix')):
            path = get('fix')
        elif os.path.exists(get('interpolated')):
            path = get('interpolated')
        else:
            path = 'file not found'

        print(path)
        ncview.append(path)
    print(' '.join(ncview))

        
def formatTime(time):
    return tt.float2format(time, '%y-%m-%d %H')


def manual(mode):
    initHour = 12
    varName = 'ssh'
    year = 2014

    # ---- need interpolation
    # hycom_ssh_2014092312_t018.nc 

    # # ---- for 1/1
    # srcs = [
    #     Forecast(varName, tt.ymd2float(year - 1, 12, 31, initHour), 6),
    #     Forecast(varName, tt.ymd2float(year, 1, 1, initHour), 0),
    # ]
    # for lead in [12, 18]:
    #     des = Forecast(varName, tt.ymd2float(year - 1, 12, 31, initHour), lead)
    #     run(varName, des, srcs, maxDeltaHour=12, mode=mode)

    srcs = [
        Forecast(varName, tt.ymd2float(year, 9, 23, initHour), 12),
        Forecast(varName, tt.ymd2float(year, 9, 24, initHour), 0),
    ]
    des = Forecast(varName, tt.ymd2float(year, 9, 23, initHour), 18)
    run(varName, des, srcs, maxDeltaHour=6, mode=mode)

def auto():
    varName, desList, srcList = parseMissingFiles()
    if varName is None:
        return
    for des, srcs in zip(desList, srcList):
        run(varName, des, srcs)


def parseMissingFiles():
    missingFiles = [
        '/data/public/HYCOM/leecw/data/interpolated/2013/ssh/12/hycom_ssh_2013123112_t012.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2013/ssh/12/hycom_ssh_2013123112_t018.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/01/hycom_ssh_2014010212_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/01/hycom_ssh_2014010412_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/01/hycom_ssh_2014011212_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/01/hycom_ssh_2014011612_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/01/hycom_ssh_2014012612_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/01/hycom_ssh_2014012812_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/02/hycom_ssh_2014021512_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/02/hycom_ssh_2014022712_t018.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/03/hycom_ssh_2014031012_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/03/hycom_ssh_2014032312_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/04/hycom_ssh_2014040412_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/04/hycom_ssh_2014041212_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/05/hycom_ssh_2014050112_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/05/hycom_ssh_2014051112_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/05/hycom_ssh_2014051212_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/05/hycom_ssh_2014051312_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/05/hycom_ssh_2014052112_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/05/hycom_ssh_2014052212_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/05/hycom_ssh_2014052612_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014061512_t012.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014062412_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014062612_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014062712_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014062812_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014062812_t006.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014062812_t012.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014062812_t018.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/06/hycom_ssh_2014063012_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/07/hycom_ssh_2014070812_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/07/hycom_ssh_2014072712_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/08/hycom_ssh_2014080212_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/08/hycom_ssh_2014080512_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/09/hycom_ssh_2014090712_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/09/hycom_ssh_2014092012_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/09/hycom_ssh_2014092312_t018.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/10/hycom_ssh_2014100612_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/10/hycom_ssh_2014102412_t012.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/10/hycom_ssh_2014102712_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/12/hycom_ssh_2014120912_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/12/hycom_ssh_2014121312_t000.nc',
        '/data/public/HYCOM/leecw/data/interpolated/2014/ssh/12/hycom_ssh_2014122112_t000.nc',
    ]

    desList, srcList = [], []
    for path in missingFiles:
        path = ''.join(path.split('.nc')[:-1])
        varName, initStr, leadStr = [path.split('_')[i] for i in [-3, -2, -1]]
        initTime = tt.format2float(initStr, '%Y%m%d%H')
        leadHour = int(leadStr[1:4])

        des = Forecast(varName, initTime, leadHour)
        if os.path.exists(des.getPath()):
            continue

        desList.append(des)

        leadHour0 = leadHour - 6
        initTime0 = initTime
        if leadHour0 < 0:
            leadHour0 += 24
            initTime0 -= 1

        leadHour1 = leadHour + 6
        initTime1 = initTime
        if leadHour1 >= 24:
            leadHour1 -= 24
            initTime1 += 1

        srcList.append([
            Forecast(varName, initTime0, leadHour0),
            Forecast(varName, initTime1, leadHour1),
        ])

    if not desList:
        print('all files created')
        return None, None, None
    return varName, desList, srcList


def run(varName, des, srcs, maxDeltaHour=12, mode='interpolated'):
    print(f'from={[
        tt.float2format(src.getValidTime(), "%y/%m/%d_%H")
        for src in srcs
    ]}, to={tt.float2format(des.getValidTime(), "%y/%m/%d_%H")}', end=' ')

    #
    # ---- check if des needs interpolated (force to run for fix mode)
    if os.path.exists(des.getPath('source')) and mode!='fix':
        print("No need to interpolate: the target's source path exists")
        print(des.getPath('source'))
        return

    desPath = des.getPath(mode)
    if os.path.exists(des.getPath(mode)):
        print("des file already exists.", end=' ')
        print(des.getPath(mode))
        return

    #
    # ---- check if the src paths exists
    for src in srcs:
        if not os.path.exists(src.getPath('source')):
            print(f'src path not found: {src.getPath("source")}')
            return

    #
    # ---- check delta hours to interpolate
    deltas = np.array([
        24 * np.abs(des.getValidTime() - src.getValidTime())
        for src in srcs 
    ])

    if any([delta > maxDeltaHour for delta in deltas]):
        print(f'deltas too big ({deltas} > max={maxDeltaHour})')
        return

    print(f'hours={deltas}')

    #
    # ---- read data and interpolate 
    datas = np.array([src.read() for src in srcs])
    times = np.array([src.getValidTime() for src in srcs])
    newTime = des.getValidTime()
    newData = datas[0, :] \
        + (datas[1, :] - datas[0, :]) \
        / (times[1] - times[0]) \
        * (newTime - times[0]) 

    #
    # ---- save data
    desDir = os.path.dirname(desPath)
    if not os.path.exists(desDir):
        os.system(f'mkdir -p {desDir}')

    if not os.path.exists(desPath):
        os.system(f'cp {srcs[0].getPath("source")} {desPath}')

    with nc.Dataset(desPath, 'a') as h:
        h[getNcVarName(varName)][:] = newData
    

class Forecast:
    def __init__(self, varName, initTime, leadHour):
        self.initTime = initTime
        self.leadHour = leadHour
        self.varName = varName

    def getValidTime(self):
        return self.initTime + self.leadHour/24
    
    def getPath(self, mode='interpolated'):
        if mode not in ['interpolated', 'source', 'fix']:
            raise ValueError(f'unrecognized {mode=}')

        fixedPath = getPath(self.varName, self.initTime, self.leadHour, 'fix')
        if mode == 'source' and os.path.exists(fixedPath):
            return fixedPath
        else:
            return getPath(self.varName, self.initTime, self.leadHour, mode) 

    def read(self):
        return nct.read(self.getPath('source'), getNcVarName(self.varName))
       
 
def getPath(varName, initTime, leadHour, mode):
    return tt.float2format(
        initTime,
        f'../data/{mode}/%Y/{varName}/%m/hycom_{varName}_%Y%m%d%H_t{leadHour:03d}.nc'
    )


def validTime2InitLead(validTime):
    year, month, day = tt.float2ymd(validTime)
    hour = tt.hour(validTime)
    if hour >= 12:
        initTime, lead = tt.ymd2float(year, month, day, 12), hour - 12
    else:
        initTime, lead = tt.ymd2float(year, month, day, 12) - 1, hour + 12
    return initTime, lead

def getNcVarName(varName):
    if varName == 'ssh':
        return 'surf_el'
    else:
        raise NotImplementedError(f'unrecognized {varName=}')


if __name__ == '__main__':
    main()

