import os
import sys
from math import sqrt
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd


e0 = 8.85e-12
e  = 1.602e-19  # C
pi = 3.1415926535


class tkPNJunction():
    def __init__(self, ND, epsn, dn, NA, epsp, dp):
        self.ND = ND
        self.epsn = epsn
        self.dn = dn
        self.NA = NA
        self.epsp = epsp
        self.dp = dp

    def full_depletion(self, Wn, Wp, V, Vbi):
        Wn = self.dn
        Wp = self.dp
        Cn = self.epsn * e0 / Wn     # F/m2
        Cp = self.epsp * e0 / Wp
        C  = 1.0 / (1.0 / Cn + 1.0 / Cp)
        Vn = (Vbi - V) * C / Cn
        Vp = (Vbi - V) * C / Cp

        return Wn, Wp, Vn, Vp, Cn, Cp, C

    def cal_depletion_WV(self, V, Vbi):
#    Wn = sqrt(2.0 * epsn * e0 / e / (ND * 1.0e6) * (Vbi - V)) * 1.0e9 # nm
#    Wp = sqrt(2.0 * epsp * e0 / e / (NA * 1.0e6) * (Vbi - V)) * 1.0e9 # nm
        Wn = sqrt(2.0 * self.epsn * e0 * (Vbi - V) / e / (1.0 / self.NA + 1.0 / self.ND) / (self.ND * self.ND * 1.0e6)) # m
        Wp = sqrt(2.0 * self.epsp * e0 * (Vbi - V) / e / (1.0 / self.NA + 1.0 / self.ND) / (self.NA * self.NA * 1.0e6)) # m

        if Wn > self.dn and Wp > self.dp:
            Wn, Wp, Vn, Vp, Cn, Cp, C = self.full_depletion(Wn, Wp, V, Vbi)
        elif Wn > self.dn:
            Wn = self.dn
            Vn = self.ND * 1.0e6 * Wn * Wn / 2.0 / self.epsn / e0 * e
            Vp = Vbi - V - Vn
            Wp = sqrt(Vp / self.NA / 1.0e6 * 2.0 * self.epsp * e0 / e)
            if Wp > self.dp:
                Wn, Wp, Vn, Vp, Cn, Cp, C = self.full_depletion(Wn, Wp, V, Vbi)
        elif Wp > self.dp:
            Wp = self.dp
            Vp = self.NA * 1.0e6 * Wp * Wp / 2.0 / self.epsp / e0 * e
            Vn = Vbi - V - Vp
            Wn = sqrt(Vn / self.ND / 1.0e6 * 2.0 * self.epsn * e0 / e)
            if Wn > self.dn:
                Wn, Wp, Vn, Vp, Cn, Cp, C = self.full_depletion(Wn, Wp, V, Vbi)
        else:
            Vn = self.ND * 1.0e6 * Wn * Wn / 2.0 / self.epsn / e0 * e
            Vp = self.NA * 1.0e6 * Wp * Wp / 2.0 / self.epsp / e0 * e

        return Wn, Wp, Vn, Vp

    def cal_C(self, Wn, Wp):
        Cn = self.epsn * e0 / Wn     # F/m2
        Cp = self.epsp * e0 / Wp
        Ctot = 1.0 / (1.0 / Cn + 1.0 / Cp)
        return Cn, Cp, Ctot

    def cal_N_W(self, i, V, C2inv, Wn, Wp):
        nV = len(V)
        
        if i == 0:
            ip = 1
            im = 0
        elif i == nV - 1:
            ip = nV - 1
            im = nV - 2
        else:
            ip = i + 1
            im = i - 1

        epsnW = self.epsn / Wn[i]
        epspW = self.epsp / Wp[i]
        epsW_av = 1.0 / (1.0 / epsnW + 1.0 / epspW)
        eps_av = epsW_av * (Wn[i] + Wp[i])

        diff = (C2inv[ip] - C2inv[im]) / (V[ip] - V[im])
#        diff.append((C[ip] - C[im]) / (V[ip] - V[im]))
        if diff != 0.0:
            _N = -2.0 / e / eps_av / e0 / diff   # m^-3
#            _N = 2.0 * C[i] * C[i] * C[i] / e / eps_av / e0 / diff[i]  # m^-3
        else:
            _N = 0.0

        return Wn[i] + Wp[i], _N

    def cal_depletion_layers(self, V, Vbi):
        epsn = self.epsn
        ND = self.ND
        dn = self.dn
        epsp = self.epsp
        NA = self.NA
        dp = self.dp

        Wn, Wp, Vn, Vp = self.cal_depletion_WV(V, Vbi)

# Vn is given:
# Vp = Vbi - V - Vn
# Wn = sqrt(2.0 * epsn * e0 * (Vbi - Vn) / e / ND / 1.0e6) # m
# Wp = sqrt(2.0 * epsp * e0 * (Vbi - Vp) / e / NA / 1.0e6) # m
# W = Wn + Wp
# Cn = en/Wn
# Cp = ep/Wp
# C = 1/(1/Cn + 1/Cp)
# Vn = V * C / Cn
# Vp = V * C / Cp
# ? Vbi - V = Vn + Vp

        Cn, Cp, Ctot = self.cal_C(Wn, Wp)

        return Wn, Wp, Wn + Wp, Vn, Vp, Vn + Vp, Cn, Cp, Ctot


def initialize():
    class tkParams():
        def __init__(self):
            pass

    cparams = tkParams()

    cparams.mode = 'NW'

    cparams.S   = 850e-6 * 850e-6 # m2
    cparams.Vbi = 1.0  # V
    cparams.V   = 0.0  # V

    cparams.dn   = 3.7e-6 # m
    cparams.ND   = 2.7e16 # cm-3
    cparams.epsn = 10     # e0

    cparams.dp   = 375e-9 # m
    cparams.NA   = 5.0e18 # cm-3
    cparams.epsp = 10     # e0

    cparams.Vmin = -800 # V
    cparams.Vmax = 0    # V
    cparams.nV   = 1001

    cparams.outfile = 'CV.xlsx'

    cparams.figsize             = [8, 6]
    cparams.fontsize            = 14
    cparams.legend_fontsize     = 12

    return cparams

def update_variables(cparams):
    argv = sys.argv
    nargs = len(argv)
    
    if nargs >= 2:
        cparams.mode = argv[1]


def main():
    cparams = initialize()
    update_variables(cparams)

    pnj = tkPNJunction(cparams.ND, cparams.epsn, cparams.dn, cparams.NA, cparams.epsp, cparams.dp)

    nV = cparams.nV
    Vstep = (cparams.Vmax - cparams.Vmin) / (nV - 1)
    V = [cparams.Vmin + i * Vstep for i in range(nV)]

    C = []
    C2inv = []
    Wn = []
    Wp = []
    W  = []
    Vn = []
    Vp = []
    Vstep = (cparams.Vmax - cparams.Vmin) / (nV - 1)
    nskip = int(cparams.nV / 50)
    for i in range(nV):
        _V = V[i]
        _Wn, _Wp, _W, _Vn, _Vp, _Vtot, _Cn, _Cp, _C = pnj.cal_depletion_layers(_V, cparams.Vbi)
        C.append(_C)
        _C2inv = 1.0 / _C / _C
        C2inv.append(_C2inv)
        Wn.append(_Wn)
        Wp.append(_Wp)
        W.append(_Wn + _Wp)
        Vn.append(_Vn)
        Vp.append(_Vp)

    diff = []
    N = []
    diff_meas = []
    N_meas = []
    print()
    print(f"{'V':>10} {'Vn':>10} {'Vp':>10} {'W (nm)':>10} {'N(W) (cm-3)':>10} {'Wn (nm)':>10} {'Wp (nm)':>10} {'C (F/m2)':>10} {'1/C^2 (m^2F^-2)':>10}")
    for i in range(nV):
        if i == 0:
            ip = 1
            im = 0
        elif i == nV - 1:
            ip = nV - 1
            im = nV - 2
        else:
            ip = i + 1
            im = i - 1

        epsnW = cparams.epsn / Wn[i]
        epspW = cparams.epsp / Wp[i]
        epsW_av = 1.0 / (1.0 / epsnW + 1.0 / epspW)
        eps_av = epsW_av * (Wn[i] + Wp[i])

        diff.append((C2inv[ip] - C2inv[im]) / (V[ip] - V[im]))
        if diff[i] != 0.0:
            _N = -2.0 / e / eps_av / e0 / diff[i]   # m^-3
        else:
            _N = 0.0
        N.append(_N)

        print(f"{V[i]:10.4g} {Vn[i]:10.4g} {Vp[i]:10.4g} {W[i]*1e9:10.4g} {N[i]*1.0e-6:10.4g} {Wn[i]*1e9:10.4g} {Wp[i]*1e9:10.4g} {C[i]:10.4g} {C2inv[i]:10.4g}")
            
    S = cparams.S

    plt.rcParams['font.size'] = cparams.fontsize
    fig = plt.figure(figsize = (8, 8))
    ax1 = fig.add_subplot(2, 2, 1)
    ax2 = fig.add_subplot(2, 2, 2)
    ax3 = fig.add_subplot(2, 2, 3)
    ax4 = fig.add_subplot(2, 2, 4)

    ax1.plot(V, np.array(C2inv),       label = '1/$C^2$', linestyle = '', marker = 'o', markersize = 5.0)
    ax1.set_xlabel('V (V)', fontsize = cparams.fontsize)
    ax1.set_ylabel('1/$C^2$ (m$^4$F$^{-2}$)',        fontsize = cparams.fontsize)
    ax1.set_ylabel('1/$C_{meas}^2$ (m$^4$F$^{-2}$)', fontsize = cparams.fontsize)
#    ax1.set_yscale('log')
    ax1.legend(fontsize = cparams.legend_fontsize)

    ax2.plot(V, np.array(C),        label = 'C', linestyle = '', marker = 'o', markersize = 5.0)
    ax2.set_xlabel('V (V)', fontsize = cparams.fontsize)
    ax2.set_ylabel('C (F)', fontsize = cparams.fontsize)
#    ax2.set_yscale('log')
    ax2.legend(fontsize = cparams.legend_fontsize)

    colors = []
    for i in range(len(Wn)):
        if Wn[i] >= cparams.dn and Wp[i] >= cparams.dp:
            colors.append('black')
        elif Wn[i] >= cparams.dn:
            colors.append('red')
        else:
            colors.append('blue')

    ax3.plot(V, np.array(Wn) * 1.0e9, label = '$W_n$ (nm)', color = 'red', linewidth = 0.5)
    ax3.scatter(V, np.array(Wn) * 1.0e9, color = colors, marker = '>', s = 3.0)
    ax3.plot(V, np.array(Wp) * 1.0e9, label = '$W_p$ (nm)', color = 'blue', linewidth = 0.5)
    ax3.scatter(V, np.array(Wp) * 1.0e9, color = colors, marker = '^', s = 3.0)
    ax3.plot(V, np.array(W) * 1.0e9,  label = '$W$ (nm)',   color = 'black', linewidth = 1.0)
    ax3.scatter(V, np.array(W) * 1.0e9, color = colors, marker = 'o', s = 6.0)
    ax3.set_xlabel('V (V)', fontsize = cparams.fontsize)
    ax3.set_ylabel('W (nm)', fontsize = cparams.fontsize)
    ax3.set_yscale('log')
    ax3.legend(fontsize = cparams.legend_fontsize)

    ax4.plot(V, np.array(N) * 1.0e-6, label = '$N$ (cm$^{-3}$)')
    ax4.set_xlabel('V (V)',        fontsize = cparams.fontsize)
    ax4.set_ylabel('N (cm$^{-3}$)', fontsize = cparams.fontsize)
    ax4.set_yscale('log')
    ax4.legend(fontsize = cparams.legend_fontsize)

    plt.tight_layout()

    plt.pause(0.0001)
    input("\nPress ENTER to terminate>>")

    exit()


if __name__ == "__main__":
    main()

