import xml.etree.ElementTree as ET
import numpy as np
from pymatgen.io.vasp.outputs import Vasprun
import csv

Nline = 20 # number of kpoints in a line

class KPOINTS_OPT_Parser:

    def __init__(self, file_path, Nline=20):
        self.Nline = Nline
        self.file_path = file_path
        self.vasprun, self.xml_root = self.parse_vasprun()
        self.fermi_level = self.get_fermi_level()
        self.kpoints = self.get_kpoints()
        self.weights_list = self.get_weights()
        self.reciprocal_lattice = self.vasprun.final_structure.lattice.reciprocal_lattice.matrix
        self.bands_data = self.get_eigenvalues()
        self.kpath = self.get_kpath()

    def parse_vasprun(self):
        vasprun = Vasprun(self.file_path)
        tree = ET.parse(self.file_path)
        xml_root = tree.getroot()
        return vasprun, xml_root

    def get_fermi_level(self):
        kpoints_opt_dos = self.xml_root.find(".//dos[@comment='kpoints_opt']")
        fermi_level = float(kpoints_opt_dos.find(".//i[@name='efermi']").text.strip())
        return fermi_level

    def get_kpoints(self):
        kpoints_opt = self.xml_root.find(".//eigenvalues_kpoints_opt")
        kpointlist = kpoints_opt.find(".//varray[@name='kpointlist']")
        kpoints = [np.array([float(x) for x in v.text.split()]) for v in kpointlist.findall('v')]
        return kpoints

    def get_weights(self):
        kpoints_opt = self.xml_root.find(".//eigenvalues_kpoints_opt")
        weights = kpoints_opt.find(".//varray[@name='weights']")
        weights_list = [float(v.text) for v in weights.findall('v')]
        return weights_list

    def get_eigenvalues(self):
        fermi_level = self.fermi_level
        kpoints_opt = self.xml_root.find(".//eigenvalues_kpoints_opt")
        eigenvalues = kpoints_opt.find(".//eigenvalues")
        bands_data = []
        for spin_set in eigenvalues.findall('.//set[@comment]'):
            spin = spin_set.get('comment')
            for kpoint_set in spin_set.findall('.//set[@comment]'):
                kpoint = kpoint_set.get('comment')
                for r in kpoint_set.findall('r'):
                    bands_data.append({
                        'spin': spin,
                        'kpoint': kpoint,
                        'eigenvalue': float(r.text.strip()) - fermi_level
                    })
        return bands_data
    
    def kpoints_opt_dos(self):
        kpoints_opt_dos = self.xml_root.find(".//dos[@comment='kpoints_opt']")
        dos_data = kpoints_opt_dos.find(".//total/array")
        return dos_data
    
    def dos_spin(self, dos_data):
        fermi_level = self.fermi_level
        dos_spin = {}
        for spin_set in dos_data.findall(".//set[@comment]"):
            spin = spin_set.get('comment').split()[1]  # "spin 1" -> "1"
            dos_spin[spin] = []

            for r in spin_set.findall('r'):
                energy, total, integrated = [float(x) for x in r.text.split()]
                dos_spin[spin].append({
                    'energy': energy - fermi_level,  # フェルミ準位を0とする
                    'total': total,
                    'integrated': integrated
                })
        return dos_spin

    def get_kpath(self):
        reciprocal_lattice = self.reciprocal_lattice
        kpoints = self.kpoints
        Nline = self.Nline

        kpath = [0]
        for i in range(1, len(kpoints)):
            if i % Nline == 0:
                kpath.append(kpath[-1])
            else:
                dk = np.linalg.norm(np.dot(kpoints[i] - kpoints[i - 1], reciprocal_lattice))
                kpath.append(kpath[-1] + dk)
        return kpath

    def eigenvalues_per_kpoint(self):
        bands_data = self.get_eigenvalues()
        eigenvalues_per_kpoint = {}
        for band in bands_data:
            kpoint_index = int(band['kpoint'].split()[1]) - 1
            spin = band['spin']
            if (spin, kpoint_index) not in eigenvalues_per_kpoint:
                eigenvalues_per_kpoint[(spin, kpoint_index)] = []
            eigenvalues_per_kpoint[(spin, kpoint_index)].append(band['eigenvalue'])
        return eigenvalues_per_kpoint
    
def write_kptopt_eig_csv(file_path, Nline=20):
    parser = KPOINTS_OPT_Parser(file_path, Nline)

    kpoints = parser.kpoints
    weights_list = parser.weights_list
    kpath = parser.kpath

    # 各k点に対する固有値をまとめる
    eigenvalues_dict = parser.eigenvalues_per_kpoint()

    # CSVファイルに書き込む
    with open('kpoints_opt_eigenvalues.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)

        # 最大バンド数を取得
        max_bands = max(len(eigenvalues) for eigenvalues in eigenvalues_dict.values())

        # ヘッダーの書き込み
        header = ['spin', 'kpoint_index', 'kpath', 'weight', 'kx', 'ky', 'kz'] + [f'band_{i+1}' for i in range(max_bands)]
        writer.writerow(header)

        # スピンとk点インデックスでソート
        sorted_keys = sorted(eigenvalues_dict.keys())

        # 各k点ごとに行を書き込む
        for key in sorted_keys:
            spin, kpoint_index = key
            kpoint = kpoints[kpoint_index]
            weight = weights_list[kpoint_index]
            row = [spin, kpoint_index, kpath[kpoint_index], weight, *kpoint] + eigenvalues_dict[key]
            writer.writerow(row)

def write_kptopt_dos_csv(file_path, Nline=20):

    # vasprunParserオブジェクトの作成
    parser = KPOINTS_OPT_Parser(file_path, Nline)

    # vasprun.xmlファイルの読み込み
    dos_data = parser.kpoints_opt_dos()
    dos_spin = parser.dos_spin(dos_data)

    # CSVファイルに書き込み
    with open('kpoints_opt_dos.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)

        # ヘッダーの書き込み
        writer.writerow(['spin', 'energy', 'total', 'integrated'])

        # データの書き込み
        for spin, data in dos_spin.items():
            for record in data:
                writer.writerow([spin, record['energy'], record['total'], record['integrated']])

def write_high_symmetry_kpoints_opt(file_path, Nline=20):

    # make vasprunParser object
    parser = KPOINTS_OPT_Parser(file_path, Nline)

    # read vasprun.xml
    bands_data = parser.bands_data
    kpath = parser.get_kpath()

    # Calculate the minimum and maximum eigenvalues
    eigenvalues = [band['eigenvalue'] for band in bands_data]
    min_eigenvalue = min(eigenvalues) * 1.1
    max_eigenvalue = max(eigenvalues) * 1.1

    # Create a list for the first column
    kpath_column = [kpath[i] for i in range(Nline-1, len(kpath), Nline) for _ in range(2)]

    # (min_eigenvalue, max_eigenvalue, max_eigenvalue, min_eigenvalue)のパターンでリストに追加
    pattern = [min_eigenvalue, max_eigenvalue, max_eigenvalue, min_eigenvalue]
    eigenvalue_column = [pattern[i % 4] for i in range(len(kpath_column))]
    # return kpath_column, eigenvalue_column

    # Write the data to the CSV file
    with open('kpath_origin.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)

        # Write the header
        writer.writerow(['kpath', 'eigenvalue'])

        # Write the data rows
        for kpath_val, eigenvalue_val in zip(kpath_column, eigenvalue_column):
            writer.writerow([kpath_val, eigenvalue_val])


if __name__ == '__main__':
    write_kptopt_eig_csv("vasprun.xml", Nline)
    write_kptopt_dos_csv("vasprun.xml", Nline)
    write_high_symmetry_kpoints_opt("vasprun.xml", Nline)
    
