import sys
from pprint import pprint
import numpy as np
from numpy import sqrt, exp, sin, cos, tan, pi
import numpy.linalg as LA 
import csv
from matplotlib import pyplot as plt

"""
Free electron band calculation
"""

#===================================
# physical constants
#===================================
pi   = 3.14159265358979323846
pi2  = 2.0 * pi
h    = 6.6260755e-34    # Js";
hbar = 1.05459e-34      # "Js";
c    = 2.99792458e8     # m/s";
e    = 1.60218e-19      # C";
e0   = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
kB   = 1.380658e-23     # JK<sup>-1</sup>";
me   = 9.1093897e-31    # kg";
R    = 8.314462618      # J/K/mol
a0   = 5.29177e-11      # m";


#========================
# Crystal definition
#========================
# Si
a  = 5.4064             # angstrom, lattice parameter
rg = np.zeros([3, 3])
rg[0][0] = pow(2.0 * pi / a, 2) # reciprocal space metric in angstrom^-2
rg[1][1] = rg[0][0]
rg[2][2] = rg[0][0]

# 有効質量
meff = 1.0 # in me

# E(k) = KE * k*k, E(k) in eV, k in angstrom^1
KE = hbar * hbar / 2.0 / (meff * me) * 1.0e20 / e
print("KE = ", KE)

#========================
# Band plot parameters
#========================
# バンド構造をプロットするk点の軌跡: [kx, ky, kz, k点名称]
klist = [
#      [-0.5,  0.0, 0.0,  "-X"]
#    , [0.0,  0.0, 0.0,  "$\Gamma$"]
#    , [ 0.5,  0.0, 0.0,  "X"]
      [0.5,  0.0, 1.0, "W"]
    , [0.5,  0.5, 0.5,  "L"]
    , [0.0,  0.0, 0.0,  "$\Gamma$"]
    , [0.0,  0.0, 1.0,  "X"]
    , [0.5,  0.0, 1.0,  "W"]
    , [0.75, 0.0, 0.75, "K"]
    ]
# プロットするバンド構造E(k)のk点数の概数
nk = 101

# Ehkl(k)を計算するhkl範囲
hrange = [-3, 3]
krange = [-3, 3]
lrange = [-3, 3]
#hrange = [0, 0]
#krange = [0, 0]
#lrange = [0, 0]

# プロットするエネルギー範囲
Erange = [0.0, 10.0]    # eV


#===================================
# figure configuration
#===================================
figsize = (6, 8)
#figsize = (6, 4)
fontsize        = 16
legend_fontsize = 8


#==============================================
# fundamental functions
#==============================================
# 実数値に変換できない文字列をfloat()で変換するとエラーになってプログラムが終了する
# この関数は、変換できなかったらNoneを返すが、プログラムは終了させない
def pfloat(str):
    try:
        return float(str)
    except:
        return None

# pfloat()のint版
def pint(str):
    try:
        return int(str)
    except:
        return None

# 起動時引数を取得するsys.argリスト変数は、範囲外のindexを渡すとエラーになってプログラムが終了する
# egtarg()では、範囲外のindexを渡したときは、defvalを返す
def getarg(position, defval = None):
    try:
        return sys.argv[position]
    except:
        return defval

# 起動時引数を実数に変換して返す
def getfloatarg(position, defval = None):
    return pfloat(getarg(position, defval))

# 起動時引数を整数値に変換して返す
def getintarg(position, defval = None):
    return pint(getarg(position, defval))

def usage():
    print("")
    print("Usage:")
#    print("  python {}".format(sys.argv[0]))

def terminate(message = None):
    print("")
    if message is not None:
        print("")
        print(message)
        print("")

    usage()
    print("")
    exit()


# 逆格子のmetricsから、2点のk点間の距離を計算する
def cal_kdistance(rg, k0, k1):
    dkx = k1[0] - k0[0]
    dky = k1[1] - k0[1]
    dkz = k1[2] - k0[2]
    r2  = rg[0][0] * dkx*dkx + rg[1][1] * dky*dky + rg[2][2] * dkz*dkz
    r2 += 2.0 * (rg[0][1] * dkx*dky + rg[1][2] * dky*dkz + rg[2][0] * dky*dkx)

    return sqrt(r2)

# k点を与えて自由電子のエネルギーを計算する。eV単位
# k, Ghklは内部座標で与える
def cal_E(k, Ghkl):
    global rg

    kabs2  = rg[0][0] * (k[0] + Ghkl[0])**2 # in angstrom^-2
    kabs2 += rg[1][1] * (k[1] + Ghkl[1])**2
    kabs2 += rg[2][2] * (k[2] + Ghkl[2])**2
    E = KE * kabs2  # in eV
    return E

# プロットするk点リスト klistのk点間距離 dklist と
# 最初のk点からの距離の和 ktotal_list を計算する。
# また、バンド構造プロットのため、k点名のリスト ktotal_namelistを作る
def get_dklist(klist, nk):
    print("")

# list of k distances from the first k point of each k region
    dklist     = []
# list of k distance from the first k point of the k list
    ktotal = 0.0
    ktotal_list     = []
    ktotal_namelist = []
    ktotal_list.append(0.0)
    ktotal_namelist.append(klist[0][3])
    for i in range(1, len(klist)):
        print("k [{:<10s}: ({:6.4f}, {:6.4f}, {:6.4f}] to [{:<10s}:({:6.4f}, {:6.4f}, {:6.4f}]"
            .format(
                klist[i-1][3], klist[i-1][0], klist[i-1][1], klist[i-1][2], 
                klist[i][3],   klist[i][0],   klist[i][1],   klist[i][2]))

        dk_ = cal_kdistance(rg,
                [klist[i-1][0], klist[i-1][1], klist[i-1][2]], 
                [klist[i][0],   klist[i][1],   klist[i][2]] )
        ktotal += dk_
        dklist.append(dk_)
        ktotal_list.append(ktotal)
        ktotal_namelist.append(klist[i][3])
#        print("  dk={}  ktotal={}".format(dk_, ktotal))

    return dklist, ktotal_list, ktotal_namelist, ktotal

# プロットするk点リスト klistとk点数の概数 nk から、
# なるべくk点間隔が等間隔になるように、
# 計算するk点などをリストアップする
# バンド構造プロットに必要なリストも返す
def get_cal_klist(klist, nk):
    dklist, ktotallist, ktotal_namelist, ktotal = get_dklist(klist, nk)
    kstep = ktotal / nk

    nklist = []
    xk = []
    xkvec = []
    ktotal_ = 0.0
    for i in range(1, len(klist)):
        nk_ = int(dklist[i-1] / kstep + 1.00001)
        nklist.append(nk_)
        if i == len(klist) - 1:
            ndiv = nk_ - 1
        else: 
            ndiv = nk_    
        kstepx_ = (klist[i][0] - klist[i-1][0]) / ndiv
        kstepy_ = (klist[i][1] - klist[i-1][1]) / ndiv
        kstepz_ = (klist[i][2] - klist[i-1][2]) / ndiv

        print("k: {:<10s} - {:<10s}".format(klist[i-1][3], klist[i][3]), end = '')
        print(": nk={:3d} ktotal={:6.4f} kstep=({:6.4f}, {:6.4f}, {:6.4f})"
                .format(nklist[i-1], ktotal_, kstepx_, kstepy_, kstepz_))

        dk_ = cal_kdistance(rg, [0.0, 0.0, 0.0], [kstepx_,kstepy_, kstepz_])
        for j in range(nklist[i-1]):
            kx = klist[i-1][0] + j * kstepx_
            ky = klist[i-1][1] + j * kstepy_
            kz = klist[i-1][2] + j * kstepz_

            xk.append(ktotal_)
            xkvec.append([kx, ky, kz])
            ktotal_ += dk_

    res = {"dklist": dklist, "nklist": nklist, "ktotal": ktotal, "kstep": kstep}

    return xk, xkvec, ktotallist, ktotal_namelist, res

def get_cal_Elist(xkvec, hrange, krange, lrange):
    yE = []
    for i in range(len(xkvec)):
        kx = xkvec[i][0]
        ky = xkvec[i][1]
        kz = xkvec[i][2]
        Elist = []
        for ih in range(hrange[0], hrange[1]+1):
            for ik in range(krange[0], krange[1]+1):
                for il in range(lrange[0], lrange[1]+1):
                    E = cal_E([kx, ky, kz], [ih, ik, il])
                    Elist.append(E)

        yE.append(Elist)

    return yE

# バンド構造をプロット
# xk: プロットするk点の蓄積距離のリスト
# yE: E(xk[i])。入れ子になったリストで構わない
# ktotallist: k点境界における、最初のk点からの距離の和のリスト
# ktotal_namelist: k点境界における、k点の名称
def plot_band(axis, xk, yE, Erange, ktotallist, ktotal_namelist):    
# 表示範囲は決め打ち
    axis.set_xlim([min(xk), max(xk)])
    axis.set_ylim(Erange)

# バンド構造をプロット
    axis.plot(xk, yE, linestyle = 'none', 
                marker = 'o', markerfacecolor = 'none', markeredgecolor = 'black', 
                markeredgewidth = 0.5, markersize = 2.0)

# Γ点、BZ境界の縦線を引く
    for i in range(1, len(ktotallist)-1):
        axis.plot([ktotallist[i], ktotallist[i]], Erange, 
                    linestyle = '-', color = 'black', linewidth = 0.5)

# k軸の目盛りにk点の名称を表示する
# グラフ枠が一つであれば plt.xtics()で設定できる
# axisに対しては、.setpでattributeを直接書き換える必要があるらしい
    plt.setp(axis, xticks = ktotallist, xticklabels = ktotal_namelist)
    axis.set_xlabel("k", fontsize = fontsize)
    axis.set_ylabel("E (eV)", fontsize = fontsize)
    axis.tick_params(labelsize = fontsize)


def main():
    global a, rg
    global nk, klist

    print("")
    print("Lattice parameters: ({}) A".format(a))
    print("Effective mass: {} me".format(meff))
    print("Reciprocal lattice metric [A^-2]:")
    pprint(rg)
    print("khl range: h=[{}, {}] k=[{}, {}] l=[{}, {}]"
            .format(*hrange, *krange, *lrange))
    print("Plot E range: {} - {} eV".format(*Erange))

    xk, xkvec, ktotallist, ktotal_namelist, res = get_cal_klist(klist, nk)

    print("")
    print("k vectors")
    print(" k_total: {} A^-1".format(res["ktotal"]))
    print(" nk     : ", nk)
    print(" kstep  : {}".format(res["kstep"]))
    print(" dklist")
    pprint(res["dklist"])
    print(" ktotallist:")
    pprint(ktotallist)
    print(" nk_list:", res["nklist"])
#    print("xk")
#    pprint(xk)

    yE = get_cal_Elist(xkvec, hrange, krange, lrange)

    print("")
    print("plot")
    
    fig = plt.figure(figsize = figsize)
    ax1 = fig.add_subplot(1, 1, 1)

    plot_band(ax1, xk, yE, Erange, ktotallist, ktotal_namelist)

    plt.tight_layout()

    plt.pause(0.1)
    print("Press ENTER to exit>>", end = '')
    input()

    terminate()


if __name__ == "__main__":
    main()
