import sys
from sys import exit
import os
import time
import re
import numpy as np
from numpy import sqrt, log
import openpyxl
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

try:
    import physbo
except:
    print("")
    print("#########################################################")
    print("# IMPORT ERROR for physbo")
    print("#   Use python 3.6 environment and install physbo")
    print("#   For anaconda and if you made py36 as python 3.6 virtual environment:")
    print("#      > conda activate py36")
    print("#      > pip install physbo")
    print("#      or")
    print("#      > pip3 install physbo")
    print("#      or")
    print("#      > python -m pip3 install physbo")
    print("#########################################################")
    exit()

from tklib.tkutils import getarg, getintarg, getfloatarg, pint, pfloat, pintfloat
from tklib.tkinifile import tkIniFile
from tklib.tkparams import tkParams
from tklib.tkapplication import tkApplication


#===================================================================================
"""
Perform Bayes optimization based on Gauss process using PHYSBO
"""
#===================================================================================


#===================================================================================
usage_str = '''
f"(i) usage: python {sys.argv[0]} infile max_num_probes num_rand_basis score_mode interval standardize"
f"             max_num_probes: Use a number to reach convergence"
f"             num_rand_basis: Use a large number so as to reproduce training data"
f"             score_mode    : [EI|PI|TS]"
f"             interval      : Number of cycle to update hyper parameters"
f"             standardize   : Flag to standardize descrptors [0|1]"
f"       ex: python {sys.argv[0]} {cparams.infile} {cparams.max_num_probes} {cparams.num_rand_basis} {cparams.score_mode} {cparams.interval}"
'''[1:-1]
#===================================================================================

#==============================================
# Global variables
#==============================================
physbo_url = "https://www.pasums.issp.u-tokyo.ac.jp/physbo/"
citation_url = "https://issp-center-dev.github.io/PHYSBO/manual/master/en/introduction.html"

prog_name = 'Bayes/GP CLI'
version = [1, 3, 1]


#===================================================================================
# Other functions
#===================================================================================
def citation(app):
    print("")
    print( "====================================================================================")
    print("Please cite designated referneces for PHYSBO. See")
    print(f"              Citation: {citation_url}")
    print( "====================================================================================")

def usage(app):
    cparams = app.get_params()
    for s in usage_str.split('\n'):
        cmd = 'print({})'.format(s.rstrip())
        eval(cmd)

    citation(app)

def initialize():
#================================
# Global variables
#================================
    app        = tkApplication(usage_str  = usage_str)
    argv, narg = app.get_argv()

    cparams             = app.get_params()
    cparams.debug       = 0
    cparams.print_level = 0

    cparams.infile  = 'data_simple.xlsx'
#    cparams.infile = 'data_simple.csv'

    cparams.nx_2D = 11
    cparams.ny_2D = 11

    cparams.standardize     = 1
    cparams.num_search_each = 1
    cparams.max_num_probes  = 1
    cparams.num_rand_basis  = 200   # -1 for non-approximated
    cparams.score_mode      = 'EI'
    cparams.interval        = 0

    return app, cparams

def update_vars(app, cparams):
    cparams.infile          = getarg   (1, cparams.infile)
    cparams.max_num_probes  = getintarg(2, cparams.max_num_probes)
    cparams.num_rand_basis  = getintarg(3, cparams.num_rand_basis)
    cparams.score_mode      = getarg   (4, cparams.score_mode)
    cparams.interval        = getintarg(5, cparams.interval)
    cparams.standardize     = getintarg(6, cparams.standardize)

#====================================================
# Graph parameters
#====================================================
    cparams.figsize         = (8, 6)
    cparams.fontsize        = 16
    cparams.legend_fontsize = 12


def read_data_file(app, infile):
    error = False
    if '.xlsx' in infile:
        try:
            df = pd.read_excel(infile, engine = 'openpyxl')
        except:
            error = True
    else:
        try:
            df = pd.read_csv(infile)
        except:
            error = True

    if error:
        print(f"Error in load_data(): Can not read [{infile}]")
        return None

    return df

def split_raw_data(df_original):
    targets         = tkParams()
    descriptors     = tkParams()
    targets.labels0     = []
    targets.labels      = []
    targets.modes       = []
    targets.indexes     = []
    targets.values      = []
    descriptors.labels0 = []

    columns = df_original.columns.to_list()

# ヘッダーの制御コードにより、記述子と目的関数、変換モードを抽出
    idx_target_x = None
    idx_target_y = None
    for s in columns:
# 数値最適化
        if re.match(r'=([+-\.\deE]+):', s, flags = re.IGNORECASE):
            m = re.match(r'=([+-\.\deE]+):(.*)', s, flags = re.IGNORECASE)
            targets.labels0.append(s)
            targets.labels.append(m.groups()[1])
            targets.indexes.append(None)
            targets.modes.append('value')
            targets.values.append(pfloat(m.groups()[0]))
# 最大化
        elif re.match(r'max\d*:', s, flags = re.IGNORECASE):
            m = re.match(r'max(\d*):(.*)', s, flags = re.IGNORECASE)
            targets.labels0.append(s)
            targets.labels.append(m.groups()[1])
            index = pint(m.groups()[0], defval = None)
            if index is None or index == '':
                if idx_target_x is None:
                    index = 0
                    idx_target_x = index
                elif idx_target_y is None:
                    index = 1
                    idx_target_y = index
                else:
                    index = 999999
            else:
                index -= 1
            targets.indexes.append(index)
            targets.modes.append('max')
            targets.values.append(None)
        elif re.match(r'[to]\d*:', s, flags = re.IGNORECASE):
            m = re.match(r'[to](\d*):(.*)', s, flags = re.IGNORECASE)
            targets.labels0.append(s)
            targets.labels.append(m.groups()[1])
            index = pint(m.groups()[0], defval = None)
            if index is None or index == '':
                if idx_target_x is None:
                    index = 0
                    idx_target_x = index
                elif idx_target_y is None:
                    index = 1
                    idx_target_y = index
                else:
                    index = 999999
            else:
                index -= 1
            targets.indexes.append(index)
            targets.modes.append('max')
            targets.values.append(None)
# 最小化
        elif re.match(r'min\d*:', s, flags = re.IGNORECASE):
            m = re.match(r'min(\d*):(.*)', s, flags = re.IGNORECASE)
            targets.labels0.append(s)
            targets.labels.append(m.groups()[1])
            index = pint(m.groups()[0], defval = None)
            if index is None or index == '':
                if idx_target_x is None:
                    index = 0
                    idx_target_x = index
                elif idx_target_y is None:
                    index = 1
                    idx_target_y = index
                else:
                    index = 999999
            else:
                index -= 1
            targets.indexes.append(index)
            targets.modes.append('min')
            targets.values.append(None)
# 記述子、目的関数から除外
        elif s is None or s == '' or re.match(r'\-', s):
            pass
        else:
# 記述子に追加
            descriptors.labels0.append(s)

    ntargets = len(targets.labels0)

    if ntargets == 0:
        targets.labels0.append(columns[0])
        targets.labels.append(columns[0])
        targets.indexes.append(0)
        targets.modes.append('max')
        targets.values.append(None)
        ntargets = 1

# 記述子に目的関数と同じ列が入っている場合、削除
#    print( "  descriptors:", descriptors.labels0)
    try:
        for i in range(ntargets):
            descriptors.labels0.remove(targets.labels0[i])
    except:
        pass

    return targets, descriptors

def get_plot_descriptors(descriptors):
    if len(descriptors) == 1:
        plot_indexes = {'x': 0}
        plot_labels  = {'x': descriptors[0]}
    elif len(descriptors) == 2:
        plot_indexes = {'x': 0, 'y': 1}
        plot_labels  = {'x': descriptors[0], 'y': descriptors[1]}
    else:
        plot_indexes = {'x': 0, 'y': 1, 'z': 2}
        plot_labels  = {'x': descriptors[0], 'y': descriptors[1], 'z': descriptors[2]}
    
    descriptors_label = descriptors
    for i in range(len(descriptors)):
        m = re.search(r'([xyzXYZ]):\s*(.*)\s*$', descriptors[i])
        if m:
            descriptors_label[i] = m.groups()[1]
            var = m.groups()[0].lower()
            plot_indexes[var] = i
            plot_labels[var]  = descriptors_label[i]

    return plot_indexes, plot_labels, descriptors_label

def load_data(app, cparams):
    print("")
    print(f"Read data from [{cparams.infile}]")
    df_original = read_data_file(app, cparams.infile)
    if df_original is None:
        return None, None, None, None, None

    targets_params, descriptors = split_raw_data(df_original)
    ntargets = len(targets_params.labels0)
    print( "  descriptors:", descriptors.labels0)
#    print(f"  343 ntarget={ntargets}")
    print( "  objective functions:")
    for i in range(ntargets):
        if targets_params.values[i] is None:
            print(f"    i={i} Pareto plot index={targets_params.indexes[i]} label={targets_params.labels[i]} mode={targets_params.modes[i]}")
        else:
            print(f"    i={i} Pareto plot index={targets_params.indexes[i]} label={targets_params.labels[i]} mode={targets_params.modes[i]} value={targets_params.values[i]}")

    t_all = df_original[targets_params.labels0]
    x_all = df_original[descriptors.labels0]
#    print(f"{t_all=}")
#    print(f"{x_all=}")
#    exit()

# target functionがnanでないデータを学習データとして抽出
    df2       = df_original.dropna(how = 'any')
    idx_train = df2.index.to_numpy()
#    df2 = df2.reset_index()
    t_train = df2[targets_params.labels0]
    x_train = df2[descriptors.labels0]
#    print(f"{t_train=}")
#    print(f"{x_train=}")

    ndata   = len(df_original.index)
    columns = df_original.columns.to_list()
    index   = df_original.index.to_list()
#    ncol    = len(columns)

# descriptorにnanがあるデータを削除
    drop_idx_all   = []
    drop_idx_train = []
    for idx in index:
        if x_all.iloc[idx].isnull().any():
            drop_idx_all.append(idx)
            if idx in idx_train:
                drop_idx_train.append(idx)

#    print("drop_idx_train=", drop_idx_train)
#    print("drop_idx_all=", drop_idx_all)
#    print("t_all=", t_all)

    t_all   = t_all.drop(index = drop_idx_all, axis = 0)
    x_all   = x_all.drop(index = drop_idx_all, axis = 0)
    t_all   = t_all.to_numpy()
    x_all   = x_all.to_numpy()

    t_train = t_train.drop(index = drop_idx_train, axis = 0)
    x_train = x_train.drop(index = drop_idx_train, axis = 0)
    t_train = t_train.to_numpy()
    x_train = x_train.to_numpy()

# ヘッダーの制御コードによって目的関数を変換。もとの目的関数は t_train_org に保存
    t_train_org = []
    for i in range(ntargets):
        t_train_org.append(t_train)
        if targets_params.modes[i] == 'min':
            for j in range(len(t_train)):
                t_train[j][i] = -t_train[j][i]
        elif targets_params.modes[i] == 'value':
            for j in range(len(t_train)):
                t_train[j][i] = -(t_train[j][i] - targets_params.values[i])**2

    data_all   = tkParams()
    data_train = tkParams()

    data_all.ndata         = ndata
    data_all.columns       = columns
    data_all.indexes       = index
    data_all.targets       = t_all
    data_all.descriptors   = x_all

    data_train.ndata       = len(idx_train)
    data_train.columns     = columns
    data_train.indexes     = idx_train
    data_train.targets     = t_train
    data_train.descriptors = x_train

    return df_original, targets_params, descriptors, data_all, data_train

def validate(app, cparams, wait_by_input = True):
    argv, narg = app.get_argv()
    mtime = time.localtime(os.path.getctime(app.script_path))

    cparams.outresfile = app.replace_path(cparams.infile, 
                    template = ["{filebody}-save{i}.npz"], ext_dict = {'i': '{i}'})
    outresfile = cparams.outresfile

    cparams.outfile    = app.replace_path(cparams.infile, 
                    template = ["{filebody}-predict{i}.xlsx"], ext_dict = {'i': '{i}'})
    outfile = cparams.outfile

    print("")
    print( "====================================================================================")
    print(f"  {argv[0]}: Perform Bayes optimization based on Gauss process ")
    print(f"       Requires PHYSBO: {physbo_url}")
    print(f"# {app.script_path} ver {version[0]}.{version[1]}.{version[2]}")
    print(f"#       Modified time: {mtime.tm_year}/{mtime.tm_mon}/{mtime.tm_mday} {mtime.tm_hour}:{mtime.tm_min}:{mtime.tm_sec}")
    print( "====================================================================================")
    cparams.printinf(app)

    df_original, targets_params, descriptors_params, data_all_params, data_train_params = load_data(app, cparams)
    if df_original is None:
        return None

    X_all         = data_all_params.descriptors
    t_all         = data_all_params.targets
    X_train       = data_train_params.descriptors
    t_train       = data_train_params.targets
    idx_train     = data_train_params.indexes
    descriptors   = descriptors_params.labels0

    ndescriptors  = len(descriptors)
    ntargets      = len(targets_params.labels0)
    targets       = targets_params.labels0
    targets_label = targets_params.labels
    targets_mode  = targets_params.modes
    targets_value = targets_params.values

    ndata       = len(t_all)
    n_traindata = len(idx_train)
    n_testdata  = ndata - n_traindata
    if n_testdata == 0:
        print("")
        print( "====================================================================================")
        print( "  At least one data must be test data (blank target function)")
        print( "    A dummy data is added at top.")
        print( "====================================================================================")
        print( "")
        line0 = X_all[0]
#        print("line0=", line0)
        X_all = np.insert(X_all, 0, line0, axis = 0)
        none_list = [None for d in targets_label]
        t_all = np.insert(t_all, 0, none_list, axis = 0)
        ndata += 1
        n_testdata += 1


def execute(app, cparams, wait_by_input = True):
    argv, narg = app.get_argv()
    mtime = time.localtime(os.path.getctime(app.script_path))

    cparams.outresfile = app.replace_path(cparams.infile, 
                    template = ["{filebody}-save{i}.npz"], ext_dict = {'i': '{i}'})
    outresfile = cparams.outresfile

    cparams.outfile    = app.replace_path(cparams.infile, 
                    template = ["{filebody}-predict{i}.xlsx"], ext_dict = {'i': '{i}'})
    outfile = cparams.outfile

    print("")
    print( "====================================================================================")
    print(f"  {argv[0]}: Perform Bayes optimization based on Gauss process ")
    print(f"       Requires PHYSBO: {physbo_url}")
    print(f"# {app.script_path} ver {version[0]}.{version[1]}.{version[2]}")
    print(f"#       Modified time: {mtime.tm_year}/{mtime.tm_mon}/{mtime.tm_mday} {mtime.tm_hour}:{mtime.tm_min}:{mtime.tm_sec}")
    print( "====================================================================================")
    cparams.printinf(app)

# descriptors_params: 記述子のパラメータインスタンス
# targts_params: 目的関数のパラメータインスタンス
# ntargets: 目的関数の数
# descriptors       : 記述子名 (制御子を含む)
# X_all: 全記述子
# t_all: 全目的関数値
# idx_train  : 学習データの、全データ中のindex
# X_train    : 学習データの記述子
# t_train_org: 学習データのmin,value変換前の値
# t_train    : 学習データのmin,value変換後の値
# targets      : 目的関数の元ラベル
# targets_label: 目的関数名（ラベルから制御子を除いたもの)
# targets_mode: min,max,value
# targets_value: 目的関数の値
    df_original, targets_params, descriptors_params, data_all_params, data_train_params = load_data(app, cparams)
    if df_original is None:
        return None

    X_all         = data_all_params.descriptors
    t_all         = data_all_params.targets
    X_train       = data_train_params.descriptors
    t_train       = data_train_params.targets
    idx_train     = data_train_params.indexes
    descriptors   = descriptors_params.labels0

    ndescriptors  = len(descriptors)
    ntargets      = len(targets_params.labels0)
    targets       = targets_params.labels0
    targets_label = targets_params.labels
    targets_mode  = targets_params.modes
    targets_value = targets_params.values

    ndata       = len(t_all)
    n_traindata = len(idx_train)
    n_testdata  = ndata - n_traindata
    if n_testdata == 0:
        print("")
        print( "====================================================================================")
        print( "  At least one data must be test data (blank target function)")
        print( "    A dummy data is added at top.")
        print( "====================================================================================")
        print( "")
        line0 = X_all[0]
#        print("line0=", line0)
        X_all = np.insert(X_all, 0, line0, axis = 0)
        none_list = [None for d in targets_label]
        t_all = np.insert(t_all, 0, none_list, axis = 0)
        ndata += 1
        n_testdata += 1

#        if wait_by_input:
#            exit()
#        else:
#            return

# plot_indexes: 2D(等高線)-3Dグラフにプロットする記述子番号の辞書型変数 ('x', 'y', 'z')
# plot_labels : 2D(等高線)-3Dグラフにプロットする記述子名 (descriptorsから制御子を除いたもの)
# descriptors_label: 全記述子名 (descriptorsから制御子を除いたもの
    plot_indexes, plot_labels, descriptors_label = get_plot_descriptors(descriptors.copy())

    print("")
    print(f"# of all data: {len(t_all)}")
    print(f"# of training data: {n_traindata}")
    print(f"# of descriptors        : {ndescriptors}")
    print(f"# of objective variables: {ntargets}")
    print(f"  Objective variables: {targets_label}")
    print(f"# of descriptors: {len(descriptors)}")
    print(f"  Descriptors: {descriptors_label}")
    print(f"  Descriptors for plot: {plot_labels}")
#    print("X_train=")
#    print(X_train)
#    print("t_train=")
#    print(t_train)
#    print("X_all=")
#    print(X_all)
#    print("t_all=")
#    print(t_all)

# policy のセット
    print("")
    print("Make policy:")
    print("  Training data indexes:")
    print(idx_train)
    print("    descriptors from X_train:")
    print(X_train)
    print("    descriptors from X_all:")
    print(X_all[idx_train])
    print("    objective values:")
    print(t_train)

    def standardize(X, labels):
        print("")
        print("Standardize descriptors:")
        X_std = X.copy()

        stdX  = np.std(X_train, 0)
        meanX = np.mean(X_train, 0)
        xstd = []
        for i in range(len(stdX)):
            sigma = stdX[i]
            x0    = meanX[i]
            if sigma == 0.0:
                sigma = 1.0
                x0 = 0.0
                print(f"  Warning: {labels[i]:>10} has no distribution (sigma = 0.0). Standardization not applied")
            print(f"  {labels[i]:>10}: X(standardized)[i] = (X[i] - {x0:12.4g}) / {sigma:12.4g}")

            xstd.append((X.T[i] - x0) / sigma)

        X_std= np.array(xstd).T

        return X_std, meanX, stdX

    if cparams.standardize:
        print("Execute standardization")
        X_std, meanX, stdX = standardize(X_all, descriptors)
    else:
        print("Descriptors are not standardized")
        X_std = X_all

    inf = []
    for i in range(ntargets):
        print("")
        print(f"Objective #{i+1}")
# 2021-05-23 物性研CCMS講習会 本山裕一
#   「ベイズ最適化パッケージ PHYSBOの使い方」 physbo_usage.pdf
        t_train1 = t_train.T[i]
        policy = physbo.search.discrete.policy(test_X = X_std, initial_data = (idx_train, t_train1))
#        policy = physbo.search.discrete.policy(test_X = X_all, initial_data = (idx_train, t_train1))
        inf.append({"policy": policy})

# シード値のセット
        seed = cparams.get("random_seed", None)
        if seed is not None:
            seed = pint(seed, defval = None)
            if seed is not None:
                policy.set_seed(0)

# bayes_searchは、simulator = Noneでは actions が返り、simulatorに関数を渡すと Hisotry object が返ってくる
        print("")
        print("Start Bayes search:")
        actions = policy.bayes_search(max_num_probes = cparams.max_num_probes, simulator = None, 
                        score = cparams.score_mode, interval = cparams.interval, num_rand_basis = cparams.num_rand_basis)

        print("")
        print("Bayse search:")
        print("show_search_results")
        physbo.search.utility.show_search_results(policy.history, 10)

# Hisotry objectの取得
        res = policy.export_history()
        best_fx, best_action = res.export_all_sequence_best_fx()
        bayes_x = res.chosen_actions
#        x_bayes = X_all[bayes_x]
        x_bayes = X_std[bayes_x]
        y_bayes = res.fx

# 獲得関数
#        score = policy.get_score(mode = "EI", xs = X_all)
        score = policy.get_score(mode = "EI", xs = X_std)

# 回帰。事後分布の平均値、分散
        mean = policy.get_post_fmean(X_std)
        var  = policy.get_post_fcov(X_std)
#        mean = policy.get_post_fmean(X_all)
#        var  = policy.get_post_fcov(X_all)
        std  = np.sqrt(var)
        mean_m_sigma = mean - std
        mean_p_sigma = mean + std

#print("score=", score)
        idx_best = np.argmax(score)
        print("  Best candidate    :", idx_best, X_all[idx_best], mean[idx_best])
#       print("  Best candidate from hisotry:", int(best_action[-1]), X_best, Y_best)

        inf[i]["mean"] = mean
        inf[i]["var"] = var
        inf[i]["std"] =std
        inf[i]["mean_m_sigma"] = mean - std
        inf[i]["mean_p_sigma"] = mean + std
        inf[i]["idx_best"] = idx_best
        inf[i]["score"] = score

        if cparams.get("outresfile", None) is None or cparams.outresfile == "":
            dirname = get_dirname(cparams.infile)
            cparams.outresfile = app.replace_path(cparams.infile, 
                    template = ["dir", "{filebody}-save{i}.npz"], ext_dict = {'i': '{i}'})

        outresfile = cparams.outresfile.format(i = i + 1)
        print("")
        print(f"Save seach result to [{outresfile}]")
        res.save(outresfile)

        outfile = cparams.outfile.format(i = i + 1)
        print("")
        print(f"Save predictions to [{outfile}]")
        print("X_all.T=")
        print(X_all.T)
        zlist = zip(*X_all.T, t_all.T[i], mean, mean_m_sigma, mean_p_sigma)
        df = pd.DataFrame(list(zlist), 
                      columns = [*descriptors_params.labels0, targets_params.labels0[i], 'mean', 'mean-std', 'mean+std'])
        df.to_excel(outfile)


#=====================================================
# plot
#=====================================================
# TkAggを使うと、自動的に複数のウインドウ位置がずれて表示してくれる
# しかし、TkAggを使うと、plt.pause()のあと、input()でグラフウインドウの制御ができなくなる
    if not wait_by_input:
        matplotlib.use('TkAgg')

    def get_winpos(plt):
        if wait_by_input:
            return None

        pos = plt.get_current_fig_manager().window.wm_geometry()
        print("pos=", pos[0])
        m = re.match(r'(\d+)x(\d+)\+(\d+)\+(\d+)', pos)
        if m:
            g = m.groups()
            return pint(g[0]), pint(g[1]), pint(g[2]), pint(g[3])

        return None

    def set_winpos(plt, x, y):
        if wait_by_input:
            return

        plt.get_current_fig_manager().window.wm_geometry(f"{x}+{y}")

    nxy = int(sqrt(ntargets) + 0.5)
    nx_fig = nxy
    ny_fig = nxy
    if nx_fig * ny_fig < ntargets:
        nx_fig += 1

    axis_inf = []
    win_pos = None
    fig_contour = None
    fig_pareto  = None
    fig_scores  = None

    print("")
    print("Plot")
    print(f"  ntargets={ntargets} in graphs {nx_fig} x {ny_fig}")
    print(f"  ndescriptors={ndescriptors}")
    
    if ndescriptors > 1:
##############################################################################
# 記述子が複数ある場合、等高線図を描く
# x軸はidx_x、y軸はidx_yで指定される記述子 X_allT[idx_x]/[idx_y]
##############################################################################
# indexes of descriptors for 2D plots
        idx_x = plot_indexes['x']
        idx_y = plot_indexes['y']
        print(f"  indexes for 2D contour plot: idx={idx_x} for x-axis, {idx_y} for y-axis")

# mean図のグラフ枠
        fig_contour, axes = plt.subplots(ny_fig, nx_fig, figsize = cparams.figsize)
        fig_contour.canvas.manager.set_window_title('contour: mean')
#        if ny_fig > 1 or nx_fig > 1:
#            fig_contour = [fig_contour]

# std図のグラフ枠
        fig_std_contour, axes_std = plt.subplots(ny_fig, nx_fig, figsize = cparams.figsize)
        fig_std_contour.canvas.manager.set_window_title('contour: std')
        if not isinstance(axes_std, np.ndarray) and type(axes_std) is not list:
#        if ny_fig == 1 and nx_fig == 1:
            axes_std = [axes_std]

        axes     = np.reshape(axes, nx_fig * ny_fig)
        axes_std = np.reshape(axes_std, nx_fig * ny_fig)

        mean_list = []
        std_list = []
        for i in range(ntargets):
            ax1 = axes[i]
            ax1_std = axes_std[i]
#            if isinstance(ax1_std, np.ndarray):
#                ax1_std = ax1_std[0]
            ax1.tick_params(labelsize = cparams.fontsize)
            ax1_std.tick_params(labelsize = cparams.fontsize)

            policy = inf[i]["policy"]
            ndata = len(X_all)
            ndesc = len(X_all[idx_x])

            X_allT = X_all.T
#            x1_unique = sorted(set(X_allT[idx_x]))
#            x2_unique = sorted(set(X_allT[idx_y]))
#            nx = len(x1_unique)
#            ny = len(x2_unique)
            xmin1 = min(X_allT[idx_x])
            xmax1 = max(X_allT[idx_x])
            if xmin1 == xmax1:
                xmin1, xmax1 = xmin1 - 0.5, xmin1 + 0.5
            xmin2 = min(X_allT[idx_y])
            xmax2 = max(X_allT[idx_y])
            if xmin2 == xmax2:
                xmin2, xmax2 = xmin2 - 0.5, xmin2 + 0.5
            xstep1 = (xmax1 - xmin1) / (cparams.nx_2D - 1)
            xstep2 = (xmax2 - xmin2) / (cparams.ny_2D - 1)
            x1_unique = np.arange(xmin1, xmax1 + xstep1 * 0.5, xstep1)
            x2_unique = np.arange(xmin2, xmax2 + xstep2 * 0.5, xstep2)
#            print("x1_unique=", x1_unique)
#            print("x2_unique=", x2_unique)

# 2Dプロット用の記述子の組 X_plot を作成。描画x,y軸用以外の記述子は、最初のデータの記述子の値 X_allT[0] を使う
            n_plot = cparams.nx_2D * cparams.ny_2D
            X_plot = []
            for ix in range(cparams.nx_2D):
                for iy in range(cparams.ny_2D):
                    l = []
                    for idx in range(len(X_all[0])):
                        if idx == idx_x:
                            l.append(x1_unique[ix])
                        elif idx == idx_y:
                            l.append(x2_unique[iy])
                        else:
                            l.append(X_all[0][idx])
                    X_plot.append(l)

            if cparams.standardize:
                X_plot_std, meanX_plat, stdX_plat = standardize(np.array(X_plot), descriptors)
            else:
                X_plot_std = np.array(X_plot)

            mean = policy.get_post_fmean(X_plot_std)
            var  = policy.get_post_fcov(X_plot_std)
            std  = np.sqrt(var)
            mean_list.append(mean)
            std_list.append(std)

            mean_plot = np.zeros([cparams.nx_2D, cparams.ny_2D])
            c = 0
            index_mean = []
            for ix in range(cparams.nx_2D):
                for iy in range(cparams.ny_2D):
                    mean_plot[ix][iy] = mean[c]
                    index_mean.append(c)
                    c += 1

            std_plot = np.zeros([cparams.nx_2D, cparams.ny_2D])
            c = 0
            index_std = []
            for ix in range(cparams.nx_2D):
                for iy in range(cparams.ny_2D):
                    std_plot[ix][iy] = std[c]
                    index_std.append(c)
                    c += 1

#            print("len=", len(x1_unique), len(x2_unique), len(mean_plot), len(mean_plot[idx_x]))
# mean値のマップ描画
            cont = ax1.contourf(x1_unique, x2_unique, mean_plot.T, cmap = "jet_r")
#複数グラフにcolorbarをつける
# https://bourbaki.biz/fit-colorbar-to-a-graph-on-matplotlib/
            div     = make_axes_locatable(ax1)
            cax     = div.append_axes("right", size = "5%", pad = 0.1)
            cb_mean = fig_contour.colorbar(cont, cax = cax)
            cb_mean.ax.tick_params(labelsize = cparams.fontsize)
            ax1.set_title(targets_label[i], fontname="MS Gothic")
            ax1.set_xlabel(descriptors_label[idx_x], fontsize = cparams.fontsize, fontname="MS Gothic")
            ax1.set_ylabel(descriptors_label[idx_y], fontsize = cparams.fontsize, fontname="MS Gothic")

# std値のマップ描画
            cont_std = ax1_std.contourf(x1_unique, x2_unique, std_plot.T, cmap = "jet_r")
#複数グラフにcolorbarをつける
# https://bourbaki.biz/fit-colorbar-to-a-graph-on-matplotlib/
            div_std = make_axes_locatable(ax1_std)
            cax_std = div_std.append_axes("right", size = "5%", pad = 0.1)
            cb_std = fig_std_contour.colorbar(cont_std, cax = cax_std)
            cb_std.ax.tick_params(labelsize = cparams.fontsize)
            ax1_std.set_title(targets_label[i], fontname="MS Gothic")
            ax1_std.set_xlabel(descriptors_label[idx_x], fontsize = cparams.fontsize, fontname="MS Gothic")
            ax1_std.set_ylabel(descriptors_label[idx_y], fontsize = cparams.fontsize, fontname="MS Gothic")

            for j in range(cparams.nx_2D):
                ax1.scatter([x1_unique[j]] * cparams.ny_2D, x2_unique, s = 5.0, 
                        marker = 'o', c = 'yellow', alpha = 0.2, linewidth = 0.5, edgecolors = 'black')
                ax1_std.scatter([x1_unique[j]] * cparams.ny_2D, x2_unique, s = 5.0, 
                        marker = 'o', c = 'yellow', alpha = 0.2, linewidth = 0.5, edgecolors = 'black')

            t = t_train.T[i]
            tmin = min(t)
            tmax = max(t)
            size = (t - tmin) / (tmax - tmin)
#            size = log((t - tmin) / (tmax - tmin) * 1000.0 + 1.0)
#            ax1.scatter(X_train.T[idx_x], X_train.T[idx_y], s = 300.0, 
            ax1.scatter(X_train.T[idx_x], X_train.T[idx_y], s = size * 300.0,
                        marker = '*', c = 'yellow', alpha = 0.5, linewidth = 0.5, edgecolors = 'black')
#            tmax = max(abs(t_train.T[i]))
#            ax1.scatter(X_train.T[idx_x], X_train.T[idx_y], s = t_train.T[i] / tmax * 500, 
#                        marker = '*', c = 'yellow', alpha = 0.7, linewidth = 0.5, edgecolors = 'black')

            ax1_std.scatter(X_train.T[idx_x], X_train.T[idx_y], s = size * 300.0,
                        marker = '*', c = 'yellow', alpha = 0.5, linewidth = 0.5, edgecolors = 'black')

            axis_inf.append({"type": "contour", "label": f"{targets_label[i]}", "axis": ax1, 
                             "x1": x1_unique, "y1": x2_unique,
                             "x2": X_train.T[idx_x], "y2": X_train.T[idx_y],
                             "idx": index_mean, "descriptors": X_plot, "target": None, 
                             "mean": mean, "std": std, "score": None})

            axis_inf.append({"type": "contour", "label": f"{targets_label[i]}", "axis": ax1_std, 
                             "x1": x1_unique, "y1": x2_unique,
                             "x2": X_train.T[idx_x], "y2": X_train.T[idx_y],
                             "idx": index_std, "descriptors": X_plot, "target": None, 
                             "mean": mean, "std": std, "score": None})

        plt.tight_layout()
        plt.pause(0.1)

#        win_pos = get_winpos(plt)
#        pos = [win_pos[2] + 10, win_pos[3] + 10]

    if ntargets > 1:
# 目的関数が複数ある場合、パレート図を描く
# indexes of objective functions for Pareto plot
        idx_Pareto_x = targets_params.indexes[0]
        idx_Pareto_y = targets_params.indexes[1]
        print(f"  indexes for Pareto plot: idx={idx_Pareto_x} for x-axis, {idx_Pareto_y} for y-axis")

        fig_pareto, ax2 = plt.subplots(1, 1, figsize = cparams.figsize)
        fig_pareto.canvas.manager.set_window_title('Pareto')
        ax2.scatter(mean_list[idx_Pareto_x], mean_list[idx_Pareto_y], s = 20.0, 
                        marker = 'o', c = 'red', alpha = 1.0, linewidth = 0.5, edgecolors = 'black')
        ax2.scatter(t_train.T[idx_Pareto_x], t_train.T[idx_Pareto_y], s = 300.0, 
                    marker = '*', c = 'blue', alpha = 0.5, linewidth = 0.5, edgecolors = 'black')
        ax2.set_xlabel(targets_label[idx_Pareto_x], fontsize = cparams.fontsize, fontname="MS Gothic")
        ax2.set_ylabel(targets_label[idx_Pareto_y], fontsize = cparams.fontsize, fontname="MS Gothic")

        axis_inf.append({"type": "Pareto", "label": "Pareto", "axis": ax2, 
                         "x1": mean_list[idx_Pareto_x], "y1": mean_list[idx_Pareto_y],
                         "x2": t_train.T[idx_Pareto_x], "y2": t_train.T[idx_Pareto_y],
                         "idx": idx_train, "descriptors": X_plot, "target": None, 
                         "mean": mean, "std": std, "score": None})


        plt.tight_layout()
#        set_winpos(plt, pos[0], pos[1])
#        pos = [pos[0] + 10, pos[1] + 10]
        plt.pause(0.1)


# 予測値、獲得関数のグラフを描く
    fig_scores, axes = plt.subplots(ny_fig, nx_fig, figsize = cparams.figsize)
    fig_scores.canvas.manager.set_window_title('prediction and aquisition functions')
    axes = np.reshape(axes, nx_fig * ny_fig)

    for i in range(ntargets):
        t_train1 = t_train.T[i]
        mean = inf[i]["mean"]
        var  = inf[i]["var"]
        std  = inf[i]["std"]
        mean_m_sigma = inf[i]["mean_m_sigma"] #= mean - std
        mean_p_sigma = inf[i]["mean_p_sigma"] #= mean + std
        score = inf[i]["score"]
        idx_best = inf[i]["idx_best"]

        ax3 = axes[i]
        ax3b = ax3.twinx()

        """
        ax1.plot(best_fx, label = 'best action', color = 'black')
        ax1.set_xlabel("sequence", fontsize = cparams.fontsize, fontname="MS Gothic")
        ax1.set_ylabel("value",    fontsize = cparams.fontsize, fontname="MS Gothic")
        ax1.tick_params(labelsize = cparams.fontsize)
        ax1.legend(fontsize = cparams.legend_fontsize, loc = 'best')
        """

        x = range(len(X_all))
        ins1 = ax3.plot(idx_train, t_train1, label = 'training', linestyle = '', marker = 'o', markersize = 8.0)
        ins3 = ax3.plot(x, mean,   label = 'mean',   color = 'black',   linewidth = 1.0)
#        ax3.plot(x, mean + std, color = 'blue', linewidth = 0.3)
#       ax3.plot(x, mean - std, color = 'blue', linewidth = 0.3)
        ins6 = ax3b.plot(x, score, label = f'score {cparams.score_mode}', color = 'red', linestyle = 'dashed', linewidth = 0.5)
        ins7 = ax3.plot(idx_best, mean[idx_best], label = 'best candidate', linestyle = '', marker = '*', markersize = 15.0)
        ax3.fill_between(x, mean_m_sigma, mean_p_sigma, color='b', alpha=.1)
        if targets_mode[i] == 'value':
            ax3.plot(ax3.get_xlim(), [0.0, 0.0], linestyle = 'dashed', linewidth = 0.5, color = 'red')

#        axis_inf.append({"label": f"{targets_label[i]}_L", "axis": ax3,  "x1": idx_train, "y1": t_train1, "x2": x, "y2": mean})
        axis_inf.append({"type": "score", "label": f"{targets_label[i]}_R", "axis": ax3b, 
                         "x1": x, "y1": score,
                         "x2": None, "y2": None,
                         "idx": idx_train, "descriptors": X_all, "target": t_all.T[i], 
                         "mean": mean, "std": std, "score": score})

        ax3.minorticks_on()
#        ax3.set_xticklabels(range(ndata))
        ax3.grid(which = "major", axis = "x", color = "green", alpha = 0.5, linestyle = '--', linewidth = 0.5)
        ax3.grid(which = "minor", axis = "x", color = "green", alpha = 0.5, linestyle = '--', linewidth = 0.1)
        ax3.set_xlabel( "index", fontsize = cparams.fontsize, fontname="MS Gothic")
        ax3.set_ylabel(targets_label[i], fontsize = cparams.fontsize, fontname="MS Gothic")
        ax3b.set_ylabel(f"score {cparams.score_mode}", fontsize = cparams.fontsize)
#        ax3b.set_yscale('log')
        ax3.tick_params( labelsize = cparams.fontsize)
        ax3b.tick_params(labelsize = cparams.fontsize)
        ins = ins1 + ins3 + ins6 + ins7
        ax3.legend(ins, [l.get_label() for l in ins], fontsize = cparams.legend_fontsize, loc = 'best')

#        ax1.set_ylim(ax3.get_ylim())

    plt.tight_layout()
#    set_winpos(plt, pos[0], pos[1])
#    pos = [pos[0] + 10, pos[1] + 10]
    plt.pause(0.1)


# マウスクリック時に最近接データの情報をコンソールに表示
    hit_itarget = None
    hit_axisinf = None

# マウスクリック時に最近接データを検索
    def find_nearest_data(type, x, y, xlist, ylist, idx_list, x2list, y2list):
        minr2 = 1.0e300
        ihit = None
        minr2b = 1.0e300
        ihitb  = None
        if type == 'score':
            for i in range(len(xlist)):
                xi = xlist[i]
                yi = ylist[i]
                r2 = (x - xi)**2 + (y - yi)**2
#                print("i=", i, x, y, xi, yi, r2, end = '')
                if minr2 > r2:
                    minr2 = r2
                    ihit  = i
#               print("  => ", minr2, ihit)
        elif type == 'contour':
            c = 0
            for xi in xlist:
                for yi in ylist:
                    r2 = (x - xi)**2 + (y - yi)**2
#                    print("x=", c, x, y, xi, yi, r2, end = '')

                    if minr2 > r2:
                        minr2 = r2
                        ihit  = idx_list[c]
#                    print("  => ", minr2, ihit)

                    c += 1

            for i in range(len(x2list)):
                xi = x2list[i]
                yi = y2list[i]
                r2 = (x - xi)**2 + (y - yi)**2
#                print("b: x=", i, x, y, xi, yi, r2, end = '')

                if minr2b > r2:
                    minr2b = r2
                    ihitb  = i
#                print("  => ", minr2b, ihitb)
        elif type == 'Pareto':
            for i in range(len(xlist)):
                xi = xlist[i]
                yi = ylist[i]
                r2 = (x - xi)**2 + (y - yi)**2
#                print("x=", i, x, y, xi, yi, r2, end = '')

                if minr2 > r2:
                    minr2 = r2
                    ihit  = i
#                print("  => ", minr2, ihit)

            for i in range(len(x2list)):
                xi = x2list[i]
                yi = y2list[i]
                r2 = (x - xi)**2 + (y - yi)**2
#                print("b: x=", i, x, y, xi, yi, r2, end = '')

                if minr2b > r2:
                    minr2b = r2
                    ihitb  = i
#                print("  => ", minr2b, ihitb)
        else:
            print("")
            print(f"Error: Invalid type [{type}] in find_nearest_data()")
            exit()
            
        return ihit, minr2, ihitb, minr2b

    def onclick(event):
# マウスクリック時にクリックされたグラフ軸を特定
#       print("click", event.inaxes, ax1, ax3, ax3b)
        for i in range(len(axis_inf)):
#            print("a=", event.inaxes, axis_inf[i]["axis"])
            if event.inaxes == axis_inf[i]["axis"]:
               hit_itarget = i
               hit_axisinf = axis_inf[i]
               break
        else:
            return

# 最近接データを取得
        xe, ye  = event.xdata, event.ydata
        idx, r2, idxb, r2b = find_nearest_data(hit_axisinf["type"], xe, ye, 
                            hit_axisinf["x1"], hit_axisinf["y1"], hit_axisinf["idx"],
                            hit_axisinf["x2"], hit_axisinf["y2"])

# 最近接データの情報をコンソールに表示
        print("")
#        print("descriptors=", hit_axisinf['descriptors'])
#        print("descriptors=", hit_axisinf['descriptors'][idx])
        if hit_axisinf["type"] == 'score':
            print(f"clicked in {hit_axisinf['label']} near the data index={idx} / the excel line {idx+2}:")
            print(f"    descriptors = {hit_axisinf['descriptors'][idx]}   given objective value = {hit_axisinf['target'][idx]}")
            print(f"    predicted objective value = {hit_axisinf['mean'][idx]:10.4g} +- {hit_axisinf['std'][idx]:10.4g}")
            print(f"    score = {hit_axisinf['score'][idx]:10.4g}")
        elif hit_axisinf["type"] == 'contour':
# indexes of descriptors for 2D plots
            idx_x = plot_indexes['x']
            idx_y = plot_indexes['y']
            print(f"clicked in {hit_axisinf['label']} at the data index={idx} near {descriptors[idx_x]}={xe:10.4g} and {descriptors[idx_y]}={ye:10.4g}")
            print(f"    descriptors = {hit_axisinf['descriptors'][idx]}")
            print(f"    predicted objective value = {hit_axisinf['mean'][idx]:10.4g} +- {hit_axisinf['std'][idx]:10.4g}")
            print(f"  Nearest training data:")
            print(f"    descriptors = {X_train[idxb]}  objective value={t_train[idxb]}")
        elif hit_axisinf["type"] == 'Pareto':
# indexes of objective functions for Pareto plot
            idx_Pareto_x = targets_params.indexes[0]
            idx_Pareto_y = targets_params.indexes[1]
            print(f"clicked in {hit_axisinf['label']} at the data index={idx} near {targets_label[idx_Pareto_x]}={xe:10.4g} and {targets_label[idx_Pareto_y]}={ye:10.4g}")
            print(f"    descriptors = {hit_axisinf['descriptors'][idx]}")
            print(f"    predicted objective values:")
            for i in range(ntargets):
                print(f"    #{i+1} for {targets_label[i]}: {mean_list[i][idx]:10.4g} +- {std_list[i][idx]:10.4g}")
            print(f"  Nearest training data:")
            print(f"    descriptors = {X_train[idxb]}  objective value={t_train[idxb]}")
        else:
            print("")
            print(f"Error: Invalid type [{type}] in onclick()")
            exit()

    if fig_contour:
        fig_contour.canvas.mpl_connect("button_press_event", onclick)
    if fig_std_contour:
        fig_std_contour.canvas.mpl_connect("button_press_event", onclick)
    if fig_pareto:
        fig_pareto.canvas.mpl_connect("button_press_event", onclick)
    if fig_scores:
        fig_scores.canvas.mpl_connect("button_press_event", onclick)
#       fig_scores.canvas.mpl_connect("motion_notify_event", hover)

    if wait_by_input:
        app.terminate(usage = usage, pause = wait_by_input)
    else:
# Called by GUI
        pass

if __name__ == "__main__":
    app, cparams = initialize()
    update_vars(app, cparams)

    cparams.logfile = app.replace_path(cparams.infile)
    print(f"Open logfile [{cparams.logfile}]")
    app.redirect(targets = ["stdout", cparams.logfile], mode = 'w')

    execute(app, cparams)

    app.terminate(app, "", usage = None, pause = True)
