import csv
import numpy as np
from matplotlib import pyplot as plt


"""
  Note: Standard python has only 64bit double precision for floating point type.
        Numerical python module (numpy) can handle different precisions for ndarray type
    np.float16  # half precision   (半精度 浮動小数点数)
    np.float32  # single precision (単精度 浮動小数点数)
    np.float64  # double precision (倍精度 浮動小数点数)
    np.float128 # quadruple (double-double) precision (四倍精度 浮動小数点数) 
                # numpy for Anaconda ver 3.7 may not support np.float128
"""


#===================
# parameters
#===================
outfile = 'sum_error.csv'
N = 101
h = 0.01
iprintstep = 10

#==========================================================
# define precision types as the first elements of ndarrays
# all the variables will be refered to e.g. as h16[0], sum16[0] ...
#==========================================================
h16    = np.array([h],   dtype=np.float16)
sum16  = np.array([0.0], dtype=np.float16)
h32    = np.array([h],   dtype=np.float32)
sum32  = np.array([0.0], dtype=np.float32)
h64    = np.array([h],   dtype=np.float64)
sum64  = np.array([0.0], dtype=np.float64)
#h128   = np.array([h],   dtype=np.float128)
#sum128 = np.array([0.0], dtype=np.float128)

# error variables should have the highest precision
ex     = np.array([0.0], dtype=np.float64)
#err16  = np.array([0.0], dtype=np.float64)
err16  = np.array([0.0], dtype=np.float16)
#err32  = np.array([0.0], dtype=np.float64)
err32  = np.array([0.0], dtype=np.float32)
err64  = np.array([0.0], dtype=np.float64)

#===================
# main routine
#===================
print("Summing up {} for {} times with "
      "different precision floating point types".format(h, N))
print("Write to [{}]".format(outfile))

# open outfile to write a csv file
f = open(outfile, 'w')
fout = csv.writer(f, lineterminator='\n')
fout.writerow(['exact', 'float16', 'float32', 'float64', 
               'error(float16)', 'error(float32)', 'error(float64'])

print("")
print("{:^3}:\t{:^28}\t{:^28}\t{:^28}".format('exact', 'sum16 (error)', 'sum32 (error)', 'sum64 (error)'))

xN = [i for i in range(N)]
yerr16 = []
yerr32 = []
yerr64 = []
for i in range(N): # repeat N times from i = 0 to N-1
    ex[0]    = (i+1) * h
    sum16[0] += h16[0]
    err16[0] = ex[0] - sum16[0]
    sum32[0] += h32[0]
    err32[0] = ex[0] - sum32[0]
    sum64[0] += h64[0]
    err64[0] = ex[0] - sum64[0]
#    sum128 += h128
#    err128 = ex - sum128[0]

    yerr16.append(err16[0])
    yerr32.append(err32[0])
    yerr64.append(err64[0])

    fout.writerow([ex[0], sum16[0], sum32[0], sum64[0], err16[0], err32[0], err64[0]])
    if(i % iprintstep == 0):
        print(f"{ex[0]:<0.4f}: ", end = '')
        print(f"{sum16[0]:<0.18f} ({err16[0]:<+9.2e})  ", end = '')
        print(f"{sum32[0]:<0.18f} ({err32[0]:<+9.2e})  ", end = '')
        print(f"{sum64[0]:<0.18f} ({err64[0]:<+9.2e})  ", end = '')
        print("")


f.close()


#=============================
# Plot graphs
#=============================
fig = plt.figure(figsize = (8, 4))

ax1 = fig.add_subplot(1, 3, 1)
ax2 = fig.add_subplot(1, 3, 2)
ax3 = fig.add_subplot(1, 3, 3)

ax1.plot(xN, yerr16, label = 'float16', linestyle = 'none', marker = 'o', markersize = 0.5)
ax1.set_xlabel("N")
ax1.set_ylabel("error")
ax1.legend()
ax2.plot(xN, yerr32, label = 'float32', linestyle = 'none', marker = 'o', markersize = 0.5)
ax2.set_xlabel("N")
ax2.set_ylabel("error")
ax2.legend()
ax3.plot(xN, yerr64, label = 'float64', linestyle = 'none', marker = 'o', markersize = 0.5)
ax3.set_xlabel("N")
ax3.set_ylabel("error")
ax3.legend()

plt.tight_layout()

plt.pause(0.1)
   
input("Press ENTER to exit>>")
