""" Shot Contamination Analyzer ============================ Analyzes the precise contamination zone around each shot event. Shows speed/acceleration profiles before and after each shot to identify the exact duration of IMU perturbation vs voluntary movement. Usage: python analyze_shots.py [--plot] [--window 100] Protocol for best results: 1. Stay stable (no movement) for 2-3 seconds 2. Fire a single shot 3. Stay stable again for 2-3 seconds 4. Repeat 10+ times This isolates the IMU shock from voluntary movement. """ import csv import sys import math import os import argparse from dataclasses import dataclass from typing import List, Tuple @dataclass class Frame: timestamp: float real_pos: Tuple[float, float, float] real_aim: Tuple[float, float, float] pred_pos: Tuple[float, float, float] pred_aim: Tuple[float, float, float] safe_count: int buffer_count: int extrap_time: float shot_fired: bool = False def load_csv(path: str) -> List[Frame]: frames = [] with open(path, 'r') as f: reader = csv.DictReader(f) for row in reader: frames.append(Frame( timestamp=float(row['Timestamp']), real_pos=(float(row['RealPosX']), float(row['RealPosY']), float(row['RealPosZ'])), real_aim=(float(row['RealAimX']), float(row['RealAimY']), float(row['RealAimZ'])), pred_pos=(float(row['PredPosX']), float(row['PredPosY']), float(row['PredPosZ'])), pred_aim=(float(row['PredAimX']), float(row['PredAimY']), float(row['PredAimZ'])), safe_count=int(row['SafeCount']), buffer_count=int(row['BufferCount']), extrap_time=float(row['ExtrapolationTime']), shot_fired=int(row.get('ShotFired', 0)) == 1, )) return frames def vec_dist(a, b): return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b))) def vec_sub(a, b): return tuple(ai - bi for ai, bi in zip(a, b)) def vec_len(a): return math.sqrt(sum(ai * ai for ai in a)) def vec_normalize(a): l = vec_len(a) if l < 1e-10: return (0, 0, 0) return tuple(ai / l for ai in a) def angle_between(a, b): dot = sum(ai * bi for ai, bi in zip(a, b)) dot = max(-1.0, min(1.0, dot)) return math.degrees(math.acos(dot)) def compute_per_frame_metrics(frames): """Compute speed, acceleration, and aim angular speed per frame.""" n = len(frames) pos_speed = [0.0] * n aim_speed = [0.0] * n pos_accel = [0.0] * n for i in range(1, n): dt = frames[i].timestamp - frames[i - 1].timestamp if dt > 1e-6: pos_speed[i] = vec_dist(frames[i].real_pos, frames[i - 1].real_pos) / dt aim_a = vec_normalize(frames[i].real_aim) aim_b = vec_normalize(frames[i - 1].real_aim) if vec_len(aim_a) > 0.5 and vec_len(aim_b) > 0.5: aim_speed[i] = angle_between(aim_a, aim_b) / dt # deg/s for i in range(1, n): dt = frames[i].timestamp - frames[i - 1].timestamp if dt > 1e-6: pos_accel[i] = (pos_speed[i] - pos_speed[i - 1]) / dt return pos_speed, aim_speed, pos_accel def analyze_single_shot(frames, shot_idx, pos_speed, aim_speed, pos_accel, window_ms=200.0): """Analyze contamination around a single shot event.""" window_s = window_ms / 1000.0 shot_time = frames[shot_idx].timestamp # Collect frames in window before and after shot before = [] # (time_relative_ms, pos_speed, aim_speed, pos_accel) after = [] for i in range(max(0, shot_idx - 100), min(len(frames), shot_idx + 100)): dt_ms = (frames[i].timestamp - shot_time) * 1000.0 if -window_ms <= dt_ms < 0: before.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i])) elif dt_ms >= 0 and dt_ms <= window_ms: after.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i])) if not before: return None # Baseline: average speed in the window before the shot baseline_pos_speed = sum(s for _, s, _, _ in before) / len(before) baseline_aim_speed = sum(s for _, _, s, _ in before) / len(before) baseline_pos_std = math.sqrt(sum((s - baseline_pos_speed) ** 2 for _, s, _, _ in before) / len(before)) if len(before) > 1 else 0.0 baseline_aim_std = math.sqrt(sum((s - baseline_aim_speed) ** 2 for _, _, s, _ in before) / len(before)) if len(before) > 1 else 0.0 # Find contamination end: when speed returns to within 2 sigma of baseline pos_threshold = baseline_pos_speed + max(2.0 * baseline_pos_std, 5.0) # at least 5 cm/s aim_threshold = baseline_aim_speed + max(2.0 * baseline_aim_std, 5.0) # at least 5 deg/s pos_contamination_end_ms = 0.0 aim_contamination_end_ms = 0.0 max_pos_spike = 0.0 max_aim_spike = 0.0 for dt_ms, ps, ais, _ in after: if ps > pos_threshold: pos_contamination_end_ms = dt_ms if ais > aim_threshold: aim_contamination_end_ms = dt_ms max_pos_spike = max(max_pos_spike, ps - baseline_pos_speed) max_aim_spike = max(max_aim_spike, ais - baseline_aim_speed) return { 'shot_time': shot_time, 'baseline_pos_speed': baseline_pos_speed, 'baseline_aim_speed': baseline_aim_speed, 'baseline_pos_std': baseline_pos_std, 'baseline_aim_std': baseline_aim_std, 'pos_contamination_ms': pos_contamination_end_ms, 'aim_contamination_ms': aim_contamination_end_ms, 'max_contamination_ms': max(pos_contamination_end_ms, aim_contamination_end_ms), 'max_pos_spike': max_pos_spike, 'max_aim_spike': max_aim_spike, 'pos_threshold': pos_threshold, 'aim_threshold': aim_threshold, 'before': before, 'after': after, 'is_stable': baseline_pos_speed < 30.0 and baseline_aim_speed < 200.0, } def main(): parser = argparse.ArgumentParser(description="Shot Contamination Analyzer") parser.add_argument("csv_file", help="CSV recording file with ShotFired column") parser.add_argument("--plot", action="store_true", help="Generate per-shot plots (requires matplotlib)") parser.add_argument("--window", type=float, default=200.0, help="Analysis window in ms before/after shot (default: 200)") args = parser.parse_args() if not os.path.exists(args.csv_file): print(f"Error: File not found: {args.csv_file}") sys.exit(1) frames = load_csv(args.csv_file) print(f"Loaded {len(frames)} frames from {os.path.basename(args.csv_file)}") duration = frames[-1].timestamp - frames[0].timestamp fps = len(frames) / duration if duration > 0 else 0 print(f"Duration: {duration:.1f}s | FPS: {fps:.0f}") shot_indices = [i for i, f in enumerate(frames) if f.shot_fired] print(f"Shots detected: {len(shot_indices)}") if not shot_indices: print("No shots found! Make sure the CSV has a ShotFired column.") sys.exit(1) pos_speed, aim_speed, pos_accel = compute_per_frame_metrics(frames) # Analyze each shot results = [] print(f"\n{'=' * 90}") print(f"{'Shot':>4} {'Time':>8} {'Stable':>7} {'PosSpike':>10} {'AimSpike':>10} " f"{'PosContam':>10} {'AimContam':>10} {'MaxContam':>10}") print(f"{'':>4} {'(s)':>8} {'':>7} {'(cm/s)':>10} {'(deg/s)':>10} " f"{'(ms)':>10} {'(ms)':>10} {'(ms)':>10}") print(f"{'-' * 90}") for idx, si in enumerate(shot_indices): result = analyze_single_shot(frames, si, pos_speed, aim_speed, pos_accel, args.window) if result is None: continue results.append(result) stable_str = "YES" if result['is_stable'] else "no" print(f"{idx + 1:>4} {result['shot_time']:>8.2f} {stable_str:>7} " f"{result['max_pos_spike']:>10.1f} {result['max_aim_spike']:>10.1f} " f"{result['pos_contamination_ms']:>10.1f} {result['aim_contamination_ms']:>10.1f} " f"{result['max_contamination_ms']:>10.1f}") # Summary: only stable shots (user was not moving) stable_results = [r for r in results if r['is_stable']] all_results = results print(f"\n{'=' * 90}") print(f"SUMMARY - ALL SHOTS ({len(all_results)} shots)") if all_results: contam_all = sorted([r['max_contamination_ms'] for r in all_results]) pos_spikes = sorted([r['max_pos_spike'] for r in all_results]) aim_spikes = sorted([r['max_aim_spike'] for r in all_results]) p95_idx = int(len(contam_all) * 0.95) print(f" Contamination: mean={sum(contam_all)/len(contam_all):.1f}ms, " f"median={contam_all[len(contam_all)//2]:.1f}ms, " f"p95={contam_all[min(p95_idx, len(contam_all)-1)]:.1f}ms, " f"max={contam_all[-1]:.1f}ms") print(f" Pos spike: mean={sum(pos_spikes)/len(pos_spikes):.1f}cm/s, " f"max={pos_spikes[-1]:.1f}cm/s") print(f" Aim spike: mean={sum(aim_spikes)/len(aim_spikes):.1f}deg/s, " f"max={aim_spikes[-1]:.1f}deg/s") print(f"\nSUMMARY - STABLE SHOTS ONLY ({len(stable_results)} shots, baseline speed < 20cm/s)") if stable_results: contam_stable = sorted([r['max_contamination_ms'] for r in stable_results]) pos_spikes_s = sorted([r['max_pos_spike'] for r in stable_results]) aim_spikes_s = sorted([r['max_aim_spike'] for r in stable_results]) p95_idx = int(len(contam_stable) * 0.95) print(f" Contamination: mean={sum(contam_stable)/len(contam_stable):.1f}ms, " f"median={contam_stable[len(contam_stable)//2]:.1f}ms, " f"p95={contam_stable[min(p95_idx, len(contam_stable)-1)]:.1f}ms, " f"max={contam_stable[-1]:.1f}ms") print(f" Pos spike: mean={sum(pos_spikes_s)/len(pos_spikes_s):.1f}cm/s, " f"max={pos_spikes_s[-1]:.1f}cm/s") print(f" Aim spike: mean={sum(aim_spikes_s)/len(aim_spikes_s):.1f}deg/s, " f"max={aim_spikes_s[-1]:.1f}deg/s") recommended = math.ceil(contam_stable[min(p95_idx, len(contam_stable)-1)] / 5.0) * 5.0 print(f"\n >>> RECOMMENDED DiscardTime (from stable shots P95): {recommended:.0f}ms <<<") else: print(" No stable shots found! Make sure you stay still before firing.") print(" Shots where baseline speed > 20cm/s are excluded as 'not stable'.") # Plotting if args.plot: try: import matplotlib.pyplot as plt # Plot each shot individually n_shots = len(results) cols = min(4, n_shots) rows = math.ceil(n_shots / cols) fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False) fig.suptitle(f'Per-Shot Speed Profile ({os.path.basename(args.csv_file)})', fontsize=14) for idx, result in enumerate(results): r, c = divmod(idx, cols) ax = axes[r][c] # Before shot if result['before']: t_before = [b[0] for b in result['before']] s_before = [b[1] for b in result['before']] ax.plot(t_before, s_before, 'b-', linewidth=1, label='Before') # After shot if result['after']: t_after = [a[0] for a in result['after']] s_after = [a[1] for a in result['after']] ax.plot(t_after, s_after, 'r-', linewidth=1, label='After') # Shot line ax.axvline(x=0, color='red', linestyle='--', alpha=0.7, label='Shot') # Threshold ax.axhline(y=result['pos_threshold'], color='orange', linestyle=':', alpha=0.5, label='Threshold') # Contamination zone if result['pos_contamination_ms'] > 0: ax.axvspan(0, result['pos_contamination_ms'], alpha=0.15, color='red') stable_str = "STABLE" if result['is_stable'] else "MOVING" ax.set_title(f"Shot {idx+1} ({stable_str}) - {result['max_contamination_ms']:.0f}ms", fontsize=9, color='green' if result['is_stable'] else 'orange') ax.set_xlabel('Time from shot (ms)', fontsize=8) ax.set_ylabel('Pos Speed (cm/s)', fontsize=8) ax.tick_params(labelsize=7) if idx == 0: ax.legend(fontsize=6) # Hide unused subplots for idx in range(n_shots, rows * cols): r, c = divmod(idx, cols) axes[r][c].set_visible(False) plt.tight_layout() plot_path = args.csv_file.replace('.csv', '_shots.png') plt.savefig(plot_path, dpi=150) print(f"\nPlot saved: {plot_path}") plt.show() # Also plot aim speed fig2, axes2 = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False) fig2.suptitle(f'Per-Shot Aim Angular Speed ({os.path.basename(args.csv_file)})', fontsize=14) for idx, result in enumerate(results): r, c = divmod(idx, cols) ax = axes2[r][c] if result['before']: t_before = [b[0] for b in result['before']] a_before = [b[2] for b in result['before']] # aim_speed ax.plot(t_before, a_before, 'b-', linewidth=1) if result['after']: t_after = [a[0] for a in result['after']] a_after = [a[2] for a in result['after']] # aim_speed ax.plot(t_after, a_after, 'r-', linewidth=1) ax.axvline(x=0, color='red', linestyle='--', alpha=0.7) ax.axhline(y=result['aim_threshold'], color='orange', linestyle=':', alpha=0.5) if result['aim_contamination_ms'] > 0: ax.axvspan(0, result['aim_contamination_ms'], alpha=0.15, color='red') stable_str = "STABLE" if result['is_stable'] else "MOVING" ax.set_title(f"Shot {idx+1} ({stable_str}) - Aim {result['aim_contamination_ms']:.0f}ms", fontsize=9, color='green' if result['is_stable'] else 'orange') ax.set_xlabel('Time from shot (ms)', fontsize=8) ax.set_ylabel('Aim Speed (deg/s)', fontsize=8) ax.tick_params(labelsize=7) for idx in range(n_shots, rows * cols): r, c = divmod(idx, cols) axes2[r][c].set_visible(False) plt.tight_layout() plot_path2 = args.csv_file.replace('.csv', '_shots_aim.png') plt.savefig(plot_path2, dpi=150) print(f"Plot saved: {plot_path2}") plt.show() except ImportError: print("\nmatplotlib not installed. Install with: pip install matplotlib") print("\nDone.") if __name__ == '__main__': main()