import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import time


# サンプルサイズ
burn_in = 1000
thinning = 1
sample_size = 1000
num_chains = 4  # 複数チェーン


# 逆関数法による正規分布乱数生成
def inverse_transform_sampling_normal(n_samples, mean=0, std=1, eps=1e-10):
    """逆関数法で正規分布乱数を生成（クリップ付き）"""
    u = np.random.uniform(0, 1, n_samples)
    u = np.clip(u, eps, 1 - eps)  # 0と1を避けるためにクリップ
    samples = norm.ppf(u)  # 標準正規分布の逆累積分布関数
    samples = mean + std * samples  # 平均と標準偏差を適用
    return samples

# 正規分布の乱数生成（numpy.random.normal）
def generate_normal_samples(sample_size, mean=0, std=1):
    return np.random.normal(mean, std, sample_size)

def target_dist(theta, mean=0, std=1):  # 目標分布（事後分布に対応）
    return np.exp(-(theta - mean)**2 / (2.0 * std**2))

# Box-Muller法
def box_muller(sample_size, mean=0, std=1):
    u1 = np.random.rand(sample_size // 2)
    u2 = np.random.rand(sample_size // 2)
    z1 = np.sqrt(-2 * np.log(u1)) * np.cos(2 * np.pi * u2)
    z2 = np.sqrt(-2 * np.log(u1)) * np.sin(2 * np.pi * u2)
    samples = np.concatenate([z1, z2])
    samples = mean + std * samples  # 平均と標準偏差を適用
    return samples

# Hastings法 (Metropolis-Hastings法) - 複数チェーンを使用
def metropolis_hastings_multiple_chains(sample_size, target_mean=0, target_std=1, burn_in=1000, thinning=10, num_chains=4):
    samples = []
    proposal_std = 1.0  # 提案分布の標準偏差
    
    for _ in range(num_chains):
        chain_samples = []
        theta = target_mean  # 各チェーンの初期値
        for _ in range(sample_size):
            # 提案分布に従って次のサンプルを生成
            theta_p = theta + np.random.normal(0, proposal_std)
            r = target_dist(theta_p, target_mean, target_std) / target_dist(theta, target_mean, target_std)  # 受容比率
            if np.random.rand() < min(1, r):
                theta = theta_p
            chain_samples.append(theta)
        
        # チェーンごとにバーンインと間引き
        chain_samples = np.array(chain_samples)
        chain_samples = chain_samples[burn_in:]  # バーンイン
        chain_samples = chain_samples[::thinning]  # 間引き

        samples.append(chain_samples)
    
    # 複数チェーンのサンプルを統合
    return np.concatenate(samples)

# 高精度な時間計測を関数化
def time_execution(func, *args):
    start_time = time.perf_counter()  # 高精度タイマー
    result = func(*args)
    end_time = time.perf_counter()
    return result, end_time - start_time

# メインの処理
def main():
    # ベンチマーク計測（実行時間の計測）
    normal_samples, normal_time = time_execution(generate_normal_samples, sample_size)
    box_muller_samples, box_muller_time = time_execution(box_muller, sample_size)
    inverse_transform_samples, inverse_transform_time = time_execution(inverse_transform_sampling_normal, sample_size)
    mh_samples, mh_time = time_execution(metropolis_hastings_multiple_chains, sample_size, 0, 1, burn_in, thinning, num_chains)

    # 理論的な正規分布
    x = np.linspace(-5, 5, 1000)
    theoretical_normal = norm.pdf(x, 0, 1)

    # ベンチマーク結果を表示
    print(f"Time for Normal Distribution (numpy): {normal_time*1000:.6f} ms")
    print(f"Time for Box-Muller: {box_muller_time*1000:.6f} ms")
    print(f"Time for Inverse Transform Sampling: {inverse_transform_time*1000:.6f} ms")
    print(f"Time for Metropolis-Hastings (multiple chains): {mh_time*1000:.6f} ms")

    # ヒストグラムを折れ線グラフとして表示
    plt.figure(figsize=(10, 6))
    # 通常の正規分布サンプル
    plt.hist(normal_samples, bins=50, density=True, alpha=0.6, histtype='step', label="Normal Distribution (numpy)", color='blue')
    # Box-Mullerサンプル
    plt.hist(box_muller_samples, bins=50, density=True, alpha=0.6, histtype='step', label="Box-Muller", color='green')
    # 逆関数法サンプル
    plt.hist(inverse_transform_samples, bins=50, density=True, alpha=0.6, histtype='step', label="Inverse Transform Sampling", color='orange')
    # Metropolis-Hastingsサンプル (複数チェーン)
    plt.hist(mh_samples, bins=50, density=True, alpha=0.6, histtype='step', label="Metropolis-Hastings (multiple chains)", color='red')

    # 理論正規分布を折れ線グラフとして表示
    plt.plot(x, theoretical_normal, 'k--', label="Theoretical Normal Distribution", linewidth=2)

    # グラフの設定
    plt.title("Comparison of Normal Distribution Sampling Methods")
    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.legend()
    plt.grid(True)
    plt.show()

# ライブラリとしても使用可能にするため、以下のコードを追加
if __name__ == "__main__":
    main()
