import os
import sys
import shutil
import glob
import csv
import re
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
from tklib.tkinifile import tkIniFile
import tklib.tkre as tkre
from tklib.tkutils import save_csv, colors
from tklib.tkapplication import tkApplication
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tkconvolution import convolution, convolve_func
from tklib.tksci.tkconvolution import convolve_xydata
from tklib.tksci.tksci import Reduce01, Round
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType
from tklib.tkcrystal.tkvasp import tkVASP
import tklib.tkcsv


"""
Estimate VBM correction for defect model
"""

#================================
# global parameters
#================================
debug = 0

# Algorism to determin dEVBM
#mode = 'average'
validmodes = ['mode', 'average']
mode = 'mode'

CAR_dir1 = '.'
CAR_dir2 = '.'
summaryfile = None

# FWHM of Gaussian smearing function for distribution
Wg  = 0.1  # eV

# FWHM of Gaussian smearing function for DOS
WG_DOS = 0.1 # eV

# Criteria to exclude sites near defects
dVth = 0.1  # eV

# Defect position that will be the origin of distances
defect_position = ''

# Criteria to find vacancies
Rmin_vacancy  = 0.5   # angstrom
# Minimum distance to calculate site potentials
Rmin_potential = 5.0 # angstrom

# Plot configuration
Emin = 0.0  # eV
Emax = 8.0  # eV
#colors = ['#000000', '#ff0000', '#00aa00', '#0000ff', '#aaaa00', '#ff00ff', '#00ffff', '#aa0000', '#00aa00', '#0000aa']
#colors = ['k', 'r', 'g', 'b', 'y', 'm', 'c']
figsize = (10, 6)
fontsize        = 16
titlefontsize   = 12
labelfontsize   = 8
legend_fontsize = 12


app = tkApplication()


#=============================
# Treat argments
#=============================
def usage():
    global mode
    
    if mode not in validmodes:
        mode = validmodes[0]

    print("")
    print("Usage:")
    print("  (a)  python {} mode CAR_dir(ideal) CAR_dir(defect) dVth WGauss Rmin_vacancy Rmin_potential Emin, Emax".format(sys.argv[0]))
    print("              mode: 'mode' for mode (most freqent) algorism, 'average' for average")
    print("     ex: python {} {} {} {} {} {} {} {}".format(sys.argv[0], mode, CAR_dir1, CAR_dir2, dVth, Wg, Rmin_vacancy, Rmin_potential, Emin, Emax))

def updatevars():
    global mode
    global CAR_dir1, CAR_dir2
    global dVth, Wg, Rmin_vacancy, Rmin_potential, defect_position
    global Emin, Emax, WG_DOS

    mode     = getarg     (1, mode)
    CAR_dir1 = getarg     (2, CAR_dir1)
    CAR_dir2 = getarg     (3, CAR_dir2)
    dVth     = getfloatarg(4, dVth)
    Wg       = getfloatarg(5, Wg)
    Rmin_vacancy    = getfloatarg(6, Rmin_vacancy)
    Rmin_potential  = getfloatarg(7, Rmin_potential)
    Emin            = getfloatarg(8, Emin)
    Emax            = getfloatarg(9, Emax)
    defect_position = getarg(10, defect_position)
    WG_DOS          = getfloatarg(11, WG_DOS)
    if mode == 'average' or mode == 'mode':
        pass
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


#def save_csv(path, headerlist, datalist, is_print = 0):

def VBM_correction(mode, CAR_path1, CAR_path2, Emin, Emax, Rmin_vacancy, Rmin_potential):
    debug = 0
    
    vasp1 = tkVASP()
    vasp2 = tkVASP()

    base_path1 = vasp1.getdir(CAR_path1)
    base_path2 = vasp2.getdir(CAR_path2)

    print("")
    print(f"Summary file: {summaryfile}")
    INCAR_path1    = vasp1.get_INCAR(base_path1)
    POSCAR_path1   = vasp1.get_POSCAR(base_path1)
    CONTCAR_path1  = vasp1.get_CONTCAR(base_path1)
    OUTCAR_path1   = vasp1.get_OUTCAR(base_path1)
    DOSCAR_path1   = vasp1.get_VASPPath(base_path1, 'DOSCAR')
    EIGENVAL_path1 = vasp1.get_VASPPath(base_path1, 'EIGENVAL')
    INCAR_path2    = vasp2.get_INCAR(base_path2)
    POSCAR_path2   = vasp2.get_POSCAR(base_path2)
    CONTCAR_path2  = vasp2.get_CONTCAR(base_path2)
    OUTCAR_path2   = vasp2.get_OUTCAR(base_path2)
    DOSCAR_path2   = vasp2.get_VASPPath(base_path2, 'DOSCAR')
    print("CAR dir(ideal) : ", CAR_path1)
    print("  INCAR   : ", INCAR_path1)
    print("  POSCAR  : ", POSCAR_path1)
    print("  CONTCAR : ", CONTCAR_path1)
    print("  OUTCAR  : ", OUTCAR_path1)
    print("  DOSCAR  : ", DOSCAR_path1)
    print("  EIGENVAL: ", EIGENVAL_path1)
    print("CAR dir(defect): ", CAR_path2)
    print("  INCAR   : ", INCAR_path2)
    print("  POSCAR  : ", POSCAR_path2)
    print("  CONTCAR : ", CONTCAR_path2)
    print("  OUTCAR  : ", OUTCAR_path2)
    print("  DOSCAR  : ", DOSCAR_path2)
    print("")
    if defect_position == '':
        print("Defect position to be the origin of distances will be determined by the following maximum distances")
    else:
        print(f"Defect position to be the origin of distances: {defect_position}")
    print(f"Maximum distance to find vacancies: {Rmin_vacancy} A")
    print(f"Maximum distance from defect to calculate site potentials for dEVBM: {Rmin_potential} A")

    print(f"dVth: {dVth} eV")
    print(f"Width of Gaussian function to deconvolute distrition: {Wg} eV")
    print("")
    print("DOS plot")
    print(f"  FWHM of Gauss function for convolution: {WG_DOS} eV")

    print("")
    print("Idetal crystal model:")
    print("Read from [{}]".format(POSCAR_path1))
    cry1 = vasp1.read_poscar(POSCAR_path1)
    if cry1 is None:
        terminate("Error: Can not read [{}]".format(POSCAR_path1), usage = usage)

    a1, b1, c1, alpha1, beta1, gamm1 = cry1.LatticeParameters()
    Vcell1        = cry1.Volume()
    types1        = cry1.AtomTypeList()
    ntypes1       = len(types1)
    sites1        = cry1.ExpandedAtomSiteList()
    nsites1       = len(sites1)
    atomnamelist1 = cry1.get_atom_name_list(mode = 'all', NameOnly = True)
    poslist1      = cry1.get_position_list(mode = 'all', IsReduce01 = True)
    cry1.PrintInf("cell")

    print("")
    print("Defect crystal model:")
    print("Read from [{}]".format(POSCAR_path2))
    cry2 = vasp2.read_poscar(POSCAR_path2)
    if cry2 is None:
        terminate("Error: Can not read [{}]".format(POSCAR_path2), usage = usage)

    a2, b2, c2, alpha2, beta2, gamm2 = cry2.LatticeParameters()
    Vcell2        = cry2.Volume()
    types2        = cry2.AtomTypeList()
    ntypes2       = len(types2)
    sites2        = cry2.ExpandedAtomSiteList()
    nsites2       = len(sites2)
    atomnamelist2 = cry2.get_atom_name_list(mode = 'all', NameOnly = True)
    poslist2      = cry2.get_position_list(mode = 'all', IsReduce01 = True)
    cry2.PrintInf("cell")

    print("")
    outcarinf1 = vasp1.read_outcar_inf(OUTCAR_path1)
    if outcarinf1 is None:
        terminate("Error: Can not read [{}]".format(OUTCAR_path1), usage = usage)
    EF1 = vasp1.outcarinf["EF"]
    ISPIN1 = outcarinf1["ISPIN"]

    print("")
    outcarinf2 = vasp2.read_outcar_inf(OUTCAR_path2)
    if outcarinf2 is None:
        terminate("Error: Can not read [{}]".format(OUTCAR_path2), usage = usage)
    EF2 = vasp2.outcarinf["EF"]
    ISPIN2 = outcarinf2["ISPIN"]

    Vcore1 = outcarinf1["Vcore"]
    Vcore2 = outcarinf2["Vcore"]
    print("Average core pontentials in the ideal model:\n", Vcore1)
    print("Average core pontentials in the defect model:\n", Vcore2)

    print("")
    print(f"Ideal crystal model: Read DOSCAR from [{DOSCAR_path1}]")
    doscarinf1 = vasp1.read_doscar(DOSCAR_path1, EF = EF1, normalize_E = False, unit = '/cm3')
    nE1 = doscarinf1["nE"]
    E1  = doscarinf1["E"]
    if nE1 == 0:
        Edmin1 = 0.0
        Edmax1 = 0.0
    else:
        Edmin1  = min(E1)
        Edmax1  = max(E1)
    tDOSup1 = doscarinf1["TotalDOSup"]
    Neup1   = doscarinf1["Neup"]
    tDOSdn1 = doscarinf1["TotalDOSdn"]
    Nedn1   = doscarinf1["Nedn"]
    if nE1 > 0:
        print(f"EF={EF1} eV")
        print(f"DOS E range: {Edmin1:10.6g} - {Edmax1:10.6g} eV, {nE1} points")
        print(f"  Energy is not normalized")

    print(f"Read EIGENVAL from [{DOSCAR_path1}]")
    eigenvalinf = vasp1.read_eigenval(EIGENVAL_path1, EF = EF1, normalize_E = False)
    EF0 = 0.1
    occ_th = 0.5
    bandedgeinf1a = vasp1.find_band_edges_from_eigenval(EF0 = EF0, eigenvalinf = eigenvalinf, ISPIN = ISPIN1, occ_th = occ_th)
    bandedgeinf1b = vasp1.gbandedges(base_path1, OUTCAR_path1, EIGENVAL_path1, EF1)
    EV1   = bandedgeinf1a["EV"]
    EC1   = bandedgeinf1a["EC"]
    Eg1   = bandedgeinf1a["Eg"]
    EV    = bandedgeinf1b["EVBM"]
    EC    = bandedgeinf1b["ECBM"]
    Eg    = bandedgeinf1b["Eg"]
    EHOMO = bandedgeinf1a["EHOMO"]
    ELUMO = bandedgeinf1a["ELUMO"]
    print(f"Band edge from EIGENVAL:")
    print(f"EF0={EF0:10.6f} eV")
    print(f"  find_band_edges: EV={EV:10.6f}  EC={EC:10.6f}  Eg={Eg:10.6f} eV")
    print(f"             HOMO:{EHOMO:10.6f} eV")
    print(f"             LUMO:{ELUMO:10.6f} eV")
    print(f"  gbandedges()   : EV={EV:10.6f}  EC={EC:10.6f}  Eg={Eg:10.6f} eV")


    print("")
    print(f"Defect crystal model: Read DOSCAR from [{DOSCAR_path2}]")
    doscarinf2 = vasp2.read_doscar(DOSCAR_path2, EF = EF2, normalize_E = False, unit = '/cm3')
    nE2 = doscarinf2["nE"]
    E2  = doscarinf2["E"]
    if nE2 == 0:
        Edmin2 = 0.0
        Edmax2 = 0.0
    else:
        Edmin2  = min(E2)
        Edmax2  = max(E2)
    tDOSup2 = doscarinf2["TotalDOSup"]
    Neup2   = doscarinf2["Neup"]
    tDOSdn2 = doscarinf2["TotalDOSdn"]
    Nedn2   = doscarinf2["Nedn"]
    if nE2 > 0:
        print(f"EF={EF2} eV")
        print(f"DOS E range: {Edmin2:10.6g} - {Edmax2:10.6g} eV, {nE2} points")
        print(f"  Energy is not normalized")

# Estimate supercell size
    print("")
    nx = int(a2 / a1 + 0.2)
    ny = int(b2 / b1 + 0.2)
    nz = int(c2 / c1 + 0.2)
    print("Supercell multipliers of the defect crystal with respect to the ideal crystal")
    print("  (nx, ny, nz) = ({}, {}, {})".format(nx, ny, nz))

    
# Find defect sites
    nk1 = nx * ny * nz
    natoms1 = {}
    for name in atomnamelist1:
        v = natoms1.get(name, None)
        if v is None:
            natoms1[name] = nk1
        else:
            natoms1[name] += nk1

    natoms2 = {}
    for name in atomnamelist2:
        v = natoms2.get(name, None)
        if v is None:
            natoms2[name] = 1
        else:
            natoms2[name] += 1

    print("natoms1=", natoms1)
    print("natoms2=", natoms2)

    natoms_diff = natoms1.copy()
    for name in natoms2.keys():
        v = natoms_diff.get(name, None)
        if v is None:
            natoms_diff[name] = -natoms2[name]
        else:
            natoms_diff[name] -= natoms2[name]

    interstitials = {}
    vacancies = {}
    print("")
    print("Possible defects judged from the numbers of atoms normalized by the defect crystal cell size")
    print(f"  {'name':4}: {'ideal':5} {'defect':6} {'diff':4}")
    for name in natoms_diff.keys():
        n1 = natoms1.get(name, 0)
        n2 = natoms2.get(name, 0)
        ndiff = natoms_diff[name]
        if ndiff < 0:
            interstitials[name] = 1
            name = f"{name}_i"
            print(f"  {name:4}: {n1:3} - {n2:3} = {ndiff:3}")
        elif ndiff > 0:
            vacancies[name] = 1
            name = f"V_{name}_"
            print(f"  {name:4}: {n1:3} - {n2:3} = {ndiff:3}")
        else:
            print(f"  {name:4}: {n1:3} - {n2:3} = {ndiff:3}")

# Find interstitial sites
    print("")
    dsiteinf = []
    inames = list(interstitials.keys())
    ni = len(inames)
    if ni == 0:
        print("No interstitial site")
    else:
        print("Find interstitial sites")
        for i in range(nsites2):
            if atomnamelist2[i] in inames:
# 欠陥モデル（超格子）の内部座標をidealモデルの内部座標に変換し、欠陥モデルの格子間位置を探す
                x2 = poslist2[i][0] * nx
                y2 = poslist2[i][1] * ny
                z2 = poslist2[i][2] * nz
                idx1 = cry1.FindNearestSite([x2, y2, z2])
                dis  = cry1.GetNearestInterAtomicDistance(poslist1[idx1], [x2, y2, z2])
#         print(f"{count=} idx2={i} ({atomnamelist2[i]}) - {idx1=} ({atomnamelist1[idx1]}) {x2=:8.4f} {y2=:8.4f} {z2=:8.4f} {dis=:10.4f}")
# 最近接位置がRmin_vacancyより大きかったら、格子間と判断する
                if dis > Rmin_vacancy:
                    dsiteinf.append({"type": "i", "i2": i, 
                                     "name2": atomnamelist2[i], "pos2": [poslist2[i][0], poslist2[i][1], poslist2[i][2]] })

# Find vacancy sites
    print("")
    print("Find vacancy sites")
    count = 0
    for i in range(nsites1):
        for iz in range(nz+1):
            for iy in range(ny+1):
                for ix in range(nx+1):
# idealモデルの内部座標を欠陥モデル（超格子）の内部座標に変換し、欠陥モデルの最近接位置を探す
                    x1 = (poslist1[i][0] + ix) / nx
                    y1 = (poslist1[i][1] + iy) / ny
                    z1 = (poslist1[i][2] + iz) / nz
                    idx2 = cry2.FindNearestSite([x1, y1, z1]) #, irange = [2, 2, 2])
                    dis  = cry2.GetNearestInterAtomicDistance(poslist2[idx2], [x1, y1, z1]) #, irange = [2, 2, 2])
#                    print(f"{count=} idx1={i} ({atomnamelist1[i]}) - {idx2=} ({atomnamelist2[idx2]}) "
#                            + f"({x1:8.4f}+{ix} {y1=:8.4f}+{iy} {z1:8.4f}+{iz}) - ({poslist2[idx2][0]:8.4f} {poslist2[idx2][1]:8.4f} {poslist2[idx2][2]:8.4f}) {dis=:10.4f}")
# 欠陥モデルの最近接位置がRmin_vacancyより大きかったら、vacancyと判断する
                    if dis > Rmin_vacancy:
                        dsiteinf.append({ "type": "v", "i2": i, "name2": atomnamelist1[i], "pos2": [x1, y1, z1] })
                        count += 1

# Find correspoding sites
    siteinf   = []
    Vdifflist = []
    excluded = [0] * nsites2
    for i in range(nsites2):
        name2 = atomnamelist2[i]
# 欠陥モデル（超格子）の内部座標をideal modelの内部座標に変換し、ideal modelの最近接位置を探す
        idx1 = cry1.FindNearestSite([poslist2[i][0] * nx, poslist2[i][1] * ny, poslist2[i][2] * nz]) #, irange = [2, 2, 2])
#        dis  = cry2.GetNearestInterAtomicDistance(poslist2[idx2], [x1, y1, z1], irange = [2, 2, 2])

        defecttype = ''
        if idx1 >= 0:
            name1 = atomnamelist1[idx1]
            if name1 != name2:
                defecttype = "{}_{}".format(name2, name1)
                dsiteinf.append({ "type": defecttype, "i2": i, "name2": name2, "pos2": poslist2[i] })
# 対応する原子がない（置換）場合は除外
                excluded[i] = 1
                if debug:
                    print(f"idx2={i} {name2} is excluded due to atom name mismatch with atom1={name1}")
            Vdiff = Vcore2[i] - Vcore1[idx1]
        else:
            idx1  = -1
            name1 = 'i'
            defecttype = 'i'
            dsiteinf.append({ "type": "i", "i2": i, "name2": name2, "pos2": poslist2[i] })
            Vdiff = Vcore2[i]

        Vdifflist.append(Vdiff)
        inf = { "i1": idx1, "name1": atomnamelist1[idx1], "V1": Vcore1[idx1], "pos1": poslist1[idx1],
                "i2": i,    "name2": atomnamelist2[i],    "V2": Vcore2[i],    "pos2": poslist2[i],
                "Vdiff": Vdiff, "Defect": defecttype }
        siteinf.append(inf)

    print("")
    print("Defect sites:")
    def_sites = []
    for i in range(len(dsiteinf)):
        inf = dsiteinf[i]
        print(f"  {inf['type']}: {inf['name2']:4}{inf['i2']:04d} ({inf['pos2'][0]:8.4f}, {inf['pos2'][1]:8.4f}, {inf['pos2'][2]:8.4f})")
        if inf['type'] == 'i':
            def_sites.append([inf['name2'], 'i'])
        elif inf['type'] == 'v':
            def_sites.append(['V', inf['name2']])
        elif '_' in inf['type']:
            a = inf['type'].split('_')
            def_sites.append(a)
        else:
            print("\n ***Error: Invalid defect [{inf['type']}]")
#    print("  {}".format(CAR_path2))

    print("")
    print("The origin of distances:")
    if defect_position == '':
        if len(dsiteinf) > 0:
            print(f"  determined by the first position of the dfects")
            pos0 = dsiteinf[0]["pos2"]
        else:
            print(f"  assumed to be (0, 0, 0)")
            pos0 = [0, 0, 0]
    else:
        print(f"  determined by defect_position = {defect_position}")
        a = defect_position.split(',')
        if len(a) < 3:
            a = defect_position.split()
        if len(a) < 3:
            app.terminate(f"Error: defect_position must have 3 coordinates seperated by commas / spaces", pause = True)

        pos0 = [pfloat(a[0]), pfloat(a[1]), pfloat(a[2])]
    print("  Origin: ({:8.4f}, {:8.4f}, {:8.4f})".format(*pos0))

# 欠陥からRmin_potentialの距離で切り分け、excluded[]を設定する
    rlist_all         = []
    rlist_include     = []
    V1list_all        = []
    V1list_include    = []
    Vdifflist_all     = []
    Vdifflist_include = []
    name1list = []
    name2list = []
    for i in range(nsites2):
#        inf = { "i1": idx1, "name1": atomnamelist1[idx1], "V1": Vcore1[idx1], "pos1": poslist1[idx1],
#                "i2": i,    "name2": atomnamelist2[i],    "V2": Vcore2[i],    "pos2": poslist2[i],
#                "Vdiff": Vdiff, "Defect": defecttype }
#        siteinf.append(inf)
        inf = siteinf[i]
        V1 = inf["V1"]
        V2 = inf["V2"]

        dis  = cry2.GetNearestInterAtomicDistance(pos0, poslist2[i], idx = i) #, irange = [2, 2, 2], idx = i)
#        print("{}: ({:8.4f}, {:8.4f}, {:8.4f})-({:8.4f}, {:8.4f}, {:8.4f}), r={}".format(i, *pos0, *poslist2[i], dis))
        rlist_all.append(dis)
        V1list_all.append(inf["V1"])
        Vdifflist_all.append(Vdifflist[i])
# 欠陥からの距離がRmin_potential以下であればexcludeリストに加える
        if excluded[i] or dis <= Rmin_potential:
            excluded[i] = 1
            if debug:
                print(f"idx2={i} {atomnamelist2[i]} is excluded by Rmin(pot)")
        else:
            excluded[i] = 0
            rlist_include.append(dis)
            V1list_include.append(inf["V1"])
            Vdifflist_include.append(Vdifflist[i])

        name1list.append(inf["name1"])
        name2list.append(inf["name2"])
#        namelist.append("{}{}".format(inf["name2"], i))
        print("    {:4}{:04d} ({:4}{:04d}) ({:8.4f}, {:8.4f}, {:8.4f})  {:10.4f} eV - {:10.4f} eV = {:10.4f} eV  ({:1}) {:10.4g} A"
                    .format(inf["name2"], inf["i2"], inf["name1"], inf["i1"],
                            *inf["pos2"], V2, V1, inf["Vdiff"], not excluded[i], dis))

# Find sites outsite defect regions
    print("")

#    x0 = min(Vdifflist_include) - Wg
#    x1 = max(Vdifflist_include) + Wg
    x0 = min(Vdifflist) - Wg
    x1 = max(Vdifflist) + Wg
    xstep = Wg / 5.0
    xdist, ydist = convolve_xydata(Vdifflist_include, None, Wg, x0, x1, xstep)
#    xdist, ydist = convolve_xydata(Vdifflist, None, Wg, x0, x1, xstep)

    if mode == 'mode':
        ymax = -1.0e100
        for i in range(len(xdist)):
            if ymax < ydist[i]:
                ymax = ydist[i]
                dVmf = xdist[i]

        dVav = 0.0
        count = 0
        for i in range(nsites2):
            inf = siteinf[i]
            V1 = inf["V1"]
            V2 = inf["V2"]
            if not excluded[i] and dVmf - dVth <= V2 - V1 <= dVmf + dVth:
#                excluded[i] = 0
                dVav += V2 - V1
                count += 1
            else:
                excluded[i] = 1
                if debug:
                    print(f"idx2={i} {atomnamelist2[i]} is excluded due to large difference from dVmf")

        if count == 0:
            dVav = 0.0
        else:
            dVav /= count

        print("Vcore difference threshold       (dVth): {:12.4f} eV".format(dVth))
        print("  Most frequent dVcore           (dVmf): {:12.4f} eV".format(dVmf))
        print("  Averaged dVcore in dVth        (dVav): {:12.4f} eV".format(dVav))
    else:
        for iloop in range(nsites2):
            dV = []
            dVav = 0.0
            count = 0
            for i in range(nsites2):
                if excluded[i]:
                    continue

                _dV = siteinf[i]["Vdiff"]
                dVav += _dV
                dV.append(_dV)
                count += 1
            dVav /= count

            dVmax = 0.0
            idxmax = -1
            for i in range(nsites2):
                if excluded[i]:
                    continue
    
                ddV = abs(siteinf[i]["Vdiff"] - dVav)
                if dVmax < ddV:
                    dVmax = ddV
                    idxmax = i

            excluded[idxmax] = 1

            if dVmax < dVth:
                break

        """
        for i in range(nsites2):
            inf = siteinf[i]
            V1 = inf["V1"]
            V2 = inf["V2"]
# site potentialがdVthの範囲に入っていなかったら exclude する
            if not excluded[i] and dVav - dVth <= V2 - V1 <= dVav + dVth:
                excluded[i] = 0
            else:
                excluded[i] = 1
        """

#    print("e", excluded)
        print("Vcore difference threshold       (dVth): {:12.4f} eV".format(dVth))
        print("  Averaged Vcore outside defects (dVav): {:12.4f} eV".format(dVav))
        print("  Max deviation from dVav        (ddV) : {:12.4f} eV".format(ddV))

    print("")
    print("Average core potentials:")
    print("  {}".format(CAR_path1))
    for i in range(nsites1):
        print("    {:04d} {:4} ({:8.4f}, {:8.4f}, {:8.4f})  {} eV".format(i, atomnamelist1[i], *poslist1[i], Vcore1[i]))


    print("")
    print("VBM correction:")
    if mode == 'mode':
        print("  Most frequent dVcore (dVmf): {:12.4f} eV".format(dVmf))
    print("  Averaged dVcore      (dVav): {:12.4f} eV".format(dVav))
    if mode == 'mode':
        dEVBM = dVmf
        print(f"  dEVBM is taken as dVmf for mode={mode}: {dEVBM:12.6g} eV")
    else:
        dEVBM = dVav
        print(f"  dEVBM is taken as dVav for mode={mode}: {dEVBM:12.6g} eV")

    print("")
    print(f"Save summary to {summaryfile}")
    ini = tkIniFile()
    kwargs = {
                "Car_dir1": base_path1,
                "Car_dir2": base_path2,
                "EV":    EV,
                "EC":    EC,
                "Eg":    Eg,
                "dEVBM": dEVBM
            }
    for i in range(len(def_sites)):
        kwargs[f'defect_atom{i}'] = def_sites[i][0]
        kwargs[f'defect_site{i}'] = def_sites[i][1]

    ini.write_from_scratch(summaryfile, 'results', **kwargs)

    print("")
    print("DOS plot range")
    print(f"  Ideal crystal model:")
    print(f"    EF={EF1} eV")
    print(f"    E range: {Edmin1:10.6g} - {Edmax1:10.6g} eV, {nE1} points")
    """
    print(f"    DOS(E) is scaled for the size of the defect model by {nx}x{ny}x{nz}")
    k = nx * ny * nz * 1.0
    for i in range(len(E1)):
        tDOSup1[i] *= k
        Neup1[i]   *= k
        if ISPIN1 == 2:
            tDOSdn2[i] *= k
            Nedn1[i]   *= k
    """

    print(f"  Defect crystal model:")
    print(f"    Original EF={EF2} eV")
    print(f"    Original E range: {Edmin2:10.6g} - {Edmax2:10.6g} eV, {nE2} points")
    EF2 -= dEVBM
    Edmin2 -= dEVBM
    Edmax2 -= dEVBM
    for i in range(len(E2)):
        E2[i] -= dEVBM
    print(f"    Adjusted EF={EF2} eV")
    print(f"    Adjusted E range: {Edmin2:10.6g} - {Edmax2:10.6g} eV, {nE2} points")

#=============================
# Plot graphs
#=============================
    print("")
    print("Plot")
    print(f"Threshold potential to exclude sites near the defect site {dVth}: {dVth:12.4f} eV")
    print(f"Averaged Vcore outside defects (dVav): {dVav:12.4f} eV")
    print(f"  dVav+-dVth: {dVav-dVth:10.4f} - {dVav+dVth:10.4f} eV")
    if mode == 'mode':
        print(f"Most frequent dVcore           (dVmf): {dVmf:12.4f} eV")
        print(f"  dVmf+-dVth: {dVmf-dVth:10.4f} - {dVmf+dVth:10.4f}eV")
#    print("namelist")
    idx = 0
    idxname2 = {}
    for i in range(len(name1list)):
        v = idxname2.get(name2list[i], None)
        if v is None:
            idxname2[name2list[i]] = idx
            idx += 1
#    for i in range(len(name1list)):
#        print(f"{i} {name1list[i]} {name2list[i]} {idxname2[name2list[i]]}")

    fig = plt.figure(figsize = figsize)
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)
    ax1b = ax1.twiny()
    ax2b = ax2.twiny()
#    ax2b = fig.add_subplot(1, 3, 3)

#    print("r=", len(rlist), rlist)
#    print("V1=", len(V1list), V1list)
#    print("V2=", len(Vcore2), Vcore2)
    label_check = {}
    for i in range(len(rlist_all)):
        if excluded[i]:
            key = f"{name2list[i]}:excluded"
            key2 = 'ex'
            colorf = 'white'
            colore = colors[idxname2[name2list[i]] % len(colors)]
        else:
            key = f"{name2list[i]}:included"
            key2 = 'inc'
            colorf = colors[idxname2[name2list[i]] % len(colors)]
            colore = colors[idxname2[name2list[i]] % len(colors)]

        if label_check.get(key, None) is None:
            label_check[key] = 0
            label1_dic = {'label': f'{name1list[i]}(ideal)  {key2}'}
            label2_dic = {'label': f'{name2list[i]}(defect) {key2}'}
        else:
            label_check[key] = 1
            label1_dic = {}
            label2_dic = {}

        ax1.plot(rlist_all[i], V1list_all[i], linestyle = 'none',
                        marker = 'o', markeredgewidth = 1.0, markersize = 6.0, markerfacecolor = colorf, markeredgecolor = colore, **label1_dic)
        ax1.plot(rlist_all[i], Vcore2[i], linestyle = 'none', 
                        marker = '^', markeredgewidth = 1.0, markersize = 6.0, markerfacecolor = colorf, markeredgecolor = colore, **label2_dic)

    ax1.set_xlabel("$r$ ($\\AA$)",    fontsize = fontsize)
    ax1.set_ylabel("$V_{core}$ (eV)", fontsize = fontsize)
    xlim = ax1.get_xlim()
    ax1.set_xlim(xlim)
    ax1.legend(fontsize = legend_fontsize)
    ax1.tick_params(labelsize = fontsize)
    ax1.set_title(CAR_path1[-35:] + '\n' + CAR_path2[-35:], fontsize = titlefontsize)

    label_check = {}
    for i in range(len(rlist_all)):
        if siteinf[i]["Defect"] != "":
            continue
#        if excluded[i]:
#            color = 'white'
#        else:
#            color = 'black'

        if excluded[i]:
            key = f"{name2list[i]}:excluded"
            key2 = 'ex'
            colorf = 'white'
            colore = colors[idxname2[name2list[i]] % len(colors)]
        else:
            key = f"{name2list[i]}:included"
            key2 = 'inc'
            colorf = colors[idxname2[name2list[i]] % len(colors)]
            colore = colors[idxname2[name2list[i]] % len(colors)]

        if label_check.get(key, None) is None:
            label_check[key] = 0
            label2_dic = {'label': f'{name2list[i]} {key2}'}
        else:
            label_check[key] = 1
            label2_dic = {}

        ax2.plot(rlist_all[i], Vdifflist_all[i], linestyle = 'none',
                        marker = 'o', markeredgewidth = 1.0, markersize = 6.0, markerfacecolor = colorf, markeredgecolor = colore, **label2_dic)

    ylim = ax2.get_ylim()
    ylim_dV = [min([ylim[0], Emin]), max([ylim[1], Emax])]
    if mode == 'average':
        ax2.plot(xlim, [dVav, dVav],           label = '$dV_{av}$',           linestyle = 'dashed', linewidth = 0.5, color = 'red')
        ax2.plot(xlim, [dVav+dVth, dVav+dVth], label = '$dV_{av} + dV_{th}$', linestyle = 'dashed', linewidth = 0.5, color = 'blue')
        ax2.plot(xlim, [dVav-dVth, dVav-dVth], label = '$dV_{av} - dV_{th}$', linestyle = 'dashed', linewidth = 0.5, color = 'blue')
        ax2.plot([Rmin_potential, Rmin_potential], ylim, label = '$R_{min}$', linestyle = 'dashed', linewidth = 0.5, color = 'red')
    elif mode == 'mode':
        ax2.plot(xlim, [dVav, dVav],           label = '$dV_{av}$',           linestyle = 'dashed', linewidth = 0.5, color = 'red')
        ax2.plot(xlim, [dVmf, dVmf],           label = '$dV_{mf}$',           linestyle = 'dashed', linewidth = 0.5, color = 'green')
        ax2.plot(xlim, [dVmf+dVth, dVmf+dVth], label = '$dV_{mf} + dV_{th}$', linestyle = 'dashed', linewidth = 0.5, color = 'blue')
        ax2.plot(xlim, [dVmf-dVth, dVmf-dVth], label = '$dV_{mf} - dV_{th}$', linestyle = 'dashed', linewidth = 0.5, color = 'blue')
        ax2.plot([Rmin_potential, Rmin_potential], ylim_dV, label = '$R_{min}$', linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax2.set_xlabel("$r$ ($\\AA$)",      fontsize = fontsize)
    ax2.set_ylabel("$\\Delta$$E$ (eV)", fontsize = fontsize)
    ax2.set_xlim(xlim)
    ax2.set_ylim(ylim_dV)
    ax2.legend(fontsize = legend_fontsize)
    ax2.tick_params(labelsize = fontsize)
    ax2.set_title("$\\Delta$$E_{VBM}$ = %0.4g eV ($\\Delta$$V_{th}$ = %0.3g eV)" % (dVav, dVth), fontsize = titlefontsize)

# 目盛りに文字列を表示する
# グラフ枠が一つであれば plt.xtics()で設定できる
# axisに対しては、.setpでattributeを直接書き換える必要があるらしい
    ax1b.tick_params(labelsize = labelfontsize)
    plt.setp(ax1b, xticks = rlist_all, xticklabels = name2list)
    ax1b.set_xticklabels(name2list, rotation = 90.0)
    ax1b.set_xlim(xlim)

    """
    ax2.tick_params(labelsize = labelfontsize)
    plt.setp(ax2, xticks = rlist_all, xticklabels = namelist)
    ax2.set_xticklabels(namelist, rotation = 90.0)
    ax2.set_xlim(xlim)
    """

# Plot distribution
    ax2b.plot(ydist, xdist, linestyle = 'dashed', linewidth = 0.3, color = 'green')
#    ax2b.plot(xlim, [dVav, dVav],           linestyle = 'dashed', linewidth = 0.5, color = 'red')
#    ax2b.plot(xlim, [dVav+dVth, dVav+dVth], linestyle = 'dashed', linewidth = 0.5, color = 'blue')
#    ax2b.plot(xlim, [dVav-dVth, dVav-dVth], linestyle = 'dashed', linewidth = 0.5, color = 'blue')
    ax2b.set_xlabel("Distribution",      fontsize = fontsize)
    ax2b.set_ylabel("$\\Delta$$E$ (eV)", fontsize = fontsize)
#    ax2b.legend(fontsize = legend_fontsize)
    ax2b.tick_params(labelsize = fontsize)


# Rearange the graph axes so that they are not overlapped
    plt.tight_layout()

    plt.pause(0.1)

#======================================================
# DOS plot
#======================================================

    fig, axes = plt.subplots(1, 1, sharex = 'all', figsize = figsize)
#    ax1, ax2 = axes
    ax1 = axes
    ax2 = ax1

    tDOSup_s1 = convolution(E1, tDOSup1, WG_DOS, func_type = 'gauss')
    ax1.plot(E1, tDOSup_s1, label = 'up(ideal)', linestyle = '-', linewidth = 0.5, color = 'black')
    if ISPIN1 == 2:
        tDOSdn_s1 = convolution(E, tDOSdn1, WG_DOS, func_type = 'gauss')
        ax1.plot(E1, tDOSdn_s1, label = 'dn(ideal)', linestyle = 'dashed', linewidth = 0.5, color = 'black')
    ylim = ax1.get_ylim()
    ax1.plot([EV, EV],       ylim, label = '$E_V$(ideal)',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax1.plot([EC, EC],       ylim, label = '$E_C$(ideal)',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax1.plot([EF1, EF1],     ylim, label = '$E_F$(ideal)',      linestyle = 'dashed', linewidth = 0.5, color = 'blue')
#    ax1.plot([EF0, EF0],     ylim, label = '$E_{F0}$',   linestyle = 'dashed', linewidth = 0.5, color = 'blue')
    ax1.plot([EHOMO, EHOMO], ylim, label = '$E_{HOMO}$(ideal)', linestyle = 'dashed', linewidth = 0.5, color = 'green')
    ax1.plot([ELUMO, ELUMO], ylim, label = '$E_{LUMO}$(ideal)', linestyle = 'dashed', linewidth = 0.5, color = 'green')
#    ax1.set_xlabel("$E$ (eV)",    fontsize = fontsize)
    ax1.set_ylabel("DOS (states/cm$^3$/eV)", fontsize = fontsize)
#    ax1.set_xlim([Emin, Emax])
    ax1.tick_params(labelsize = fontsize)

    tDOSup_s2 = convolution(E2, tDOSup2, WG_DOS, func_type = 'gauss')
    ax2.plot(E2, tDOSup_s2, label = 'up(defect)', linestyle = '-', linewidth = 1.0, color = 'red')
    if ISPIN2 == 2:
        tDOSdn_s2 = convolution(E, tDOSdn2, WG_DOS, func_type = 'gauss')
        ax2.plot(E2, tDOSdn_s2, label = 'dn(defect)', linestyle = 'dashed', linewidth = 1.0, color = 'red')
#    ylim = ax2.get_ylim()
#    ax2.plot([EV, EV],       ylim, label = '$E_V$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
#    ax2.plot([EC, EC],       ylim, label = '$E_C$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
#    ax1.plot([EF0, EF0],     ylim, label = '$E_{F0}$',   linestyle = 'dashed', linewidth = 0.5, color = 'blue')
    ax2.plot([EF2, EF2],     ylim, label = '$E_F$(defect)',      linestyle = 'dashed', linewidth = 1.0, color = 'blue')
#    ax2.plot([EHOMO, EHOMO], ylim, label = '$E_{HOMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
#    ax2.plot([ELUMO, ELUMO], ylim, label = '$E_{LUMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
#    ax2.set_xlabel("$E$ (eV)",    fontsize = fontsize)
#    ax2.set_ylabel("DOS (states/cm$^3$/eV)", fontsize = fontsize)
#    ax2.set_xlim([Emin, Emax])
#    ax2.tick_params(labelsize = fontsize)

    ax1.legend(fontsize = legend_fontsize)
#    ax2.legend(fontsize = legend_fontsize)

    use_pause = True
    if use_pause:
        plt.pause(0.1)
        print("")
        app.terminate(f"", pause = True)
    else:
        plt.show()
        print("")
        print("Close graph window to terminate")
    
#    app.terminate(f"", usage = usage, pause = True)


def main():
    global summaryfile
    
    updatevars()

    vasp = tkVASP()
    base_path = vasp.getdir(CAR_dir2)
#    logfile = app.replace_path(infile)
    logfile     = os.path.join(base_path, 'VBM_correction-out.txt')
    summaryfile = os.path.join(base_path, 'VBM_correction-summary.prm')
    print("")
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    print("")
    print("=============== Estimate VBM correction for defect model calculated by VASP ============")
    print("")
    print("mode: ", mode)

    if mode == 'average' or mode == 'mode':
        VBM_correction(mode, CAR_dir1, CAR_dir2, Emin, Emax, Rmin_vacancy, Rmin_potential)
    else:
        app.terminate(f"Error: Invalide mode [{mode}]", usage = usage, pause = True)


if __name__ == "__main__":
    main()
