import sys
import os
from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, sqrt
import numpy as np
from numpy import linalg as la
from mpl_toolkits.mplot3d import Axes3D
#import matplotlib.pyplot as plt


from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatargfrom tklib.tkfile import tkFile
from tklib.tkcrystal.tkcif import tkCIF
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType
from tkcrystalbase import *



"""
Calculate Bragg angles
  Requirement: tkcrystalbase.py
"""


WAVELENGTHS = {
    "CuKa": 1.54184,
    "CuKa2": 1.54439,
    "CuKa1": 1.54056,
    "CuKb1": 1.39222,
    "MoKa": 0.71073,
    "MoKa2": 0.71359,
    "MoKa1": 0.70930,
    "MoKb1": 0.63229,
    "CrKa": 2.29100,
    "CrKa2": 2.29361,
    "CrKa1": 2.28970,
    "CrKb1": 2.08487,
    "FeKa": 1.93735,
    "FeKa2": 1.93998,
    "FeKa1": 1.93604,
    "FeKb1": 1.75661,
    "CoKa": 1.79026,
    "CoKa2": 1.79285,
    "CoKa1": 1.78896,
    "CoKb1": 1.63079,
    "AgKa": 0.560885,
    "AgKa2": 0.563813,
    "AgKa1": 0.559421,
    "AgKb1": 0.497082,
}


infile = 'SrTiO3.cif'

Xray_source = 'CuKa1'

# G min to remove the origin of reciprocal space
Gmin = 1.0e-5

# 2Theta max
Q2max = 150.0       # degree in 2Theta

#figsize = (8, 4)
#fontsize = 12


#==========================================
# Main prgram
#==========================================
infile      = getarg( 1, infile)
Xray_source = getarg( 2, Xray_source)
Q2max       = getfloatarg(3, Q2max)


def cal_diffraction_angles(cry, Xray_source = "CuKa1", Qmax = 150.0, Gmin = 1.0e-5):
    print(f"X-ray source: {Xray_source}")
    try:
        wl = float(Xray_source)
    except:
        wl = WAVELENGTHS.get(Xray_source, None)
    if wl:
        print(f"  wavelength: {wl} angstrom")
    else:
        if '+' in Xray_source or '*' in Xray_source:
            app.terminate(f"\nError: Mixed X-ray source [{Xray_source}] is not supported. Choose a proper source", pause = True)
        app.terminate(f"\nError: Invalid X-ray source [{Xray_source}]. Choose a proper source", pause = True)

    print(f"2Theta max: {Q2max:9.5g} degree")

    lattice_parameters = cry.LatticeParameters()
    ls = cry.lattice_system()
    la = cry.lattice_axis()
    aij = cry.LatticeVectors()
    gij, Rgij = cry.Metrics()
    volume = cry.Volume()

    Raij = cry.ReciprocalLatticeVectors()
    Rlatt = cry.ReciprocalLatticeParameters()
    Rvolume = cry.ReciprocalVolume()

    print("")
    print("Lattice parameters:", lattice_parameters)
    print(f"Lattice system: {ls}")
    print(f"Unitcell axis system: {la}")

    """
    print("Lattice vectors:")
    print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
    print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
    print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
    print("Metric tensor:")
    print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
    print("Unit cell volume: {:12.4g} A^3".format(volume))

    print("")
    print("Reciprocal lattice parameters:", Rlatt)
    print("Reciprocal lattice vectors:")
    print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
    print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
    print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
    print("Reciprocal lattice metric tensor:")
    print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
    print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))
    """

    sites = []
    AtomTypes = cry.AtomTypeList()
    AtomSites = cry.ExpandedAtomSiteList()
    allsites = []
    for atom in AtomSites:
        label  = atom.Label()
        name   = atom.AtomNameOnly()
        z      = 1 #atom.AtomicNumber()
        q      = atom.Charge()
        pos    = atom.Position()
        xc, yc, zc = cry.FractionalToCartesian(*pos)
        if q >= 0.0:
            r = 0.7
            color = 'red'
            M = 1.0
        else:
            r = 1.4
            color = 'blue'
            M = 1.0

        sites.append([name, label, z, M, q, r, color, pos])

    dmin = wl / 2.0 / sin(0.5 * Q2max * torad)
    hmax = int(lattice_parameters[0] / dmin)
    kmax = int(lattice_parameters[1] / dmin)
    lmax = int(lattice_parameters[2] / dmin)

    print("")
    print("hkl range:", hmax, kmax, lmax)

# Calculate diffraction angles and store them in qlist list variable
    org = np.array([0.0, 0.0, 0.0])
    qlist = []
    for l in range(-lmax, lmax+1):
      for k in range(-kmax, kmax+1):
        for h in range(-hmax, hmax+1):
# Calculate distance in reciprocal space between (0, 0, 0) and (h, k, l)
            G = distance(org, np.array([h, k, l]), Rgij)
            if G < Gmin:
                continue

# Calculate lattice space from G
            d = 1.0 / G

            sinQ = wl / 2.0 / d
            if sinQ >= 1.0:
                continue
            
# Calculate diffraction angle 2Theta
            Q2 = 2.0 * todeg * arcsin(sinQ)
            if Q2 > Q2max:
                continue

            qlist.append([Q2, d, h, k, l])
#            print("  2Q={:12.4g}  d={:12.6g}  ({:3d} {:3d} {:3d})".format(Q2, d, h, k, l))

# Sort rlist by 2Theta (x[0] priority)
    qlist.sort(key = lambda x: (x[0], x[2], x[3], x[4]))

    peak_identical = {}
    for qinf in qlist:
        Q2, d, h, k, l = qinf
        
        h2, k2, l2 = cry.normalize_hkl(h, k, l)
        if la == 'hexagonal' or la == 'trigonal':
            key = f"{Q2:10.6f}:{h2}_{k2}_{-h2-k2}_{l2}"
        else:
            key = f"{Q2:10.6f}:{h2}_{k2}_{l2}"

        if peak_identical.get(key, None):
            peak_identical[key].append({ "Q2": Q2, "h": h, "k": k, "l": l, "d": d})
        else:
            peak_identical[key] = [{ "Q2": Q2, "h": h, "k": k, "l": l, "d": d}]

    print("")
    print("Diffraction angle, d, h, k, l:")
    if la == 'hexagonal':
#    print(f"{'h':2} {'k':2} {'l':2}    {'dhkl':10} {'2Theta':8}")
        print(f"{'h':>2} {'k':>2} {'i':>2} {'l':>2}   {'m':>2}    {'dhkl':^10} {'2Theta':^8}")
    else:
        print(f"{'h':>2} {'k':>2} {'l':>2}   {'m':>2}    {'dhkl':^10} {'2Theta':^8}")
    for key in sorted(peak_identical.keys()):
        p_key = peak_identical[key]
        p = peak_identical[key][len(p_key)-1]
        h  = p['h']
        k  = p['k']
        l  = p['l']
        m = cry.multiplicity_hkl(h, k, l)   #len(peak_identical[key])
        Q2 = p['Q2']
        d  = p['d']
        if la == 'hexagonal':
#        print(f"{h:2} {k:2} {l:2}    {d:10.6f} {Q2:8.6f}")
            print(f"{h:2} {k:2} {-h-k:2} {l:2}   {m:2}    {d:10.6f} {Q2:8.6f}")
        else:
            print(f"{h:2} {k:2} {l:2}   {m:2}    {d:10.6f} {Q2:8.6f}")


def main():
    app    = tkApplication()
    logfile = app.replace_path(infile, template = ["{dirname}", "{filebody}-out.txt"])
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    print("")
    print(f"input: {infile}")   
    print(f"log file: {logfile}")

    print("")
    print("Read [{}]".format(infile))
    cif = tkCIF()
    cif.debug = False
    cifdata = cif.ReadCIF(infile, find_valid_structure = True)
    cif.Close()
    if not cifdata:
        terminate("Error: Could not get cifdat from infile [{}]".format(infile))

#   cifdata.print()
    cry = cifdata.GetCrystal()
    cry.PrintInf()

    cal_diffraction_angles(cry, Xray_source, Q2max, Gmin)

    print("")
    app.terminate(pause = True)


if __name__ == '__main__':
    main()
