#/usr/bin/env python
import pytools as pyt
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import matplotlib
import numpy as np
from dataclasses import dataclass
import os
matplotlib.use("Agg")


class Global:
    def __init__(self):
        self.figPrefix = '/home/CWA_MDPS/hycom/figs'


def main():
    varName = 'ssv'
    year = 2011
    for varName in ['ssh', 'ssu', 'ssv', 'MLD', 'D20', 'D15']:
        for year in range(2007, 2015+1):
            try:
                plot_map_monthly(varName, year)
            except Exception as e:
                print(repr(e))
                


def plot_map_monthly(varName:str, year:int) -> None:
    figName = f'{glb.figPrefix}/{varName}_monmean_{year}.png'
    if os.path.exists(figName):
        fp.print(f'fig already exists: {figName}')
        return 
    path = get_src_path(varName, pyt.tt.ymd2float(year, 1, 1), '%Y', 'monmean')

    fp.flush(f'reading {path}..')
    data = pyt.nct.read(path, varName)
    data = np.flip(data, axis=-2)

    fp.flush(f'plotting..')
    lon = np.r_[0:360:0.25]
    lat = np.r_[-90:90+0.25:0.25]

    npy, npx = 4, 3

    vmin, vmax = np.percentile(data, [15, 85])

    fig = plt.figure(layout='constrained')
    for it, z in enumerate(data):
        ax = fig.add_subplot(npy, npx, it+1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.pcolor(lon, lat, z, vmin=vmin, vmax=vmax)
        pyt.pt.plotcoast(ax, color='tab:red')

        mean = np.nanmean(z)
        std = np.nanstd(z)
        min = np.min(z)
        max = np.max(z)
        nnans = np.sum(np.isnan(z))

        pyt.pt.titleCorner(
            ax,
            f'{mean:.2f}, {std:.2f}, {min:.2f}, {max:.2f}, {nnans}',
            fontsize=7
        )
    fig.savefig(figName, dpi=150)
    fp.print(f'fig saved to {figName}')



def plot_map_6hrly(varName:str, day:float) -> None:
    npanels = 4
    dateRange = [day, day+0.75]

    # ---- begins
    data, dims = read_data(varName, dateRange)
    data[(data==0)] = np.nan

    nrows = int(np.ceil(np.sqrt(npanels)))
    ncols = int(np.ceil(npanels / nrows))

    x = dims[-1]
    y = dims[-2]
    fig = plt.figure(layout='constrained')
    for iax in range(npanels):
        z = data[iax, :]
        ax = fig.add_subplot(nrows, ncols, iax+1)
        hp = ax.pcolor(x, y, z)
        pyt.pt.plotcoast(ax, color='k')
        fig.colorbar(hp, ax=ax)
        glb_mean = np.nanmean(z)
        glb_std = np.nanstd(z)
        glb_max = np.nanmax(z)
        glb_min = np.nanmin(z)
        ax.set_title(f'{glb_mean:.2f}, {glb_std:.2f}, {glb_min:.2f}, {glb_max:.2f}', fontsize=10)
        ax.set_xlabel(f'{pyt.tt.hour(float(dims[0][iax]))}')

    fig.suptitle(f'{varName}, {pyt.tt.float2format(day, '%Y-%m-%d')}')

    figName = pyt.tt.float2format(day, f'{glb.figPrefix}/6hrly/{varName}/%Y/%y%m%d.png')
    if not os.path.exists(os.path.dirname(figName)):
        os.system(f'mkdir -p {os.path.dirname(figName)}')
    fp.print(f'fig saved to {figName}')
    fig.savefig(figName)


# def read_data_raw(varName:str, timeRange:list[float]) -> tuple[np.ndarray, list[np.ndarray]]:
#     if len(timeRange) != 2:
#         raise ValueError(f'expected 2 elements in timeRange but received {len(timeRange)}')
#
#     if pyt.tt.year(timeRange[0]) != pyt.tt.year(timeRange[1]):
#         raise ValueError(f'not supported for reading multiple years')
#
#     path = get_src_path(varName, pyt.tt.year(timeRange[0]))
#     minMaxs = [timeRange, [None, None], [None, None]]
#     fp.flush(f'reading {path=}, time={[pyt.tt.float2format(d, '%y-%m-%d_%H') for d in timeRange]}')
#     data, dims = pyt.nct.ncreadByDimRange(path, varName, minMaxs)
#     return data, dims


def get_src_path(varName:str, date:float, dateFormat:str, freq:str) -> str:
    root = f'/data/public/HYCOM/leecw/data/merged_interp/{freq}'
    return pyt.tt.float2format(date, f'{root}/{varName}_{dateFormat}_{freq}.nc')


def get_des_path(varName:str, suffix:str) -> str:
    ...


def test_read_data(varName:str, timeRange:list[float]):
    fp.flush('test reading..')
    data, dims = read_data(varName, timeRange)
    fp.print(f'{data.shape = }')
    fp.print(f'test passed for {varName=}, timeRange={'-'.join([pyt.tt.float2format(t) for t in timeRange])}')


if __name__ == '__main__':
    fp = pyt.tmt.FlushPrinter()
    glb = Global()
    main()
