import os  
import xarray as xr
import numpy as np
from tqdm import tqdm
from numba import njit

# === 🧩 Set Parameters ===
region = "glb"  # Available regions: glb, pac, ind, atl, nh, sh, eq
max_depth = 1000
calc_d20 = True
calc_d15 = True
calc_mld = True
# =====================
hfile = os.getenv('hfile')
if hfile is None:
    raise ValueError("❌ 'hfile' environment variable is not set!")

print(f"Reading from file: {hfile}")

ohfile1 = os.getenv('ohfile1')
ohfile2 = os.getenv('ohfile2')
ohfile3 = os.getenv('ohfile3')


# 📁 Load data
ds = xr.open_dataset(hfile, engine='netcdf4', decode_times=False)

# 🔧 Select Region
def select_region(ds, region):
    if region == "glb":
        return ds
    elif region == "pac":
        return ds.sel(lon=slice(120, 290))
    elif region == "ind":
        return ds.sel(lon=slice(30, 120))
    elif region == "atl":
        return ds.sel(lon=slice(290, 20))
    elif region == "nh":
        return ds.sel(lat=slice(0, 90))
    elif region == "sh":
        return ds.sel(lat=slice(-90, 0))
    elif region == "eq":
        return ds.sel(lat=slice(-10, 10))
    else:
        raise ValueError(f"❌ Unsupported region: {region}")

ds = select_region(ds, region)
ds = ds.sel(depth=ds.depth.where(ds.depth <= max_depth, drop=True))
print(f"🔎 Depth layers: {len(ds.depth)} layers, range: {ds.depth.min().values} ~ {ds.depth.max().values} m")

# ➤ Extract temperature data
temp_raw = ds['water_temp'].astype(float)
scale_factor = temp_raw.encoding.get('scale_factor', 1.0)
add_offset = temp_raw.encoding.get('add_offset', 0.0)
temp = temp_raw * scale_factor + add_offset

# ➤ Handle missing values (_FillValue = -10 for land)
fill_value = temp_raw.encoding.get('_FillValue', -10)
temp = temp.where(temp != fill_value)

# ➤ Extract numpy arrays for temperature, depth, lat, lon, and time
temp_np = temp.values
depth_np = ds.depth.values
lat = ds.lat.values
lon = ds.lon.values
time = ds.time.values

# ➤ Initialize output arrays (optional)
shape = temp_np.shape[0], len(lat), len(lon)
if calc_d20:
    d20_array = np.full(shape, np.int16(fill_value))  # Set default to -10 (np.int16)
if calc_d15:
    d15_array = np.full(shape, np.int16(fill_value))  # Set default to -10 (np.int16)
if calc_mld:
    mld_array = np.full(shape, np.int16(fill_value))  # Set default to -10 (np.int16)

# ➤ Numba acceleration functions
@njit
def calculate_depth_of_isotherm(profile, depth, threshold):
    for k in range(len(profile) - 1):
        if np.isnan(profile[k]) or np.isnan(profile[k + 1]):
            continue
        if profile[k] > threshold and profile[k + 1] <= threshold:
            t1, t2 = profile[k], profile[k + 1]
            z1, z2 = depth[k], depth[k + 1]
            return z1 + (threshold - t1) * (z2 - z1) / (t2 - t1)
    return np.nan

@njit
def calculate_mld(profile, depth):
    surface_temp = profile[0]
    for k in range(len(profile)):
        if np.isnan(profile[k]):
            continue
        delta = abs(profile[k] - surface_temp)
        if delta > 0.5:
            k = k - 1 if k > 0 else 0
            if k < len(depth) - 1:
                t1 = abs(profile[k] - surface_temp)
                t2 = abs(profile[k + 1] - surface_temp)
                z1, z2 = depth[k], depth[k + 1]
                if t2 - t1 != 0:
                    return z1 + (0.5 - t1) * (z2 - z1) / (t2 - t1)
            break
    return np.nan

# ➤ Main loop (using Numba for acceleration)
for t in tqdm(range(temp_np.shape[0]), desc=f"⏳ Calculating d20 for {region}"):
    for i in range(len(lat)):
        for j in range(len(lon)):
            profile = temp_np[t, :, i, j]
            if np.all(np.isnan(profile)):  # If the whole profile is NaN (land)
                d20_array[t, i, j] = np.int16(fill_value)  # Keep land value as -10 (np.int16)
                d15_array[t, i, j] = np.int16(fill_value)  # Keep land value as -10 (np.int16)
                mld_array[t, i, j] = np.int16(fill_value)  # Keep land value as -10 (np.int16)
                continue
            if calc_d20:
                d20_value = calculate_depth_of_isotherm(profile, depth_np, 20)
                d20_array[t, i, j] = np.int16(d20_value) if not np.isnan(d20_value) else np.int16(0)  # Set 0 if no valid value
            if calc_d15:
                d15_value = calculate_depth_of_isotherm(profile, depth_np, 15)
                d15_array[t, i, j] = np.int16(d15_value) if not np.isnan(d15_value) else np.int16(0)  # Set 0 if no valid value
            if calc_mld:
                mld_value = calculate_mld(profile, depth_np)
                mld_array[t, i, j] = np.int16(mld_value) if not np.isnan(mld_value) else np.int16(0)  # Set 0 if no valid value

# ➤ Save function
def save_output(array, name, desc, output_file):
    da = xr.DataArray(
        array,
        coords={"time": time, "lat": lat, "lon": lon},
        dims=["time", "lat", "lon"],
        name=name,
        attrs={
            "long_name": desc,
            "units": "m",
            "standard_name": name.lower()
        }
    )
    da.to_netcdf(output_file)  # Save to the provided output file path
    print(f"✅ Saved: {output_file}")

# ➤ Save selected results to specific output files
if calc_d20:
    save_output(d20_array, "D20", "Depth of 20C Isotherm", ohfile2)
if calc_d15:
    save_output(d15_array, "D15", "Depth of 15C Isotherm", ohfile1)
if calc_mld:
    save_output(mld_array, "MLD", "Mixed Layer Depth (ΔT > 0.5°C)", ohfile3)
