import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D, proj3d


#始点の初期値
elev = 5.90
azim = -67.55

axis_label_font_size = 24
lattice_vector_color = 'blue'


class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)

def draw_vector(ax, vec, color, label, fontsize):
    arrow = Arrow3D([0, vec[0]], [0, vec[1]], [0, vec[2]],
                    mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
    ax.add_artist(arrow)
    ax.text(vec[0] * 1.05, vec[1] * 1.05, vec[2] * 1.05, label, fontsize = fontsize, color = color)

def lattice_vectors(a, b, c, alpha, beta, gamma):
    alpha_r, beta_r, gamma_r = np.radians([alpha, beta, gamma])
    va = np.array([a, 0, 0])
    vb = np.array([b * np.cos(gamma_r), b * np.sin(gamma_r), 0])
    cx = c * np.cos(beta_r)
    cy = c * (np.cos(alpha_r) - np.cos(beta_r) * np.cos(gamma_r)) / np.sin(gamma_r)
    cz = np.sqrt(c**2 - cx**2 - cy**2)
    vc = np.array([cx, cy, cz])
    return va, vb, vc

def set_equal_aspect(ax, points):
    # 全軸の範囲を揃える
    xlim = [np.min(points[:,0]), np.max(points[:,0])]
    ylim = [np.min(points[:,1]), np.max(points[:,1])]
    zlim = [np.min(points[:,2]), np.max(points[:,2])]
    max_range = max(
        xlim[1] - xlim[0],
        ylim[1] - ylim[0],
        zlim[1] - zlim[0]
    ) / 2

    mid_x = np.mean(xlim)
    mid_y = np.mean(ylim)
    mid_z = np.mean(zlim)

    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    ax.set_box_aspect([1,1,1])

def draw_unit_cell(a, b, c, alpha, beta, gamma):
    va, vb, vc = lattice_vectors(a, b, c, alpha, beta, gamma)
    origin = np.array([0, 0, 0])
    points = [
        origin,
        va,
        vb,
        vc,
        va + vb,
        va + vc,
        vb + vc,
        va + vb + vc
    ]
    points = np.array(points)

    edges = [
        (0,1), (0,2), (0,3),
        (1,4), (1,5),
        (2,4), (2,6),
        (3,5), (3,6),
        (4,7), (5,7), (6,7)
    ]

    fig = plt.figure(figsize=(8,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=elev, azim=azim)
    def on_draw(event):
        elev = ax.elev
        azim = ax.azim
        print(f"elev: {elev:.2f}, azim: {azim:.2f}")
    fig.canvas.mpl_connect('draw_event', on_draw)


    # 単位格子の辺（細い黒線）
    for i,j in edges:
        ax.plot(*zip(points[i], points[j]), color='black', linewidth=0.5)

    # 頂点（黒い小さな●）
    ax.scatter(points[:,0], points[:,1], points[:,2], color='black', s=20)

    # 基本ベクトル（太い矢印）とラベル
    for vec, label in zip([va, vb, vc], ['a', 'b', 'c']):
         draw_vector(ax, vec, lattice_vector_color, label, axis_label_font_size)
#        ax.quiver(0, 0, 0, *vec, color='blue', linewidth=2, arrow_length_ratio=0.1)
#        ax.text(*vec * 1.05, label, fontsize=axis_font_size, color='blue')

    # 不要な要素を削除
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_zlabel('')
    ax.grid(False)
    ax.xaxis.line.set_color((0,0,0,0))
    ax.yaxis.line.set_color((0,0,0,0))
    ax.zaxis.line.set_color((0,0,0,0))
# 軸面（pane）の塗りつぶしを消す
    ax.xaxis.pane.set_visible(False)
    ax.yaxis.pane.set_visible(False)
    ax.zaxis.pane.set_visible(False)
# 軸線も消す（必要なら）
    ax.xaxis.line.set_color((0,0,0,0))
    ax.yaxis.line.set_color((0,0,0,0))
    ax.zaxis.line.set_color((0,0,0,0))
    ax.set_facecolor((1,1,1,0))  # 背景透明

#    ax.set_box_aspect([1,1,1])
    set_equal_aspect(ax, points)
    plt.tight_layout()
    plt.savefig("unit_cell.png", dpi=300, bbox_inches='tight', transparent=True)
    plt.show()



draw_unit_cell(a=5.0, b=5.5, c=4.5, alpha=80, beta=70, gamma=100) 
