import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import time

# サンプルサイズ
sample_size = 10000
num_chains = 4  # 複数チェーン

# 正規分布の乱数生成（numpy.random.normal）
def generate_normal_samples(sample_size, mean=0, std=1):
    return np.random.normal(mean, std, sample_size)

# Box-Muller法
def box_muller(sample_size):
    u1 = np.random.rand(sample_size // 2)
    u2 = np.random.rand(sample_size // 2)
    z1 = np.sqrt(-2 * np.log(u1)) * np.cos(2 * np.pi * u2)
    z2 = np.sqrt(-2 * np.log(u1)) * np.sin(2 * np.pi * u2)
    return np.concatenate([z1, z2])

# Hastings法 (Metropolis-Hastings法) - 複数チェーンを使用
def metropolis_hastings_multiple_chains(sample_size, target_mean=0, target_std=1, burn_in=1000, thinning=10, num_chains=4):
    chains = []
    
    # target_func (目標分布) の対数を使う
    target_log_pdf = lambda x: -0.5 * ((x - target_mean) / target_std) ** 2  # 対数正規分布

    for _ in range(num_chains):  # 複数チェーン
        current_sample = np.random.normal(target_mean, target_std)
        samples = []
        
        for i in range(sample_size + burn_in):
            proposed_sample = np.random.normal(target_mean, target_std)
            
            # 受容比率（対数確率密度の差で計算）
            log_acceptance_ratio = target_log_pdf(proposed_sample) - target_log_pdf(current_sample)
            acceptance_ratio = min(1, np.exp(log_acceptance_ratio))  # expを使って確率に戻す

            if np.random.rand() < acceptance_ratio:
                current_sample = proposed_sample

            # バーニング期間を過ぎた後、間引きをしてサンプルを取ります
            if i >= burn_in and (i - burn_in) % thinning == 0:
                samples.append(current_sample)
        
        chains.append(np.array(samples))
    
    # 複数のチェーンからサンプルを統合
    combined_samples = np.concatenate(chains)
    return combined_samples

# 時間計測を関数化
def time_execution(func, *args):
    start_time = time.time()
    result = func(*args)
    end_time = time.time()
    return result, end_time - start_time

# ベンチマーク計測（実行時間の計測）
normal_samples, normal_time = time_execution(generate_normal_samples, sample_size)
box_muller_samples, box_muller_time = time_execution(box_muller, sample_size)

# Metropolis-Hastingsの引数を位置引数として渡す
mh_samples, mh_time = time_execution(metropolis_hastings_multiple_chains, sample_size, 0, 1, 1000, 10, num_chains)

# 理論的な正規分布
x = np.linspace(-5, 5, 1000)
theoretical_normal = norm.pdf(x, 0, 1)

# ベンチマーク結果を表示
print(f"Time for Normal Distribution (numpy): {normal_time:.6f} seconds")
print(f"Time for Box-Muller: {box_muller_time:.6f} seconds")
print(f"Time for Metropolis-Hastings (multiple chains): {mh_time:.6f} seconds")

# ヒストグラムを折れ線グラフとして表示
plt.figure(figsize=(10, 6))
# 通常の正規分布サンプル
plt.hist(normal_samples, bins=50, density=True, alpha=0.6, histtype='step', label="Normal Distribution (numpy)", color='blue')
# Box-Mullerサンプル
plt.hist(box_muller_samples, bins=50, density=True, alpha=0.6, histtype='step', label="Box-Muller", color='green')
# Metropolis-Hastingsサンプル (複数チェーン)
plt.hist(mh_samples, bins=50, density=True, alpha=0.6, histtype='step', label="Metropolis-Hastings (multiple chains)", color='red')

# 理論正規分布を折れ線グラフとして表示
plt.plot(x, theoretical_normal, 'k--', label="Theoretical Normal Distribution", linewidth=2)

# グラフの設定
plt.title("Comparison of Normal Distribution Sampling Methods")
plt.xlabel("Value")
plt.ylabel("Density")
plt.legend()
plt.grid(True)
plt.show()
