# -*- coding: utf-8 -*-
"""
train_piston_v5_3_hybrid.py
------------------------------------------------
v5.3 Hybrid 版（快速+穩定，無巨大快取）
- 統一輸入大小：960x544（灰階）
- 自動掃描 png/jpg/jpeg，壞圖自動略過＆記錄
- 混合精度 AMP + channels_last + cuDNN benchmark
- 從 v5.2 最佳權重續訓（若存在）
- 自動分出少量驗證集（不需 val 資料夾）
- 每個 epoch 產生預覽：outputs/preview_v53.png
- 儲存最新 ckpt 與最佳權重
"""

import os, time, math, random, traceback
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# ---------------- 基本設定 ----------------
ROOT = Path(r"C:\xampp\htdocs\cs_ai")
DIR_IMG = ROOT / r"data\train\images"       # 單一資料夾
OUT_DIR = ROOT / "outputs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_DIR = ROOT / "ai_models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
INIT_WEIGHT = MODEL_DIR / "piston_v5_2_best.pth"      # 續訓來源（若存在）
CKPT_PATH   = MODEL_DIR / "piston_v5_3_hybrid_latest.ckpt"
BEST_PATH   = MODEL_DIR / "piston_v5_3_hybrid_best.pth"

LOG_BAD = ROOT / "bad_images.txt"

# 影像設定
W, H = 960, 544
CHANNELS = 1

# 訓練超參數
EPOCHS = 100
INIT_BATCH = 8            # OOM 會自動降
LR = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
PIN_MEMORY = True
VAL_RATIO = 0.02          # 從訓練資料中切一小部分做驗證
VAL_MAX = 64              # 驗證最多張數
PREVIEW_EVERY = 1

# CUDA / cuDNN 加速
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

# ---------------- 模型 ----------------
def conv_block(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
        nn.BatchNorm2d(out_ch),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
        nn.BatchNorm2d(out_ch),
        nn.LeakyReLU(0.1, inplace=True),
    )

class UNetAE(nn.Module):
    def __init__(self, ch=1, base=32):
        super().__init__()
        self.c1 = conv_block(ch, base);   self.p1 = nn.MaxPool2d(2)
        self.c2 = conv_block(base, base*2); self.p2 = nn.MaxPool2d(2)
        self.c3 = conv_block(base*2, base*4); self.p3 = nn.MaxPool2d(2)
        self.c4 = conv_block(base*4, base*8); self.p4 = nn.MaxPool2d(2)
        self.bn = conv_block(base*8, base*16)
        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, 2); self.d4 = conv_block(base*16, base*8)
        self.up3 = nn.ConvTranspose2d(base*8,  base*4, 2, 2); self.d3 = conv_block(base*8,  base*4)
        self.up2 = nn.ConvTranspose2d(base*4,  base*2, 2, 2); self.d2 = conv_block(base*4,  base*2)
        self.up1 = nn.ConvTranspose2d(base*2,  base,   2, 2); self.d1 = conv_block(base*2,  base)
        self.out = nn.Conv2d(base, 1, 1)
    def forward(self, x):
        c1 = self.c1(x); p1 = self.p1(c1)
        c2 = self.c2(p1); p2 = self.p2(c2)
        c3 = self.c3(p2); p3 = self.p3(c3)
        c4 = self.c4(p3); p4 = self.p4(c4)
        bn = self.bn(p4)
        u4 = self.up4(bn); d4 = self.d4(torch.cat([u4, c4], 1))
        u3 = self.up3(d4); d3 = self.d3(torch.cat([u3, c3], 1))
        u2 = self.up2(d3); d2 = self.d2(torch.cat([u2, c2], 1))
        u1 = self.up1(d2); d1 = self.d1(torch.cat([u1, c1], 1))
        return torch.sigmoid(self.out(d1))

# ---------------- 資料集 ----------------
IMG_EXTS = {".png", ".jpg", ".jpeg"}

def list_images(folder: Path):
    return sorted([p for p in folder.iterdir() if p.suffix.lower() in IMG_EXTS])

def read_gray_resize(path: Path, w=W, h=H):
    img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise RuntimeError("讀檔失敗")
    if img.shape[0] <= 0 or img.shape[1] <= 0:
        raise RuntimeError("尺寸無效")
    if img.shape[1] != w or img.shape[0] != h:
        img = cv2.resize(img, (w, h), interpolation=cv2.INTER_AREA)
    img = img.astype(np.float32) / 255.0
    return img  # (H,W) float32 [0,1]

class PistonDs(Dataset):
    def __init__(self, folder: Path, log_bad: Path):
        self.files = list_images(folder)
        self.good = []
        self.bad = []
        # 預掃一輪（快速）確認可讀
        for p in self.files:
            try:
                _ = cv2.imdecode(np.fromfile(str(p), dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
                # 不做 resize/normalize，僅測可讀；讀不到再 fallback 直接 imread
                self.good.append(p)
            except Exception:
                self.bad.append(p)
        if self.bad:
            with open(log_bad, "a", encoding="utf-8") as f:
                for p in self.bad:
                    f.write(f"[預掃失敗] {p}\n")
        self.files = self.good

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        p = self.files[idx]
        try:
            img = read_gray_resize(p, W, H)
        except Exception as e:
            # 真正讀取仍失敗 → 回傳 None 交由 collate_fn 過濾
            return None
        t = torch.from_numpy(img).unsqueeze(0)          # (1,H,W)
        return t, t

def safe_collate(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    xs, ys = zip(*batch)
    return torch.stack(xs, 0), torch.stack(ys, 0)

def make_loader(ds, bs, shuffle):
    return DataLoader(
        ds, batch_size=bs, shuffle=shuffle,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=safe_collate, drop_last=True, persistent_workers=(NUM_WORKERS>0),
        prefetch_factor=(4 if NUM_WORKERS>0 else None)
    )

# ---------------- 輔助 ----------------
def describe_device():
    if torch.cuda.is_available():
        i = torch.cuda.current_device()
        print("✅ 使用裝置：cuda")
        print(f"   GPU：{torch.cuda.get_device_name(i)} | CUDA：{torch.version.cuda}")
        return torch.device("cuda")
    print("⚠️ 未偵測到 GPU，改用 CPU")
    return torch.device("cpu")

def save_preview(xb, pred, out_path):
    i = xb[0,0].detach().cpu().numpy()
    p = pred[0,0].detach().cpu().numpy()
    plt.figure(figsize=(12,4))
    for k,(title,img) in enumerate([("Input",i),("Predict",p),("Diff",np.abs(p-i))]):
        plt.subplot(1,3,k+1); plt.imshow(img, cmap="gray"); plt.title(title); plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()

# ---------------- 主流程 ----------------
def main():
    print(f"✅ 使用訓練資料路徑：{DIR_IMG}")
    ds_all = PistonDs(DIR_IMG, LOG_BAD)
    if len(ds_all) == 0:
        raise RuntimeError("資料夾中沒有可用影像（或副檔名非 png/jpg/jpeg）。")

    # 切出小型驗證集
    val_len = max(1, min(VAL_MAX, int(len(ds_all)*VAL_RATIO)))
    train_len = len(ds_all) - val_len
    ds_train, ds_val = random_split(ds_all, [train_len, val_len],
                                    generator=torch.Generator().manual_seed(2025))

    device = describe_device()
    amp = (device.type == "cuda")

    model = UNetAE(ch=CHANNELS, base=32).to(device).to(memory_format=torch.channels_last)
    # 續訓
    if INIT_WEIGHT.exists():
        try:
            w = torch.load(str(INIT_WEIGHT), map_location="cpu")
            if "model" in w:
                model.load_state_dict(w["model"], strict=False)
            else:
                model.load_state_dict(w, strict=False)
            print(f"✅ 已載入前版模型: {INIT_WEIGHT.name}")
        except Exception as e:
            print(f"⚠️ 載入前版模型失敗：{e}")

    opt = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sch = CosineAnnealingLR(opt, T_max=EPOCHS, eta_min=1e-6)
    scaler = torch.amp.GradScaler('cuda', enabled=amp)
    crit = nn.MSELoss()

    # 嘗試 batch 探測
    batch = INIT_BATCH
    while True:
        try:
            tmp_loader = make_loader(ds_train, batch, shuffle=True)
            first = next(iter(tmp_loader))
            if first is None: raise RuntimeError("首個 batch 無有效資料")
            xb, yb = first
            xb = xb.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            break
        except Exception as e:
            msg = str(e).lower()
            if "out of memory" in msg or "cuda" in msg:
                batch = max(1, batch//2)
                print(f"⚠️ OOM/初始化失敗，降低 batch_size → {batch}")
                if device.type=="cuda":
                    torch.cuda.empty_cache()
                if batch == 1:
                    break
            else:
                raise

    train_loader = make_loader(ds_train, batch, shuffle=True)
    val_loader   = make_loader(ds_val,   bs=1,  shuffle=False)

    best_val = float("inf")
    start_epoch = 0
    # 續訓 ckpt
    if CKPT_PATH.exists():
        try:
            ck = torch.load(str(CKPT_PATH), map_location=device)
            model.load_state_dict(ck["model"], strict=False)
            opt.load_state_dict(ck["opt"])
            sch.load_state_dict(ck["sch"])
            scaler.load_state_dict(ck["scaler"])
            best_val = float(ck.get("best", best_val))
            start_epoch = int(ck.get("epoch", 0))+1
            print(f"🔁 從 epoch {start_epoch} 續訓（best_val={best_val:.6f}）")
        except Exception as e:
            print(f"⚠️ 載入舊 ckpt 失敗：{e}")

    preview_path = OUT_DIR / "preview_v53.png"

    for epoch in range(start_epoch, EPOCHS):
        model.train()
        if device.type=="cuda":
            torch.cuda.reset_peak_memory_stats()
        t0 = time.time()
        losses, seen = [], 0

        pbar = tqdm(train_loader, total=len(train_loader),
                    desc=f"Epoch {epoch+1}/{EPOCHS}", ncols=110)
        for batch_data in pbar:
            if batch_data is None:
                continue
            xb, yb = batch_data
            xb = xb.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            yb = yb.to(device, non_blocking=True).to(memory_format=torch.channels_last)

            opt.zero_grad(set_to_none=True)
            try:
                with torch.amp.autocast("cuda", enabled=amp):
                    pred = model(xb)
                    loss = crit(pred, yb)
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    if device.type=="cuda":
                        torch.cuda.empty_cache()
                    print("⚠️ 單步 OOM，略過此 batch")
                    continue
                raise

            l = loss.item()
            losses.append(l)
            seen += xb.size(0)
            spd = seen / max(1e-6, time.time()-t0)
            pbar.set_postfix_str(f"loss={l:.6f}, 速度={spd:.2f}/s")

        tr_loss = float(np.mean(losses)) if losses else 0.0
        sec = time.time() - t0
        peak = (torch.cuda.max_memory_allocated()/1024**2) if device.type=="cuda" else 0.0
        print(f"✅ Epoch [{epoch+1}/{EPOCHS}] | train loss: {tr_loss:.6f} | {sec:.1f}s | GPU峰值: {peak:.1f} MB | bs={batch}")

        # ---- 驗證 ----
        model.eval()
        with torch.no_grad(), torch.amp.autocast("cuda", enabled=amp):
            vloss, steps = 0.0, 0
            sample_for_preview = None
            for val_batch in val_loader:
                if val_batch is None:
                    continue
                xv, yv = val_batch
                xv = xv.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                yv = yv.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                pv = model(xv)
                vloss += crit(pv, yv).item()
                steps += 1
                if sample_for_preview is None:
                    sample_for_preview = (xv, pv)
            vloss = vloss / max(1, steps)

        if ((epoch+1) % PREVIEW_EVERY == 0) and sample_for_preview is not None:
            xv, pv = sample_for_preview
            save_preview(xv, pv, preview_path)
            print(f"🖼️ 已輸出預覽圖：{preview_path}")

        # 最佳模型
        if vloss < best_val:
            best_val = vloss
            torch.save(model.state_dict(), str(BEST_PATH))
            print(f"🏅 新最佳 (val={best_val:.6f}) → {BEST_PATH}")

        # 例行 ckpt
        torch.save({
            "epoch": epoch,
            "model": model.state_dict(),
            "opt": opt.state_dict(),
            "sch": sch.state_dict(),
            "scaler": scaler.state_dict(),
            "best": best_val
        }, str(CKPT_PATH))

        sch.step()

    print(f"🎉 訓練完成。最佳 val = {best_val:.6f}")
    print(f"最後 ckpt：{CKPT_PATH}\n最佳權重：{BEST_PATH}")

if __name__ == "__main__":
    main()
