#!/home/CWA_MDPS/.conda/envs/rd/bin/python
import shared as st
import pytools as pyt
import numpy as np
import os

#       ssu     ssv     ssh     MLD     D20     D15
# 2001
# 2002
# 2003
# 2004
# 2005  O       O       O       O       O       O
# 2006  O       O       O       O       O       O
# 2007  O       O       O       O       O       O
# 2008  O       O       O       O       O       O
# 2009  O       O       O       O       O       O
# 2010  O       O       O       O       O       O
# 2011  O       O       O       O       O       O
# 2012  O       O       O       O       O       O
# 2013  O       O       O       O       O       O
# 2014  O       O       O       O       O       O
# 2015  O       O       O       O       O       O
# 2016  O       O       O       O       O       O   
# 2017                                              
# 2018
# 2019
# 2020


def main():
    run_all()
    # var_by_var()


def run_all():
    years = [2005, 2006]
#     years = [2007]
    doInterpolate, debug_meanRange = True, False
    for varName in ['ssu', 'ssv', 'ssh', 'MLD', 'D20', 'D15']:
        for year in years:
            run(year, varName, doInterpolate, debug_meanRange)


def var_by_var():
    varName = 'ssh'
    istep = 2

    year = 2017
    years = [2005, 2006, 2017]

    if istep == 1: # should all fail
        doInterpolate, debug_meanRange = False, True
        run(year, varName, doInterpolate, debug_meanRange)

    elif istep == 2:
        doInterpolate, debug_meanRange = False, False
        run(year, varName, doInterpolate, debug_meanRange)

    elif istep == 3:
        doInterpolate, debug_meanRange = True, False
        for year in years:
            run(year, varName, doInterpolate, debug_meanRange)


def run(year, varName, doInterpolate, debug_meanRange):
    print(f'{year=}, {varName=}, {doInterpolate=}')
    maxDelta = 15
    overwrite = True

    if varName == 'ssu':
        meanRange = (-0.01, 0.08)
        stdRange = (0.17, 0.25)

    elif varName == 'ssv':
        meanRange = (-0.01, 0.08)
        stdRange = (0.14, 0.20)

    elif varName == 'ssh':
        meanRange = (-0.43, -0.35)
        stdRange = (0.69, 0.75)

    elif varName == 'MLD':
        meanRange = (60, 112)
        stdRange = (55, 135)

    elif varName == 'D20':
        meanRange = (25, 35)
        stdRange = (55, 65)

    elif varName == 'D15':
        meanRange = (64, 71)
        stdRange = (109, 119)

    else:
        raise ValueError(f'Unrecognized {varName=}')

    if debug_meanRange:
        meanRange = (0.04, 0.06)
    
    #
    # ---- find corrupted dates
    corruptedDates = findCorruptedDates(year, varName, meanRange, stdRange)

    desPath = f'../data/qcstat/{varName}/corrupted_dates_{year}.nc'
    if overwrite and os.path.exists(desPath):
        os.remove(desPath)
        
    pyt.nct.save(
        f'../data/qcstat/{varName}/corrupted_dates_{year}.nc',
        {'date': corruptedDates}, overwrite=True
    )

    if not doInterpolate:
        print(f'{doInterpolate=}: not going to do the interpolation.')
        return

    #
    # ---- interpolation
    for date in corruptedDates:
        date = float(date)
        valid0 = findNeighborValidDate(varName, date, corruptedDates, -0.25, maxDelta)
        valid1 = findNeighborValidDate(varName, date, corruptedDates, 0.25, maxDelta)
        srcs = [
            st.Forecast(varName, *st.validTime2InitLead(valid))
            for valid in [valid0, valid1]
        ]
        des = st.Forecast(varName, *st.validTime2InitLead(date))
        st.interpolate(varName, des, srcs, maxDelta * 6)


def findCorruptedDates(year, varName, meanRange, stdRange):
    qcpath = f'../data/qcstat/{varName}/{year}.nc'
    means = pyt.nct.read(qcpath, 'mean')
    stds = pyt.nct.read(qcpath, 'std')
    dates = pyt.nct.read(qcpath, 'time')
    land_nans_ratio = pyt.nct.read(qcpath, 'land_nans_ratio')
    sea_nans_ratio = pyt.nct.read(qcpath, 'sea_nans_ratio')

    max_sea_nans_ratio = 0.01
    min_land_nans_ratio = 0.97

    if varName in ['MLD', 'D20', 'D15'] and year <= 2010:
        max_sea_nans_ratio = 0.08

    elif varName in ['D15', 'D20'] and year in [2016]:
        max_sea_nans_ratio = 0.03

    elif varName in ['MLD'] and year in [2016]:
        max_sea_nans_ratio = 0.08

    meanmean = np.nanmean(means)
    stdmean = np.nanstd(means)
    meanstd = np.nanmean(stds)
    stdstd = np.nanstd(stds)

    deviations = np.sqrt(
        ((means - meanmean) / stdmean)**2 
        + ((stds - meanstd) / stdstd)**2
    )

    corrupted_m = (~isBetween(means, meanRange))
    corrupted_s = ((stds <= stdRange[0]) | (stdRange[1] <= stds))
    corrupted_n = (np.isnan(means) | np.isnan(stds))
    corrupted_ln = (~(land_nans_ratio > min_land_nans_ratio))
    corrupted_sn = (~(sea_nans_ratio < max_sea_nans_ratio))

    isCorrupted = (
        corrupted_m | corrupted_s | corrupted_n
        | corrupted_ln | corrupted_sn
    )

    print(f'{'date':>11s} {'dev.':>8s} {'mean':>8s} {'std':>8s} {'l_nans':>8s} {'s_nans':>8s}')
    endChars = {True: '*', False:' '}
    for mean, std, l_nans, s_nans, date, deviation, \
        c, c_m, c_s, c_ln, c_sn in zip(
        means, 
        stds,
        land_nans_ratio,
        sea_nans_ratio,
        dates,
        deviations,
        isCorrupted,
        corrupted_m,
        corrupted_s,
        corrupted_ln,
        corrupted_sn,
    ):
        if not c:
            continue
        print(f'{pyt.tt.float2format(date, "%y-%m-%d_%H")}', end=' ')
        print(f'{deviation:+8.2f}', end=' ')

        print(f'{mean:+8.2f}', end=endChars[c_m])
        print(f'{std:+8.2f}', end=endChars[c_s])
        print(f'{l_nans:8.3f}', end=endChars[c_ln])
        print(f'{s_nans:8.3f}', end=endChars[c_sn])

        if (deviation < 3) and (not c_sn) and (not c_ln):
            print(' MAY NOT BE CORRUPTED', end='') 
        print('')

    ind = np.argsort(deviations)
    print('----')

    for i in ind:
        if corrupted_n[i]:
            continue
        if not isCorrupted[i]:
            continue
        if corrupted_ln[i] or corrupted_sn[i]:
            continue
        print(f'{pyt.tt.float2format(dates[i], "%y-%m-%d_%H")}', end=' ')
        print(f'{deviations[i]:8.2f}', end='')
        print(f'{means[i]:+8.2f}', end=' ')
        print(f'{stds[i]:+8.2f}', end=' ')
        print('')

    print('')
    print(f'valid {meanRange=}')
    print(f'mean: (min, max) = {np.nanmin(means):5f}, {np.nanmax(means):5f}')
    print(f'valid {stdRange=}')
    print(f'std: (min, max) = {np.nanmin(stds):5f}, {np.nanmax(stds):5f}')
    print(f'land_nans_ratio: (min, max) = {np.nanmin(land_nans_ratio):5f}, {np.nanmax(land_nans_ratio):5f}')
    print(f'sea_nans_ratio: (min, max) = {np.nanmin(sea_nans_ratio):5f}, {np.nanmax(sea_nans_ratio):5f}')

    print('')
    if sum(~isCorrupted):
        i = np.argmax(deviations[~isCorrupted])
        print(f'max(dev(not_corrupted)) @ {pyt.tt.float2format(dates[~isCorrupted][i], '%y-%m-%d_%H')}:{deviations[~isCorrupted][i]:+6.2f} ←↓ should be separated')
    else:
        print('all corrupted :(((')

    c_not_due_to_mask = (
        (corrupted_m | corrupted_s) & (~corrupted_ln) & (~corrupted_sn)
    )
    if sum(c_not_due_to_mask):
        print(f'min(dev(corrutped_not_due_to_mask))={np.min(deviations[c_not_due_to_mask]):+8.2f} <-should be large')
    else:
        print(f'all the corrupted are with a land-sea mask problem')
    print('')

    print(f'num(corrupted) = {sum(isCorrupted)}')
    return dates[isCorrupted]


def isBetween(values, rng):
    return ((rng[0] <= values) & (values <=rng[1]))


def findNeighborValidDate(varName, date, invalid, shift, maxNShift):
    neighbor, i = float(date), 0
    while i < maxNShift:
        neighbor += shift
        i += 1
        path = st.getPath(varName, *st.validTime2InitLead(neighbor), 'source')
        if neighbor not in invalid and os.path.exists(path):
            return neighbor
            break
    raise RuntimeError(f'{varName=}: cannot find a valid neighbor for {pyt.tt.float2format(date, "%y/%m/%d_%H")} with {maxNShift=}')


if __name__ == '__main__':
    main()

