"""
単位格子および超格子を描画するためのスクリプト。

概要:
    結晶学における単位格子（unit cell）や、その繰り返しからなる超格子（supercell）を3Dで可視化します。

詳細説明:
    格子定数と角度を入力として、3D空間における単位格子を構築し、指定された数だけ繰り返して描画します。
    基本ベクトルを計算し、平行六面体を積み重ねることで超格子の形状を表現します。

関連リンク: draw_unit_cells_usage
"""

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#===================
# デフォルト設定
#===================
ELEV_DEF = 5.90
AZIM_DEF = -67.55
AXIS_FONT_SIZE_DEF = 24
DRAW_LATTICE_VECTORS_DEF = False

def lattice_vectors(a, b, c, alpha, beta, gamma):
    """
    概要:
        結晶の格子定数と角度から基本格子ベクトルを計算します。

    詳細説明:
        直交座標系における3つの基本格子ベクトル va, vb, vc を導出します。
        va はX軸に沿い、vb はXY平面にあり、vc は3次元空間に配置されるように定義されます。

    引数:
        :param a: 格子定数aの長さ。
        :type a: float
        :param b: 格子定数bの長さ。
        :type b: float
        :param c: 格子定数cの長さ。
        :type c: float
        :param alpha: 角度α（度）。vb と vc のなす角。
        :type alpha: float
        :param beta: 角度β（度）。va と vc のなす角。
        :type beta: float
        :param gamma: 角度γ（度）。va と vb のなす角。
        :type gamma: float
    戻り値:
        :returns: 3つの基本格子ベクトル va, vb, vc。
        :rtype: tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
    """
    alpha_r, beta_r, gamma_r = np.radians([alpha, beta, gamma])
    va = np.array([a, 0, 0])
    vb = np.array([b * np.cos(gamma_r), b * np.sin(gamma_r), 0])
    cx = c * np.cos(beta_r)
    cy = c * (np.cos(alpha_r) - np.cos(beta_r) * np.cos(gamma_r)) / np.sin(gamma_r)
    cz = np.sqrt(max(0, c**2 - cx**2 - cy**2))
    vc = np.array([cx, cy, cz])
    return va, vb, vc

def set_equal_aspect(ax, points):
    """
    概要:
        Matplotlib 3Dプロットのアスペクト比を全ての軸で等しく設定します。

    引数:
        :param ax: 3DプロットのAxesオブジェクト。
        :type ax: matplotlib.axes.Axes
        :param points: 表示範囲を決定するための点群。
        :type points: numpy.ndarray
    戻り値:
        :returns: None
        :rtype: None
    """
    xlim = [np.min(points[:,0]), np.max(points[:,0])]
    ylim = [np.min(points[:,1]), np.max(points[:,1])]
    zlim = [np.min(points[:,2]), np.max(points[:,2])]
    max_range = max(
        xlim[1] - xlim[0],
        ylim[1] - ylim[0],
        zlim[1] - zlim[0]
    ) / 2

    mid_x = np.mean(xlim)
    mid_y = np.mean(ylim)
    mid_z = np.mean(zlim)

    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    ax.set_box_aspect([1,1,1])

def draw_cell(ax, va, vb, vc, origin, linewidth=0.5):
    """
    概要:
        指定された基本格子ベクトルと原点に基づいて単一の単位格子を描画します。

    引数:
        :param ax: 3DプロットのAxesオブジェクト。
        :type ax: matplotlib.axes.Axes
        :param va: 第1基本格子ベクトル。
        :type va: numpy.ndarray
        :param vb: 第2基本格子ベクトル。
        :type vb: numpy.ndarray
        :param vc: 第3基本格子ベクトル。
        :type vc: numpy.ndarray
        :param origin: 描画開始点。
        :type origin: numpy.ndarray
        :param linewidth: 線の太さ。
        :type linewidth: float
    戻り値:
        :returns: None
        :rtype: None
    """
    p0 = origin
    p1 = origin + va
    p2 = origin + vb
    p3 = origin + vc
    p4 = origin + va + vb
    p5 = origin + va + vc
    p6 = origin + vb + vc
    p7 = origin + va + vb + vc
    verts = [p0, p1, p2, p3, p4, p5, p6, p7]

    edges = [
        (0,1), (0,2), (0,3),
        (1,4), (1,5),
        (2,4), (2,6),
        (3,5), (3,6),
        (4,7), (5,7), (6,7)
    ]
    for i,j in edges:
        ax.plot(*zip(verts[i], verts[j]), color='black', linewidth=linewidth)

def draw_supercell_plot(a, b, c, alpha, beta, gamma, nx=3, ny=3, nz=3, 
                        elev=ELEV_DEF, azim=AZIM_DEF, 
                        draw_lattice_vectors=DRAW_LATTICE_VECTORS_DEF, 
                        font_size=AXIS_FONT_SIZE_DEF):
    """
    概要:
        指定された格子定数と角度に基づき、超格子を描画し、画像として保存して表示します。

    引数:
        :param a: 格子定数aの長さ。
        :type a: float
        :param b: 格子定数bの長さ。
        :type b: float
        :param c: 格子定数cの長さ。
        :type c: float
        :param alpha: 角度α（度）。
        :type alpha: float
        :param beta: 角度β（度）。
        :type beta: float
        :param gamma: 角度γ（度）。
        :type gamma: float
        :param nx: X方向の単位格子の繰り返し数。
        :type nx: int
        :param ny: Y方向の単位格子の繰り返し数。
        :type ny: int
        :param nz: Z方向の単位格子の繰り返し数。
        :type nz: int
        :param elev: 3Dプロットの視点仰角。
        :type elev: float
        :param azim: 3Dプロットの視点方位角。
        :type azim: float
        :param draw_lattice_vectors: 基本格子ベクトルを描画するかどうかのフラグ。
        :type draw_lattice_vectors: bool
        :param font_size: 軸ラベルのフォントサイズ。
        :type font_size: int
    戻り値:
        :returns: None
        :rtype: None
    """
    va, vb, vc = lattice_vectors(a, b, c, alpha, beta, gamma)
    fig = plt.figure(figsize=(8,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=elev, azim=azim)

    # 角度確認用のイベントハンドラ
    def on_draw(event):
        curr_elev = ax.elev
        curr_azim = ax.azim
        print(f"Current elev: {curr_elev:.2f}, azim: {curr_azim:.2f}")
    fig.canvas.mpl_connect('draw_event', on_draw)

    # 超格子の全セルの描画（細い黒線）
    for i in range(nx):
        for j in range(ny):
            for k in range(nz):
                offset = i * va + j * vb + k * vc
                draw_cell(ax, va, vb, vc, offset, linewidth=0.5)

    # 原点にあるセルの強調（太い黒線）
    draw_cell(ax, va, vb, vc, np.array([0,0,0]), linewidth=2)

    # 基本格子ベクトルの描画（オプション）
    if draw_lattice_vectors:
        for vec, label in zip([va, vb, vc], ['a', 'b', 'c']):
            ax.quiver(0, 0, 0, *vec, color='blue', linewidth=2, arrow_length_ratio=0.1)
            ax.text(*vec * 1.05, label, fontsize=font_size, color='blue')

    # 不要な座標軸やグリッドを非表示にする
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    ax.set_xticklabels([]); ax.set_yticklabels([]); ax.set_zticklabels([])
    ax.set_xlabel(''); ax.set_ylabel(''); ax.set_zlabel('')
    ax.grid(False)
    ax.xaxis.pane.set_visible(False)
    ax.yaxis.pane.set_visible(False)
    ax.zaxis.pane.set_visible(False)
    ax.xaxis.line.set_color((0,0,0,0))
    ax.yaxis.line.set_color((0,0,0,0))
    ax.zaxis.line.set_color((0,0,0,0))
    ax.set_facecolor((1,1,1,0))

    # 全格子点の計算と描画（小さな●）
    all_points = []
    for i in range(nx+1):
        for j in range(ny+1):
            for k in range(nz+1):
                all_points.append(i*va + j*vb + k*vc)
    all_points = np.array(all_points)
    ax.scatter(all_points[:,0], all_points[:,1], all_points[:,2], color='black', s=20)

    # アスペクト比の調整
    set_equal_aspect(ax, all_points)

    plt.tight_layout()
    plt.savefig("unit_cells.png", dpi=300, bbox_inches='tight', transparent=True)
    plt.show()

def main():
    """
    概要:
        メインの実行フロー。特定のパラメータで超格子を描画します。

    戻り値:
        :returns: None
        :rtype: None
    """
    params = {
        'a': 5.0, 'b': 5.5, 'c': 4.5,
        'alpha': 80, 'beta': 70, 'gamma': 100,
        'nx': 3, 'ny': 3, 'nz': 3
    }
    
    print(f"Drawing supercell: {params['nx']}x{params['ny']}x{params['nz']}")
    draw_supercell_plot(**params)

if __name__ == "__main__":
    main()