import signal
import sys
import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import rcParams


# キーボード割り込みをキャッチしてアニメーションを停止
def signal_handler(sig, frame):
    input('Press ENTER to stop animation>>')
    plt.close(fig)
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)


#method = 'euler'
method = 'verlet'

# pos: 部分座標
# x, xeq, xall, xk: 絶対座標

# 単位格子長
a0 = 1.0  
# 部分座標
pos_list0 = [0, 0.5]
# 質量
m_list0 = [1.0, 1.0]
n0 = len(pos_list0)

# 力の定数
k2_list0 = [[0, 0.1e-1],
            [0.1e-1, 0]]
k3_list0 = [[0, 0],
            [0, 0]]
k4_list0 = [[0, 1e3],
            [1e3, 0]]
#k4_list0 = [[0, 0],
#            [0, 0]]
# 平衡原子間距離
l0 = 0.5

# 初期変位最大値
dxmax = 0.02

# supercellの繰り返し数
nrepeat = 100
n = nrepeat * n0

tstep = 1.0e-1  # 時間ステップ
nstep = 100000  # ステップ数
tsleep = 0 # アニメーション時間ステップ [ms]
plotinterval = 100

# supercell内での原子の座標
asl = a0 * nrepeat
optid = [1] * n
pos_list = np.empty(n)
m_list = np.empty(n)
k2_list = np.empty([n, n])
k3_list = np.empty([n, n])
k4_list = np.empty([n, n])
for i in range(n):
    i0 = i % n0
    isl = i // n0
    pos_list[i] = isl * a0 + pos_list0[i0]
    m_list[i] = m_list0[i0]
    for j in range(n):
        j0 = j % n0
        k2_list[i][j] = k2_list0[i0][j0]
        k3_list[i][j] = k3_list0[i0][j0]
        k4_list[i][j] = k4_list0[i0][j0]
pos_list /= nrepeat


#平衡位置
xeq = asl * pos_list
# 固定端
dx01 = xeq[1] - xeq[0]
x_l_fixed = xeq[0] - dx01  
x_r_fixed = xeq[n - 1] + dx01

#初期位置
optid[0] = 0        #固定端
optid[n - 1] = 0    #固定端
optid[1] = 0        #歪を入れた初期構造の最適化では固定する
optid[n - 2] = 0    #歪を入れた初期構造の最適化では固定する
xini = np.copy(xeq)
xini[1]     += -dxmax   #初期歪
xini[n - 2] += dxmax    #初期歪
# 初期速度
v = np.zeros_like(xini)

def get_x(i, xall):
    if i == -1: return x_l_fixed # 左固定端
    if i == n : return x_r_fixed # 右固定端
#    if i == -1: return x[n-1] - asl # 周期構造
#    if i == n : return x[0] + asl   # 周期構造
    return xall[i]

def dx(i, j, xall):
    if i < j: return get_x(j, xall) - get_x(i, xall) - l0
    if i > j: return get_x(j, xall) - get_x(i, xall) + l0
    return 0.0
   
#原子間力
def fij(i, j, xall):
    if abs(j - i) != 1: return 0.0

    if   i == -1: ik = n - 1
    elif i == n : ik = 0
    else: ik = i

    if   j == -1: jk = n - 1
    elif j == n : jk = 0
    else: jk = j

    _dx = dx(i, j, xall)
    return k2_list[ik][jk] * _dx \
         + 0.5 * k3_list[ik][jk] * _dx**2 \
         + 1/6 * k4_list[ik][jk] * _dx**3

#ポテンシャル[エネルギー
def Uij(i, j, xall):
    if abs(j - i) != 1: return 0.0

    if   i == -1: ik = n - 1
    elif i == n : ik = 0
    else: ik = i

    if   j == -1: jk = n - 1
    elif j == n : jk = 0
    else: jk = j

    _dx = dx(i, j, xall)

    return  0.5 * k2_list[ik][jk] * _dx**2 \
         +  1/6 * k3_list[ik][jk] * _dx**3 \
         + 1/24 * k4_list[ik][jk] * _dx**4

def Uk(xall):
    _Uk = 0.0
#supercell内の結合
    for i in range(n - 1):
        _Uk += Uij(i, i+1, xall)

#左端の結合 (周期構造の場合はコメントアウト)
    _Uk += Uij(-1, 0, xall)
#右端の結合 (周期構造の場合でも残す)
    _Uk += Uij(n-1, n, xall)

    return _Uk

#運動エネルギ-
def UK(v):
    _UK = 0.0
    for i in range(n):
        _UK += 0.5 * m_list[i] * v[i]**2
    return _UK

# エネルギー関数の定義
def Utot(x, v):
    _Uk = Uk(x)
    _UK = UK(v)
    return _Uk + _UK

# 力のリスト
# メモリ取得を繰り返さないため、globalにする
f = np.zeros(n, dtype=float)
def F_list(x):
    global f

    for i in range(n):
        f[i]  = fij(i, i - 1, x)
        f[i] += fij(i, i + 1, x)

    return f

def extract_parameters(xkall, optid):
    optpk = []
    for i in range(len(xkall)):
        if optid[i]: optpk.append(xkall[i])
    return optpk

def recover_parameters(xk, xkall, optid):
    pk = []
    c = 0
    for i in range(len(xkall)):
        if optid[i] == 1:
            pk.append(xk[c])
            c += 1
        else:
           pk.append(xkall[i])

    return pk

def minimize_func(xk):
    xkall = recover_parameters(xk, xini, optid)
    ret = Uk(xkall)
#    print("xall=", xall)
#    print("  U=", ret)
    return ret

# 最適化
iter = 0
def optimize(xini, optid):
    def callback(xk):
        global iter
        iter += 1
        xkall = recover_parameters(xk, xini, optid)
        U = Uk(xkall)
        print(f"iter #{iter:03d}: U={U:8.4f}")
#        print(f"iter #{iter:03d}:", [f"{x:8.3f}" for x in xkall])

    xk = extract_parameters(xini, optid)
    res = minimize(minimize_func, xk, method = 'BFGS', callback = callback) #'nelder-mead') #BFGS')
    xopt = recover_parameters(res.x, xini, optid)
    return xopt, res.fun


print()
print("optimize:")
print("initial x:", xini)
xopt, Umin = optimize(xini, optid)
print("optimized x:", xopt, f"  Umin={Umin}")

print("Equilibrium positions:")
print("xeq=", xeq)
print("x_l_fixed=", x_l_fixed)
print("x_r_fixed=", x_r_fixed)

print("Initial values:")
#print("pos=", pos_list)
print("pos=", xini)
print("dx=", xini - xeq)
print("v  =", v)

print("xini and f0:")
f0 = F_list(xini)
for i in range(n):
    print(f"#{i}: {xini[i]:8.5f} {f0[i]:10.6g}")

# アニメーションの更新関数
x = np.copy(xopt)
UKs = []
Uks = []
Ets = []
x0 = None
x1 = None
v0 = None
v1 = None
import signal 
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams

# アニメーション用の更新部分を plt.pause を使うように変更
def update(frame):
    global x, v, x0, x1, v0, v1, UKs, Uks, Uts, line1, line2_kinetic, line2_potential, line2_total

    if frame == 0 or method == 'euler':
        x0 = x
        v0 = v

        f1 = F_list(x)

        v1 = v0 + f1 / m_list * tstep 
        x1 = x0 + v1 * tstep 

        x = x1
        v = v1
    else:
        f1 = F_list(x)
        x2 = 2.0 * x1 - x0 + f1 / m_list * tstep * tstep
        v2 = (x2 - x0) / 2.0 / tstep

        x0 = x1
        v0 = v1
        x1 = x2
        v1 = v2
        x = x2
        v = v2

    _Uk = Uk(x)
    _UK = UK(v)
    UKs.append(_UK)
    Uks.append(_Uk)
    Uts.append(_UK + _Uk)

    if frame % plotinterval == 0:
        ax1.set_title(f"frame {frame}")

        # ax1 の更新
        line1.set_data(xeq, x - xeq)
        ax1.set_xlim([-a0, a0 * (nrepeat + 1)])
        ax1.set_ylim([1.1 * np.min(x - xeq), 1.1 * np.max(x - xeq)])

        # ax2 の更新
        t_list = tstep * np.array(range(len(UKs)))
        line2_kinetic.set_data(t_list, UKs)
        line2_potential.set_data(t_list, Uks)
        line2_total.set_data(t_list, Uts)
        ax2.set_xlim([0, tstep * len(UKs)])
        ax2.set_ylim([0.0, 1.1 * max(max(UKs), max(Uks), max(Uts))])

        plt.tight_layout()
        plt.pause(0.001)

# プロットの設定
rcParams['font.sans-serif'] = ['MS Gothic']
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6))

# 初期化
x = np.copy(xopt)
UKs = []
Uks = []
Uts = []
x0 = None
x1 = None
v0 = None
v1 = None

# ax1 の初期化
line1, = ax1.plot([], [], label="running")
ax1.plot(xini, xini - xeq, label="ini")
ax1.plot(xopt, xopt - xeq, label="opt")
ax1.set_xlabel("x", fontdict={'fontname': 'MS Gothic'})
ax1.set_ylabel("dx", fontdict={'fontname': 'MS Gothic'})
ax1.legend()
ax1.grid(True)

# ax2 の初期化
line2_kinetic, = ax2.plot([], [], label='kinetic energy')
line2_potential, = ax2.plot([], [], label='potential energy')
line2_total, = ax2.plot([], [], label='total energy')
#ax2.set_title("Energy", fontdict={'fontname': 'MS Gothic'})
ax2.set_xlabel("t", fontdict={'fontname': 'MS Gothic'})
ax2.set_ylabel("Energy", fontdict={'fontname': 'MS Gothic'})
ax2.legend()
ax2.grid(True)

# メインループ
for frame in range(nstep):
    update(frame)
    
input("\nPress ENTER to terminate>>")
    
