#!/home/CWA_MDPS/.conda/envs/rd/bin/python
import pytools as pyt
import shared as st
import numpy as np
from tqdm import tqdm
import os

# 
# "O": finished; ".": running
#
#       ssu     ssv     ssh     MLD     D20     D15
# 2001
# 2002
# 2003
# 2004
# 2005  O       O       O       O       O       O
# 2006  O       O       O       O       O       O
# 2007  O       O       O       O       O       O      
# 2008  O       O       O       O       O       O         
# 2009  O       O       O       O       O       O         
# 2010  O       O       O       O       O       O         
# 2011  O       O       O       O       O       O
# 2012  O       O       O       O       O       O
# 2013  O       O       O       O       O       O
# 2014  O       O       O       O       O       O
# 2015  O       O       O       O       O       O
# 2016  O       O       O       O       O       O
# 2017                                           
# 2018                                           
# 2019
# 2020


def main():
    run(2015, 'sst', debug=False, forceUpdate=False)
    # for year in [2017]:
    #     for varName in ['ssh', 'ssu', 'ssv', 'D20', 'D15', 'MLD']:
    #         run(year, varName, debug=False, forceUpdate=False)
    #
    # for varName in ['D20']:
    #     for year in [2011, 2012, 2013, 2014, 2015]:
    #         run(
    #             year=year, 
    #             varName=varName, 
    #             debug=False, 
    #             forceUpdate=True
    #         )


class Global:
    def __init__(self):
        self.numStats = 4

        self.isLand = self.read_land_sea_mask()
        self.isSea = 1 - self.isLand
        self.nIsLand = np.sum(self.isLand)
        self.nIsSea = np.sum(self.isSea)

    def read_land_sea_mask(self):
        return pyt.nct.read('../data/land_sea_mask_hycom.nc', 'isLand')


def run(year, varName, debug, forceUpdate):
    print(f'{year=}, {varName=}, {debug=}, {forceUpdate=}')
    dateStart = pyt.tt.ymd2float(year, 1, 1)
    if debug:
        path = f'../data/qcstat/{varName}/debug_{year}.nc'
        dateEnd = pyt.tt.ymd2float(year, 1, 1) + 0.75
    else:
        dateEnd = pyt.tt.ymd2float(year, 12, 31) + 0.75
        path = f'../data/qcstat/{varName}/{year}.nc'

    pyt.ft.check_des_path(path)

    #
    # ---- auto
    if not forceUpdate and os.path.exists(path):
        print(f'path already exists: {path}')
        return


    dates = np.r_[dateStart:dateEnd+0.25:0.25]

    datas = np.nan * np.ones((len(dates), glb.numStats))
    for iDate, date in enumerate(tqdm(dates)):
        datas[iDate, :] = readMeanStd(varName, date)

    for varName, data in zip(
        ['mean', 'std', 'land_nans_ratio', 'sea_nans_ratio'], 
        np.transpose(datas)
    ):
        pyt.nct.save(path, {varName: data, 'time': dates}, overwrite=True)

    print('')
    print(f'saved: {path}')


def readMeanStd(varName, date):
    forecast = st.Forecast(varName, *st.validTime2InitLead(date))
    path = forecast.getPath('source')

    if not os.path.exists(path):
        print(f'path not found {path}')
        return tuple([np.nan]*glb.numStats)

    data = forecast.read()
    data = st.maskInvalidValues(data, varName, glb.isLand)

    mean = float(np.nanmean(data))
    std = float(np.nanstd(data))

    isnan = np.isnan(data)
    land_nans_ratio = np.sum(isnan * glb.isLand) / glb.nIsLand # -> should be close to 1
    sea_nans_ratio = np.sum(isnan * glb.isSea) / glb.nIsSea # -> should be close to 0

    return mean, std, land_nans_ratio, sea_nans_ratio


if __name__ == '__main__':
    glb = Global() 
    main()
