"""
lsq_latt2.py スクリプトのフルシステムテストモジュール。
さまざまな結晶系(三斜晶、単斜晶、直方晶、正方晶、立方晶、三方晶、六方晶)
に対して、格子定数計算プログラム lsq_latt2.py が正しく動作するかを検証します。
結晶の反射データ(2θ)をシミュレートし、ノイズを付加した後、
lsq_latt2.py を実行し、その出力結果を期待される値と比較します。
関連リンク: :doc:`test_lsq_latt2_usage`
"""
from __future__ import annotations
import argparse
import math
import random
import re
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple
[ドキュメント]
@dataclass
class CellCase:
"""
テストケースとして使用する結晶セルの情報を保持するデータクラス。
理論的な格子定数と結晶系タイプを定義します。
"""
ls: int
"""lsq_latt2.py におけるLSパラメータ(結晶系タイプコード)."""
name: str
"""結晶系の名称(例: "triclinic")."""
a: float
"""実格子定数 a (Å)."""
b: float
"""実格子定数 b (Å)."""
c: float
"""実格子定数 c (Å)."""
alpha: float
"""実格子定数 α (度)."""
beta: float
"""実格子定数 β (度)."""
gamma: float
"""実格子定数 γ (度)."""
[ドキュメント]
@dataclass
class FitResult:
"""
lsq_latt2.py の実行結果として得られた格子定数を保持するデータクラス。
解析された格子定数を含みます。
"""
a: float
"""解析された実格子定数 a (Å)."""
b: float
"""解析された実格子定数 b (Å)."""
c: float
"""解析された実格子定数 c (Å)."""
alpha: float
"""解析された実格子定数 α (度)."""
beta: float
"""解析された実格子定数 β (度)."""
gamma: float
"""解析された実格子定数 γ (度)."""
[ドキュメント]
def deg2rad(x: float) -> float:
"""
度をラジアンに変換します。
:param x: 角度(度)。
:returns: 角度(ラジアン)。
"""
return x * math.pi / 180.0
[ドキュメント]
def cell_metric(cell: CellCase) -> List[List[float]]:
"""
結晶セルの格子定数から実格子の計量テンソル (G) を計算します。
計量テンソルは、格子定数 a, b, c と角度 alpha, beta, gamma から導出されます。
このテンソルは逆格子空間の計算に用いられます。
:param cell: 格子定数情報を含む CellCase オブジェクト。
:returns: 3x3 の計量テンソル行列。
"""
a, b, c = cell.a, cell.b, cell.c
ca = math.cos(deg2rad(cell.alpha))
cb = math.cos(deg2rad(cell.beta))
cg = math.cos(deg2rad(cell.gamma))
return [
[a * a, a * b * cg, a * c * cb],
[a * b * cg, b * b, b * c * ca],
[a * c * cb, b * c * ca, c * c],
]
[ドキュメント]
def inv3(m: List[List[float]]) -> List[List[float]]:
"""
3x3 行列の逆行列を計算します。
サラスの公式を用いて行列式を計算し、各要素の余因子を求めて逆行列を構築します。
行列式がほぼゼロの場合、特異行列としてエラーを発生させます。
:param m: 3x3 の浮動小数点数行列。
:returns: 入力行列の逆行列。
:raises ValueError: 行列が特異である(行列式がほぼゼロ)場合。
"""
a = m
det = (
a[0][0] * (a[1][1] * a[2][2] - a[1][2] * a[2][1])
- a[0][1] * (a[1][0] * a[2][2] - a[1][2] * a[2][0])
+ a[0][2] * (a[1][0] * a[2][1] - a[1][1] * a[2][0])
)
if abs(det) < 1e-14:
raise ValueError("Singular metric tensor.")
inv = [[0.0] * 3 for _ in range(3)]
inv[0][0] = (a[1][1] * a[2][2] - a[1][2] * a[2][1]) / det
inv[0][1] = (a[0][2] * a[2][1] - a[0][1] * a[2][2]) / det
inv[0][2] = (a[0][1] * a[1][2] - a[0][2] * a[1][1]) / det
inv[1][0] = (a[1][2] * a[2][0] - a[1][0] * a[2][2]) / det
inv[1][1] = (a[0][0] * a[2][2] - a[0][2] * a[2][0]) / det
inv[1][2] = (a[0][2] * a[1][0] - a[0][0] * a[1][2]) / det
inv[2][0] = (a[1][0] * a[2][1] - a[1][1] * a[2][0]) / det
inv[2][1] = (a[0][1] * a[2][0] - a[0][0] * a[2][1]) / det
inv[2][2] = (a[0][0] * a[1][1] - a[0][1] * a[1][0]) / det
return inv
[ドキュメント]
def d_spacing(cell: CellCase, h: int, k: int, l: int) -> float:
"""
指定された結晶面 (hkl) の面間隔 (d値) を計算します。
実格子の計量テンソルから逆格子の計量テンソルを導出し、
それを用いて面間隔の逆数の二乗 (1/d^2) を計算します。
:param cell: 格子定数情報を含む CellCase オブジェクト。
:param h: ミラー指数 h。
:param k: ミラー指数 k。
:param l: ミラー指数 l。
:returns: 面間隔 d (Å)。
:raises ValueError: 計算された 1/d^2 が非正の場合。
"""
g = cell_metric(cell)
g_star = inv3(g)
v = [h, k, l]
dinv2 = 0.0
for i in range(3):
for j in range(3):
dinv2 += v[i] * g_star[i][j] * v[j]
if dinv2 <= 0.0:
raise ValueError(f"Non-positive 1/d^2 for hkl=({h},{k},{l})")
return 1.0 / math.sqrt(dinv2)
[ドキュメント]
def two_theta_from_d(d: float, wavelength: float) -> float | None:
"""
面間隔 (d) とX線波長から回折角 2θ を計算します。
ブラッグの法則 (nλ = 2d sinθ) を用いて計算します。
arcsin の引数が有効な範囲 (0 < x < 1) にない場合、None を返します。
:param d: 面間隔 (Å)。
:param wavelength: X線波長 (Å)。
:returns: 回折角 2θ(度)、または None(計算不能な場合)。
"""
x = wavelength / (2.0 * d)
if not (0.0 < x < 1.0):
return None
theta = math.asin(x)
return 2.0 * math.degrees(theta)
[ドキュメント]
def generate_reflections(
cell: CellCase,
wavelength: float,
n_lines: int = 15,
noise_sigma_deg: float = 0.02,
hmax: int = 6,
twotheta_min: float = 10.0,
twotheta_max: float = 140.0,
seed: int = 1234,
) -> List[Tuple[int, int, int, float, float]]:
"""
指定された結晶セルに対してX線回折ピークをシミュレートし、ノイズを付加します。
ミラー指数 (hkl) の範囲内で可能な反射を全て計算し、
2θ角でソートした後、指定された数のピークを選択します。
選択されたピークの 2θ 値にはガウスノイズが追加されます。
同一の 2θ が発生するピークは重複を除去します(小数点以下3桁で丸めて判定)。
:param cell: 基となる格子定数情報を含む CellCase オブジェクト。
:param wavelength: 使用するX線波長 (Å)。
:param n_lines: 生成する反射ピークの数。
:param noise_sigma_deg: 2θ 値に加えるノイズの標準偏差(度)。
:param hmax: ミラー指数 (h, k, l) の最大絶対値。
:param twotheta_min: 考慮する 2θ の最小値(度)。
:param twotheta_max: 考慮する 2θ の最大値(度)。
:param seed: 乱数生成器のシード値。
:returns: (h, k, l, 理論2θ, ノイズ付き2θ) のタプルのリスト。
:raises RuntimeError: 指定された範囲内で十分な数のピークが生成できなかった場合。
"""
rng = random.Random(seed)
peaks: List[Tuple[int, int, int, float]] = []
seen = set()
for h in range(0, hmax + 1):
for k in range(0, hmax + 1):
for l in range(0, hmax + 1):
if h == 0 and k == 0 and l == 0:
continue
d = d_spacing(cell, h, k, l)
tt = two_theta_from_d(d, wavelength)
if tt is None:
continue
if not (twotheta_min <= tt <= twotheta_max):
continue
key = round(tt, 3)
if key in seen:
continue
seen.add(key)
peaks.append((h, k, l, tt))
peaks.sort(key=lambda x: x[3])
if len(peaks) < n_lines:
raise RuntimeError(f"{cell.name}: not enough peaks ({len(peaks)})")
chosen = peaks[:n_lines]
noisy: List[Tuple[int, int, int, float, float]] = []
for h, k, l, tt in chosen:
tt_noisy = tt + rng.gauss(0.0, noise_sigma_deg)
noisy.append((h, k, l, tt, tt_noisy))
return noisy
[ドキュメント]
def build_input_text(
title: str,
ls: int,
wavelength: float,
peaks: List[Tuple[int, int, int, float, float]],
ip: int = 1,
) -> str:
"""
lsq_latt2.py プログラムへの入力ファイルコンテンツを生成します。
タイトル、LSパラメータ、波長、およびミラー指数とノイズ付き2θ値のリストを
指定されたフォーマットで文字列として構築します。
:param title: 入力ファイルのタイトル行。
:param ls: 結晶系タイプコード (LSパラメータ)。
:param wavelength: X線波長 (Å)。
:param peaks: (h, k, l, 理論2θ, ノイズ付き2θ) のタプルのリスト。
:param ip: 解析オプション (通常 1)。
:returns: lsq_latt2.py が読み込むためのフォーマットされた入力文字列。
"""
lines = [
title,
f"{ls} 0 0 0 0 0 2 4 {ip}",
f"{wavelength:.6f} 0.0",
]
for h, k, l, _tt, tt_noisy in peaks:
lines.append(f"{h:d} {k:d} {l:d} {tt_noisy:.6f}")
lines.append("1000 0 0 0")
lines.append("0")
return "\n".join(lines) + "\n"
DIRECT_CELL_BLOCK_OLD_RE = re.compile(
r"Direct cell constant\s+"
r".*?a.*?b.*?c\s+"
r"([0-9Ee+\-.]+)\(\s*[0-9Ee+\-.]+\)\s+"
r"([0-9Ee+\-.]+)\(\s*[0-9Ee+\-.]+\)\s+"
r"([0-9Ee+\-.]+)\(\s*[0-9Ee+\-.]+\)\s+"
r".*?alpha.*?beta.*?gamma\s+"
r"([0-9Ee+\-.]+)\(\s*[0-9Ee+\-.]+\)\s+"
r"([0-9Ee+\-.]+)\(\s*[0-9Ee+\-.]+\)\s+"
r"([0-9Ee+\-.]+)\(\s*[0-9Ee+\-.]+\)",
re.S,
)
DIRECT_SECTION_NEW_RE = re.compile(
r"\[Direct lattice constants / 実格子定数\](.*?)(?:\n\[|\Z)",
re.S,
)
VALUE_SIGMA_LINE_RE = re.compile(
r"^\s*([A-Za-z*()]+)\s*=\s*([0-9Ee+\-.]+)\s*±\s*([0-9Ee+\-.]+)",
re.M,
)
[ドキュメント]
def parse_output_file(outfile: Path) -> FitResult:
"""
lsq_latt2.py の出力ファイルを解析し、フィットされた格子定数を抽出します。
新しい出力フォーマット("[Direct lattice constants / 実格子定数]" ブロック)と、
古い出力フォーマット("Direct cell constant" ブロック)の両方に対応しています。
適切なブロックが見つからない場合、エラーを発生させます。
:param outfile: lsq_latt2.py の出力ファイルへのパス。
:returns: 解析された格子定数を含む FitResult オブジェクト。
:raises RuntimeError: 出力ファイルから格子定数ブロックを解析できなかった場合。
"""
text = outfile.read_text(encoding="utf-8", errors="replace")
# --- new format ---
msec = DIRECT_SECTION_NEW_RE.search(text)
if msec:
block = msec.group(1)
found = {}
for name, value, _sigma in VALUE_SIGMA_LINE_RE.findall(block):
key = name.strip().lower()
found[key] = float(value)
required = ["a", "b", "c", "alpha", "beta", "gamma"]
if all(k in found for k in required):
return FitResult(
found["a"],
found["b"],
found["c"],
found["alpha"],
found["beta"],
found["gamma"],
)
# --- old format (backward compatible) ---
mold = DIRECT_CELL_BLOCK_OLD_RE.search(text)
if mold:
vals = [float(x) for x in mold.groups()]
return FitResult(*vals)
raise RuntimeError(f"Could not parse direct lattice constants block: {outfile}")
[ドキュメント]
def all_cases() -> List[CellCase]:
"""
テスト対象の全ての結晶系とそれに対応する理論的な格子定数リストを返します。
各 CellCase オブジェクトは、lsq_latt2.py の LS パラメータと
対応する結晶系の名前、そして格子定数を含みます。
:returns: 事前に定義された CellCase オブジェクトのリスト。
"""
return [
CellCase(1, "triclinic", 4.83, 5.27, 6.11, 72.0, 83.0, 76.0),
CellCase(2, "monoclinic_b", 5.12, 6.35, 7.21, 90.0, 104.0, 90.0),
CellCase(3, "monoclinic_c", 5.74, 6.22, 7.48, 90.0, 90.0, 111.0),
CellCase(4, "orthorhombic", 4.91, 5.36, 6.02, 90.0, 90.0, 90.0),
CellCase(5, "tetragonal", 4.14, 4.14, 6.57, 90.0, 90.0, 90.0),
CellCase(6, "cubic", 4.07, 4.07, 4.07, 90.0, 90.0, 90.0),
CellCase(7, "trigonal", 5.21, 5.21, 5.21, 75.0, 75.0, 75.0),
CellCase(8, "hexagonal", 3.25, 3.25, 5.18, 90.0, 90.0, 120.0),
]
[ドキュメント]
def cell_to_dict(cell: CellCase | FitResult) -> dict:
"""
CellCase または FitResult オブジェクトの格子定数属性を辞書に変換します。
:param cell: 変換する CellCase または FitResult オブジェクト。
:returns: 格子定数名 (a, b, c, alpha, beta, gamma) をキーとする辞書。
"""
return {
"a": cell.a,
"b": cell.b,
"c": cell.c,
"alpha": cell.alpha,
"beta": cell.beta,
"gamma": cell.gamma,
}
[ドキュメント]
def judge_diffs(expected: CellCase, got: FitResult) -> tuple[bool, dict, dict]:
"""
期待される格子定数とフィット結果の間の差を評価します。
絶対誤差と相対誤差を計算し、定義された許容範囲内にあるかどうかを判定します。
- a, b, c の絶対誤差は 0.05 Å 未満
- alpha, beta, gamma の絶対誤差は 1.0 度 未満
:param expected: 期待される格子定数を含む CellCase オブジェクト。
:param got: lsq_latt2.py から得られた FitResult オブジェクト。
:returns: (テスト結果がOKかどうかの真偽値, 絶対誤差の辞書, 相対誤差の辞書)。
"""
exp = cell_to_dict(expected)
out = cell_to_dict(got)
abs_err = {k: abs(out[k] - exp[k]) for k in exp}
rel_err = {}
for k in exp:
denom = abs(exp[k])
rel_err[k] = abs_err[k] / denom if denom > 1e-12 else 0.0
ok = (
abs_err["a"] < 0.05 and
abs_err["b"] < 0.05 and
abs_err["c"] < 0.05 and
abs_err["alpha"] < 1.0 and
abs_err["beta"] < 1.0 and
abs_err["gamma"] < 1.0
)
return ok, abs_err, rel_err
[ドキュメント]
def infer_issue(expected: CellCase, got: FitResult, abs_err: dict, rel_err: dict) -> str:
"""
フィット結果の誤差から、考えられる問題点を推測します。
特に三方晶系 (trigonal) の場合、長さパラメータが sqrt(3) 倍になっている可能性をチェックします。
それ以外の場合は、最も誤差が大きいパラメータ(長さまたは角度)を特定します。
:param expected: 期待される格子定数を含む CellCase オブジェクト。
:param got: lsq_latt2.py から得られた FitResult オブジェクト。
:param abs_err: 各パラメータの絶対誤差の辞書。
:param rel_err: 各パラメータの相対誤差の辞書。
:returns: 問題点に関する説明文字列。
"""
if expected.name == "trigonal":
ratio_a = got.a / expected.a if abs(expected.a) > 1e-12 else float("nan")
if abs(ratio_a - math.sqrt(3.0)) < 0.02 and abs_err["alpha"] < 0.1:
return (
"length parameters only: output a,b,c are about sqrt(3) times expected while angles are correct; "
"this strongly suggests the LS=7 post-processing interprets P0 as 3*a*^2 and divides by 3 once too many or too few."
)
worst = max(abs_err, key=abs_err.get)
if worst in {"a", "b", "c"}:
return f"length mismatch is dominant, especially {worst}"
return f"angle mismatch is dominant, especially {worst}"
[ドキュメント]
def run_one_case(
lsq_script: Path,
workdir: Path,
cell: CellCase,
wavelength: float,
n_lines: int,
noise_sigma_deg: float,
seed: int,
) -> tuple[bool, FitResult | None, dict | None, dict | None, str]:
"""
単一の結晶系テストケースを実行します。
一時ディレクトリ内にテストケース固有のサブディレクトリを作成し、
反射データを生成して lsq_latt2.py への入力ファイルを構築します。
lsq_latt2.py を subprocess として実行し、その出力ファイルを解析して、
期待される結果と比較します。
:param lsq_script: テスト対象の lsq_latt2.py スクリプトへのパス。
:param workdir: 一時作業ディレクトリのルートパス。
:param cell: テスト対象の CellCase オブジェクト。
:param wavelength: シミュレーションに使用するX線波長 (Å)。
:param n_lines: 生成する反射ピークの数。
:param noise_sigma_deg: ピークに加えるノイズの標準偏差(度)。
:param seed: 反射ピーク生成に使用する乱数シード。
:returns:
(テストが成功したかの真偽値,
フィット結果の FitResult オブジェクトまたは None,
絶対誤差の辞書または None,
相対誤差の辞書または None,
エラーまたは問題点の説明文字列)。
"""
case_dir = workdir / f"LS{cell.ls}_{cell.name}"
case_dir.mkdir(parents=True, exist_ok=True)
peaks = generate_reflections(
cell=cell,
wavelength=wavelength,
n_lines=n_lines,
noise_sigma_deg=noise_sigma_deg,
seed=seed + cell.ls,
)
infile = case_dir / "input.BZ.K"
outfile = case_dir / "output.txt"
infile.write_text(
build_input_text(
title=f"TEST {cell.name}",
ls=cell.ls,
wavelength=wavelength,
peaks=peaks,
ip=1,
),
encoding="utf-8",
)
cmd = [sys.executable, str(lsq_script), str(infile), str(outfile), "--silent", "--no-show"]
proc = subprocess.run(
cmd,
input="",
text=True,
capture_output=True,
cwd=case_dir,
)
if proc.returncode != 0:
msg = f"execution failed: rc={proc.returncode}\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
return False, None, None, None, msg
if not outfile.exists():
msg = f"output file not created\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
return False, None, None, None, msg
got = parse_output_file(outfile)
ok, abs_err, rel_err = judge_diffs(cell, got)
issue = infer_issue(cell, got, abs_err, rel_err)
return ok, got, abs_err, rel_err, issue
[ドキュメント]
def main() -> int:
"""
lsq_latt2.py スクリプトのフルシステムテストを実行するメイン関数。
コマンドライン引数を解析し、定義済みの全ての結晶系テストケースをループして実行します。
各テストケースについて、反射データを生成し、lsq_latt2.py を実行し、結果を解析・評価します。
最後に、全テストケースのサマリーを表示し、失敗したケースを詳細に報告します。
一時作業ディレクトリの管理(作成と削除)も行います。
:returns: 全てのテストが成功した場合は 0、失敗したテストがある場合は 1、
スクリプトが見つからない場合は 2 を返します。
"""
parser = argparse.ArgumentParser(description="lsq_latt2.py full-system test with verbose diagnostics")
parser.add_argument("lsq_script", nargs="?", default="lsq_latt2.py", help="path to lsq_latt2.py")
parser.add_argument("--wavelength", type=float, default=1.5406)
parser.add_argument("--n-lines", type=int, default=15)
parser.add_argument("--noise-sigma", type=float, default=0.02)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--keep-workdir", action="store_true")
args = parser.parse_args()
src = Path(args.lsq_script).resolve()
if not src.exists():
print(f"ERROR: script not found: {src}", file=sys.stderr)
return 2
if args.keep_workdir:
tmp_root = Path(tempfile.mkdtemp(prefix="test_lsq_latt2_"))
cleanup = False
else:
tmpctx = tempfile.TemporaryDirectory(prefix="test_lsq_latt2_")
tmp_root = Path(tmpctx.name)
cleanup = True
try:
work_script = tmp_root / "lsq_latt2_under_test.py"
work_script.write_text(src.read_text(encoding="utf-8"), encoding="utf-8")
results = []
print(f"Work directory : {tmp_root}")
print(f"Target script : {src}")
print()
for cell in all_cases():
ok, got, abs_err, rel_err, issue = run_one_case(
lsq_script=work_script,
workdir=tmp_root,
cell=cell,
wavelength=args.wavelength,
n_lines=args.n_lines,
noise_sigma_deg=args.noise_sigma,
seed=args.seed,
)
results.append((cell, ok, got, abs_err, rel_err, issue))
if got is None:
print(f"[FAIL] LS={cell.ls} {cell.name}")
print(" issue : " + issue)
else:
print(f"[{'PASS' if ok else 'FAIL'}] LS={cell.ls} {cell.name}")
print(format_case_report(cell, got, abs_err, rel_err, issue))
print()
n_pass = sum(1 for _c, ok, _g, _a, _r, _i in results if ok)
n_total = len(results)
print(f"Summary: {n_pass}/{n_total} passed")
failed = [(c, g, a, r, i) for c, ok, g, a, r, i in results if not ok]
if failed:
print("\nFailed cases:")
for cell, got, abs_err, rel_err, issue in failed:
print(f"- LS={cell.ls} {cell.name}")
if got is None:
print(f" issue : {issue}")
else:
print(format_case_report(cell, got, abs_err, rel_err, issue))
print()
return 1
if args.keep_workdir:
print(f"Temporary files kept at: {tmp_root}")
return 0
finally:
if cleanup:
tmpctx.cleanup()
if __name__ == "__main__":
raise SystemExit(main())