#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
try:
    from skimage.measure import marching_cubes
except:
    print(f"\nError: Can not import skimage")
    print(f"  Install: pip install scikit-image")
    input("\nPress ENTER to terminate>>\n")
    exit()

sys.path.append('d:/git/tkProg/tklib/python')


import tklib.tkgraphic.tkPlot3d as tk3d


# --- CHGCAR 読み込み（空行スキップ対応 + dtype 選択） ---
def read_chgcar(filename, dtype=np.float32):
    info = {}
    with open(filename, 'r') as f:
        info['title'] = f.readline().strip()
        info['scale'] = float(f.readline().strip())
        lattice = []
        for _ in range(3):
            lattice.append([float(x) for x in f.readline().split()])
        info['lattice'] = np.array(lattice, dtype=dtype)

        info['atoms'] = f.readline().split()
        info['numbers'] = [int(x) for x in f.readline().split()]
        info['coord_type'] = f.readline().strip()

        natoms = sum(info['numbers'])
        coords = []
        for _ in range(natoms):
            l = f.readline()
            coords.append([float(x) for x in l.split()[:3]])
        info['coords'] = np.array(coords, dtype=dtype)

        # 空行スキップしてグリッド
        while True:
            l = f.readline()
            if not l:
                raise ValueError("Unexpected EOF before grid size.")
            if l.strip():
                break
        nx, ny, nz = [int(x) for x in l.split()]
        info['grid'] = (nx, ny, nz)

        # 体積データ
        ngrid = nx * ny * nz
        data = []
        while len(data) < ngrid:
            line = f.readline()
            if not line:
                break
            data.extend([float(x) for x in line.split()])
        if len(data) != ngrid:
            raise ValueError("CHGCAR data size mismatch")

        F = np.array(data, dtype=dtype).reshape((nx, ny, nz), order='F')
        P = [[[F[i, j, k] for k in range(nz)] for j in range(ny)] for i in range(nx)]
    return info, P, F


# --- 直交座標メッシュ生成（分率→直交） ---
def build_cartesian_grid(lattice, nx, ny, nz, dtype=np.float32):
    fx = np.arange(nx, dtype=dtype) / dtype(nx)
    fy = np.arange(ny, dtype=dtype) / dtype(ny)
    fz = np.arange(nz, dtype=dtype) / dtype(nz)
    FX, FY, FZ = np.meshgrid(fx, fy, fz, indexing='ij')  # (nx,ny,nz)

    a, b, c = lattice[0], lattice[1], lattice[2]
    X = FX * a[0] + FY * b[0] + FZ * c[0]
    Y = FX * a[1] + FY * b[1] + FZ * c[1]
    Z = FX * a[2] + FY * b[2] + FZ * c[2]
    return X, Y, Z


# --- ダウンサンプリング（各軸のステップ） ---
def downsample_volume(F, dsx, dsy, dsz):
    return F[::dsx, ::dsy, ::dsz]

def fractional_spacing_after_downsampling(nx, ny, nz, dsx, dsy, dsz):
    return (dsx / nx, dsy / ny, dsz / nz)


# --- 等値面（分率で抽出→直交に変換） ---
def plot_isosurfaces_fractional(ax, F, lattice, levels, alpha=0.3, edgecolor='none', linewidth=0.0,
                                colors=None, spacing_frac=(1,1,1), legend=True, antialias=False):
    patches = []
    if colors is None:
        cmap = plt.get_cmap('tab10')
        colors = [cmap(i % 10) for i in range(len(levels))]

    for il, level in enumerate(levels):
        try:
            verts_frac, faces, _, _ = marching_cubes(F, level=level, spacing=spacing_frac)
        except Exception:
            continue
        verts_cart = verts_frac @ lattice  # (N,3) @ (3,3)

        mesh = tk3d.Poly3DCollection(verts_cart[faces], alpha=alpha,
                                     facecolor=colors[il], edgecolor=edgecolor,
                                     linewidth=linewidth, antialiased=antialias)
        ax.add_collection3d(mesh)
        if legend:
            patches.append(Patch(facecolor=colors[il], edgecolor='k', label=f"Level={level:.4g}"))

    if legend and patches:
        ax.legend(handles=patches, loc="upper right", fontsize=9)


# --- スライス面を正しく直交座標に配置して描画 ---
def _facecolors_from_vals(vals, cmap):
    """vals: (M, N) のスカラー。四隅平均で (M-1, N-1, 4) の RGBA を返す"""
    import numpy as np
    import matplotlib.pyplot as plt
    norm = plt.Normalize(vmin=float(vals.min()), vmax=float(vals.max()))
    vc = 0.25 * (vals[:-1, :-1] + vals[1:, :-1] + vals[:-1, 1:] + vals[1:, 1:])
    return plt.get_cmap(cmap)(norm(vc))

def render_slice_plane(ax, lattice, F, which, frac, cmap='viridis', alpha=0.6):
    """
    which: 'xy' 固定 fz = frac, 'yz' 固定 fx = frac, 'zx' 固定 fy = frac（0<=frac<=1）
    分率面を直交座標へ射影し、セル中心色で面を貼る
    """
    nx, ny, nz = F.shape
    a, b, c = lattice[0], lattice[1], lattice[2]

    if which == 'xy':
        # fz = frac
        fx = np.arange(nx) / nx
        fy = np.arange(ny) / ny
        FX, FY = np.meshgrid(fx, fy, indexing='ij')  # (nx, ny)
        fz = np.full_like(FX, frac, dtype=FX.dtype)
        X = FX * a[0] + FY * b[0] + fz * c[0]
        Y = FX * a[1] + FY * b[1] + fz * c[1]
        Z = FX * a[2] + FY * b[2] + fz * c[2]
        iz = int(round(frac * (nz - 1)))
        vals = F[:, :, iz]  # (nx, ny)

    elif which == 'yz':
        # fx = frac
        fy = np.arange(ny) / ny
        fz = np.arange(nz) / nz
        FY, FZ = np.meshgrid(fy, fz, indexing='ij')  # 形状: (ny, nz)
        fx = np.full_like(FY, frac, dtype=FY.dtype)

        X = fx * a[0] + FY * b[0] + FZ * c[0]
        Y = fx * a[1] + FY * b[1] + FZ * c[1]
        Z = fx * a[2] + FY * b[2] + FZ * c[2]

        ix = int(round(frac * (nx - 1)))
        # ★転置は不要です。FY/FZ は (ny, nz) なので、(ny, nz) のままにする
        vals = F[ix, :, :]   # ← ここを .T しない！

    elif which == 'zx':
        # fy = frac
        fx = np.arange(nx) / nx
        fz = np.arange(nz) / nz
        FX, FZ = np.meshgrid(fx, fz, indexing='ij')  # (nx, nz)
        fy = np.full_like(FX, frac, dtype=FX.dtype)
        X = FX * a[0] + fy * b[0] + FZ * c[0]
        Y = FX * a[1] + fy * b[1] + FZ * c[1]
        Z = FX * a[2] + fy * b[2] + FZ * c[2]
        iy = int(round(frac * (ny - 1)))
        vals = F[:, iy, :]  # (nx, nz)

    else:
        raise ValueError("which must be one of 'xy', 'yz', 'zx'")

    # セル中心の facecolors を作る（(M-1,N-1,4)）
    facecolors = _facecolors_from_vals(vals, cmap)

    # 描画（X,Y,Z は (M,N)、facecolors は (M-1,N-1,4) でOK）
    tk3d.plot_surface3d(ax, X, Y, Z, facecolors=facecolors, edgecolor='none', alpha=alpha, shade=False)

# ユーティリティ：小数配列パース（例: "0.25 0.5"）
def parse_fracs(str_list):
    if str_list is None:
        return []
    out = []
    for s in str_list:
        out.append(float(s))
    return out


def main():
    ap = argparse.ArgumentParser(description="Visualize VASP volumetric data with isosurfaces and slice planes.")
    ap.add_argument("infile")

    # 表示モード
    ap.add_argument("--mode", choices=["iso", "dots", "both"], default="both",
                    help="Isosurfaces, dots, or both (isosurface + slices).")

    # 等値面
    ap.add_argument("--levels", type=float, nargs="+", help="Iso-surface levels (absolute values)")
    ap.add_argument("--nlevels", type=int, default=None, help="Auto-generate N iso levels between min/max")
    ap.add_argument("--alpha", type=float, default=0.30, help="Alpha for isosurfaces/dots")
    ap.add_argument("--edge", default="none")
    ap.add_argument("--lw", type=float, default=0.0)
    ap.add_argument("--no-legend", action="store_true")

    # dots
    ap.add_argument("--cutoff", type=float, default=None)
    ap.add_argument("--quantile", type=float, default=None)
    ap.add_argument("--subsample", type=int, default=1)
    ap.add_argument("--max-points", type=int, default=None)
    ap.add_argument("--size", type=float, default=0.4)
    ap.add_argument("--cmap", default="viridis")

    # スライス（分率座標で指定：0～1）
    ap.add_argument("--slice-xy", nargs="+", help="Fractions for XY slices (fix fractional z), e.g., 0.25 0.5")
    ap.add_argument("--slice-yz", nargs="+", help="Fractions for YZ slices (fix fractional x)")
    ap.add_argument("--slice-zx", nargs="+", help="Fractions for ZX slices (fix fractional y)")
    ap.add_argument("--slice-cmap", default="viridis")
    ap.add_argument("--slice-alpha", type=float, default=0.6)

    # 高速化
    ap.add_argument("--float32", action="store_true")
    ap.add_argument("--ds", type=int, nargs=3, metavar=("DSX","DSY","DSZ"), default=[1,1,1],
                    help="Downsample steps per axis (e.g., --ds 2 2 2)")

    # 表示
    ap.add_argument("--ortho", action="store_true")
    ap.add_argument("--pad", type=float, default=0.0)
    ap.add_argument("--save", default=None)
    ap.add_argument("--title", default=None)
    args = ap.parse_args()

    dtype = np.float32 if args.float32 else np.float64

    info, _, F = read_chgcar(args.infile, dtype=dtype)
    lattice = (info['lattice'] * info['scale']).astype(dtype, copy=False)
    nx, ny, nz = info['grid']

    # ダウンサンプリング
    dsx, dsy, dsz = args.ds
    if dsx > 1 or dsy > 1 or dsz > 1:
        F = downsample_volume(F, dsx, dsy, dsz)
        spacing_frac = fractional_spacing_after_downsampling(nx, ny, nz, dsx, dsy, dsz)
        nx2, ny2, nz2 = F.shape
    else:
        spacing_frac = (1.0/nx, 1.0/ny, 1.0/nz)
        nx2, ny2, nz2 = nx, ny, nz

    # 直交グリッド（dots 用と軸範囲決定用）
    X, Y, Z = build_cartesian_grid(lattice, nx2, ny2, nz2, dtype=dtype)

    # 図
    fig = plt.figure(figsize=(9, 7))
    ax = fig.add_subplot(111, projection="3d")
    if args.ortho:
        try:
            ax.set_proj_type('ortho')
        except Exception:
            pass

    title = args.title if args.title else info['title']
    ax.set_title(title)

    # 軸範囲（pad 余白）
    minx, maxx = float(X.min()), float(X.max())
    miny, maxy = float(Y.min()), float(Y.max())
    minz, maxz = float(Z.min()), float(Z.max())
    if args.pad != 0.0:
        minx -= args.pad; maxx += args.pad
        miny -= args.pad; maxy += args.pad
        minz -= args.pad; maxz += args.pad
    ax.set_xlim([minx, maxx]); ax.set_ylim([miny, maxy]); ax.set_zlim([minz, maxz])
    ax.set_box_aspect([1, 1, 1])
    ax.set_xlabel("X (Å)"); ax.set_ylabel("Y (Å)"); ax.set_zlabel("Z (Å)")

    # 等値面
    if args.mode in ("iso", "both"):
        vmin, vmax = float(F.min()), float(F.max())
        if args.levels:
            levels = args.levels
        elif args.nlevels:
            levels = np.linspace(vmin, vmax, args.nlevels + 2, dtype=float)[1:-1]
        else:
            levels = np.linspace(vmin, vmax, 5, dtype=float)[1:-1]

        cmap = plt.get_cmap('tab10')
        colors = [cmap(i % 10) for i in range(len(levels))]
        plot_isosurfaces_fractional(
            ax, F, lattice, levels,
            alpha=args.alpha,
            edgecolor=(None if args.edge == 'none' else args.edge),
            linewidth=args.lw,
            colors=colors,
            spacing_frac=spacing_frac,
            legend=not args.no_legend if hasattr(args, "no_legend") else True,
            antialias=False
        )

    # dots
    if args.mode in ("dots",):
        # 値選別
        vals = F.ravel()
        xs = X.ravel(); ys = Y.ravel(); zs = Z.ravel()
        if args.quantile is not None:
            thr = np.quantile(vals, args.quantile)
            mask = vals >= thr
        elif args.cutoff is not None:
            mask = vals >= float(args.cutoff)
        else:
            thr = np.median(vals)
            mask = vals >= thr
        xs = xs[mask]; ys = ys[mask]; zs = zs[mask]; cs = vals[mask]
        if args.subsample > 1:
            xs = xs[::args.subsample]; ys = ys[::args.subsample]; zs = zs[::args.subsample]; cs = cs[::args.subsample]
        if args.max_points is not None and xs.size > args.max_points:
            idx = np.random.default_rng(0).choice(xs.size, size=args.max_points, replace=False)
            xs = xs[idx]; ys = ys[idx]; zs = zs[idx]; cs = cs[idx]
        norm = plt.Normalize(vmin=float(cs.min()), vmax=float(cs.max()))
        tk3d.plot_scatter3d(ax, xs, ys, zs,
                            minx, maxx, miny, maxy, minz, maxz,
                            cmap=args.cmap, c=cs, norm=norm, size=args.size, alpha=args.alpha)
        if not args.no_legend:
            tk3d.make_colorbar(ax, cs, plt.get_cmap(args.cmap), vmin=float(cs.min()), vmax=float(cs.max()), label='Value')

    # スライス
    fr_xy = parse_fracs(args.slice_xy)
    fr_yz = parse_fracs(args.slice_yz)
    fr_zx = parse_fracs(args.slice_zx)

    for fz in fr_xy:
        render_slice_plane(ax, lattice, F, which='xy', frac=fz, cmap=args.slice_cmap, alpha=args.slice_alpha)
    for fx in fr_yz:
        render_slice_plane(ax, lattice, F, which='yz', frac=fx, cmap=args.slice_cmap, alpha=args.slice_alpha)
    for fy in fr_zx:
        render_slice_plane(ax, lattice, F, which='zx', frac=fy, cmap=args.slice_cmap, alpha=args.slice_alpha)

    if args.save:
        plt.tight_layout()
        plt.savefig(args.save, dpi=220)
    plt.show()


if __name__ == "__main__":
    main()
