import numpy as np
import math
import sys
import os # For os.system('cls') / 'clear' to clear console

# --- Constants ---
MMAX_DEFINE = 300

# math.pi provides higher precision than hardcoded literal
PI_CONST = math.pi

# --- Global Arrays (sized for 1-based indexing, so size + 1) ---
x = np.zeros(1025, dtype=np.float64) # Max ND 1024, so 1025 elements (0-1024)
y = np.zeros(1025, dtype=np.float64) # Max ND 1024, so 1025 elements
fpe = np.zeros(MMAX_DEFINE + 1, dtype=np.float64)
r = np.zeros(MMAX_DEFINE + 1, dtype=np.float64)
rr = np.zeros(MMAX_DEFINE + 1, dtype=np.float64)
rfpe = np.zeros(MMAX_DEFINE + 1, dtype=np.float64)

# Global variables mirroring C++ structure
rm = 0.0
wnmin = 0.0
wnmax = 0.0
sum1 = 0.0
sum2 = 0.0
sum_val = 0.0 # Renamed 'sum' to 'sum_val' to avoid conflict with built-in sum()
sumn = 0.0
sumd = 0.0
wnint = 0.0

isw_global = 0
nd = 0
mmax_global = 0
ns = 0
minm = 0

dt = 0.0
pmm = 0.0
pm = 0.0
fpemin = 0.0
av = 0.0
ci_const = 2 * PI_CONST

# --- Function Definitions ---

# Helper to clear screen (similar to CLS in BASIC)
def clear_screen():
    os.system('cls' if os.name == 'nt' else 'clear')

# ## Function: gs1880
def gs1880():
    global minm, pmm, rfpe, r, mmax_global
    minm = mmax_global
    pmm = pm
    for i in range(1, mmax_global + 1):
        rfpe[i] = r[i]

# ## Function: gs1800
def gs1800(m):
    global fpe, nd, pm, fpemin, minm, pmm, rfpe, r
    # Prevent division by zero: (nd - m - 1)
    if (nd - m - 1) == 0:
        fpe[m] = 0.0 # Or float('inf')
    else:
        fpe[m] = float(nd + m + 1) / float(nd - m - 1) * pm
    
    if fpe[m] > fpemin: return
    fpemin = fpe[m]
    minm = m
    pmm = pm
    for i in range(1, m + 1):
        rfpe[i] = r[i]

# ## Function: gs1340 (Data Input)
def gs1340():
    global nd, x
    while True:
        nm = input(" *** データファイルメイ   =")
        try:
            with open(nm, 'r') as fp:
                for i in range(1, nd + 1):
                    line = fp.readline()
                    if not line:
                        print(f"警告: ファイルの終わりに到達しました。{i}行目が読み取れませんでした。")
                        x[i] = 0.0 # Assign default if line is missing
                        continue
                    try:
                        # Assuming the data file contains three values per line,
                        # like the BASIC code's INPUT#1,X(I),BBB,CCC
                        parts = line.strip().split()
                        if len(parts) >= 1:
                            x[i] = float(parts[0])
                        else:
                            print(f"警告: ファイルの行 {i} が期待される形式ではありません。")
                            x[i] = 0.0
                    except ValueError:
                        print(f"警告: ファイルの行 {i} のデータが数値ではありません。スキップします。")
                        x[i] = 0.0
                    print(f"I={i}  X={x[i]:g}") # %g for flexible formatting
            break # Exit loop if file processed without critical error
        except FileNotFoundError:
            print(f"エラー: ファイル '{nm}' が見つかりません。")
        except Exception as e:
            print(f"ファイル読み込み中に予期せぬエラーが発生しました: {e}")

# ## Function: gs1420 (AR Model Estimation - Burg's Method)
def gs1420():
    global sum_val, nd, x, y, av, pm, fpemin, fpe, mmax_global, sumn, sumd, rm, r, rr, isw_global, pmm

    # Calculate average and subtract it from data
    sum_val = 0.0
    for i in range(1, nd + 1): # Assuming x[1] to x[nd] contain data
        sum_val += x[i]
    av = sum_val / float(nd)

    # Detrend data and calculate sum of squares
    detrended_sum_sq = 0.0
    for i in range(1, nd + 1):
        z = x[i] - av
        x[i] = z # x[i] now holds detrended data
        if i > 0: # Ensure y[0] is accessible
            y[i-1] = z # y[0] to y[nd-1] will hold detrended data
        detrended_sum_sq += z * z # Recalculate sum of squares from detrended data

    pm = detrended_sum_sq / float(nd)

    if (nd - 1) == 0: # Avoid division by zero
        fpemin = 0.0
    else:
        fpemin = float(nd + 1) / float(nd - 1) * pm
    fpe[0] = fpemin

    for m in range(1, mmax_global + 1):
        sumn = 0.0
        sumd = 0.0
        for i in range(1, nd - m + 1): # Loop up to nd-m
            sumn += x[i] * y[i]
            sumd += x[i] * x[i] + y[i] * y[i]
        
        if sumd == 0.0: # Prevent division by zero
            rm = 0.0
        else:
            rm = -2.0 * sumn / sumd
        r[m] = rm
        pm *= (1.0 - rm * rm)

        if m > 1:
            for i in range(1, m): # Loop up to m-1
                r[i] = rr[i] + rm * rr[m - i]
        
        for i in range(1, m + 1):
            rr[i] = r[i]
        
        # Update forward and backward prediction errors (Burg's algorithm core)
        for i in range(1, nd - m): # Loop up to nd-m-1
            temp_xi = x[i]
            x[i] = x[i] + rm * y[i]
            y[i] = y[i+1] + rm * temp_xi # Corrected per C++ comments (using old x[i])

        if isw_global == 0: gs1800(m)
    
    if isw_global == 1: gs1880()

# ## Function: gs1960 (Power Spectrum Estimation)
def gs1960():
    global x, f, spec, pmm, dt, ci_const, wnmin, wnmax, wnint, minm, sum1, sum2

    # Using a static-like variable for ci_initialized (function attribute)
    if not hasattr(gs1960, 'ci_initialized'):
        gs1960.ci_initialized = False
    
    if not gs1960.ci_initialized:
        ci_const *= dt # Multiply ci_const by dt only once
        gs1960.ci_initialized = True
    
    i = 0
    # Iterate from wnmin to wnmax with step wnint
    # Using a small epsilon for float comparison to avoid precision issues
    # Added max(0.0, ...) for sqrt calls in case of negative results from float arithmetic
    for freq_val in np.arange(wnmin, wnmax + wnint / 2.0, wnint): # Renamed 'f' to 'freq_val'
        sum1 = 1.0
        sum2 = 0.0
        for j in range(1, minm + 1):
            sum1 += rfpe[j] * math.cos(float(j) * ci_const * freq_val)
            sum2 += rfpe[j] * math.sin(float(j) * ci_const * freq_val)
        
        denominator = (sum1 * sum1 + sum2 * sum2)
        if denominator == 0.0: # Prevent division by zero
            x[i] = float('inf') # Assign Python infinity
        else:
            x[i] = pmm * dt / denominator
        
        i += 1
        if i >= 1025: # Prevent out-of-bounds access for x array (0 to 1024)
            sys.stderr.write("警告: スペクトル点数が配列サイズを超過しました。処理を中断します。\n")
            break

# ## Function: gs2110 (Data Output)
def gs2110():
    global x, ns
    while True:
        nm = input("   Out Data File Name: ")
        try:
            with open(nm, 'w') as fp:
                for i in range(ns): # Assuming x[0] to x[ns-1] for spectrum data
                    fp.write(f"{x[i]:g}\n")
            print(f"データがファイル '{nm}' に保存されました。\n")
            break
        except IOError as e:
            sys.stderr.write(f"ファイル '{nm}' への書き込みに失敗しました: {e}\n")
            # Loop to allow user to try again

# ## Main Routine: mem
def mem():
    global nd, dt, isw_global, mmax_global, ns, wnmin, wnmax, wnint, pm
    
    clear_screen()
    print("      ***** MEM *****\n")

    while True:
        try:
            nd = int(input(" *** サンプル テンスウ    ="))
            if nd <= 0:
                print("サンプル点数は正の整数である必要があります。")
            else:
                break
        except ValueError:
            print("無効な入力です。数値を入力してください。")
    print()

    while True:
        try:
            dt = float(input(" *** サンプル カンカク    ="))
            if dt <= 0:
                print("サンプル間隔は正の数である必要があります。")
            else:
                break
        except ValueError:
            print("無効な入力です。数値を入力してください。")
    print()

    while True:
        an_char = input("モデル ジスウ ヲ アタエマスカ (Y/N)? ").upper()
        if an_char == "Y":
            isw_global = 1
            p_str = " *** モデル ジスウ        ="
            break
        elif an_char == "N":
            isw_global = 0
            p_str = " *** サイダイ モデル ジスウ   ="
            break
        else:
            print("YまたはNを入力してください。")
    print()

    while True:
        try:
            mmax_global = int(input(p_str))
            if mmax_global <= 0:
                print("モデル次数は正の整数である必要があります。")
            else:
                break
        except ValueError:
            print("無効な入力です。数値を入力してください。")
    print()

    gs1340() # Data input
    
    # Reset ci_initialized flag for gs1960 if mem() is called multiple times
    if hasattr(gs1960, 'ci_initialized'):
        gs1960.ci_initialized = False

    ns = nd
    # wnmin and wnmax are reset to nd as per original BASIC line 1190.
    # Note: If nd can be very large, this might make the initial spectrum range very wide.
    wnmin = 0.0
    wnmax = float(nd) 
    
    # Original BASIC had a graph display here, not implemented in this text-based version
    # print("Data Graph Display (Not implemented in text mode)")

    gs1420() # AR-model estimation (Burg's method)

    # Main loop for spectrum calculation and re-calculation
    while True:
        print("\n--- Spectrum Calculation Parameters ---")
        while True:
            try:
                wnmin = float(input("   サイショウ シュウハスウ ="))
                break
            except ValueError:
                print("無効な入力です。数値を入力してください。")

        while True:
            try:
                wnmax = float(input("   サイダイ シュウハスウ ="))
                if wnmax <= wnmin:
                    print("最大周波数は最小周波数より大きい必要があります。")
                else:
                    break
            except ValueError:
                print("無効な入力です。数値を入力してください。")
        print(f"wnmin,max:{wnmin:g} , {wnmax:g}")

        while True:
            try:
                ns = int(input("   スペクトル テンスウ  ="))
                if ns <= 0:
                    print("スペクトル点数は正の整数である必要があります。")
                else:
                    break
            except ValueError:
                print("無効な入力です。数値を入力してください。")
        
        if ns > 1025: # Check against x array size
            sys.stderr.write(f"警告: スペクトル点数 ({ns}) が許容範囲 (1025) を超えています。処理を中断します。\n")
            return 0 # Indicate error

        wnint = (wnmax - wnmin) / float(ns)
        print(f"ns:{ns}    wnint:{wnint:g}")
        
        # Reset ci_initialized for gs1960 for new spectrum calculation
        gs1960.ci_initialized = False 
        gs1960() # Spectrum calculation

        # Output graph (file output only for now)
        gs2110()

        while True:
            an_char = input("   モウイチド シマスカ (Y/N)? ").upper()
            if an_char == "Y":
                break # Loop back to spectrum parameter input
            elif an_char == "N":
                clear_screen()
                return 1 # Exit mem() indicating success
            else:
                print("YまたはNを入力してください。")
    
    # This part should be unreachable if the loops work as expected
    clear_screen()
    return 1 # Should indicate successful execution if it exits main loop

# ## Main Program Entry Point
def main_entry(): # Renamed to avoid direct conflict with 'main' from previous contexts
    if mem() == 1:
        print("プログラムが正常に終了しました。\n")
        sys.exit(0)
    else:
        print("プログラムがエラーで終了しました。\n")
        sys.exit(1)

if __name__ == "__main__":
    main_entry()