#!/usr/bin/env python3
import argparse
import numpy as np
import sys, math, functools
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.spatial import Voronoi
import warnings
warnings.filterwarnings(
  "ignore",
  category=DeprecationWarning,
  module=r"spglib\.spglib",
  message=r"dict interface \(SpglibDataset\['.*'\]\) is deprecated"
)

# ================== 幾何ユーティリティ ==================
def build_conventional_basis(a, b, c, alpha, beta, gamma):
    ar, br, gr = np.radians([alpha, beta, gamma])
    a_vec = np.array([a, 0.0, 0.0])
    b_vec = np.array([b*np.cos(gr), b*np.sin(gr), 0.0])
    cx = c*np.cos(br)
    denom = np.sin(gr) if abs(np.sin(gr)) > 1e-12 else 1e-12
    cy = c*(np.cos(ar) - np.cos(br)*np.cos(gr)) / denom
    cz2 = c**2 - cx**2 - cy**2
    cz = math.sqrt(max(cz2, 0.0))
    c_vec = np.array([cx, cy, cz])
    return np.column_stack((a_vec, b_vec, c_vec))  # 列: a,b,c

def primitive_from_centering(A_conv, lattice_tag):
    tag = lattice_tag.strip().upper()
    alias = {'SC':'P','P':'P','BCC':'I','I':'I','FCC':'F','F':'F','A':'A','B':'B','C':'C'}
    if tag not in alias: raise ValueError(f"Unknown lattice '{lattice_tag}'")
    tag = alias[tag]
    if tag == 'P':
        P = np.eye(3)
    elif tag == 'I':
        P = np.array([[-0.5,  0.5,  0.5],
                      [ 0.5, -0.5,  0.5],
                      [ 0.5,  0.5, -0.5]])
    elif tag == 'F':
        P = np.array([[0.0, 0.5, 0.5],
                      [0.5, 0.0, 0.5],
                      [0.5, 0.5, 0.0]])
    elif tag == 'R':
        P = np.array([[ 2/3, -1/3, -1/3],
                      [ 1/3,  1/3, -2/3],
                      [ 1/3,  1/3,  1/3]])
    elif tag == 'A':
        P = np.array([[1.0,  0.0,  0.0],
                      [0.0,  0.5,  0.5],
                      [0.0, -0.5,  0.5]])
    elif tag == 'B':
        P = np.array([[0.5,  0.0,  0.5],
                      [0.0,  1.0,  0.0],
                      [0.5,  0.0, -0.5]])
    elif tag == 'C':
        P = np.array([[ 0.5,  0.5, 0.0],
                      [ 0.5, -0.5, 0.0],
                      [ 0.0,  0.0, 1.0]])
    else:
        raise ValueError(f"Unsupported lattice '{lattice_tag}'")
    return A_conv @ P  # 列が原始基底

def order_polygon_vertices_on_plane(P):
    C = P.mean(axis=0); Q = P - C
    if len(P) <= 2: return np.arange(len(P))
    _, _, Vt = np.linalg.svd(Q, full_matrices=False)
    n = Vt[-1]
    ref = Q[0]
    if np.linalg.norm(np.cross(ref, n)) < 1e-12:
        ref = Q[1] if len(Q)>1 else np.array([1.0,0.0,0.0])
    u = ref - np.dot(ref, n)*n
    if np.linalg.norm(u) < 1e-12:
        tmp = np.array([1.0,0.0,0.0]); 
        if abs(np.dot(tmp,n)) > 0.9: tmp = np.array([0.0,1.0,0.0])
        u = tmp - np.dot(tmp, n)*n
    u /= np.linalg.norm(u)
    v = np.cross(n, u)
    ang = np.arctan2(Q @ v, Q @ u)
    return np.argsort(ang)

def reciprocal_points_from_basis(B, n):
    pts = []
    b1,b2,b3 = B[:,0], B[:,1], B[:,2]
    for i in range(-n,n+1):
        for j in range(-n,n+1):
            for k in range(-n,n+1):
                pts.append(i*b1 + j*b2 + k*b3)
    return np.array(pts)

def get_cell_faces_and_neighbors(vor, point_index):
    faces, neighbors = [], []
    for (p,q), v_idx in zip(vor.ridge_points, vor.ridge_vertices):
        if point_index not in (p,q): continue
        if -1 in v_idx or len(v_idx) < 3: continue
        faces.append(np.array(v_idx))
        neighbors.append(q if p==point_index else p)
    return faces, neighbors

def cell_polygons(vor, point_index):
    polys = []
    faces, neighbors = get_cell_faces_and_neighbors(vor, point_index)
    for f in faces:
        poly = vor.vertices[f]
        poly = poly[order_polygon_vertices_on_plane(poly)]
        polys.append(poly)
    return polys, neighbors

# ================== SeeK-path（キャッシュ付き） ==================
def _key_from_Aprim(A_prim, symprec):
    # 行が a1,a2,a3 のフラット（数値は丸めてキャッシュ安定化）
    Arows = A_prim.T
    return (float(symprec),) + tuple(np.round(Arows.ravel(), 12))

@functools.lru_cache(maxsize=128)
def _seekpath_cached(key):
    symprec = key[0]
    Arows = np.array(key[1:]).reshape(3,3)
    import seekpath
    cell_tuple = (Arows, np.array([[0.0,0.0,0.0]]), [1])
    res = seekpath.get_path(cell_tuple, with_time_reversal=True, symprec=symprec)
    B_rows = np.array(res['reciprocal_primitive_lattice'])
    Bstd = B_rows.T  # 列が b1,b2,b3（2π含む）
    kpts_abs = {lbl: Bstd @ np.array(frac) for lbl, frac in res['point_coords'].items()}
    path = res['path']
    return Bstd, kpts_abs, path

def bz_and_kpath_from_seekpath_cached(A_prim, symprec=1e-7):
    key = _key_from_Aprim(A_prim, symprec)
    return _seekpath_cached(key)

def pretty_label(lbl): return lbl.replace("GAMMA", "Γ")

# ================== 着色＆透明度ユーティリティ ==================
def color_checker(i,j,k):
    palette = ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd','#8c564b','#e377c2','#7f7f7f']
    idx = ((i & 1) << 0) | ((j & 1) << 1) | ((k & 1) << 2)  # 0..7
    return palette[idx]

def color_hash(i,j,k, cmap=plt.cm.tab20):
    h = (i*73856093 ^ j*19349663 ^ k*83492791) & 0xffffffff
    return cmap((h % 256)/255.0)

def color_distance(t, rmax, cmap=plt.cm.viridis):
    r = np.linalg.norm(t); u = 0.0 if rmax is None or rmax<=0 else min(1.0, r/rmax)
    return cmap(u)

def alpha_fade(t, rmax, a_min=0.06, a_max=0.22, power=1.0):
    if rmax is None or rmax<=0: return a_max
    r = np.linalg.norm(t); u = min(1.0, r/rmax); w = (1.0 - u)**power
    return a_min + (a_max - a_min)*w

# ================== 描画マネージャ（差分更新） ==================
class SceneManager:
    def __init__(self, args):
        self.args = args
        self.fig = plt.figure(figsize=(10,9))
        self.ax = self.fig.add_subplot(111, projection='3d')
        self.artists = []  # 前回描画したアーティストを保持（差し替え用）

    def clear_artists(self):
        # Axesは維持し、前回のアーティストだけ remove()
        for a in self.artists:
            try: a.remove()
            except Exception: pass
        self.artists = []

    def add_artist(self, *objs):
        for o in objs:
            if isinstance(o, (list, tuple)):
                self.artists.extend(o)
            elif o is not None:
                self.artists.append(o)

    def draw_scene(self, space, A_prim, Bstd, n_shell,
                   tile_nx, tile_ny, tile_nz, tile_rmax,
                   show_outside_points, highlight_face_points,
                   draw_connectors_to_faces, draw_kpath,
                   color_mode='checker', fade_power=1.0,
                   kpts_abs=None, kpath=None,
                   label_fontsize=10):
        self.clear_artists()
        ax = self.ax

        if space == 'reciprocal':
            G = reciprocal_points_from_basis(Bstd, n_shell)
            vor = Voronoi(G)
            idx_center = np.argmin(np.linalg.norm(G, axis=1))
            polys, neighbor_idx = cell_polygons(vor, idx_center)
            tile_basis = Bstd
            label_axes = (r'$k_x$', r'$k_y$', r'$k_z$')
            outside_pts = G
            face_color_main = 'lightskyblue'
            origin_label = 'Γ (origin)'
        else:
            R = reciprocal_points_from_basis(A_prim, n_shell)
            vor = Voronoi(R)
            idx_center = np.argmin(np.linalg.norm(R, axis=1))
            polys, neighbor_idx = cell_polygons(vor, idx_center)
            tile_basis = A_prim
            label_axes = ('x','y','z')
            outside_pts = R
            face_color_main = 'peachpuff'
            origin_label = 'origin'

        # 外部点群
        if show_outside_points:
            mask = np.ones(len(outside_pts), dtype=bool); mask[idx_center] = False
            sc = ax.scatter(outside_pts[mask,0], outside_pts[mask,1], outside_pts[mask,2],
                            s=10, color='gray', alpha=0.45,
                            label=('Reciprocal points' if space=='reciprocal' else 'Lattice points'))
            self.add_artist(sc)

        # 中心セル（外周のみ）
        for poly in polys:
            col = Poly3DCollection([poly], alpha=0.32, facecolor=face_color_main, edgecolor='none')
            ax.add_collection3d(col)
            self.add_artist(col)
            ring = np.vstack([poly, poly[0]])
            ln, = ax.plot(ring[:,0], ring[:,1], ring[:,2], color='k', linewidth=1.8)
            self.add_artist(ln)

        # 面生成点と点線
        centers = outside_pts[np.array(neighbor_idx)] if len(neighbor_idx)>0 else np.empty((0,3))
        if highlight_face_points and len(centers)>0:
            sc2 = ax.scatter(centers[:,0], centers[:,1], centers[:,2], s=28, color='tab:blue', alpha=0.9, label='Face-defining points')
            self.add_artist(sc2)
        if draw_connectors_to_faces:
            lines = []
            for p in centers:
                ln, = ax.plot([0,p[0]],[0,p[1]],[0,p[2]], linestyle=':', linewidth=1.2, color='k', alpha=0.9)
                lines.append(ln)
            self.add_artist(lines)

        # 周期タイル
        if any(n>0 for n in (tile_nx, tile_ny, tile_nz)):
            b1,b2,b3 = tile_basis[:,0], tile_basis[:,1], tile_basis[:,2]
            for i in range(-tile_nx, tile_nx+1):
                for j in range(-tile_ny, tile_ny+1):
                    for k in range(-tile_nz, tile_nz+1):
                        if i==0 and j==0 and k==0: continue
                        t = i*b1 + j*b2 + k*b3
                        if tile_rmax is not None and np.linalg.norm(t) > tile_rmax + 1e-12:
                            continue
                        # 色
                        if color_mode == 'checker':
                            fc = color_checker(i,j,k)
                        elif color_mode == 'distance':
                            fc = color_distance(t, tile_rmax)
                        else:
                            fc = color_hash(i,j,k)
                        # 透明度
                        a = alpha_fade(t, tile_rmax, a_min=0.06, a_max=0.22, power=fade_power)
                        # セル描画
                        cols, lns = [], []
                        for poly in polys:
                            P = poly + t
                            c = Poly3DCollection([P], alpha=a, facecolor=fc, edgecolor='none')
                            ax.add_collection3d(c); cols.append(c)
                            ring = np.vstack([P, P[0]])
                            ln, = ax.plot(ring[:,0], ring[:,1], ring[:,2], color='#666666', linewidth=0.6); lns.append(ln)
                        self.add_artist(cols); self.add_artist(lns)

        # 高対称点と k-path（逆格子のみ）
        if space=='reciprocal' and kpts_abs:
            texts, pts, segs = [], [], []
            for lbl, k in kpts_abs.items():
                sct = ax.scatter(k[0],k[1],k[2], s=40, color='crimson'); pts.append(sct)
                txt = ax.text(k[0],k[1],k[2], " "+pretty_label(lbl), fontsize=label_fontsize, color='crimson'); texts.append(txt)
            self.add_artist(pts); self.add_artist(texts)
            if draw_kpath and kpath:
                for a_lbl,b_lbl in kpath:
                    ka,kb = kpts_abs[a_lbl], kpts_abs[b_lbl]
                    seg = np.vstack([ka,kb])
                    ln, = ax.plot(seg[:,0], seg[:,1], seg[:,2], color='crimson', linewidth=1.6, alpha=0.85)
                    segs.append(ln)
                self.add_artist(segs)

        # 原点
        origin = ax.scatter([0],[0],[0], color='k', s=50, label=origin_label)
        self.add_artist(origin)

        # 軸や凡例は1回だけ設定（差分更新では再設定しない）
        basis = (Bstd if space=='reciprocal' else A_prim)
        scale = 1.35 * np.linalg.norm(basis[:,0])
        ax.set_xlim(-scale, scale); ax.set_ylim(-scale, scale); ax.set_zlim(-scale, scale)

        if self.args.plot_limit is not None:
            L = self.args.plot_limit
            ax.set_xlim(-L, L)
            ax.set_ylim(-L, L)
            ax.set_zlim(-L, L)
        ax.set_box_aspect([1,1,1])
        ax.set_xlabel(label_axes[0]); ax.set_ylabel(label_axes[1]); ax.set_zlabel(label_axes[2])
        ax.set_title('First Brillouin Zone (tiled)' if space=='reciprocal' else 'Wigner–Seitz cell (tiled)')
        # 凡例は毎回作り直す（重複回避のため一旦消してから）
        leg = ax.legend(loc='upper right'); self.add_artist(leg)

        self.fig.tight_layout()
        self.fig.canvas.draw_idle()

# ================== 実行系 ==================
def plot_once(args):
    A_conv = build_conventional_basis(args.a, args.b, args.c, args.alpha, args.beta, args.gamma)
    A_prim = primitive_from_centering(A_conv, args.lattice)

    mgr = SceneManager(args)
    if args.space == 'reciprocal':
        try:
            Bstd, kpts_abs, kpath = bz_and_kpath_from_seekpath_cached(A_prim)
        except RuntimeError as e:
            print(e, file=sys.stderr); sys.exit(1)
        mgr.draw_scene('reciprocal', A_prim, Bstd, args.n_shell,
                       args.tile_nx, args.tile_ny, args.tile_nz, args.tile_rmax,
                       (not args.no_outside_points), (not args.no_highlight_face_points),
                       (not args.no_connectors), (not args.no_kpath),
                       color_mode=args.tile_color_mode, fade_power=args.fade_power,
                       kpts_abs=kpts_abs, kpath=kpath, label_fontsize=args.label_fontsize)
    else:
        mgr.draw_scene('real', A_prim, None, args.n_shell,
                       args.tile_nx, args.tile_ny, args.tile_nz, args.tile_rmax,
                       (not args.no_outside_points), (not args.no_highlight_face_points),
                       (not args.no_connectors), False,
                       color_mode=args.tile_color_mode, fade_power=args.fade_power,
                       kpts_abs=None, kpath=None, label_fontsize=args.label_fontsize)
    plt.show()

def plot_interactive_apply(args):
    # 初期値
    A_conv = build_conventional_basis(args.a, args.b, args.c, args.alpha, args.beta, args.gamma)
    A_prim = primitive_from_centering(A_conv, args.lattice)
    mgr = SceneManager(args)

    def do_draw(a, b, c):
        A_conv = build_conventional_basis(a, b, c, args.alpha, args.beta, args.gamma)
        A_prim = primitive_from_centering(A_conv, args.lattice)
        if args.space == 'reciprocal':
            Bstd, kpts_abs, kpath = bz_and_kpath_from_seekpath_cached(A_prim)  # キャッシュで高速化
            mgr.draw_scene('reciprocal', A_prim, Bstd, args.n_shell,
                           args.tile_nx, args.tile_ny, args.tile_nz, args.tile_rmax,
                           (not args.no_outside_points), (not args.no_highlight_face_points),
                           (not args.no_connectors), (not args.no_kpath),
                           color_mode=args.tile_color_mode, fade_power=args.fade_power,
                           kpts_abs=kpts_abs, kpath=kpath, label_fontsize=args.label_fontsize)
        else:
            mgr.draw_scene('real', A_prim, None, args.n_shell,
                           args.tile_nx, args.tile_ny, args.tile_nz, args.tile_rmax,
                           (not args.no_outside_points), (not args.no_highlight_face_points),
                           (not args.no_connectors), False,
                           color_mode=args.tile_color_mode, fade_power=args.fade_power,
                           kpts_abs=None, kpath=None, label_fontsize=args.label_fontsize)

    # 初回描画
    do_draw(args.a, args.b, args.c)

    # UI: スライダ + Apply ボタン（押したときのみ再描画）
    plt.subplots_adjust(bottom=0.20)
    ax_a = plt.axes([0.15, 0.08, 0.65, 0.03])
    ax_b = plt.axes([0.15, 0.05, 0.65, 0.03])
    ax_c = plt.axes([0.15, 0.02, 0.65, 0.03])
    s_a = Slider(ax_a, 'a', 0.5*args.a, 2.0*args.a, valinit=args.a)
    s_b = Slider(ax_b, 'b', 0.5*args.b, 2.0*args.b, valinit=args.b)
    s_c = Slider(ax_c, 'c', 0.5*args.c, 2.0*args.c, valinit=args.c)

    ax_btn = plt.axes([0.83, 0.02, 0.12, 0.09])
    btn = Button(ax_btn, 'Apply')

    busy = {'flag': False}
    def on_apply(event):
        if busy['flag']:  # 処理中は無視
            return
        busy['flag'] = True
        # ボタン表示を Redrawing... に変更
        btn.label.set_text('Redrawing...')
        btn.ax.set_facecolor('#dddddd')
        mgr.fig.canvas.draw_idle()

        try:
            do_draw(s_a.val, s_b.val, s_c.val)
        finally:
            # ボタン表示を戻す
            btn.label.set_text('Apply')
            btn.ax.set_facecolor('#f0f0f0')
            mgr.fig.canvas.draw_idle()
            busy['flag'] = False

    btn.on_clicked(on_apply)
    plt.show()

# ================== CLI ==================
def main():
    ap = argparse.ArgumentParser(description="Voronoi 多面体（第一BZ / Wigner–Seitz）: キャッシュ・差分更新・Applyボタン対応")
    ap.add_argument('--space', choices=['reciprocal','real'], default='reciprocal',
                    help="reciprocal: 第一BZ（SeeK-pathで高対称点） / real: 実空間 Wigner–Seitz")
    ap.add_argument('--lattice', default='P', help="格子タイプ: P|SC, I|BCC, F|FCC, A, B, C")
    ap.add_argument('--a', type=float, default=1.0); ap.add_argument('--b', type=float, default=1.0); ap.add_argument('--c', type=float, default=1.0)
    ap.add_argument('--alpha', type=float, default=90.0); ap.add_argument('--beta', type=float, default=90.0); ap.add_argument('--gamma', type=float, default=90.0)
    ap.add_argument('--n-shell', type=int, default=3, help='Voronoi 計算用に並べる格子点の範囲')
    ap.add_argument('--tile-nx', type=int, default=0); ap.add_argument('--tile-ny', type=int, default=0); ap.add_argument('--tile-nz', type=int, default=0)
    ap.add_argument('--tile-rmax', type=float, default=None,
                    help='タイル描画の半径上限（中心から距離 r ≤ rmax のセルのみ）')
    ap.add_argument('--plot-limit', type=float, default=None,
                    help='軸の表示範囲を [-L, L] に強制的に制限する（内部座標、L=1.0で[-2pi/a, 2pi/a]に対応）')
    ap.add_argument('--tile-color-mode', choices=['checker','byindex','distance'], default='checker',
                    help="タイル色: checker=チェッカー, byindex=ハッシュ, distance=距離グラデーション")
    ap.add_argument('--fade-power', type=float, default=1.0, help='距離フェードの指数（1=線形, 2=二乗など）')
    ap.add_argument('--no-outside-points', action='store_true')
    ap.add_argument('--no-highlight-face-points', action='store_true')
    ap.add_argument('--no-connectors', action='store_true')
    ap.add_argument('--no-kpath', action='store_true')
    ap.add_argument('--label-fontsize', type=int, default=10)
    ap.add_argument('--interactive-apply', action='store_true', help='スライダ + Applyボタンで再描画')
    ap.add_argument('--save', default=None, help='（一括描画時のみ）画像保存パス')

    args = ap.parse_args()
    if args.interactive_apply:
        plot_interactive_apply(args)
    else:
        plot_once(args)

if __name__ == "__main__":
    main()
