#!/usr/bin/env python3
"""
檢查指定年份內所有月份的 NetCDF 檔案中的陸地分佈吻合度，並自動修復有問題或遺失的檔案。

使用方法:
  在命令行中執行，並在後面加上想要檢查的年份。
  python check_traningData.py <年份>
  例如: python check_traningData.py 2013

程式功能:
1) 讀取一個專用的陸海遮罩檔案（in.nc）。
2) 遍歷指定年份的所有月份與日期。
3) 動態調整資料夾路徑以找到對應的檔案。
4) 檢查檔案是否存在，並驗證其內容是否與陸海遮罩吻合。
5) 檔案有問題或遺失時，腳本會根據狀態重新命名或創建標記檔案，然後嘗試用前後兩個最近的正常檔案進行平均，或複製單一正常檔案來創建一個新的檔案。
6) 印出報告，顯示不一致的數量與最多3個點的位置。
"""

from __future__ import annotations
import sys
import datetime
from pathlib import Path
import re
import numpy as np
import xarray as xr
import shutil

# ==================================
# === User Settings 使用者設定區 ===
# ==================================
# 這些變數現在可以透過命令行參數來動態設定，無需手動修改
# START_DATE_STR, END_DATE_STR, DATA_DIR
MASK_FILE_PATH = "../in_s.nc"
# ==================================
# === End of User Settings ===
# ==================================

# --- 內部固定參數設定 (不需經常修改) ---
FILE_TEMPLATE = "hycom_sss_{base}_t{offset:03d}.nc"
START_HOUR = 12
END_OFFSET = 18
STEP = 6
VAR_NAME = "salinity"
SAMPLE_LIMIT = 3
FILL_VALUE = -30000
# ==================================
# === End of Internal Settings ===
# ==================================


# 顏色碼
class bcolors:
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def build_path(directory: Path, base: str, offset: int) -> Path:
    rel_path = f"{str(directory)}/{FILE_TEMPLATE}".format(base=base, offset=offset)
    return Path(rel_path)

def open_dataset(path: Path, engine: str | None, decode_cf: bool) -> xr.Dataset:
    kwargs = {"decode_cf": decode_cf}
    if engine:
        kwargs["engine"] = engine
    return xr.open_dataset(path, **kwargs)

def print_changes(mask_raw: np.ndarray, current_raw: np.ndarray, limit: int, label: str):
    mask_locations_to_check = (mask_raw == FILL_VALUE)
    mismatch = mask_locations_to_check & (current_raw != FILL_VALUE)
    n_mismatch = int(np.sum(mismatch))
    
    if n_mismatch > 0:
        print(f"{bcolors.FAIL}[錯誤]{bcolors.ENDC} {label}:")
        print(f"  - {bcolors.WARNING}陸地上的數據不一致:{bcolors.ENDC} {n_mismatch} locations that should be {FILL_VALUE} have different values.")
        if limit > 0:
            mismatch_indices = np.argwhere(mismatch)
            for i, idx in enumerate(mismatch_indices[:limit], 1):
                mask_val = mask_raw[tuple(idx)]
                current_val = current_raw[tuple(idx)]
                coord_strs = [f"dim{j}={idx[j]}" for j in range(len(idx))]
                print(f"    • sample {i}: ({', '.join(coord_strs)}) - Mask value: {mask_val}, Current value: {current_val}")
    else:
        print(f"{bcolors.OKGREEN}[正常]{bcolors.ENDC} {label}: All {FILL_VALUE} locations match perfectly.")
    
    return n_mismatch

def find_and_fix_file(target_dt: datetime.datetime, target_offset: int, directory: Path, mask_raw_array: np.ndarray, status: str):
    """
    尋找前後最近的有效檔案，將它們的數據平均或複製後，產生新的檔案。
    """
    target_base = f"{target_dt.strftime('%Y%m%d')}{START_HOUR:02d}"
    target_path = build_path(directory, target_base, target_offset)
    
    print(f"  > 正在嘗試修復檔案: {target_path.name}")

    if status == 'error':
        error_file_path = target_path.with_suffix('.nc.error')
        try:
            shutil.move(target_path, error_file_path)
            print(f"  > 檔案已重新命名為: {error_file_path.name}")
        except Exception as e:
            print(f"{bcolors.FAIL}  > 重新命名檔案失敗: {e}{bcolors.ENDC}")
            return False
    elif status == 'miss':
        miss_file_path = target_path.with_suffix('.nc.miss')
        try:
            open(miss_file_path, 'a').close()
            print(f"  > 檔案遺失，已創建標記檔案: {miss_file_path.name}")
        except Exception as e:
            print(f"{bcolors.FAIL}  > 創建標記檔案失敗: {e}{bcolors.ENDC}")
            return False
    
    prev_ds, next_ds = None, None
    
    # 尋找前一個正常檔案 (可能在昨天)
    prev_dt = target_dt
    prev_off = target_offset - STEP
    if prev_off < 0:
        prev_dt -= datetime.timedelta(days=1)
        prev_off = END_OFFSET
    
    prev_base = f"{prev_dt.strftime('%Y%m%d')}{START_HOUR:02d}"
    prev_dir = Path(f"{prev_dt.year}/sss/{prev_dt.month:02d}/")
    prev_file_path = build_path(prev_dir, prev_base, prev_off)
    
    if prev_file_path.exists():
        try:
            temp_ds = open_dataset(prev_file_path, None, False)
            temp_data = temp_ds[VAR_NAME].squeeze().to_numpy()
            if temp_data.shape == mask_raw_array.shape and np.sum((mask_raw_array == FILL_VALUE) & (temp_data != FILL_VALUE)) == 0:
                prev_ds = temp_ds
                print("  > 找到前一個正常檔案。")
            else:
                temp_ds.close()
        except Exception:
            if temp_ds: temp_ds.close()
    
    # 尋找後一個正常檔案 (可能在明天)
    next_dt = target_dt
    next_off = target_offset + STEP
    if next_off > END_OFFSET:
        next_dt += datetime.timedelta(days=1)
        next_off = 0

    next_base = f"{next_dt.strftime('%Y%m%d')}{START_HOUR:02d}"
    next_dir = Path(f"{next_dt.year}/sss/{next_dt.month:02d}/")
    next_file_path = build_path(next_dir, next_base, next_off)
    
    if next_file_path.exists():
        try:
            temp_ds = open_dataset(next_file_path, None, False)
            temp_data = temp_ds[VAR_NAME].squeeze().to_numpy()
            if temp_data.shape == mask_raw_array.shape and np.sum((mask_raw_array == FILL_VALUE) & (temp_data != FILL_VALUE)) == 0:
                next_ds = temp_ds
                print("  > 找到後一個正常檔案。")
            else:
                temp_ds.close()
        except Exception:
            if temp_ds: temp_ds.close()

    if prev_ds and next_ds:
        print("  > 找到兩個正常檔案。正在進行平均...")
        prev_data_xr = prev_ds[VAR_NAME].squeeze()
        next_data_xr = next_ds[VAR_NAME].squeeze()
        
        combined_data = xr.zeros_like(prev_data_xr)
        combined_data.values = np.full_like(combined_data.values, FILL_VALUE)
        
        valid_both = (prev_data_xr != FILL_VALUE) & (next_data_xr != FILL_VALUE)
        combined_data.values[valid_both] = (prev_data_xr.values[valid_both] + next_data_xr.values[valid_both]) / 2.0
        
        valid_prev_only = (prev_data_xr != FILL_VALUE) & (next_data_xr == FILL_VALUE)
        combined_data.values[valid_prev_only] = prev_data_xr.values[valid_prev_only]
        
        valid_next_only = (prev_data_xr == FILL_VALUE) & (next_data_xr != FILL_VALUE)
        combined_data.values[valid_next_only] = next_data_xr.values[valid_next_only]
        
        new_ds = xr.Dataset({VAR_NAME: (prev_data_xr.dims, combined_data.values)}, coords=prev_ds.coords)
        new_ds[VAR_NAME].attrs = prev_ds[VAR_NAME].attrs
        
        try:
            new_ds.to_netcdf(target_path)
            print(f"{bcolors.OKGREEN}  > 成功產生修復檔案: {target_path.name}{bcolors.ENDC}")
            prev_ds.close()
            next_ds.close()
            new_ds.close()
            return True
        except Exception as e:
            print(f"{bcolors.FAIL}  > 儲存新檔案失敗: {e}{bcolors.ENDC}")
            prev_ds.close()
            next_ds.close()
            return False

    elif prev_ds or next_ds:
        source_ds = prev_ds if prev_ds else next_ds
        source_name = Path(source_ds.encoding['source']).name
        print(f"  > 只找到一個正常檔案 ({source_name})。正在進行複製...")
        try:
            shutil.copy(source_ds.encoding['source'], target_path)
            print(f"{bcolors.OKGREEN}  > 成功產生修復檔案: {target_path.name}{bcolors.ENDC}")
            source_ds.close()
            return True
        except Exception as e:
            print(f"{bcolors.FAIL}  > 複製新檔案失敗: {e}{bcolors.ENDC}")
            source_ds.close()
            return False

    else:
        print(f"{bcolors.WARNING}  > 無法找到足夠的正常檔案進行修復。{bcolors.ENDC}")
        return False

def check_date_series(current_dt: datetime.datetime, directory: Path, mask_ds: xr.Dataset):
    print(f"\n{'-'*20} 正在檢查系列: {bcolors.BOLD}{current_dt.strftime('%Y%m%d')}{bcolors.ENDC} {'-'*20}")
    
    mask_raw_array = mask_ds[VAR_NAME].squeeze().to_numpy()
    current_base_tag = f"{current_dt.strftime('%Y%m%d')}{START_HOUR:02d}"
    
    for off in range(0, END_OFFSET + 1, STEP):
        path = build_path(directory, current_base_tag, off)
        label = f"t{off:03d} ({path.name})"
        
        needs_fix = False
        status = ''
        
        if not path.exists():
            print(f"{bcolors.FAIL}[遺失]{bcolors.ENDC} {label}: 檔案不存在。")
            needs_fix = True
            status = 'miss'
        else:
            try:
                ds = open_dataset(path, None, False)
                if VAR_NAME not in ds:
                    print(f"{bcolors.FAIL}[錯誤]{bcolors.ENDC} {label}: 資料中沒有 '{VAR_NAME}' 變數。")
                    needs_fix = True
                    status = 'error'
                else:
                    current_raw_array = ds[VAR_NAME].squeeze().to_numpy()
                    if current_raw_array.shape != mask_raw_array.shape:
                        print(f"{bcolors.FAIL}[錯誤]{bcolors.ENDC} {label}: 檔案維度大小不一致。")
                        print(f"  - 遮罩檔案維度大小: {mask_raw_array.shape}")
                        print(f"  - 檢查檔案維度大小: {current_raw_array.shape}")
                        needs_fix = True
                        status = 'error'
                    else:
                        n_mismatch = print_changes(mask_raw_array, current_raw_array, SAMPLE_LIMIT, label)
                        if n_mismatch > 0:
                            needs_fix = True
                            status = 'error'
                ds.close()
            except Exception as e:
                print(f"{bcolors.FAIL}[錯誤]{bcolors.ENDC} {label}: 開啟失敗 ({e})")
                needs_fix = True
                status = 'error'
        
        if needs_fix:
            success = find_and_fix_file(current_dt, off, directory, mask_raw_array, status)
            if success:
                print(f"  > 重新檢查 {label}...")
                try:
                    new_ds = open_dataset(path, None, False)
                    new_raw = new_ds[VAR_NAME].squeeze().to_numpy()
                    recheck_mismatch = (mask_raw_array == FILL_VALUE) & (new_raw != FILL_VALUE)
                    recheck_n_mismatch = int(np.sum(recheck_mismatch))
                    if recheck_n_mismatch == 0:
                        print(f"{bcolors.OKGREEN}  > {label} 已成功修復，所有 {FILL_VALUE} 位置吻合。{bcolors.ENDC}")
                    else:
                        print(f"{bcolors.FAIL}  > 修復後仍有 {recheck_n_mismatch} 個不一致。{bcolors.ENDC}")
                    new_ds.close()
                except Exception as e:
                    print(f"{bcolors.FAIL}  > 重新檢查失敗 ({e})。{bcolors.ENDC}")
    
    print(f"{'-'*40} 系列檢查結束 {'-'*40}\n")

def main():
    if len(sys.argv) != 2:
        print(f"{bcolors.FAIL}用法錯誤:{bcolors.ENDC} 請提供一個年份作為參數。")
        print(f"範例: python {Path(sys.argv[0]).name} 2013")
        return 2

    try:
        year = int(sys.argv[1])
        if year < 1900 or year > datetime.date.today().year + 5:
            raise ValueError("年份超出合理範圍。")
    except (ValueError, IndexError) as e:
        print(f"{bcolors.FAIL}ERROR:{bcolors.ENDC} 無效的年份參數: {e}", file=sys.stderr)
        return 2

    try:
        mask_ds = open_dataset(Path(MASK_FILE_PATH), None, False)
    except FileNotFoundError:
        print(f"{bcolors.FAIL}ERROR:{bcolors.ENDC} 找不到陸海遮罩檔案: {MASK_FILE_PATH}。請確認路徑正確。", file=sys.stderr)
        return 2
    except Exception as e:
        print(f"{bcolors.FAIL}ERROR:{bcolors.ENDC} 開啟陸海遮罩檔案失敗: {e}", file=sys.stderr)
        return 2
    
    if VAR_NAME not in mask_ds:
        print(f"{bcolors.FAIL}ERROR:{bcolors.ENDC} 陸海遮罩檔案中沒有找到變數 '{VAR_NAME}'。", file=sys.stderr)
        return 2

    # 動態設定日期範圍
    start_dt = datetime.datetime.strptime(f"{year}0101", "%Y%m%d")
    end_dt = datetime.datetime.strptime(f"{year}1231", "%Y%m%d")
    current_dt = start_dt

    while current_dt <= end_dt:
        data_dir = Path(f"{current_dt.year}/sss/{current_dt.month:02d}/")
        check_date_series(current_dt, data_dir, mask_ds)
        current_dt += datetime.timedelta(days=1)
    
    mask_ds.close()
    print(f"{bcolors.OKGREEN}所有日期檢查完畢。{bcolors.ENDC}")
    return 0

if __name__ == "__main__":
    sys.exit(main())
