import os
import sys
import xml.etree.ElementTree as ET
import xml.dom.minidom as md
import numpy as np

import xmltodict
import dicttoxml


infile = 'opt_2onl.xml'

nargs = len(sys.argv)
if nargs > 1:
    infile = sys.argv[1]


def get_elements_recursive(element):
    '''
    input: element 
    return: if element has children: string including child element strings recursively
            else: list of dict information about children recursively
    '''

    ne = len(element)
    
    if ne == 0 or element.text is None:
        return element.text

    inf = []
# if element has child elements
    for child_element in element:
        tag    = child_element.tag
        ctext  = child_element.text
        attrib = child_element.attrib
        child_inf = get_elements_recursive(child_element)
        if type(child_inf) is str:
            inf.append({"tag": tag, "text": ctext, "attrib": attrib})
        else:
            inf.append({"tag": tag, "text": ctext, "attrib": attrib, "child": child_inf})

    return inf

def get_unique_key(d, section, val = None):
    if val is None:
        key = section
    else:
        key = f"{section}:{val}"
    
    if d.get(key, None) is None:
        return key

    i = 0
    while True:
        if val is None:
            key = f"{section}[{i}]"
        else:
            key = f"{section}[{i}]:{val}"

        if d.get(key, None) is None:
            return key

        i += 1

def get_section_inf(parent, sections = None, inf = None, params = None, ret_type = 'list', add_parent_params = True):
    '''
    input: parent: element instance
           sections: list of element tags
           ret_type: return variable type [list|dict]
    return: dict/list of information about child elements
            if the last section includes multiple data: dict/list includes dict of theri attributes
            else: dict/list includes string of child elements recursively
    '''

    if inf is None:
        if ret_type == 'list':
            inf_child = []
        else:
            inf_child = {}
    else:
        inf_child = inf.copy()

    if sections is None or len(sections) == 0:
#        return None
        return inf_child

    ns = len(sections)
    if params is None:
        params_child = {}
    else:
        params_child = params.copy()

    section_list = parent
    for i in range(ns - 1):
        section_list = section_list.find(sections[i])
        if section_list is None:
            return None

        for _key, _val in section_list.items():
            key = get_unique_key(params_child, sections[i], _key)
            params_child[key] = _val

    sections_last = section_list.findall(sections[-1])
    for parameter in sections_last:
#        print(f"parameter for {sections[-1]}:", parameter, type(parameter))
        keys = parameter.keys()
#        print("  keys: ", keys)
        if len(keys) > 0:
            pkey = parameter.get(keys[0], None)
            _inf = {}
            for k, v in parameter.items():
                _inf[k] = v
        else:
            pkey = "list"
            _inf = get_elements_recursive(parameter)

        if ret_type == 'list':
            inf_child.append(_inf)
        else:
            inf_child[pkey] = _inf

    if add_parent_params:
        if ret_type == 'list':
            inf_child.append(params_child)
        else:
            inf_child["params"] = params_child

    return inf_child

def get_section_inf_all(parent, sections = None, section_parent = 'root', inf = None, 
                level = 0, params = None, pkey = None, ret_type = 'list', last_node_only = True):
    '''
    input: parent: element instance
           sections: list of element tags
           ret_type: return variable type [list|dict]
    return: dict/list of information about child elements
            if the last section includes multiple data: dict/list includes dict of theri attributes
            else: dict/list includes string of child elements recursively
    '''

#    ret_type = 'list'

    if inf is None:
        if ret_type == 'list':
            inf = []
        else:
            inf = {}

    if sections is None or len(sections) == 0:
        return inf

    ns = len(sections)
    if params is None:
        params_child = {}
    else:
        params_child = params.copy()

    for _key, _val in parent.items():
        key = get_unique_key(params_child, section_parent, _key)
        params_child[key] = _val

#    print(f"level={level} sections={sections} sections[0]={sections[0]} param={params_child}")
    section_list = parent.findall(sections[0])
    for i, section in enumerate(section_list):
        if len(sections) == 0:
            print(f"  **return  len(inf)={len(inf)}")
            return inf

        tag    = section.tag
        ctext  = section.text
        attrib = section.attrib
        keys   = section.keys()
        if ret_type == 'dict':
            if len(keys) == 0:
                _pkey = 'list'
            else:
                if pkey is None:
                    _pkey = keys[0]
                else:
                    _pkey = pkey
                _pkey = get_unique_key(inf, section.get(_pkey, keys[0]), None)

        _inf = {"level": level, "tag": tag, "text": ctext, "attrib": attrib, "params": params_child}

        if not last_node_only or len(sections) == 1:
            if ret_type == 'list':
                inf.append(_inf)
            else:
                inf[_pkey] = _inf

        inf_ret = get_section_inf_all(section, section_parent = sections[0], sections = sections[1:], 
                        level = level + 1, inf = inf, params = params_child, 
                        pkey = pkey, ret_type = ret_type, last_node_only = last_node_only)
#        print("  in for: i=", i, "  _inf=", _inf)

    return inf

def xml2dict(xml):
    '''
    input: xml instance
    return: dict converted from the input xml
    '''
    return xmltodict.parse(xml)

def _element_to_dict(element):
    """
    xml.etree.ElementTreeのElementオブジェクトを再帰的に辞書に変換します。
    属性は '@attributes' キーの下に、テキスト内容は '#text' キーの下に格納されます。
    同じタグ名の子要素が複数ある場合はリストとして扱われます。
    """
    result = {}

    # 要素の属性を処理
    if element.attrib:
        result['@attributes'] = element.attrib

    # 要素のテキスト内容を処理
    if element.text and element.text.strip():
        result['#text'] = element.text.strip()

    # 子要素を処理
    for child in element:
        child_dict = _element_to_dict(child)
        if child.tag in result:
            # 同じタグ名の子要素が複数ある場合、リストに追加
            if not isinstance(result[child.tag], list):
                result[child.tag] = [result[child.tag]]
            result[child.tag].append(child_dict)
        else:
            # 新しいタグ名の子要素の場合
            result[child.tag] = child_dict
    return result

def file2dict(infile):
    '''
    input: input file path
    return: dict converted from the xml red from the input file
    '''
    with open(infile) as fp:
        xml = fp.read()

    return xml2dict(xml)

def dict2xml(d, attr_type = True, root = True):
    '''
    input: dict
    return: xml instance converted from the input dict
    '''
    return dicttoxml.dicttoxml(d, attr_type = attr_type, root = root)    

def to_xml(outfile, element, encoding = 'utf-8', newl = '', indent = '', addindent = '    ',
            xml_declaration = True, use_minidom = False):
    '''
    input: outfile: path to write
           element: root element to write
           use_minidom: reformat xml
              True: newl: new line characters
                    indent: indent characters
                    addindent: indent characters
           xml_declaration: flag to add XML document declaration
    return: null
    '''
    if not use_minidom:
        tree = ET.ElementTree(element)
        tree.write(outfile, encoding = encoding, xml_declaration = xml_declaration)
    else:
        doc = md.parseString(ET.tostring(element, encoding = encoding))
        fp = open(outfile, 'w')
        doc.writexml(fp, encoding = encoding, newl = newl, indent = indent, addindent = addindent)
        fp.close()

def get_attrib(element):
    '''
    input: element instance
    return: XML tag, content text and attributes of the input element
    '''
    return element.tag, element.text, element.attrib
    
def get_root(xml):
    '''
    input: xml instance or input file path
    return: XML root element
    '''
    if os.path.isfile(xml):
        xml_tree = ET.parse(xml)
        xml_root = xml_tree.getroot()
    else:
        xml_root = ET.fromstring(xml)

    return xml_root


def main():
    print()
    print(f"infile: {infile}")

#    dict_xml = file2dict(infile)
#    print("dict:", dict_xml)
#    _xml = dict2xml(dict_xml)
#    print("xml:", _xml)

    xml_root = get_root(infile)
#    tag, text, attrib = get_attrib(xml_root)
#    print("tag=", tag)
#    print("text=", text)
#    print("attrib=", attrib)

    if xml_root.find('.//parameter-list'):
        print()
        print("ATLAS optimization XML:")
        parameter_inf = get_section_inf_all(xml_root, ['parameter-list', 'parameter'], ret_type = 'dict')
#        parameter_inf = get_section_inf(xml_root, ['parameter-list', 'parameter'], ret_type = 'dict')
#        parameter_inf = get_section_inf(xml_root.find('parameter-list'), ['parameter'])
#       parameter_inf = get_section_inf(xml_root, ['.//parameter'])
#        parameter_inf = get_section_inf(xml_root, ['.//parameter-list'])
        print("parameter_inf:")
        for key, val in parameter_inf.items():
            print(f"  {key}: {val}")
        setting_inf = get_section_inf_all(xml_root, ['settings', 'setting'], ret_type = 'dict')
#        setting_inf = get_section_inf(xml_root, ['settings', 'setting'], ret_type = 'dict')
        print("setting_inf:")
        for key, val in setting_inf.items():
            print(f"  {key}: {val}")
        target_inf = get_section_inf_all(xml_root, ['target-list', 'target'], ret_type = 'dict')
#        target_inf = get_section_inf(xml_root, ['target-list', 'target'], ret_type = 'dict')
        print("target_inf:")
        for key, val in target_inf.items():
            print(f"  {key}: {val}")
    elif xml_root.find('.//incar'):
        print()
        print("vasprun.xml:")
#        band_inf = get_section_inf(xml_root)
#        band_inf = get_section_inf(xml_root, [".//projected_kpoints_opt", ".//eigenvalues", './/set[@comment]'])
#        band_inf = get_section_inf(xml_root, [".//projected_kpoints_opt", ".//eigenvalues", './/set[@comment]', './/set[@comment]', './/r'])
        band_inf = get_section_inf_all(xml_root, [".//projected_kpoints_opt", ".//eigenvalues", './/set[@comment]', './/set[@comment]', './/r'])
        print("band_inf:")
#        for inf in band_inf:
#            print(f"  {inf}")
#        incar_inf = get_section_inf(xml_root, [".//incar"], ret_type = 'dict')
#        print("incar_inf:", incar_inf)
#        incar_inf = get_section_inf_all(xml_root, [".//incar", "*"], ret_type = 'list', last_node_only = False)
#        print("incar_inf:")
#        for val in incar_inf:
#            print(f"  {val}")
        incar_inf = get_section_inf_all(xml_root, [".//incar", "*"], pkey = 'name', ret_type = 'dict', last_node_only = False)
        for key, val in incar_inf.items():
            print(f"  {key}: {val}")
    else:
        print()
        print("Error: Invalid XML type")
        exit()


if __name__ == '__main__':
    main()
    