import os
import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import simpson

# tkvaspライブラリのインポート設定
os.environ['tkProg_Root'] = 'd:/git/tkProg'
os.environ['tklib'] = 'd:/git/tkProg/tklib/python'
os.environ['tklib'] = 'd:/git/tkProg/tklib/python'
sys.path.append('d:/git/tkProg/tklib/python')

from tklib.tkcrystal.tkvasp import tkVASP

def fermi_dirac(E, Ef, T):
    """
    フェルミ・ディラック分布関数 (数値的安定性を向上)
    """
    k_B = 8.617333262e-5  # Boltzmann定数 (eV/K)
    arg = (E - Ef) / (k_B * T)
    
    # 指数が大きすぎる場合のオーバーフローを避ける
    # exp(x) -> inf の場合、1 / (inf + 1) -> 0 に収束
    arg_threshold = 700.0  
    
    # NumPyのクリッピング機能を使用して引数の範囲を制限
    clipped_arg = np.clip(arg, a_min=None, a_max=arg_threshold)
    
    return 1.0 / (np.exp(clipped_arg) + 1.0)

def calculate_concentrations(dos_info, fermi_energy_list, T):
    """
    指定されたフェルミエネルギーのリストに対するキャリア濃度を計算する。
    """
    energy = np.array(dos_info['E'])
    total_dos = np.array(dos_info['TotalDOS'])
    total_dos[total_dos < 0] = 0.0

    electron_concentrations = []
    hole_concentrations = []

    for Ef in fermi_energy_list:
        # VB内の正孔濃度 (1 - f(E))
        f_holes = 1.0 - fermi_dirac(energy, Ef, T)
        hole_conc = simpson(total_dos * f_holes, x=energy)
        hole_concentrations.append(hole_conc)

        # CB内の電子濃度 (f(E))
        f_electrons = fermi_dirac(energy, Ef, T)
        electron_conc = simpson(total_dos * f_electrons, x=energy)
        electron_concentrations.append(electron_conc)

    return np.array(electron_concentrations), np.array(hole_concentrations)

def calculate_seebeck_coefficient(dos_info, fermi_energy_list, T):
    """
    Constant Relaxation Time Approximation (CRTA) を用いてSeebeck係数を計算する。
    """
    energy = np.array(dos_info['E'])
    total_dos = np.array(dos_info['TotalDOS'])
    total_dos[total_dos < 0] = 0.0
    
    # フェルミ・ディラック関数のエネルギー微分
    k_B = 8.617333262e-5  # Boltzmann定数 (eV/K)
    
    seebeck_coefficients = []

    print(f"{'E_F (eV)':>12} | {'Ne (cm^-3)':>12} | {'Nh (cm^-3)':>12} | {'S (uV/K)':>12}")
    print("-" * 55)

    for Ef in fermi_energy_list:
        # オーバーフロー対策
        arg = (energy - Ef) / (k_B * T)
        
        # 指数が大きくなりすぎる領域は0と見なす
        # exp(-x)はxが大きいと0に収束するため、exp(x)の引数xが小さい場合にのみ計算する
        f_prime = np.zeros_like(arg)
        valid_indices = arg < 500  # 経験的に安全な閾値
        f_prime[valid_indices] = -1.0 / (k_B * T) * np.exp(arg[valid_indices]) / (np.exp(arg[valid_indices]) + 1)**2

        # 輸送積分L_0とL_1を計算
        L0 = simpson(f_prime * total_dos, x=energy)
        L1 = simpson(f_prime * (energy - Ef) * total_dos, x=energy)
        
        # Seebeck係数を計算
        if L0 != 0:
            seebeck = 1.0 / (k_B * T) * L1 / L0
        else:
            seebeck = 0.0
        
        # 単位をV/KからμV/Kに変換
        seebeck_coefficients.append(seebeck * 1.0e6)
        
        # キャリア濃度も計算し、コンソールに出力
        Ne = simpson(fermi_dirac(energy, Ef, T) * total_dos, x=energy)
        Nh = simpson((1.0 - fermi_dirac(energy, Ef, T)) * total_dos, x=energy)
        print(f"{Ef:12.6f} | {Ne:12.6g} | {Nh:12.6g} | {seebeck * 1.0e6:12.6f}")

    return np.array(seebeck_coefficients)

def main():
    parser = argparse.ArgumentParser(description='Plot DOS or carrier concentrations from VASP DOSCAR.')
    parser.add_argument('path', type=str, help='Path to VASP calculation directory (e.g., ./vasp_run_dir/)')
    parser.add_argument('--mode', type=str, default='dos', choices=['dos', 'n', 'seebeck'], help='Plot mode: "dos" for total DOS, "n" for carrier concentrations vs. Fermi level, "seebeck" for Seebeck coefficient.')
    parser.add_argument('--T', type=float, default=300.0, help='Temperature in Kelvin for calculation (default: 300K).')
    args = parser.parse_args()

    # ファイルパスの定義
    base_path = args.path
    poscar_path = os.path.join(base_path, 'POSCAR')
    doscar_path = os.path.join(base_path, 'DOSCAR')
    outcar_path = os.path.join(base_path, 'OUTCAR')

    if not all(os.path.exists(p) for p in [poscar_path, doscar_path]):
        print(f"Error: POSCAR or DOSCAR not found in directory '{base_path}'")
        sys.exit(1)

    # VASPファイル読み込み
    vasp_reader = tkVASP()
    poscar_inf = vasp_reader.read_poscar_inf(poscar_path)
    doscar_inf = vasp_reader.read_doscar(doscar_path, normalize_E=True, unit='/cm3')
    
    outcar_inf = {}
    if os.path.exists(outcar_path):
        outcar_inf = vasp_reader.read_outcar_inf(outcar_path)
    else:
        print(f"Warning: OUTCAR not found in directory '{base_path}'. NELECT will be unavailable.")

    if not doscar_inf or 'E' not in doscar_inf:
        print("Error: Failed to read DOSCAR data.")
        sys.exit(1)
    
    # OUTCARから全電子数を取得
    NELECT = outcar_inf.get('NELECT', None)

    # EVBMとECBMをtkVASPのfind_band_edges_from_dos()で取得
    dos_energy = doscar_inf['E']
    dos_data = doscar_inf['TotalDOS']
    EF_vasp = doscar_inf['Efermi']
    dos_threshold = 1.0e18

    EVBM, ECBM = vasp_reader.find_band_edges_from_dos(dos_energy, dos_data, EF0=EF_vasp, DOSth=dos_threshold)

    print(f"EVBM: {EVBM:.3f} eV")
    print(f"ECBM: {ECBM:.3f} eV")
    print(f"Band Gap: {ECBM - EVBM:.3f} eV")
    if NELECT is not None:
        print(f"NELECT: {NELECT}")

    # プロットの実行
    if args.mode == 'dos':
        fig, ax1 = plt.subplots(figsize=(10, 6))

        # DOSを左y軸にプロット
        color = 'tab:blue'
        ax1.set_xlabel('Energy (eV)')
        ax1.set_ylabel('DOS (states/cm$^3$/eV)', color=color)
        ax1.plot(dos_energy, dos_data, color=color, label='Total DOS')
        ax1.fill_between(dos_energy, 0, dos_data, color='gray', alpha=0.3)
        ax1.tick_params(axis='y', labelcolor=color)

        # 積分電子数Nintを計算 (台形公式による累積積分)
        integrated_dos = np.cumsum((dos_data[:-1] + dos_data[1:]) / 2.0 * np.diff(dos_energy))
        integrated_dos = np.insert(integrated_dos, 0, 0.0)
        
        # 積分電子数を右y軸にプロット
        ax2 = ax1.twinx()
        color = 'tab:red'
        ax2.set_ylabel('Integrated Electron Count (states/cm$^3$)', color=color)
        ax2.plot(dos_energy, integrated_dos, color=color, linestyle='--', label='Integrated DOS')
        ax2.tick_params(axis='y', labelcolor=color)
        
        # 垂直線の追加
        ax1.axvline(x=EVBM, color='blue', linestyle='--', label=f'EVBM ({EVBM:.3f} eV)')
        ax1.axvline(x=ECBM, color='red', linestyle='--', label=f'ECBM ({ECBM:.3f} eV)')
        ax1.axvline(x=doscar_inf['EF'], color='green', linestyle=':', label=f'Fermi Level ({doscar_inf["EF"]:.3f} eV)')
        if NELECT is not None:
            ax2.axhline(y=NELECT, color='purple', linestyle='-.', label=f'NELECT ({NELECT})')

        # グラフの装飾
        plt.title('Total Density of States and Integrated Electron Count')
        ax1.grid(True)
        ax1.legend(loc='upper left')
        ax2.legend(loc='upper right')
        plt.show()

    elif args.mode == 'n':
        fermi_level_range = np.linspace(EVBM - 2.0, ECBM + 2.0, 200)
        e_conc, h_conc = calculate_concentrations(doscar_inf, fermi_level_range, args.T)

        plt.figure(figsize=(10, 6))
        plt.plot(fermi_level_range, e_conc, label='Electron Concentration ($n$)', color='red')
        plt.plot(fermi_level_range, h_conc, label='Hole Concentration ($p$)', color='blue')
        plt.axvline(x=(EVBM + ECBM) / 2.0, color='gray', linestyle='--', label='Midgap')

        # グラフの装飾
        plt.title(f'Carrier Concentration vs. Fermi Level (T = {args.T} K)')
        plt.xlabel('Fermi Energy ($E_F$) relative to $E_{VBM}$')
        plt.ylabel('Concentration (cm$^{-3}$)')
        plt.yscale('log')  # キャリア濃度は対数スケールで表示
        plt.grid(True, which="both", linestyle=':')
        plt.legend()
        plt.show()

    elif args.mode == 'seebeck':
        fermi_level_range = np.linspace(EVBM - 2.0, ECBM + 2.0, 200)
        seebeck_coefficients = calculate_seebeck_coefficient(doscar_inf, fermi_level_range, args.T)
        
        plt.figure(figsize=(10, 6))
        plt.plot(fermi_level_range, seebeck_coefficients, color='green')
        plt.axvline(x=EVBM, color='blue', linestyle='--', label=f'EVBM ({EVBM:.3f} eV)')
        plt.axvline(x=ECBM, color='red', linestyle='--', label=f'ECBM ({ECBM:.3f} eV)')
        plt.axhline(y=0.0, color='black', linestyle='-')
        
        # グラフの装飾
        plt.title(f'Seebeck Coefficient vs. Fermi Level (T = {args.T} K)')
        plt.xlabel('Fermi Energy ($E_F$) relative to $E_{VBM}$')
        plt.ylabel('Seebeck Coefficient ($μV/K$)')
        plt.grid(True)
        plt.legend()
        plt.show()

if __name__ == "__main__":
    main()