import sys
import csv
import numpy as np
from math import exp, sqrt, sin, cos, pi
import matplotlib.pyplot as plt


"""
  Solve first order diffrential equation by Euler method
"""


#===================
# parameters
#===================
outfile = 'diffeq_euler.csv'
x0 = 1.0
dt = 0.5
nt = 501
iprint_interval = 20

argv = sys.argv
n = len(argv)
if n >= 2:
    x0 = float(argv[1])
if n >= 3:
    dt = float(argv[2])
if n >= 4:
    nt = int(argv[3])
if n >= 5:
    iprint_interval = int(argv[4])

# dx/dt = force(x,t)
# define function to be integrated
def force(t, x):
    return -x*x

# solution: x = 1 / (C + t), C = 1 for x(0) = 1.0
def fsolution(t):
    return 1.0 / (1.0 + t)

def diffeq_euler(force, t0, x0, dt):
   k1 = dt * force(t0, x0)
   x1 = x0 + k1
   return x1


#===================
# main routine
#===================
def main(x0, dt, nt):
    print("Solve first order diffrential equation by Euler method")

    print("Write to [{}]".format(outfile))

# prepare for graph
    xt    = [0.0]
    yxex  = [x0]
    yxsim = [x0]

    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1)
# 凡例を表示させるため、とりあえずplot()を呼び出す
# 後でプロット毎にデータリストを再設定するので、lineオブジェクトを受け取っておく
    line1, = ax1.plot(xt, yxex,  label = 'exact')
    line2, = ax1.plot(xt, yxsim, label = 'euler')
#    ax1.set_xscale('log')
#    ax1.set_yscale('log')
    ax1.set_xlabel("t")
    ax1.set_ylabel("x(t)")
    ax1.legend()

# open outfile to write a csv file
    f = open(outfile, 'w')
    fout = csv.writer(f, lineterminator='\n')
    fout.writerow([
        't', 'x(cal)', 'x(exact)'
        ])

    print("{:^5}  {:^12}  {:^12}".format('t', 'x(cal)', 'x(exact)'))
    for i in range(1, nt):
        t0 = i * dt
        x0 = diffeq_euler(force, t0, x0, dt)
        xexact = fsolution(t0)

        xt.append(t0)
        yxex.append(xexact)
        yxsim.append(x0)

# graphをupdateするには、プロットデータ line1/line2 に set_data() でデータリストを設定し、plt.pause()を呼び出す
# set_data() ではグラフの表示範囲は更新されないので、データの最小・最大値で設定する
        line1.set_data(xt, yxex)
        line2.set_data(xt, yxsim)
        ax1.set_xlim((min(xt), max(xt)))
        ax1.set_ylim((min(yxsim + yxex), max(yxsim + yxex)))
        plt.tight_layout()
        plt.pause(0.00001)

        if i == 1 or i % iprint_interval == 0:
            print("t={:5.2f}  {:12.6f}  {:12.6f}".format(t0, x0, xexact))

        fout.writerow([t0, x0, xexact])

    f.close()

    print("Press ENTER to exit>>", end = '')
    input()

    exit()


if __name__ == '__main__':
    main(x0, dt, nt)
