import os
import sys
import glob
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.special import legendre
from scipy.interpolate import interp1d
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_absolute_error, mean_squared_error


from tklib.tkapplication import tkApplication
from tklib.tkfile import tkFile
from tklib.tkutils import getarg, getintarg, getfloatarg, pint, pfloat, split_file_path, replace_path
from tklib.tkvariousdata import tkVariousData
from tklib.tksci.tksci import Gaussian, Lorentzian, GaussLorentz
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tkcrystal.tkxrd import Xray_wavelengths
from tklib.tkgraphic.tkplotevent import tkPlotEvent


"""
Convert CIF file to VASP POSCAR file
"""

# Xray_wavelengths
# URL: https://github.com/materialsproject/pymatgen/blob/v2023.5.10/pymatgen/analysis/diffraction/xrd.py
# 'CuKa', 'CuKa2', 'CuKa1', 'CuKb1', 'MoKa', 'MoKa2', 'MoKa1', 'MoKb1',
# 'CrKa', 'CrKa2', 'CrKa1', 'CrKb1', 'FeKa', 'FeKa2', 'FeKa1', 'FeKb1', 'CoKa', 'CoKa2', 'CoKa1', 'CoKb1',
# 'AgKa', 'AgKa2', 'AgKa1', 'AgKb1'



usage_str = '''
"  (i) usage: python {} mode input_path cif_dir wavelength Q2min Q2max Q2step fwhm yscale BG_order LASSO_alpha".format(sys.argv[0])
"         ex: python {} plot_all  {} {} {} {} {} {} {} {}".format(sys.argv[0], cparams.infile, cparams.wavelength, cparams.xmin, cparams.xmax, cparams.xstep, cparams.fwhm, cparams.yscale, cparams.BGorder, cparams.alpha)
"         ex: python {} plot_one  {} {} {} {} {} {} {} {}".format(sys.argv[0], cparams.infile, cparams.wavelength, cparams.xmin, cparams.xmax, cparams.xstep, cparams.fwhm, cparams.yscale, cparams.BGorder, cparams.alpha)
"         ex: python {} plot_one2 {} {} {} {} {} {} {} {} {}".format(sys.argv[0], cparams.infile, cparams.wavelength, cparams.xmin, cparams.xmax, cparams.xstep, cparams.fwhm, cparams.yscale, cparams.BGorder, cparams.alpha)
"         ex: python {} fit {} {} {} {} {} {} {} {} {}".format(sys.argv[0], cparams.infile, cparams.wavelength, cparams.xmin, cparams.xmax, cparams.xstep, cparams.fwhm, cparams.yscale, cparams.BGorder, cparams.alpha)
'''[1:-1]



#================================
# global parameters
#================================
module_names = []
modules = []

markers = ['o', 's', '+', 'x', 'D', 'v', '^', '<', '>', '8', 'h', 'H']
colors  = ['black', 'red', 'blue', 'darkgreen', 'darkorange', 'hotpink', 'lightgreen', 'cyan', 'yellow', 'magenta', 'chocolate', 
           'navy', 'slategray', 'olive' ]

figsize       = (12, 8)
figsize_one   = (10, 6)
figsize_coeff = (8, 6)
fontsize            = 12
legend_fontsize     = 8
legend_fontsize_one = 12


#=============================
# Treat argments
#=============================
def usage(app = None, cparams = None):
    cparams = app.get_params()
#    app.usage(infile = cparams.infile)
    for s in app.usage_str.split('\n'):
        cmd = 'print({})'.format(s.rstrip())
        eval(cmd)

def initialize(app = None, cparams = None):
    cparams.debug = 0
    cparams.findvalidstructure = True

    cparams.plugin_dir = 'filter'

    cparams.mode = 'plot'

    cparams.infile    = '*.txt'
    cparams.cif_files = 'data/*.*'

    cparams.beam = 'X-ray'
    cparams.wavelength = "CuKa"
    cparams.xmin  = 20.0
    cparams.xmax  = 120.0
    cparams.xstep = 0.02
    cparams.fwhm   = 0.2
    cparams.Gfraction = 0.5

    cparams.fwhm_smear      = 0.0
    cparams.Gfraction_smear = 0.0

    cparams.yscale  = 'linear'
    cparams.BGorder = 3
    cparams.alpha   = 0.1

def update_vars(app = None, cparams = None):
#    if getarg(2, None) is None:
#        app.terminate(usage = usage)
    
    cparams.mode            = getarg(1, cparams.mode)
    cparams.plugin_dir      = getarg(2, cparams.plugin_dir)
    cparams.infile          = getarg(3, cparams.infile)
    cparams.cif_files       = getarg(4, cparams.cif_files)
    cparams.wavelength      = getarg(5, cparams.wavelength)
    cparams.xmin            = getfloatarg(6, cparams.xmin)
    cparams.xmax            = getfloatarg(7, cparams.xmax)
    cparams.xstep           = getfloatarg(8, cparams.xstep)
    cparams.fwhm            = getfloatarg(9, cparams.fwhm)
    cparams.Gfraction       = getfloatarg(10, cparams.Gfraction)
    cparams.fwhm_smear      = getfloatarg(11, cparams.fwhm_smear)
    cparams.Gfraction_smear = getfloatarg(12, cparams.Gfraction_smear)
    cparams.yscale          = getarg     (13, cparams.yscale)
    cparams.BGorder         = getintarg  (14, cparams.BGorder)
    cparams.alpha           = getfloatarg(15, cparams.alpha)

def read_file(path, app, cparams):
    module = None
    for i in range(len(modules)):
        name = module_names[i]
        m = modules[i]

        file_type  = m.check_file_type(path, app = app, cparams = cparams)
#        file_type  = app.call(m, "check_file_type", path, app = app, cparams = cparams)
        print(f"try [{name}] for [{path}]: file_type={file_type}")
        if file_type is not None and 'Error' not in file_type:
            print("   type matched.")
            module = m
            break

    if module is None:
        return None, None

#    inf = app.call(module, "read_data", path, app = app, cparams = cparams)
    inf = module.read_data(path, app = app, cparams = cparams)

    return module, inf

def read_all_files(app, cparams, input_only = False):
    print("")
    print(f"read_all_files(): Read input file [{cparams.infile}]")
    module_input, inf_input = read_file(cparams.infile, app, cparams)
    if module_input:
#        module_input.print_data(inf_input)
        inf_input = module_input.convert(inf_input, cparams = cparams)
#        save_data([cparams.outfile], inf_input, cparams = cparams)
#        app.call(module_input, "plot_data", inf_input, cparams = cparams)
    else:
        app.terminate(f"Error in read_all_files(): Could not read [{cparams.infile}]", pause = True)

    xQ2_infile  = inf_input["data_list"][0]
    xmin = min(xQ2_infile)
    xmax = max(xQ2_infile)
    xstep = xQ2_infile[1] - xQ2_infile[0]
    print(f"  2Theta range: {xmin} - {xmax}, {xstep} step")
    print(f"  fwhm: {cparams.fwhm}")
    print(f"  Gaussian fraction: {cparams.Gfraction}")

    cparams.xmin  = max([cparams.xmin, xmin])
    cparams.xmax  = min([cparams.xmax, xmax])
    cparams.xstep = xstep
    print(f"2Theta range to be calculated: {cparams.xmin} - {cparams.xmax} degrees, {cparams.xstep} step")

    if input_only:
        inf = {
            "module_input": module_input,
            "inf_input"   : inf_input,
            }
        return inf

    cif_mask = cparams.cif_files
    files = glob.glob(cif_mask)
    print("")
    print(f"Read cif and xlsx files from [{cif_mask}]")
    inf_cif_list = []
    module_cif = None
    for f in files:
        print("")
        print(f"  Read [{f}]")
        dirname, basename, filebody, ext = split_file_path(f)
        if len(filebody) == 0 or filebody[0] == '~':
            print(f"    [{basename}] has '~' at the first character. may be a temprary file. skip")
            continue
        if '-out.' in basename.lower():
            print(f"    [{basename}] include '-out.'. maybe an output file of some program. skip")
            continue

        module_cif, inf_cif = read_file(f, app, cparams)
        if module_cif:
            print(f"    File [{f}] is red by [{module_cif.name}] module")
#            app.call(module_cif, "print_data", inf_cif)
            inf = app.call(module_cif, "convert", inf_cif, app = app, cparams = cparams)
#            app.call(module_cif, "plot_data", inf_cif, cparams = cparams)

        inf_cif_list.append(inf_cif)

    inf = {
        "module_input": module_input,
        "module_cif"  : module_cif,
        "inf_input"   : inf_input,
        "inf_cif_list": inf_cif_list,
        }

    return inf

def max_none(x):
    m = -1.0e100
    for v in x:
        if v is not None and m < v:
            m = v
    return m

def min_none(x):
    m = 1.0e100
    for v in x:
        if v is not None and m > v:
            m = v
    return m

def normalize_none(l, Amin = 0.0, Amax = 1.0, vmin = None, vmax = None):
    if vmax is None:
        vmax = max_none(l)
    if vmin is None:
        vmin = min_none(l)
    if vmax - vmin == 0.0:
        vmax = vmin + 1.0

    for i in range(len(l)):
        if l[i] is None:
            continue

        l[i] = (l[i] - vmin) / (vmax - vmin) * (Amax - Amin) + Amin

    return l

def normalize(l, Amin = 0.0, Amax = 1.0, vmin = None, vmax = None):
    if l is None:
        return None
        
    if vmax is None:
        vmax = max(l)
    if vmin is None:
        vmin = min(l)
    if vmax - vmin == 0.0:
        vmax = vmin + 1.0

    for i in range(len(l)):
        l[i] = (l[i] - vmin) / (vmax - vmin) * (Amax - Amin) + Amin

    return l

def plot_all(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Plot one XRD curve in a different box")
    print("#########################################")
    module_input = inf["module_input"]
    module_cif   = inf["module_cif"]
    inf_input    = inf["inf_input"]
    inf_cif_list = inf["inf_cif_list"]
    
    if module_cif is None:
        app.terminate(f"Error in plot_all(): CIFファイルが見つかりませんでした", pause = True)

    ncif = len(inf_cif_list)
    print("")
    print("plot")
    print(f"yscale: {cparams.yscale}")
    print("# of cif data:", ncif)

    sample_infile = inf_input["sample_name"]
    xQ2_infile  = inf_input["data_list"][0]
    yobs_infile = inf_input["data_list"][1]
    if len(inf_input["data_list"]) >= 3:
        ysim_infile = inf_input["data_list"][2]
    else:
        ysim_infile = None
    
    vmax = max(yobs_infile)
    vmin = min(yobs_infile)
    yobs_infile = normalize(yobs_infile, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
    if ysim_infile is not None:
        ysim_infile = normalize(ysim_infile, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)

    fig, axes = plt.subplots(ncif + 1, 1, figsize = figsize, sharex=True, gridspec_kw={'hspace': 0})
    plot_event = tkPlotEvent(plt, distance = 'x')
    ncolors  = len(colors)
    nmarkers = len(markers)

    ax0 = axes[0]
    ax0.tick_params(labelsize = fontsize)
    ax0.set_yticks([])
    ax0.set_xlabel(None)
    ax_bottom = axes[ncif]
    
    for i in range(1, ncif + 1):
        ax = axes[i]
        ax.tick_params(labelsize = fontsize)
        if i < ncif:
#            ax.set_xticks([])
            ax.set_xlabel(None)
        ax.set_yticks([])

    ax = axes[0]
    ax.set_title(f"{sample_infile}")
    ax.tick_params(labelsize = fontsize)
    ax.plot(xQ2_infile, yobs_infile, label = "obs", color = colors[0], linewidth = 0.5)
    if ysim_infile is not None:
        ax.plot(xQ2_infile, ysim_infile, label = "sim", color = colors[1], linewidth = 0.3)
    ax.axhline(0.0, linestyle = 'dashed', color = 'red', linewidth = 0.5)
    if cparams.yscale == 'log':
        ax.set_yscale('log')
        if ysim_infile is None:
            ymax = max(yobs_infile)
        else:
            ymax = max([max(ysim_infile), max(yobs_infile)])
        ax.set_ylim([1.0e-4 * ymax, ax.get_ylim()[1]])
    ax.legend(fontsize = legend_fontsize)
    
    for i in range(1, ncif + 1):
        inf_cif = inf_cif_list[i - 1]
        xQ2, xrd_cal = inf_cif["conv_data"]
        filename = inf_cif["filename"]
        dirname, basename, filebody, ext = split_file_path(filename)

        src  = inf_cif["diffractions"]["source"]
        Q2   = inf_cif["diffractions"]["Q2"]
        dhkl = inf_cif["diffractions"]["dhkl"]
        hkl  = inf_cif["diffractions"]["hkl"]
        mul  = inf_cif["diffractions"].get("mul", None)
        if mul is None:
            mul = [0 for i in range(len(src))]
        Int  = inf_cif["diffractions"]["intensity"]
        ndiffractions = len(Q2)

        vmax = max(xrd_cal)
        vmin = min(xrd_cal)
        xrd_cal = normalize(xrd_cal, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
        Int     = normalize(Int, Amin = 0.0, Amax = 1.0)

        ax = axes[i]
        color  = colors[(i - 1) % ncolors]
        marker = markers[(i - 1) % nmarkers]
        phase = [filebody] * ndiffractions

        ax.plot(xQ2, xrd_cal, label = filebody, color = 'black', linewidth = 0.5)
        data  = ax.plot(Q2, Int, linestyle = '', marker = marker, markerfacecolor = color, markeredgecolor = color, markersize = 2.0)
        data0 = ax0.plot(Q2, Int, linestyle = '', marker = marker, markerfacecolor = color, markeredgecolor = color, markersize = 2.0)
        for j in range(ndiffractions):
#           data0 = ax0.plot([Q2[j], Q2[j]], [0.0, Int[j]], linestyle = 'dashed', color = color, linewidth = 0.5)
            ax.plot([Q2[j], Q2[j]], [0.0, Int[j]], linestyle = 'dashed', color = color, linewidth = 0.5)
            ax0.plot([Q2[j], Q2[j]], [0.0, Int[j]], linestyle = 'dashed', color = color, linewidth = 0.5)

#plot_event.add_data({"label": "diffraction", "plot_type": "2D", "axis": ax, "data": [Q2, dhkl, hkl, mul, Int]})
        hkl_str = [f"{hkl[i][0]} {hkl[i][1]} {hkl[i][2]}" for i in range(ndiffractions)]
        plot_event.add_data({"label": "diffraction", "plot_type": "2D", "axis": ax, "data": data,
                                 "x": Q2, 'y': Int, "xlabel": "2Theta", "ylabel": "Intensity",
                                 "xlist": [src, phase, Q2, dhkl,  hkl_str, mul, Int],
                                 "xlabels": ['X-ray', 'phase', 'Q2', 'dhkl', 'hkl', 'multiplicity', 'Intensity']
                                 })
        plot_event.add_data({"label": "diffraction", "plot_type": "2D", "axis": ax0, "data": data0,
                                 "x": Q2, 'y': Int, "xlabel": "2Theta", "ylabel": "Intensity",
                                 "xlist": [src, phase, Q2, dhkl,  hkl_str, mul, Int],
                                 "xlabels": ['X-ray', 'phase', 'Q2', 'dhkl', 'hkl', 'multiplicity', 'Intensity']
                                 })

        ax.axhline(0.0, linestyle = 'dashed', color = 'red', linewidth = 0.5)
        if cparams.yscale == 'log':
            ax.set_yscale('log')
            ymax = max(xrd_cal)
            ax.set_ylim([1.0e-4 * ymax, ax.get_ylim()[1]])
        ax.legend(fontsize = legend_fontsize)

    xmin = max([inf_input["xmin"], cparams.xmin])
    xmax = min([inf_input["xmax"], cparams.xmax])

    for ax in axes:
        for q2 in range(int(xmin), int(xmax)):
            if q2 % 5 == 0:
                ax.axvline(q2, linestyle = 'dotted', linewidth = 0.5, color = 'black')
            else:
                ax.axvline(q2, linestyle = 'dotted', linewidth = 0.3, color = 'gray')

    ax.set_xlim(xmin, xmax)

    plot_event.register_click(fig) #callback = lambda event: plot_event.onclick(event))
#    plot_event.register_event(fig, event = "button_press_event", 
#                    callback = lambda event: plot_event.onclick(event))

    ax_bottom.set_xlabel(r'2$\theta$', fontsize = fontsize)
    axes[int(ncif/2)].set_ylabel('Intensity', fontsize = fontsize)
    ax.legend(fontsize = legend_fontsize)

    plt.tight_layout()
    plt.pause(0.1)

    input("Press ENTER to terminate>>")

def plot_input(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Plot input XRD curve")
    print("#########################################")
    module_input = inf["module_input"]
    inf_input    = inf["inf_input"]

    print("")
    print("plot")
    print(f"yscale: {cparams.yscale}")

    sample_infile = inf_input["sample_name"]
    xQ2_infile  = inf_input["data_list"][0]
    yobs_infile = inf_input["data_list"][1]

    fig = plt.figure(figsize = figsize_one)
    ax = fig.add_subplot(1, 1, 1)
    ax.tick_params(labelsize = fontsize)

    ax.set_title(f"{sample_infile}")
    ax.plot(xQ2_infile, yobs_infile, label = "obs", color = colors[0], linewidth = 0.5)
#    ax.legend(fontsize = legend_fontsize)

    Qmin = max([inf_input["xmin"], cparams.xmin])
    Qmax = min([inf_input["xmax"], cparams.xmax])
    ax.set_xlim(Qmin, Qmax)
    ax.set_xlabel('2$\\theta$ ($\\degree$)', fontsize = fontsize)
    ax.set_ylabel('Intensity', fontsize = fontsize)
    if cparams.yscale == 'log':
       ax.set_yscale('log')
#       ymax = max(yobs_infile)
#       ax.set_ylim([1.0e-4 * ymax, ax.get_ylim()[1]])

    plt.tight_layout()
    plt.pause(0.1)

    input("Press ENTER to terminate>>")

def plot_one(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Plot input XRD curve and cif diffraction angles in one graph")
    print("#########################################")
    module_input = inf["module_input"]
    module_cif   = inf["module_cif"]
    inf_input    = inf["inf_input"]
    inf_cif_list = inf["inf_cif_list"]

    if module_cif is None:
        app.terminate(f"Error in plot_all(): CIFファイルが見つかりませんでした", pause = True)

    ncif = len(inf_cif_list)
    print("")
    print("plot")
    print(f"yscale: {cparams.yscale}")
    print("# of cif data:", ncif)

    sample_infile = inf_input["sample_name"]
    xQ2_infile  = inf_input["data_list"][0]
    yobs_infile = inf_input["data_list"][1]
    if len(inf_input["data_list"]) >= 3:
        ysim_infile = inf_input["data_list"][2]
    else:
        ysim_infile = None

    vmax = max(yobs_infile)
    vmin = min(yobs_infile)
    yobs_infile = normalize(yobs_infile, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
    if ysim_infile is not None:
        ysim_infile = normalize(ysim_infile, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)

    fig = plt.figure(figsize = figsize_one)
    plot_event = tkPlotEvent(plt, distance = 'x')
    ncolors  = len(colors)
    nmarkers = len(markers)

    ax = fig.add_subplot(1, 1, 1)
#    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(None)
    ax.set_title(f"{sample_infile}")
    ax.tick_params(labelsize = fontsize)
    ax.plot(xQ2_infile, yobs_infile, label = "obs", color = colors[0], linewidth = 0.5)
    if ysim_infile is not None:
        ax.plot(xQ2_infile, ysim_infile, label = "sim", color = colors[1], linewidth = 0.3)
    ax.legend(fontsize = legend_fontsize)

    if ysim_infile is None:
        ymax = max(yobs_infile)
    else:
        ymax = max([max(yobs_infile), max(ysim_infile)])
    for i in range(1, ncif + 1):
        inf_cif = inf_cif_list[i - 1]
        xQ2, xrd_cal = inf_cif["conv_data"]
        filename = inf_cif["filename"]
        dirname, basename, filebody, ext = split_file_path(filename)

        src  = inf_cif["diffractions"]["source"]
        Q2   = inf_cif["diffractions"]["Q2"]
        dhkl = inf_cif["diffractions"]["dhkl"]
        hkl  = inf_cif["diffractions"]["hkl"]
        mul  = inf_cif["diffractions"].get("mul", None)
        if mul is None:
            mul = [0 for i in range(len(src))]
        Int  = inf_cif["diffractions"]["intensity"]
        ndiffractions = len(Q2)

        vmax = max(xrd_cal)
        vmin = min(xrd_cal)
        xrd_cal = normalize(xrd_cal, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
        Int     = normalize(Int, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)

        color  = colors[(i - 1) % ncolors]
        marker = markers[(i - 1) % nmarkers]
        phase = [filebody] * ndiffractions

        data0  = ax.plot(Q2, Int, linestyle = '', marker = marker, markerfacecolor = color, markeredgecolor = color, markersize = 2.0)
        ymax = max([max(Int), ymax])
        for j in range(ndiffractions):
#           data0 = ax.plot([Q2[j], Q2[j]], [0.0, Int[j]], linestyle = 'dashed', color = color, linewidth = 0.5)
           ax.plot([Q2[j], Q2[j]], [0.0, Int[j]], linestyle = 'dashed', color = color, linewidth = 0.5)

        if len(hkl[0]) == 3:
            hkl_str = [f"{hkl[i][0]} {hkl[i][1]} {hkl[i][2]}" for i in range(ndiffractions)]
        else:
            hkl_str = [f"{hkl[i][0]} {hkl[i][1]} ({hkl[i][2]}) {hkl[i][3]}" for i in range(ndiffractions)]
        plot_event.add_data({"label": "diffraction", "plot_type": "2D", "axis": ax, "data": data0,
                                 "x": Q2, 'y': Int, "xlabel": "2Theta", "ylabel": "Intensity",
                                 "xlist": [src, phase, Q2, dhkl,  hkl_str, mul, Int],
                                 "xlabels": ["X-ray", 'phase', 'Q2', 'dhkl', 'hkl', 'multiplicity', 'Intensity']
                                 })

        ymax = max([max(Int), ymax])
        ax.legend(fontsize = legend_fontsize)

    plot_event.register_click(fig) #callback = lambda event: plot_event.onclick(event))
#    plot_event.register_event(fig, event = "button_press_event", 
#                    callback = lambda event: plot_event.onclick(event))

    Qmin = max([inf_input["xmin"], cparams.xmin])
    Qmax = min([inf_input["xmax"], cparams.xmax])
    ax.set_xlim(Qmin, Qmax)
    ax.set_ylabel('Intensity', fontsize = fontsize)
    if cparams.yscale == 'log':
        ax.set_yscale('log')
        ax.set_ylim([1.0e-4 * ymax, ax.get_ylim()[1]])

    plt.tight_layout()
    plt.pause(0.1)

    input("Press ENTER to terminate>>")

def plot_one2(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Plot XRD curves in one graph")
    print("#########################################")
    module_input = inf["module_input"]
    module_cif   = inf["module_cif"]
    inf_input    = inf["inf_input"]
    inf_cif_list = inf["inf_cif_list"]

    if module_cif is None:
        app.terminate(f"Error in plot_all(): CIFファイルが見つかりませんでした", pause = True)

    ncif = len(inf_cif_list)
    print("")
    print("plot")
    print(f"yscale: {cparams.yscale}")
    print("# of cif data:", ncif)

    sample_infile = inf_input["sample_name"]
    xQ2_infile  = inf_input["data_list"][0]
    yobs_infile = inf_input["data_list"][1]
    if len(inf_input["data_list"]) >= 3:
        ysim_infile = inf_input["data_list"][2]
    else:
        ysim_infile = None
    
    vmax = max(yobs_infile)
    vmin = min(yobs_infile)
    yobs_infile = normalize(yobs_infile, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
    if ysim_infile is not None:
        ysim_infile = normalize(inf_input["data_list"][2], Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)

    fig = plt.figure(figsize = figsize_one)
    plot_event = tkPlotEvent(plt, distance = 'x')
    ncolors  = len(colors)
    nmarkers = len(markers)

    ax = fig.add_subplot(1, 1, 1)
#    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(None)
    ax.set_title(f"{sample_infile}")
    ax.tick_params(labelsize = fontsize)
    ax.plot(xQ2_infile, yobs_infile, label = "obs", color = colors[0], linewidth = 1.0)
#    ax.plot(xQ2_infile, ysim_infile, label = "sim", color = colors[1], linewidth = 0.3)

    ymax = max(yobs_infile)
    for i in range(1, ncif + 1):
        inf_cif = inf_cif_list[i - 1]
        xQ2, xrd_cal = inf_cif["conv_data"]
        filename = inf_cif["filename"]
        dirname, basename, filebody, ext = split_file_path(filename)

        vmax = max(xrd_cal)
        vmin = min(xrd_cal)
        xrd_cal = normalize(xrd_cal, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)

        ax.plot(xQ2, xrd_cal, label = filebody, linestyle = '-', color = colors[i % ncolors], linewidth = 0.5)
        ymax = max([ymax, max(xrd_cal)])

    Qmin = max([inf_input["xmin"], cparams.xmin])
    Qmax = min([inf_input["xmax"], cparams.xmax])
    ax.set_xlim(Qmin, Qmax)
    ax.set_ylabel('Intensity', fontsize = fontsize)
    if cparams.yscale == 'log':
        ax.set_yscale('log')
        ax.set_ylim([1.0e-4 * ymax, ax.get_ylim()[1]])
    ax.legend(fontsize = legend_fontsize)

    plt.tight_layout()
    plt.pause(0.1)

    input("Press ENTER to terminate>>")

def convolution(x, y, filter, nskip = 1):
    if y is None:
        return None

    ny = len(y)
    nconv = len(filter) // 2
    yc = np.zeros(ny)
    for i in range(0, ny, nskip):
        for idf in range(-nconv, nconv + 1):
            i1 = i + idf
            if i1 % nskip == 0 and 0 <= i1 < ny:
                yc[i] += filter[idf] * y[i1]

    _x = [x[i]  for i in range(0,  ny, nskip)]
    _y = [yc[i] for i in range(0,  ny, nskip)]

    return _x, _y
#    return x, yc

def collect_data(inf, app, cparams, is_print = True):
    module_input = inf["module_input"]
    inf_input    = inf["inf_input"]
    inf_cif_list = inf["inf_cif_list"]
    ncif      = len(inf_cif_list)
    nspectrum = inf_input["nspectrum"]
    
    if ncif < 1:
        app.terminate(f"\nError: CIF file is not found.\n", pause = True)

    fwhm = cparams.fwhm_smear
    Gf   = cparams.Gfraction_smear
    dQ2 = 12.0 * fwhm
    nconv = int(dQ2 / cparams.xstep + 1.00001)
    filter = [0.0] * (2 * nconv + 1)
    if fwhm <= 0.0:
        wGL = 1.0
    else:
        wGL = GaussLorentz(0.0, 0.0, fwhm / 2.0, C0 = 1.0, Gfraction = Gf, Gwratio = 1.0, A = None)
    for i in range(-nconv, nconv + 1):
        if fwhm <= 0.0:
            filter[i] = 0.0
        else:
            filter[i] = GaussLorentz(0.0, i * cparams.xstep, fwhm / 2.0, C0 = 1.0, Gfraction = Gf, Gwratio = 1.0, A = None) / wGL

    sample_infile = inf_input["sample_name"]
    xQ2_infile  = inf_input["data_list"][0]
    nx_in = len(xQ2_infile)
    dx = xQ2_infile[1] - xQ2_infile[0]
    nskip = int(fwhm / dx / 5.0 + 1.0e-5)
    if nskip == 0:
        nskip = 1
    minnx = 500
    if int(nx_in / nskip) < minnx:
        nskip = int(nx_in / minnx)
#    nskip = 1    

    yobs_infile = inf_input["data_list"][1]
    if nspectrum >= 3:
        ysim_infile = inf_input["data_list"][2]
    else:
        ysim_infile = None

    if cparams.yscale == 'log':
        yobs_infile = log(yobs_infile)
        if ysim_infile is not None:
            ysim_infile = log(ysim_infile)
    if fwhm > 0.0:
        _xQ2_infile, yobs_infile = convolution(xQ2_infile, yobs_infile, filter, nskip = nskip)
        if ysim_infile is not None:
            _xQ2_infile, ysim_infile = convolution(xQ2_infile, ysim_infile, filter, nskip = nskip)
    else:
        _xQ2_infile = xQ2_infile

    vmax = max(yobs_infile)
    vmin = min(yobs_infile)
    yobs_infile = normalize(yobs_infile, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
    if ysim_infile is not None:
        ysim_infile = normalize(ysim_infile, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)

    print(f"Read CIF data from {ncif} files")
    xQ2 = None
    ycif_list    = []
    sample_names = []
    for i in range(1, ncif + 1):
        inf_cif = inf_cif_list[i - 1]
        xQ2, xrd_cal = inf_cif["conv_data"]
        maxy = abs(max(xrd_cal))
        if cparams.yscale == 'log':
            xrd_cal = log(xrd_cal + maxy * 1.0e-7)

        if fwhm > 0.0:
            _xQ2, xrd_cal = convolution(xQ2, xrd_cal, filter, nskip = nskip)
        else:
            _xQ2 = xQ2

        xrd_cal = normalize(xrd_cal, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
        xrd_cal = normalize(xrd_cal, Amin = 0.0, Amax = 1.0)

        filename = inf_cif["filename"]
        dirname, basename, filebody, ext = split_file_path(filename)

        ycif_list.append(xrd_cal)
        sample_names.append(filebody)

    nx = len(xQ2)
    xmin = min(xQ2)
    xmax = max(xQ2)
    xstep = (xmax - xmin) / (nx - 1)
    x_list  = []
    Q2_list = []
    for i in range(nx):
        x_list.append(-1 + i * xstep)
        Q2_list.append(xmin + i * xstep)

    func1d = interp1d(_xQ2_infile, yobs_infile, bounds_error = False, fill_value = (0.0, 0.0))
#        func1d = interp1d(xQ2_infile, yobs_infile, bounds_error = False, fill_value = (0.0, 0.0))
    yobs_infile = func1d(Q2_list)
    if ysim_infile is not None:
        func1d = interp1d(_xQ2_infile, ysim_infile, bounds_error = False, fill_value = (0.0, 0.0))
#            func1d = interp1d(xQ2_infile, ysim_infile, bounds_error = False, fill_value = (0.0, 0.0))
        ysim_infile = func1d(Q2_list)
    for i in range(1, ncif + 1):
        func1d = interp1d(_xQ2, ycif_list[i-1], bounds_error = False, fill_value = (0.0, 0.0))
#            func1d = interp1d(xQ2, ycif_list[i-1], bounds_error = False, fill_value = (0.0, 0.0))
        ycif_list[i-1] = func1d(Q2_list)

    bg_list = []
    for order in range(cparams.BGorder + 1):
        poly1d = legendre(order)
        bg_list.append(poly1d(x_list))

    if is_print:
        print("# of cif data:", ncif)
        print("BG order   :", cparams.BGorder)

    bg_names = [f"{i}-th" for i in range(cparams.BGorder)]
    x_train      = bg_list.copy()
    labels_train = bg_names.copy()
    x_train.extend(ycif_list)
    labels_train.append("obs")
    labels_train.extend([inf["sample_name"] for inf in inf_cif_list])

    return Q2_list, yobs_infile, ysim_infile, bg_names, bg_list, sample_names, ycif_list, labels_train, x_train

def check_overwrap(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Check overwrap between cif XRD patterns")
    print("#########################################")
    print(f"  smearing: fwhm={cparams.fwhm_smear} Gaussian fraction={cparams.Gfraction_smear}")
    inf_cif_list = inf["inf_cif_list"]

    Q2_list, yobs_infile, ysim_infile, bg_names, bg_list, sample_names, ycif_list, labels_train, x_train = collect_data(inf, app, cparams)
    nx   = len(Q2_list)
    Qmin = Q2_list[0]
    Qmax = Q2_list[nx - 1]
    ncif = len(ycif_list)
    nxtrain = len(x_train)

    fwhm = cparams.fwhm + cparams.fwhm_smear
    Gf   = cparams.Gfraction_smear
    dQ2 = fwhm
    nconv = int(dQ2 / cparams.xstep + 1.00001)
    filter = [0.0] * (2 * nconv + 1)
    if fwhm <= 0.0:
        wGL = 1.0
    else:
        wGL = GaussLorentz(0.0, 0.0, fwhm / 2.0, C0 = 1.0, Gfraction = Gf, Gwratio = 1.0, A = None)
    for i in range(-nconv, nconv + 1):
        if fwhm <= 0.0:
            filter[i] = 0.0
        else:
            filter[i] = GaussLorentz(0.0, i * cparams.xstep, fwhm / 2.0, C0 = 1.0, Gfraction = Gf, Gwratio = 1.0, A = None) / wGL
#    print("nconv=", nconv)
#    print("filter=", filter)

    corr = make_matrix3(ncif, ncif, nx)
    for ic in range(ncif):
        inf = inf_cif_list[ic]
        Q2  = inf["diffractions"]["Q2"]
        Int = inf["diffractions"]["intensity"]
        nd = len(Q2)
        ndiffractions = len(Q2)

        for ic2 in range(ncif):
            if ic == ic2:
                continue
    
            ycif = ycif_list[ic2]

            for id in range(nd):
                Q20 = Q2[id]

                id0 = None
                for i in range(nx):
                    if Q20 <= Q2_list[i]:
                        id0 = i
                        break
                
                if id0 is None:
                    continue

                for idf in range(-nconv, nconv + 1):
                    i1 = id0 + idf
                    if 0 <= i1 < nx:
                        v = filter[idf] * Int[id] * ycif[i1]
                        if corr[ic][ic2][i1] is None:
                            corr[ic][ic2][i1] = v
                        else:
                            corr[ic][ic2][i1] += v
#    print("corr=", corr)

    fig, axes = plt.subplots(ncif, 1, figsize = figsize, sharex=True, gridspec_kw={'hspace': 0})
    plot_event = tkPlotEvent(plt, distance = 'x')
    ncolors  = len(colors)
    nmarkers = len(markers)

    for ic in range(ncif):
        inf = inf_cif_list[ic]
        src  = inf["diffractions"]["source"]
        Q2   = inf["diffractions"]["Q2"]
        dhkl = inf["diffractions"]["dhkl"]
        hkl  = inf["diffractions"]["hkl"]
        mul  = inf["diffractions"].get("mul", None)
        if mul is None:
            mul = [0 for i in range(len(src))]
        Int  = inf["diffractions"]["intensity"]
        ndiffractions = len(Q2)
        filename = inf["filename"]
        dirname, basename, filebody, ext = split_file_path(filename)
        phase = [filebody] * ndiffractions

        ax = axes[ic]
        ax.tick_params(labelsize = fontsize)
        ax2 = ax.twinx()
        ax2.tick_params(labelsize = 0)
        ycif = ycif_list[ic]
        ax.plot([], [], label = filebody, linestyle = '')
        ax2.plot(Q2_list, ycif, picker = True, linestyle = 'dashed', color = 'black', linewidth = 0.5)

        for ic2 in range(ncif):
            if ic == ic2:
                continue

            inf_cif2 = inf_cif_list[ic2]
            filename = inf_cif2["filename"]
            dirname, basename, filebody, ext = split_file_path(filename)

            color  = colors[ic2 % ncolors]
            marker = markers[ic2 % ncolors]

            y = corr[ic][ic2]
            vmax = max_none(y)
            vmin = min_none(y)
            y   = normalize_none(y, Amin = 0.0, Amax = 1.0, vmin = vmin, vmax = vmax)
            Int = normalize(Int, Amin = 0.0, Amax = 1.0)

            ax.plot(Q2_list, y, label = filebody, picker = True, color = color, linewidth = 1.0)

        data = ax.plot(Q2, Int, linestyle = '', marker = marker, markerfacecolor = color, markeredgecolor = 'black', markersize = 2.0)
        for j in range(ndiffractions):
             ax.plot([Q2[j], Q2[j]], [0.0, Int[j]], linestyle = 'dashed', color = 'black', linewidth = 0.5)

        hkl_str = [f"{hkl[i][0]} {hkl[i][1]} {hkl[i][2]}" for i in range(ndiffractions)]
        plot_event.add_data({"label": "diffraction", "plot_type": "2D", "axis": ax, "data": data,
                                 "x": Q2, 'y': Int, "xlabel": "2Theta", "ylabel": "Intensity",
                                 "xlist": [src, phase, Q2, dhkl,  hkl_str, mul, Int],
                                 "xlabels": ['X-ray', 'phase', 'Q2', 'dhkl', 'hkl', 'multiplicity', 'Intensity']
                                 })

        ax.set_xlim([Qmin, Qmax])
        ax.legend(fontsize = legend_fontsize)
        if cparams.yscale == 'log':
            ax.set_yscale('log')
            ax.set_ylim([1.0e-4, ax.get_ylim()[1]])
        ax2.set_yscale('linear')

    ax.set_xlabel(r'2$\theta$', fontsize = fontsize)

    plot_event.register_click(fig) # callback = lambda event: plot_event.onclick(event))
    plot_event.register_pick(fig) # callback = lambda event: plot_event.onclick(event))

    plt.tight_layout()
    plt.pause(0.1)

    input("Press ENTER to terminate>>")

def CIF_correlation(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Check correlation between cif XRD patterns")
    print("#########################################")
    print(f"  smearing: fwhm={cparams.fwhm_smear} Gaussian fraction={cparams.Gfraction_smear}")
    inf_cif_list = inf["inf_cif_list"]

    Q2_list, yobs_infile, ysim_infile, bg_names, bg_list, sample_names, ycif_list, labels_train, x_train = collect_data(inf, app, cparams)
    xlist   = [yobs_infile]
    xlabels = ["obs"]
    xlist.extend(ycif_list)
    xlabels.extend(sample_names)
    nx = len(xlist)
    dx = Q2_list[1] - Q2_list[0]

#相関係数
    print("")
    print("Correlation coefficients:")
    corr = np.zeros([nx, nx])
    for i in range(nx):
        for j in range(i, nx):
            corr[i][j] = np.dot(xlist[i], xlist[j])
    for i in range(nx):
        for j in range(i + 1, nx):
            corr[i][j] /= np.sqrt(corr[i][i] * corr[j][j])
            if corr[i][j] >= 0.9:
                print(f" identical? {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.8:
                print(f"    similar {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.5:
                print(f"      ***** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.5:
                print(f"      ***** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.4:
                print(f"       **** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.3:
                print(f"        *** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            else:
                print(f"            {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")

# input dataとの相関係数をsmearing fwhmを変えながらプロット
    print("")
    print("Correlation coefficients with input data with varied smearing fwhm:")
    fwhm_smear_original = cparams.fwhm_smear
    fwhm_list = np.arange(0.0, 1.001, 0.2)
    corr_list = make_matrix2(nx, len(fwhm_list))
    for i in range(len(fwhm_list)):
        _fwhm = fwhm_list[i]
        cparams.fwhm_smear = _fwhm
        print(f"  Calculating for fwhm={_fwhm}")
        Q2_list0, yobs_infile0, ysim_infile0, bg_names0, bg_list0, sample_names0, ycif_list0, labels_train0, x_train0 \
                                    = collect_data(inf, app, cparams, is_print = False)

        nskip = int((_fwhm / 5.0) / dx + 1.00001)
#        nskip = 1

        xlist0 = [[v for i, v in enumerate(yobs_infile) if i % nskip == 0]]
        for y in ycif_list0:
            xlist0.append([v for i, v in enumerate(y) if i % nskip == 0])
#        xlist0   = [yobs_infile]
#        xlist0.extend(ycif_list0)
        nx = len(xlist0)
        print(f"    _fwhm={_fwhm}  dx={dx}  nskip={nskip}  nx={len(xlist0[0])}")
        
#        nskip = int((_fwhm / 4.0) / dx)
#        print("nskip = ", nskip)
#        print("")
        
        norm = make_matrix1(nx)
        for j in range(nx):
            print(f"    calculating {j}-th diagonal norm")
            norm[j] = np.sqrt(np.dot(xlist0[j], xlist0[j]))
        for j in range(nx):
            print(f"    calculating obs - {j}-th non-diagonal correlation")
            corr_list[j][i] = np.dot(xlist0[0], xlist0[j]) / norm[0] / norm[j]
            
    cparams.fwhm_smear = fwhm_smear_original

    print("")
    print("plot")

    fig, ax = plt.subplots(1, 1, figsize = figsize_one)
    plot_event = tkPlotEvent(plt, distance = 'x')
    ncolors  = len(colors)

    for i in range(1, nx):
        ax.tick_params(labelsize = fontsize)
        ax.plot(fwhm_list, corr_list[i], label = xlabels[i], picker = True, color = colors[i % ncolors], linewidth = 1.0)

    ax.legend(fontsize = legend_fontsize)
    ax.set_xlabel(r'fwhm ($\degree$)', fontsize = fontsize)
    ax.set_ylabel(r'Correlation coefficient', fontsize = fontsize)

    plot_event.register_pick(fig) # callback = lambda event: plot_event.onclick(event))

    plt.tight_layout()
    plt.pause(0.1)

    print("")
    input("Press ENTER to terminate>>")


def correlation(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Check correlation between the input XRD and the cif XRD patterns")
    print("#########################################")
    print(f"  smearing: fwhm={cparams.fwhm_smear} Gaussian fraction={cparams.Gfraction_smear}")
    inf_cif_list = inf["inf_cif_list"]

    Q2_list, yobs_infile, ysim_infile, bg_names, bg_list, sample_names, ycif_list, labels_train, x_train = collect_data(inf, app, cparams)
    xlist   = [yobs_infile]
    xlabels = ["obs"]
    xlist.extend(ycif_list)
    xlabels.extend(sample_names)
    nx = len(xlist)
    dx = Q2_list[1] - Q2_list[0]

#相関係数
    print("")
    print("Correlation coefficients:")
    corr = np.zeros([nx, nx])
    for i in range(nx):
        for j in range(i, nx):
            corr[i][j] = np.dot(xlist[i], xlist[j])
    for i in range(nx):
        for j in range(i + 1, nx):
            corr[i][j] /= np.sqrt(corr[i][i] * corr[j][j])
            if corr[i][j] >= 0.9:
                print(f" identical? {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.8:
                print(f"    similar {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.5:
                print(f"      ***** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.5:
                print(f"      ***** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.4:
                print(f"       **** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            elif corr[i][j] >= 0.3:
                print(f"        *** {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")
            else:
                print(f"            {corr[i][j]:8.4f}: ({i:2}-{j:2}): ({xlabels[i]:20})-({xlabels[j]:20})")

# input dataとの相関係数をsmearing fwhmを変えながらプロット
    print("")
    print("Correlation coefficients with input data with varied smearing fwhm:")
    fwhm_smear_original = cparams.fwhm_smear
    fwhm_list = np.arange(0.0, 1.001, 0.2)
    corr_list = make_matrix2(nx, len(fwhm_list))
    for i in range(len(fwhm_list)):
        _fwhm = fwhm_list[i]
        cparams.fwhm_smear = _fwhm
        print("")
        print(f"  Calculating for fwhm={_fwhm}")
        Q2_list0, yobs_infile0, ysim_infile0, bg_names0, bg_list0, sample_names0, ycif_list0, labels_train0, x_train0 \
                                    = collect_data(inf, app, cparams, is_print = False)

        nskip = int((_fwhm / 5.0) / dx + 1.00001)
#        nskip = 1

        xlist0 = [[v for i, v in enumerate(yobs_infile) if i % nskip == 0]]
        for y in ycif_list0:
            xlist0.append([v for i, v in enumerate(y) if i % nskip == 0])
#        xlist0   = [yobs_infile]
#        xlist0.extend(ycif_list0)
        nx = len(xlist0)
        print(f"    _fwhm={_fwhm}  dx={dx}  nskip={nskip}  nx={len(xlist0[0])}")

        norm = make_matrix1(nx)
        for j in range(nx):
            print(f"    calculating {j}-th diagonal norm")
            norm[j] = np.sqrt(np.dot(xlist0[j], xlist0[j]))
        for j in range(nx):
            print(f"    calculating obs - {j}-th non-diagonal correlation")
            corr_list[j][i] = np.dot(xlist0[0], xlist0[j]) / norm[0] / norm[j]
            
    cparams.fwhm_smear = fwhm_smear_original


    print("")
    print("plot")

    fig, ax = plt.subplots(1, 1, figsize = figsize_one)
    plot_event = tkPlotEvent(plt, distance = 'x')
    ncolors  = len(colors)

    for i in range(1, nx):
        ax.tick_params(labelsize = fontsize)
        ax.plot(fwhm_list, corr_list[i], label = xlabels[i], picker = True, color = colors[i % ncolors], linewidth = 1.0)

    ax.legend(fontsize = legend_fontsize)
    ax.set_xlabel(r'smearing FWHM ($\degree$)', fontsize = fontsize)
    ax.set_ylabel(r'Correlation coefficient', fontsize = fontsize)

    plot_event.register_pick(fig) # callback = lambda event: plot_event.onclick(event))

    plt.tight_layout()
    plt.pause(0.1)

    print("")
    input("Press ENTER to terminate>>")

def fit(inf, app, cparams):
    print("")
    print("#########################################")
    print("  Fitting by LASSO")
    print("#########################################")
    print(f"  alpha={cparams.alpha}")
    print(f"  smearing: FWHM={cparams.fwhm_smear} Gaussian fraction={cparams.Gfraction_smear}")
    print("yscale:", cparams.yscale)

    Q2_list, yobs_infile, ysim_infile, bg_names, bg_list, sample_names, ycif_list, labels_train, x_train \
                            = collect_data(inf, app, cparams)
    nx   = len(Q2_list)
    ncif = len(ycif_list)
    nxtrain = len(x_train)

    print("")
    print("LASSO alpha:", cparams.alpha)
    x_train = np.array(x_train).T

    alpha0 = 1.0e-8
    ntry = 80
    print("")
    print(f"Lasso regression with varied alpha: start from {alpha0}")
    print(f"{'':42}", labels_train[1:])
    _alpha = alpha0
    alpha_list = []
    c_list     = []
    RMSE_list  = []
    maxC = None
    for i in range(ntry):
        model = Lasso(alpha = _alpha, fit_intercept = False)
        model.fit(x_train, yobs_infile)
        yfit = model.predict(x_train)
        MSE  = mean_squared_error(yobs_infile, yfit)
        RMSE = np.sqrt(MSE)

        alpha_list.append(_alpha)
        c_list.append(model.coef_)
        RMSE_list.append(RMSE)

        s = ["{:10.4g}".format(v) for v in model.coef_]
        s = " ".join(s)
        print(f"  {i:2}  alpha={_alpha:8.2g} RMSE={RMSE:12.4g}   coeff={s}")
        if maxC is None:
            maxC = abs(max(model.coef_))
        else:
            if abs(max(model.coef_)) < 1.0e-5 * maxC and abs(min(model.coef_)) < 1.0e-5 * maxC:
                break
        
        _alpha *= 2.0

    plot_event = tkPlotEvent(plt, distance = 'x')

    print("")
    print("plot LASSO analysis")
    print("sample_names=", sample_names)
    fig_lasso = plt.figure(figsize = figsize_coeff)
    ax  = fig_lasso.add_subplot(1, 1, 1)
    ax2 = ax.twinx()
    ax.tick_params(labelsize = fontsize)
    ax2.tick_params(labelsize = fontsize)

    ax.plot(alpha_list, RMSE_list, label = 'RMSE', picker = True, linestyle = 'dashed', color = 'black', linewidth = 1.5)
    for i in range(len(sample_names)):
        ax2.plot(alpha_list, np.array(c_list).T[i], label = sample_names[i], picker = True, linewidth = 1.0)
    ax2.axhline(0.0, linestyle = 'dashed', color = 'red', linewidth = 0.5)

    ax.set_xlabel('alpha', fontsize = fontsize)
    ax.set_ylabel('RMSE', fontsize = fontsize)
    ax2.set_ylabel('coefficient', fontsize = fontsize)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax2.legend(fontsize = legend_fontsize_one)

    plt.tight_layout()
    plt.pause(0.1)

    model = Lasso(alpha = cparams.alpha, fit_intercept = False)
    model.fit(x_train, yobs_infile)
    yfit = model.predict(x_train)

    print("")
    print(f"Lasso regression with alpha={cparams.alpha}")
    print(f"  Coefficients:")
#    print(f"  coefficients:", model.coef_)
    for order in range(cparams.BGorder + 1):
        print(f"    BG({order}-th order): {model.coef_[order]:10.6g}")
    yBG = np.zeros(nx)
    for i in range(nx):
        for iorder in range(cparams.BGorder + 1):
            yBG[i] += model.coef_[iorder] * bg_list[iorder][i]

    for i in range(1, ncif + 1):
        c = model.coef_[cparams.BGorder + i]
        print(f"    {sample_names[i-1]:20}: {c:10.6g}")
        ycif_list[i-1] *= c

    MAE  = mean_absolute_error(yobs_infile, yfit)
    MSE  = mean_squared_error(yobs_infile, yfit)
    RMSE = np.sqrt(MSE)
    print(f"  MAE : {MAE:12.4g}")
    print(f"  MSE : {MSE:12.4g}")
    print(f"  RMSE: {RMSE:12.4g}")

    print("")
    print("plot XRD profiles")
    fig = plt.figure(figsize = figsize_one)
    ncolors  = len(colors)
    ax = fig.add_subplot(1, 1, 1)
    ax.tick_params(labelsize = fontsize)

    ax.plot(Q2_list, yobs_infile, label = "obs", picker = True, linestyle = '', 
                marker = 'o', markersize = 2.0, markerfacecolor = colors[0], markeredgecolor = colors[0])
    ax.plot(Q2_list, yBG, label = "background", picker = True, linestyle = '-', color = 'cyan', linewidth = 0.5)
#    ax.plot(Q2_list, yobs_infile, label = "obs", color = colors[0], linewidth = 1.5)
    ax.plot(Q2_list, yfit, label = "fit", picker = True, linestyle = 'dashed', color = colors[1], linewidth = 1.5)
#    ax.plot(Q2_list, ysim_infile, label = "obs", color = colors[1], linewidth = 0.5)
    for i in range(1, ncif + 1):
        ax.plot(Q2_list, ycif_list[i-1], label = f"{sample_names[i-1]}", picker = True, color = colors[(i+1) % ncolors ], linewidth = 1.2)

#    for order in range(cparams.BGorder):
#        ax.plot(Q2_list, bg_list[order], label = f"{order+1}-th order", color = colors[(ncif+2+order) % ncolors], linewidth = 0.5)

    ax.set_xlabel('$2\\theta$', fontsize = fontsize)
    if cparams.yscale == 'log':
        ax.set_ylabel('log($y$)', fontsize = fontsize)
    else:
        ax.set_ylabel('$y$', fontsize = fontsize)
    ax.legend(fontsize = legend_fontsize_one)

    plot_event.register_pick(fig_lasso) # callback = lambda event: plot_event.onclick(event))
    plot_event.register_pick(fig) # callback = lambda event: plot_event.onclick(event))

    
    plt.tight_layout()
    plt.pause(0.1)

    input("Press ENTER to terminate>>")

def main():
    global module_names, modules

#==================================================================
# Initialize parameters
#==================================================================
    app     = tkApplication(usage_str  = usage_str, globals = globals(), locals = locals())
    cparams = app.get_params()

    logfile = app.replace_path(None, template = ["{dirname}", "{filebody}-out.txt"])
#    logfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-out.txt"])
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    initialize(app, cparams)
    update_vars(app, cparams)

    cparams.outfile = replace_path(cparams.infile, template = os.path.join("{dirname}", "{filebody}-out.txt"))

    print("")
    print( "==========================================================================")
    print(" Convert CIF file to powder XRD pattern")
    print( "==========================================================================")
    print(f"mode: {cparams.mode}")
    print(f"Plug-in dir: {cparams.plugin_dir}")
    print(f"Input  file: {cparams.infile}")
    print(f"Output file: {cparams.outfile}")

# Load modules
    print("")
    print(f"Load modules:")
    module_names, modules = app.load_modules(cparams.plugin_dir, "*.py", target = "read_data", is_print = True)
    for m in modules:
#        if hasattr(m, "initialize"):
#            print(f"initialize {m.name}")
#            m.initialize(app, cparams)
        input_type  = m.get_input_type(app = app, cparams = cparams)
        output_type = m.get_output_type(app = app, cparams = cparams)
        print(f"  {m.name}: input_type={input_type}  output_type={output_type}")

    if cparams.mode == 'plot':
        inf = read_all_files(app, cparams, input_only = True)
        plot_input(inf, app, cparams)
    elif cparams.mode == 'overwrap':
        inf = read_all_files(app, cparams)
        check_overwrap(inf, app, cparams)
    elif cparams.mode == 'CIFcorrelation':
        inf = read_all_files(app, cparams)
        CIF_correlation(inf, app, cparams)
    elif cparams.mode == 'correlation':
        inf = read_all_files(app, cparams)
        correlation(inf, app, cparams)
    elif cparams.mode == 'plot_all':
        inf = read_all_files(app, cparams)
        plot_all(inf, app, cparams)
    elif cparams.mode == 'plot_one':
        inf = read_all_files(app, cparams)
        plot_one(inf, app, cparams)
    elif cparams.mode == 'plot_one2':
        inf = read_all_files(app, cparams)
        plot_one2(inf, app, cparams)
    elif cparams.mode == 'fit':
        inf = read_all_files(app, cparams)
        fit(inf, app, cparams)

    app.terminate(usage = usage)


if __name__ == "__main__":
    main()
