#Pymatgenチュートリアル⑥ XRDのシミュレーションをする https://qiita.com/ojiya/items/1b154c3698cff91c8a2b

import math
from math import exp
import numpy as np
import scipy as sp
import scipy.special
import matplotlib.pyplot as plt

from pymatgen.core.periodic_table import Element
from pymatgen.io.cif import CifParser
from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.analysis.diffraction.xrd import XRDCalculator, WAVELENGTHS
from pymatgen.analysis.diffraction.tem import TEMCalculator
from pymatgen.analysis.diffraction.neutron import NDCalculator

from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatargfrom tklib.tkvariousdata import tkVariousData
from tklib.tkgraphic.tkplotevent import tkPlotEvent


infile = 'SrTiO3.cif'

# URL: https://github.com/materialsproject/pymatgen/blob/v2023.5.10/pymatgen/analysis/diffraction/xrd.py
# 'CuKa', 'CuKa2', 'CuKa1', 'CuKb1', 'MoKa', 'MoKa2', 'MoKa1', 'MoKb1',
# 'CrKa', 'CrKa2', 'CrKa1', 'CrKb1', 'FeKa', 'FeKa2', 'FeKa1', 'FeKb1', 'CoKa', 'CoKa2', 'CoKa1', 'CoKb1',
# 'AgKa', 'AgKa2', 'AgKa1', 'AgKb1'
"""
WAVELENGTHS = {
    "CuKa": 1.54184,
    "CuKa2": 1.54439,
    "CuKa1": 1.54056,
    "CuKb1": 1.39222,
    "MoKa": 0.71073,
    "MoKa2": 0.71359,
    "MoKa1": 0.70930,
    "MoKb1": 0.63229,
    "CrKa": 2.29100,
    "CrKa2": 2.29361,
    "CrKa1": 2.28970,
    "CrKb1": 2.08487,
    "FeKa": 1.93735,
    "FeKa2": 1.93998,
    "FeKa1": 1.93604,
    "FeKb1": 1.75661,
    "CoKa": 1.79026,
    "CoKa2": 1.79285,
    "CoKa1": 1.78896,
    "CoKb1": 1.63079,
    "AgKa": 0.560885,
    "AgKa2": 0.563813,
    "AgKa1": 0.559421,
    "AgKb1": 0.497082,
}
"""

Xray_source     = 'CuKa1'

Q2min  = 10.0
Q2max  = 80.0
Q2step = 0.02

fwhm = 0.05

figsize = (8, 4)
fontsize = 12


app    = tkApplication()
infile      = getarg( 1, infile)
Xray_source = getarg( 2, Xray_source)
Q2min       = getfloatarg( 3, Q2min)
Q2max       = getfloatarg( 4, Q2max)
Q2step      = getfloatarg( 5, Q2step)
fwhm        = getfloatarg( 6, fwhm)


#==========================================
# Main prgram
#==========================================
def Gaussian(x, x0, whalf):
    a = whalf / 0.832554611
    X = (x - x0) / a
    return exp(-X * X)


#==========================================
# Main routine
#==========================================
def main():
    logfile = app.replace_path(infile, template = ["{dirname}", "{filebody}-out.txt"])
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    outxlsxfile = app.replace_path(infile, template = ["{dirname}", "{filebody}-pxrd.xlsx"])
    nQ2 = int((Q2max - Q2min) / Q2step + 1.000001)
    dQ2 = 6.0 * fwhm
    nconv = int(dQ2 / Q2step + 1.00001)

    print("")
    print(f"input: {infile}")   
    print(f"log file: {logfile}")
    print(f"(output xlsx file: {outxlsxfile}")
    print(f"X-ray source: {Xray_source}")
    try:
        wl = float(Xray_source)
    except:
        wl = WAVELENGTHS.get(Xray_source, None)
    if wl:
        print(f"  wavelength: {wl} angstrom")
    print(f"2Theta range: {Q2min:9.5g} - {Q2max:9.5g} at {Q2step:9.5g} degree ({nQ2} points)")
    print(f"2Theta range for a peak: {dQ2:9.5g} degree for {nconv} points")
    print(f"  FWHM of Gaussian functin: {fwhm} degree")

    x = np.arange(10,60,0.02)

    parser = CifParser(infile)
    mat = parser.get_structures(primitive=False, symmetrized=False)[0]

    xrd = XRDCalculator(wavelength = Xray_source) #, debye_waller_factors = {Element('Sr'): 0.2, Element('O'): 0.6})
    Q2min_diffraction = max([  0.0, Q2min - dQ2])
    Q2max_diffraction = min([180.0, Q2max + dQ2])
    diffractions = xrd.get_pattern(mat, two_theta_range = (Q2min_diffraction, Q2max_diffraction))
#diffractions.as_dict()

    ndiffractions = len(diffractions.x)
    Q2   = diffractions.x
    dhkl = diffractions.d_hkls
    hkl  = [diffractions.hkls[i][0]['hkl'] for i in range(ndiffractions)]
    mul  = [diffractions.hkls[i][0]['multiplicity'] for i in range(ndiffractions)]
# y includes mul, LP and T
    Int  = [diffractions.y[i] for i in range(ndiffractions)]

    print("")
    print(f"Diffractions: {ndiffractions} diffractions")
    print(f"{'h':2} {'k':2} {'l':2}   {'m':2}    {'dhkl':8} {'2Theta':8} {'Intensity'}")
    for i in range(len(Q2)):
        h = hkl[i][0]
        k = hkl[i][1]
        if len(hkl[i]) == 4:
            l = hkl[i][3]
        else:
            l = hkl[i][2]
        print(f"{h:2} {k:2} {l:2}   {mul[i]:2}    {dhkl[i]:8.6f} {Q2[i]:8.6f} {Int[i]:8.4f}")
#        print("  diffractions.hkls[i]=", diffractions.hkls[i])

    print("")
    print("Calculate specrum")
    xQ2     = np.arange(Q2min, Q2max + Q2step, Q2step)
    xrd_cal = np.zeros(nQ2)
    for idf in range(ndiffractions):
        idx0 = int((Q2[idf] - Q2min) / Q2step)

        for j in range(idx0 - nconv, idx0 + nconv + 1):
            if j < 0 or j >= nQ2:
                continue

            xrd_cal[j] += Int[idf] * Gaussian(xQ2[j], Q2[idf], fwhm)

    print("")
    print(f"Save to [{outxlsxfile}]")
    tkVariousData().to_excel(outxlsxfile, ['2Theta (degree)', 'Intensity'], [xQ2, xrd_cal])

    mode = ''
    if mode == 'xrd':
        xrd_pattern = xrd.plot_structures(
            [mat],
            two_theta_range=[Q2min, Q2max],
            show = True,
#            savefig="si_xrd.pdf"
        )
        exit()
    elif mode == 'nd':
        nd = NDCalculator()
        nd_pattern = nd.plot_structures(
            [mat],
            two_theta_range=[Q2min, Q2max],
            show = True,
#            savefig="si_nd.pdf"
        )
        exit()
    elif mode == 'tem':
        tem = TEMCalculator()
        tem_pattern = tem.get_plot_2d(
            mat
            )
        tem_pattern.write_image("si_tem.pdf")
        exit()
    
#=====================
# plot
#=====================

    fig = plt.figure(figsize = figsize)
    plot_event = tkPlotEvent(plt)

    ax = fig.add_subplot(111)
    ax.set_title(f"{infile} - source: {Xray_source}")
    ax.tick_params(labelsize = fontsize)

    ax.plot(xQ2, xrd_cal, color = 'black', linewidth = 0.5)
    data = ax.plot(Q2, Int, linestyle = '', marker = 'o', markerfacecolor = 'red', markeredgecolor = 'red', markersize = 3.0)

#plot_event.add_data({"label": "diffraction", "plot_type": "2D", "axis": ax, "data": [Q2, dhkl, hkl, mul, Int]})
    hkl_str = [f"{hkl[i][0]} {hkl[i][1]} {hkl[i][2]}" for i in range(ndiffractions)]
    plot_event.add_data({"label": "diffraction", "plot_type": "2D", "axis": ax, "data": data,
                                "x": Q2, 'y': Int, "xlabel": "2Theta", "ylabel": "Intensity",
                                 "xlist": [Q2, dhkl,  hkl_str, mul, Int],
                                 "xlabels": ['Q2', 'dhkl', 'hkl', 'multiplicity', 'Intensity']
                                 })

    ax.set_xlim(Q2min, Q2max)

    ax.set_xlabel(r'2$\theta$', fontsize = fontsize)
    ax.set_ylabel('Intensity', fontsize = fontsize)

    plot_event.register_event(fig, event = "button_press_event", 
                    callback = lambda event: plot_event.onclick(event))

    plt.tight_layout()
    plt.pause(0.1)

    app.terminate(pause = True)


if __name__ == "__main__":
    main()


