#!/home/CWA_MDPS/.conda/envs/rd/bin/python
from pytools.caltools import interp_1d
import pytools.nctools as nct
import pytools.timetools as tt
import numpy as np
import os
import netCDF4
# !!! run this file on betago !!!


def main():
    for year in [2016, 2017]:
        run(year=year, debug=False)


def run(year, debug):
    print(f'{debug=}')
    if debug:
        validTimeStart = tt.ymd2float(year, 1, 1)
        validTimeEnd = tt.ymd2float(year, 1, 2, 18)
    else:
        validTimeStart = tt.ymd2float(year, 1, 1)
        validTimeEnd = tt.ymd2float(year, 12, 31, 18)
    validTimes = getValidTimeRange(validTimeStart, validTimeEnd)

    stat = checkSrcPaths(validTimes)
    if not stat:
        return

    desPath = getDesPath(year, debug)
    print(f' {desPath=}')

    createDesFile(desPath, validTimes)

    with netCDF4.Dataset(desPath, 'a') as h:
        variableHandle = h[VARNAME_OUT]
        for itime, validTime in enumerate(validTimes):
            print(f'{itime = }, {tt.float2format(validTime, '%Y-%m-%d %Hz')}')
            readAndWriteRecord(variableHandle, itime, validTime)


def getValidTimeRange(start, end):
    delta = 6/24
    return np.r_[start:(end+delta):delta]


def checkSrcPaths(validTimes):
    stat = True
    missings = []
    
    for validTime in validTimes:
        path = getSrcPath(*valid2initLead(validTime))
        if not os.path.exists(path):
            stat = False
            missings.append(path)

    if not stat:
        print('path(s) missing:')
        for path in missings:
            print(f' {path}')
        return

    print(f' All paths are found for {
        ' to '.join([
            tt.float2format(t, '%Y-%m-%d %Hz') for t in [
                validTimes[i] for i in [0, -1]
            ]
        ])
    }, NT={len(validTimes)}.')
    return stat


def getDesPath(year, debug):
    rootDir = '/data/public/HYCOM/leecw/data/merged'
    if debug:
        return f'{rootDir}/debug_ssh_{year}.nc'
    else:
        return f'{rootDir}/ssh_{year}.nc'


def createDesFile(desPath, validTimes):
    nxOutput = len(LON_OUT)
    nyOutput = len(LAT_OUT)
    ntOutput = len(validTimes)
    outputShape = (ntOutput, nyOutput, nxOutput)

    print(f' creating output shape = {outputShape}...')
    nct.save(desPath, {
        VARNAME_OUT: np.nan * np.ones(outputShape),
        TIMENAME_OUT: validTimes,
        LATNAME_OUT: LAT_OUT,
        LONNAME_OUT: LON_OUT,
    })

    return


def readAndWriteRecord(variableHandle, itime, validTime):
    # read
    src = getSrcPath(*valid2initLead(validTime))
    lon_in = nct.read(src, LONNAME_IN)
    lat_in = nct.read(src, LATNAME_IN)
    data = nct.read(src, VARNAME_IN)

    # missing values
    data[(np.abs(data) > 1e4)] = np.nan

    # lon = -180 ~ 180 -> 0 ~ 360
    ind = next((i for i, l in enumerate(lon_in) if l >= 0), None)
    lon_in = np.concatenate((lon_in[ind:], lon_in[:ind]+360), axis=-1)
    data = np.concatenate((data[:, :, ind:], data[:, :, :ind]), axis=-1)

    # flip latitude to be increaing for the function
    data = interp_1d(lat_in, data, LAT_OUT[::-1], axis=-2, extrapolate=True) 
    data = interp_1d(lon_in, data, LON_OUT, axis=-1)
    data = data[:, ::-1, :] # flip latitude

    # write
    variableHandle[itime, :] = data


def getSrcPath(initTime, lead):
    def getPath(subDir):
        rootDir = f'/data/public/HYCOM/leecw/data/{subDir}'
        return tt.float2format(
           initTime,
            f'{rootDir}/%Y/ssh/%m/hycom_ssh_%Y%m%d%H_t{lead:03d}.nc'
        )

    path = getPath('patched')
    if os.path.exists(path):
        return path

    path = getPath('source')
    return path

def valid2initLead(validTime):
    validHour = tt.hour(validTime)
    validDate = validTime - validHour/24

    # find the previous 12z
    if validHour >= 12:
        initTime = validDate + 0.5  # today 12z
    else:
        initTime = (validDate-1) + 0.5  # yesterday 12z

    lead = (validTime-initTime) * 24

    return initTime, int(lead)


if __name__ == '__main__':
    VARNAME_IN = 'surf_el'
    LONNAME_IN = 'lon'
    LATNAME_IN = 'lat'

    VARNAME_OUT = 'ssh'
    LONNAME_OUT = 'lon'
    LATNAME_OUT = 'lat'
    TIMENAME_OUT = 'time'
    LON_OUT = np.r_[0:360:0.25]
    LAT_OUT = np.r_[90:-90.25:-0.25]

    main()
