import os
import sys
import glob
import numpy as np

from tklib.tkapplication import tkApplication


print()
print("XRD_GUI_lib loaded")

def error_message(message):
    print()
    print("#############################################")
    print(f"Error in XRD_GUI_lib: {message}")
    print("#############################################")
    print()

root_dir = os.getenv('tkprog_X_path', None)
if root_dir is None:
    error_message("Environment variable tkprog_X_path must be specified")
    input("Pree ENTER to terminate>>\n")
    exit()

filter_dir = os.path.join(root_dir, "xrd", "filter")

app     = tkApplication()
cparams = app.get_params()
cparams.debug = 0
cparams.findvalidstructure = True
cparams.plugin_dir = 'filter'
cparams.mode = 'plot'
cparams.infile    = '*.txt'
cparams.cif_files = 'data/*.*'
cparams.beam = 'X-ray'
cparams.wavelength = "CuKa"
cparams.xmin  = 20.0
cparams.xmax  = 120.0
cparams.xstep = 0.02
cparams.fwhm   = 0.2
cparams.Gfraction = 0.5
cparams.fwhm_smear      = 0.0
cparams.Gfraction_smear = 0.0
cparams.yscale  = 'linear'
cparams.BGorder = 3
cparams.alpha   = 0.1


print(f"Load modules from {filter_dir}")
module_names, modules = app.load_modules(filter_dir, "*.py", target = "read_data", is_print = True)
for m in modules:
    input_type  = m.get_input_type(app = app, cparams = cparams)
    output_type = m.get_output_type(app = app, cparams = cparams)
    print(f"  {m.name}: input_type={input_type}  output_type={output_type}")


def set_two_theta_range(xmin=None, xmax=None, xstep=None):
    """
    GUI側から2θ範囲を変更するための設定関数。
    None は変更しない。
    """
    global cparams
    if xmin is not None:
        cparams.xmin = float(xmin)
    if xmax is not None:
        cparams.xmax = float(xmax)
    if xstep is not None:
        cparams.xstep = float(xstep)

def set_wavelength(wavelength=None):
    global cparams
    if wavelength is not None:
        cparams.wavelength = str(wavelength)

def parse_xrd(path):
    print(f"Read {path} in XRD_GUI_lib.parse_xrd() using filters in {filter_dir}")
#    raise ValueError("")

    module = None
    for i in range(len(modules)):
        name = module_names[i]
        m = modules[i]

        file_type  = m.check_file_type(path, app = app, cparams = cparams)
        print(f"try [{name}] for [{path}]: file_type={file_type}")
        if file_type is not None and 'Error' not in file_type:
            print("   type matched.")
            module = m
            break

    if module is None:
        raise ValueError(f"Failed to find modules in {filter_dir}")
#        return None, None, None

    inf = module.read_data(path, app = app, cparams = cparams)

    if not inf:
        raise ValueError("Failed to read {path} by modules in {filter_dir}")

#   module_input.print_data(inf_input)
    inf_input = module.convert(inf, cparams = cparams)
    data_list = inf_input["data_list"][0]
    if type(data_list[0]) is float or type(data_list[0]) is int:
        data_list = inf_input["data_list"]

    sample_name = inf_input["sample_name"]
    xQ2_infile  = data_list[0]
    yobs_infile = data_list[1]
#    if len(inf_input["data_list"]) >= 3:
#        ysim_infile = inf_input["data_list"][2]
#    else:
#        ysim_infile = None

#    print("xQ2_infile=", xQ2_infile)
#    print("yobs_infile=", yobs_infile)
    return sample_name, np.array(xQ2_infile), np.array(yobs_infile)

def _to_hkl_str_and_raw(hkl):
    """
    pymatgenのhkl(tuple) → "(h k l)" 文字列 と raw_indices(dict) に変換
    3指数: (h,k,l)
    4指数: (h,k,i,l) の可能性もあり
    """
    try:
        if len(hkl) == 3:
            h, k, l = int(hkl[0]), int(hkl[1]), int(hkl[2])
            return f"({h} {k} {l})", {"h": h, "k": k, "l": l, "i": -(h + k)}
        elif len(hkl) == 4:
            h, k, i_val, l = int(hkl[0]), int(hkl[1]), int(hkl[2]), int(hkl[3])
            return f"({h} {k} {i_val} {l})", {"h": h, "k": k, "l": l, "i": i_val}
    except:
        pass
    return str(hkl), {"h": 0, "k": 0, "l": 0, "i": 0}


def parse_reference(path, xmin=None, xmax=None, xstep=None, wavelength=None, normalize=True):
    """
    GUI用: CIFなどから回折ピーク(Ref)を作成して返す。
    返り値: (reference_name, positions, intensities, hkls, raw_indices)
    """
    # GUIから範囲指定があればここで反映
    if xmin is not None or xmax is not None or xstep is not None:
        set_two_theta_range(xmin=xmin, xmax=xmax, xstep=xstep)
    if wavelength is not None:
        set_wavelength(wavelength=wavelength)

    print(f"Read {path} in XRD_GUI_lib.parse_reference() using filters in {filter_dir}")

    module = None
    for i in range(len(modules)):
        name = module_names[i]
        m = modules[i]
        file_type = m.check_file_type(path, app=app, cparams=cparams)
        print(f"try [{name}] for [{path}]: file_type={file_type}")
        if file_type is not None and 'Error' not in file_type:
            print("   type matched.")
            module = m
            break

    if module is None:
        raise ValueError(f"Failed to find modules in {filter_dir}")

    inf = module.read_data(path, app=app, cparams=cparams)
    if not inf:
        raise ValueError(f"Failed to read {path} by modules in {filter_dir}")

    inf_input = module.convert(inf, cparams=cparams)

    # diffractions から peak list を取る
    diff = inf_input.get("diffractions", {})
    positions = diff.get("Q2", [])
    intensities = diff.get("intensity", [])
    hkls_raw = diff.get("hkl", [])

    # hkls を文字列化し raw_indices も作る
    hkls = []
    raw_indices = []
    for hkl in hkls_raw:
        s, r = _to_hkl_str_and_raw(hkl)
        hkls.append(s)
        raw_indices.append(r)

    # 強度の正規化（ref描画用に0〜1へ）
    if normalize and intensities:
        maxI = max(intensities)
        if maxI > 0:
            intensities = [float(I) / float(maxI) for I in intensities]
        else:
            intensities = [0.0 for _ in intensities]

    reference_name = os.path.splitext(os.path.basename(path))[0]
    return reference_name, positions, intensities, hkls, raw_indices

def get_supported_file_filters():
    """
    ロードされた全プラグインが対応するファイルフィルター文字列を生成して返す
    （データ／リファレンスの区別なしバージョン）
    """
    if not modules:
        return "テキストファイル (*.txt);;全てのファイル (*.*)"
    
    filters = []
    for m in modules:
        try:
            input_type_dict = m.get_input_type(app=app, cparams=cparams)
            
            if not isinstance(input_type_dict, dict):
                continue

            desc, ext = None, None

            if 'file_type' in input_type_dict:
                file_type_str = input_type_dict['file_type'].strip()
                last_space_index = file_type_str.rfind(' ')
                if last_space_index != -1 and '.' in file_type_str[last_space_index:]:
                    desc = file_type_str[:last_space_index].strip()
                    ext_part = file_type_str[last_space_index+1:].strip()
                    ext = f"*{ext_part}"
                else:
                    desc = file_type_str
                    ext = f"*.{file_type_str.split()[-1].lower().lstrip('.')}"
            elif 'description' in input_type_dict and 'extension' in input_type_dict:
                desc = input_type_dict['description']
                ext = input_type_dict['extension']
            
            if desc and ext:
                filters.append(f"{desc} ({ext})")

        except Exception as e:
            print(f"フィルター取得エラー ({m.name}): {e}")

    if filters:
        unique_filters = sorted(list(set(filters)), key=filters.index)
        return ";;".join(unique_filters) + ";;全てのファイル (*.*)"
    else:
        return "全てのファイル (*.*)"


def main():
#    infile = "D:/git/tkProg/tkprog_COE/XRD/data/phase1.cif"
#    infile = "D:/git/tkProg/tkprog_COE/XRD/data/Bi_R-3m.xlsx"
    infile = "D:/git/tkProg/tkprog_COE/XRD/data/240219_AlScN_300_5h_oradw_25_Al100.txt"
    parse_xrd(infile)

#    infile = "D:/git/tkProg/tkprog_COE/XRD/data/phase1.cif"
#    parse_reference(infile)
    exit()


if __name__ == "__main__":
    main()
    