#!/usr/bin/env python3
"""
diagonalize2d.py

概要:
    2x2実数行列の対角化を可視化するスクリプト。

詳細説明:
    このスクリプトは以下の機能を提供します:
    1. 元の基底ベクトル (a1, a2) とそれらの行列Sによる像 (S a1, S a2) を描画します。
    2. 2x2実数行列Sの固有値と固有ベクトルを計算します。
    3. 行列が実数上で対角化可能でない場合、メッセージを表示して終了します。
    4. 行列が実数上で対角化可能である場合、実固有ベクトル (v1, v2) とそれらの像 (S v1, S v2) を描画します。
    5. 対角化前後の基底ベクトル間の角度を出力します。

解釈:
    - Sが対称行列の場合、対角化基底は直交するので、角度は90度を保ちます。
    - Sが非対称であるが実数上で対角化可能な場合、対角化基底は一般に直交しません。
    - 複素固有値が現れる場合、実数の主軸基底は存在しません。

Usage:
    python diagonalize2d.py S11 S12 S21 S22

関連リンク:
    :doc:`diagonalize2d_usage`
"""

from __future__ import annotations

import argparse
import math

import matplotlib.pyplot as plt
import numpy as np


EPS = 1e-10


def unit(v: np.ndarray) -> np.ndarray:
    """
    概要:
        与えられたベクトルを単位ベクトルに正規化する。

    詳細説明:
        ベクトルのL2ノルムが非常に小さい（EPS以下）場合、ゼロベクトルとみなしValueErrorを発生させる。

    :param v: 正規化する2Dベクトル（numpy.ndarray形式）。
    :returns: 正規化された単位ベクトル（numpy.ndarray形式）。
    :raises ValueError: ゼロベクトルを正規化しようとした場合。
    """
    n = np.linalg.norm(v)
    if n < EPS:
        raise ValueError("Zero vector cannot be normalized.")
    return v / n


def angle_deg(v: np.ndarray) -> float:
    """
    概要:
        2Dベクトルが+X軸となす角度を度数で計算する。

    詳細説明:
        角度の範囲は (-180, 180] 度となる。

    :param v: 角度を計算する2Dベクトル（numpy.ndarray形式）。
    :returns: +X軸となす角度（度数）。
    """
    return math.degrees(math.atan2(float(v[1]), float(v[0])))


def angle_between_deg(v1: np.ndarray, v2: np.ndarray) -> float:
    """
    概要:
        2つの2Dベクトル間の最小角度を度数で計算する。

    詳細説明:
        計算される角度の範囲は [0, 180] 度となる。内積とアークコサインを用いて計算される。

    :param v1: 1つ目の2Dベクトル（numpy.ndarray形式）。
    :param v2: 2つ目の2Dベクトル（numpy.ndarray形式）。
    :returns: 2つのベクトル間の最小角度（度数）。
    """
    u1 = unit(v1)
    u2 = unit(v2)
    c = float(np.clip(np.dot(u1, u2), -1.0, 1.0))
    return math.degrees(math.acos(c))


def is_parallel(v1: np.ndarray, v2: np.ndarray, tol: float = 1e-9) -> bool:
    """
    概要:
        2つの2Dベクトルが平行であるかをチェックする。

    詳細説明:
        2Dベクトル `(x1, y1)` と `(x2, y2)` が平行である場合、`x1*y2 - y1*x2` (2D外積のZ成分) は0になる。
        この値が指定された許容誤差 `tol` よりも小さい絶対値であれば、平行とみなす。

    :param v1: 1つ目の2Dベクトル（numpy.ndarray形式）。
    :param v2: 2つ目の2Dベクトル（numpy.ndarray形式）。
    :param tol: 平行とみなすための許容誤差。
    :returns: 2つのベクトルが平行であればTrue、そうでなければFalse。
    """
    return abs(float(v1[0] * v2[1] - v1[1] * v2[0])) < tol


def classify_real_diagonalizable(S: np.ndarray):
    """
    概要:
        2x2行列が実数上で対角化可能であるかを判定する。

    詳細説明:
        行列Sの固有値と固有ベクトルを計算し、以下の条件で対角化可能性を判定する。
        1. 固有値が複素数を含む場合、実数上で対角化不可能。
        2. 実数固有値のみの場合でも、固有ベクトルが線形独立でない場合、実数上で対角化不可能。
        それ以外の場合は、実数上で対角化可能と判定する。

    :param S: 判定対象の2x2実数行列（numpy.ndarray形式）。
    :returns: (tuple)
        - `diagonalizable_over_R` (bool): 実数上で対角化可能であればTrue、そうでなければFalse。
        - `message` (str): 対角化可能性に関する説明メッセージ。
        - `eigenvalues` (numpy.ndarray): 計算された固有値。複素数を含む場合がある。
        - `eigenvectors` (numpy.ndarray): 計算された固有ベクトル。複素数を含む場合がある。
    """
    vals, vecs = np.linalg.eig(S)

    if np.max(np.abs(vals.imag)) > 1e-9:
        return False, "Complex eigenvalues: no real principal axes.", vals, vecs

    vals = vals.real
    vecs = vecs.real

    rank = np.linalg.matrix_rank(vecs)
    if rank < 2:
        return False, "Real eigenvalues found, but eigenvectors are not independent: not diagonalizable over R.", vals, vecs

    return True, "Diagonalizable over the real numbers.", vals, vecs


def sort_eigensystem(vals: np.ndarray, vecs: np.ndarray):
    """
    概要:
        固有値と固有ベクトルをソートし、表示用に正規化する。

    詳細説明:
        固有値を昇順にソートし、それに対応するように固有ベクトルの順序も変更する。
        さらに、各固有ベクトルの最初の非ゼロ成分が正になるように符号を調整し、表示を統一する。

    :param vals: ソート対象の固有値（numpy.ndarray形式）。
    :param vecs: ソート対象の固有ベクトル（numpy.ndarray形式、各列が1つの固有ベクトル）。
    :returns: (tuple)
        - `vals` (numpy.ndarray): ソートされ、符号調整後の固有値。
        - `vecs` (numpy.ndarray): ソートされ、符号調整後の固有ベクトル。
    """
    idx = np.argsort(vals)
    vals = vals[idx]
    vecs = vecs[:, idx]

    for j in range(2):
        v = unit(vecs[:, j])
        if abs(float(v[0])) > EPS:
            if float(v[0]) < 0:
                v = -v
        else:
            if float(v[1]) < 0:
                v = -v
        vecs[:, j] = v

    return vals, vecs


def plot_vectors(
    ax,
    basis_vectors,
    image_vectors,
    basis_labels,
    image_labels,
    title,
):
    """
    概要:
        指定された軸上に基底ベクトルとそれらの像を描画する。

    詳細説明:
        プロットの軸範囲は自動的に調整され、全てのベクトルが見えるように設定される。
        基底ベクトルは黒い実線矢印、像ベクトルは赤い実線矢印で表示される。
        軸の目盛り、ラベル、タイトルは非表示になり、グリッドと枠線も描画されない。

    :param ax: 描画対象のMatplotlib Axesオブジェクト。
    :param basis_vectors: 基底ベクトルのリスト（各要素はnumpy.ndarray形式）。
    :param image_vectors: 基底ベクトルの像のリスト（各要素はnumpy.ndarray形式）。
    :param basis_labels: 基底ベクトルのラベルのリスト（文字列）。
    :param image_labels: 像ベクトルのラベルのリスト（文字列）。
    :param title: プロットのタイトル（未使用、コード内でコメントアウトされている）。
    :returns: なし
    """
    all_vecs = basis_vectors + image_vectors
    max_norm = max(np.linalg.norm(v) for v in all_vecs)
    lim = max(1.2, 1.25 * max_norm)

    for v, label in zip(basis_vectors, basis_labels):
        ax.arrow(
            0, 0, float(v[0]), float(v[1]),
            length_includes_head=True,
            head_width=0.06 * lim,
            head_length=0.08 * lim,
            linewidth=1.8,
            color='black',
        )
        ax.text(float(v[0]) * 1.05, float(v[1]) * 1.05, label, fontsize=11)

    for v, label in zip(image_vectors, image_labels):
        ax.arrow(
            0, 0, float(v[0]), float(v[1]),
            length_includes_head=True,
            head_width=0.03 * lim,
            head_length=0.04 * lim,
            linewidth=1.5,
#            linestyle="--",
            color='red',
        )
        ax.text(float(v[0]) * 1.05, float(v[1]) * 1.05, label, fontsize=11)

    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    ax.set_aspect("equal", adjustable="box")
#    ax.set_title(title)
#    ax.set_xlabel("x")
#    ax.set_ylabel("y")
#    ax.grid(True, alpha=0.3)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    ax.grid(False)
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_title("")

    # 枠線（spines）を消す
    for spine in ax.spines.values():
        spine.set_visible(False)

    # 軸線（axhline, axvline）を消したい場合はコメントアウト
    # ax.axhline(0, color="0.75", lw=1)
    # ax.axvline(0, color="0.75", lw=1)


def main():
    """
    概要:
        2x2実数行列の対角化を可視化し、結果を出力するメイン関数。

    詳細説明:
        コマンドライン引数から行列の要素を受け取り、以下の処理を実行する。
        1. 元の基底ベクトルとその像を計算し、角度情報と共に表示する。
        2. 行列が実数上で対角化可能かを判定し、その結果と理由を表示する。
        3. 行列が対称であるかを判定する。
        4. 実数上で対角化可能であれば、固有値と正規化された固有ベクトルを計算・表示し、
           元の基底と固有ベクトルによる対角化基底の角度を比較する。
           さらに、対角化された行列 `P^{-1} S P` を表示する。
        5. 結果を2つのサブプロット（元の基底と対角化基底）または1つのサブプロット（対角化不可能時）に描画し、
           指定されたファイル名で保存して表示する。

    :returns: なし
    """
    parser = argparse.ArgumentParser(
        description="Visualize diagonalization of a 2x2 real matrix."
    )
    parser.add_argument("S11", type=float)
    parser.add_argument("S12", type=float)
    parser.add_argument("S21", type=float)
    parser.add_argument("S22", type=float)
    parser.add_argument(
        "--save",
        type=str,
        default="diagonalize2d.png",
        help="Output figure filename (default: diagonalize2d.png)",
    )
    args = parser.parse_args()

    S = np.array([[args.S11, args.S12], [args.S21, args.S22]], dtype=float)

    a1 = np.array([1.0, 0.0])
    a2 = np.array([0.0, 1.0])
    Sa1 = S @ a1
    Sa2 = S @ a2

    print("=== Input matrix S ===")
    print(S)
    print()

    print("=== Original basis ===")
    print(f"a1 = {a1}, angle = {angle_deg(a1):.6f} deg")
    print(f"a2 = {a2}, angle = {angle_deg(a2):.6f} deg")
    print(f"angle(a1, a2) = {angle_between_deg(a1, a2):.6f} deg")
    print()

    print("=== Images of original basis ===")
    print(f"S a1 = {Sa1}, parallel to a1? {is_parallel(a1, Sa1)}")
    print(f"S a2 = {Sa2}, parallel to a2? {is_parallel(a2, Sa2)}")
    print()

    diag_ok, msg, vals, vecs = classify_real_diagonalizable(S)
    print("=== Diagonalization test over R ===")
    print(msg)
    print()

    symmetric = np.allclose(S, S.T, atol=1e-10)
    print(f"Symmetric matrix? {symmetric}")
    print()

    if diag_ok:
        vals, vecs = sort_eigensystem(vals.real, vecs.real)
        v1 = vecs[:, 0]
        v2 = vecs[:, 1]
        Sv1 = S @ v1
        Sv2 = S @ v2

        print("=== Eigenvalues / eigenvectors ===")
        print(f"lambda1 = {vals[0]:.12g}")
        print(f"v1      = {v1}")
        print(f"angle(v1) = {angle_deg(v1):.6f} deg")
        print()
        print(f"lambda2 = {vals[1]:.12g}")
        print(f"v2      = {v2}")
        print(f"angle(v2) = {angle_deg(v2):.6f} deg")
        print()

        ang_orig = angle_between_deg(a1, a2)
        ang_diag = angle_between_deg(v1, v2)

        print("=== Angle comparison ===")
        print(f"angle(original basis)      = {ang_orig:.6f} deg")
        print(f"angle(diagonalizing basis) = {ang_diag:.6f} deg")
        if symmetric:
            print("For a symmetric matrix, this should stay 90 deg (orthogonal diagonalization).")
        else:
            print("For a non-symmetric matrix, the diagonalizing basis need not be orthogonal.")
        print()

        P = np.column_stack([v1, v2])
        D = np.linalg.inv(P) @ S @ P
        print("=== P^{-1} S P ===")
        print(D)
        print()

        fig, axes = plt.subplots(1, 2, figsize=(11, 5))
        ax_left, ax_right = axes

        plot_vectors(
            ax_left,
            basis_vectors=[a1, a2],
            image_vectors=[Sa1, Sa2],
            basis_labels=["a1", "a2"],
            image_labels=["S a1", "S a2"],
            title="Original basis",
        )
        plot_vectors(
            ax_right,
            basis_vectors=[v1, v2],
            image_vectors=[Sv1, Sv2],
            basis_labels=["v1", "v2"],
            image_labels=["S v1", "S v2"],
            title="Diagonalizing basis (real eigenvectors)",
        )
    else:
        print("=== Eigenvalues (possibly complex) ===")
        print(vals)
        print()

        fig, ax = plt.subplots(1, 1, figsize=(5.5, 5))
        plot_vectors(
            ax,
            basis_vectors=[a1, a2],
            image_vectors=[Sa1, Sa2],
            basis_labels=["a1", "a2"],
            image_labels=["S a1", "S a2"],
            title="Original basis (not diagonalizable over R)",
        )

    fig.suptitle("2D matrix visualization", fontsize=14)
    fig.tight_layout()
    fig.savefig(args.save, dpi=160)
    print(f"Saved figure: {args.save}")
    plt.show()


if __name__ == "__main__":
    main()