PS_Ballistics/Tools/analyze_shots.py
j.foucher cd097e4e55 Optimize adaptive extrapolation defaults from real-world test data
- Update defaults from test-driven optimization:
  BufferTime=200ms, DiscardTime=30ms, Sensitivity=3.0,
  DeadZone=0.95, MinSpeed=0.0, Damping=5.0
- Add ShotFired column to CSV recording for contamination analysis
- Rewrite Python optimizer with 6-parameter search (sensitivity,
  dead zone, min speed, damping, buffer time, discard time)
- Fix velocity weighting order bug in Python simulation
- Add dead zone, min speed threshold, and damping to Python sim
- Add shot contamination analysis (analyze_shots.py) to measure
  exact IMU perturbation duration per shot
- Support multi-file optimization with mean/worst_case strategies
- Add jitter and overshoot scoring metrics

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 18:33:14 +01:00

367 lines
15 KiB
Python

"""
Shot Contamination Analyzer
============================
Analyzes the precise contamination zone around each shot event.
Shows speed/acceleration profiles before and after each shot to identify
the exact duration of IMU perturbation vs voluntary movement.
Usage:
python analyze_shots.py <csv_file> [--plot] [--window 100]
Protocol for best results:
1. Stay stable (no movement) for 2-3 seconds
2. Fire a single shot
3. Stay stable again for 2-3 seconds
4. Repeat 10+ times
This isolates the IMU shock from voluntary movement.
"""
import csv
import sys
import math
import os
import argparse
from dataclasses import dataclass
from typing import List, Tuple
@dataclass
class Frame:
timestamp: float
real_pos: Tuple[float, float, float]
real_aim: Tuple[float, float, float]
pred_pos: Tuple[float, float, float]
pred_aim: Tuple[float, float, float]
safe_count: int
buffer_count: int
extrap_time: float
shot_fired: bool = False
def load_csv(path: str) -> List[Frame]:
frames = []
with open(path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
frames.append(Frame(
timestamp=float(row['Timestamp']),
real_pos=(float(row['RealPosX']), float(row['RealPosY']), float(row['RealPosZ'])),
real_aim=(float(row['RealAimX']), float(row['RealAimY']), float(row['RealAimZ'])),
pred_pos=(float(row['PredPosX']), float(row['PredPosY']), float(row['PredPosZ'])),
pred_aim=(float(row['PredAimX']), float(row['PredAimY']), float(row['PredAimZ'])),
safe_count=int(row['SafeCount']),
buffer_count=int(row['BufferCount']),
extrap_time=float(row['ExtrapolationTime']),
shot_fired=int(row.get('ShotFired', 0)) == 1,
))
return frames
def vec_dist(a, b):
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
def vec_sub(a, b):
return tuple(ai - bi for ai, bi in zip(a, b))
def vec_len(a):
return math.sqrt(sum(ai * ai for ai in a))
def vec_normalize(a):
l = vec_len(a)
if l < 1e-10:
return (0, 0, 0)
return tuple(ai / l for ai in a)
def angle_between(a, b):
dot = sum(ai * bi for ai, bi in zip(a, b))
dot = max(-1.0, min(1.0, dot))
return math.degrees(math.acos(dot))
def compute_per_frame_metrics(frames):
"""Compute speed, acceleration, and aim angular speed per frame."""
n = len(frames)
pos_speed = [0.0] * n
aim_speed = [0.0] * n
pos_accel = [0.0] * n
for i in range(1, n):
dt = frames[i].timestamp - frames[i - 1].timestamp
if dt > 1e-6:
pos_speed[i] = vec_dist(frames[i].real_pos, frames[i - 1].real_pos) / dt
aim_a = vec_normalize(frames[i].real_aim)
aim_b = vec_normalize(frames[i - 1].real_aim)
if vec_len(aim_a) > 0.5 and vec_len(aim_b) > 0.5:
aim_speed[i] = angle_between(aim_a, aim_b) / dt # deg/s
for i in range(1, n):
dt = frames[i].timestamp - frames[i - 1].timestamp
if dt > 1e-6:
pos_accel[i] = (pos_speed[i] - pos_speed[i - 1]) / dt
return pos_speed, aim_speed, pos_accel
def analyze_single_shot(frames, shot_idx, pos_speed, aim_speed, pos_accel, window_ms=200.0):
"""Analyze contamination around a single shot event."""
window_s = window_ms / 1000.0
shot_time = frames[shot_idx].timestamp
# Collect frames in window before and after shot
before = [] # (time_relative_ms, pos_speed, aim_speed, pos_accel)
after = []
for i in range(max(0, shot_idx - 100), min(len(frames), shot_idx + 100)):
dt_ms = (frames[i].timestamp - shot_time) * 1000.0
if -window_ms <= dt_ms < 0:
before.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i]))
elif dt_ms >= 0 and dt_ms <= window_ms:
after.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i]))
if not before:
return None
# Baseline: average speed in the window before the shot
baseline_pos_speed = sum(s for _, s, _, _ in before) / len(before)
baseline_aim_speed = sum(s for _, _, s, _ in before) / len(before)
baseline_pos_std = math.sqrt(sum((s - baseline_pos_speed) ** 2 for _, s, _, _ in before) / len(before)) if len(before) > 1 else 0.0
baseline_aim_std = math.sqrt(sum((s - baseline_aim_speed) ** 2 for _, _, s, _ in before) / len(before)) if len(before) > 1 else 0.0
# Find contamination end: when speed returns to within 2 sigma of baseline
pos_threshold = baseline_pos_speed + max(2.0 * baseline_pos_std, 5.0) # at least 5 cm/s
aim_threshold = baseline_aim_speed + max(2.0 * baseline_aim_std, 5.0) # at least 5 deg/s
pos_contamination_end_ms = 0.0
aim_contamination_end_ms = 0.0
max_pos_spike = 0.0
max_aim_spike = 0.0
for dt_ms, ps, ais, _ in after:
if ps > pos_threshold:
pos_contamination_end_ms = dt_ms
if ais > aim_threshold:
aim_contamination_end_ms = dt_ms
max_pos_spike = max(max_pos_spike, ps - baseline_pos_speed)
max_aim_spike = max(max_aim_spike, ais - baseline_aim_speed)
return {
'shot_time': shot_time,
'baseline_pos_speed': baseline_pos_speed,
'baseline_aim_speed': baseline_aim_speed,
'baseline_pos_std': baseline_pos_std,
'baseline_aim_std': baseline_aim_std,
'pos_contamination_ms': pos_contamination_end_ms,
'aim_contamination_ms': aim_contamination_end_ms,
'max_contamination_ms': max(pos_contamination_end_ms, aim_contamination_end_ms),
'max_pos_spike': max_pos_spike,
'max_aim_spike': max_aim_spike,
'pos_threshold': pos_threshold,
'aim_threshold': aim_threshold,
'before': before,
'after': after,
'is_stable': baseline_pos_speed < 30.0 and baseline_aim_speed < 200.0,
}
def main():
parser = argparse.ArgumentParser(description="Shot Contamination Analyzer")
parser.add_argument("csv_file", help="CSV recording file with ShotFired column")
parser.add_argument("--plot", action="store_true", help="Generate per-shot plots (requires matplotlib)")
parser.add_argument("--window", type=float, default=200.0, help="Analysis window in ms before/after shot (default: 200)")
args = parser.parse_args()
if not os.path.exists(args.csv_file):
print(f"Error: File not found: {args.csv_file}")
sys.exit(1)
frames = load_csv(args.csv_file)
print(f"Loaded {len(frames)} frames from {os.path.basename(args.csv_file)}")
duration = frames[-1].timestamp - frames[0].timestamp
fps = len(frames) / duration if duration > 0 else 0
print(f"Duration: {duration:.1f}s | FPS: {fps:.0f}")
shot_indices = [i for i, f in enumerate(frames) if f.shot_fired]
print(f"Shots detected: {len(shot_indices)}")
if not shot_indices:
print("No shots found! Make sure the CSV has a ShotFired column.")
sys.exit(1)
pos_speed, aim_speed, pos_accel = compute_per_frame_metrics(frames)
# Analyze each shot
results = []
print(f"\n{'=' * 90}")
print(f"{'Shot':>4} {'Time':>8} {'Stable':>7} {'PosSpike':>10} {'AimSpike':>10} "
f"{'PosContam':>10} {'AimContam':>10} {'MaxContam':>10}")
print(f"{'':>4} {'(s)':>8} {'':>7} {'(cm/s)':>10} {'(deg/s)':>10} "
f"{'(ms)':>10} {'(ms)':>10} {'(ms)':>10}")
print(f"{'-' * 90}")
for idx, si in enumerate(shot_indices):
result = analyze_single_shot(frames, si, pos_speed, aim_speed, pos_accel, args.window)
if result is None:
continue
results.append(result)
stable_str = "YES" if result['is_stable'] else "no"
print(f"{idx + 1:>4} {result['shot_time']:>8.2f} {stable_str:>7} "
f"{result['max_pos_spike']:>10.1f} {result['max_aim_spike']:>10.1f} "
f"{result['pos_contamination_ms']:>10.1f} {result['aim_contamination_ms']:>10.1f} "
f"{result['max_contamination_ms']:>10.1f}")
# Summary: only stable shots (user was not moving)
stable_results = [r for r in results if r['is_stable']]
all_results = results
print(f"\n{'=' * 90}")
print(f"SUMMARY - ALL SHOTS ({len(all_results)} shots)")
if all_results:
contam_all = sorted([r['max_contamination_ms'] for r in all_results])
pos_spikes = sorted([r['max_pos_spike'] for r in all_results])
aim_spikes = sorted([r['max_aim_spike'] for r in all_results])
p95_idx = int(len(contam_all) * 0.95)
print(f" Contamination: mean={sum(contam_all)/len(contam_all):.1f}ms, "
f"median={contam_all[len(contam_all)//2]:.1f}ms, "
f"p95={contam_all[min(p95_idx, len(contam_all)-1)]:.1f}ms, "
f"max={contam_all[-1]:.1f}ms")
print(f" Pos spike: mean={sum(pos_spikes)/len(pos_spikes):.1f}cm/s, "
f"max={pos_spikes[-1]:.1f}cm/s")
print(f" Aim spike: mean={sum(aim_spikes)/len(aim_spikes):.1f}deg/s, "
f"max={aim_spikes[-1]:.1f}deg/s")
print(f"\nSUMMARY - STABLE SHOTS ONLY ({len(stable_results)} shots, baseline speed < 20cm/s)")
if stable_results:
contam_stable = sorted([r['max_contamination_ms'] for r in stable_results])
pos_spikes_s = sorted([r['max_pos_spike'] for r in stable_results])
aim_spikes_s = sorted([r['max_aim_spike'] for r in stable_results])
p95_idx = int(len(contam_stable) * 0.95)
print(f" Contamination: mean={sum(contam_stable)/len(contam_stable):.1f}ms, "
f"median={contam_stable[len(contam_stable)//2]:.1f}ms, "
f"p95={contam_stable[min(p95_idx, len(contam_stable)-1)]:.1f}ms, "
f"max={contam_stable[-1]:.1f}ms")
print(f" Pos spike: mean={sum(pos_spikes_s)/len(pos_spikes_s):.1f}cm/s, "
f"max={pos_spikes_s[-1]:.1f}cm/s")
print(f" Aim spike: mean={sum(aim_spikes_s)/len(aim_spikes_s):.1f}deg/s, "
f"max={aim_spikes_s[-1]:.1f}deg/s")
recommended = math.ceil(contam_stable[min(p95_idx, len(contam_stable)-1)] / 5.0) * 5.0
print(f"\n >>> RECOMMENDED DiscardTime (from stable shots P95): {recommended:.0f}ms <<<")
else:
print(" No stable shots found! Make sure you stay still before firing.")
print(" Shots where baseline speed > 20cm/s are excluded as 'not stable'.")
# Plotting
if args.plot:
try:
import matplotlib.pyplot as plt
# Plot each shot individually
n_shots = len(results)
cols = min(4, n_shots)
rows = math.ceil(n_shots / cols)
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False)
fig.suptitle(f'Per-Shot Speed Profile ({os.path.basename(args.csv_file)})', fontsize=14)
for idx, result in enumerate(results):
r, c = divmod(idx, cols)
ax = axes[r][c]
# Before shot
if result['before']:
t_before = [b[0] for b in result['before']]
s_before = [b[1] for b in result['before']]
ax.plot(t_before, s_before, 'b-', linewidth=1, label='Before')
# After shot
if result['after']:
t_after = [a[0] for a in result['after']]
s_after = [a[1] for a in result['after']]
ax.plot(t_after, s_after, 'r-', linewidth=1, label='After')
# Shot line
ax.axvline(x=0, color='red', linestyle='--', alpha=0.7, label='Shot')
# Threshold
ax.axhline(y=result['pos_threshold'], color='orange', linestyle=':', alpha=0.5, label='Threshold')
# Contamination zone
if result['pos_contamination_ms'] > 0:
ax.axvspan(0, result['pos_contamination_ms'], alpha=0.15, color='red')
stable_str = "STABLE" if result['is_stable'] else "MOVING"
ax.set_title(f"Shot {idx+1} ({stable_str}) - {result['max_contamination_ms']:.0f}ms",
fontsize=9, color='green' if result['is_stable'] else 'orange')
ax.set_xlabel('Time from shot (ms)', fontsize=8)
ax.set_ylabel('Pos Speed (cm/s)', fontsize=8)
ax.tick_params(labelsize=7)
if idx == 0:
ax.legend(fontsize=6)
# Hide unused subplots
for idx in range(n_shots, rows * cols):
r, c = divmod(idx, cols)
axes[r][c].set_visible(False)
plt.tight_layout()
plot_path = args.csv_file.replace('.csv', '_shots.png')
plt.savefig(plot_path, dpi=150)
print(f"\nPlot saved: {plot_path}")
plt.show()
# Also plot aim speed
fig2, axes2 = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False)
fig2.suptitle(f'Per-Shot Aim Angular Speed ({os.path.basename(args.csv_file)})', fontsize=14)
for idx, result in enumerate(results):
r, c = divmod(idx, cols)
ax = axes2[r][c]
if result['before']:
t_before = [b[0] for b in result['before']]
a_before = [b[2] for b in result['before']] # aim_speed
ax.plot(t_before, a_before, 'b-', linewidth=1)
if result['after']:
t_after = [a[0] for a in result['after']]
a_after = [a[2] for a in result['after']] # aim_speed
ax.plot(t_after, a_after, 'r-', linewidth=1)
ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
ax.axhline(y=result['aim_threshold'], color='orange', linestyle=':', alpha=0.5)
if result['aim_contamination_ms'] > 0:
ax.axvspan(0, result['aim_contamination_ms'], alpha=0.15, color='red')
stable_str = "STABLE" if result['is_stable'] else "MOVING"
ax.set_title(f"Shot {idx+1} ({stable_str}) - Aim {result['aim_contamination_ms']:.0f}ms",
fontsize=9, color='green' if result['is_stable'] else 'orange')
ax.set_xlabel('Time from shot (ms)', fontsize=8)
ax.set_ylabel('Aim Speed (deg/s)', fontsize=8)
ax.tick_params(labelsize=7)
for idx in range(n_shots, rows * cols):
r, c = divmod(idx, cols)
axes2[r][c].set_visible(False)
plt.tight_layout()
plot_path2 = args.csv_file.replace('.csv', '_shots_aim.png')
plt.savefig(plot_path2, dpi=150)
print(f"Plot saved: {plot_path2}")
plt.show()
except ImportError:
print("\nmatplotlib not installed. Install with: pip install matplotlib")
print("\nDone.")
if __name__ == '__main__':
main()