- Update defaults from test-driven optimization: BufferTime=200ms, DiscardTime=30ms, Sensitivity=3.0, DeadZone=0.95, MinSpeed=0.0, Damping=5.0 - Add ShotFired column to CSV recording for contamination analysis - Rewrite Python optimizer with 6-parameter search (sensitivity, dead zone, min speed, damping, buffer time, discard time) - Fix velocity weighting order bug in Python simulation - Add dead zone, min speed threshold, and damping to Python sim - Add shot contamination analysis (analyze_shots.py) to measure exact IMU perturbation duration per shot - Support multi-file optimization with mean/worst_case strategies - Add jitter and overshoot scoring metrics Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
367 lines
15 KiB
Python
367 lines
15 KiB
Python
"""
|
|
Shot Contamination Analyzer
|
|
============================
|
|
Analyzes the precise contamination zone around each shot event.
|
|
Shows speed/acceleration profiles before and after each shot to identify
|
|
the exact duration of IMU perturbation vs voluntary movement.
|
|
|
|
Usage:
|
|
python analyze_shots.py <csv_file> [--plot] [--window 100]
|
|
|
|
Protocol for best results:
|
|
1. Stay stable (no movement) for 2-3 seconds
|
|
2. Fire a single shot
|
|
3. Stay stable again for 2-3 seconds
|
|
4. Repeat 10+ times
|
|
This isolates the IMU shock from voluntary movement.
|
|
"""
|
|
|
|
import csv
|
|
import sys
|
|
import math
|
|
import os
|
|
import argparse
|
|
from dataclasses import dataclass
|
|
from typing import List, Tuple
|
|
|
|
|
|
@dataclass
|
|
class Frame:
|
|
timestamp: float
|
|
real_pos: Tuple[float, float, float]
|
|
real_aim: Tuple[float, float, float]
|
|
pred_pos: Tuple[float, float, float]
|
|
pred_aim: Tuple[float, float, float]
|
|
safe_count: int
|
|
buffer_count: int
|
|
extrap_time: float
|
|
shot_fired: bool = False
|
|
|
|
|
|
def load_csv(path: str) -> List[Frame]:
|
|
frames = []
|
|
with open(path, 'r') as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
frames.append(Frame(
|
|
timestamp=float(row['Timestamp']),
|
|
real_pos=(float(row['RealPosX']), float(row['RealPosY']), float(row['RealPosZ'])),
|
|
real_aim=(float(row['RealAimX']), float(row['RealAimY']), float(row['RealAimZ'])),
|
|
pred_pos=(float(row['PredPosX']), float(row['PredPosY']), float(row['PredPosZ'])),
|
|
pred_aim=(float(row['PredAimX']), float(row['PredAimY']), float(row['PredAimZ'])),
|
|
safe_count=int(row['SafeCount']),
|
|
buffer_count=int(row['BufferCount']),
|
|
extrap_time=float(row['ExtrapolationTime']),
|
|
shot_fired=int(row.get('ShotFired', 0)) == 1,
|
|
))
|
|
return frames
|
|
|
|
|
|
def vec_dist(a, b):
|
|
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
|
|
|
|
|
|
def vec_sub(a, b):
|
|
return tuple(ai - bi for ai, bi in zip(a, b))
|
|
|
|
|
|
def vec_len(a):
|
|
return math.sqrt(sum(ai * ai for ai in a))
|
|
|
|
|
|
def vec_normalize(a):
|
|
l = vec_len(a)
|
|
if l < 1e-10:
|
|
return (0, 0, 0)
|
|
return tuple(ai / l for ai in a)
|
|
|
|
|
|
def angle_between(a, b):
|
|
dot = sum(ai * bi for ai, bi in zip(a, b))
|
|
dot = max(-1.0, min(1.0, dot))
|
|
return math.degrees(math.acos(dot))
|
|
|
|
|
|
def compute_per_frame_metrics(frames):
|
|
"""Compute speed, acceleration, and aim angular speed per frame."""
|
|
n = len(frames)
|
|
pos_speed = [0.0] * n
|
|
aim_speed = [0.0] * n
|
|
pos_accel = [0.0] * n
|
|
|
|
for i in range(1, n):
|
|
dt = frames[i].timestamp - frames[i - 1].timestamp
|
|
if dt > 1e-6:
|
|
pos_speed[i] = vec_dist(frames[i].real_pos, frames[i - 1].real_pos) / dt
|
|
|
|
aim_a = vec_normalize(frames[i].real_aim)
|
|
aim_b = vec_normalize(frames[i - 1].real_aim)
|
|
if vec_len(aim_a) > 0.5 and vec_len(aim_b) > 0.5:
|
|
aim_speed[i] = angle_between(aim_a, aim_b) / dt # deg/s
|
|
|
|
for i in range(1, n):
|
|
dt = frames[i].timestamp - frames[i - 1].timestamp
|
|
if dt > 1e-6:
|
|
pos_accel[i] = (pos_speed[i] - pos_speed[i - 1]) / dt
|
|
|
|
return pos_speed, aim_speed, pos_accel
|
|
|
|
|
|
def analyze_single_shot(frames, shot_idx, pos_speed, aim_speed, pos_accel, window_ms=200.0):
|
|
"""Analyze contamination around a single shot event."""
|
|
window_s = window_ms / 1000.0
|
|
shot_time = frames[shot_idx].timestamp
|
|
|
|
# Collect frames in window before and after shot
|
|
before = [] # (time_relative_ms, pos_speed, aim_speed, pos_accel)
|
|
after = []
|
|
|
|
for i in range(max(0, shot_idx - 100), min(len(frames), shot_idx + 100)):
|
|
dt_ms = (frames[i].timestamp - shot_time) * 1000.0
|
|
if -window_ms <= dt_ms < 0:
|
|
before.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i]))
|
|
elif dt_ms >= 0 and dt_ms <= window_ms:
|
|
after.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i]))
|
|
|
|
if not before:
|
|
return None
|
|
|
|
# Baseline: average speed in the window before the shot
|
|
baseline_pos_speed = sum(s for _, s, _, _ in before) / len(before)
|
|
baseline_aim_speed = sum(s for _, _, s, _ in before) / len(before)
|
|
baseline_pos_std = math.sqrt(sum((s - baseline_pos_speed) ** 2 for _, s, _, _ in before) / len(before)) if len(before) > 1 else 0.0
|
|
baseline_aim_std = math.sqrt(sum((s - baseline_aim_speed) ** 2 for _, _, s, _ in before) / len(before)) if len(before) > 1 else 0.0
|
|
|
|
# Find contamination end: when speed returns to within 2 sigma of baseline
|
|
pos_threshold = baseline_pos_speed + max(2.0 * baseline_pos_std, 5.0) # at least 5 cm/s
|
|
aim_threshold = baseline_aim_speed + max(2.0 * baseline_aim_std, 5.0) # at least 5 deg/s
|
|
|
|
pos_contamination_end_ms = 0.0
|
|
aim_contamination_end_ms = 0.0
|
|
max_pos_spike = 0.0
|
|
max_aim_spike = 0.0
|
|
|
|
for dt_ms, ps, ais, _ in after:
|
|
if ps > pos_threshold:
|
|
pos_contamination_end_ms = dt_ms
|
|
if ais > aim_threshold:
|
|
aim_contamination_end_ms = dt_ms
|
|
max_pos_spike = max(max_pos_spike, ps - baseline_pos_speed)
|
|
max_aim_spike = max(max_aim_spike, ais - baseline_aim_speed)
|
|
|
|
return {
|
|
'shot_time': shot_time,
|
|
'baseline_pos_speed': baseline_pos_speed,
|
|
'baseline_aim_speed': baseline_aim_speed,
|
|
'baseline_pos_std': baseline_pos_std,
|
|
'baseline_aim_std': baseline_aim_std,
|
|
'pos_contamination_ms': pos_contamination_end_ms,
|
|
'aim_contamination_ms': aim_contamination_end_ms,
|
|
'max_contamination_ms': max(pos_contamination_end_ms, aim_contamination_end_ms),
|
|
'max_pos_spike': max_pos_spike,
|
|
'max_aim_spike': max_aim_spike,
|
|
'pos_threshold': pos_threshold,
|
|
'aim_threshold': aim_threshold,
|
|
'before': before,
|
|
'after': after,
|
|
'is_stable': baseline_pos_speed < 30.0 and baseline_aim_speed < 200.0,
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Shot Contamination Analyzer")
|
|
parser.add_argument("csv_file", help="CSV recording file with ShotFired column")
|
|
parser.add_argument("--plot", action="store_true", help="Generate per-shot plots (requires matplotlib)")
|
|
parser.add_argument("--window", type=float, default=200.0, help="Analysis window in ms before/after shot (default: 200)")
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.csv_file):
|
|
print(f"Error: File not found: {args.csv_file}")
|
|
sys.exit(1)
|
|
|
|
frames = load_csv(args.csv_file)
|
|
print(f"Loaded {len(frames)} frames from {os.path.basename(args.csv_file)}")
|
|
|
|
duration = frames[-1].timestamp - frames[0].timestamp
|
|
fps = len(frames) / duration if duration > 0 else 0
|
|
print(f"Duration: {duration:.1f}s | FPS: {fps:.0f}")
|
|
|
|
shot_indices = [i for i, f in enumerate(frames) if f.shot_fired]
|
|
print(f"Shots detected: {len(shot_indices)}")
|
|
|
|
if not shot_indices:
|
|
print("No shots found! Make sure the CSV has a ShotFired column.")
|
|
sys.exit(1)
|
|
|
|
pos_speed, aim_speed, pos_accel = compute_per_frame_metrics(frames)
|
|
|
|
# Analyze each shot
|
|
results = []
|
|
print(f"\n{'=' * 90}")
|
|
print(f"{'Shot':>4} {'Time':>8} {'Stable':>7} {'PosSpike':>10} {'AimSpike':>10} "
|
|
f"{'PosContam':>10} {'AimContam':>10} {'MaxContam':>10}")
|
|
print(f"{'':>4} {'(s)':>8} {'':>7} {'(cm/s)':>10} {'(deg/s)':>10} "
|
|
f"{'(ms)':>10} {'(ms)':>10} {'(ms)':>10}")
|
|
print(f"{'-' * 90}")
|
|
|
|
for idx, si in enumerate(shot_indices):
|
|
result = analyze_single_shot(frames, si, pos_speed, aim_speed, pos_accel, args.window)
|
|
if result is None:
|
|
continue
|
|
results.append(result)
|
|
|
|
stable_str = "YES" if result['is_stable'] else "no"
|
|
print(f"{idx + 1:>4} {result['shot_time']:>8.2f} {stable_str:>7} "
|
|
f"{result['max_pos_spike']:>10.1f} {result['max_aim_spike']:>10.1f} "
|
|
f"{result['pos_contamination_ms']:>10.1f} {result['aim_contamination_ms']:>10.1f} "
|
|
f"{result['max_contamination_ms']:>10.1f}")
|
|
|
|
# Summary: only stable shots (user was not moving)
|
|
stable_results = [r for r in results if r['is_stable']]
|
|
all_results = results
|
|
|
|
print(f"\n{'=' * 90}")
|
|
print(f"SUMMARY - ALL SHOTS ({len(all_results)} shots)")
|
|
if all_results:
|
|
contam_all = sorted([r['max_contamination_ms'] for r in all_results])
|
|
pos_spikes = sorted([r['max_pos_spike'] for r in all_results])
|
|
aim_spikes = sorted([r['max_aim_spike'] for r in all_results])
|
|
p95_idx = int(len(contam_all) * 0.95)
|
|
print(f" Contamination: mean={sum(contam_all)/len(contam_all):.1f}ms, "
|
|
f"median={contam_all[len(contam_all)//2]:.1f}ms, "
|
|
f"p95={contam_all[min(p95_idx, len(contam_all)-1)]:.1f}ms, "
|
|
f"max={contam_all[-1]:.1f}ms")
|
|
print(f" Pos spike: mean={sum(pos_spikes)/len(pos_spikes):.1f}cm/s, "
|
|
f"max={pos_spikes[-1]:.1f}cm/s")
|
|
print(f" Aim spike: mean={sum(aim_spikes)/len(aim_spikes):.1f}deg/s, "
|
|
f"max={aim_spikes[-1]:.1f}deg/s")
|
|
|
|
print(f"\nSUMMARY - STABLE SHOTS ONLY ({len(stable_results)} shots, baseline speed < 20cm/s)")
|
|
if stable_results:
|
|
contam_stable = sorted([r['max_contamination_ms'] for r in stable_results])
|
|
pos_spikes_s = sorted([r['max_pos_spike'] for r in stable_results])
|
|
aim_spikes_s = sorted([r['max_aim_spike'] for r in stable_results])
|
|
p95_idx = int(len(contam_stable) * 0.95)
|
|
print(f" Contamination: mean={sum(contam_stable)/len(contam_stable):.1f}ms, "
|
|
f"median={contam_stable[len(contam_stable)//2]:.1f}ms, "
|
|
f"p95={contam_stable[min(p95_idx, len(contam_stable)-1)]:.1f}ms, "
|
|
f"max={contam_stable[-1]:.1f}ms")
|
|
print(f" Pos spike: mean={sum(pos_spikes_s)/len(pos_spikes_s):.1f}cm/s, "
|
|
f"max={pos_spikes_s[-1]:.1f}cm/s")
|
|
print(f" Aim spike: mean={sum(aim_spikes_s)/len(aim_spikes_s):.1f}deg/s, "
|
|
f"max={aim_spikes_s[-1]:.1f}deg/s")
|
|
|
|
recommended = math.ceil(contam_stable[min(p95_idx, len(contam_stable)-1)] / 5.0) * 5.0
|
|
print(f"\n >>> RECOMMENDED DiscardTime (from stable shots P95): {recommended:.0f}ms <<<")
|
|
else:
|
|
print(" No stable shots found! Make sure you stay still before firing.")
|
|
print(" Shots where baseline speed > 20cm/s are excluded as 'not stable'.")
|
|
|
|
# Plotting
|
|
if args.plot:
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Plot each shot individually
|
|
n_shots = len(results)
|
|
cols = min(4, n_shots)
|
|
rows = math.ceil(n_shots / cols)
|
|
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False)
|
|
fig.suptitle(f'Per-Shot Speed Profile ({os.path.basename(args.csv_file)})', fontsize=14)
|
|
|
|
for idx, result in enumerate(results):
|
|
r, c = divmod(idx, cols)
|
|
ax = axes[r][c]
|
|
|
|
# Before shot
|
|
if result['before']:
|
|
t_before = [b[0] for b in result['before']]
|
|
s_before = [b[1] for b in result['before']]
|
|
ax.plot(t_before, s_before, 'b-', linewidth=1, label='Before')
|
|
|
|
# After shot
|
|
if result['after']:
|
|
t_after = [a[0] for a in result['after']]
|
|
s_after = [a[1] for a in result['after']]
|
|
ax.plot(t_after, s_after, 'r-', linewidth=1, label='After')
|
|
|
|
# Shot line
|
|
ax.axvline(x=0, color='red', linestyle='--', alpha=0.7, label='Shot')
|
|
|
|
# Threshold
|
|
ax.axhline(y=result['pos_threshold'], color='orange', linestyle=':', alpha=0.5, label='Threshold')
|
|
|
|
# Contamination zone
|
|
if result['pos_contamination_ms'] > 0:
|
|
ax.axvspan(0, result['pos_contamination_ms'], alpha=0.15, color='red')
|
|
|
|
stable_str = "STABLE" if result['is_stable'] else "MOVING"
|
|
ax.set_title(f"Shot {idx+1} ({stable_str}) - {result['max_contamination_ms']:.0f}ms",
|
|
fontsize=9, color='green' if result['is_stable'] else 'orange')
|
|
ax.set_xlabel('Time from shot (ms)', fontsize=8)
|
|
ax.set_ylabel('Pos Speed (cm/s)', fontsize=8)
|
|
ax.tick_params(labelsize=7)
|
|
if idx == 0:
|
|
ax.legend(fontsize=6)
|
|
|
|
# Hide unused subplots
|
|
for idx in range(n_shots, rows * cols):
|
|
r, c = divmod(idx, cols)
|
|
axes[r][c].set_visible(False)
|
|
|
|
plt.tight_layout()
|
|
plot_path = args.csv_file.replace('.csv', '_shots.png')
|
|
plt.savefig(plot_path, dpi=150)
|
|
print(f"\nPlot saved: {plot_path}")
|
|
plt.show()
|
|
|
|
# Also plot aim speed
|
|
fig2, axes2 = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False)
|
|
fig2.suptitle(f'Per-Shot Aim Angular Speed ({os.path.basename(args.csv_file)})', fontsize=14)
|
|
|
|
for idx, result in enumerate(results):
|
|
r, c = divmod(idx, cols)
|
|
ax = axes2[r][c]
|
|
|
|
if result['before']:
|
|
t_before = [b[0] for b in result['before']]
|
|
a_before = [b[2] for b in result['before']] # aim_speed
|
|
ax.plot(t_before, a_before, 'b-', linewidth=1)
|
|
|
|
if result['after']:
|
|
t_after = [a[0] for a in result['after']]
|
|
a_after = [a[2] for a in result['after']] # aim_speed
|
|
ax.plot(t_after, a_after, 'r-', linewidth=1)
|
|
|
|
ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
|
|
ax.axhline(y=result['aim_threshold'], color='orange', linestyle=':', alpha=0.5)
|
|
|
|
if result['aim_contamination_ms'] > 0:
|
|
ax.axvspan(0, result['aim_contamination_ms'], alpha=0.15, color='red')
|
|
|
|
stable_str = "STABLE" if result['is_stable'] else "MOVING"
|
|
ax.set_title(f"Shot {idx+1} ({stable_str}) - Aim {result['aim_contamination_ms']:.0f}ms",
|
|
fontsize=9, color='green' if result['is_stable'] else 'orange')
|
|
ax.set_xlabel('Time from shot (ms)', fontsize=8)
|
|
ax.set_ylabel('Aim Speed (deg/s)', fontsize=8)
|
|
ax.tick_params(labelsize=7)
|
|
|
|
for idx in range(n_shots, rows * cols):
|
|
r, c = divmod(idx, cols)
|
|
axes2[r][c].set_visible(False)
|
|
|
|
plt.tight_layout()
|
|
plot_path2 = args.csv_file.replace('.csv', '_shots_aim.png')
|
|
plt.savefig(plot_path2, dpi=150)
|
|
print(f"Plot saved: {plot_path2}")
|
|
plt.show()
|
|
|
|
except ImportError:
|
|
print("\nmatplotlib not installed. Install with: pip install matplotlib")
|
|
|
|
print("\nDone.")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|