"""
ひとつ前の「移動度に対する散乱機構の寄与率」をプロンプトにコピペ
prompt
Q: 添付のスライドから、各散乱機構の寄与率をプロットする mode=weightを実装してください
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import argparse
import os
import json

# 定数: ボルツマン定数 (eV/K)
K_B = 8.617333262e-5

# ---------------------------------------------------------
# 1. データ入出力関連の関数
# ---------------------------------------------------------
def load_hall_data(file_path):
    """ファイルの存在確認を行い、データを読み込む"""
    if not os.path.exists(file_path):
        print(f"エラー: ファイル '{file_path}' が見つかりません。")
        return None
    if file_path.endswith('.csv'):
        return pd.read_csv(file_path)
    return pd.read_excel(file_path)

def save_params(params, filename='llsq_params.json'):
    """解析結果のパラメータをJSONで保存"""
    with open(filename, 'w') as f:
        json.dump(params, f, indent=4)
    print(f"パラメータを保存しました: {filename}")

def load_params(filename='llsq_params.json'):
    """保存されたパラメータを読み込む"""
    if not os.path.exists(filename):
        print(f"エラー: '{filename}' がありません。先に mode=llsq を実行してください。")
        return None
    with open(filename, 'r') as f:
        return json.load(f)

# ---------------------------------------------------------
# 2. 物理モデルと解析関連の関数
# ---------------------------------------------------------
def get_inv_mu_components(T, params, Eop):
    """各散乱機構の逆移動度 (1/mu) を個別に計算する"""
    aop, a1, a2, a3, VB = params
    
    # 各散乱機構の基底関数
    f_op = 1.0 / (np.exp(Eop / (K_B * T)) - 1.0) # 光学フォノン
    f_ac = T**1.5                                # 音響フォノン
    f_ni = np.ones_like(T)                       # 中性不純物
    f_ii = T**-1.5                               # イオン化不純物
    
    # バルク成分
    components = {
        'Optical Phonon': aop * f_op,
        'Acoustic Phonon': a1 * f_ac,
        'Neutral Impurity': a2 * f_ni,
        'Ionized Impurity': a3 * f_ii
    }
    
    # 粒界散乱の考慮: mu_total = mu_bulk * exp(-VB/kBT)
    # => 1/mu_total = (1/mu_bulk) * exp(VB/kBT)
    inv_mu_bulk = np.maximum(sum(components.values()), 1e-10)
    exp_factor = np.exp(VB / (K_B * T))
    inv_mu_total = inv_mu_bulk * exp_factor
    
    # 粒界による散乱頻度の増加分を抽出
    components['Grain Boundary'] = inv_mu_total - inv_mu_bulk
    
    return components, inv_mu_total

def solve_llsq(T, mu_exp, Eop):
    """線形最小二乗法で aop, a1, a2, a3 を推定 (VB=0)"""
    f_op = 1.0 / (np.exp(Eop / (K_B * T)) - 1.0)
    f_ac = T**1.5
    f_ni = np.ones_like(T)
    f_ii = T**-1.5
    X = np.column_stack([f_op, f_ac, f_ni, f_ii])
    y = 1.0 / mu_exp
    coeffs, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
    return coeffs

# ---------------------------------------------------------
# 3. 可視化関連の関数
# ---------------------------------------------------------
def visualize_fit(T, mu_exp, mu_fit=None, title='Hall Mobility Fit', save_name='plot.png'):
    """実験データとフィッティング曲線の比較プロット"""
    plt.figure(figsize=(8, 6))
    plt.scatter(T, mu_exp, color='red', label='Experimental', alpha=0.6)
    if mu_fit is not None:
        idx = np.argsort(T)
        plt.plot(T[idx], mu_fit[idx], color='blue', label='Model Fit', linewidth=2)
    plt.xlabel('Temperature (K)')
    plt.ylabel('Mobility (cm²/Vs)')
    plt.yscale('log')
    plt.title(title)
    plt.legend()
    plt.grid(True, which='both', alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_name)
    plt.show()

def visualize_weights(T, components, total, save_name='weight_plot.png'):
    """各散乱機構の寄与率を100%積み上げグラフで表示"""
    plt.figure(figsize=(10, 6))
    names = list(components.keys())
    # 寄与率(%)に変換
    weights = [components[name] / total * 100 for name in names]
    
    idx = np.argsort(T)
    plt.stackplot(T[idx], [w[idx] for w in weights], labels=names, alpha=0.8)
    
    plt.xlabel('Temperature (K)')
    plt.ylabel('Contribution to Scattering (%)')
    plt.title('Scattering Mechanism Weights')
    plt.xlim(min(T), max(T))
    plt.ylim(0, 100)
    plt.legend(loc='upper right', bbox_to_anchor=(1.25, 1))
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_name)
    plt.show()

# ---------------------------------------------------------
# 4. メイン処理
# ---------------------------------------------------------
def main():
    parser = argparse.ArgumentParser(description='Hall効果解析 統合ツール')
    parser.add_argument('--input', type=str, default='Hall-T1.xlsx', help='入力ファイル名')
    parser.add_argument('--temp_col', type=int, default=0, help='温度列(0開始)')
    parser.add_argument('--mu_col', type=int, default=2, help='移動度列(0開始)')
    parser.add_argument('--mode', type=str, choices=['read', 'llsq', 'fit', 'weight'], default='read')
    parser.add_argument('--eop', type=float, default=0.045, help='光学フォノンエネルギー(eV)')
    args = parser.parse_args()

    df = load_hall_data(args.input)
    if df is None: return

    T = df.iloc[:, args.temp_col].values
    mu_exp = df.iloc[:, args.mu_col].values

    if args.mode == 'read':
        print(df)
        visualize_fit(T, mu_exp, title='Experimental Data Only')

    elif args.mode == 'llsq':
        c = solve_llsq(T, mu_exp, args.eop)
        params_dict = {'aop': c[0], 'a1': c[1], 'a2': c[2], 'a3': c[3], 'VB': 0.0}
        save_params(params_dict)
        _, mu_fit = get_inv_mu_components(T, list(c) + [0.0], args.eop)
        visualize_fit(T, mu_exp, 1/mu_fit, title='LLSQ Initial Fit (VB=0)')

    elif args.mode == 'fit':
        p_base = load_params()
        if p_base is None: return
        init = [p_base['aop'], p_base['a1'], p_base['a2'], p_base['a3'], 0.0]
        
        def objective(p):
            _, inv_total = get_inv_mu_components(T, p, args.eop)
            return np.sum((mu_exp - 1/inv_total)**2)
        
        print("最適化中 (Nelder-Mead)...")
        res = minimize(objective, init, method='Nelder-Mead')
        labels = ['aop', 'a1', 'a2', 'a3', 'VB']
        final_params = {l: v for l, v in zip(labels, res.x)}
        save_params(final_params, filename='fit_params.json')
        
        print("\n--- 最適化パラメータ ---")
        for k, v in final_params.items(): print(f"{k:4s}: {v:.4e}")
        
        _, inv_fit = get_inv_mu_components(T, res.x, args.eop)
        visualize_fit(T, mu_exp, 1/inv_fit, title=f'Final Fit (VB={res.x[4]:.4e} eV)')

    elif args.mode == 'weight':
        # fit_params.json があれば優先、なければ llsq_params.json
        fname = 'fit_params.json' if os.path.exists('fit_params.json') else 'llsq_params.json'
        p_dict = load_params(fname)
        if p_dict is None: return
        p_list = [p_dict['aop'], p_dict['a1'], p_dict['a2'], p_dict['a3'], p_dict.get('VB', 0.0)]
        comp, total = get_inv_mu_components(T, p_list, args.eop)
        visualize_weights(T, comp, total)

if __name__ == '__main__':
    main()