import sys
import random
import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


nparameters = 4   # 2 or 4

method = "ga"

# 変数の範囲
if nparameters == 4:
    bounds = [[-10, 10], [-10, 10], [-10, 10], [-10, 10]]
else:
    bounds = [[-10, 10], [-10, 10]]
#bounds = [(-20, 20), (-20, 20)]
#bounds = [(-30, 30), (-30, 30)]
#初期値 (simplex)
if nparameters == 4:
    initial_guess = [5, 5, 5, 5]
else:
    initial_guess = [5, 5]
#initial_guess = [50, 50]

# 母集団の個体数
nparents = 50

nmaxiter = 100
tol = 1.0e-3

print_level = 0


nargs = len(sys.argv)
if nargs > 1: method = sys.argv[1]
if nargs > 2: nparents = int(sys.argv[2])
if nargs > 3: tol      = float(sys.argv[3])
if nargs > 4: nmaxiter = int(sys.argv[4])
if nargs > 5: 
    bounds[0][0] = float(sys.argv[5])
    bounds[1][0] = bounds[0][0]
if nargs > 6:
    bounds[0][1] = float(sys.argv[6])
    bounds[1][1] = bounds[0][1]
if nargs > 7: print_level = int(sys.argv[7])


def AckleyFunc(X, Y):
# Ackley function
    t1 = 20
    t2 = -20 * np.exp(-0.2 * np.sqrt(1.0 / 2 * (X**2 + Y**2 )))
    t3 = np.e
    t4 = -np.exp(1.0 / 2 * (np.cos(2 * np.pi * X)+np.cos(2 * np.pi * Y)))

    return t1 + t2 + t3 + t4

icall = 0
def minimize_func(xk):
    global icall

    icall += 1

    if nparameters == 4:
        return AckleyFunc(xk[0], xk[1]) + 2.0 * AckleyFunc(xk[2], xk[3])
    else:
        return AckleyFunc(xk[0], xk[1])


iter = 0
def callback(xk):
    global iter

    if print_level:
        print(f"iter={iter}: fmin=", minimize_func(xk), "  xk:", xk)

    iter += 1

def usage():
    print(f"\nUsage: python {sys.argv[0]} method=[ga|pso|remc|simplex] nparents tol nmaxiter\n")


def sampling(bounds):
    return np.array([random.uniform(low, high) for low, high in bounds])

#######################
# GA
#######################
def mutate(individual, bounds, mutation_rate = 0.01):
    for i in range(len(individual)):
        if random.random() < mutation_rate:
            individual[i] = random.uniform(bounds[i][0], bounds[i][1])

def crossover(parent1, parent2):
    crossover_point = random.randint(1, len(parent1) - 1)
    child1 = np.concatenate((parent1[:crossover_point], parent2[crossover_point:]))
    child2 = np.concatenate((parent2[:crossover_point], parent1[crossover_point:]))
    return child1, child2

def select_best_from_random_group(population, scores, k = 3):
    selected = random.choices(population, k = k)

# selectedから、scoresに対応するindexを検索してscoresに入れる
#    selected_scores = [scores[i] for i, ind in enumerate(population) if any((ind == s).all() for s in selected)]
    selected_scores = []
    selected_params = []
    for i, ind in enumerate(population):
        for s in selected:
            if (ind == s).all():
                selected_scores.append(scores[i])
                selected_params.append(s)
                break
#    if not selected_scores:  # selected_scoresが空の場合の対処
#        return random.choice(population)

    idx = np.argmin(selected_scores)
    return selected_params[idx]

def genetic_algorithm(objective_function, bounds, population_size, callback = None, tol = 1.0e-5, nmaxiter = 1000, mutation_rate = 0.01):
    population = [sampling(bounds) for _ in range(population_size)]
    best_individual = population[0]
    best_score = objective_function(best_individual)

    for generation in range(nmaxiter):
        scores = [objective_function(ind) for ind in population]
        new_population = []
        for i in range(population_size // 2):
            parent1 = select_best_from_random_group(population, scores)
            parent2 = select_best_from_random_group(population, scores)
            child1, child2 = crossover(parent1, parent2)
            mutate(child1, bounds, mutation_rate)
            mutate(child2, bounds, mutation_rate)
            new_population.extend([child1, child2])

        population = new_population
        current_best = min(population, key=objective_function)
        current_best_score = objective_function(current_best)
        if current_best_score < best_score:
            best_individual, best_score = current_best, current_best_score

        if callback: callback(best_individual)

        if current_best_score < tol: 
            print("Converged\n")
            return best_individual, best_score

    print("Not converged\n")
    return best_individual, best_score


##########################
# sworm optimization
##########################
class Particle:
    def __init__(self, bounds):
        self.position = np.array([np.random.uniform(low, high) for low, high in bounds])
        self.velocity = np.array([0.0 for _ in bounds])
        self.best_position = self.position.copy()
        self.best_score = float('inf')

    def update_velocity(self, global_best_position, inertia, cognitive, social):
        r1, r2 = np.random.rand(2)
        cognitive_velocity = cognitive * r1 * (self.best_position - self.position)
        social_velocity = social * r2 * (global_best_position - self.position)
        self.velocity = inertia * self.velocity + cognitive_velocity + social_velocity

    def update_position(self, bounds):
        self.position += self.velocity
        for i, (low, high) in enumerate(bounds):
            if self.position[i] < low:
                self.position[i] = low
            elif self.position[i] > high:
                self.position[i] = high

def particle_swarm_optimization(objective_function, bounds, num_particles, callback = None, tol = 1.0e-5, nmaxiter = 1000, inertia = 0.5, cognitive = 1.5, social = 1.5):
    particles = [Particle(bounds) for _ in range(num_particles)]
    global_best_position = particles[0].position.copy()
    global_best_score = float('inf')

    for _ in range(nmaxiter):
        for particle in particles:
            score = objective_function(particle.position)
            if score < particle.best_score:
                particle.best_score = score
                particle.best_position = particle.position.copy()
            if score < global_best_score:
                global_best_score = score
                global_best_position = particle.position.copy()
                if callback: callback(global_best_position)

        for particle in particles:
            particle.update_velocity(global_best_position, inertia, cognitive, social)
            particle.update_position(bounds)

        if global_best_score < tol: 
            print("Converged\n")
            return global_best_position, global_best_score

    print("Not converged\n")
    return global_best_position, global_best_score

##############################
# Replica Exchange MonteCarlo
##############################
def exchange_acceptance(delta, temp1, temp2):
    return np.exp(delta * (1/temp1 - 1/temp2))

def replica_exchange_monte_carlo(objective_function, bounds, num_replicas, temperatures, callback = None, tol = 1.0e-5, nmaxiter = 1000):
    replicas = [sampling(bounds) for _ in range(num_replicas)]
    best_individual = replicas[0]
    best_score = objective_function(best_individual)

    for iteration in range(nmaxiter):
        for i in range(num_replicas):
            new_individual = sampling(bounds)
            delta = objective_function(new_individual) - objective_function(replicas[i])
            if delta < 0 or np.random.rand() < np.exp(-delta / temperatures[i]):
                replicas[i] = new_individual

        for i in range(num_replicas - 1):
            delta = objective_function(replicas[i+1]) - objective_function(replicas[i])
            if np.random.rand() < exchange_acceptance(delta, temperatures[i], temperatures[i+1]):
                replicas[i], replicas[i+1] = replicas[i+1], replicas[i]

        current_best = min(replicas, key=objective_function)
        current_best_score = objective_function(current_best)
        if current_best_score < tol: 
            print("Converged\n")
            return best_individual, best_score
        
        if current_best_score < best_score:
            best_individual, best_score = current_best, current_best_score

        if callback: callback(best_individual)

    print("Not converged\n")
    return best_individual, best_score


def optimize(method):
    global iter, icall
    
    iter = 0
    icall = 0

    print()
    print(f"method: {method}")
    
    if method == 'ga':
        best_params, best_score = genetic_algorithm(minimize_func, bounds, nparents, 
                                        callback = callback, tol = tol, nmaxiter = nmaxiter)
    elif method == 'simplex':
        if nparameters == 4:
            initial_simplex = np.array([
                initial_guess,
                (bounds[0][0], initial_guess[1], initial_guess[2], initial_guess[3]),
                (initial_guess[0], bounds[1][1], initial_guess[2], initial_guess[3]),
                (initial_guess[0], initial_guess[1], bounds[2][0], initial_guess[3]),
                (initial_guess[0], initial_guess[1], initial_guess[2], bounds[3][1]),
                ])
        else:
            initial_simplex = np.array([
                initial_guess,
                (bounds[0][0], initial_guess[1]),
                (initial_guess[0], bounds[1][1]),
                ])
        result = minimize(minimize_func, initial_guess, method = 'Nelder-Mead', tol = tol,
                        callback = callback, 
                        options = {'initial_simplex': initial_simplex, "maxiter": nmaxiter})
        best_params = result.x
        best_score = result.fun
        if result.success:
            print("Converged\n")
        else:
            print("Not converged\n")
    elif method == 'pso':
        best_params, best_score = particle_swarm_optimization(minimize_func, bounds, nparents, 
                                    callback = callback, tol = tol, nmaxiter = nmaxiter)
    elif method == 'remc':
        temperatures = np.linspace(1, 10, nparents)
        best_params, best_score = replica_exchange_monte_carlo(minimize_func, bounds, nparents, temperatures, 
                            callback = callback, tol = tol, nmaxiter = nmaxiter)
    else:
        print(f"Invalide method [{method}]")
        exit()

    print(f"Best parameters: {best_params}")
    print(f"Best score: {best_score}")
    print(f"icall={icall}  iter={iter}")
    print()

def plot():
    x = np.linspace(-5, 5, 400)
    y = np.linspace(-5, 5, 400)
    x, y = np.meshgrid(x, y)    

    z = minimize_func([x, y])

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(x, y, z, cmap='viridis')

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    plt.pause(0.01)


def main():
    print()
    print(f"method: {method}")
    print(f"bounds:", bounds)
    print(f"initial guess:", initial_guess)
    print(f"nmaxiter: {nmaxiter}")

    if method == 'all':
        for m in ['ga', 'pso', 'remc', 'simplex']:
            optimize(m)
    else:
       optimize(method)


if __name__ == '__main__':
    if nparameters == 2: plot()

    main()
    
    usage()

    input("\nPress ENTER to terminate>>\n")
    
    