import sys
from pprint import pprint
import numpy as np
from numpy import sqrt, exp, sin, cos, tan, pi
import numpy.linalg as LA 
import csv
from matplotlib import pyplot as plt

"""
概要:
    自由電子モデルを用いたバンド計算を行うモジュール。

関連リンク:
    free_electron_band_usage
"""

#===================================
# physical constants
#===================================
pi   = 3.14159265358979323846
pi2  = 2.0 * pi
h    = 6.6260755e-34    # Js";
hbar = 1.05459e-34      # "Js";
c    = 2.99792458e8     # m/s";
e    = 1.60218e-19      # C";
e0   = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
kB   = 1.380658e-23     # JK<sup>-1</sup>";
me   = 9.1093897e-31    # kg";
R    = 8.314462618      # J/K/mol
a0   = 5.29177e-11      # m";


#========================
# Crystal definition
#========================
# Si
a  = 5.4064             # angstrom, lattice parameter
rg = np.zeros([3, 3])
rg[0][0] = pow(2.0 * pi / a, 2) # reciprocal space metric in angstrom^-2
rg[1][1] = rg[0][0]
rg[2][2] = rg[0][0]

# 有効質量
meff = 1.0 # in me

# E(k) = KE * k*k, E(k) in eV, k in angstrom^1
KE = hbar * hbar / 2.0 / (meff * me) * 1.0e20 / e
print("KE = ", KE)

#========================
# Band plot parameters
#========================
# バンド構造をプロットするk点の軌跡: [kx, ky, kz, k点名称]
klist = [
#      [-0.5,  0.0, 0.0,  "-X"]
#    , [0.0,  0.0, 0.0,  r"$\Gamma$"]
#    , [ 0.5,  0.0, 0.0,  "X"]
      [0.5,  0.0, 1.0, "W"]
    , [0.5,  0.5, 0.5,  "L"]
    , [0.0,  0.0, 0.0,  r"$\Gamma$"]
    , [0.0,  0.0, 1.0,  "X"]
    , [0.5,  0.0, 1.0,  "W"]
    , [0.75, 0.0, 0.75, "K"]
    ]
# プロットするバンド構造E(k)のk点数の概数
nk = 101

# Ehkl(k)を計算するhkl範囲
hrange = [-3, 3]
krange = [-3, 3]
lrange = [-3, 3]
#hrange = [0, 0]
#krange = [0, 0]
#lrange = [0, 0]

# プロットするエネルギー範囲
Erange = [0.0, 10.0]    # eV


#===================================
# figure configuration
#===================================
figsize = (6, 8)
#figsize = (6, 4)
fontsize        = 16
legend_fontsize = 8


#==============================================
# fundamental functions
#==============================================
def pfloat(str):
    """
    概要:
        文字列を実数値に変換する。
    詳細説明:
        実数値に変換できない文字列をfloat()で変換するとエラーになってプログラムが終了するため、
        この関数は変換できなかったらNoneを返すが、プログラムは終了させない。
    引数:
        :param str: 変換元の文字列
        :type str: str
    戻り値:
        :returns: 変換された実数値、変換できない場合はNone
        :rtype: float or None
    """
    try:
        return float(str)
    except:
        return None

def pint(str):
    """
    概要:
        文字列を整数値に変換する。
    詳細説明:
        pfloat()と同様に、変換できない文字列が渡されてもエラー終了させずNoneを返す。
    引数:
        :param str: 変換元の文字列
        :type str: str
    戻り値:
        :returns: 変換された整数値、変換できない場合はNone
        :rtype: int or None
    """
    try:
        return int(str)
    except:
        return None

def getarg(position, defval = None):
    """
    概要:
        起動時のコマンドライン引数を安全に取得する。
    詳細説明:
        sys.argvに対して範囲外のインデックスを指定してもエラー終了せず、デフォルト値を返す。
    引数:
        :param position: 取得したい引数のインデックス
        :type position: int
        :param defval: 範囲外アクセス時に返すデフォルト値
        :type defval: Any
    戻り値:
        :returns: 取得した引数の文字列、またはデフォルト値
        :rtype: str or Any
    """
    try:
        return sys.argv[position]
    except:
        return defval

def getfloatarg(position, defval = None):
    """
    概要:
        起動時引数を取得し、実数に変換して返す。
    引数:
        :param position: 取得したい引数のインデックス
        :type position: int
        :param defval: 範囲外アクセス時や変換失敗時に返すデフォルト値
        :type defval: Any
    戻り値:
        :returns: 変換された実数値、またはデフォルト値
        :rtype: float or Any
    """
    return pfloat(getarg(position, defval))

def getintarg(position, defval = None):
    """
    概要:
        起動時引数を取得し、整数値に変換して返す。
    引数:
        :param position: 取得したい引数のインデックス
        :type position: int
        :param defval: 範囲外アクセス時や変換失敗時に返すデフォルト値
        :type defval: Any
    戻り値:
        :returns: 変換された整数値、またはデフォルト値
        :rtype: int or Any
    """
    return pint(getarg(position, defval))

def usage():
    """
    概要:
        プログラムの使用方法を表示する。
    詳細説明:
        コンソール上に使い方のメッセージを出力する。
    戻り値:
        :returns: なし
        :rtype: None
    """
    print("")
    print("Usage:")
#    print("  python {}".format(sys.argv[0]))

def terminate(message = None):
    """
    概要:
        メッセージを表示し、プログラムを安全に終了する。
    詳細説明:
        任意のメッセージと使用方法を表示した上で exit() を呼び出す。
    引数:
        :param message: 終了前に表示する文字列（Noneの場合は表示しない）
        :type message: str or None
    戻り値:
        :returns: なし
        :rtype: None
    """
    print("")
    if message is not None:
        print("")
        print(message)
        print("")

    usage()
    print("")
    exit()

def cal_kdistance(rg, k0, k1):
    """
    概要:
        逆格子の計量テンソルを用いて、2つのk点間の距離を計算する。
    引数:
        :param rg: 逆格子の計量テンソル(3x3行列)
        :type rg: list or numpy.ndarray
        :param k0: 始点のk点座標 [kx, ky, kz]
        :type k0: list
        :param k1: 終点のk点座標 [kx, ky, kz]
        :type k1: list
    戻り値:
        :returns: 2点間の距離
        :rtype: float
    """
    dkx = k1[0] - k0[0]
    dky = k1[1] - k0[1]
    dkz = k1[2] - k0[2]
    r2  = rg[0][0] * dkx*dkx + rg[1][1] * dky*dky + rg[2][2] * dkz*dkz
    r2 += 2.0 * (rg[0][1] * dkx*dky + rg[1][2] * dky*dkz + rg[2][0] * dky*dkx)

    return sqrt(r2)

def cal_E(k, Ghkl):
    """
    概要:
        k点座標と逆格子ベクトルから自由電子のエネルギーを計算する(単位: eV)。
    詳細説明:
        kおよびGhklは内部座標系で与える。
    引数:
        :param k: k点の内部座標 [kx, ky, kz]
        :type k: list
        :param Ghkl: 逆格子ベクトル [h, k, l]
        :type Ghkl: list
    戻り値:
        :returns: 計算されたエネルギー値 (eV)
        :rtype: float
    """
    global rg

    kabs2  = rg[0][0] * (k[0] + Ghkl[0])**2 # in angstrom^-2
    kabs2 += rg[1][1] * (k[1] + Ghkl[1])**2
    kabs2 += rg[2][2] * (k[2] + Ghkl[2])**2
    E = KE * kabs2  # in eV
    return E

def get_dklist(klist, nk):
    """
    概要:
        プロットするk点軌跡の各距離と、最初のk点からの累積距離リストを計算する。
    詳細説明:
        バンド構造プロット用に、k点の名称リストも合わせて生成する。
    引数:
        :param klist: [kx, ky, kz, 名称] の形式を持つk点の軌跡リスト
        :type klist: list
        :param nk: プロットするk点の概算数
        :type nk: int
    戻り値:
        :returns: dklist, ktotal_list, ktotal_namelist, ktotal のタプル。
                  dklistは各k点間の距離リスト。
                  ktotal_listは最初のk点からの累積距離リスト。
                  ktotal_namelistはk点の名称リスト。
                  ktotalは累積距離の合計。
        :rtype: tuple
    """
    print("")

# list of k distances from the first k point of each k region
    dklist     = []
# list of k distance from the first k point of the k list
    ktotal = 0.0
    ktotal_list     = []
    ktotal_namelist = []
    ktotal_list.append(0.0)
    ktotal_namelist.append(klist[0][3])
    for i in range(1, len(klist)):
        print("k [{:<10s}: ({:6.4f}, {:6.4f}, {:6.4f}] to [{:<10s}:({:6.4f}, {:6.4f}, {:6.4f}]"
            .format(
                klist[i-1][3], klist[i-1][0], klist[i-1][1], klist[i-1][2], 
                klist[i][3],   klist[i][0],   klist[i][1],   klist[i][2]))

        dk_ = cal_kdistance(rg,
                [klist[i-1][0], klist[i-1][1], klist[i-1][2]], 
                [klist[i][0],   klist[i][1],   klist[i][2]] )
        ktotal += dk_
        dklist.append(dk_)
        ktotal_list.append(ktotal)
        ktotal_namelist.append(klist[i][3])
#        print("  dk={}  ktotal={}".format(dk_, ktotal))

    return dklist, ktotal_list, ktotal_namelist, ktotal

def get_cal_klist(klist, nk):
    """
    概要:
        指定されたk点軌跡から、等間隔になるように計算用のk点リストを生成する。
    詳細説明:
        プロットに必要な累積距離リストやk点名称などのデータもあわせて作成する。
    引数:
        :param klist: バンド構造をプロットするk点の軌跡リスト
        :type klist: list
        :param nk: プロットするk点数の概数
        :type nk: int
    戻り値:
        :returns: xk, xkvec, ktotallist, ktotal_namelist, res のタプル。
                  xkは各計算k点の累積距離リスト。
                  xkvecは各計算k点の座標リスト。
                  ktotallistは区切りとなるk点の累積距離リスト。
                  ktotal_namelistは区切りとなるk点の名称リスト。
                  resは距離情報や分割数情報を持つ辞書。
        :rtype: tuple
    """
    dklist, ktotallist, ktotal_namelist, ktotal = get_dklist(klist, nk)
    kstep = ktotal / nk

    nklist = []
    xk = []
    xkvec = []
    ktotal_ = 0.0
    for i in range(1, len(klist)):
        nk_ = int(dklist[i-1] / kstep + 1.00001)
        nklist.append(nk_)
        if i == len(klist) - 1:
            ndiv = nk_ - 1
        else: 
            ndiv = nk_    
        kstepx_ = (klist[i][0] - klist[i-1][0]) / ndiv
        kstepy_ = (klist[i][1] - klist[i-1][1]) / ndiv
        kstepz_ = (klist[i][2] - klist[i-1][2]) / ndiv

        print("k: {:<10s} - {:<10s}".format(klist[i-1][3], klist[i][3]), end = '')
        print(": nk={:3d} ktotal={:6.4f} kstep=({:6.4f}, {:6.4f}, {:6.4f})"
                .format(nklist[i-1], ktotal_, kstepx_, kstepy_, kstepz_))

        dk_ = cal_kdistance(rg, [0.0, 0.0, 0.0], [kstepx_,kstepy_, kstepz_])
        for j in range(nklist[i-1]):
            kx = klist[i-1][0] + j * kstepx_
            ky = klist[i-1][1] + j * kstepy_
            kz = klist[i-1][2] + j * kstepz_

            xk.append(ktotal_)
            xkvec.append([kx, ky, kz])
            ktotal_ += dk_

    res = {"dklist": dklist, "nklist": nklist, "ktotal": ktotal, "kstep": kstep}

    return xk, xkvec, ktotallist, ktotal_namelist, res

def get_cal_Elist(xkvec, hrange, krange, lrange):
    """
    概要:
        計算用k点リストと逆格子ベクトルの範囲から、各k点でのエネルギーリストを計算する。
    引数:
        :param xkvec: 計算を行うk点座標 [kx, ky, kz] のリスト
        :type xkvec: list
        :param hrange: 逆格子ベクトルhの計算範囲 [最小値, 最大値]
        :type hrange: list
        :param krange: 逆格子ベクトルkの計算範囲 [最小値, 最大値]
        :type krange: list
        :param lrange: 逆格子ベクトルlの計算範囲 [最小値, 最大値]
        :type lrange: list
    戻り値:
        :returns: 各k点におけるエネルギー群が格納された入れ子リスト
        :rtype: list
    """
    yE = []
    for i in range(len(xkvec)):
        kx = xkvec[i][0]
        ky = xkvec[i][1]
        kz = xkvec[i][2]
        Elist = []
        for ih in range(hrange[0], hrange[1]+1):
            for ik in range(krange[0], krange[1]+1):
                for il in range(lrange[0], lrange[1]+1):
                    E = cal_E([kx, ky, kz], [ih, ik, il])
                    Elist.append(E)

        yE.append(Elist)

    return yE

def plot_band(axis, xk, yE, Erange, ktotallist, ktotal_namelist):    
    """
    概要:
        計算した自由電子のバンド構造をグラフにプロットする。
    引数:
        :param axis: matplotlibのAxesオブジェクト
        :type axis: matplotlib.axes.Axes
        :param xk: プロットするk点の累積距離のリスト
        :type xk: list
        :param yE: 各k点に対応するエネルギーのリスト(入れ子でも可)
        :type yE: list
        :param Erange: プロットするエネルギーの表示範囲 [最小値, 最大値]
        :type Erange: list
        :param ktotallist: k点境界における、最初のk点からの距離の和のリスト
        :type ktotallist: list
        :param ktotal_namelist: k点境界における、k点の名称のリスト
        :type ktotal_namelist: list
    戻り値:
        :returns: なし
        :rtype: None
    """
# 表示範囲は決め打ち
    axis.set_xlim([min(xk), max(xk)])
    axis.set_ylim(Erange)

# バンド構造をプロット
    axis.plot(xk, yE, linestyle = 'none', 
                marker = 'o', markerfacecolor = 'none', markeredgecolor = 'black', 
                markeredgewidth = 0.5, markersize = 2.0)

# Γ点、BZ境界の縦線を引く
    for i in range(1, len(ktotallist)-1):
        axis.plot([ktotallist[i], ktotallist[i]], Erange, 
                    linestyle = '-', color = 'black', linewidth = 0.5)

# k軸の目盛りにk点の名称を表示する
# グラフ枠が一つであれば plt.xtics()で設定できる
# axisに対しては、.setpでattributeを直接書き換える必要があるらしい
    plt.setp(axis, xticks = ktotallist, xticklabels = ktotal_namelist)
    axis.set_xlabel("k", fontsize = fontsize)
    axis.set_ylabel("E (eV)", fontsize = fontsize)
    axis.tick_params(labelsize = fontsize)


def main():
    """
    概要:
        自由電子バンド計算のメインルーチン。
    詳細説明:
        格子定数やプロットパラメータを設定し、k点のリストとエネルギーを計算して
        matplotlibを用いてバンド図を描画する。
    戻り値:
        :returns: なし
        :rtype: None
    """
    global a, rg
    global nk, klist

    print("")
    print("Lattice parameters: ({}) A".format(a))
    print("Effective mass: {} me".format(meff))
    print("Reciprocal lattice metric [A^-2]:")
    pprint(rg)
    print("khl range: h=[{}, {}] k=[{}, {}] l=[{}, {}]"
            .format(*hrange, *krange, *lrange))
    print("Plot E range: {} - {} eV".format(*Erange))

    xk, xkvec, ktotallist, ktotal_namelist, res = get_cal_klist(klist, nk)

    print("")
    print("k vectors")
    print(" k_total: {} A^-1".format(res["ktotal"]))
    print(" nk     : ", nk)
    print(" kstep  : {}".format(res["kstep"]))
    print(" dklist")
    pprint(res["dklist"])
    print(" ktotallist:")
    pprint(ktotallist)
    print(" nk_list:", res["nklist"])
#    print("xk")
#    pprint(xk)

    yE = get_cal_Elist(xkvec, hrange, krange, lrange)

    print("")
    print("plot")
    
    fig = plt.figure(figsize = figsize)
    ax1 = fig.add_subplot(1, 1, 1)

    plot_band(ax1, xk, yE, Erange, ktotallist, ktotal_namelist)

    plt.tight_layout()

    plt.pause(0.1)
    print("Press ENTER to exit>>", end = '')
    input()

    terminate()


if __name__ == "__main__":
    main()