"""
六方格子（Hexagonal）と菱面体格子（Rhombohedral）の幾何学的関係を可視化するモジュール。

概要:
    このスクリプトは、六方格子の基本ベクトル a1, a2, a3 を定義し、
    変換行列 T を用いて菱面体格子の基本ベクトル r1, r2, r3 へ変換します。
    変換後の格子ベクトルおよび単位格子を 3D 空間上に描画し、両者の関係を視覚化します。

詳細説明:
    このスクリプトは以下の手順で処理を実行します。
    1. 六方格子の軸長 aH, cH に基づき、直交座標系での格子ベクトルを生成します。
    2. 六方→菱面体の変換行列 T を適用し、菱面体格子の基底を計算します。
    3. 3D 矢印（Arrow3D クラス）を用いて格子ベクトルを描画します。
    4. 平行六面体（単位格子）を描画し、空間的な重なりを示します。

変換式:
    菱面体格子ベクトルを Rrhombo、六方格子ベクトルを Ahex とすると、
    Rrhombo = T @ Ahex で表されます。
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D, proj3d

#===================
# グローバル設定
#===================
# フォント設定（日本語化対策）
rcParams['font.family'] = 'MS Gothic'

class Arrow3D(FancyArrowPatch):
    """
    概要:
        3D空間に矢印を描画するためのmatplotlib拡張クラスです。

    詳細説明:
        matplotlib.patches.FancyArrowPatch を継承し、
        3D座標を受け取って2D投影された矢印として描画します。
        主に draw メソッドと do_3d_projection メソッドが3D描画のためにオーバーライドされています。

    引数:
        :param xs: 矢印のX座標のリストまたはタプル (開始点X, 終了点X)。
        :type xs: list or tuple
        :param ys: 矢印のY座標のリストまたはタプル (開始点Y, 終了点Y)。
        :type ys: list or tuple
        :param zs: 矢印のZ座標のリストまたはタプル (開始点Z, 終了点Z)。
        :type zs: list or tuple
        :param args: FancyArrowPatch に渡される追加の位置引数。
        :param kwargs: FancyArrowPatch に渡される追加のキーワード引数。
    """
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)

def draw_unit_cell(ax, origin, v1, v2, v3, color='black', lw=0.5):
    """
    概要:
        指定された3つのベクトルに基づいて平行六面体（単位格子）を描画します。

    詳細説明:
        原点 origin から始まり、3つの基底ベクトル v1, v2, v3 で定義される
        平行六面体の全ての辺を描画します。これにより単位格子が可視化されます。

    引数:
        :param ax: 描画対象の3D Axesオブジェクト。
        :type ax: mpl_toolkits.mplot3d.axes3d.Axes3D
        :param origin: 格子の原点座標。
        :type origin: numpy.ndarray
        :param v1: 単位格子を構成する第一のベクトル。
        :type v1: numpy.ndarray
        :param v2: 単位格子を構成する第二のベクトル。
        :type v2: numpy.ndarray
        :param v3: 単位格子を構成する第三のベクトル。
        :type v3: numpy.ndarray
        :param color: 描画する線の色。デフォルトは 'black'。
        :type color: str
        :param lw: 描画する線の太さ。デフォルトは 0.5。
        :type lw: float
    """
    O = origin
    A = O + v1
    B = O + v2
    C = O + v3
    D = O + v1 + v2
    E = O + v2 + v3
    F = O + v3 + v1
    G = O + v1 + v2 + v3

    edges = [
        (O, A), (O, B), (O, C),
        (A, D), (A, F),
        (B, D), (B, E),
        (C, E), (C, F),
        (D, G), (E, G), (F, G)
    ]

    for start, end in edges:
        ax.plot([start[0], end[0]],
                [start[1], end[1]],
                [start[2], end[2]],
                color=color, lw=lw)

def draw_vector(ax, vec, color, label, fontsize):
    """
    概要:
        3D空間にベクトル（矢印）とそのラベルを描画します。

    詳細説明:
        指定されたベクトルを原点 (0,0,0) から開始する矢印として Axes オブジェクト ax に追加します。
        矢印の終点付近に、指定されたラベルを配置します。

    引数:
        :param ax: 描画対象の3D Axesオブジェクト。
        :type ax: mpl_toolkits.mplot3d.axes3d.Axes3D
        :param vec: 描画するベクトルの終点座標。原点 (0,0,0) から開始します。
        :type vec: numpy.ndarray
        :param color: 矢印とラベルの色。
        :type color: str
        :param label: ベクトルに表示するテキストラベル。
        :type label: str
        :param fontsize: ラベルのフォントサイズ。
        :type fontsize: int or float
    """
    arrow = Arrow3D([0, vec[0]], [0, vec[1]], [0, vec[2]],
                    mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
    ax.add_artist(arrow)
    ax.text(vec[0] * 1.05, vec[1] * 1.05, vec[2] * 1.05, label, fontsize=fontsize, color=color)

def main():
    """
    概要:
        格子変換の計算と描画を実行するメインルーチンです。

    詳細説明:
        六方格子の軸長を定義し、六方格子ベクトルと格子点を計算します。
        その後、六方→菱面体の変換行列を適用して菱面体格子ベクトルと格子点を導出します。
        最後に、matplotlib を使用してこれらの格子ベクトルと単位格子を3D空間に描画し、
        六方格子の外郭も表示して、両格子の関係性を視覚的に示します。
    """
    axis_label_font_size = 24

    # 六方格子の軸長
    a_H = 1.0
    c_H = 1.6

    # 六方格子ベクトル
    a1 = np.array([a_H, 0, 0])
    a2 = np.array([-a_H/2, a_H*np.sqrt(3)/2, 0])
    a3 = np.array([0, 0, c_H])

    # 六方格子点
    hex_points = []
    for i in range(2):
        for j in range(2):
            for k in range(2):
                pt = i*a1 + j*a2 + k*a3
                hex_points.append(pt)
    hex_points = np.array(hex_points)

    # 六方格子ベクトルを行ベクトルとして並べる
    A_hex = np.array([a1, a2, a3])

    # 六方→菱面体変換行列 T
    T = np.array([
        [ 2/3,  1/3,  1/3],
        [-1/3,  1/3,  1/3],
        [-1/3, -2/3,  1/3]
    ])

    # 菱面体格子ベクトルを計算（T @ A_hex）
    R_rhombo = T @ A_hex
    r1, r2, r3 = R_rhombo

    # 菱面体格子点
    rhombo_points = []
    for i in range(2):
        for j in range(2):
            for k in range(2):
                pt = i*r1 + j*r2 + k*r3
                rhombo_points.append(pt)
    rhombo_points = np.array(rhombo_points)

    # --- 描画開始 ---
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    # 格子ベクトル描画
    draw_vector(ax, a1, 'blue', 'a(H)', axis_label_font_size)
    draw_vector(ax, a2, 'blue', 'b(H)', axis_label_font_size)
    draw_vector(ax, a3, 'blue', 'c(H)', axis_label_font_size)
    draw_vector(ax, r1, 'red',  'a(R)', axis_label_font_size)
    draw_vector(ax, r2, 'red',  'b(R)', axis_label_font_size)
    draw_vector(ax, r3, 'red',  'c(R)', axis_label_font_size)

    # 単位格子描画
    draw_unit_cell(ax, origin=np.array([0,0,0]), v1=a1, v2=a2, v3=a3, color='blue', lw=0.5)
    draw_unit_cell(ax, origin=np.array([0,0,0]), v1=r1, v2=r2, v3=r3, color='red', lw=0.5)

    # 格子点のプロット
    ax.scatter(hex_points[:,0], hex_points[:,1], hex_points[:,2], color='blue', alpha=0.6)
    ax.scatter(rhombo_points[:,0], rhombo_points[:,1], rhombo_points[:,2], color='red', alpha=0.6)

    # 六角柱描画（外郭）
    theta = np.linspace(0, 2*np.pi, 7)
    r = a_H
    z_vals = [0, c_H]
    for z in z_vals:
        x = r * np.cos(theta)
        y = r * np.sin(theta)
        ax.plot(x, y, z, color='gray', alpha=0.5)
    for i in range(6):
        x = [r*np.cos(theta[i]), r*np.cos(theta[i])]
        y = [r*np.sin(theta[i]), r*np.sin(theta[i])]
        z = [0, c_H]
        ax.plot(x, y, z, color='gray', alpha=0.5)

    # グラフ設定
    ax.set_axis_off()
    ax.grid(False)
    ax.set_box_aspect([1,1,1])
    ax.view_init(elev=20, azim=30)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()