# -*- coding: utf-8 -*-
"""
Created on Sat Nov 15 14:30:04 2025

@author: ersol
"""

# -*- coding: utf-8 -*-
"""
stft_h2_lr_combined_full_with_fund_smooth.py

Compute 2nd-harmonic distortion vs time and plot H2 + fundamental magnitude
in a single figure (two stacked subplots sharing the time axis), with minor grid lines.

Fundamental magnitude is now smoothed using SMOOTH_MS (same as H2 smoothing).

Paste into Spyder and run. No argparse required.

Requirements:
    pip install numpy scipy matplotlib soundfile
"""
# ---------------------------
# User configuration
# ---------------------------
FILENAME = "2025 gyro oc9 tacet no mat TACET tracking.wav"      # path to audio (relative to working dir or absolute)
F0_OVERRIDE = None              # e.g., 1000.0 to force fundamental; set None to auto-detect per frame
F0_SEARCH_MIN = 200            # Hz (when auto-detecting)
F0_SEARCH_MAX = 400            # Hz (when auto-detecting)
NFFT = 65536                     # STFT FFT size (increase for finer freq resolution)
HOP = 2048                      # STFT hop size (time resolution)
WINDOW = "hann"
PEAK_TOL_HZ = 100               # tolerance when searching a harmonic (Hz) 3-100
MIN_FUND_MAG_DB = -120.0        # ignore frames where fundamental is below this absolute dB
SMOOTH_MS = 900                # smoothing window for plotted H2 and fundamental (ms), 0 = no smoothing
PLOT_FRAME_MAG = False           # also plot fundamental & 2nd-harm magnitudes per channel (separate figure)
CHANNELS_TO_PLOT = [0, 1]       # channels to analyze/plot; [0]=left, [1]=right

# User-configurable plot title. Set to a string to force the title, or None to auto-generate.
PLOT_TITLE = "2025 gyro oc9 tacet no mat TACET tracking.wav"
# ---------------------------

import os
import numpy as np
import soundfile as sf
import scipy.signal as sig
import matplotlib.pyplot as plt

EPS = 1e-20

def dbr(x):
    return 20.0 * np.log10(np.maximum(np.abs(x), EPS))

def ensure_2d(x):
    x = np.asarray(x)
    if x.ndim == 1:
        return x[:, None]
    return x

def load_audio(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    data, sr = sf.read(path, dtype='float32')
    info = sf.info(path)
    print(f"Loaded: {path}  sr={sr} channels={data.shape[1] if data.ndim>1 else 1} frames={len(data)} subtype={info.subtype}")
    return ensure_2d(data), sr, info

def compute_stft_matrix(x, sr, nfft=NFFT, hop=HOP, window=WINDOW):
    x = ensure_2d(x)
    n_channels = x.shape[1]
    win = sig.get_window(window, nfft, fftbins=True)
    stft_list = []
    times = None
    freqs = np.fft.rfftfreq(nfft, 1.0/sr)
    for ch in range(n_channels):
        f, t, Zxx = sig.stft(x[:, ch], fs=sr, window=win, nperseg=nfft, noverlap=nfft-hop, boundary=None, padded=False)
        stft_list.append(Zxx)
        if times is None:
            times = t
    return freqs, times, stft_list

def find_bin_index(freqs, target_freq):
    if target_freq is None:
        return None
    if target_freq < freqs[0] - 1e-12 or target_freq > freqs[-1] + 1e-12:
        return None
    return int(np.argmin(np.abs(freqs - target_freq)))

def autodetect_f0_per_frame(freqs, Zxx, fmin=F0_SEARCH_MIN, fmax=F0_SEARCH_MAX):
    mask = (freqs >= fmin) & (freqs <= fmax)
    if not np.any(mask):
        raise ValueError("F0 search range out of band for FFT frequency axis")
    idxs = np.where(mask)[0]
    mags = np.abs(Zxx)
    mags_search = mags[idxs, :]
    peak_idxs_rel = np.argmax(mags_search, axis=0)
    peak_bins = idxs[peak_idxs_rel]
    peak_freqs = freqs[peak_bins]
    peak_mags = mags[peak_bins, np.arange(mags.shape[1])]
    return peak_freqs, peak_mags

def smooth_signal(x, sr_frames, smooth_ms):
    if smooth_ms <= 0:
        return x
    kernel_len = max(1, int(round(smooth_ms * sr_frames / 1000.0)))
    kernel = np.ones(kernel_len) / kernel_len
    return np.convolve(x, kernel, mode='same')

def compute_h2_for_channel(freqs, times, Zxx, channel_idx=0):
    if F0_OVERRIDE is None:
        f0s, fund_mags = autodetect_f0_per_frame(freqs, Zxx, fmin=F0_SEARCH_MIN, fmax=F0_SEARCH_MAX)
        h2_targets = 2.0 * f0s
        n_frames = Zxx.shape[1]
        h2_mags = np.zeros(n_frames, dtype=float)
        for i, h2t in enumerate(h2_targets):
            if h2t <= freqs[0] or h2t > freqs[-1]:
                h2_mags[i] = 0.0
            else:
                mask = np.abs(freqs - h2t) <= PEAK_TOL_HZ
                if np.any(mask):
                    h2_mags[i] = np.max(np.abs(Zxx[mask, i]))
                else:
                    idx = find_bin_index(freqs, h2t)
                    h2_mags[i] = np.abs(Zxx[idx, i]) if idx is not None else 0.0
        fund_lin = fund_mags + EPS
        h2_lin = h2_mags + EPS
        frame_peak = np.max(np.abs(Zxx), axis=0) + EPS
        fund_db_rel = 20*np.log10(fund_lin / frame_peak)
        h2_db_rel = 20*np.log10(h2_lin / frame_peak)
    else:
        idx_f0 = find_bin_index(freqs, F0_OVERRIDE)
        idx_h2 = find_bin_index(freqs, 2.0 * F0_OVERRIDE)
        if idx_f0 is None or idx_h2 is None:
            raise ValueError("F0_OVERRIDE out of FFT frequency band; increase NFFT or reduce F0.")
        fund_lin = np.abs(Zxx[idx_f0, :]) + EPS
        h2_lin = np.abs(Zxx[idx_h2, :]) + EPS
        frame_peak = np.max(np.abs(Zxx), axis=0) + EPS
        fund_db_rel = 20*np.log10(fund_lin / frame_peak)
        h2_db_rel = 20*np.log10(h2_lin / frame_peak)

    h2_ratio_db = 20.0 * np.log10((h2_lin + EPS) / (fund_lin + EPS))
    fund_abs_db = 20*np.log10(fund_lin + EPS)
    valid_mask = fund_abs_db > MIN_FUND_MAG_DB
    h2_ratio_db_masked = np.where(valid_mask, h2_ratio_db, np.nan)

    return dict(times=times, f0s=(None if F0_OVERRIDE is not None else f0s),
                fund_lin=fund_lin, h2_lin=h2_lin,
                fund_db_rel=fund_db_rel, h2_db_rel=h2_db_rel,
                fund_abs_db=fund_abs_db,
                h2_ratio_db=h2_ratio_db_masked, valid_mask=valid_mask)

# ---------------------------
# Main execution
# ---------------------------
if not FILENAME:
    raise SystemExit("Set FILENAME in the script header to your audio file path (or absolute path).")

candidate = FILENAME
if not os.path.isabs(candidate):
    cand_cwd = os.path.join(os.getcwd(), candidate)
    if os.path.exists(cand_cwd):
        candidate = cand_cwd

print("Attempting to load:", candidate)
print("Exists:", os.path.exists(candidate))

x, sr, info = load_audio(candidate)

n_channels = x.shape[1]
print(f"Audio channels: {n_channels}")
freqs, times, stft_list = compute_stft_matrix(x, sr, nfft=NFFT, hop=HOP, window=WINDOW)
print(f"STFT: n_bins={len(freqs)} n_frames={len(times)} df={freqs[1]-freqs[0]:.4f}Hz dt={times[1]-times[0]:.4f}s")

results = {}
for ch in CHANNELS_TO_PLOT:
    if ch >= n_channels:
        print(f"Channel {ch} not present in file; skipping")
        continue
    Zxx = stft_list[ch]
    print(f"Computing H2 time series for channel {ch}...")
    res = compute_h2_for_channel(freqs, times, Zxx, channel_idx=ch)
    results[ch] = res
    finite = np.isfinite(res['h2_ratio_db'])
    if np.any(finite):
        print(f" Channel {ch} stats: mean H2 (dB) = {np.nanmean(res['h2_ratio_db'][finite]):.2f}, median = {np.nanmedian(res['h2_ratio_db'][finite]):.2f}")
    else:
        print(f" Channel {ch}: no valid frames (fundamental too weak or out of band).")

# Smoothing
sr_frames = 1.0 / (times[1] - times[0]) if len(times) > 1 else 1.0
for ch, res in results.items():
    # H2 smoothing (as before)
    if SMOOTH_MS and SMOOTH_MS > 0:
        finite_vals = res['h2_ratio_db'][np.isfinite(res['h2_ratio_db'])]
        fill_value = np.nanmean(finite_vals) if finite_vals.size > 0 else 0.0
        res['h2_smooth'] = smooth_signal(np.nan_to_num(res['h2_ratio_db'], nan=fill_value), sr_frames, SMOOTH_MS)
    else:
        res['h2_smooth'] = res['h2_ratio_db']
    # Fundamental smoothing (new): smooth absolute dB fundamental
    if SMOOTH_MS and SMOOTH_MS > 0:
        # fund_abs_db has no NaNs; smooth directly
        res['fund_smooth'] = smooth_signal(res['fund_abs_db'], sr_frames, SMOOTH_MS)
    else:
        res['fund_smooth'] = res['fund_abs_db']

# Combined figure: two stacked subplots sharing the x-axis
fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(12, 7), sharex=True,
                                     gridspec_kw={'height_ratios': [1, 1.0]})

# Top: H2 (smoothed) for each channel
for ch in CHANNELS_TO_PLOT:
    if ch not in results:
        continue
    res = results[ch]
    ax_top.plot(res['times'], res['h2_smooth'], label=f"Ch{ch} H2 smoothed", linewidth=2.5)
ax_top.set_ylabel("H2 (dB)  20*log10(A2/A1)")
ax_top.grid(which='major', linestyle='-', linewidth=0.8, color='gray', alpha=0.9)
ax_top.grid(which='minor', linestyle=':', linewidth=0.6, color='gray', alpha=0.5)
ax_top.minorticks_on()
ax_top.legend(loc='upper left')

# Bottom: fundamental absolute dB for each channel (smoothed)
for ch in CHANNELS_TO_PLOT:
    if ch not in results:
        continue
    res = results[ch]
    ax_bot.plot(res['times'], res['fund_smooth'], label=f"Ch{ch} Fundamental (dB, smoothed)", linewidth=2.5)
ax_bot.set_xlabel("Time (s)")
ax_bot.set_ylabel("Fundamental magnitude (dB)  20*log10(A1)")
ax_bot.grid(which='major', linestyle='-', linewidth=0.8, color='gray', alpha=0.9)
ax_bot.grid(which='minor', linestyle=':', linewidth=0.6, color='gray', alpha=0.5)
ax_bot.minorticks_on()
ax_bot.legend(loc='upper left')

# Title selection: user-specified or auto-generate
if PLOT_TITLE and isinstance(PLOT_TITLE, str) and PLOT_TITLE.strip():
    title_text = PLOT_TITLE.strip()
else:
    ch_list = ", ".join(str(ch) for ch in CHANNELS_TO_PLOT if ch in results)
    base_name = os.path.basename(candidate)
    title_text = f"{base_name} 2nd-harmonic + fundamental vs time (channels: {ch_list})"
fig.suptitle(title_text, fontsize=12)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

# Optional per-channel magnitude figure (with minor grid lines) if requested
if PLOT_FRAME_MAG:
    n_plots = len(results)
    plt.figure(figsize=(12, 3 * n_plots))
    idx = 1
    for ch, res in results.items():
        ax = plt.subplot(n_plots, 1, idx)
        ax.plot(res['times'], res['fund_db_rel'], label=f"Ch{ch} Fundamental dB (rel frame peak)")
        ax.plot(res['times'], res['h2_db_rel'], label=f"Ch{ch} 2nd harmonic dB (rel frame peak)")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("dB (rel frame peak)")
        ax.set_title(f"Channel {ch} per-frame magnitudes")
        ax.grid(which='major', linestyle='-', linewidth=0.8, color='gray', alpha=0.9)
        ax.grid(which='minor', linestyle=':', linewidth=0.6, color='gray', alpha=0.5)
        ax.minorticks_on()
        ax.legend(loc='upper left')
        idx += 1
    plt.tight_layout()
    plt.show()