import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D, proj3d


axis_label_font_size = 24


# 六方格子の軸長
a_H = 1.0
c_H = 1.6

# 六方格子ベクトル
a1 = np.array([a_H, 0, 0])
a2 = np.array([-a_H/2, a_H*np.sqrt(3)/2, 0])
a3 = np.array([0, 0, c_H])

# 六方格子点
hex_points = []
for i in range(2):
    for j in range(2):
        for k in range(2):
            pt = i*a1 + j*a2 + k*a3
            hex_points.append(pt)
hex_points = np.array(hex_points)

# 六方格子ベクトルを行ベクトルとして並べる（shape: 3×3）
A_hex = np.array([a1, a2, a3])  # shape (3, 3)

# 六方→菱面体変換行列
T = np.array([
    [ 2/3,  1/3,  1/3],
    [-1/3,  1/3,  1/3],
    [-1/3, -2/3,  1/3]
])

# 菱面体格子ベクトルを計算（T × A_hex）
R_rhombo = T @ A_hex  # shape (3, 3)

# 各ベクトルを抽出
r1, r2, r3 = R_rhombo

# 菱面体格子点
rhombo_points = []
for i in range(2):
    for j in range(2):
        for k in range(2):
            pt = i*r1 + j*r2 + k*r3
            rhombo_points.append(pt)
rhombo_points = np.array(rhombo_points)


# フォント設定
rcParams['font.family'] = 'MS Gothic'

# 3D矢印描画クラス（最新版対応）
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)

# 単位格子（平行六面体）描画関数
def draw_unit_cell(ax, origin, v1, v2, v3, color='black', lw=0.5):
    O = origin
    A = O + v1
    B = O + v2
    C = O + v3
    D = O + v1 + v2
    E = O + v2 + v3
    F = O + v3 + v1
    G = O + v1 + v2 + v3

    edges = [
        (O, A), (O, B), (O, C),
        (A, D), (A, F),
        (B, D), (B, E),
        (C, E), (C, F),
        (D, G), (E, G), (F, G)
    ]

    for start, end in edges:
        ax.plot([start[0], end[0]],
                [start[1], end[1]],
                [start[2], end[2]],
                color=color, lw=lw)

# ベクトル描画関数
def draw_vector(vec, color, label, fontsize):
    arrow = Arrow3D([0, vec[0]], [0, vec[1]], [0, vec[2]],
                    mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
    ax.add_artist(arrow)
    ax.text(vec[0] * 1.05, vec[1] * 1.05, vec[2] * 1.05, label, fontsize = fontsize, color = color)


# 描画開始
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# 格子ベクトル描画
draw_vector(a1, 'blue', 'a(H)', axis_label_font_size)
draw_vector(a2, 'blue', 'b(H)', axis_label_font_size)
draw_vector(a3, 'blue', 'c(H)', axis_label_font_size)
draw_vector(r1, 'red',  'a(R)', axis_label_font_size)
draw_vector(r2, 'red',  'b(R)', axis_label_font_size)
draw_vector(r3, 'red',  'c(R)', axis_label_font_size)

# 単位格子描画
draw_unit_cell(ax, origin=np.array([0,0,0]), v1=a1, v2=a2, v3=a3, color='blue', lw=0.5)
draw_unit_cell(ax, origin=np.array([0,0,0]), v1=r1, v2=r2, v3=r3, color='red', lw=0.5)

# 格子点
ax.scatter(hex_points[:,0], hex_points[:,1], hex_points[:,2], color='blue')
ax.scatter(rhombo_points[:,0], rhombo_points[:,1], rhombo_points[:,2], color='red')

# 六角柱描画
theta = np.linspace(0, 2*np.pi, 7)
r = a_H
z_vals = [0, c_H]
for z in z_vals:
    x = r * np.cos(theta)
    y = r * np.sin(theta)
    ax.plot(x, y, z, color='gray')
for i in range(6):
    x = [r*np.cos(theta[i]), r*np.cos(theta[i])]
    y = [r*np.sin(theta[i]), r*np.sin(theta[i])]
    z = [0, c_H]
    ax.plot(x, y, z, color='gray')

ax.set_axis_off()
ax.grid(False)
ax.set_box_aspect([1,1,1])
ax.view_init(elev=20, azim=30)
plt.tight_layout()
plt.show()
