#!/home/CWA_MDPS/.conda/envs/rd/bin/python
import pytools as pyt
import shared as st
from tqdm import tqdm
import numpy as np


def main():
    year = 2014
    varName, varName2 = 'ssh', 'surf_el'

    #
    # ---- auto
    desPath = f'../data/merged/{varName}_{year}.nc'
    LON_OUT = pyt.nct.read(desPath, 'lon')
    LAT_OUT = pyt.nct.read(desPath, 'lat')

    #
    # ---- read corrupted dates
    dates = pyt.nct.read(f'../data/qcstat/ssh/corrupted_dates_{year}.nc', 'date')
    date0 = pyt.tt.ymd2float(year, 1, 1)
    for date in tqdm(dates):
        date = float(date)
        iDate = int(4 * (date - date0))
        srcPath = st.Forecast(varName, *st.validTime2InitLead(date)).getPath('patched')
        data, dims = pyt.nct.ncreadByDimRange(srcPath, varName2, [[None]*2]*3)

        # missing values
        data[(np.abs(data) > 1e4)] = np.nan

        # lon = -180 ~ 180 -> 0 ~ 360
        ind = next((i for i, l in enumerate(dims[-1]) if l >= 0), None)
        dims[-1] = np.concatenate((dims[-1][ind:], dims[-1][:ind]+360), axis=-1)
        data = np.concatenate((data[:, :, ind:], data[:, :, :ind]), axis=-1)

        # flip latitude to be increaing for the function
        data = pyt.ct.interp_1d(dims[-2], data, LAT_OUT[::-1], axis=-2, extrapolate=True)
        data = pyt.ct.interp_1d(dims[-1], data, LON_OUT, axis=-1)
        data = data[:, ::-1, :] # flip latitude

        pyt.nct.write(
            desPath, varName, data, [slice(iDate, iDate+1), slice(None), slice(None)]
        )


if __name__ == '__main__':
    main()
