from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, sqrt
import numpy as np
from numpy import linalg as la
from pprint import pprint
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt


"""
Base library for calculating crystal properties
"""



pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";


#===================================
# Treat arguments
#===================================
def pint(str, defval = None):
    try:
        return int(str)
    except:
        return defval

def pfloat(s, strict = True, defval = 0.0):
    try:
        return float(s)
    except:
        pass
    if s is None or strict:
        return defval

    match = re.search(r'([\+\-]?[\d\.eE]+)', s)
    if not match:
        return defval

    val = match.group()
    try:
        return float(val)
    except:
        if defval is None:
            return s
        return defval


def getarg(idx, defval = None):
    try:
        return sys.argv[idx]
    except:
        return defval

def getfloatarg(idx, defval = None):
    try:
        return pfloat(sys.argv[idx])
    except:
        return defval

def getintarg(idx, defval = None):
    try:
        return pint(sys.argv[idx])
    except:
        return defval

def usage():
    argv = sys.argv
    print("")
    print("Usage: python {} csv_path".format(argv[0]))
    print("   ex: python {} {}".format(argv[0], csvfile))

def terminate(message = None, usage = usage, postmessage = None):
    if message is not None:
        print("")
        print(message)

    if usage is not None:
        usage()

    print("")
    if postmessage is not None:
        print(postmessage)
        print("")
        
    exit()

def is_dir(path):
    return os.path.isdir(path)

def is_file(path):
    return os.path.isfile(path)

def make_path(dir, *args):
    return os.path.join(dir, *args)

def split_path(path):
    if os.sep == '\\':
        path0 = tkre.Sub(r'\\', '/', path)
    else:
        path0 = path
    path0 = tkre.Sub('/$', '', path0)

#    dirname, basename, os.path.split(path)
    basename    = os.path.basename(path0)
    dirname     = os.path.dirname(path0)
    header, ext = os.path.splitext(path0)
    filebody    = os.path.basename(header)

    if os.path.isdir(path):
        ext = ''
    if os.sep == '\\':
        dirname = tkre.Sub('/', r'\\', dirname)

    return dirname, basename, filebody, ext

def modify_path(path, addpath):
    basename    = os.path.basename(path)
    dirname     = os.path.dirname(path)
    header, ext = os.path.splitext(path)
    filebody    = os.path.basename(header)
    return os.path.join(dirname, filebody + addpath)

def safe_getelement(var, key, defval = None):
    try:
        return var[key]
    except:
        return defval

def sfmt(str, fmt):
    f = '{' + fmt + '}'
    return f.format(str).strip()

def get_charcode(path):
    detector = UniversalDetector()
    
    fp = open(path, 'rb')
    detector.reset()
    for line in fp.readlines():
        detector.feed(line)
        if detector.done:
            break

    fp.close()
    detector.close()

    charcode = detector.result['encoding']
    if 'Windows' in charcode:
        charcode = 'SHIFT-JIS'

    return charcode

def savecsv(outfile, header, datalist):
    try: 
        print("Write to [{}]".format(outfile))
        f = open(outfile, 'w')
    except:
#    except IOError:
        print("Error: Can not write to [{}]".format(outfile))
    else:
        fout = csv.writer(f, lineterminator='\n')
        fout.writerow(header)
#        fout.writerows(data)
        for i in range(0, len(datalist[0])):
            a = []
            for j in range(len(datalist)):
                a.append(datalist[j][i])
            fout.writerow(a)
        f.close()

def read_csv(infile, xmin = None, xmax = None, delimiter = ','):
    print("xrange=", xmin, xmax)
    data = []
    try:
        infp = open(infile, "r")
        f = csv.reader(infp, delimiter = delimiter)
        header = next(f)
        print("header=", header)
        for j in range(len(header)):
            data.append([])

        for row in f:
            x = pfloat(row[0])
            if xmin is not None and xmin <= x <= xmax:
                y = pfloat(row[1])
                data[0].append(x)
                data[1].append(y)
    except:
        print("Error: Can not read [{}]".format(infile))
        exit()
    return header, data[0], data[1]

def read_excel(infile):
    wb = openpyxl.load_workbook(infile, data_only = True)
    if not wb:
        print("")
        print("Error to read [{}]".format(infile))
        print("")
        terminate()

    sheetnames = wb.sheetnames
    print("sheet names:", sheetnames)
    print("  read [{}]".format(sheetnames[0]))

#    ws = wb.active
    ws = wb[sheetnames[0]]
    Tlabel = ws[1][0].value
    Nlabel = ws[1][1].value
    Slabel = ws[1][2].value
    print("Labels:", Tlabel, Nlabel, Slabel)
    xT    = []
    yN    = []
    yS    = []
    i = 2
    while 1:
        Tcell = ws.cell(row = i, column = 1).value
        if Tcell is None or Tcell == '':
            break
        Ncell = ws.cell(row = i, column = 2).value
        Scell = ws.cell(row = i, column = 3).value

        Tcell = float(Tcell)
        Ncell = float(Ncell)
        Scell = float(Scell)

        xT.append(Tcell)
        yN.append(Ncell)
        yS.append(Scell)
        i += 1

    return [Tlabel, Nlabel], xT, yN, yS
