# https://zenn.dev/shittoku_xxx/articles/13afd6fdfac44e

import sys
import argparse
import numpy as np
from numpy import abs, angle, exp, sqrt, cos, sin, arccos, arctan2, pi, min, max
from scipy.integrate import quad
from scipy.constants import physical_constants
from scipy.special import sph_harm, lpmv, assoc_laguerre

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation
from matplotlib.colors import LightSource, Normalize, LinearSegmentedColormap

import tkWavefunction_H as wf
import tkPlot3d as plt3d


a0 = physical_constants['Bohr radius'][0] * 1e10  # オングストローム単位に変換
pi = np.pi


mode = 'rplot'
functype = '2px'

# for Monte Carlo sampling
nsamples = 30000

figsize = (12, 10)
fontsize = 16


nbins = 100
colors_wf    = [(0, 0, 1), (0, 1, 0), (1, 0, 0)]  # 緑、青、赤
colors_phase = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (1, 0, 0)]  # 赤, 緑, 青, 黄, 赤
cmap_name_wf    = 'custom cmap for wavefunction'
cmap_name_phase = 'custom cmap for phase'


# Custom ArgumentParserクラス
class CustomArgumentParser(argparse.ArgumentParser):
    def print_help(self, *args, **kwargs):
        print()
        print("General usage:")
        super().print_help(*args, **kwargs)
        print()
        print("Examples:")
        print(r"usage: python .\plot_wf.py [rplot|ranim|dot|dotc|dotanim|iso3d|iso3danim|rsurface] [100r|2pxa etc]")
        print(r"usage: python .\plot_wf.py [Rnl|Ylm|Theta|Phi|normalize]")
        print(r"usage: python .\plot_wf.py [3dMC_animation|contour3d|torus]")

# argparse のセットアップ
parser = CustomArgumentParser(description='Plot wave functions of H\n\n')
parser.add_argument('mode', nargs='?', default=mode, help='Mode of operation')
parser.add_argument('functype', nargs='?', default=functype, help='Function type')
parser.add_argument('nsamples', nargs='?', type=int, default=nsamples, help='Number of samples for Monte Carlo sampling')
parser.add_argument('--fplot', type=int, default=1, help='Flag to plot graphs (1: enable, 0: disable)')
parser.add_argument('--fsave', type=int, default=1, help='Flag to save figure file (1: enable, 0: disable)')

# 引数の解析
args = parser.parse_args()

# 引数の設定
mode = args.mode
functype = args.functype
nsamples = args.nsamples
fplot = args.fplot
fsave = args.fsave

# 使用例の出力
print(f"mode = {mode}")
print(f"functype = {functype}")
print(f"nsamples = {nsamples}")
print(f"fplot = {fplot}")


def usage():
    parser.print_help()

def torus(theta, phi):
# トーラスのパラメータ
    R = 3  # 大半径
    r = 1  # 小半径

    X = (R + r * np.cos(theta)) * np.cos(phi)
    Y = (R + r * np.cos(theta)) * np.sin(phi)
    Z = r * np.sin(theta)
    return X, Y, Z

def sphere(X, Y, Z):
    return (X**2 + Y**2 + Z**2)  # 単位球: x^2 + y^2 + z^2 = r^2

def exp_r2(X, Y, Z):
    r2 = X**2 + Y**2 + Z**2
    return np.exp(-r2)


print()
print(f"mode: {mode}")
print(f"function type: {functype}")
ftype, nlmt = wf.analyze_function(functype)
nlm, orb_type = nlmt[:3], nlmt[3]
print(f"   type: {ftype}  nlm:", nlm, f"  orb_type: {orb_type}")

if mode == 'torus':
    f = torus
elif ftype == 'c' or ftype == 'r':
    f = wf.psi_r
else:
    if ftype == 'f':
        if nlm == 'sphere':
            f = sphere
        elif nlm == 'exp_r2':
            f = exp_r2
    elif mode == 'contour3d' or mode == 'iso3d':
        f = sphere
    else:
        f = sphere


def plot_torus():
# パラメトリック変数
    theta = np.linspace(0, 2 * np.pi, 100)
    phi = np.linspace(0, 2 * np.pi, 100)
    theta, phi = np.meshgrid(theta, phi)

    X, Y, Z = torus(theta, phi)

# カスタムカラーマップの定義
    colors = [(1, 0.5, 0.5), (0, 1, 0), (1, 0.5, 0.5)]  # 薄い赤から青、符水赤へのグラデーション
    color_map = plt3d.make_color_map(phi, colors, cmap_name = 'custom_cmap', nbins = nbins)

# プロット
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    plt3d.plot_surface3d(ax, X, Y, Z, facecolors = color_map, edgecolor = 'none', alpha = 0.7, shade = True)

    return plt

def plot_isosurface3d():
# 3次元グリッドを作成
    x = np.linspace(-1.2, 1.2, 40)
    y = np.linspace(-1.2, 1.2, 40)
    z = np.linspace(-1.2, 1.2, 40)
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

# 3Dスカラー場: 球の等値面 (中心を原点にする)
    F = f(X, Y, Z)

# 等値面のレベル
    levels = [0.1, 0.2, 0.3]
    colors = ['cyan', 'magenta', 'yellow']  # 等値面の色

# 3D プロット
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')

    dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
    plt3d.plot_isosurface3d(ax, X, Y, Z, F, levels = levels, origin = [x.min(), y.min(), z.min()], spacing = [dx, dy, dz],
            colors = colors, edgecolor = 'k', alpha = 0.3, linewidth = 0.1)

    return plt

def plot_contour3d():
    x = np.linspace(-5, 5, 100)
    y = np.linspace(-5, 5, 100)
    z = np.linspace(-5, 5, 100)
    X, Y, Z = np.meshgrid(x, y, z)
    F = exp_r2(X, Y, Z)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    plt3d.contour3d(ax, X[:, :, 50], Y[:, :, 50], F[:, :, 50], nbins = nbins, cmap = cm.viridis)

    return plt

def plot_3dMC_animation():
# サンプル点の総数と一度に追加する点の数
    num_points = 100000
    nadd = 100

# アニメーションの設定
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    sc = ax.scatter([], [], [], marker='o', s=1, alpha=0.5)

# 初期化関数
    def init():
        sc._offsets3d = ([], [], [])
        return sc,

# 更新関数
    def update(frame):
        new_points = nadd  # 追加する新しい点の数
        u = np.random.rand(new_points)
        r = np.cbrt(-np.log(1 - u))  # 立方根で変換

        phi = np.random.rand(new_points) * 2.0 * np.pi
        cos_theta = np.random.rand(new_points) * 2 - 1  # cos(θ) を一様分布させる
        theta = np.arccos(cos_theta)

        x = r * np.sin(theta) * np.cos(phi)
        y = r * np.sin(theta) * np.sin(phi)
        z = r * np.cos(theta)

    # x, y, z の範囲を -10 から 10 にスケーリング
        x = (x / np.max(np.abs(x))) * 10
        y = (y / np.max(np.abs(y))) * 10
        z = (z / np.max(np.abs(z))) * 10

    # 分布関数 |x| * 4π r^3 exp(-r**2)
        weights = np.abs(x) * 4 * np.pi * r**3 * np.exp(-r**2)
        weights /= np.sum(weights)  # 正規化

        indices = np.random.choice(np.arange(new_points), size=new_points, p=weights)
        x_sampled = x[indices]
        y_sampled = y[indices]
        z_sampled = z[indices]

        offsets = np.array(sc._offsets3d)
        offsets = np.vstack([offsets.T, np.column_stack((x_sampled, y_sampled, z_sampled))]).T
        sc._offsets3d = (offsets[0], offsets[1], offsets[2])

    # 色の設定
        color_map = np.where(offsets[0] > 0, 'red', 'blue')  # x > 0 を赤、x < 0 を青
        sc.set_color(color_map)

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')

    # 軸の範囲を設定
        ax.set_xlim([-10, 10])
        ax.set_ylim([-10, 10])
        ax.set_zlim([-10, 10])

    # タイトルに点の総数を表示
        ax.set_title(f'Total points: {len(offsets[0])}')

        return sc,

# アニメーションの作成
    ani = FuncAnimation(fig, update, frames=num_points // nadd, init_func=init, blit=False, repeat=False)

    if fplot: plt.pause(0.1)

#ani.save('3dMC_animation.gif', writer='imagemagick')

    input("Press ENTER to terminate>>")

    return None

def plot_Phi():
    N = 101
    phi0, phi1 = 0.0, 2.0 * pi
    phi = np.linspace(phi0, phi1, N)

    flist = []
    for m in range(0, 5):
        f = wf.Ylm_phi(m, phi)
        flist.append([m, f.real, f.imag])

    fig, axes = plt.subplots(2, 1, figsize = (10, 10))
 
    phi_deg = 180.0 / pi * phi
    for inf in flist:
        axes[0].plot(phi_deg, inf[1], label = f"m={inf[0]},real")
        axes[1].plot(phi_deg, inf[2], label = f"m={inf[0]},imag")
    axes[0].set_xlabel('phi')
    axes[0].set_ylabel('exp(j*m*phi),real')
    axes[0].legend()

    axes[1].set_xlabel('phi')
    axes[1].set_ylabel('exp(j*m*phi),imag')
    axes[1].legend()
    
    return plt


def plot_Ylm():
    N = 101
    theta0, theta1 = 0.0, pi
    theta = np.linspace(theta0, theta1, N)

    flist = []
    for l in range(0, 3):
        for m in range(-l, l + 1):
            f = wf.Ylm_theta(l, m, theta)
            flist.append([l, m, f])

    fig, ax = plt.subplots(1, 1, figsize = (10, 10))
 
    theta_deg = 180.0 / pi * theta
    for inf in flist:
        ax.plot(theta_deg, inf[2], label = f"l,m={inf[0]},{inf[1]}")
    ax.set_title('Spherical harmonic (theta dependence)')
    ax.set_xlabel('theta')
    ax.set_ylabel('Ylm')

    ax.legend()
    
    return plt

def plot_Rnl():
    Nr = 101
    rmin = 0.0
    rmax = 40.0

    a00 = 1.0
    r = np.linspace(rmin, rmax, Nr)
    f1list = []
    f2list = []
    P1list = []
    P2list = []
    for inf in [['1s', 1, 0], ['2s', 2, 0], ['2p', 2, 1], ['3s', 3, 0], ['3p', 3, 1], ['3d', 3, 2], 
                ['4s', 4, 0], ['4p', 4, 1], ['4d', 4, 2], ['4f', 4, 3]]:
        f1 = wf.Rnl(r, inf[1], inf[2])
        f2 = wf.Rnl_if(r, inf[1], inf[2])
        f1list.append([inf[0], inf[1], inf[2], f1])
        f2list.append([inf[0], inf[1], inf[2], f2])
        if f1 is not None:
            P1list.append([inf[0], inf[1], inf[2], r**2 * f1**2])
        else:
            P1list.append([inf[0], inf[1], inf[2], None])
        if f2 is not None:
            P2list.append([inf[0], inf[1], inf[2], r**2 * f2**2])
        else:
            P2list.append([inf[0], inf[1], inf[2], None])

#plot
#   plt.rcParams["font.size"] = 18
    fig1, ax1 = plt.subplots(3, 3, figsize = (10, 10))
    ax1 = ax1.flatten()
    fig1.subplots_adjust(hspace=0.4, wspace=0.3)
    fig1.suptitle('$R_{nl}(r)$: Radial wave function')

    for ax, inf1, inf2 in zip(ax1, f1list, f2list):
        if inf1[3] is not None:
            ax.plot(r, inf1[3], label='Module')
        if inf2[3] is not None:
            ax.plot(r, inf2[3], label='explicit', linestyle = 'None', marker='o', markersize = 2.0)
        ax.set_title(inf1[0])
        ax.grid()
        ax.legend()
        ax.set_xlabel('r/$a_0$')
        ax.set_ylabel('$R_{nl}(r)$')

    if fsave: plt.savefig('plot_wf_Rnl.png')

#動径確率密度分布の描画
    fig2, ax2 = plt.subplots(3, 3, figsize = (10, 10))
    ax2 = ax2.flatten()
    fig2.subplots_adjust(hspace=0.4, wspace=0.3)
    fig2.suptitle('$P_{nl}(r)$: Radial probability density distribution')

    for ax, inf1, inf2 in zip(ax2, P1list, P2list):
        if inf1[3] is not None:
            ax.plot(r, inf1[3], label='Module')
        if inf2[3] is not None:
            ax.plot(r, inf2[3], label='explicit', linestyle = 'None', marker='o', markersize = 2.0)
        ax.set_title(inf1[0])
        ax.grid()
        ax.legend()
        ax.set_xlabel('r/$a_0$')
        ax.set_ylabel('$P_{nl}(r)$')

    if fsave: plt.savefig('plot_wf_Pnl.png')

    return plt

def plot_rplot():
    n, l, m = nlm
    orb_name = wf.get_orb_name(n, l, m)
    n, l, m = wf.get_qnumbers(n, l, m)

    print()
    print(f"Plot R(theta, phi) for {orb_name}  n, l, m=({n}, {l}, {m})")
    print(f"  orbital type: {orb_type}")
    En = wf.En(n)
    print(f"  En={En:.6g} eV")

    theta = np.linspace(0, pi, 101)
    phi   = np.linspace(0, 2.0 * pi, 101)
    Theta2d, Phi2d = np.meshgrid(theta, phi)
#    r  = wf.Ylm_real(Theta2d, Phi2d, l, m)
    f = wf.Ylm(Theta2d, Phi2d, l, m)
    r, phase = wf.get_by_type(f, orb_type)
    nfigures = 2
    if orb_type == 'a':
        title = r"|$\psi$|"
    elif orb_type == 'a2':
        title = r"|$\psi$|$^2$"
    elif orb_type == 'r':
        title = r"$\psi_r$"
    elif orb_type == 'i':
        title = r"$\psi_i$"

    ra = abs(r)
    z   = ra * cos(Theta2d)
    rxy = ra * sin(Theta2d)
    x   = rxy * cos(Phi2d)
    y   = rxy * sin(Phi2d)

    alpha = 0.7

# カスタムカラーマップの作成
    color_list = colors_wf
    cmap_name = cmap_name_wf
    custom_cmap = LinearSegmentedColormap.from_list(cmap_name, color_list, N = nbins)
# rの値にcmapを対応させる
    norm = Normalize(vmin = min(r), vmax = max(r))
    face_colors = custom_cmap(norm(r))

#    ax = fig.add_subplot(111, projection='3d')
    fig, axes = plt.subplots(1, nfigures, figsize = (8 * nfigures, 10), subplot_kw = {'projection': '3d'})
    if nfigures == 1:
        ax = axes
    else:
        ax = axes[0]
        ax2 = axes[1]

    if phase is None:
        options = {"facecolors": face_colors, "edgecolor": None}
    else:
        phase_min = np.min(phase)
        phase_max = np.max(phase)
        if phase_min == phase_max:
            phase_min -= 0.1
            phase_max += 0.1
        normalized_phase = (phase - phase_min) / (phase_max - phase_min)

# カスタムカラーマップを定義
        colors = colors_phase
        cmap_name = cmap_name_phase
        custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N = nbins)
        facecolors = custom_cmap(normalized_phase)

        options = {"facecolors": facecolors, "edgecolor": None}

    plt3d.plot_surface3d(ax, x, y, z, **options, alpha = alpha, shade = False,
                    xlabel = 'x', ylabel = 'y', zlabel = 'z')

    def contour_func(x, y, z):
        f = wf.Ylm_xyz(x, y, z, l, m)
        r, phase = wf.get_by_type(f, orb_type)
        return r

    maxx = plt3d.get_max_xyz(x, y, z)
    plt3d.plot_contours_xyz_by_func(ax, contour_func, minx = -maxx, maxx = maxx, nmesh = 100,
                            posx = 0.1, posy = 0.1, posz = 0.1,
                            offsetx = -maxx, offsety = maxx, offsetz = -maxx, 
                            cmap = custom_cmap, levels = 20, alpha = 0.3)

    plt3d.set_cubic_scale(ax, maxx)
    ax.set_title(title)

    if phase is None:
        cbar = plt3d.show_color_bar(fig, ax, scale = z, cmap = custom_cmap, shrink = 0.5, aspect = 10.0)
        if orb_type == 'i':
            cbar.set_label(r'$\psi_i$')
        else:
            cbar.set_label(r'$\psi_r$')
    else:
        vmin = phase_min * 180.0 / pi
        vmax = phase_max * 180.0 / pi
        scmap = plt.cm.ScalarMappable(cmap = custom_cmap, norm = plt.Normalize(vmin = vmin, vmax = vmax))
        scmap.set_array(normalized_phase)
        cbar = plt.colorbar(scmap, ax=ax, shrink=0.5, aspect=10.0)
        cbar.set_label('phase')

    if phase is not None and nfigures > 1:
        ra = abs(phase)
        z   = ra * cos(Theta2d)
        rxy = ra * sin(Theta2d)
        x   = rxy * cos(Phi2d)
        y   = rxy * sin(Phi2d)
        plt3d.plot_surface3d(ax2, x * 180.0/pi, y * 180.0/pi, z * 180.0/pi, **options, alpha = alpha, shade = False,
                    xlabel = 'x', ylabel = 'y', zlabel = 'z')
#        plt3d.set_cubic_scale(ax2, maxx)
        ax2.set_title("phase")

    return plt

def plot_ranim():
    n, l, m = nlm
    orb_name = wf.get_orb_name(n, l, m)
    n, l, m = wf.get_qnumbers(n, l, m)

    print()
    print(f"Animation of R(theta, phi)=Ylm(theta, phi) for {orb_name}  n, l, m=({n}, {l}, {m})")
    print(f"  orbital type: {orb_type}")
    En = wf.En(n)
    print(f"  En={En:.6g} eV")

    theta = np.linspace(0, pi, 101)
    phi   = np.linspace(0, 2.0 * pi, 101)
    Theta2d, Phi2d = np.meshgrid(theta, phi)
#    r  = wf.Ylm_real(Theta2d, Phi2d, l, m)
    f = wf.Ylm(Theta2d, Phi2d, l, m)
    r, phase = wf.get_by_type(f, orb_type)
    nfigures = 1
    if orb_type == 'a':
        title = rf"|$\psi$| ({n}{l}{m})"
    elif orb_type == 'a2':
        title = rf"|$\psi$|$^2$ ({n}{l}{m})"
    elif orb_type == 'r':
        title = rf"$\psi_r$ ({n}{l}{m})"
    elif orb_type == 'i':
        title = rf"$\psi_i$ ({n}{l}{m})"

    ra = abs(r)
    z   = ra * cos(Theta2d)
    rxy = ra * sin(Theta2d)
    x   = rxy * cos(Phi2d)
    y   = rxy * sin(Phi2d)

    alpha = 0.7

    fig, ax = plt.subplots(1, nfigures, figsize = (8, 8), subplot_kw = {'projection': '3d'})

    phase_min = -pi
    phase_max =  pi
    normalized_phase = (phase - phase_min) / (phase_max - phase_min)

# カスタムカラーマップを定義
    colors = colors_phase
    cmap_name = cmap_name_phase
    custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N = nbins)
    facecolors = custom_cmap(normalized_phase)

    options = {"facecolors": facecolors, "edgecolor": None}

    plt3d.plot_surface3d(ax, x, y, z, **options, alpha = alpha, shade = False,
                    xlabel = 'x', ylabel = 'y', zlabel = 'z')

    maxx = plt3d.get_max_xyz(x, y, z)
    plt3d.set_cubic_scale(ax, maxx)
    ax.set_title(title)

    vmin = -180.0
    vmax =  180.0
    scmap = plt.cm.ScalarMappable(cmap = custom_cmap, norm = plt.Normalize(vmin = vmin, vmax = vmax))
    scmap.set_array(normalized_phase)
    cbar = plt.colorbar(scmap, ax = ax, shrink = 0.5, aspect = 10.0)
    cbar.set_label('phase')

    nmaxiter = 100
    wait = 5 # ms
    ani = None
    window_closed = False
    def update_animation(frame):
        nonlocal window_closed, ani

        if window_closed :
            if ani.event_source: ani.event_source.stop()
            return ax

        if frame % 10 == 0: print(f"frame: {frame}/{nmaxiter}")

        ax.cla()  # 軸の設定を保持するために全体をクリア

        f = wf.Ylm(Theta2d, Phi2d, l, m, phase_m = frame * 0.1)
        r, phase = wf.get_by_type(f, orb_type)
        ra = abs(r)
        z   = ra * cos(Theta2d)
        rxy = ra * sin(Theta2d)
        x   = rxy * cos(Phi2d)
        y   = rxy * sin(Phi2d)

        phase_min = -pi
        phase_max =  pi
        normalized_phase = (phase - phase_min) / (phase_max - phase_min)
        custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N = nbins)
        facecolors = custom_cmap(normalized_phase)
        options = {"facecolors": facecolors, "edgecolor": None}

        plt3d.plot_surface3d(ax, x, y, z, **options, alpha = alpha, shade = False,
                    xlabel = 'x', ylabel = 'y', zlabel = 'z')

        plt3d.set_cubic_scale(ax, maxx)
        ax.set_title(f"{title} frame: {frame}/{nmaxiter}")

        if fplot: plt.pause(0.1)

#        if frame >= nmaxiter - 1: 
#            ani.event_source.stop()
#            return ax
        
        return ax

    def on_close(event):
        nonlocal window_closed

        window_closed = True

    fig.canvas.mpl_connect('close_event', on_close)

    ani = FuncAnimation(fig, update_animation, frames = nmaxiter, interval = wait)
#ani = FuncAnimation(fig, update, frames=nmaxiter, init_func=init, blit=False, repeat=False)

    outfile = f'plot_wf_{mode}_{functype}.gif'
    if fsave:
        print(f"saving to {outfile}")
        ani.save(outfile, writer='imagemagick')

    if fplot:
        input("Press Enter to close the window...")

#    plt.show()

    return None


def plot_rsurface():
    n, l, m = nlm
    orb_name = wf.get_orb_name(n, l, m)
    n, l, m = wf.get_qnumbers(n, l, m)

    print()
    print(f"Plot R(theta, phi) for {orb_name}  n, l, m=({n}, {l}, {m})")

    theta = np.linspace(0, pi, 101)
    phi   = np.linspace(0, 2.0 * pi, 101)
    Theta2d, Phi2d = np.meshgrid(theta, phi)
    r  = wf.Ylm_real(Theta2d, Phi2d, l, m)
    ra = abs(r)
    z   = ra * cos(Theta2d)
    rxy = ra * sin(Theta2d)
    x   = rxy * cos(Phi2d)
    y   = rxy * sin(Phi2d)

    alpha = 0.7


# カスタムカラーマップの作成
    color_list = colors_wf
    cmap_name = cmap_name_wf
    custom_cmap = LinearSegmentedColormap.from_list(cmap_name, color_list, N = nbins)
# rの値にcmapを対応させる
    norm = Normalize(vmin = min(r), vmax = max(r))
    face_colors = custom_cmap(norm(r))

    fig = plt.figure()
    ax = fig.add_subplot(111, projection = '3d')

#    options = {"colors": colors}
#    options = {"color": color}
#    options = {"cmap": cmap, "edgecolor": None}
    options = {"facecolors": face_colors, "edgecolor": None}
#    options = {"facecolors": colors, "edgecolor": None}
    plt3d.plot_surface3d(ax, x, y, r, **options, alpha = alpha, shade = False,
                    xlabel = 'x', ylabel = 'y', zlabel = 'z')

    maxx = plt3d.get_max_xyz(x, y, z)
    plt3d.plot_contours_xyz_by_func(ax, lambda x, y, z: wf.Ylm_xyz_real(x, y, z, l, m), minx = -maxx, maxx = maxx, nmesh = 100,
                            posx = 0.1, posy = 0.1, posz = 0.1,
                            offsetx = -maxx, offsety = maxx, offsetz = -maxx, 
                            cmap = custom_cmap, levels = 20, alpha = 0.3)

    plt3d.set_cubic_scale(ax, maxx)

    plt3d.show_color_bar(fig, ax, scale = z, cmap = custom_cmap, shrink = 0.5, aspect = 10.0)

    return plt


# リジェクションサンプリングの受け入れ率を計算する関数
def estimate_max_f(func, minx = -3, maxx = -3, nsamples_estimate = 1000):
    x_samples, y_samples, z_samples = np.random.uniform(minx, maxx, (3, nsamples_estimate))
    f_vals = func(x_samples, y_samples, z_samples)
    maxf = np.max(f_vals)

    return maxf

def rejection_sampling(func, func2, M = None,  minx = -3, maxx = 3, nsamples = 10000, nsamples_estimate = 1000):
    if M is None: M = estimate_max_f(lambda x, y, z: abs(func(x, y, z)), minx, maxx, nsamples_estimate)

    x_list = []
    y_list = []
    z_list = []
    f_list = []
    phase_list = []
    for _ in range(nsamples):
        x, y, z = np.random.uniform(minx, maxx, 3)
        f_val, p_val = func2(x, y, z)
        if np.random.uniform(0, M) < abs(f_val):
            x_list.append(x)
            y_list.append(y)
            z_list.append(z)
            f_list.append(f_val)
            phase_list.append(p_val)

    return np.array(x_list), np.array(y_list), np.array(z_list), np.array(f_list), np.array(phase_list), M

def plot_dot(plot_contours = True):
    nsamples_estimate = 1000
    n, l, m = nlm

    maxx_list = [None, 3.0, 5.0, 10.0, 15.0, 25.0]
    kmaxx_list = [0.5, 0.7, 1.0, 1.0, 1.0, 1.0]
    maxx = maxx_list[n] * kmaxx_list[l]
    minx = -maxx

    offsetx = -maxx
    offsety = maxx
    offsetz = -maxx
    posx = 0.1
    posy = 0.1
    posz = 0.1
    
    if orb_type == 'r':
        title = r'$\psi_r$'
    elif orb_type == 'i':
        title = r'$\psi_i$'
    elif orb_type == 'a':
        title = r'|$\psi$|'
    elif orb_type == 'a2':
        title = r'|$\psi_i$|$^2$'

    print()
    print(f"Plot wave function with dots: n,l,m=({n},{l},{m})")
    print(f"nsamples={nsamples}")


#    func = lambda x, y, z: wf.psi_xyz_real(x, y, z, n, l, m)
    def func(x, y, z):
        f = wf.psi_xyz(x, y, z, n, l, m)
        f, phase = wf.get_by_type(f, orb_type)
        return f

    def func2(x, y, z):
        f = wf.psi_xyz(x, y, z, n, l, m)
        f, phase = wf.get_by_type(f, orb_type)
        return f, phase

    x_sampled, y_sampled, z_sampled, f_sampled, phase_sampled, M = rejection_sampling(func, func2, None, minx, maxx, nsamples, nsamples_estimate)
    nsampled = f_sampled.size

# 正規化の重み
    weights = np.abs(f_sampled)
    weights /= np.sum(weights)

    print(f"nsampled={nsampled}")
    print("Estimated max func:", M)

    alpha_s = 0.2
    alpha_c = 0.1

# 3Dプロットで色を変える
    fig = plt.figure(figsize = (10, 10))
    ax = fig.add_subplot(111, projection = '3d')

# phaseに基づいて色を設定
    colors = colors_phase
    custom_cmap = LinearSegmentedColormap.from_list(cmap_name_phase, colors, N = nbins)
    norm = Normalize(vmin = -pi, vmax = pi)
    scatter = plt3d.plot_scatter3d(ax, x_sampled, y_sampled, z_sampled, minx, maxx, minx, maxx, minx, maxx, 
                    size = 1.0, c = phase_sampled, cmap = custom_cmap, norm = norm, alpha = alpha_s)

    pad = 0.1
    if True:
        scmap = plt.cm.ScalarMappable(cmap = custom_cmap, norm = norm)
        scmap.set_array(phase_sampled)
        cbar = fig.colorbar(scmap, ax = ax, shrink = 0.5, aspect = 10.0, pad = pad)
        cbar.set_label('phase')

    if plot_contours:
#       maxx = plt3d.get_max_xyz(x1, y1, z1)
# カスタムカラーマップの作成
        color_list = colors_wf
        custom_cmap = LinearSegmentedColormap.from_list(cmap_name_wf, color_list, N = nbins)

        fr1d = f_sampled.flatten()
        minF = min(fr1d)
        maxF = max(fr1d)

# rの値にcmapを対応させる
        norm = Normalize(vmin = minF, vmax = maxF)

        plt3d.plot_contours_xyz_by_func(ax, func, minx = -maxx, maxx = maxx, nmesh = 100,
                            posx = posx, posy = posy, posz = posz,
                            offsetx = -maxx, offsety = maxx, offsetz = -maxx, 
                            cmap = custom_cmap, levels = 20, alpha = alpha_c)

        scmap = plt.cm.ScalarMappable(cmap = custom_cmap, norm = norm)
        scmap.set_array(f_sampled)
        cbar = fig.colorbar(scmap, ax = ax, shrink = 0.5, aspect = 10.0, pad = pad)
        cbar.set_label(title)

    plt3d.set_cubic_scale(ax, maxx)

    plt.title(f"dot plot of {title} for nlm={n}{l}{m}")

    return plt

def plot_dotanim(plot_contours = True):
    nsamples_estimate = 1000
    n, l, m = nlm

    maxx_list = [None, 3.0, 5.0, 10.0, 15.0, 25.0]
    kmaxx_list = [0.5, 0.7, 1.0, 1.0, 1.0, 1.0]
    maxx = maxx_list[n] * kmaxx_list[l]
    minx = -maxx

    if orb_type == 'r':
        title = r'$\psi_r$'
    elif orb_type == 'i':
        title = r'$\psi_i$'
    elif orb_type == 'a':
        title = r'|$\psi$|'
    elif orb_type == 'a2':
        title = r'|$\psi_i$|$^2$'

    print()
    print(f"Animation of wave function with dots: n,l,m=({n},{l},{m})")
    print(f"nsamples={nsamples}")

    def func(x, y, z):
        f = wf.psi_xyz(x, y, z, n, l, m)
        f, phase = wf.get_by_type(f, orb_type)
        return f

    def func2(x, y, z):
        f = wf.psi_xyz(x, y, z, n, l, m)
        f, phase = wf.get_by_type(f, orb_type)
        return f, phase

    x_sampled, y_sampled, z_sampled, f_sampled, phase_sampled, M = rejection_sampling(func, func2, None, minx, maxx, nsamples, nsamples_estimate)
    nsampled = f_sampled.size
    print(f"nsampled={nsampled}")
    print("Estimated max func:", M)

# plot
    fig = plt.figure(figsize = (10, 10))
    ax = fig.add_subplot(111, projection = '3d')

# phaseに基づいて色を設定
    colors = colors_phase
    custom_cmap = LinearSegmentedColormap.from_list(cmap_name_phase, colors, N = nbins)
    norm = Normalize(vmin = -pi, vmax = pi)

    scatter = plt3d.plot_scatter3d(ax, x_sampled, y_sampled, z_sampled, minx, maxx, minx, maxx, minx, maxx, 
                    size = 1.0, c = phase_sampled, cmap = custom_cmap, norm = norm)

    scmap = plt.cm.ScalarMappable(cmap = custom_cmap, norm = norm)
    scmap.set_array(phase_sampled)
    cbar = fig.colorbar(scmap, ax = ax, shrink = 0.5, aspect = 10.0)
    cbar.set_label('phase')

    plt3d.set_cubic_scale(ax, maxx)
    plt.title(f"dot plot of {title} for nlm={n}{l}{m}")

    nmaxiter = 62
    wait = 5 # ms
    ani = None
    window_closed = False
    def update_animation(frame):
        nonlocal window_closed, ani

        if window_closed :
            if ani.event_source: ani.event_source.stop()
            return ax

        if frame == 0 and not hasattr(update_animation, 'initialized'):
            update_animation.initialized = True
            return
        if frame % 10 == 0: print(f"frame: {frame}/{nmaxiter}")

        ax.cla()  # 軸の設定を保持するために全体をクリア

        def func2(x, y, z):
            f = wf.psi_xyz(x, y, z, n, l, m, phase = frame * 0.1)
            f, phase = wf.get_by_type(f, orb_type)
            return f, phase

        x_sampled, y_sampled, z_sampled, f_sampled, phase_sampled, _M \
                = rejection_sampling(None, func2, M, minx, maxx, nsamples, nsamples_estimate)

        plt3d.plot_scatter3d(ax, x_sampled, y_sampled, z_sampled, minx, maxx, minx, maxx, minx, maxx, 
                    size = 1.0, c = phase_sampled, cmap = custom_cmap, norm = norm)

        plt3d.set_cubic_scale(ax, maxx)
        ax.set_title(f"{title} frame: {frame}/{nmaxiter}")

        if fplot: plt.pause(0.1)

        return ax

    def on_close(event):
        nonlocal window_closed
        window_closed = True

    fig.canvas.mpl_connect('close_event', on_close)

    ani = FuncAnimation(fig, update_animation, frames = nmaxiter, interval = wait)

    outfile = f'plot_wf_{mode}_{functype}.gif'
    if fsave:
        print(f"saving to {outfile}")
        ani.save(outfile, writer='imagemagick')


    if fplot: 
#    plt.show()
        plt.pause(0.1)
        input("Press Enter to close the window...")

    return plt


def plot_iso3d():
    n, l, m = nlm
    orb_name = wf.get_orb_name(n, l, m)
    n, l, m = wf.get_qnumbers(n, l, m)

    maxx_list = [None, 3.0, 5.0, 10.0, 15.0, 25.0]
    kmaxx_list = [0.5, 0.7, 1.0, 1.0, 1.0, 1.0]
    maxx = maxx_list[n] * kmaxx_list[l]

    x = np.linspace(-maxx, maxx, 50)
    y = np.linspace(-maxx, maxx, 50)
    z = np.linspace(-maxx, maxx, 50)
    X, Y, Z = np.meshgrid(x, y, z)

    def func(x, y, z):
        f = wf.psi_xyz(x, y, z, n, l, m)
        f, phase = wf.get_by_type(f, orb_type)
        return f, phase

    def func1(x, y, z):
        return func(x, y, z)[0]

    F, phase = func(X, Y, Z)

# 等値面のレベル
    fr1d = F.flatten()
    minF = min(fr1d)
    maxF = max(fr1d)
    print(f"F range: {minF} - {maxF}")
    levels = [-0.6 * maxF, -0.3 * maxF, 0.3 * maxF, 0.6 * maxF]
    colors = ['blue', 'cyan', 'pink', 'red']  # 等値面の色

# 3D プロット
    fig = plt.figure(figsize = (8, 8))
    ax = fig.add_subplot(111, projection='3d')

    dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
    plt3d.plot_isosurface3d(ax, X, Y, Z, F, levels = levels, origin = [x.min(), y.min(), z.min()], spacing = [dx, dy, dz],
            colors = colors, edgecolor = 'k', alpha = 0.3, linewidth = 0.1)

# 2D等高線
# カスタムカラーマップの作成
    color_list = colors_wf
    cmap_name = cmap_name_wf
    custom_cmap = LinearSegmentedColormap.from_list(cmap_name, color_list, N = nbins)
# rの値にcmapを対応させる
#    norm = Normalize(vmin = -maxF, vmax = maxF)
    norm = Normalize(vmin = minF, vmax = maxF)
    face_colors = custom_cmap(norm(F))

    maxx = plt3d.get_max_xyz(X, Y, Z)
    plt3d.plot_contours_xyz_by_func(ax, func1, minx = -maxx, maxx = maxx, nmesh = 100,
                            posx = 0.1, posy = 0.1, posz = 0.1,
                            offsetx = -maxx, offsety = maxx, offsetz = -maxx, 
                            cmap = custom_cmap, levels = 20, alpha = 0.3)

    plt3d.set_cubic_scale(ax, maxx)

    scmap = plt.cm.ScalarMappable(cmap = custom_cmap, norm = norm)
    scmap.set_array(F)
    cbar = plt.colorbar(scmap, ax=ax, shrink=0.5, aspect=10.0)
    plt.title(f"isosurfaces for nlm={n}{l}{m}")

    return plt

def plot_iso3danim():
    n, l, m = nlm
    orb_name = wf.get_orb_name(n, l, m)
    n, l, m = wf.get_qnumbers(n, l, m)

    if orb_type == 'r':
        title = r'$\psi_r$'
    elif orb_type == 'i':
        title = r'$\psi_i$'
    elif orb_type == 'a':
        title = r'|$\psi$|'
    elif orb_type == 'a2':
        title = r'|$\psi_i$|$^2$'

    maxx_list = [None, 3.0, 5.0, 10.0, 15.0, 25.0]
    kmaxx_list = [0.5, 0.7, 1.0, 1.0, 1.0, 1.0]
    maxx = maxx_list[n] * kmaxx_list[l]

    x = np.linspace(-maxx, maxx, 50)
    y = np.linspace(-maxx, maxx, 50)
    z = np.linspace(-maxx, maxx, 50)
    X, Y, Z = np.meshgrid(x, y, z)

    def func(x, y, z):
        f = wf.psi_xyz(x, y, z, n, l, m)
        f, phase = wf.get_by_type(f, orb_type)
        return f, phase

    def func1(x, y, z):
        return func(x, y, z)[0]

    F, phase = func(X, Y, Z)

# 等値面のレベル
    fr1d = F.flatten()
    minF = min(fr1d)
    maxF = max(fr1d)
    print(f"F range: {minF} - {maxF}")
    levels = [-0.6 * maxF, -0.3 * maxF, 0.3 * maxF, 0.6 * maxF]
    colors = ['blue', 'cyan', 'pink', 'red']  # 等値面の色

# 3D プロット
    fig = plt.figure(figsize = (8, 8))
    ax = fig.add_subplot(111, projection='3d')

    dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
    plt3d.plot_isosurface3d(ax, X, Y, Z, F, levels = levels, origin = [x.min(), y.min(), z.min()], spacing = [dx, dy, dz],
            colors = colors, edgecolor = 'k', alpha = 0.3, linewidth = 0.1)

    plt3d.set_cubic_scale(ax, maxx)
    plt.title(f"isosurfaces for nlm={n}{l}{m}")

    nmaxiter = 62
    wait = 5 # ms
    ani = None
    window_closed = False
    def update_animation(frame):
        nonlocal window_closed, ani

        if window_closed :
            if ani.event_source: ani.event_source.stop()
            return ax

        if frame == 0 and not hasattr(update_animation, 'initialized'):
            update_animation.initialized = True
            return
        if frame % 10 == 0: print(f"frame: {frame}/{nmaxiter}")

        ax.cla()  # 軸の設定を保持するために全体をクリア

        def func(x, y, z):
            f = wf.psi_xyz(x, y, z, n, l, m, phase = frame * 0.1)
            f, phase = wf.get_by_type(f, orb_type)
            return f, phase

        F, phase = func(X, Y, Z)
        plt3d.plot_isosurface3d(ax, X, Y, Z, F, levels = levels, origin = [x.min(), y.min(), z.min()], spacing = [dx, dy, dz],
                colors = colors, edgecolor = 'k', alpha = 0.3, linewidth = 0.1)

        plt3d.set_cubic_scale(ax, maxx)
        ax.set_title(f"{title} frame: {frame}/{nmaxiter}")

        if fplot: plt.pause(0.1)

        return ax

    def on_close(event):
        nonlocal window_closed
        window_closed = True

    fig.canvas.mpl_connect('close_event', on_close)

    ani = FuncAnimation(fig, update_animation, frames = nmaxiter, interval = wait)

#    plt.show()

    outfile = f'plot_wf_{mode}_{functype}.gif'
    if fsave:
        print(f"saving to {outfile}")
        ani.save(outfile, writer='imagemagick')

    if fplot: 
#    plt.show()
#        plt.pause(0.1)
        input("Press Enter to close the window...")

    return plt


def make_3d_density_map(output_file="EDens.vasp", Z=1.0, n=1, l=0, m=0, lx=10.0, ly=10.0, lz=10.0, step=0.2):
    print("\nCalc 3D Electron Density map\n")

    OrbName = wf.GetOrbName(n, l, m)
    print(f"  Z={Z}")
    print(f"  n,l,m={n},{l},{m}")
    print(f"  Orbital: {OrbName}")
    output_file = output_file.replace("{Orb}", OrbName)
    print(f"OutFile: [{output_file}]")

    try:
        out = open(output_file, "w")
    except IOError:
        print(f"Error: Can not write to [{output_file}].")
        return -1

    nx = int(lx / step + 0.1)
    ny = int(ly / step + 0.1)
    nz = int(lz / step + 0.1)
    print(f"Range: {lx} {ly} {lz}")
    print(f"nMesh: {nx} {ny} {nz}")

    out.write(f"Electron density for Z={Z}, n={n}, l={l}, m={m}\n")
    out.write(f"   {lx:12.6f}\n")
    out.write(f"     {lx/lx:12.6f}   {0.0:12.6f}   {0.0:12.6f}\n")
    out.write(f"     {0.0:12.6f}   {ly/lx:12.6f}   {0.0:12.6f}\n")
    out.write(f"     {0.0:12.6f}   {0.0:12.6f}   {lz/lx:12.6f}\n")
    out.write("  1\n")
    out.write("Direct\n")
    out.write("  0.500000  0.500000  0.500000\n")
    out.write("\n")
    out.write(f"  {nx:3d}  {ny:3d}  {nz:3d}\n")

    count = 0
    for ix in range(nx):
        for iy in range(ny):
            for iz in range(nz):
                x = ix * step
                y = iy * step
                z = iz * step
                # 中心は(0.5*lx, 0.5*ly, 0.5*lz)
                f = wf.cal_wave_function(Z, n, l, m, x, y, z, 0.5*lx, 0.5*ly, 0.5*lz)
                if f is None:
                    print(f"Invalid n,l,m({n},{l},{m}).")
                    out.close()
                    return -2

                out.write(f" {f:11.5g}")
                count += 1
                if count % 10 == 0:
                    out.write("\n")

    out.close()

# モンテカルロ法による積分（リジェクトサンプリングを用いる）
def integ_MC3d(f, x_min, x_max, y_min, y_max, z_min, z_max, n_samples):
    x_samples = np.random.uniform(x_min, x_max, n_samples)
    y_samples = np.random.uniform(y_min, y_max, n_samples)
    z_samples = np.random.uniform(z_min, z_max, n_samples)
    
    S = sum(f(x_samples, y_samples, z_samples))
    volume = (x_max - x_min) * (y_max - y_min) * (z_max - z_min)
    integral = S * volume / n_samples

    return integral

def normalize():
    print("\nCalculate |psi(x,y,z)|^2 to check normalization\n")
    print("Rnl:")
    for n in range(1, 5):
        for l in range(0, n):
            def func(r):
#                return r**2 * abs(wf.psi_r(r, 0.0, 0.0, n, l, 0))**2
                return 4 * pi * r**2 * abs(wf.psi_r(r, 0.0, 0.0, n, l, 0))**2
#                return 4 * pi * r**2 * wf.Rnl(r, n, 0)**2

            result, error = quad(func, 0, np.inf)
            print(f"  n,l={n},{l}: S={result} +- {error}")

    nsampling = 1000000
    print()
    print("psi:")
    for n in range(1, 5):
        lmax = n - 1
        for l in range(0, lmax + 1):
            print(f"Orbital: {n}{l}0: En={wf.En(n):.6g} eV", end = '')

            def psi2(x, y, z):
                f = wf.psi_xyz(x, y, z, n, l, 0)
                return f.real**2 + f.imag**2

            if n == 1:
                maxx = 5.0
            else:
                maxx = 20.0
            integ = integ_MC3d(psi2, -maxx, maxx, -maxx, maxx, -maxx, maxx, nsampling)
            print(f"  |psi|^2={integ}")

def main(mode):
    if mode == 'torus':
        plt = plot_torus()
    elif mode == 'Rnl':
        plt = plot_Rnl()
    elif mode == 'Ylm' or mode == 'Theta':
        plt = plot_Ylm()
    elif mode == 'Phi':
        plt = plot_Phi()
    elif mode == 'normalize':
        normalize()
        exit()
    elif mode == 'rplot':
        plt = plot_rplot()
    elif mode == 'ranim':
        plt = plot_ranim()
    elif mode == 'dot':
        plt = plot_dot(plot_contours = False)
    elif mode == 'dotc':
        plt = plot_dot(plot_contours = True)
    elif mode == 'dotanim':
        plt = plot_dotanim(plot_contours = False)
    elif mode == 'iso3d':
        plt = plot_iso3d()
    elif mode == 'iso3danim':
        plt = plot_iso3danim()
    elif mode == 'rsurface':
        plt = plot_rsurface()
    elif mode == 'isosurface3d':
        plt = plot_isosurface3d()
    elif mode == 'contour3d':
        plt = plot_contour3d()
    elif mode == '3dMC_animation':
        plt = plot_3dMC_animation()
    else:
        print(f"\nError: Invalid mode [{mode}]")
        exit()

    if plt:
        if fsave: 
            output_file = f"plot_wf_{mode}_{functype}.png"
            print()
            print(f"Save plot to {output_file}")
            plt.savefig(output_file)
        if fplot: plt.show()

if __name__ == "__main__":
    main(mode)
    usage()


'''
# ヒストグラムの表示
    if 'h' in mode:
        fig, axes = plt.subplots(2, 2)
        axes[0, 0].hist(r_sampled, bins=100, color='blue', alpha=0.7, density=True)
        axes[0, 0].set_xlabel('r')
        axes[0, 0].set_ylabel('Density')
        axes[1, 0].hist(theta, bins=100, color='blue', alpha=0.7, density=True)  # 修正
        axes[1, 0].set_xlabel('θ')
        axes[1, 0].set_ylabel('Density')
        axes[0, 1].hist(phi, bins=100, color='blue', alpha=0.7, density=True)  # 修正
        axes[0, 1].set_xlabel('φ')
        axes[0, 1].set_ylabel('Density')

        plt.tight_layout()
        plt.show()
        exit()
'''