import sys
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# ─── 格子定数とミラー指数 ───────────────────────
a = 1.0  # a 軸長さ
b = 1.2  # b 軸長さ
c = 0.8  # c 軸長さ
alpha, beta, gamma = 90, 100, 110  # α, β, γ （度）
h, k, l = 1, 2, 3  # ミラー指数
# ───────────────────────────────────────────


argv = sys.argv
nargs = len(argv)
if nargs > 1: h = int(argv[1])
if nargs > 2: k = int(argv[2])
if nargs > 3: l = int(argv[3])


plane_color = 'cyan'
edge_color = 'blue'
plane_alpha = 0.1

# フォントを MS Gothic に設定
plt.rcParams['font.family'] = 'MS Gothic'

# ラジアン変換
alpha_r, beta_r, gamma_r = np.deg2rad(alpha), np.deg2rad(beta), np.deg2rad(gamma)
cos_alpha, cos_beta, cos_gamma = np.cos(alpha_r), np.cos(beta_r), np.cos(gamma_r)
sin_gamma = np.sin(gamma_r)

# 格子ベクトル (実座標系)
a1 = np.array([a, 0.0, 0.0])
a2 = np.array([b * cos_gamma, b * sin_gamma, 0.0])
a3_x = c * cos_beta
a3_y = c * (cos_alpha - cos_beta * cos_gamma) / sin_gamma
a3_z = c * np.sqrt(max(0, 1 - cos_beta**2 - ((cos_alpha - cos_beta * cos_gamma)/sin_gamma)**2))
a3 = np.array([a3_x, a3_y, a3_z])
M = np.column_stack((a1, a2, a3))

# 分数単位格子頂点と辺
def init_frac_cell():
    fverts = np.array([
        [0,0,0], [1,0,0], [1,1,0], [0,1,0],
        [0,0,1], [1,0,1], [1,1,1], [0,1,1]
    ])
    edges = [
        (0,1),(1,2),(2,3),(3,0),
        (4,5),(5,6),(6,7),(7,4),
        (0,4),(1,5),(2,6),(3,7)
    ]
    return fverts, edges
fverts, edges = init_frac_cell()

# 平面との交点計算
def compute_frac_polygon(h, k, l, fverts, edges):
    pts = []
    for i0, i1 in edges:
        v0, v1 = fverts[i0], fverts[i1]
        d = v1 - v0
        denom = h*d[0] + k*d[1] + l*d[2]
        if abs(denom) < 1e-8:
            continue
        t = (1 - (h*v0[0] + k*v0[1] + l*v0[2])) / denom
        if 0 <= t <= 1:
            pts.append(v0 + t*d)
    if not pts:
        return None
    pts = np.unique(np.array(pts), axis=0)
    cen = pts.mean(axis=0)
    n = np.array([h, k, l], dtype=float)
    n /= np.linalg.norm(n)
    if abs(n[0]) < abs(n[1]):
        u = np.cross(n, [1,0,0])
    else:
        u = np.cross(n, [0,1,0])
    u /= np.linalg.norm(u)
    v = np.cross(n, u)
    ang = [np.arctan2(np.dot(p-cen, v), np.dot(p-cen, u)) for p in pts]
    return pts[np.argsort(ang)]

# 分数→実座標変換
frac_poly = compute_frac_polygon(h, k, l, fverts, edges)
cart_poly = np.dot(M, frac_poly.T).T if frac_poly is not None else None
cverts = np.dot(M, fverts.T).T

# プロット設定
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(projection='3d')
ax.set_box_aspect((1,1,1))
ax.grid(False)
ax.set_axis_off()
ax.xaxis.set_pane_color((1,1,1,1))
ax.yaxis.set_pane_color((1,1,1,1))
ax.zaxis.set_pane_color((1,1,1,1))

# 単位格子描画
for i0, i1 in edges:
    xs, ys, zs = zip(cverts[i0], cverts[i1])
    ax.plot(xs, ys, zs, 'k-', lw=1)

# (hkl) 面描画
if cart_poly is not None and len(cart_poly) >= 3:
    ax.add_collection3d(
        Poly3DCollection(
            [cart_poly], facecolors=plane_color,
            edgecolors=edge_color, alpha=plane_alpha
        )
    )

# 格子軸ベクトルラベル（少し外側にオフセット）
offset_factor = 1.05
label_pos_a = a1 * offset_factor
label_pos_b = a2 * offset_factor
label_pos_c = a3 * offset_factor
def label_vec(vec, label):
    ax.text(*vec, label, fontsize=14,
            fontfamily='MS Gothic', ha='center', va='center')
label_vec(label_pos_a, 'a')
label_vec(label_pos_b, 'b')
label_vec(label_pos_c, 'c')

# ミラーインターセプト注記（オフセット適用）
#offset_label = 1.05
offset_alabel = -0.1
offset_blabel = -0.2
offset_clabel = -0.05
if h != 0:
    pos_h = np.dot(M, np.array([1/h, 0, 0]))
    pos_h_offset = [pos_h[0], pos_h[1] + offset_alabel, pos_h[2]]
#    pos_h_offset = pos_h * offset_label
    ax.text(*pos_h_offset, f'$1/{{{h}}}$', fontsize=12,
            fontfamily='MS Gothic', ha='center', va='bottom')
if k != 0:
    pos_k = np.dot(M, np.array([0, 1/k, 0]))
    pos_k_offset = [pos_k[0] + offset_blabel, pos_k[1], pos_k[2]]
#    pos_k_offset = pos_k * offset_label
    ax.text(*pos_k_offset, f'$1/{{{k}}}$', fontsize=12,
            fontfamily='MS Gothic', ha='right', va='center')
if l != 0:
    pos_l = np.dot(M, np.array([0, 0, 1/l]))
    pos_l_offset = [pos_l[0], pos_l[1] + offset_clabel, pos_l[2]]
#    pos_l_offset = pos_l * offset_label
    ax.text(*pos_l_offset, f'$1/{{{l}}}$', fontsize=12,
            fontfamily='MS Gothic', ha='right', va='bottom')

# 視点調整
ax.view_init(elev=20, azim=30)
ax.dist = 12

plt.tight_layout()
plt.savefig(f"lattice_plane{h}{k}{l}.png", dpi=300, bbox_inches='tight', transparent=True)
plt.show()
