import sys
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D, proj3d


#始点の初期値
elev = 5.90
azim = -67.55

draw_lattice_vectors = True
draw_support_lines   = True
draw_primitive_cell  = False


# ブラべー格子
lattice_vector_color = 'blue'
axis_arrow_color = 'blue'
axis_label_font_size = 24
# 格子点
point_color = 'red'
point_size = 500


#cell = 'SC'
#cell = 'ST'
#cell = 'SO'
#cell = 'SR'
#cell = 'SH'
#cell = 'FC'
#cell = 'FO'
cell = 'BC'
#cell = 'BT'
#cell = 'BO'
#cell = 'CO'
#cell = 'SM'
#cell = 'STri'
#cell = 'CM'

argv = sys.argv
nargs = len(argv)
if nargs > 1: cell = argv[1]
if nargs > 2: draw_lattice_vectors = int(argv[2])
if nargs > 3: draw_support_lines = int(argv[3])
if nargs > 4: draw_primitive_cell = int(argv[4])

print()
print(f"{cell=}")
print(f"{draw_lattice_vectors=}")
print(f"{draw_support_lines=}")
print(f"{draw_primitive_cell=}")


if cell[0] == 'S':
    if cell[1] == 'C':
        a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 90.0, 90.0, 90.0
    elif cell[1:] == 'T':
        a, b, c, alpha, beta, gamma = 5.0, 5.0, 7.0, 90.0, 90.0, 90.0
    elif cell[1] == 'O':
        a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0
    elif cell[1] == 'H':
        a, b, c, alpha, beta, gamma = 7.0, 7.0, 5.0, 90.0, 90.0, 120.0
    elif cell[1] == 'R':
        a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 70.0, 70.0, 70.0
    elif cell[1] == 'M':
        a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 110.0, 90.0
    elif cell[1:] == 'Tri':
        a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 80.0, 110.0, 60.0
    
    lattice_points = [
        [0,0,0], [1,0,0], [0,1,0], [0,0,1],
        [0,1,1], [1,0,1], [1,1,0], [1,1,1],
    ]

    dashed_lines = [
    ]

    basis_vectors = [
    ]
elif cell[0] == 'F':
    if cell[1] == 'C':
        a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 90.0, 90.0, 90.0
    elif cell[1] == 'O':
        a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0

    lattice_points = [
        [0,0,0], [1,0,0], [0,1,0], [0,0,1],
        [0,1,1], [1,0,1], [1,1,0], [1,1,1],
        [0.0,0.5,0.5],
        [0.5,0.0,0.5],
        [0.5,0.5,0.0],
        [1.0,0.5,0.5],
        [0.5,1.0,0.5],
        [0.5,0.5,1.0],
    ]

    dashed_lines = [
        [(0,0,0), (1,1,0)],
        [(0,1,0), (1,0,0)],
        [(0,0,1), (1,1,1)],
        [(0,1,1), (1,0,1)],
        [(0,0,0), (1,0,1)],
        [(0,0,1), (1,0,0)],
        [(0,1,0), (1,1,1)],
        [(0,1,1), (1,1,0)],
        [(0,0,0), (0,1,1)],
        [(0,1,0), (0,0,1)],
        [(1,0,0), (1,1,1)],
        [(1,1,0), (1,0,1)],
    ]

    basis_vectors = [
        [-0.5,  0.5,  0.5],
        [ 0.5, -0.5,  0.5],
        [ 0.5,  0.5, -0.5]
    ]
elif cell[0] == 'B':
    if cell[1] == 'C':
        a, b, c, alpha, beta, gamma = 5.0, 5.0, 5.0, 90.0, 90.0, 90.0
    elif cell[1] == 'T':
        a, b, c, alpha, beta, gamma = 5.0, 5.0, 7.0, 90.0, 90.0, 90.0
    elif cell[1] == 'O':
        a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0

    lattice_points = [
        [0,0,0], [1,0,0], [0,1,0], [0,0,1],
        [0,1,1], [1,0,1], [1,1,0], [1,1,1],
        [0.5,0.5,0.5]
    ]

    dashed_lines = [
        [(0,0,0), (1,1,1)],
        [(1,0,0), (0,1,1)],
        [(0,1,0), (1,0,1)],
        [(0,0,1), (1,1,0)],
        [(0.5,0.5,0.5), (0,0,0)],
        [(0.5,0.5,0.5), (1,1,1)]
    ]

    basis_vectors = [
        [-0.5,  0.5,  0.5],
        [ 0.5, -0.5,  0.5],
        [ 0.5,  0.5, -0.5]
    ]
elif cell[0] == 'C':
    if cell[1] == 'O':
        a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 90.0, 90.0
        lattice_points = [
            [0,0,0], [1,0,0], [0,1,0], [0,0,1],
            [0,1,1], [1,0,1], [1,1,0], [1,1,1],
            [0.5,0.5,0.0],
            [0.5,0.5,1.0]
        ]
        dashed_lines = [
            [(0,0,0), (1,1,0)],
            [(0,1,0), (1,0,0)],
            [(0,0,1), (1,1,1)],
            [(0,1,1), (1,0,1)],
        ]
        basis_vectors = [
            [-0.5,  0.5,  0.5],
            [ 0.5, -0.5,  0.5],
            [ 0.5,  0.5, -0.5]
        ]
    elif cell[1] == 'M':
        a, b, c, alpha, beta, gamma = 5.0, 7.0, 6.0, 90.0, 110.0, 90.0
        lattice_points = [
            [0,0,0], [1,0,0], [0,1,0], [0,0,1],
            [0,1,1], [1,0,1], [1,1,0], [1,1,1],
            [0.5,0.0,0.5],
            [0.5,1.0,0.5]
        ]
        dashed_lines = [
            [(0,0,0), (1,0,1)],
            [(0,0,1), (1,0,0)],
            [(0,1,0), (1,1,1)],
            [(0,1,1), (1,1,0)],
        ]
        basis_vectors = [
            [-0.5,  0.5,  0.5],
            [ 0.5, -0.5,  0.5],
            [ 0.5,  0.5, -0.5]
        ]

else:
    print(f"\nError: Invalid cell type [{cell}]\n")
    exit()


class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        proj = self.axes.get_proj()
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, proj)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)

def draw_vector(ax, vec, color, label, fontsize):
    arrow = Arrow3D([0, vec[0]], [0, vec[1]], [0, vec[2]],
                    mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
    ax.add_artist(arrow)
    ax.text(vec[0] * 1.05, vec[1] * 1.05, vec[2] * 1.05, label, fontsize = fontsize, color = color)

def lattice_vectors(a, b, c, alpha, beta, gamma):
    alpha_r, beta_r, gamma_r = np.radians([alpha, beta, gamma])
    va = np.array([a, 0, 0])
    vb = np.array([b * np.cos(gamma_r), b * np.sin(gamma_r), 0])
    cx = c * np.cos(beta_r)
    cy = c * (np.cos(alpha_r) - np.cos(beta_r) * np.cos(gamma_r)) / np.sin(gamma_r)
    cz = np.sqrt(c**2 - cx**2 - cy**2)
    vc = np.array([cx, cy, cz])
    return va, vb, vc

def set_equal_aspect(ax, points):
    xlim = [np.min(points[:,0]), np.max(points[:,0])]
    ylim = [np.min(points[:,1]), np.max(points[:,1])]
    zlim = [np.min(points[:,2]), np.max(points[:,2])]
    max_range = max(
        xlim[1] - xlim[0],
        ylim[1] - ylim[0],
        zlim[1] - zlim[0]
    ) / 2

    mid_x = np.mean(xlim)
    mid_y = np.mean(ylim)
    mid_z = np.mean(zlim)

    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    ax.set_box_aspect([1,1,1])

def draw_unit_cell_with_lattice(
    a, b, c, alpha, beta, gamma,
    lattice_points=None,
    dashed_lines=None,
    basis_vectors=None
):
    va, vb, vc = lattice_vectors(a, b, c, alpha, beta, gamma)
    origin = np.array([0, 0, 0])
    corners = [
        origin,
        va,
        vb,
        vc,
        va + vb,
        va + vc,
        vb + vc,
        va + vb + vc
    ]
    corners = np.array(corners)

    edges = [
        (0,1), (0,2), (0,3),
        (1,4), (1,5),
        (2,4), (2,6),
        (3,5), (3,6),
        (4,7), (5,7), (6,7)
    ]

    fig = plt.figure(figsize=(8,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=elev, azim=azim)
    def on_draw(event):
        elev = ax.elev
        azim = ax.azim
        print(f"elev: {elev:.2f}, azim: {azim:.2f}")
    fig.canvas.mpl_connect('draw_event', on_draw)

    # 単位格子の辺
    for i,j in edges:
        ax.plot(*zip(corners[i], corners[j]), color='black', linewidth=0.5)

    # 単位格子の頂点
    ax.scatter(corners[:,0], corners[:,1], corners[:,2], color='black', s=20)

    # 格子ベクトル
    if draw_lattice_vectors:
        for vec, label in zip([va, vb, vc], ['a', 'b', 'c']):
            draw_vector(ax, vec, lattice_vector_color, label, axis_label_font_size)
#            ax.quiver(0, 0, 0, *vec, color=axis_arrow_color, linewidth=2, arrow_length_ratio=0.1)
#            ax.text(*vec * 1.05, label, fontsize=axis_label_font_size, color=lattice_vector_color)

    # 格子点の描画
    real_points = []
    if lattice_points:
        real_points = [p[0]*va + p[1]*vb + p[2]*vc for p in lattice_points]
        real_points = np.array(real_points)
        ax.scatter(real_points[:,0], real_points[:,1], real_points[:,2], color=point_color, s=point_size)

    # 補助線（鎖線）
    if draw_support_lines and dashed_lines:
        for p1, p2 in dashed_lines:
            pt1 = p1[0]*va + p1[1]*vb + p1[2]*vc
            pt2 = p2[0]*va + p2[1]*vb + p2[2]*vc
            ax.plot(*zip(pt1, pt2), color='gray', linestyle='dashed', linewidth=0.8)

    # 基本格子
    if draw_primitive_cell and basis_vectors:
        for vec in basis_vectors:
            v = vec[0]*va + vec[1]*vb + vec[2]*vc
            ax.quiver(0, 0, 0, *v, color='green', linewidth=1.5, arrow_length_ratio=0.2)

        # 基本格子の描画
        if len(basis_vectors) == 3:
            bv1 = basis_vectors[0][0]*va + basis_vectors[0][1]*vb + basis_vectors[0][2]*vc
            bv2 = basis_vectors[1][0]*va + basis_vectors[1][1]*vb + basis_vectors[1][2]*vc
            bv3 = basis_vectors[2][0]*va + basis_vectors[2][1]*vb + basis_vectors[2][2]*vc

            base_points = [
                origin,
                bv1,
                bv2,
                bv3,
                bv1 + bv2,
                bv1 + bv3,
                bv2 + bv3,
                bv1 + bv2 + bv3
            ]

            base_edges = [
                (0,1), (0,2), (0,3),
                (1,4), (1,5),
                (2,4), (2,6),
                (3,5), (3,6),
                (4,7), (5,7), (6,7)
            ]

            for i,j in base_edges:
                ax.plot(*zip(base_points[i], base_points[j]), color='lightgreen', linewidth=1.2)

            # 基本格子ベクトルのラベル
            for vec, label in zip([bv1, bv2, bv3], ["a'", "b'", "c'"]):
                ax.text(*vec, label, fontsize=12, color='green')

    # 軸や背景の調整
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    ax.set_xticklabels([]); ax.set_yticklabels([]); ax.set_zticklabels([])
    ax.set_xlabel(''); ax.set_ylabel(''); ax.set_zlabel('')
    ax.grid(False)
    ax.xaxis.pane.set_visible(False)
    ax.yaxis.pane.set_visible(False)
    ax.zaxis.pane.set_visible(False)
    ax.xaxis.line.set_color((0,0,0,0))
    ax.yaxis.line.set_color((0,0,0,0))
    ax.zaxis.line.set_color((0,0,0,0))
    ax.set_facecolor((1,1,1,0))

    all_points = np.vstack([corners] + [real_points] if lattice_points else [corners])
    set_equal_aspect(ax, all_points)

    plt.tight_layout()
    plt.savefig("bravais_cell_with_basis_vectors.png", dpi=300, bbox_inches='tight', transparent=True)
    plt.show()

draw_unit_cell_with_lattice(
    a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma,
    lattice_points=lattice_points,
    dashed_lines=dashed_lines,
    basis_vectors=basis_vectors
)
