import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.patheffects as patheffects  # 修正1: ここを追加
import argparse

def draw_custom_colorbar(fig, ax, colors, max_zone):
    """
    メインのAxesの右側に、自作のカラーバー（凡例）を描画する関数
    色領域の中央にBZ番号を配置します。
    """
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    
    cax.set_xlim(0, 1)
    cax.set_ylim(0, max_zone)
    
    cax.set_xticks([])
    cax.set_yticks([])
    for spine in cax.spines.values():
        spine.set_visible(False)
    
    for i in range(max_zone):
        bz_number = i + 1
        
        # 四角形の描画
        rect = patches.Rectangle((0, i), 1, 1, 
                                 linewidth=0.5, edgecolor='black', 
                                 facecolor=colors[bz_number])
        cax.add_patch(rect)
        
        # テキスト（番号）を中央に配置
        # 修正1対応: patheffectsを正しく呼び出し
        cax.text(0.5, i + 0.5, str(bz_number), 
                 ha='center', va='center', fontsize=12, fontweight='bold',
                 color='white', 
                 path_effects=[patheffects.withStroke(linewidth=2, foreground='black')])

    cax.set_title("BZ", fontsize=10)

def plot_brillouin_zones(max_zone, limit = None):
    resolution = 1000
    if limit is None: limit = np.sqrt(max_zone) + 1.0
    
    x = np.linspace(-limit, limit, resolution)
    y = np.linspace(-limit, limit, resolution)
    X, Y = np.meshgrid(x, y)
    
    dist_origin_sq = X**2 + Y**2
    zone_index = np.ones_like(X, dtype=int)
    
    search_range = int(np.ceil(limit + 1))
    lines_to_draw = []

    # --- ゾーン計算 ---
    for i in range(-search_range, search_range + 1):
        for j in range(-search_range, search_range + 1):
            if i == 0 and j == 0:
                continue
            
            dist_g_sq = (X - i)**2 + (Y - j)**2
            zone_index += (dist_g_sq < dist_origin_sq)

            rhs = 0.5 * (i**2 + j**2)
            if rhs / np.sqrt(i**2 + j**2) < limit * 1.5:
                lines_to_draw.append((i, j, rhs))

    # --- 描画準備 ---
    plot_data = np.where(zone_index <= max_zone, zone_index, 0)
    
    base_cmap = plt.get_cmap('tab10')
    colors = [base_cmap(i % 10) for i in range(max_zone + 1)]
    colors[0] = (1, 1, 1, 1)
    
    custom_cmap = ListedColormap(colors)

    fig, ax = plt.subplots(figsize=(10, 8))
    
    ax.imshow(plot_data, extent=[-limit, limit, -limit, limit], 
              origin='lower', cmap=custom_cmap, interpolation='nearest')
    
    x_range = np.linspace(-limit, limit, 100)
    for (i, j, rhs) in lines_to_draw:
        if j == 0:
            x_val = rhs / i
            if -limit <= x_val <= limit:
                ax.vlines(x_val, -limit, limit, colors='black', linestyles='-.', linewidth=0.5)
        else:
            y_vals = (rhs - i * x_range) / j
            ax.plot(x_range, y_vals, color='black', linestyle='-.', linewidth=0.5)

    grid_points_x = []
    grid_points_y = []
    for i in range(-search_range, search_range + 1):
        for j in range(-search_range, search_range + 1):
            grid_points_x.append(i)
            grid_points_y.append(j)
    ax.scatter(grid_points_x, grid_points_y, c='black', s=15, zorder=10)

    ax.set_xlim(-limit, limit)
    ax.set_ylim(-limit, limit)
    ax.set_aspect('equal')
    ax.set_title(f'Square Lattice Brillouin Zones (1-{max_zone})')
    
    # 修正2: raw string (r'...') を使用して警告を回避
    ax.set_xlabel(r'$k_x / (2\pi/a)$')
    ax.set_ylabel(r'$k_y / (2\pi/a)$')

    draw_custom_colorbar(fig, ax, colors, max_zone)

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Draw Brillouin Zones with Custom Legend.')
    parser.add_argument('max_zone', type=int, nargs='?', default=10, 
                        help='The maximum Brillouin zone number to draw (default: 5)')
    parser.add_argument('limit', type=float, nargs='?', default=None, 
                        help='The maximum plot range in 2pi/a (default: None)')
    args = parser.parse_args()
    
    plot_brillouin_zones(args.max_zone, args.limit)