import os
import sys
import argparse
from types import SimpleNamespace
import numpy as np
from numpy import sin, cos, tan, pi, exp, log, sqrt
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
import plotly.graph_objects as go
import plotly.io as pio
import webbrowser
import tempfile


# Physical constants
e0 = 8.854418782e-12  # C^2/N/m^2
e  = 1.602176634e-19  # C

outfile = None


def parse_arguments():
    parser = argparse.ArgumentParser(description="TFT Simulation")
    parser.add_argument("--mode", type=str, choices=["IV"], default="IV",
                        help="Simulation mode")
    parser.add_argument("--model", type=str, choices=["fet", "tft"], default="fet",
                        help="Device model")
    parser.add_argument("--output", type=str, choices=["matplotlib", "plotly"], default="matplotlib",
                        help="Output mode")

    parser.add_argument("--EV", type=float, default=0.0)
    parser.add_argument("--EC", type=float, default=1.12)
#    parser.add_argument("--ND", type=float, default=1.0e15)
#    parser.add_argument("--ED", type=float, default=2.95)
    parser.add_argument("--NA", type=float, default=1.0e16 * 1.0e6, help = "Acceptor density [m^-3]")
#    parser.add_argument("--EA", type=float, default=0.05)

    parser.add_argument("--psi_S", type=float, default=0.2)
    parser.add_argument("--psi_M", type=float, default=0.5)

    parser.add_argument("--ers", type=float, default=11.9)
    parser.add_argument("--mu", type=float, default=10.0e-4, help="Mobility [m^2/Vs]")

    parser.add_argument("--dg", type=float, default = 100.0e-9, help="Gate insulator thickness [m]")
    parser.add_argument("--erg", type=float, default = 3.9)

    parser.add_argument("--W", type=float, default=300.0e-6, help="Channel width [m]")
    parser.add_argument("--L", type=float, default=50.0e-6, help="Channel length [m]")

#    parser.add_argument("--Vth", type=float, default=0.0)

    parser.add_argument("--Vg_min", type=float, default=0.0)
    parser.add_argument("--Vg_max", type=float, default=20.0)
    parser.add_argument("--nVg", type=int, default=51)
    parser.add_argument("--Vd_min", type=float, default=0.0)
    parser.add_argument("--Vd_max", type=float, default=20.0)
    parser.add_argument("--nVd", type=int, default=51)

    return parser.parse_args()

def save_IV(outfile, Vd_list, Vg_list, Id_VgVd_list):
    print(f"\nSave IV characteristics to {outfile}")

    # ラベルを作成
    labels = ["Vd [V]"] + [f"Id@Vg={Vg:.2f} [A]" for Vg in Vg_list]

    # データをリスト形式で準備
    data_list = list(zip(Vd_list, *Id_VgVd_list))

    # データフレームを作成
    df = pd.DataFrame(data_list, columns=labels)

    # Excelファイルに保存
    df.to_excel(outfile, index=False)

def plot_IV(Vd_list, Vg_list, Id_VdVg_list, Id_VgVd_list, Vp_list, Isat_list):
    print("\nPlot")
    fig, axes = plt.subplots(1, 2, figsize=(8, 6), dpi=100, tight_layout=True)

    ax = axes[0]
    ax.set_title("Transfer characteristics")
    for id, Vd in enumerate(Vd_list):
        color = 'blue' if id <= 4 else 'black'
        linewidth = 1.0 if id <= 4 else 0.5
        ax.plot(Vg_list, Id_VdVg_list[id], label=f"Vd = {Vd:.1f}", color=color, linewidth=linewidth)
    ax.set_xlabel("Vg [V]")
    ax.set_ylabel("Id [A]")
    ax.set_yscale("log")

    ax = axes[1]
    ax.set_title("Output characteristics")
    ntext = 5
    nskip = int(len(Vg_list) / ntext + 1.0e-4)
    for ig, Vg in enumerate(Vg_list):
        color = 'blue' if ig % nskip == 0 else 'black'
        linewidth = 1.0 if ig % nskip == 0 else 0.5
        ax.plot(Vd_list, Id_VgVd_list[ig], label=f"Vg = {Vg:.1f}", color=color, linewidth=linewidth)
    ax.plot(Vp_list, Isat_list, label="Vp", linestyle='dashed', color='red', linewidth=0.5)
    ax.set_xlabel("Vd [V]")
    ax.set_ylabel("Id [A]")

    plt.pause(0.1)
    input("\nPress Enter to terminate>>\n")

def plot_IV_plotly(Vd_list, Vg_list, Id_VdVg_list, Id_VgVd_list, Vp_list, Isat_list):
    fig = go.Figure()

    # Transfer characteristics
    for id, Vd in enumerate(Vd_list):
        color = 'blue' if id <= 4 else 'black'
        fig.add_trace(go.Scatter(x=Vg_list, y=Id_VdVg_list[id],
                                 mode='lines', name=f"Vd = {Vd:.1f}",
                                 line=dict(color=color, width=1.0 if id <= 4 else 0.5)))

    fig.update_layout(title="Transfer characteristics",
                      xaxis_title="Vg [V]",
                      yaxis_title="Id [A]",
                      yaxis_type="log",
                      font=dict(size=16),
                      showlegend=False)  # 凡例を非表示

    # Output characteristics
    fig2 = go.Figure()
    ntext = 5
    nskip = int(len(Vg_list) / ntext + 1.0e-4)

    for ig, Vg in enumerate(Vg_list):
        color = 'blue' if ig % nskip == 0 else 'black'
        fig2.add_trace(go.Scatter(x=Vd_list, y=Id_VgVd_list[ig],
                                  mode='lines', name=f"Vg = {Vg:.1f}",
                                  line=dict(color=color, width=1.0 if ig % nskip == 0 else 0.5)))

    fig2.add_trace(go.Scatter(x=Vp_list, y=Isat_list,
                              mode='lines', name="Vp",
                              line=dict(color='red', dash='dash', width=0.5)))

    fig2.update_layout(title="Output characteristics",
                       xaxis_title="Vd [V]",
                       yaxis_title="Id [A]",
                       font=dict(size=16),
                       showlegend=False)  # 凡例を非表示

    # HTMLとして出力
    html_output1 = pio.to_html(fig, full_html=False)
    html_output2 = pio.to_html(fig2, full_html=False)

    with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as temp_file:
        temp_file.write(pio.to_html(fig, full_html=True).encode('utf-8'))
        temp_filename1 = temp_file.name

    with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as temp_file:
        temp_file.write(pio.to_html(fig2, full_html=True).encode('utf-8'))
        temp_filename2 = temp_file.name

    # ブラウザを開く
#    webbrowser.open(temp_filename1)
#    webbrowser.open(temp_filename2)

    return html_output1, html_output2
    
def simulate_IV(cfg, calculate_only = False, print_level = 1):
    if cfg.model == 'tft':
        return simulate_IV_tft(cfg, calculate_only = calculate_only, print_level = print_level)
    elif cfg.model == 'fet':
        return simulate_IV_fet(cfg, calculate_only = calculate_only, print_level = print_level)
    
    print(f"Error: Invalide model [{cfg.model}]")
    exit()

def simulate_IV_fet(cfg, calculate_only = False, print_level = 1):
    print("\n# FET simulator")

    Cg = cfg.erg * e0 / cfg.dg  # (F/m^2)
    Vth_tft = cfg.VFB + 2.0 * cfg.phi_F
    Vth_fet = sqrt(4.0 * e * cfg.ers * e0 * cfg.NA * cfg.phi_F) / Cg

    KI = Cg * cfg.mu * cfg.W / cfg.L
    KI2 = -KI * 2.0 / 3.0 * sqrt(2.0 * cfg.ers * e0 * e * cfg.NA) / Cg

    print(f"EC   : {cfg.EC} eV")
    print(f"EV   : {cfg.EV} eV")
    print(f"EFi  : {cfg.EFi} eV")
    print(f"psi_S: {cfg.psi_S} eV")
    print(f"psi_M: {cfg.psi_M} eV")
    print(f"phi_F: {cfg.phi_F} eV")
    print(f"NA   : {cfg.NA * 1.0e-6} cm^-3")
    print(f"Cg   : {Cg:.6g} [F/m^2]")
    print(f"VFB  : {cfg.VFB} eV")
    print(f"Vth (TFT)             : {Vth_tft} V")
    print(f"dVth (FET correction) : {Vth_fet} V")
    if cfg.NA > 0.0:
        WMAX = sqrt(4.0 * cfg.ers * e0 * cfg.phi_F / e / cfg.NA)
        print(f"WMAX : {WMAX * 1.0e9} nm")
    print("KI :", KI)
    print("KI2:", KI2)

    eps = 1.0e-6
    Vd_list = np.linspace(cfg.Vd_min, cfg.Vd_max + eps, cfg.nVd)
    Vg_list = np.linspace(cfg.Vg_min, cfg.Vg_max + eps, cfg.nVg)
    Id_VdVg_list = np.zeros([cfg.nVd, cfg.nVg])
    Id_VgVd_list = np.zeros([cfg.nVg, cfg.nVd])
    Vp_list, Isat_list = [], []

    print("calculate Id-Vg-Vd list:")
    for ig, Vg in tqdm(enumerate(Vg_list), total=len(Vg_list)):
        Vp_tft = Vg - Vth_tft

        if cfg.NA > 0.0:
            sqrt_term = 1.0 + 2.0 * (Vg - cfg.VFB) * Cg * Cg / (cfg.ers * e0) / e / cfg.NA
            Vp_fet = cfg.ers * e0 * e * cfg.NA / Cg / Cg * (1.0 - sqrt(sqrt_term))
            Vp = Vp_tft + Vp_fet
            Vd = Vp
            I2 = KI2 * (pow(2.0 * cfg.phi_F + Vd, 1.5) - pow(2.0 * cfg.phi_F, 1.5))
        else:
            Vp_fet = 0.0
            Vp = Vp_tft
            I2 = 0.0

        Vd = Vp
        Isat = KI * (Vg - Vth_tft - 0.5 * Vd) * Vd + I2
        if print_level:
            print(f"Vg={Vg}")
            print(f"  Vp(TFT): {Vp_tft} V = {Vg} - {Vth_tft}")
            print(f"    dVp(FET correction):", Vp_fet)

        if Vp <= cfg.Vd_max:
            Vp_list.append(Vp)
            Isat_list.append(Isat)

        for id, Vd in enumerate(Vd_list):
            if Vd >= Vp:
                Vd = Vp

            if cfg.NA > 0.0:
                I2 = KI2 * (pow(2.0 * cfg.phi_F + Vd, 1.5) - pow(2.0 * cfg.phi_F, 1.5))
                I = KI * (Vg - Vth_tft - 0.5 * Vd) * Vd + I2
            else:
                I = KI * (Vg - Vth_tft - 0.5 * Vd) * Vd

            if print_level:
                print(f"  I(fet):", I)
                print(f"  dI(fet correction):", I2)
    
            Id_VdVg_list[id, ig] = I
            Id_VgVd_list[ig, id] = I

    if calculate_only:
        return Vd_list, Vg_list, Id_VdVg_list, Id_VgVd_list, Vp_list, Isat_list


    outfile = f"{cfg.model}-IV.xlsx"
    save_IV(outfile, Vd_list, Vg_list, Id_VgVd_list)

    if cfg.output == 'matplotlib':
        plot_IV(Vd_list, Vg_list, Id_VdVg_list, Id_VgVd_list, Vp_list, Isat_list)
    else:
        html1, html2 = plot_IV_plotly(Vd_list, Vg_list, Id_VdVg_list, Id_VgVd_list, Vp_list, Isat_list)
#        print("html1=", html1)
#        print("html2=", html2)


def simulate_IV_tft(cfg, calculate_only = False, print_level = 1):
    print("\n# TFT simulator")
    cfg.NA = 0.0
    return simulate_IV_fet(cfg, calculate_only = calculate_only, print_level = print_level)


def main():
    args = parse_arguments()
    cfg = SimpleNamespace()
    for k, v in vars(args).items():
        setattr(cfg, k, v)

    cfg.Eg = cfg.EC - cfg.EV
    cfg.EFi = (cfg.EC + cfg.EV) / 2.0
    cfg.Vbi = cfg.psi_S - cfg.psi_M
    cfg.VFB = cfg.psi_M - cfg.psi_S
    cfg.phi_F = cfg.EFi - cfg.psi_S

    if cfg.mode == 'IV':
        simulate_IV(cfg)
    else:
        print(f"Error: Unknown mode [{cfg.mode}]")
        sys.exit(1)


if __name__ == '__main__':
    main()
