Optimize adaptive extrapolation defaults from real-world test data
- Update defaults from test-driven optimization: BufferTime=200ms, DiscardTime=30ms, Sensitivity=3.0, DeadZone=0.95, MinSpeed=0.0, Damping=5.0 - Add ShotFired column to CSV recording for contamination analysis - Rewrite Python optimizer with 6-parameter search (sensitivity, dead zone, min speed, damping, buffer time, discard time) - Fix velocity weighting order bug in Python simulation - Add dead zone, min speed threshold, and damping to Python sim - Add shot contamination analysis (analyze_shots.py) to measure exact IMU perturbation duration per shot - Support multi-file optimization with mean/worst_case strategies - Add jitter and overshoot scoring metrics Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
4e9c33778c
commit
cd097e4e55
@ -1,27 +1,32 @@
|
|||||||
"""
|
"""
|
||||||
Anti-Recoil Prediction Analyzer
|
Anti-Recoil Parameter Optimizer
|
||||||
================================
|
================================
|
||||||
Reads CSV files recorded by the EBBarrel CSV recording feature and finds
|
Reads CSV files recorded by the EBBarrel CSV recording feature and finds
|
||||||
optimal AdaptiveSensitivity and scaling factor for the Adaptive Extrapolation mode.
|
optimal parameters for the Adaptive Extrapolation mode.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python analyze_antirecoil.py <path_to_csv>
|
python analyze_antirecoil.py <csv_file> [csv_file2 ...] [options]
|
||||||
python analyze_antirecoil.py <path_to_csv> --plot (requires matplotlib)
|
|
||||||
|
Options:
|
||||||
|
--plot Generate comparison plots (requires matplotlib)
|
||||||
|
--grid Use grid search instead of differential evolution
|
||||||
|
--strategy <s> Multi-file aggregation: mean (default), worst_case
|
||||||
|
--max-iter <n> Max optimizer iterations (default: 200)
|
||||||
|
|
||||||
The script:
|
The script:
|
||||||
1. Loads per-frame data (real position/aim vs predicted position/aim)
|
1. Loads per-frame data (real position/aim vs predicted position/aim)
|
||||||
2. Computes prediction error at each frame
|
2. Simulates adaptive extrapolation offline (matching C++ exactly)
|
||||||
3. Simulates different AdaptiveSensitivity values offline
|
3. Optimizes all 4 parameters: Sensitivity, DeadZone, MinSpeed, Damping
|
||||||
4. Finds the value that minimizes total prediction error
|
4. Reports recommended parameters with per-file breakdown
|
||||||
5. Reports recommended parameters
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import sys
|
import sys
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -34,13 +39,38 @@ class Frame:
|
|||||||
safe_count: int
|
safe_count: int
|
||||||
buffer_count: int
|
buffer_count: int
|
||||||
extrap_time: float
|
extrap_time: float
|
||||||
|
shot_fired: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdaptiveParams:
|
||||||
|
sensitivity: float = 3.0
|
||||||
|
dead_zone: float = 0.95
|
||||||
|
min_speed: float = 0.0
|
||||||
|
damping: float = 5.0
|
||||||
|
buffer_time_ms: float = 200.0
|
||||||
|
discard_time_ms: float = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoreResult:
|
||||||
|
pos_mean: float
|
||||||
|
pos_p95: float
|
||||||
|
aim_mean: float
|
||||||
|
aim_p95: float
|
||||||
|
jitter: float
|
||||||
|
overshoot: float
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
def load_csv(path: str) -> List[Frame]:
|
def load_csv(path: str) -> List[Frame]:
|
||||||
frames = []
|
frames = []
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
|
has_shot_col = False
|
||||||
for row in reader:
|
for row in reader:
|
||||||
|
if not has_shot_col and 'ShotFired' in row:
|
||||||
|
has_shot_col = True
|
||||||
frames.append(Frame(
|
frames.append(Frame(
|
||||||
timestamp=float(row['Timestamp']),
|
timestamp=float(row['Timestamp']),
|
||||||
real_pos=(float(row['RealPosX']), float(row['RealPosY']), float(row['RealPosZ'])),
|
real_pos=(float(row['RealPosX']), float(row['RealPosY']), float(row['RealPosZ'])),
|
||||||
@ -50,10 +80,13 @@ def load_csv(path: str) -> List[Frame]:
|
|||||||
safe_count=int(row['SafeCount']),
|
safe_count=int(row['SafeCount']),
|
||||||
buffer_count=int(row['BufferCount']),
|
buffer_count=int(row['BufferCount']),
|
||||||
extrap_time=float(row['ExtrapolationTime']),
|
extrap_time=float(row['ExtrapolationTime']),
|
||||||
|
shot_fired=int(row.get('ShotFired', 0)) == 1,
|
||||||
))
|
))
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# --- Vector math helpers ---
|
||||||
|
|
||||||
def vec_dist(a, b):
|
def vec_dist(a, b):
|
||||||
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
|
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
|
||||||
|
|
||||||
@ -88,6 +121,8 @@ def angle_between(a, b):
|
|||||||
return math.degrees(math.acos(dot))
|
return math.degrees(math.acos(dot))
|
||||||
|
|
||||||
|
|
||||||
|
# --- Prediction error from recorded data ---
|
||||||
|
|
||||||
def compute_prediction_error(frames: List[Frame]) -> dict:
|
def compute_prediction_error(frames: List[Frame]) -> dict:
|
||||||
"""Compute error between predicted and actual (real) positions/aims."""
|
"""Compute error between predicted and actual (real) positions/aims."""
|
||||||
pos_errors = []
|
pos_errors = []
|
||||||
@ -122,217 +157,599 @@ def compute_prediction_error(frames: List[Frame]) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def simulate_adaptive(frames: List[Frame], sensitivity: float) -> List[Tuple[float, float]]:
|
# --- Shot contamination analysis ---
|
||||||
|
|
||||||
|
def analyze_shot_contamination(frames: List[Frame], analysis_window_ms: float = 200.0):
|
||||||
"""
|
"""
|
||||||
Simulate the adaptive extrapolation offline with given sensitivity.
|
Analyze how shots contaminate the tracking data.
|
||||||
Uses separate pos/aim confidences with relative variance (matching C++ code).
|
For each shot, measure the velocity/acceleration spike and how long it takes
|
||||||
Returns list of (position_error, aim_error) per frame.
|
to return to baseline. This tells us the minimum discard_time needed.
|
||||||
|
|
||||||
|
Returns a dict with analysis results, or None if no shots found.
|
||||||
|
"""
|
||||||
|
shot_indices = [i for i, f in enumerate(frames) if f.shot_fired]
|
||||||
|
if not shot_indices:
|
||||||
|
return None
|
||||||
|
|
||||||
|
analysis_window_s = analysis_window_ms / 1000.0
|
||||||
|
|
||||||
|
# Compute per-frame speeds
|
||||||
|
speeds = [0.0]
|
||||||
|
for i in range(1, len(frames)):
|
||||||
|
dt = frames[i].timestamp - frames[i - 1].timestamp
|
||||||
|
if dt > 1e-6:
|
||||||
|
d = vec_dist(frames[i].real_pos, frames[i - 1].real_pos)
|
||||||
|
speeds.append(d / dt)
|
||||||
|
else:
|
||||||
|
speeds.append(speeds[-1] if speeds else 0.0)
|
||||||
|
|
||||||
|
# For each shot, measure the speed profile before and after
|
||||||
|
contamination_durations = []
|
||||||
|
speed_spikes = []
|
||||||
|
|
||||||
|
for si in shot_indices:
|
||||||
|
# Baseline speed: average speed in 100ms BEFORE the shot
|
||||||
|
baseline_speeds = []
|
||||||
|
for j in range(si - 1, -1, -1):
|
||||||
|
if frames[si].timestamp - frames[j].timestamp > 0.1:
|
||||||
|
break
|
||||||
|
baseline_speeds.append(speeds[j])
|
||||||
|
|
||||||
|
if not baseline_speeds:
|
||||||
|
continue
|
||||||
|
|
||||||
|
baseline_mean = sum(baseline_speeds) / len(baseline_speeds)
|
||||||
|
baseline_std = math.sqrt(sum((s - baseline_mean) ** 2 for s in baseline_speeds) / len(baseline_speeds)) if len(baseline_speeds) > 1 else baseline_mean * 0.1
|
||||||
|
|
||||||
|
# Threshold: speed is "contaminated" if it deviates by more than 3 sigma from baseline
|
||||||
|
threshold = baseline_mean + max(3.0 * baseline_std, 10.0) # at least 10 cm/s spike
|
||||||
|
|
||||||
|
# Find how long after the shot the speed stays above threshold
|
||||||
|
max_speed = 0.0
|
||||||
|
last_contaminated_time = 0.0
|
||||||
|
for j in range(si, len(frames)):
|
||||||
|
dt_from_shot = frames[j].timestamp - frames[si].timestamp
|
||||||
|
if dt_from_shot > analysis_window_s:
|
||||||
|
break
|
||||||
|
if speeds[j] > threshold:
|
||||||
|
last_contaminated_time = dt_from_shot
|
||||||
|
if speeds[j] > max_speed:
|
||||||
|
max_speed = speeds[j]
|
||||||
|
|
||||||
|
contamination_durations.append(last_contaminated_time * 1000.0) # in ms
|
||||||
|
speed_spikes.append(max_speed - baseline_mean)
|
||||||
|
|
||||||
|
if not contamination_durations:
|
||||||
|
return None
|
||||||
|
|
||||||
|
contamination_durations.sort()
|
||||||
|
return {
|
||||||
|
'num_shots': len(shot_indices),
|
||||||
|
'contamination_mean_ms': sum(contamination_durations) / len(contamination_durations),
|
||||||
|
'contamination_p95_ms': contamination_durations[int(len(contamination_durations) * 0.95)],
|
||||||
|
'contamination_max_ms': contamination_durations[-1],
|
||||||
|
'speed_spike_mean': sum(speed_spikes) / len(speed_spikes) if speed_spikes else 0,
|
||||||
|
'speed_spike_max': max(speed_spikes) if speed_spikes else 0,
|
||||||
|
'recommended_discard_ms': math.ceil(contamination_durations[int(len(contamination_durations) * 0.95)] / 5.0) * 5.0, # round up to 5ms
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Offline adaptive extrapolation simulation (matches C++ exactly) ---
|
||||||
|
|
||||||
|
def simulate_adaptive(frames: List[Frame], params: AdaptiveParams) -> Tuple[List[float], List[float]]:
|
||||||
|
"""
|
||||||
|
Simulate the adaptive extrapolation offline with given parameters.
|
||||||
|
Matches the C++ PredictAdaptiveExtrapolation algorithm exactly.
|
||||||
|
Optimized for speed: pre-extracts arrays, inlines math, avoids allocations.
|
||||||
"""
|
"""
|
||||||
pos_errors = []
|
pos_errors = []
|
||||||
aim_errors = []
|
aim_errors = []
|
||||||
|
|
||||||
# Sliding window for velocity estimation (matches C++ safe window ~18 samples)
|
n_frames = len(frames)
|
||||||
window_size = 12
|
if n_frames < 4:
|
||||||
|
return pos_errors, aim_errors
|
||||||
|
|
||||||
|
# Pre-extract into flat arrays for speed
|
||||||
|
ts = [f.timestamp for f in frames]
|
||||||
|
px = [f.real_pos[0] for f in frames]
|
||||||
|
py = [f.real_pos[1] for f in frames]
|
||||||
|
pz = [f.real_pos[2] for f in frames]
|
||||||
|
ax = [f.real_aim[0] for f in frames]
|
||||||
|
ay = [f.real_aim[1] for f in frames]
|
||||||
|
az = [f.real_aim[2] for f in frames]
|
||||||
|
|
||||||
|
buffer_s = params.buffer_time_ms / 1000.0
|
||||||
|
discard_s = params.discard_time_ms / 1000.0
|
||||||
|
sensitivity = params.sensitivity
|
||||||
|
dead_zone = params.dead_zone
|
||||||
|
min_speed = params.min_speed
|
||||||
|
damping = params.damping
|
||||||
SMALL = 1e-10
|
SMALL = 1e-10
|
||||||
|
_sqrt = math.sqrt
|
||||||
|
_exp = math.exp
|
||||||
|
_acos = math.acos
|
||||||
|
_degrees = math.degrees
|
||||||
|
_pow = pow
|
||||||
|
|
||||||
for i in range(window_size + 1, len(frames) - 1):
|
for i in range(2, n_frames - 1):
|
||||||
# Build velocity pairs from recent real positions and aims
|
ct = ts[i]
|
||||||
pos_vels = []
|
safe_cutoff = ct - discard_s
|
||||||
aim_vels = []
|
oldest_allowed = ct - buffer_s
|
||||||
for j in range(1, min(window_size, i)):
|
|
||||||
dt = frames[i - j].timestamp - frames[i - j - 1].timestamp
|
|
||||||
if dt > 1e-6:
|
|
||||||
pos_vel = vec_scale(vec_sub(frames[i - j].real_pos, frames[i - j - 1].real_pos), 1.0 / dt)
|
|
||||||
aim_vel = vec_scale(vec_sub(frames[i - j].real_aim, frames[i - j - 1].real_aim), 1.0 / dt)
|
|
||||||
pos_vels.append(pos_vel)
|
|
||||||
aim_vels.append(aim_vel)
|
|
||||||
|
|
||||||
if len(pos_vels) < 4:
|
# Collect safe sample indices (backward scan, then reverse)
|
||||||
|
safe = []
|
||||||
|
for j in range(i, -1, -1):
|
||||||
|
t = ts[j]
|
||||||
|
if t < oldest_allowed:
|
||||||
|
break
|
||||||
|
if t <= safe_cutoff:
|
||||||
|
safe.append(j)
|
||||||
|
safe.reverse()
|
||||||
|
|
||||||
|
ns = len(safe)
|
||||||
|
if ns < 2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
n = len(pos_vels)
|
# Build velocity pairs inline
|
||||||
|
vpx = []; vpy = []; vpz = []
|
||||||
|
vax = []; vay = []; vaz = []
|
||||||
|
for k in range(1, ns):
|
||||||
|
p, c = safe[k - 1], safe[k]
|
||||||
|
dt = ts[c] - ts[p]
|
||||||
|
if dt > 1e-6:
|
||||||
|
inv_dt = 1.0 / dt
|
||||||
|
vpx.append((px[c] - px[p]) * inv_dt)
|
||||||
|
vpy.append((py[c] - py[p]) * inv_dt)
|
||||||
|
vpz.append((pz[c] - pz[p]) * inv_dt)
|
||||||
|
vax.append((ax[c] - ax[p]) * inv_dt)
|
||||||
|
vay.append((ay[c] - ay[p]) * inv_dt)
|
||||||
|
vaz.append((az[c] - az[p]) * inv_dt)
|
||||||
|
|
||||||
# Weighted average velocity (recent samples weighted more, matching C++)
|
nv = len(vpx)
|
||||||
total_w = 0.0
|
if nv < 2:
|
||||||
avg_pos_vel = (0.0, 0.0, 0.0)
|
continue
|
||||||
avg_aim_vel = (0.0, 0.0, 0.0)
|
|
||||||
for k in range(n):
|
|
||||||
w = (k + 1) ** 2
|
|
||||||
avg_pos_vel = vec_add(avg_pos_vel, vec_scale(pos_vels[k], w))
|
|
||||||
avg_aim_vel = vec_add(avg_aim_vel, vec_scale(aim_vels[k], w))
|
|
||||||
total_w += w
|
|
||||||
avg_pos_vel = vec_scale(avg_pos_vel, 1.0 / total_w)
|
|
||||||
avg_aim_vel = vec_scale(avg_aim_vel, 1.0 / total_w)
|
|
||||||
|
|
||||||
# Deceleration detection: recent speed (last 25%) vs average speed
|
# Weighted average velocity (quadratic weights, oldest=index 0)
|
||||||
recent_start = max(0, n - max(1, n // 4))
|
tw = 0.0
|
||||||
recent_count = n - recent_start
|
apx = apy = apz = 0.0
|
||||||
recent_pos_vel = (0.0, 0.0, 0.0)
|
aax = aay = aaz = 0.0
|
||||||
recent_aim_vel = (0.0, 0.0, 0.0)
|
for k in range(nv):
|
||||||
for k in range(recent_start, n):
|
w = (k + 1) * (k + 1)
|
||||||
recent_pos_vel = vec_add(recent_pos_vel, pos_vels[k])
|
apx += vpx[k] * w; apy += vpy[k] * w; apz += vpz[k] * w
|
||||||
recent_aim_vel = vec_add(recent_aim_vel, aim_vels[k])
|
aax += vax[k] * w; aay += vay[k] * w; aaz += vaz[k] * w
|
||||||
recent_pos_vel = vec_scale(recent_pos_vel, 1.0 / recent_count)
|
tw += w
|
||||||
recent_aim_vel = vec_scale(recent_aim_vel, 1.0 / recent_count)
|
inv_tw = 1.0 / tw
|
||||||
|
apx *= inv_tw; apy *= inv_tw; apz *= inv_tw
|
||||||
|
aax *= inv_tw; aay *= inv_tw; aaz *= inv_tw
|
||||||
|
|
||||||
avg_pos_speed = vec_len(avg_pos_vel)
|
# Recent velocity (last 25%, unweighted)
|
||||||
avg_aim_speed = vec_len(avg_aim_vel)
|
rs = max(0, nv - max(1, nv // 4))
|
||||||
recent_pos_speed = vec_len(recent_pos_vel)
|
rc = nv - rs
|
||||||
recent_aim_speed = vec_len(recent_aim_vel)
|
rpx = rpy = rpz = 0.0
|
||||||
|
rax = ray = raz = 0.0
|
||||||
|
for k in range(rs, nv):
|
||||||
|
rpx += vpx[k]; rpy += vpy[k]; rpz += vpz[k]
|
||||||
|
rax += vax[k]; ray += vay[k]; raz += vaz[k]
|
||||||
|
inv_rc = 1.0 / rc
|
||||||
|
rpx *= inv_rc; rpy *= inv_rc; rpz *= inv_rc
|
||||||
|
rax *= inv_rc; ray *= inv_rc; raz *= inv_rc
|
||||||
|
|
||||||
# Confidence = (recentSpeed / avgSpeed) ^ sensitivity
|
avg_ps = _sqrt(apx*apx + apy*apy + apz*apz)
|
||||||
pos_confidence = 1.0
|
avg_as = _sqrt(aax*aax + aay*aay + aaz*aaz)
|
||||||
if avg_pos_speed > SMALL:
|
rec_ps = _sqrt(rpx*rpx + rpy*rpy + rpz*rpz)
|
||||||
pos_ratio = min(recent_pos_speed / avg_pos_speed, 1.0)
|
rec_as = _sqrt(rax*rax + ray*ray + raz*raz)
|
||||||
pos_confidence = pos_ratio ** sensitivity
|
|
||||||
|
|
||||||
aim_confidence = 1.0
|
# Position confidence
|
||||||
if avg_aim_speed > SMALL:
|
pc = 1.0
|
||||||
aim_ratio = min(recent_aim_speed / avg_aim_speed, 1.0)
|
if avg_ps > min_speed:
|
||||||
aim_confidence = aim_ratio ** sensitivity
|
ratio = rec_ps / avg_ps
|
||||||
|
if ratio > 1.0: ratio = 1.0
|
||||||
|
if ratio < dead_zone:
|
||||||
|
rm = ratio / dead_zone if dead_zone > SMALL else 0.0
|
||||||
|
if rm > 1.0: rm = 1.0
|
||||||
|
pc = _pow(rm, sensitivity)
|
||||||
|
|
||||||
|
# Aim confidence
|
||||||
|
ac = 1.0
|
||||||
|
if avg_as > min_speed:
|
||||||
|
ratio = rec_as / avg_as
|
||||||
|
if ratio > 1.0: ratio = 1.0
|
||||||
|
if ratio < dead_zone:
|
||||||
|
rm = ratio / dead_zone if dead_zone > SMALL else 0.0
|
||||||
|
if rm > 1.0: rm = 1.0
|
||||||
|
ac = _pow(rm, sensitivity)
|
||||||
|
|
||||||
# Extrapolation time
|
# Extrapolation time
|
||||||
extrap_dt = frames[i].extrap_time if frames[i].extrap_time > 0 else 0.011
|
lsi = safe[-1]
|
||||||
|
edt = ct - ts[lsi]
|
||||||
|
if edt <= 0: edt = 0.011
|
||||||
|
|
||||||
# Predict from last safe position
|
# Damping
|
||||||
pred_pos = vec_add(frames[i - 1].real_pos, vec_scale(avg_pos_vel, extrap_dt * pos_confidence))
|
ds = _exp(-damping * edt) if damping > 0.0 else 1.0
|
||||||
pred_aim_raw = vec_add(frames[i - 1].real_aim, vec_scale(avg_aim_vel, extrap_dt * aim_confidence))
|
|
||||||
pred_aim = vec_normalize(pred_aim_raw)
|
|
||||||
|
|
||||||
# Errors
|
# Predict
|
||||||
pos_errors.append(vec_dist(pred_pos, frames[i].real_pos))
|
m = edt * pc * ds
|
||||||
real_aim_n = vec_normalize(frames[i].real_aim)
|
ppx = px[lsi] + apx * m
|
||||||
if vec_len(pred_aim) > 0.5 and vec_len(real_aim_n) > 0.5:
|
ppy = py[lsi] + apy * m
|
||||||
aim_errors.append(angle_between(pred_aim, real_aim_n))
|
ppz = pz[lsi] + apz * m
|
||||||
|
|
||||||
|
ma = edt * ac * ds
|
||||||
|
pax_r = ax[lsi] + aax * ma
|
||||||
|
pay_r = ay[lsi] + aay * ma
|
||||||
|
paz_r = az[lsi] + aaz * ma
|
||||||
|
pa_len = _sqrt(pax_r*pax_r + pay_r*pay_r + paz_r*paz_r)
|
||||||
|
|
||||||
|
# Position error
|
||||||
|
dx = ppx - px[i]; dy = ppy - py[i]; dz = ppz - pz[i]
|
||||||
|
pos_errors.append(_sqrt(dx*dx + dy*dy + dz*dz))
|
||||||
|
|
||||||
|
# Aim error
|
||||||
|
if pa_len > 0.5:
|
||||||
|
inv_pa = 1.0 / pa_len
|
||||||
|
pax_n = pax_r * inv_pa; pay_n = pay_r * inv_pa; paz_n = paz_r * inv_pa
|
||||||
|
ra_len = _sqrt(ax[i]*ax[i] + ay[i]*ay[i] + az[i]*az[i])
|
||||||
|
if ra_len > 0.5:
|
||||||
|
inv_ra = 1.0 / ra_len
|
||||||
|
dot = pax_n * ax[i] * inv_ra + pay_n * ay[i] * inv_ra + paz_n * az[i] * inv_ra
|
||||||
|
if dot > 1.0: dot = 1.0
|
||||||
|
if dot < -1.0: dot = -1.0
|
||||||
|
aim_errors.append(_degrees(_acos(dot)))
|
||||||
|
|
||||||
return pos_errors, aim_errors
|
return pos_errors, aim_errors
|
||||||
|
|
||||||
|
|
||||||
def find_optimal_parameters(frames: List[Frame]) -> dict:
|
# --- Scoring ---
|
||||||
"""Search for optimal AdaptiveSensitivity (power curve exponent for deceleration detection)."""
|
|
||||||
print("\nSearching for optimal AdaptiveSensitivity (power exponent)...")
|
|
||||||
print("-" * 60)
|
|
||||||
|
|
||||||
best_score = float('inf')
|
def compute_score(pos_errors: List[float], aim_errors: List[float]) -> ScoreResult:
|
||||||
best_sensitivity = 1.0
|
"""Compute a combined score from position and aim errors, including stability metrics."""
|
||||||
|
|
||||||
# Search grid — exponent range 0.1 to 5.0
|
|
||||||
sensitivities = [0.1, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 2.5, 3.0, 4.0, 5.0]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for sens in sensitivities:
|
|
||||||
pos_errors, aim_errors = simulate_adaptive(frames, sens)
|
|
||||||
if not pos_errors:
|
if not pos_errors:
|
||||||
continue
|
return ScoreResult(0, 0, 0, 0, 0, 0, float('inf'))
|
||||||
pos_errors_s = sorted(pos_errors)
|
|
||||||
aim_errors_s = sorted(aim_errors) if aim_errors else [0]
|
pos_sorted = sorted(pos_errors)
|
||||||
|
aim_sorted = sorted(aim_errors) if aim_errors else [0]
|
||||||
|
|
||||||
pos_mean = sum(pos_errors) / len(pos_errors)
|
pos_mean = sum(pos_errors) / len(pos_errors)
|
||||||
pos_p95 = pos_errors_s[int(len(pos_errors_s) * 0.95)]
|
pos_p95 = pos_sorted[int(len(pos_sorted) * 0.95)]
|
||||||
aim_mean = sum(aim_errors) / len(aim_errors) if aim_errors else 0
|
aim_mean = sum(aim_errors) / len(aim_errors) if aim_errors else 0
|
||||||
aim_p95 = aim_errors_s[int(len(aim_errors_s) * 0.95)] if aim_errors else 0
|
aim_p95 = aim_sorted[int(len(aim_sorted) * 0.95)] if aim_errors else 0
|
||||||
|
|
||||||
# Score: combine position and aim errors (aim weighted more since it's the main issue)
|
# Jitter: standard deviation of frame-to-frame error change
|
||||||
score = pos_mean * 0.3 + pos_p95 * 0.2 + aim_mean * 0.3 + aim_p95 * 0.2
|
jitter = 0.0
|
||||||
|
if len(pos_errors) > 1:
|
||||||
|
deltas = [abs(pos_errors[i] - pos_errors[i - 1]) for i in range(1, len(pos_errors))]
|
||||||
|
delta_mean = sum(deltas) / len(deltas)
|
||||||
|
jitter = math.sqrt(sum((d - delta_mean) ** 2 for d in deltas) / len(deltas))
|
||||||
|
|
||||||
results.append((sens, pos_mean, pos_p95, aim_mean, aim_p95, score))
|
# Overshoot: percentage of frames where error spikes above 2x mean
|
||||||
|
overshoot = 0.0
|
||||||
|
if pos_mean > 0:
|
||||||
|
overshoot_count = sum(1 for e in pos_errors if e > 2.0 * pos_mean)
|
||||||
|
overshoot = overshoot_count / len(pos_errors)
|
||||||
|
|
||||||
|
# Combined score
|
||||||
|
score = (pos_mean * 0.25 + pos_p95 * 0.15 +
|
||||||
|
aim_mean * 0.25 + aim_p95 * 0.15 +
|
||||||
|
jitter * 0.10 + overshoot * 0.10)
|
||||||
|
|
||||||
|
return ScoreResult(pos_mean, pos_p95, aim_mean, aim_p95, jitter, overshoot, score)
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_scores(per_file_scores: List[Tuple[str, ScoreResult]],
|
||||||
|
strategy: str = "mean") -> float:
|
||||||
|
"""Aggregate scores across multiple files."""
|
||||||
|
scores = [s.score for _, s in per_file_scores]
|
||||||
|
if not scores:
|
||||||
|
return float('inf')
|
||||||
|
if strategy == "worst_case":
|
||||||
|
return max(scores)
|
||||||
|
else: # mean
|
||||||
|
return sum(scores) / len(scores)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Optimizer ---
|
||||||
|
|
||||||
|
def objective(x, all_frames, strategy):
|
||||||
|
"""Objective function for the optimizer."""
|
||||||
|
params = AdaptiveParams(
|
||||||
|
sensitivity=x[0],
|
||||||
|
dead_zone=x[1],
|
||||||
|
min_speed=x[2],
|
||||||
|
damping=x[3],
|
||||||
|
buffer_time_ms=x[4],
|
||||||
|
discard_time_ms=x[5]
|
||||||
|
)
|
||||||
|
per_file_scores = []
|
||||||
|
for name, frames in all_frames:
|
||||||
|
pos_errors, aim_errors = simulate_adaptive(frames, params)
|
||||||
|
score_result = compute_score(pos_errors, aim_errors)
|
||||||
|
per_file_scores.append((name, score_result))
|
||||||
|
return aggregate_scores(per_file_scores, strategy)
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_differential_evolution(all_frames, strategy="mean", max_iter=200, min_discard_ms=10.0):
|
||||||
|
"""Find optimal parameters using scipy differential evolution."""
|
||||||
|
try:
|
||||||
|
from scipy.optimize import differential_evolution
|
||||||
|
except ImportError:
|
||||||
|
print("ERROR: scipy is required for optimization.")
|
||||||
|
print("Install with: pip install scipy")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
bounds = [
|
||||||
|
(0.1, 5.0), # sensitivity
|
||||||
|
(0.0, 0.95), # dead_zone
|
||||||
|
(0.0, 200.0), # min_speed
|
||||||
|
(0.0, 50.0), # damping
|
||||||
|
(100.0, 500.0), # buffer_time_ms
|
||||||
|
(max(10.0, min_discard_ms), 100.0), # discard_time_ms (floor from contamination analysis)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"\nRunning differential evolution (maxiter={max_iter}, popsize=25, min_discard={min_discard_ms:.0f}ms)...")
|
||||||
|
print("This may take a few minutes...\n")
|
||||||
|
|
||||||
|
result = differential_evolution(
|
||||||
|
objective,
|
||||||
|
bounds,
|
||||||
|
args=(all_frames, strategy),
|
||||||
|
maxiter=max_iter,
|
||||||
|
seed=42,
|
||||||
|
tol=1e-4,
|
||||||
|
popsize=25,
|
||||||
|
disp=True,
|
||||||
|
workers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
best_params = AdaptiveParams(
|
||||||
|
sensitivity=round(result.x[0], 2),
|
||||||
|
dead_zone=round(result.x[1], 3),
|
||||||
|
min_speed=round(result.x[2], 1),
|
||||||
|
damping=round(result.x[3], 1),
|
||||||
|
buffer_time_ms=round(result.x[4], 0),
|
||||||
|
discard_time_ms=round(result.x[5], 0)
|
||||||
|
)
|
||||||
|
return best_params, result.fun
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_grid_search(all_frames, strategy="mean", min_discard_ms=10.0):
|
||||||
|
"""Find optimal parameters using grid search (slower but no scipy needed)."""
|
||||||
|
print(f"\nRunning grid search over 6 parameters (min_discard={min_discard_ms:.0f}ms)...")
|
||||||
|
|
||||||
|
sensitivities = [1.0, 2.0, 3.0, 4.0]
|
||||||
|
dead_zones = [0.7, 0.8, 0.9]
|
||||||
|
min_speeds = [0.0, 30.0]
|
||||||
|
dampings = [5.0, 10.0, 15.0]
|
||||||
|
buffer_times = [300.0, 400.0, 500.0, 600.0, 800.0]
|
||||||
|
discard_times = [d for d in [20.0, 40.0, 60.0, 100.0, 150.0, 200.0] if d >= min_discard_ms]
|
||||||
|
if not discard_times:
|
||||||
|
discard_times = [min_discard_ms]
|
||||||
|
|
||||||
|
total = (len(sensitivities) * len(dead_zones) * len(min_speeds) *
|
||||||
|
len(dampings) * len(buffer_times) * len(discard_times))
|
||||||
|
print(f"Total combinations: {total}")
|
||||||
|
|
||||||
|
best_score = float('inf')
|
||||||
|
best_params = AdaptiveParams()
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for sens in sensitivities:
|
||||||
|
for dz in dead_zones:
|
||||||
|
for ms in min_speeds:
|
||||||
|
for damp in dampings:
|
||||||
|
for bt in buffer_times:
|
||||||
|
for dt in discard_times:
|
||||||
|
count += 1
|
||||||
|
if count % 500 == 0:
|
||||||
|
print(f" Progress: {count}/{total} ({100 * count / total:.0f}%) best={best_score:.4f}")
|
||||||
|
|
||||||
|
params = AdaptiveParams(sens, dz, ms, damp, bt, dt)
|
||||||
|
per_file_scores = []
|
||||||
|
for name, frames in all_frames:
|
||||||
|
pos_errors, aim_errors = simulate_adaptive(frames, params)
|
||||||
|
score_result = compute_score(pos_errors, aim_errors)
|
||||||
|
per_file_scores.append((name, score_result))
|
||||||
|
|
||||||
|
score = aggregate_scores(per_file_scores, strategy)
|
||||||
if score < best_score:
|
if score < best_score:
|
||||||
best_score = score
|
best_score = score
|
||||||
best_sensitivity = sens
|
best_params = params
|
||||||
|
|
||||||
# Print results
|
return best_params, best_score
|
||||||
results.sort(key=lambda x: x[5])
|
|
||||||
print(f"\n{'Sensitivity':>12} {'Pos Mean':>10} {'Pos P95':>10} {'Aim Mean':>10} {'Aim P95':>10} {'Score':>10}")
|
|
||||||
print("-" * 65)
|
|
||||||
for sens, pm, pp, am, ap, score in results[:10]:
|
|
||||||
marker = " <-- BEST" if sens == best_sensitivity else ""
|
|
||||||
print(f"{sens:>12.2f} {pm:>10.3f} {pp:>10.3f} {am:>10.3f} {ap:>10.3f} {score:>10.3f}{marker}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
'sensitivity': best_sensitivity,
|
# --- Main ---
|
||||||
'score': best_score,
|
|
||||||
}
|
def print_file_stats(name: str, frames: List[Frame]):
|
||||||
|
"""Print basic stats for a CSV file."""
|
||||||
|
duration = frames[-1].timestamp - frames[0].timestamp
|
||||||
|
avg_fps = len(frames) / duration if duration > 0 else 0
|
||||||
|
avg_safe = sum(f.safe_count for f in frames) / len(frames)
|
||||||
|
avg_buffer = sum(f.buffer_count for f in frames) / len(frames)
|
||||||
|
avg_extrap = sum(f.extrap_time for f in frames) / len(frames) * 1000
|
||||||
|
num_shots = sum(1 for f in frames if f.shot_fired)
|
||||||
|
print(f" {os.path.basename(name)}: {len(frames)} frames, {avg_fps:.0f}fps, "
|
||||||
|
f"{duration:.1f}s, safe={avg_safe:.1f}, extrap={avg_extrap:.1f}ms, shots={num_shots}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_score_detail(name: str, score: ScoreResult):
|
||||||
|
"""Print detailed score for a file."""
|
||||||
|
print(f" {os.path.basename(name):30s} Pos: mean={score.pos_mean:.3f}cm p95={score.pos_p95:.3f}cm | "
|
||||||
|
f"Aim: mean={score.aim_mean:.3f}deg p95={score.aim_p95:.3f}deg | "
|
||||||
|
f"jitter={score.jitter:.3f} overshoot={score.overshoot:.1%} | "
|
||||||
|
f"score={score.score:.4f}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if len(sys.argv) < 2:
|
parser = argparse.ArgumentParser(
|
||||||
print("Usage: python analyze_antirecoil.py <csv_file> [--plot]")
|
description="Anti-Recoil Parameter Optimizer - finds optimal AdaptiveExtrapolation parameters"
|
||||||
print("\nCSV files are saved by the EBBarrel component in:")
|
)
|
||||||
print(" <Project>/Saved/Logs/AntiRecoil_*.csv")
|
parser.add_argument("csv_files", nargs="+", help="One or more CSV recording files")
|
||||||
sys.exit(1)
|
parser.add_argument("--plot", action="store_true", help="Generate comparison plots (requires matplotlib)")
|
||||||
|
parser.add_argument("--grid", action="store_true", help="Use grid search instead of differential evolution")
|
||||||
csv_path = sys.argv[1]
|
parser.add_argument("--strategy", choices=["mean", "worst_case"], default="mean",
|
||||||
do_plot = '--plot' in sys.argv
|
help="Multi-file score aggregation strategy (default: mean)")
|
||||||
|
parser.add_argument("--max-iter", type=int, default=200, help="Max optimizer iterations (default: 200)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load all CSV files
|
||||||
|
all_frames = []
|
||||||
|
for csv_path in args.csv_files:
|
||||||
if not os.path.exists(csv_path):
|
if not os.path.exists(csv_path):
|
||||||
print(f"Error: File not found: {csv_path}")
|
print(f"Error: File not found: {csv_path}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
print(f"Loading: {csv_path}")
|
|
||||||
frames = load_csv(csv_path)
|
frames = load_csv(csv_path)
|
||||||
print(f"Loaded {len(frames)} frames")
|
|
||||||
|
|
||||||
if len(frames) < 50:
|
if len(frames) < 50:
|
||||||
print("Warning: very few frames. Record at least a few seconds of movement for good results.")
|
print(f"Warning: {csv_path} has only {len(frames)} frames (need at least 50 for good results)")
|
||||||
|
all_frames.append((csv_path, frames))
|
||||||
|
|
||||||
# Basic stats
|
print(f"\nLoaded {len(all_frames)} file(s)")
|
||||||
duration = frames[-1].timestamp - frames[0].timestamp
|
print("=" * 70)
|
||||||
avg_fps = len(frames) / duration if duration > 0 else 0
|
|
||||||
print(f"Duration: {duration:.1f}s | Avg FPS: {avg_fps:.1f}")
|
|
||||||
print(f"Avg safe samples: {sum(f.safe_count for f in frames) / len(frames):.1f}")
|
|
||||||
print(f"Avg buffer samples: {sum(f.buffer_count for f in frames) / len(frames):.1f}")
|
|
||||||
print(f"Avg extrapolation time: {sum(f.extrap_time for f in frames) / len(frames) * 1000:.1f}ms")
|
|
||||||
|
|
||||||
# Current prediction error (as recorded)
|
# Per-file stats
|
||||||
print("\n=== CURRENT PREDICTION ERROR (as recorded) ===")
|
print("\n=== FILE STATISTICS ===")
|
||||||
current_err = compute_prediction_error(frames)
|
for name, frames in all_frames:
|
||||||
print(f"Position error - Mean: {current_err['pos_mean']:.3f}cm | P95: {current_err['pos_p95']:.3f}cm | Max: {current_err['pos_max']:.3f}cm")
|
print_file_stats(name, frames)
|
||||||
print(f"Aim error - Mean: {current_err['aim_mean']:.3f}deg | P95: {current_err['aim_p95']:.3f}deg | Max: {current_err['aim_max']:.3f}deg")
|
|
||||||
|
|
||||||
# Find optimal parameters
|
# Shot contamination analysis
|
||||||
print("\n=== PARAMETER OPTIMIZATION ===")
|
has_shots = any(any(f.shot_fired for f in frames) for _, frames in all_frames)
|
||||||
optimal = find_optimal_parameters(frames)
|
if has_shots:
|
||||||
|
print("\n=== SHOT CONTAMINATION ANALYSIS ===")
|
||||||
|
max_recommended_discard = 0.0
|
||||||
|
for name, frames in all_frames:
|
||||||
|
result = analyze_shot_contamination(frames)
|
||||||
|
if result:
|
||||||
|
print(f" {os.path.basename(name)}:")
|
||||||
|
print(f" Shots detected: {result['num_shots']}")
|
||||||
|
print(f" Speed spike: mean={result['speed_spike_mean']:.1f} cm/s, max={result['speed_spike_max']:.1f} cm/s")
|
||||||
|
print(f" Contamination duration: mean={result['contamination_mean_ms']:.1f}ms, "
|
||||||
|
f"p95={result['contamination_p95_ms']:.1f}ms, max={result['contamination_max_ms']:.1f}ms")
|
||||||
|
print(f" Recommended discard_time: >= {result['recommended_discard_ms']:.0f}ms")
|
||||||
|
max_recommended_discard = max(max_recommended_discard, result['recommended_discard_ms'])
|
||||||
|
else:
|
||||||
|
print(f" {os.path.basename(name)}: no shots detected")
|
||||||
|
|
||||||
print(f"\n{'=' * 50}")
|
if max_recommended_discard > 0:
|
||||||
print(f" RECOMMENDED SETTINGS:")
|
print(f"\n >>> MINIMUM SAFE DiscardTime across all files: {max_recommended_discard:.0f}ms <<<")
|
||||||
print(f" AdaptiveSensitivity = {optimal['sensitivity']:.2f}")
|
else:
|
||||||
print(f"{'=' * 50}")
|
print("\n (No ShotFired data in CSV - record with updated plugin to get contamination analysis)")
|
||||||
|
|
||||||
|
# Baseline: current default parameters
|
||||||
|
default_params = AdaptiveParams()
|
||||||
|
print(f"\n=== BASELINE (defaults: sens={default_params.sensitivity}, dz={default_params.dead_zone}, "
|
||||||
|
f"minspd={default_params.min_speed}, damp={default_params.damping}, "
|
||||||
|
f"buf={default_params.buffer_time_ms}ms, disc={default_params.discard_time_ms}ms) ===")
|
||||||
|
|
||||||
|
baseline_scores = []
|
||||||
|
for name, frames in all_frames:
|
||||||
|
pos_errors, aim_errors = simulate_adaptive(frames, default_params)
|
||||||
|
score = compute_score(pos_errors, aim_errors)
|
||||||
|
baseline_scores.append((name, score))
|
||||||
|
print_score_detail(name, score)
|
||||||
|
|
||||||
|
baseline_agg = aggregate_scores(baseline_scores, args.strategy)
|
||||||
|
print(f"\n Aggregate score ({args.strategy}): {baseline_agg:.4f}")
|
||||||
|
|
||||||
|
# Also show recorded prediction error (as-is from the engine)
|
||||||
|
print(f"\n=== RECORDED PREDICTION ERROR (as captured in-engine) ===")
|
||||||
|
for name, frames in all_frames:
|
||||||
|
err = compute_prediction_error(frames)
|
||||||
|
print(f" {os.path.basename(name):30s} Pos: mean={err['pos_mean']:.3f}cm p95={err['pos_p95']:.3f}cm | "
|
||||||
|
f"Aim: mean={err['aim_mean']:.3f}deg p95={err['aim_p95']:.3f}deg")
|
||||||
|
|
||||||
|
# Compute minimum safe discard time from shot contamination analysis
|
||||||
|
min_discard_ms = 10.0 # absolute minimum
|
||||||
|
if has_shots:
|
||||||
|
for name, frames in all_frames:
|
||||||
|
result = analyze_shot_contamination(frames)
|
||||||
|
if result and result['recommended_discard_ms'] > min_discard_ms:
|
||||||
|
min_discard_ms = result['recommended_discard_ms']
|
||||||
|
|
||||||
|
# Optimize
|
||||||
|
print(f"\n=== OPTIMIZATION ({args.strategy}) ===")
|
||||||
|
|
||||||
|
if args.grid:
|
||||||
|
best_params, best_score = optimize_grid_search(all_frames, args.strategy, min_discard_ms)
|
||||||
|
else:
|
||||||
|
best_params, best_score = optimize_differential_evolution(all_frames, args.strategy, args.max_iter, min_discard_ms)
|
||||||
|
|
||||||
|
# Results
|
||||||
|
print(f"\n{'=' * 70}")
|
||||||
|
print(f" BEST PARAMETERS FOUND:")
|
||||||
|
print(f" AdaptiveSensitivity = {best_params.sensitivity}")
|
||||||
|
print(f" AdaptiveDeadZone = {best_params.dead_zone}")
|
||||||
|
print(f" AdaptiveMinSpeed = {best_params.min_speed}")
|
||||||
|
print(f" ExtrapolationDamping = {best_params.damping}")
|
||||||
|
print(f" AntiRecoilBufferTimeMs = {best_params.buffer_time_ms}")
|
||||||
|
print(f" AntiRecoilDiscardTimeMs= {best_params.discard_time_ms}")
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
|
||||||
|
# Per-file breakdown with optimized params
|
||||||
|
print(f"\n=== OPTIMIZED RESULTS ===")
|
||||||
|
opt_scores = []
|
||||||
|
for name, frames in all_frames:
|
||||||
|
pos_errors, aim_errors = simulate_adaptive(frames, best_params)
|
||||||
|
score = compute_score(pos_errors, aim_errors)
|
||||||
|
opt_scores.append((name, score))
|
||||||
|
print_score_detail(name, score)
|
||||||
|
|
||||||
|
opt_agg = aggregate_scores(opt_scores, args.strategy)
|
||||||
|
print(f"\n Aggregate score ({args.strategy}): {opt_agg:.4f}")
|
||||||
|
|
||||||
|
# Improvement
|
||||||
|
print(f"\n=== IMPROVEMENT vs BASELINE ===")
|
||||||
|
for (name, baseline), (_, optimized) in zip(baseline_scores, opt_scores):
|
||||||
|
pos_pct = ((baseline.pos_mean - optimized.pos_mean) / baseline.pos_mean * 100) if baseline.pos_mean > 0 else 0
|
||||||
|
aim_pct = ((baseline.aim_mean - optimized.aim_mean) / baseline.aim_mean * 100) if baseline.aim_mean > 0 else 0
|
||||||
|
score_pct = ((baseline.score - optimized.score) / baseline.score * 100) if baseline.score > 0 else 0
|
||||||
|
print(f" {os.path.basename(name):30s} Pos: {pos_pct:+.1f}% | Aim: {aim_pct:+.1f}% | Score: {score_pct:+.1f}%")
|
||||||
|
|
||||||
|
total_pct = ((baseline_agg - opt_agg) / baseline_agg * 100) if baseline_agg > 0 else 0
|
||||||
|
print(f" {'TOTAL':30s} Score: {total_pct:+.1f}%")
|
||||||
|
|
||||||
# Plotting
|
# Plotting
|
||||||
if do_plot:
|
if args.plot:
|
||||||
try:
|
try:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True)
|
n_files = len(all_frames)
|
||||||
|
fig, axes = plt.subplots(n_files, 3, figsize=(18, 5 * n_files), squeeze=False)
|
||||||
|
|
||||||
|
for row, (name, frames) in enumerate(all_frames):
|
||||||
timestamps = [f.timestamp - frames[0].timestamp for f in frames]
|
timestamps = [f.timestamp - frames[0].timestamp for f in frames]
|
||||||
|
short_name = os.path.basename(name)
|
||||||
|
|
||||||
|
# Baseline errors
|
||||||
|
bl_pos, bl_aim = simulate_adaptive(frames, default_params)
|
||||||
|
# Optimized errors
|
||||||
|
op_pos, op_aim = simulate_adaptive(frames, best_params)
|
||||||
|
|
||||||
|
# Time axis for simulated errors (offset by window_size)
|
||||||
|
t_start = window_size = 12
|
||||||
|
sim_timestamps = [frames[i].timestamp - frames[0].timestamp
|
||||||
|
for i in range(t_start + 1, t_start + 1 + len(bl_pos))]
|
||||||
|
|
||||||
# Position error
|
# Position error
|
||||||
pos_errors = [vec_dist(f.pred_pos, f.real_pos) for f in frames]
|
ax = axes[row][0]
|
||||||
axes[0].plot(timestamps, pos_errors, 'r-', alpha=0.5, linewidth=0.5)
|
if len(sim_timestamps) == len(bl_pos):
|
||||||
axes[0].set_ylabel('Position Error (cm)')
|
ax.plot(sim_timestamps, bl_pos, 'r-', alpha=0.4, linewidth=0.5, label='Baseline')
|
||||||
axes[0].set_title('Prediction Error Over Time')
|
ax.plot(sim_timestamps, op_pos, 'g-', alpha=0.4, linewidth=0.5, label='Optimized')
|
||||||
axes[0].axhline(y=current_err['pos_mean'], color='r', linestyle='--', alpha=0.5, label=f"Mean: {current_err['pos_mean']:.2f}cm")
|
ax.set_ylabel('Position Error (cm)')
|
||||||
axes[0].legend()
|
ax.set_title(f'{short_name} - Position Error')
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
# Aim error
|
# Aim error
|
||||||
aim_errors = []
|
ax = axes[row][1]
|
||||||
for f in frames:
|
if len(sim_timestamps) >= len(bl_aim):
|
||||||
a = vec_normalize(f.pred_aim)
|
t_aim = sim_timestamps[:len(bl_aim)]
|
||||||
b = vec_normalize(f.real_aim)
|
ax.plot(t_aim, bl_aim, 'r-', alpha=0.4, linewidth=0.5, label='Baseline')
|
||||||
if vec_len(a) > 0.5 and vec_len(b) > 0.5:
|
if len(sim_timestamps) >= len(op_aim):
|
||||||
aim_errors.append(angle_between(a, b))
|
t_aim = sim_timestamps[:len(op_aim)]
|
||||||
else:
|
ax.plot(t_aim, op_aim, 'g-', alpha=0.4, linewidth=0.5, label='Optimized')
|
||||||
aim_errors.append(0)
|
ax.set_ylabel('Aim Error (deg)')
|
||||||
axes[1].plot(timestamps, aim_errors, 'b-', alpha=0.5, linewidth=0.5)
|
ax.set_title(f'{short_name} - Aim Error')
|
||||||
axes[1].set_ylabel('Aim Error (degrees)')
|
ax.legend()
|
||||||
axes[1].axhline(y=current_err['aim_mean'], color='b', linestyle='--', alpha=0.5, label=f"Mean: {current_err['aim_mean']:.2f}deg")
|
|
||||||
axes[1].legend()
|
|
||||||
|
|
||||||
# Speed (from real positions)
|
# Speed profile
|
||||||
|
ax = axes[row][2]
|
||||||
speeds = [0]
|
speeds = [0]
|
||||||
for i in range(1, len(frames)):
|
for i in range(1, len(frames)):
|
||||||
dt = frames[i].timestamp - frames[i - 1].timestamp
|
dt = frames[i].timestamp - frames[i - 1].timestamp
|
||||||
@ -341,12 +758,13 @@ def main():
|
|||||||
speeds.append(d / dt)
|
speeds.append(d / dt)
|
||||||
else:
|
else:
|
||||||
speeds.append(speeds[-1])
|
speeds.append(speeds[-1])
|
||||||
axes[2].plot(timestamps, speeds, 'g-', alpha=0.7, linewidth=0.5)
|
ax.plot(timestamps, speeds, 'b-', alpha=0.7, linewidth=0.5)
|
||||||
axes[2].set_ylabel('Speed (cm/s)')
|
ax.set_ylabel('Speed (cm/s)')
|
||||||
axes[2].set_xlabel('Time (s)')
|
ax.set_xlabel('Time (s)')
|
||||||
|
ax.set_title(f'{short_name} - Speed Profile')
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plot_path = csv_path.replace('.csv', '_analysis.png')
|
plot_path = args.csv_files[0].replace('.csv', '_optimizer.png')
|
||||||
plt.savefig(plot_path, dpi=150)
|
plt.savefig(plot_path, dpi=150)
|
||||||
print(f"\nPlot saved: {plot_path}")
|
print(f"\nPlot saved: {plot_path}")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
366
Tools/analyze_shots.py
Normal file
366
Tools/analyze_shots.py
Normal file
@ -0,0 +1,366 @@
|
|||||||
|
"""
|
||||||
|
Shot Contamination Analyzer
|
||||||
|
============================
|
||||||
|
Analyzes the precise contamination zone around each shot event.
|
||||||
|
Shows speed/acceleration profiles before and after each shot to identify
|
||||||
|
the exact duration of IMU perturbation vs voluntary movement.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python analyze_shots.py <csv_file> [--plot] [--window 100]
|
||||||
|
|
||||||
|
Protocol for best results:
|
||||||
|
1. Stay stable (no movement) for 2-3 seconds
|
||||||
|
2. Fire a single shot
|
||||||
|
3. Stay stable again for 2-3 seconds
|
||||||
|
4. Repeat 10+ times
|
||||||
|
This isolates the IMU shock from voluntary movement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Frame:
|
||||||
|
timestamp: float
|
||||||
|
real_pos: Tuple[float, float, float]
|
||||||
|
real_aim: Tuple[float, float, float]
|
||||||
|
pred_pos: Tuple[float, float, float]
|
||||||
|
pred_aim: Tuple[float, float, float]
|
||||||
|
safe_count: int
|
||||||
|
buffer_count: int
|
||||||
|
extrap_time: float
|
||||||
|
shot_fired: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def load_csv(path: str) -> List[Frame]:
|
||||||
|
frames = []
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
frames.append(Frame(
|
||||||
|
timestamp=float(row['Timestamp']),
|
||||||
|
real_pos=(float(row['RealPosX']), float(row['RealPosY']), float(row['RealPosZ'])),
|
||||||
|
real_aim=(float(row['RealAimX']), float(row['RealAimY']), float(row['RealAimZ'])),
|
||||||
|
pred_pos=(float(row['PredPosX']), float(row['PredPosY']), float(row['PredPosZ'])),
|
||||||
|
pred_aim=(float(row['PredAimX']), float(row['PredAimY']), float(row['PredAimZ'])),
|
||||||
|
safe_count=int(row['SafeCount']),
|
||||||
|
buffer_count=int(row['BufferCount']),
|
||||||
|
extrap_time=float(row['ExtrapolationTime']),
|
||||||
|
shot_fired=int(row.get('ShotFired', 0)) == 1,
|
||||||
|
))
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def vec_dist(a, b):
|
||||||
|
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
|
||||||
|
|
||||||
|
|
||||||
|
def vec_sub(a, b):
|
||||||
|
return tuple(ai - bi for ai, bi in zip(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
def vec_len(a):
|
||||||
|
return math.sqrt(sum(ai * ai for ai in a))
|
||||||
|
|
||||||
|
|
||||||
|
def vec_normalize(a):
|
||||||
|
l = vec_len(a)
|
||||||
|
if l < 1e-10:
|
||||||
|
return (0, 0, 0)
|
||||||
|
return tuple(ai / l for ai in a)
|
||||||
|
|
||||||
|
|
||||||
|
def angle_between(a, b):
|
||||||
|
dot = sum(ai * bi for ai, bi in zip(a, b))
|
||||||
|
dot = max(-1.0, min(1.0, dot))
|
||||||
|
return math.degrees(math.acos(dot))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_per_frame_metrics(frames):
|
||||||
|
"""Compute speed, acceleration, and aim angular speed per frame."""
|
||||||
|
n = len(frames)
|
||||||
|
pos_speed = [0.0] * n
|
||||||
|
aim_speed = [0.0] * n
|
||||||
|
pos_accel = [0.0] * n
|
||||||
|
|
||||||
|
for i in range(1, n):
|
||||||
|
dt = frames[i].timestamp - frames[i - 1].timestamp
|
||||||
|
if dt > 1e-6:
|
||||||
|
pos_speed[i] = vec_dist(frames[i].real_pos, frames[i - 1].real_pos) / dt
|
||||||
|
|
||||||
|
aim_a = vec_normalize(frames[i].real_aim)
|
||||||
|
aim_b = vec_normalize(frames[i - 1].real_aim)
|
||||||
|
if vec_len(aim_a) > 0.5 and vec_len(aim_b) > 0.5:
|
||||||
|
aim_speed[i] = angle_between(aim_a, aim_b) / dt # deg/s
|
||||||
|
|
||||||
|
for i in range(1, n):
|
||||||
|
dt = frames[i].timestamp - frames[i - 1].timestamp
|
||||||
|
if dt > 1e-6:
|
||||||
|
pos_accel[i] = (pos_speed[i] - pos_speed[i - 1]) / dt
|
||||||
|
|
||||||
|
return pos_speed, aim_speed, pos_accel
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_single_shot(frames, shot_idx, pos_speed, aim_speed, pos_accel, window_ms=200.0):
|
||||||
|
"""Analyze contamination around a single shot event."""
|
||||||
|
window_s = window_ms / 1000.0
|
||||||
|
shot_time = frames[shot_idx].timestamp
|
||||||
|
|
||||||
|
# Collect frames in window before and after shot
|
||||||
|
before = [] # (time_relative_ms, pos_speed, aim_speed, pos_accel)
|
||||||
|
after = []
|
||||||
|
|
||||||
|
for i in range(max(0, shot_idx - 100), min(len(frames), shot_idx + 100)):
|
||||||
|
dt_ms = (frames[i].timestamp - shot_time) * 1000.0
|
||||||
|
if -window_ms <= dt_ms < 0:
|
||||||
|
before.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i]))
|
||||||
|
elif dt_ms >= 0 and dt_ms <= window_ms:
|
||||||
|
after.append((dt_ms, pos_speed[i], aim_speed[i], pos_accel[i]))
|
||||||
|
|
||||||
|
if not before:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Baseline: average speed in the window before the shot
|
||||||
|
baseline_pos_speed = sum(s for _, s, _, _ in before) / len(before)
|
||||||
|
baseline_aim_speed = sum(s for _, _, s, _ in before) / len(before)
|
||||||
|
baseline_pos_std = math.sqrt(sum((s - baseline_pos_speed) ** 2 for _, s, _, _ in before) / len(before)) if len(before) > 1 else 0.0
|
||||||
|
baseline_aim_std = math.sqrt(sum((s - baseline_aim_speed) ** 2 for _, _, s, _ in before) / len(before)) if len(before) > 1 else 0.0
|
||||||
|
|
||||||
|
# Find contamination end: when speed returns to within 2 sigma of baseline
|
||||||
|
pos_threshold = baseline_pos_speed + max(2.0 * baseline_pos_std, 5.0) # at least 5 cm/s
|
||||||
|
aim_threshold = baseline_aim_speed + max(2.0 * baseline_aim_std, 5.0) # at least 5 deg/s
|
||||||
|
|
||||||
|
pos_contamination_end_ms = 0.0
|
||||||
|
aim_contamination_end_ms = 0.0
|
||||||
|
max_pos_spike = 0.0
|
||||||
|
max_aim_spike = 0.0
|
||||||
|
|
||||||
|
for dt_ms, ps, ais, _ in after:
|
||||||
|
if ps > pos_threshold:
|
||||||
|
pos_contamination_end_ms = dt_ms
|
||||||
|
if ais > aim_threshold:
|
||||||
|
aim_contamination_end_ms = dt_ms
|
||||||
|
max_pos_spike = max(max_pos_spike, ps - baseline_pos_speed)
|
||||||
|
max_aim_spike = max(max_aim_spike, ais - baseline_aim_speed)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'shot_time': shot_time,
|
||||||
|
'baseline_pos_speed': baseline_pos_speed,
|
||||||
|
'baseline_aim_speed': baseline_aim_speed,
|
||||||
|
'baseline_pos_std': baseline_pos_std,
|
||||||
|
'baseline_aim_std': baseline_aim_std,
|
||||||
|
'pos_contamination_ms': pos_contamination_end_ms,
|
||||||
|
'aim_contamination_ms': aim_contamination_end_ms,
|
||||||
|
'max_contamination_ms': max(pos_contamination_end_ms, aim_contamination_end_ms),
|
||||||
|
'max_pos_spike': max_pos_spike,
|
||||||
|
'max_aim_spike': max_aim_spike,
|
||||||
|
'pos_threshold': pos_threshold,
|
||||||
|
'aim_threshold': aim_threshold,
|
||||||
|
'before': before,
|
||||||
|
'after': after,
|
||||||
|
'is_stable': baseline_pos_speed < 30.0 and baseline_aim_speed < 200.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Shot Contamination Analyzer")
|
||||||
|
parser.add_argument("csv_file", help="CSV recording file with ShotFired column")
|
||||||
|
parser.add_argument("--plot", action="store_true", help="Generate per-shot plots (requires matplotlib)")
|
||||||
|
parser.add_argument("--window", type=float, default=200.0, help="Analysis window in ms before/after shot (default: 200)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.csv_file):
|
||||||
|
print(f"Error: File not found: {args.csv_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
frames = load_csv(args.csv_file)
|
||||||
|
print(f"Loaded {len(frames)} frames from {os.path.basename(args.csv_file)}")
|
||||||
|
|
||||||
|
duration = frames[-1].timestamp - frames[0].timestamp
|
||||||
|
fps = len(frames) / duration if duration > 0 else 0
|
||||||
|
print(f"Duration: {duration:.1f}s | FPS: {fps:.0f}")
|
||||||
|
|
||||||
|
shot_indices = [i for i, f in enumerate(frames) if f.shot_fired]
|
||||||
|
print(f"Shots detected: {len(shot_indices)}")
|
||||||
|
|
||||||
|
if not shot_indices:
|
||||||
|
print("No shots found! Make sure the CSV has a ShotFired column.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
pos_speed, aim_speed, pos_accel = compute_per_frame_metrics(frames)
|
||||||
|
|
||||||
|
# Analyze each shot
|
||||||
|
results = []
|
||||||
|
print(f"\n{'=' * 90}")
|
||||||
|
print(f"{'Shot':>4} {'Time':>8} {'Stable':>7} {'PosSpike':>10} {'AimSpike':>10} "
|
||||||
|
f"{'PosContam':>10} {'AimContam':>10} {'MaxContam':>10}")
|
||||||
|
print(f"{'':>4} {'(s)':>8} {'':>7} {'(cm/s)':>10} {'(deg/s)':>10} "
|
||||||
|
f"{'(ms)':>10} {'(ms)':>10} {'(ms)':>10}")
|
||||||
|
print(f"{'-' * 90}")
|
||||||
|
|
||||||
|
for idx, si in enumerate(shot_indices):
|
||||||
|
result = analyze_single_shot(frames, si, pos_speed, aim_speed, pos_accel, args.window)
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
stable_str = "YES" if result['is_stable'] else "no"
|
||||||
|
print(f"{idx + 1:>4} {result['shot_time']:>8.2f} {stable_str:>7} "
|
||||||
|
f"{result['max_pos_spike']:>10.1f} {result['max_aim_spike']:>10.1f} "
|
||||||
|
f"{result['pos_contamination_ms']:>10.1f} {result['aim_contamination_ms']:>10.1f} "
|
||||||
|
f"{result['max_contamination_ms']:>10.1f}")
|
||||||
|
|
||||||
|
# Summary: only stable shots (user was not moving)
|
||||||
|
stable_results = [r for r in results if r['is_stable']]
|
||||||
|
all_results = results
|
||||||
|
|
||||||
|
print(f"\n{'=' * 90}")
|
||||||
|
print(f"SUMMARY - ALL SHOTS ({len(all_results)} shots)")
|
||||||
|
if all_results:
|
||||||
|
contam_all = sorted([r['max_contamination_ms'] for r in all_results])
|
||||||
|
pos_spikes = sorted([r['max_pos_spike'] for r in all_results])
|
||||||
|
aim_spikes = sorted([r['max_aim_spike'] for r in all_results])
|
||||||
|
p95_idx = int(len(contam_all) * 0.95)
|
||||||
|
print(f" Contamination: mean={sum(contam_all)/len(contam_all):.1f}ms, "
|
||||||
|
f"median={contam_all[len(contam_all)//2]:.1f}ms, "
|
||||||
|
f"p95={contam_all[min(p95_idx, len(contam_all)-1)]:.1f}ms, "
|
||||||
|
f"max={contam_all[-1]:.1f}ms")
|
||||||
|
print(f" Pos spike: mean={sum(pos_spikes)/len(pos_spikes):.1f}cm/s, "
|
||||||
|
f"max={pos_spikes[-1]:.1f}cm/s")
|
||||||
|
print(f" Aim spike: mean={sum(aim_spikes)/len(aim_spikes):.1f}deg/s, "
|
||||||
|
f"max={aim_spikes[-1]:.1f}deg/s")
|
||||||
|
|
||||||
|
print(f"\nSUMMARY - STABLE SHOTS ONLY ({len(stable_results)} shots, baseline speed < 20cm/s)")
|
||||||
|
if stable_results:
|
||||||
|
contam_stable = sorted([r['max_contamination_ms'] for r in stable_results])
|
||||||
|
pos_spikes_s = sorted([r['max_pos_spike'] for r in stable_results])
|
||||||
|
aim_spikes_s = sorted([r['max_aim_spike'] for r in stable_results])
|
||||||
|
p95_idx = int(len(contam_stable) * 0.95)
|
||||||
|
print(f" Contamination: mean={sum(contam_stable)/len(contam_stable):.1f}ms, "
|
||||||
|
f"median={contam_stable[len(contam_stable)//2]:.1f}ms, "
|
||||||
|
f"p95={contam_stable[min(p95_idx, len(contam_stable)-1)]:.1f}ms, "
|
||||||
|
f"max={contam_stable[-1]:.1f}ms")
|
||||||
|
print(f" Pos spike: mean={sum(pos_spikes_s)/len(pos_spikes_s):.1f}cm/s, "
|
||||||
|
f"max={pos_spikes_s[-1]:.1f}cm/s")
|
||||||
|
print(f" Aim spike: mean={sum(aim_spikes_s)/len(aim_spikes_s):.1f}deg/s, "
|
||||||
|
f"max={aim_spikes_s[-1]:.1f}deg/s")
|
||||||
|
|
||||||
|
recommended = math.ceil(contam_stable[min(p95_idx, len(contam_stable)-1)] / 5.0) * 5.0
|
||||||
|
print(f"\n >>> RECOMMENDED DiscardTime (from stable shots P95): {recommended:.0f}ms <<<")
|
||||||
|
else:
|
||||||
|
print(" No stable shots found! Make sure you stay still before firing.")
|
||||||
|
print(" Shots where baseline speed > 20cm/s are excluded as 'not stable'.")
|
||||||
|
|
||||||
|
# Plotting
|
||||||
|
if args.plot:
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Plot each shot individually
|
||||||
|
n_shots = len(results)
|
||||||
|
cols = min(4, n_shots)
|
||||||
|
rows = math.ceil(n_shots / cols)
|
||||||
|
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False)
|
||||||
|
fig.suptitle(f'Per-Shot Speed Profile ({os.path.basename(args.csv_file)})', fontsize=14)
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
r, c = divmod(idx, cols)
|
||||||
|
ax = axes[r][c]
|
||||||
|
|
||||||
|
# Before shot
|
||||||
|
if result['before']:
|
||||||
|
t_before = [b[0] for b in result['before']]
|
||||||
|
s_before = [b[1] for b in result['before']]
|
||||||
|
ax.plot(t_before, s_before, 'b-', linewidth=1, label='Before')
|
||||||
|
|
||||||
|
# After shot
|
||||||
|
if result['after']:
|
||||||
|
t_after = [a[0] for a in result['after']]
|
||||||
|
s_after = [a[1] for a in result['after']]
|
||||||
|
ax.plot(t_after, s_after, 'r-', linewidth=1, label='After')
|
||||||
|
|
||||||
|
# Shot line
|
||||||
|
ax.axvline(x=0, color='red', linestyle='--', alpha=0.7, label='Shot')
|
||||||
|
|
||||||
|
# Threshold
|
||||||
|
ax.axhline(y=result['pos_threshold'], color='orange', linestyle=':', alpha=0.5, label='Threshold')
|
||||||
|
|
||||||
|
# Contamination zone
|
||||||
|
if result['pos_contamination_ms'] > 0:
|
||||||
|
ax.axvspan(0, result['pos_contamination_ms'], alpha=0.15, color='red')
|
||||||
|
|
||||||
|
stable_str = "STABLE" if result['is_stable'] else "MOVING"
|
||||||
|
ax.set_title(f"Shot {idx+1} ({stable_str}) - {result['max_contamination_ms']:.0f}ms",
|
||||||
|
fontsize=9, color='green' if result['is_stable'] else 'orange')
|
||||||
|
ax.set_xlabel('Time from shot (ms)', fontsize=8)
|
||||||
|
ax.set_ylabel('Pos Speed (cm/s)', fontsize=8)
|
||||||
|
ax.tick_params(labelsize=7)
|
||||||
|
if idx == 0:
|
||||||
|
ax.legend(fontsize=6)
|
||||||
|
|
||||||
|
# Hide unused subplots
|
||||||
|
for idx in range(n_shots, rows * cols):
|
||||||
|
r, c = divmod(idx, cols)
|
||||||
|
axes[r][c].set_visible(False)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plot_path = args.csv_file.replace('.csv', '_shots.png')
|
||||||
|
plt.savefig(plot_path, dpi=150)
|
||||||
|
print(f"\nPlot saved: {plot_path}")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Also plot aim speed
|
||||||
|
fig2, axes2 = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False)
|
||||||
|
fig2.suptitle(f'Per-Shot Aim Angular Speed ({os.path.basename(args.csv_file)})', fontsize=14)
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
r, c = divmod(idx, cols)
|
||||||
|
ax = axes2[r][c]
|
||||||
|
|
||||||
|
if result['before']:
|
||||||
|
t_before = [b[0] for b in result['before']]
|
||||||
|
a_before = [b[2] for b in result['before']] # aim_speed
|
||||||
|
ax.plot(t_before, a_before, 'b-', linewidth=1)
|
||||||
|
|
||||||
|
if result['after']:
|
||||||
|
t_after = [a[0] for a in result['after']]
|
||||||
|
a_after = [a[2] for a in result['after']] # aim_speed
|
||||||
|
ax.plot(t_after, a_after, 'r-', linewidth=1)
|
||||||
|
|
||||||
|
ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
|
||||||
|
ax.axhline(y=result['aim_threshold'], color='orange', linestyle=':', alpha=0.5)
|
||||||
|
|
||||||
|
if result['aim_contamination_ms'] > 0:
|
||||||
|
ax.axvspan(0, result['aim_contamination_ms'], alpha=0.15, color='red')
|
||||||
|
|
||||||
|
stable_str = "STABLE" if result['is_stable'] else "MOVING"
|
||||||
|
ax.set_title(f"Shot {idx+1} ({stable_str}) - Aim {result['aim_contamination_ms']:.0f}ms",
|
||||||
|
fontsize=9, color='green' if result['is_stable'] else 'orange')
|
||||||
|
ax.set_xlabel('Time from shot (ms)', fontsize=8)
|
||||||
|
ax.set_ylabel('Aim Speed (deg/s)', fontsize=8)
|
||||||
|
ax.tick_params(labelsize=7)
|
||||||
|
|
||||||
|
for idx in range(n_shots, rows * cols):
|
||||||
|
r, c = divmod(idx, cols)
|
||||||
|
axes2[r][c].set_visible(False)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plot_path2 = args.csv_file.replace('.csv', '_shots_aim.png')
|
||||||
|
plt.savefig(plot_path2, dpi=150)
|
||||||
|
print(f"Plot saved: {plot_path2}")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("\nmatplotlib not installed. Install with: pip install matplotlib")
|
||||||
|
|
||||||
|
print("\nDone.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -206,7 +206,7 @@ void UEBBarrel::TickComponent(float DeltaTime, ELevelTick TickType, FActorCompon
|
|||||||
if (CSVFileHandle)
|
if (CSVFileHandle)
|
||||||
{
|
{
|
||||||
bCSVFileOpen = true;
|
bCSVFileOpen = true;
|
||||||
FString Header = TEXT("Timestamp,RealPosX,RealPosY,RealPosZ,RealAimX,RealAimY,RealAimZ,PredPosX,PredPosY,PredPosZ,PredAimX,PredAimY,PredAimZ,SafeCount,BufferCount,ExtrapolationTime\n");
|
FString Header = TEXT("Timestamp,RealPosX,RealPosY,RealPosZ,RealAimX,RealAimY,RealAimZ,PredPosX,PredPosY,PredPosZ,PredAimX,PredAimY,PredAimZ,SafeCount,BufferCount,ExtrapolationTime,ShotFired\n");
|
||||||
auto HeaderUtf8 = StringCast<ANSICHAR>(*Header);
|
auto HeaderUtf8 = StringCast<ANSICHAR>(*Header);
|
||||||
CSVFileHandle->Write((const uint8*)HeaderUtf8.Get(), HeaderUtf8.Length());
|
CSVFileHandle->Write((const uint8*)HeaderUtf8.Get(), HeaderUtf8.Length());
|
||||||
|
|
||||||
@ -231,15 +231,17 @@ void UEBBarrel::TickComponent(float DeltaTime, ELevelTick TickType, FActorCompon
|
|||||||
}
|
}
|
||||||
double ExtrapTime = (SafeN > 0) ? (GetWorld()->GetTimeSeconds() - TransformHistory[SafeN - 1].Timestamp) : 0.0;
|
double ExtrapTime = (SafeN > 0) ? (GetWorld()->GetTimeSeconds() - TransformHistory[SafeN - 1].Timestamp) : 0.0;
|
||||||
|
|
||||||
FString Line = FString::Printf(TEXT("%.6f,%.4f,%.4f,%.4f,%.6f,%.6f,%.6f,%.4f,%.4f,%.4f,%.6f,%.6f,%.6f,%d,%d,%.6f\n"),
|
FString Line = FString::Printf(TEXT("%.6f,%.4f,%.4f,%.4f,%.6f,%.6f,%.6f,%.4f,%.4f,%.4f,%.6f,%.6f,%.6f,%d,%d,%.6f,%d\n"),
|
||||||
GetWorld()->GetTimeSeconds(),
|
GetWorld()->GetTimeSeconds(),
|
||||||
RealPos.X, RealPos.Y, RealPos.Z,
|
RealPos.X, RealPos.Y, RealPos.Z,
|
||||||
RealAim.X, RealAim.Y, RealAim.Z,
|
RealAim.X, RealAim.Y, RealAim.Z,
|
||||||
Location.X, Location.Y, Location.Z,
|
Location.X, Location.Y, Location.Z,
|
||||||
Aim.X, Aim.Y, Aim.Z,
|
Aim.X, Aim.Y, Aim.Z,
|
||||||
SafeN, TransformHistory.Num(), ExtrapTime);
|
SafeN, TransformHistory.Num(), ExtrapTime,
|
||||||
|
bShotFiredThisFrame ? 1 : 0);
|
||||||
auto LineUtf8 = StringCast<ANSICHAR>(*Line);
|
auto LineUtf8 = StringCast<ANSICHAR>(*Line);
|
||||||
CSVFileHandle->Write((const uint8*)LineUtf8.Get(), LineUtf8.Length());
|
CSVFileHandle->Write((const uint8*)LineUtf8.Get(), LineUtf8.Length());
|
||||||
|
bShotFiredThisFrame = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (bCSVFileOpen)
|
else if (bCSVFileOpen)
|
||||||
@ -605,6 +607,8 @@ void UEBBarrel::SpawnBullet(AActor* Owner, FVector InLocation, FVector InAim, in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bShotFiredThisFrame = true;
|
||||||
|
|
||||||
if (ReplicateShotFiredEvents) {
|
if (ReplicateShotFiredEvents) {
|
||||||
ShotFiredMulticast();
|
ShotFiredMulticast();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -116,6 +116,7 @@ public:
|
|||||||
bool bCSVFileOpen = false;
|
bool bCSVFileOpen = false;
|
||||||
FString CSVFilePath;
|
FString CSVFilePath;
|
||||||
IFileHandle* CSVFileHandle = nullptr;
|
IFileHandle* CSVFileHandle = nullptr;
|
||||||
|
bool bShotFiredThisFrame = false;
|
||||||
|
|
||||||
// Debug HUD state (written by const prediction functions, read by TickComponent)
|
// Debug HUD state (written by const prediction functions, read by TickComponent)
|
||||||
mutable float DbgPosRatio = 0.0f;
|
mutable float DbgPosRatio = 0.0f;
|
||||||
@ -147,10 +148,10 @@ public:
|
|||||||
EAntiRecoilMode AntiRecoilMode = EAntiRecoilMode::ARM_AdaptiveExtrapolation;
|
EAntiRecoilMode AntiRecoilMode = EAntiRecoilMode::ARM_AdaptiveExtrapolation;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Total time window (ms) of tracker history to keep. Determines how far back in time samples are stored. Must be greater than DiscardTime. Example: 200ms at 60fps stores ~12 samples.", ClampMin = "5"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Total time window (ms) of tracker history to keep. Determines how far back in time samples are stored. Must be greater than DiscardTime. Example: 200ms at 60fps stores ~12 samples.", ClampMin = "5"))
|
||||||
float AntiRecoilBufferTimeMs = 300.0f;
|
float AntiRecoilBufferTimeMs = 200.0f;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Time window (ms) of most recent samples to exclude as potentially contaminated by IMU recoil shock. The prediction algorithms only use samples older than this. Increase if the shock lasts longer. Safe window = BufferTime - DiscardTime.", ClampMin = "0.0"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Time window (ms) of most recent samples to exclude as potentially contaminated by IMU recoil shock. The prediction algorithms only use samples older than this. Increase if the shock lasts longer. Safe window = BufferTime - DiscardTime.", ClampMin = "0.0"))
|
||||||
float AntiRecoilDiscardTimeMs = 40.0f;
|
float AntiRecoilDiscardTimeMs = 30.0f;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Controls how the weight curve grows across safe samples in regression modes. 1.0 = linear growth, >1.0 = recent samples weighted much more heavily (convex curve), <1.0 = more uniform weighting (concave curve), 0.0 = all samples weighted equally. Formula: weight = pow(sampleIndex+1, exponent).", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_WeightedRegression || AntiRecoilMode == EAntiRecoilMode::ARM_WeightedLinearRegression", ClampMin = "0.0", ClampMax = "5.0"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Controls how the weight curve grows across safe samples in regression modes. 1.0 = linear growth, >1.0 = recent samples weighted much more heavily (convex curve), <1.0 = more uniform weighting (concave curve), 0.0 = all samples weighted equally. Formula: weight = pow(sampleIndex+1, exponent).", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_WeightedRegression || AntiRecoilMode == EAntiRecoilMode::ARM_WeightedLinearRegression", ClampMin = "0.0", ClampMax = "5.0"))
|
||||||
float RegressionWeightExponent = 2.0f;
|
float RegressionWeightExponent = 2.0f;
|
||||||
@ -162,16 +163,16 @@ public:
|
|||||||
float KalmanMeasurementNoise = 0.01f;
|
float KalmanMeasurementNoise = 0.01f;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Power curve exponent for deceleration detection. Controls how aggressively slowing down reduces extrapolation. confidence = (remappedRatio)^sensitivity. 1.0 = linear (gentle). 2.0 = quadratic (aggressive). 0.5 = square root (very gentle). During steady movement, ratio is ~1 so confidence is always 1 regardless of this value.", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_AdaptiveExtrapolation", ClampMin = "0.1", ClampMax = "5.0"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Power curve exponent for deceleration detection. Controls how aggressively slowing down reduces extrapolation. confidence = (remappedRatio)^sensitivity. 1.0 = linear (gentle). 2.0 = quadratic (aggressive). 0.5 = square root (very gentle). During steady movement, ratio is ~1 so confidence is always 1 regardless of this value.", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_AdaptiveExtrapolation", ClampMin = "0.1", ClampMax = "5.0"))
|
||||||
float AdaptiveSensitivity = 1.5f;
|
float AdaptiveSensitivity = 3.0f;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Dead zone for deceleration detection. Speed ratios (recent/avg) above this value are treated as 1.0 (no correction). Only ratios below trigger extrapolation reduction. Higher = more tolerant to natural speed fluctuations (less false positives). Lower = more sensitive to deceleration. 0.8 = ignore normal jitter, only react to real braking.", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_AdaptiveExtrapolation", ClampMin = "0.0", ClampMax = "0.95"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Dead zone for deceleration detection. Speed ratios (recent/avg) above this value are treated as 1.0 (no correction). Only ratios below trigger extrapolation reduction. Higher = more tolerant to natural speed fluctuations (less false positives). Lower = more sensitive to deceleration. 0.8 = ignore normal jitter, only react to real braking.", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_AdaptiveExtrapolation", ClampMin = "0.0", ClampMax = "0.95"))
|
||||||
float AdaptiveDeadZone = 0.8f;
|
float AdaptiveDeadZone = 0.95f;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Minimum average speed (cm/s) required for deceleration detection. Below this threshold, speed ratios are unreliable due to noise, so confidence stays at 1.0 (full extrapolation). Prevents false deceleration detection during slow/small movements. 0 = disabled.", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_AdaptiveExtrapolation", ClampMin = "0.0", ClampMax = "200.0"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Minimum average speed (cm/s) required for deceleration detection. Below this threshold, speed ratios are unreliable due to noise, so confidence stays at 1.0 (full extrapolation). Prevents false deceleration detection during slow/small movements. 0 = disabled.", EditCondition = "AntiRecoilMode == EAntiRecoilMode::ARM_AdaptiveExtrapolation", ClampMin = "0.0", ClampMax = "200.0"))
|
||||||
float AdaptiveMinSpeed = 30.0f;
|
float AdaptiveMinSpeed = 0.0f;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Velocity damping during extrapolation. 0 = disabled (default). Higher values cause extrapolated velocity to decay exponentially toward zero over the discard window. Reduces overshoot on fast draw-aim-fire sequences where the user stops moving before firing. Applies to all prediction modes except Buffer. Typical range: 5-15.", ClampMin = "0.0", ClampMax = "50.0"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AntiRecoil", meta = (ToolTip = "Velocity damping during extrapolation. 0 = disabled (default). Higher values cause extrapolated velocity to decay exponentially toward zero over the discard window. Reduces overshoot on fast draw-aim-fire sequences where the user stops moving before firing. Applies to all prediction modes except Buffer. Typical range: 5-15.", ClampMin = "0.0", ClampMax = "50.0"))
|
||||||
float ExtrapolationDamping = 8.0f;
|
float ExtrapolationDamping = 5.0f;
|
||||||
|
|
||||||
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AimStabilization", meta = (ToolTip = "Angular dead zone in degrees. If the aim direction changes by less than this angle since the last stable aim, the change is ignored (aim stays locked). Eliminates micro-jitter from VR tracker vibrations. 0 = disabled. Typical: 0.1 to 0.5 degrees.", ClampMin = "0.0", ClampMax = "5.0"))
|
UPROPERTY(BlueprintReadWrite, EditAnywhere, Category = "AimStabilization", meta = (ToolTip = "Angular dead zone in degrees. If the aim direction changes by less than this angle since the last stable aim, the change is ignored (aim stays locked). Eliminates micro-jitter from VR tracker vibrations. 0 = disabled. Typical: 0.1 to 0.5 degrees.", ClampMin = "0.0", ClampMax = "5.0"))
|
||||||
float AimDeadZoneDegrees = 0.0f;
|
float AimDeadZoneDegrees = 0.0f;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user