"""
phyphox_preprocessor.py

Standalone preprocessor for phyphox gyroscope files that leaves wf_core.py and
wf_analyze.py untouched.

What it does:
1) Parses phyphox gyroscope data
2) Selects a stable middle segment (first/last excluded + stability filter)
3) Synthesizes a 2-channel 24-bit 96 kHz WAV (L == R, mono information)
4) Runs existing wf_analyze pipeline on the generated WAV
5) Writes metrics to a .txt report

Usage examples:
  python phyphox_preprocessor.py input.phyphox
  python phyphox_preprocessor.py input.phyphox --rpm 45
  python phyphox_preprocessor.py input.phyphox --axis z --trim-start-s 12 --trim-end-s 10
"""

from __future__ import annotations

import argparse
import os
import xml.etree.ElementTree as ET
from pathlib import Path

import numpy as np
import soundfile as sf

import wf_analyze


RAD_S_TO_RPM = 60.0 / (2.0 * np.pi)


def _normalize_rpm_option(raw_rpm: float) -> float:
    """
    Normalize user RPM input to canonical platter speeds.

    Accepted families:
      - 33.*  -> 33.3333333333
      - 45.*  -> 45.0
      - 78.*  -> 78.0

    This allows convenient inputs such as 33, 33.0, 33.33, etc.
    """
    if not np.isfinite(raw_rpm):
        raise ValueError('RPM must be a finite number')

    if 32.0 <= raw_rpm < 34.0:
        return 33.3333333333
    if 44.0 <= raw_rpm < 46.0:
        return 45.0
    if 77.0 <= raw_rpm < 79.0:
        return 78.0

    raise ValueError(
        'RPM must indicate one of these families: 33, 45, 78. '
        'Examples: --rpm 33, --rpm 33.33, --rpm 45, --rpm 78'
    )


def _rpm_from_rsp(rsp: int) -> float:
    """
    Map Rotational SPeed selector to canonical RPM.

      --rsp 1 -> 33 1/3 RPM
      --rsp 2 -> 45 RPM
      --rsp 3 -> 78 RPM
    """
    mapping = {
        1: 100.0 / 3.0,
        2: 45.0,
        3: 78.0,
    }
    if rsp not in mapping:
        raise ValueError('RSP must be one of: 1 (33 1/3), 2 (45), 3 (78)')
    return mapping[rsp]


def _parse_phyphox(path: Path, axis: str):
    root = ET.parse(path).getroot()
    containers_node = root.find('data-containers')
    if containers_node is None:
        raise ValueError('phyphox file has no data-containers section')

    containers = {}
    for container in containers_node.findall('container'):
        key = (container.text or '').strip()
        if key:
            containers[key] = container

    def arr(key: str) -> np.ndarray:
        c = containers.get(key)
        if c is None:
            return np.array([], dtype=np.float64)
        init = c.attrib.get('init', '')
        if not init:
            return np.array([], dtype=np.float64)
        return np.fromstring(init, sep=',', dtype=np.float64)

    t = arr('gyr_time')
    gx = arr('gyrX')
    gy = arr('gyrY')
    gz = arr('gyrZ')
    ga = arr('gyr')

    if len(t) < 2:
        raise ValueError('phyphox file has fewer than 2 time samples')

    axis_map = {}
    if len(gx) > 0:
        axis_map['x'] = gx
    if len(gy) > 0:
        axis_map['y'] = gy
    if len(gz) > 0:
        axis_map['z'] = gz
    if len(ga) > 0:
        axis_map['abs'] = ga

    axis = axis.lower()
    if axis not in {'auto', 'x', 'y', 'z', 'abs'}:
        raise ValueError('axis must be one of auto/x/y/z/abs')

    if axis == 'auto':
        preferred = [k for k in ['x', 'y', 'z'] if k in axis_map]
        if preferred:
            axis_used = max(preferred, key=lambda k: abs(float(np.mean(axis_map[k]))))
            w = axis_map[axis_used]
        elif 'abs' in axis_map:
            axis_used = 'abs'
            w = axis_map['abs']
        else:
            raise ValueError('No gyroscope channels available in phyphox file')
    else:
        if axis not in axis_map:
            raise ValueError(f"Requested axis '{axis}' not present in phyphox file")
        axis_used = axis
        w = axis_map[axis]

    n = min(len(t), len(w))
    t = t[:n]
    w = w[:n]

    valid = np.isfinite(t) & np.isfinite(w)
    t = t[valid]
    w = w[valid]
    if len(t) < 2:
        raise ValueError('No valid gyroscope samples after filtering')

    order = np.argsort(t)
    t = t[order]
    w = w[order]

    # enforce positive rotation speed magnitude
    rpm = np.abs(w) * RAD_S_TO_RPM

    dt = np.diff(t)
    dt = dt[(dt > 0) & np.isfinite(dt)]
    if len(dt) == 0:
        raise ValueError('phyphox timestamps are not strictly increasing')
    fs = 1.0 / np.median(dt)

    return t, rpm, axis_used, fs


def _contiguous_regions(mask: np.ndarray):
    idx = np.flatnonzero(mask)
    if len(idx) == 0:
        return []
    splits = np.where(np.diff(idx) > 1)[0]
    starts = np.r_[idx[0], idx[splits + 1]]
    ends = np.r_[idx[splits], idx[-1]]
    return list(zip(starts.tolist(), ends.tolist()))


def _select_stable_segment(t: np.ndarray, rpm: np.ndarray,
                           trim_start_s: float, trim_end_s: float,
                           stable_window_s: float, stable_factor: float,
                           min_stable_s: float):
    t0 = float(t[0] + trim_start_s)
    t1 = float(t[-1] - trim_end_s)
    if t1 <= t0:
        raise ValueError('Trim settings remove all data. Reduce trim_start_s/trim_end_s.')

    mask_trim = (t >= t0) & (t <= t1)
    t_trim = t[mask_trim]
    rpm_trim = rpm[mask_trim]
    if len(t_trim) < 10:
        raise ValueError('Too few samples after edge trim')

    dt = np.median(np.diff(t_trim))
    fs = 1.0 / dt
    win = max(5, int(round(stable_window_s * fs)))
    if win % 2 == 0:
        win += 1

    kernel = np.ones(win, dtype=np.float64) / float(win)
    mean = np.convolve(rpm_trim, kernel, mode='same')
    var = np.convolve((rpm_trim - mean) ** 2, kernel, mode='same')
    std = np.sqrt(np.maximum(var, 0.0))

    med_std = float(np.median(std))
    thr = med_std * stable_factor
    stable_mask = std <= thr

    regions = _contiguous_regions(stable_mask)
    if not regions:
        # fallback: keep full trimmed region
        return t_trim, rpm_trim, (float(t_trim[0]), float(t_trim[-1])), med_std, thr

    min_len = int(round(min_stable_s * fs))
    candidates = [(s, e) for (s, e) in regions if (e - s + 1) >= min_len]
    if not candidates:
        candidates = regions

    s_best, e_best = max(candidates, key=lambda se: se[1] - se[0])
    t_sel = t_trim[s_best:e_best + 1]
    rpm_sel = rpm_trim[s_best:e_best + 1]

    return t_sel, rpm_sel, (float(t_sel[0]), float(t_sel[-1])), med_std, thr


def _synthesize_wav_from_rpm(t: np.ndarray, rpm: np.ndarray,
                             wav_sr: int, carrier_hz: float,
                             out_wav: Path):
    if len(t) < 2:
        raise ValueError('Not enough samples to synthesize WAV')

    duration = float(t[-1] - t[0])
    if duration <= 0:
        raise ValueError('Selected segment has non-positive duration')

    t_rel = t - t[0]
    n_out = int(np.floor(duration * wav_sr))
    tw = np.arange(n_out, dtype=np.float64) / float(wav_sr)

    # Encode speed deviation as FM around a fixed carrier.
    rpm_mean = float(np.mean(rpm))
    dev_frac = (rpm - rpm_mean) / rpm_mean
    dev_w = np.interp(tw, t_rel, dev_frac)

    inst_f = carrier_hz * (1.0 + dev_w)
    phase = 2.0 * np.pi * np.cumsum(inst_f) / float(wav_sr)
    mono = 0.8 * np.sin(phase)

    # 2-channel mono payload (L == R => L-R ~ 0)
    stereo = np.column_stack([mono, mono]).astype(np.float64)
    sf.write(str(out_wav), stereo, samplerate=wav_sr, subtype='PCM_24')

    return rpm_mean, duration, n_out


def _write_metrics_report(path: Path, result: dict, preprocess_info: dict):
    s = result['standard']
    ns = result['non_standard']

    lines = []
    lines.append('Wow & Flutter Metrics Report')
    lines.append('===========================')
    lines.append('')
    lines.append('Preprocessing')
    lines.append('-------------')
    for k, v in preprocess_info.items():
        lines.append(f'{k}: {v}')
    lines.append('')
    lines.append('Analysis Metrics (wf_analyze / wf_core)')
    lines.append('----------------------------------------')
    lines.append(f"mean_frequency_hz: {result['f_mean']:.8f}")
    lines.append(f"peak_to_peak_pct: {result['wf_peak_to_peak']:.8f}")
    lines.append(f"unweighted_peak_2sigma_pct: {s['peak_unweighted']:.8f}")
    lines.append(f"unweighted_rms_pct: {s['rms_unweighted']:.8f}")
    lines.append(f"weighted_peak_2sigma_pct: {s['peak_weighted']:.8f}")
    lines.append(f"weighted_rms_pct: {s['rms_weighted']:.8f}")
    lines.append(f"weighted_wow_rms_pct: {s['wow_rms']:.8f}")
    lines.append(f"weighted_flutter_rms_pct: {s['flutter_rms']:.8f}")
    lines.append(f"drift_rms_pct: {ns['drift_rms']:.8f}")
    lines.append(f"unweighted_wow_rms_pct: {ns['wow_rms']:.8f}")
    lines.append(f"unweighted_flutter_rms_pct: {ns['flutter_rms']:.8f}")

    path.write_text('\n'.join(lines) + '\n', encoding='utf-8')


def main():
    parser = argparse.ArgumentParser(
        description='Preprocess phyphox gyroscope data to stable WAV and run existing wf_analyze pipeline')
    parser.add_argument('input', help='Input .phyphox file')
    parser.add_argument('--rsp', type=int, default=None, choices=[1, 2, 3],
                        help='Rotational SPeed selector: 1=33 1/3, 2=45, 3=78')
    parser.add_argument('--rpm', type=float, default=None,
                        help='Fallback RPM input (accepted families only: 33.*, 45.*, 78.*)')
    parser.add_argument('--axis', type=str, default='auto', choices=['auto', 'x', 'y', 'z', 'abs'],
                        help='Gyroscope axis selection from phyphox (default auto)')
    parser.add_argument('--trim-start-s', type=float, default=8.0,
                        help='Seconds to remove from start before stability selection')
    parser.add_argument('--trim-end-s', type=float, default=8.0,
                        help='Seconds to remove from end before stability selection')
    parser.add_argument('--stable-window-s', type=float, default=2.0,
                        help='Rolling window in seconds for local stability')
    parser.add_argument('--stable-factor', type=float, default=1.35,
                        help='Stability threshold multiplier over median local std')
    parser.add_argument('--min-stable-s', type=float, default=20.0,
                        help='Minimum accepted stable segment length in seconds')
    parser.add_argument('--carrier-hz', type=float, default=1000.0,
                        help='Carrier frequency for synthesized FG WAV')
    parser.add_argument('--out-prefix', type=str, default=None,
                        help='Output prefix (default: input basename + _preprocessed)')
    parser.add_argument('--wav-sr', type=int, default=96000,
                        help='WAV sample rate (default 96000)')
    parser.add_argument('--motor-slots', type=int, default=None,
                        help='Pass-through to wf_analyze for harmonic labels')
    parser.add_argument('--motor-poles', type=int, default=None,
                        help='Pass-through to wf_analyze for harmonic labels')
    parser.add_argument('--drive-ratio', type=float, default=1.0,
                        help='Pass-through to wf_analyze')
    args = parser.parse_args()

    if args.rsp is not None and args.rpm is not None:
        raise ValueError('Use either --rsp or --rpm, not both')

    if args.rsp is not None:
        rpm_input_display = f'rsp={args.rsp}'
        rpm_normalized = _rpm_from_rsp(args.rsp)
    elif args.rpm is not None:
        rpm_input_display = f'rpm={args.rpm}'
        rpm_normalized = _normalize_rpm_option(args.rpm)
    else:
        rpm_input_display = 'default rsp=1'
        rpm_normalized = _rpm_from_rsp(1)

    input_path = Path(args.input)
    if input_path.suffix.lower() != '.phyphox':
        raise ValueError('Input must be a .phyphox file')

    prefix = args.out_prefix or (input_path.stem + '_preprocessed')
    out_wav = input_path.with_name(prefix + '.wav')
    out_txt = input_path.with_name(prefix + '_metrics.txt')

    # Parse + stable filtering
    t, rpm, axis_used, src_fs = _parse_phyphox(input_path, args.axis)
    t_sel, rpm_sel, stable_interval, med_std, thr = _select_stable_segment(
        t, rpm,
        trim_start_s=args.trim_start_s,
        trim_end_s=args.trim_end_s,
        stable_window_s=args.stable_window_s,
        stable_factor=args.stable_factor,
        min_stable_s=args.min_stable_s,
    )

    # Synthesize compliant 2ch / 24-bit / 96k WAV for existing pipeline
    rpm_mean, dur, n_out = _synthesize_wav_from_rpm(
        t_sel, rpm_sel,
        wav_sr=args.wav_sr,
        carrier_hz=args.carrier_hz,
        out_wav=out_wav,
    )

    print(f'Generated WAV: {out_wav.name}')
    print(f'  WAV format: 2ch PCM_24 @ {args.wav_sr} Hz')
    print(f'  Axis used: {axis_used}')
    print(f'  RPM input: {rpm_input_display} -> normalized: {rpm_normalized}')
    print(f'  Stable interval: {stable_interval[0]:.3f}s .. {stable_interval[1]:.3f}s')
    print(f'  Segment duration: {dur:.3f}s')

    # Run original analysis pipeline (unchanged scripts)
    result = wf_analyze.analyze(
        str(out_wav),
        rpm=rpm_normalized,
        motor_slots=args.motor_slots,
        motor_poles=args.motor_poles,
        drive_ratio=args.drive_ratio,
    )
    wf_analyze.plot_results(
        result,
        motor_slots=args.motor_slots,
        motor_poles=args.motor_poles,
        rpm=rpm_normalized,
    )

    preprocess_info = {
        'input_phyphox': input_path.name,
        'output_wav': out_wav.name,
        'axis_requested': args.axis,
        'axis_used': axis_used,
        'source_sample_rate_hz': f'{src_fs:.6f}',
        'target_rpm_input': rpm_input_display,
        'target_rpm_normalized': f'{rpm_normalized:.12g}',
        'stable_interval_start_s': f'{stable_interval[0]:.6f}',
        'stable_interval_end_s': f'{stable_interval[1]:.6f}',
        'stable_segment_mean_rpm': f'{rpm_mean:.6f}',
        'stable_std_median': f'{med_std:.8f}',
        'stable_threshold': f'{thr:.8f}',
        'wav_sample_rate_hz': str(args.wav_sr),
        'wav_channels': '2',
        'wav_subtype': 'PCM_24',
        'wav_samples_per_channel': str(n_out),
        'note': 'First/last portions are excluded and stable segment is auto-selected to avoid edge/transient bias.',
    }
    _write_metrics_report(out_txt, result, preprocess_info)
    print(f'Metrics report: {out_txt.name}')


if __name__ == '__main__':
    main()
