import signal
import sys
import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import rcParams


# キーボード割り込みをキャッチしてアニメーションを停止
def signal_handler(sig, frame):
    input('Press ENTER to stop animation>>')
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)


class tkParams():
    pass
    
# pos: 部分座標
# x, xeq, xall, xk: 絶対座標
class tkCrystal():
    def __init__(self, a0, positions, masses, l0, k2_list, k3_list, k4_list):
        self.a0 = a0                # 単位格子長
        self.pos_list = positions   # 部分座標
        self.m_list = masses        # 質量
        self.natoms = len(self.pos_list)

# 力の定数
        self.l0 = l0                # 平衡原子間距離
        self.k2_list = k2_list      # 2次項
        self.k3_list = k3_list      # 3次項
        self.k4_list = k4_list      # 4次項

class tkPotential():
    def __init__(self, mdcell):
        self.mdcell = mdcell
        self.n      = mdcell.natoms
        self.m_list = mdcell.m_list
        self.l0 = mdcell.l0
        self.k2 = mdcell.k2_list
        self.k3 = mdcell.k3_list
        self.k4 = mdcell.k4_list

    def get_x(self, i, xall):
        if i == -1    : return None
        if i == self.n: return None
#       if i == -1    : return x[self.n-1] - asl # 周期構造
#       if i == self.n: return x[0] + asl   # 周期構造
        return xall[i]

    def dx(self, i, j, xall):
        xj = self.get_x(j, xall)
        xi = self.get_x(i, xall)
        if xj is None or xi is None: return None

        if i < j: return xj - xi - self.l0
        if i > j: return xj - xi + self.l0
        return 0.0
   
#原子間力
    def fij(self, i, j, xall):
        if abs(j - i) != 1: return 0.0

        _dx = self.dx(i, j, xall)
        if _dx is None: return 0.0

        return       self.k2[i][j] * _dx \
             + 0.5 * self.k3[i][j] * _dx**2 \
             + 1/6 * self.k4[i][j] * _dx**3

#ポテンシャル[エネルギー
    def Uij(self, i, j, xall):
        if abs(j - i) != 1: return 0.0

        _dx = self.dx(i, j, xall)

        if _dx is None: return 0.0

        return  0.5 * self.k2[i][j] * _dx**2 \
             +  1/6 * self.k3[i][j] * _dx**3 \
             + 1/24 * self.k4[i][j] * _dx**4

    def Uk(self, xall):
        _Uk = 0.0
#supercell内の結合
        for i in range(self.n - 1):
            _Uk += self.Uij(i, i+1, xall)

#左端の結合 (周期構造の場合はコメントアウト)
        _Uk += self.Uij(-1, 0, xall)
#右端の結合 (周期構造の場合でも残す)
        _Uk += self.Uij(self.n - 1, self.n, xall)

        return _Uk

#運動エネルギ-
    def UK(self, v):
        _UK = 0.0
        for i in range(self.n):
            _UK += 0.5 * self.m_list[i] * v[i]**2
        return _UK

# エネルギー関数の定義
    def Utot(self, x, v):
        _Uk = self.Uk(x)
        _UK = self.UK(v)
        return _Uk + _UK

# 力のリスト
    def F_list(self, x):
        f = np.empty(self.n, dtype = float)
        for i in range(self.n):
            f[i]  = self.fij(i, i - 1, x)
            f[i] += self.fij(i, i + 1, x)

        return f

# MD model
class tkModel():
    def __init__(self, mdcell, dxmax):
        self.mdcell = mdcell
        self.n = mdcell.natoms
#平衡位置
        self.xeq = mdcell.a0 * mdcell.pos_list

#初期位置
        self.optid_md = [1] * self.n
        self.optid_md[0] = 0               #固定端
        self.optid_md[self.n - 1] = 0      #固定端

        self.optid_relax = self.optid_md.copy()
        self.optid_relax[1] = 0               #歪を入れた初期構造の最適化では固定する
        self.optid_relax[self.n - 2] = 0      #歪を入れた初期構造の最適化では固定する

        self.xini = np.copy(self.xeq)
        self.xini[1]          += -dxmax   #初期歪
        self.xini[self.n - 2] += dxmax    #初期歪

class tkMD():
    def __init__(self, model, potential, x0, method, tstep):
        self.model = model
        self.potential = potential
        self.optid = model.optid_md
        self.n  = model.n

        self.method = method
        self.tstep = tstep

        self.UKs = []
        self.Uks = []
        self.Uts = []

# 初期位置、速度を現在位置 self.x、現在速度 self.v に設定
        self.x = x0.copy()
        self.v = np.zeros_like(self.x)
        for i in range(len(self.v)):
            if self.optid[i] == 0: 
                self.v[i] = 0.0

# 履歴バッファ
        self.x0 = self.x.copy()
        self.v0 = self.v.copy()
        self.xprev = None
        self.vprev = None

    def next_step(self, frame):
        tstep = self.tstep
        m_list = self.potential.m_list

        f1 = self.potential.F_list(self.x)
        vnext = self.v.copy()
        xnext = self.x.copy()
        if frame == 0 or self.method == 'euler':

            for i in range(len(self.x)):
                if self.optid[i]:
                    vnext[i] = self.v[i] + f1[i] / m_list[i] * tstep 
                    xnext[i] = self.x[i] + vnext[i] * tstep 

            self.vprev = vnext.copy()
            self.xprev = xnext.copy()
        else:
            self.x1 = self.x.copy()

            for i in range(len(self.x)):
                if self.optid[i]:
                    xnext[i] = 2.0 * self.xprev[i] - self.x0[i] + f1[i] / m_list[i] * tstep * tstep
                    vnext[i] = (xnext[i] - self.x0[i]) / 2.0 / tstep

            self.v0 = self.vprev
            self.x0 = self.xprev
            self.vprev = vnext
            self.xprev = xnext

        self.v = vnext
        self.x = xnext
        _Uk = self.potential.Uk(self.x)
        _UK = self.potential.UK(self.v)

        self.UKs.append(_UK)
        self.Uks.append(_Uk)
        self.Uts.append(_UK + _Uk)
    
        return self.UKs, self.Uks, self.Uts

class tkOptimize():
    def __init__(self, potential):
        self.potential = potential
        self.iter = 0

    def extract_parameters(self, xkall):
        optpk = []
        for i in range(len(xkall)):
            if self.optid[i]: optpk.append(xkall[i])
        return optpk

    def recover_parameters(self, xk):
        pk = []
        c = 0
        for i in range(len(self.xini)):
            if self.optid[i] == 1:
                pk.append(xk[c])
                c += 1
            else:
               pk.append(self.xini[i])

        return pk

    def minimize_func(self, xk):
        xkall = self.recover_parameters(xk)
        ret = self.potential.Uk(xkall)
        return ret

    def callback(self, xk):
        self.iter += 1
        xkall = self.recover_parameters(xk)
        U = self.potential.Uk(xkall)
        print(f"iter #{self.iter:03d}: U={U:8.4f}")

    def optimize(self, xini, optid):
        self.xini = xini
        self.optid = optid

        xk = self.extract_parameters(xini)
        res = minimize(self.minimize_func, xk, 
                    method = 'BFGS', 
                    callback = self.callback)
        xopt = self.recover_parameters(res.x)
        return xopt, res.fun


def supercell(nrepeat, unitcell):
    n = nrepeat * unitcell.natoms
# supercell内での原子の座標
    asl = unitcell.a0 * nrepeat
    pos_list = np.empty(n)
    m_list = np.empty(n)
    k2_list = np.empty([n, n])
    k3_list = np.empty([n, n])
    k4_list = np.empty([n, n])
    for i in range(n):
        i0 = i % unitcell.natoms
        isl = i // unitcell.natoms
        pos_list[i] = isl * unitcell.a0 + unitcell.pos_list[i0]
        m_list[i] = unitcell.m_list[i0]
        for j in range(n):
            j0 = j % unitcell.natoms
            k2_list[i][j] = unitcell.k2_list[i0][j0]
            k3_list[i][j] = unitcell.k3_list[i0][j0]
            k4_list[i][j] = unitcell.k4_list[i0][j0]
    pos_list /= nrepeat

    supercell = tkCrystal(a0 = asl, 
        positions = pos_list, masses = m_list,
        l0 = unitcell.l0,
        k2_list = k2_list,
        k3_list = k3_list,
        k4_list = k4_list
        )

    return supercell


def initialize():
    cfg = tkParams()

# 微分方程式のアルゴリズム
#   cfg.method = 'euler'
    cfg.method = 'verlet'
# 初期変位最大値
    cfg.dxmax = 0.02
    cfg.tstep = 2.0e-1  # 時間ステップ
    cfg.nstep = 100000  # ステップ数
    cfg.tsleep = 0      # アニメーション時間ステップ [ms]
    cfg.plotinterval = 10

    cfg.nrepeat = 50

    unitcell = tkCrystal(a0 = 1.0, positions = [0, 0.5], masses = [1.0, 1.0],
            l0 = 0.5,
            k2_list = [[0, 1e-2],
                       [1e-2, 0]],
            k3_list = [[0, 0],
                       [0, 0]],
            k4_list = [[0, 0],
                       [0, 0]],
#            k4_list = [[0, 1e6],
#                       [1e6, 0]]
            )

    mdcell = supercell(nrepeat = cfg.nrepeat, unitcell = unitcell)

    return cfg, unitcell, mdcell

# アニメーションの更新関数
def update(frame, cfg, model, potential, md, plot, unitcell):
    UKs, Uks, Uts = md.next_step(frame)

    if frame % plot.plotinterval != 0: return

    plot.ax1.set_title(f"frame {frame}")

    plot.line1.set_data(model.xeq, md.x - model.xeq)
    plot.ax1.set_xlim([-unitcell.a0, unitcell.a0 * (cfg.nrepeat + 1)])
    plot.ax1.set_ylim([1.1 * np.min(md.x - model.xeq), 1.1 * np.max(md.x - model.xeq)])

    t_list = cfg.tstep * np.array(range(len(md.UKs)))
    plot.line2_UK.set_data(t_list, md.UKs)
    plot.line2_Uk.set_data(t_list, md.Uks)
    plot.line2_Ut.set_data(t_list, md.Uts)
    plot.ax2.set_xlim([0, cfg.tstep * len(md.UKs)])
    plot.ax2.set_ylim([0.0, 
                1.1 * max(max(md.UKs), max(md.Uks), max(md.Uts))])

    plt.tight_layout()
    plt.pause(0.001)

def exec_md(cfg, unitcell, mdcell):
    model = tkModel(mdcell, cfg.dxmax)
    print()
    print("Equilibrium positions:")
    print("xeq=", model.xeq)

    potential = tkPotential(mdcell)

    print()
    print("optimize:")
    opt = tkOptimize(potential)
    xopt, Umin = opt.optimize(model.xini, model.optid_relax)
    print("optimized Umin={Umin}")

    md = tkMD(model, potential, xopt, cfg.method, cfg.tstep)

    print(f"{' ':3}: {'xini':8} {'fini':12}     {'xopt':8}  {'fopt':12}   {'v':8}")
    fini = potential.F_list(model.xini)
    fopt = potential.F_list(xopt)
    for i in range(model.n):
        print(f"#{i:03d}: {model.xini[i]:8.5f} {fini[i]:12.4g}    {xopt[i]:8.5f} {fopt[i]:12.4g}   {md.v[i]:8.5f}")

# プロットの設定
    plot = tkParams()
    plot.plotinterval = cfg.plotinterval

    rcParams['font.sans-serif'] = ['MS Gothic']
    fig, (plot.ax1, plot.ax2) = plt.subplots(2, 1, figsize = (10, 6))

    plot.line1, = plot.ax1.plot(xopt, xopt - model.xeq,  label = "running")
    plot.ax1.plot(model.xini, model.xini - model.xeq, label = "ini")
    plot.ax1.plot(xopt, xopt - model.xeq, label = "opt")
    plot.ax1.set_xlabel("x",  fontdict = {'fontname': 'MS Gothic'})
    plot.ax1.set_ylabel("dx", fontdict = {'fontname': 'MS Gothic'})
    plot.ax1.legend()
    plot.ax1.grid(True)

    plot.line2_Uk, = plot.ax2.plot([], [], label='kinetic energy')
    plot.line2_UK, = plot.ax2.plot([], [], label='potential energy')
    plot.line2_Ut, = plot.ax2.plot([], [], label='total energy')
#ax2.set_title("Energy", fontdict = {'fontname': 'MS Gothic'})
    plot.ax2.set_xlabel("t",      fontdict = {'fontname': 'MS Gothic'})
    plot.ax2.set_ylabel("Energy", fontdict = {'fontname': 'MS Gothic'})
    plot.ax2.legend()
    plot.ax2.grid(True)

# メインループ
    for frame in range(cfg.nstep):
        update(frame, cfg, model, potential, md, plot, unitcell)

    input("\nPress ENTER to terminate>>")


def main():
    cfg, unitcell, mdcell = initialize()
    exec_md(cfg, unitcell, mdcell)


if __name__ == '__main__':
    main()
    