"""
ブラベー格子の単位格子と格子点を3Dで可視化するスクリプト。

このスクリプトは、指定された格子定数と結晶系に基づいて、
ブラベー格子の単位格子、格子点、格子ベクトル、およびオプションで補助線や基本格子を3Dで描画します。
Matplotlibを使用して3Dプロットを生成し、結果を画像ファイルとして保存し、表示します。

関連リンク: :doc:`draw_bravais_lattice_usage`
"""
import sys
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D, proj3d

# 始点の初期値
elev = 5.90
azim = -67.55

draw_lattice_vectors = True
draw_support_lines   = True
draw_primitive_cell  = False

# ブラべー格子
lattice_vector_color = 'blue'
axis_arrow_color = 'blue'
axis_label_font_size = 24
# 格子点
point_color = 'red'
point_size = 500

#cell = 'SC'
#cell = 'ST'
#cell = 'SO'
#cell = 'SR'
#cell = 'SH'
#cell = 'FC'
#cell = 'FO'
cell = 'BC'
#cell = 'BT'
#cell = 'BO'
#cell = 'CO'
#cell = 'SM'
#cell = 'STri'
#cell = 'CM'


class Arrow3D(FancyArrowPatch):
    """
    MatplotlibのFancyArrowPatchを継承し、3D空間に矢印を描画するためのクラス。

    このクラスは、3D座標の矢印を2Dのスクリーン座標に投影し、
    Matplotlibの3D軸上に適切に描画することを可能にします。
    """
    def __init__(self, xs, ys, zs, *args, **kwargs):
        """
        Arrow3Dオブジェクトを初期化します。

        :param xs: 矢印のX座標 (開始点, 終了点)。
        :type xs: list[float]
        :param ys: 矢印のY座標 (開始点, 終了点)。
        :type ys: list[float]
        :param zs: 矢印のZ座標 (開始点, 終了点)。
        :type zs: list[float]
        :param args: FancyArrowPatchに渡される追加の位置引数。
        :param kwargs: FancyArrowPatchに渡される追加のキーワード引数。
        """
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        """
        矢印をレンダラーに描画します。

        3D座標を2Dスクリーン座標に変換し、変換された座標に基づいて矢印を描画します。

        :param renderer: 描画に使用されるレンダラーオブジェクト。
        :type renderer: matplotlib.backend_bases.RendererBase
        """
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self, renderer=None):
        """
        3D投影を計算し、描画順序のためのZ深度を返します。

        矢印の3D座標を2Dスクリーン座標に変換し、最小のZ深度を返して、
        Matplotlibの3Dレンダリングにおける描画順序を適切に処理します。

        :param renderer: 描画に使用されるレンダラーオブジェクト（オプション）。
        :type renderer: matplotlib.backend_bases.RendererBase or None
        :returns: 変換されたZ座標の最小値。
        :rtype: float
        """
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)


def draw_vector(ax, vec, color, label, fontsize):
    """
    指定された3Dベクトルを矢印としてMatplotlibの3D軸に描画します。

    矢印は原点(0,0,0)から指定されたベクトルまで描画され、
    ベクトルの先端にはラベルが表示されます。

    :param ax: 矢印を描画する3D軸オブジェクト。
    :type ax: mpl_toolkits.mplot3d.axes3d.Axes3D
    :param vec: 矢印の終点となる3Dベクトル。
    :type vec: numpy.ndarray
    :param color: 矢印とラベルの色。
    :type color: str
    :param label: 矢印のラベル文字列。
    :type label: str
    :param fontsize: ラベルのフォントサイズ。
    :type fontsize: int
    :returns: なし
    """
    arrow = Arrow3D([0, vec[0]], [0, vec[1]], [0, vec[2]],
                    mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
    ax.add_artist(arrow)
    ax.text(vec[0] * 1.05, vec[1] * 1.05, vec[2] * 1.05, label, fontsize=fontsize, color=color)


def lattice_vectors(a, b, c, alpha, beta, gamma):
    """
    単位格子の格子定数から3つの基本格子ベクトルを計算します。

    計算は、a軸をx軸に沿わせ、b軸をxy平面に置き、c軸をz軸正方向に向ける形で
    座標系を設定して行われます。

    :param a: 格子定数aの長さ。
    :type a: float
    :param b: 格子定数bの長さ。
    :type b: float
    :param c: 格子定数cの長さ。
    :type c: float
    :param alpha: 角度α（bとcの間の角度、度数）。
    :type alpha: float
    :param beta: 角度β（aとcの間の角度、度数）。
    :type beta: float
    :param gamma: 角度γ（aとbの間の角度、度数）。
    :type gamma: float
    :returns: 3つの基本格子ベクトル (va, vb, vc)。
    :rtype: tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
    """
    alpha_r, beta_r, gamma_r = np.radians([alpha, beta, gamma])
    va = np.array([a, 0, 0])
    vb = np.array([b * np.cos(gamma_r), b * np.sin(gamma_r), 0])
    cx = c * np.cos(beta_r)
    cy = c * (np.cos(alpha_r) - np.cos(beta_r) * np.cos(gamma_r)) / np.sin(gamma_r)
    cz = np.sqrt(c**2 - cx**2 - cy**2)
    vc = np.array([cx, cy, cz])
    return va, vb, vc


def set_equal_aspect(ax, points):
    """
    Matplotlibの3Dプロットの軸のアスペクト比を等しく設定します。

    描画される全ての点の範囲に基づいて、x, y, z軸の表示範囲を調整し、
    立体が歪まないように真のアスペクト比を維持します。

    :param ax: 軸のアスペクト比を設定する3D軸オブジェクト。
    :type ax: mpl_toolkits.mplot3d.axes3d.Axes3D
    :param points: 描画される全ての点のNumpy配列。
    :type points: numpy.ndarray
    :returns: なし
    """
    xlim = [np.min(points[:, 0]), np.max(points[:, 0])]
    ylim = [np.min(points[:, 1]), np.max(points[:, 1])]
    zlim = [np.min(points[:, 2]), np.max(points[:, 2])]
    max_range = max(
        xlim[1] - xlim[0],
        ylim[1] - ylim[0],
        zlim[1] - zlim[0]
    ) / 2

    mid_x = np.mean(xlim)
    mid_y = np.mean(ylim)
    mid_z = np.mean(zlim)

    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    ax.set_box_aspect([1, 1, 1])


def draw_unit_cell_with_lattice(
    a, b, c, alpha, beta, gamma,
    lattice_points=None,
    dashed_lines=None,
    basis_vectors=None
):
    """
    指定された格子定数と格子点を用いて、ブラベー単位格子を3Dで描画します。

    この関数は、単位格子の辺、頂点、格子ベクトル、格子点、オプションで補助線、
    および基本格子を描画し、結果を画像ファイルとして保存した後、画面に表示します。
    ユーザーがプロットを操作すると、現在の視点角度 (elev, azim) がコンソールに出力されます。

    :param a: 格子定数aの長さ。
    :type a: float
    :param b: 格子定数bの長さ。
    :type b: float
    :param c: 格子定数cの長さ。
    :type c: float
    :param alpha: 角度α（bとcの間の角度、度数）。
    :type alpha: float
    :param beta: 角度β（aとcの間の角度、度数）。
    :type beta: float
    :param gamma: 角度γ（aとbの間の角度、度数）。
    :type gamma: float
    :param lattice_points: 格子点を表す係数のリスト。各要素は `[u, v, w]` 形式で `u*va + v*vb + w*vc` と計算されます。
                           Noneの場合、格子点は描画されません。
    :type lattice_points: list[list[float]] or None
    :param dashed_lines: 補助線（破線）の始点と終点を表す係数のリスト。各要素は `([u1, v1, w1], [u2, v2, w2])` 形式です。
                         Noneの場合、補助線は描画されません。
    :type dashed_lines: list[tuple[tuple[float, float, float], tuple[float, float, float]]] or None
    :param basis_vectors: 基本格子ベクトルを表す係数のリスト。各要素は `[u, v, w]` 形式です。
                          Noneの場合、基本格子は描画されません。
    :type basis_vectors: list[list[float]] or None
    :returns: なし
    """
    va, vb, vc = lattice_vectors(a, b, c, alpha, beta, gamma)
    origin = np.array([0, 0, 0])
    corners = [
        origin,
        va,
        vb,
        vc,
        va + vb,
        va + vc,
        vb + vc,
        va + vb + vc
    ]
    corners = np.array(corners)

    edges = [
        (0, 1), (0, 2), (0, 3),
        (1, 4), (1, 5),
        (2, 4), (2, 6),
        (3, 5), (3, 6),
        (4, 7), (5, 7), (6, 7)
    ]

    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=elev, azim=azim)

    def on_draw(event):
        elev_now = ax.elev
        azim_now = ax.azim
        print(f"elev: {elev_now:.2f}, azim: {azim_now:.2f}")

    fig.canvas.mpl_connect('draw_event', on_draw)

    # 単位格子の辺
    for i, j in edges:
        ax.plot(*zip(corners[i], corners[j]), color='black', linewidth=0.5)

    # 単位格子の頂点
    ax.scatter(corners[:, 0], corners[:, 1], corners[:, 2], color='black', s=20)

    # 格子ベクトル
    if draw_lattice_vectors:
        for vec, label in zip([va, vb, vc], ['a', 'b', 'c']):
            draw_vector(ax, vec, lattice_vector_color, label, axis_label_font_size)

    # 格子点の描画
    real_points = []
    if lattice_points:
        real_points = [p[0] * va + p[1] * vb + p[2] * vc for p in lattice_points]
        real_points = np.array(real_points)
        ax.scatter(real_points[:, 0], real_points[:, 1], real_points[:, 2],
                   color=point_color, s=point_size)

    # 補助線（鎖線）
    if draw_support_lines and dashed_lines:
        for p1, p2 in dashed_lines:
            pt1 = p1[0] * va + p1[1] * vb + p1[2] * vc
            pt2 = p2[0] * va + p2[1] * vb + p2[2] * vc
            ax.plot(*zip(pt1, pt2), color='gray', linestyle='dashed', linewidth=0.8)

    # 基本格子
    if draw_primitive_cell and basis_vectors:
        for vec in basis_vectors:
            v = vec[0] * va + vec[1] * vb + vec[2] * vc
            ax.quiver(0, 0, 0, *v, color='green', linewidth=1.5, arrow_length_ratio=0.2)

        if len(basis_vectors) == 3:
            bv1 = basis_vectors[0][0] * va + basis_vectors[0][1] * vb + basis_vectors[0][2] * vc
            bv2 = basis_vectors[1][0] * va + basis_vectors[1][1] * vb + basis_vectors[1][2] * vc
            bv3 = basis_vectors[2][0] * va + basis_vectors[2][1] * vb + basis_vectors[2][2] * vc

            base_points = [
                origin,
                bv1,
                bv2,
                bv3,
                bv1 + bv2,
                bv1 + bv3,
                bv2 + bv3,
                bv1 + bv2 + bv3
            ]

            base_edges = [
                (0, 1), (0, 2), (0, 3),
                (1, 4), (1, 5),
                (2, 4), (2, 6),
                (3, 5), (3, 6),
                (4, 7), (5, 7), (6, 7)
            ]

            for i, j in base_edges:
                ax.plot(*zip(base_points[i], base_points[j]),
                        color='lightgreen', linewidth=1.2)

            for vec, label in zip([bv1, bv2, bv3], ["a'", "b'", "c'"]):
                ax.text(*vec, label, fontsize=12, color='green')

    # 軸や背景の調整
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    ax.set_xticklabels([]); ax.set_yticklabels([]); ax.set_zticklabels([])
    ax.set_xlabel(''); ax.set_ylabel(''); ax.set_zlabel('')
    ax.grid(False)
    ax.xaxis.pane.set_visible(False)
    ax.yaxis.pane.set_visible(False)
    ax.zaxis.pane.set_visible(False)
    ax.xaxis.line.set_color((0, 0, 0, 0))
    ax.yaxis.line.set_color((0, 0, 0, 0))
    ax.zaxis.line.set_color((0, 0, 0, 0))
    ax.set_facecolor((1, 1, 1, 0))

    all_points = np.vstack([corners] + ([real_points] if lattice_points else []))
    set_equal_aspect(ax, all_points)

    plt.tight_layout()
    plt.savefig("bravais_cell_with_basis_vectors.png",
                dpi=300, bbox_inches='tight', transparent=True)
    plt.show()


def main():
    """
    プログラムのエントリポイント。コマンドライン引数を解析し、ブラベー単位格子の描画を実行します。

    コマンドライン引数に基づいて、描画する結晶系と表示オプション（格子ベクトル、補助線、基本格子）を
    設定し、`draw_unit_cell_with_lattice`関数を呼び出して単位格子を描画します。
    引数が指定されない場合は、デフォルト値が使用されます。

    使用方法:
        python draw_bravais_lattice.py [cell_type] [draw_lattice_vectors] [draw_support_lines] [draw_primitive_cell]

        - cell_type (str): 描画する結晶系 ('SC', 'ST', 'SO', 'SR', 'SH', 'FC', 'FO', 'BC', 'BT', 'BO', 'CO', 'SM', 'STri', 'CM'など)。
        - draw_lattice_vectors (int): 格子ベクトルを描画するか (1: はい, 0: いいえ)。
        - draw_support_lines (int): 補助線を描画するか (1: はい, 0: いいえ)。
        - draw_primitive_cell (int): 基本格子を描画するか (1: はい, 0: いいえ)。

    :returns: なし
    """
    global cell, draw_lattice_vectors, draw_support_lines, draw_primitive_cell

    argv = sys.argv
    nargs = len(argv)
    if nargs > 1:
        cell = argv[1]
    if nargs > 2:
        draw_lattice_vectors = int(argv[2])
    if nargs > 3:
        draw_support_lines = int(argv[3])
    if nargs > 4:
        draw_primitive_cell = int(argv[4])

    print()
    print(f"{cell=}")
    print(f"{draw_lattice_vectors=}")
    print(f"{draw_support_lines=}")
    print(f"{draw_primitive_cell=}")

    if cell[0] == 'S':  # Simple (単純)
        if cell[1] == 'C': # Simple Cubic (単純立方)
            a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 90.0, 90.0, 90.0
        elif cell[1:] == 'T': # Simple Tetragonal (単純正方)
            a, b, c, alpha, beta, gamma = 5.0, 5.0, 7.0, 90.0, 90.0, 90.0
        elif cell[1] == 'O': # Simple Orthorhombic (単純斜方)
            a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0
        elif cell[1] == 'H': # Simple Hexagonal (単純六方)
            a, b, c, alpha, beta, gamma = 7.0, 7.0, 5.0, 90.0, 90.0, 120.0
        elif cell[1] == 'R': # Simple Rhombohedral (単純菱面体)
            a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 70.0, 70.0, 70.0
        elif cell[1] == 'M': # Simple Monoclinic (単純単斜)
            a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 110.0, 90.0
        elif cell[1:] == 'Tri': # Simple Triclinic (単純三斜)
            a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 80.0, 110.0, 60.0

        lattice_points = [
            [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
            [0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1],
        ]

        dashed_lines = []
        basis_vectors = []

    elif cell[0] == 'F': # Face-centered (面心)
        if cell[1] == 'C': # Face-centered Cubic (面心立方)
            a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 90.0, 90.0, 90.0
        elif cell[1] == 'O': # Face-centered Orthorhombic (面心斜方)
            a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0

        lattice_points = [
            [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
            [0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1],
            [0.0, 0.5, 0.5],
            [0.5, 0.0, 0.5],
            [0.5, 0.5, 0.0],
            [1.0, 0.5, 0.5],
            [0.5, 1.0, 0.5],
            [0.5, 0.5, 1.0],
        ]

        dashed_lines = [
            [(0, 0, 0), (1, 1, 0)],
            [(0, 1, 0), (1, 0, 0)],
            [(0, 0, 1), (1, 1, 1)],
            [(0, 1, 1), (1, 0, 1)],
            [(0, 0, 0), (1, 0, 1)],
            [(0, 0, 1), (1, 0, 0)],
            [(0, 1, 0), (1, 1, 1)],
            [(0, 1, 1), (1, 1, 0)],
            [(0, 0, 0), (0, 1, 1)],
            [(0, 1, 0), (0, 0, 1)],
            [(1, 0, 0), (1, 1, 1)],
            [(1, 1, 0), (1, 0, 1)],
        ]

        basis_vectors = [
            [-0.5,  0.5,  0.5],
            [ 0.5, -0.5,  0.5],
            [ 0.5,  0.5, -0.5]
        ]

    elif cell[0] == 'B': # Body-centered (体心)
        if cell[1] == 'C': # Body-centered Cubic (体心立方)
            a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 90.0, 90.0, 90.0
        elif cell[1] == 'T': # Body-centered Tetragonal (体心正方)
            a, b, c, alpha, beta, gamma = 5.0, 5.0, 7.0, 90.0, 90.0, 90.0
        elif cell[1] == 'O': # Body-centered Orthorhombic (体心斜方)
            a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0

        lattice_points = [
            [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
            [0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1],
            [0.5, 0.5, 0.5]
        ]

        dashed_lines = [
            [(0, 0, 0), (1, 1, 1)],
            [(1, 0, 0), (0, 1, 1)],
            [(0, 1, 0), (1, 0, 1)],
            [(0, 0, 1), (1, 1, 0)],
            [(0.5, 0.5, 0.5), (0, 0, 0)],
            [(0.5, 0.5, 0.5), (1, 1, 1)]
        ]

        basis_vectors = [
            [-0.5,  0.5,  0.5],
            [ 0.5, -0.5,  0.5],
            [ 0.5,  0.5, -0.5]
        ]

    elif cell[0] == 'C': # Base-centered (底心)
        if cell[1] == 'O': # Base-centered Orthorhombic (底心斜方)
            a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0
            lattice_points = [
                [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
                [0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1],
                [0.5, 0.5, 0.0],
                [0.5, 0.5, 1.0]
            ]
            dashed_lines = [
                [(0, 0, 0), (1, 1, 0)],
                [(0, 1, 0), (1, 0, 0)],
                [(0, 0, 1), (1, 1, 1)],
                [(0, 1, 1), (1, 0, 1)],
            ]
            basis_vectors = [
                [-0.5,  0.5,  0.5],
                [ 0.5, -0.5,  0.5],
                [ 0.5,  0.5, -0.5]
            ]
        elif cell[1] == 'M': # Base-centered Monoclinic (底心単斜)
            a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 110.0, 90.0
            lattice_points = [
                [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
                [0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1],
                [0.5, 0.0, 0.5],
                [0.5, 1.0, 0.5]
            ]
            dashed_lines = [
                [(0, 0, 0), (1, 0, 1)],
                [(0, 0, 1), (1, 0, 0)],
                [(0, 1, 0), (1, 1, 1)],
                [(0, 1, 1), (1, 1, 0)],
            ]
            basis_vectors = [
                [-0.5,  0.5,  0.5],
                [ 0.5, -0.5,  0.5],
                [ 0.5,  0.5, -0.5]
            ]

    else:
        print(f"\nError: Invalid cell type [{cell}]\n")
        return

    draw_unit_cell_with_lattice(
        a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma,
        lattice_points=lattice_points,
        dashed_lines=dashed_lines,
        basis_vectors=basis_vectors
    )


if __name__ == "__main__":
    main()