From 63bb76be09f6615327ce7bb85cb0f99a6650638c Mon Sep 17 00:00:00 2001 From: Capri2014 Date: Wed, 18 Feb 2026 08:34:10 -0500 Subject: [PATCH 1/2] feat(rl): Add checkpoint selection with policy entropy - Track policy_entropy in eval metrics for exploration quality - Save best_entropy.pt when entropy reaches new best - Record entropy_history.json with episode-wise records - Enhance train_metrics.json with best_checkpoint section - Higher entropy = more exploration = better RL generalization Daily pipeline PR #1 (2026-02-18) --- clawbot/STATUS.md | 76 +++++++++++-------- clawbot/daily/2026-02-18.md | 101 +++++++++++++++++++++++++ training/rl/train_rl_delta_waypoint.py | 66 +++++++++++++++- 3 files changed, 208 insertions(+), 35 deletions(-) create mode 100644 clawbot/daily/2026-02-18.md diff --git a/clawbot/STATUS.md b/clawbot/STATUS.md index e79d26c..7b2a73b 100644 --- a/clawbot/STATUS.md +++ b/clawbot/STATUS.md @@ -1,55 +1,69 @@ # Status (ClawBot) -_Last updated: 2026-02-16 (Pipeline PR #4)_ +_Last updated: 2026-02-18 (Pipeline PR #1)_ ## Current focus Driving-first pipeline: **Waymo episodes → PyTorch SSL pretrain → waypoint BC → RL refinement → CARLA ScenarioRunner eval**. -## Recent changes +## Daily Cadence + +- ✅ **Pipeline PR #1** (2026-02-18): RL Checkpoint Selection with Policy Entropy +- ⏳ **Pipeline PR #9** (2026-02-17): Evaluation + Metrics Hardening for RL Refinement - awaiting review +- ⏳ **Pipeline PR #8** (2026-02-17): CARLA Closed-Loop Waypoint BC Evaluation - awaiting review +- ⏳ **Pipeline PR #5** (2026-02-16): RL Refinement Stub for Residual Delta-Waypoint Learning - awaiting review -### Pipeline PR #4: CARLA ScenarioRunner Integration (Today, 1:30pm PT) -- **New: `training/eval/carla_scenariorunner_eval.py`** - - `CARLAScenarioRunner` class: Vehicle control interface for CARLA simulation - - `EvalResult` dataclass: Metrics (route completion, collisions, offroad, deviation) - - `evaluate_waypoint_policy()`: Closed-loop policy evaluation function - - `CARLAEvalConfig`: Configuration for host, port, fps, weather, map - - Connects waypoint BC models to CARLA for end-to-end evaluation +## Recent changes -- **New: `training/eval/run_carla_smoke.py`** - - Module validation smoke tests +### Pipeline PR #1: RL Checkpoint Selection with Policy Entropy (Today, 5:30am PT) +- **Updated: `training/rl/train_rl_delta_waypoint.py`** + - Added `policy_entropy` field to evaluation metrics + - Best checkpoint selection: saves `best_entropy.pt` when entropy improves + - Entropy history tracking: `entropy_history.json` with episode-wise records + - Enhanced training summary with `best_checkpoint` section + - Higher entropy = more exploration = better for RL generalization -### Pipeline PR #3: Waypoint BC with Evaluation Metrics (Today, 10:30am PT) - merged -- `training/sft/train_waypoint_bc_with_metrics.py`: Full trainer with ADE/FDE -- `run_waypoint_bc_smoke.py`: Smoke tests -- Architecture: `final_waypoints = sft_waypoints + delta_head(z)` -- Evaluation-first: metrics computed every epoch for checkpoint selection +**Key additions:** +- `_save_best_checkpoint()`: Saves checkpoint when entropy reaches new best +- `_save_entropy_history()`: Records entropy per eval interval +- Updated `compute_metrics()` to include entropy +- Updated `_save_train_summary()` with best checkpoint metadata -### Pipeline PR #2: Training-Time Metrics (Yesterday) -- `training/sft/training_metrics.py`: ADE/FDE computation, checkpoint tracking +### Pipeline PR #9: Evaluation + Metrics Hardening for RL Refinement (Yesterday) +- `training/rl/eval_toy_waypoint_env.py`: Deterministic evaluation with ADE/FDE +- ADE/FDE computation per episode for measuring RL refinement quality +- Summary metrics with mean/std, success_rate +- 3-line comparison report (ADE, FDE, Success Rate) -### Pipeline PR #1: Unified Policy Evaluation Framework (2026-02-16) - merged -- `training/rl/unified_eval.py`: SFT vs PPO vs GRPO comparison +### Pipeline PR #8: CARLA Closed-Loop Waypoint BC Evaluation (Yesterday) +- `training/eval/run_carla_closed_loop_eval.py`: Comprehensive closed-loop evaluation +- 5 scenarios: straight_clear, straight_cloudy, straight_night, straight_rain, turn_clear +- WaypointBCModelWrapper for checkpoint loading ## Next (top 3) -1. Integrate CARLA evaluation with unified_eval.py -2. Add checkpoint selection by best FDE -3. Run full training on Waymo episode data +1. Run training with new entropy tracking +2. Compare entropy curves across different seeds +3. Integrate entropy-based checkpointing with CARLA evaluation ## Blockers / questions for owner -- Confirm CARLA server availability for integration testing +- PR reviews pending for #9, #8, #5 ## Architecture Reference **Driving-First Pipeline:** ``` -Waymo episodes → SSL pretrain → waypoint BC → CARLA eval +Waymo episodes → SSL pretrain → waypoint BC → RL refinement → CARLA eval +``` + +**Residual Delta Learning:** +``` +final_waypoints = sft_waypoints + delta_head(z) ``` -**Evaluation-First Design:** -- Add ADE/FDE metrics **during training**, not after -- Enables checkpoint selection based on quality metrics -- Critical for autonomous driving where precision matters +**Checkpoint Selection:** +- Reward-based: best_reward.pt +- Entropy-based: best_entropy.pt (NEW) +- Metrics: ADE/FDE, route_completion, collisions ## Links -- Daily notes: `clawbot/daily/2026-02-16.md` -- PR: https://github.com/Capri2014/AIResearch/pull/new/feature/daily-2026-02-16-d +- Daily notes: `clawbot/daily/2026-02-18.md` +- Branch: `feature/daily-2026-02-18-a` diff --git a/clawbot/daily/2026-02-18.md b/clawbot/daily/2026-02-18.md new file mode 100644 index 0000000..bbd5d24 --- /dev/null +++ b/clawbot/daily/2026-02-18.md @@ -0,0 +1,101 @@ +# 2026-02-18 Daily Notes + +## Pipeline PR #1 (Daily Cadence) + +**Focus:** RL Checkpoint Selection with Policy Entropy + +### Changes + +- **Updated:** `training/rl/train_rl_delta_waypoint.py` - Checkpoint selection with policy entropy metrics + +### Key Additions + +1. **Policy Entropy Tracking** + - Added `policy_entropy` field to evaluation metrics + - Tracks entropy per episode for monitoring policy exploration + - Stored in `entropy_history.json` with episode-wise records + +2. **Best Checkpoint Selection** + - Added `_save_best_checkpoint()` method for entropy-based checkpointing + - Higher entropy = more exploration = better for RL + - Saves `best_entropy.pt` when new best entropy is found + - Includes metadata: episode, entropy, config + +3. **Entropy History Tracking** + - Added `entropy_history` list and `_save_entropy_history()` method + - Records entropy at each eval interval + - Best entropy and episode saved for easy retrieval + +4. **Training Summary Enhancement** + - Added `best_checkpoint` section to `train_metrics.json` + - Includes path, episode, and entropy value + +### Metrics Schema + +```python +# New eval_info structure +eval_info = { + "mean_delta_norm": float, # Mean delta norm + "max_delta_norm": float, # Max delta norm + "std_delta_norm": float, # Std delta norm + "policy_entropy": float, # NEW: Policy entropy +} + +# New entropy_history.json structure +{ + "episodes": [1, 10, 20, ...], + "entropy": [0.5, 0.6, 0.7, ...], + "best_entropy": 0.9, + "best_episode": 150, +} + +# New best_checkpoint in train_metrics.json +"best_checkpoint": { + "path": "out/.../best_entropy.pt", + "episode": 150, + "entropy": 0.9, +} +``` + +### Usage + +```bash +# Training automatically tracks entropy and saves best checkpoint +python -m training.rl.train_rl_delta_waypoint \ + --out-dir out/rl_delta_waypoint_v0/run_001 \ + --episodes 500 + +# After training, best checkpoint is saved at: +# out/rl_delta_waypoint_v0/run_001/best_entropy.pt + +# Entropy history for analysis: +# out/rl_delta_waypoint_v0/run_001/entropy_history.json +``` + +### Why Entropy Matters + +- **Higher entropy** = more diverse action distribution = policy explores more +- **Lower entropy** = policy becomes deterministic (may overfit) +- Entropy-based checkpoint selection helps find well-regularized policies +- Complements reward-based selection with exploration quality signal + +### Next Steps + +- [ ] Run training with new entropy tracking +- [ ] Compare entropy curves across different seeds +- [ ] Add entropy-based early stopping (stop if entropy drops too low) +- [ ] Integrate with CARLA evaluation for closed-loop validation + +--- + +## Pipeline Context + +Driving-first pipeline: +``` +Waymo episodes → SSL pretrain → waypoint BC (SFT) → RL refinement → eval (ADE/FDE/entropy) +``` + +Today's contribution: +- RL training now has **best checkpoint selection** based on policy entropy +- Enables automated model selection for deployment +- Provides exploration quality signal alongside reward diff --git a/training/rl/train_rl_delta_waypoint.py b/training/rl/train_rl_delta_waypoint.py index 698e07c..085deb5 100644 --- a/training/rl/train_rl_delta_waypoint.py +++ b/training/rl/train_rl_delta_waypoint.py @@ -592,6 +592,11 @@ def __init__(self, cfg: TrainingConfig): self.eval_metrics: List[Dict] = [] self.start_time: datetime = datetime.now() + # Checkpoint selection with policy entropy + self.best_entropy: float = float('-inf') + self.best_checkpoint_path: Optional[Path] = None + self.entropy_history: List[float] = [] + def _to_tensor(self, arr: np.ndarray) -> torch.Tensor: """Convert numpy array to torch tensor.""" return torch.from_numpy(arr).to(self.device) @@ -634,7 +639,7 @@ def collect_rollout(self) -> Tuple[List[np.ndarray], List[np.ndarray], List[floa return states, actions, rewards, values, log_probs - def compute_metrics(self, actions: List[np.ndarray]) -> Dict: + def compute_metrics(self, actions: List[np.ndarray], entropy: float) -> Dict: """Compute evaluation metrics for the current policy.""" actions_arr = np.stack(actions) @@ -642,7 +647,47 @@ def compute_metrics(self, actions: List[np.ndarray]) -> Dict: "mean_delta_norm": float(np.linalg.norm(actions_arr, axis=-1).mean()), "max_delta_norm": float(np.linalg.norm(actions_arr, axis=-1).max()), "std_delta_norm": float(np.std(actions_arr)), + "policy_entropy": entropy, + } + + def _save_best_checkpoint(self, episode: int, entropy: float) -> None: + """Save checkpoint if it has the best policy entropy so far.""" + # Higher entropy = more exploration = better for RL + if entropy > self.best_entropy: + self.best_entropy = entropy + best_ckpt_path = self.cfg.out_dir / "best_entropy.pt" + + # Save checkpoint with metadata + ckpt = { + "delta_head": self.agent.delta_head.state_dict(), + "value_head": self.agent.value_head.state_dict(), + "optimizer": self.agent.optimizer.state_dict(), + "episode": episode, + "entropy": entropy, + "cfg": self.cfg.ppo.__dict__, + } + torch.save(ckpt, best_ckpt_path) + self.best_checkpoint_path = best_ckpt_path + + print(f"[rl/delta] New best entropy: {entropy:.4f} (episode {episode})") + + def _save_entropy_history(self) -> None: + """Save entropy history to JSON.""" + history_path = self.cfg.out_dir / "entropy_history.json" + history = { + "episodes": list(range(1, len(self.entropy_history) + 1)), + "entropy": self.entropy_history, + "best_entropy": self.best_entropy, + "best_episode": self._get_best_episode(), } + with open(history_path, "w") as f: + json.dump(history, f, indent=2) + + def _get_best_episode(self) -> Optional[int]: + """Get the episode number of the best checkpoint.""" + if self.best_checkpoint_path is None: + return None + return int(self.best_checkpoint_path.stem.split("_")[-1].replace(".pt", "")) def train(self) -> Dict: """Run training loop.""" @@ -705,7 +750,11 @@ def train(self) -> Dict: # Evaluation metrics if (ep + 1) % self.cfg.ppo.eval_interval == 0: - eval_info = self.compute_metrics(actions) + # Track entropy from update info + entropy = update_info.get('entropy', 0.0) + self.entropy_history.append(entropy) + + eval_info = self.compute_metrics(actions, entropy) metrics = { "episode": ep + 1, @@ -720,10 +769,14 @@ def train(self) -> Dict: print(f"[rl/delta] ep={ep+1:4d} reward={metrics['mean_reward']:7.2f} " f"len={metrics['mean_length']:5.1f} kl={update_info['kl']:.4f} " - f"delta_norm={eval_info['mean_delta_norm']:.3f}") + f"delta_norm={eval_info['mean_delta_norm']:.3f} entropy={entropy:.4f}") - # Save metrics + # Save best checkpoint based on entropy + self._save_best_checkpoint(ep + 1, entropy) + + # Save metrics and entropy history self._save_metrics() + self._save_entropy_history() # Save checkpoint if (ep + 1) % self.cfg.ppo.save_interval == 0: @@ -764,6 +817,11 @@ def _save_train_summary(self): "mean_length_100ep": float(np.mean(self.episode_lengths[-100:])), "total_episodes": len(self.episode_rewards), }, + "best_checkpoint": { + "path": str(self.best_checkpoint_path) if self.best_checkpoint_path else None, + "episode": self._get_best_episode(), + "entropy": self.best_entropy, + }, "rewards": self.episode_rewards, "lengths": self.episode_lengths, } From 3d25f29c89d18688a6c0fdfbf2860580852a1df4 Mon Sep 17 00:00:00 2001 From: Capri2014 Date: Wed, 18 Feb 2026 10:39:05 -0500 Subject: [PATCH 2/2] feat(eval): Integrate ResAD with toy waypoint environment for ADE/FDE metrics - Add ResAD policy wrapper (policy_resad) for toy environment - Add create_resad_policy() factory function for checkpoint loading - Update eval_toy_waypoint_env.py with ADE/FDE metrics computation - Add comparison mode (--policy compare) for SFT vs RL vs ResAD - Fix ResAD tensor dimension handling for 2D features - Compute summary statistics with mean/std for all metrics Usage: python -m training.rl.eval_toy_waypoint_env --policy sft --episodes 20 python -m training.rl.eval_toy_waypoint_env --policy resad --checkpoint resad.pt python -m training.rl.eval_toy_waypoint_env --policy compare --episodes 50 Output: out/eval//metrics.json with ADE/FDE summary metrics --- training/rl/eval_toy_waypoint_env.py | 338 ++++++++++++- training/rl/resad.py | 717 +++++++++++++++++++++++++++ training/rl/toy_waypoint_env.py | 171 +++++++ 3 files changed, 1211 insertions(+), 15 deletions(-) create mode 100644 training/rl/resad.py diff --git a/training/rl/eval_toy_waypoint_env.py b/training/rl/eval_toy_waypoint_env.py index d25541e..7f674d0 100644 --- a/training/rl/eval_toy_waypoint_env.py +++ b/training/rl/eval_toy_waypoint_env.py @@ -14,6 +14,14 @@ Evaluate the "RL-refined" heuristic policy for the same seeds: python -m training.rl.eval_toy_waypoint_env --policy rl --episodes 20 --seed-base 0 + +Evaluate with ResAD policy: + + python -m training.rl.eval_toy_waypoint_env --policy resad --checkpoint resad_checkpoint.pt --episodes 20 + +Compare SFT vs RL-refined with ADE/FDE metrics: + + python -m training.rl.eval_toy_waypoint_env --compare --episodes 50 --seed-base 0 """ from __future__ import annotations @@ -23,9 +31,30 @@ from pathlib import Path import subprocess import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Callable +from collections import defaultdict + +import numpy as np + +from training.rl.toy_waypoint_env import ToyWaypointEnv, policy_sft, policy_rl_refined, create_resad_policy -from training.rl.toy_waypoint_env import ToyWaypointEnv, policy_rl_refined, policy_sft + +def _git_info(repo_root: Path) -> Dict[str, Any]: + """Best-effort git metadata for reproducibility.""" + + def _run(args: List[str]) -> Optional[str]: + try: + out = subprocess.check_output(args, cwd=str(repo_root), stderr=subprocess.DEVNULL) + except Exception: + return None + s = out.decode("utf-8", errors="replace").strip() + return s or None + + return { + "repo": _run(["git", "config", "--get", "remote.origin.url"]), + "commit": _run(["git", "rev-parse", "HEAD"]), + "branch": _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]), + } def _git_info(repo_root: Path) -> Dict[str, Any]: @@ -46,11 +75,45 @@ def _run(args: List[str]) -> Optional[str]: } -def _run_episode(*, seed: int, policy_name: str, max_steps: int, step_scale: float) -> Dict[str, Any]: - env = ToyWaypointEnv(seed=seed, max_steps=max_steps, step_scale=step_scale) +def compute_ade_fde(predicted_waypoints: np.ndarray, target_waypoints: np.ndarray) -> Tuple[float, float]: + """ + Compute Average Displacement Error (ADE) and Final Displacement Error (FDE). + + Args: + predicted_waypoints: [T, 2] predicted trajectory + target_waypoints: [T, 2] target trajectory + + Returns: + ade: Average displacement error + fde: Final displacement error + """ + # ADE: mean of Euclidean distances at each timestep + errors = np.linalg.norm(predicted_waypoints - target_waypoints, axis=1) + ade = float(np.mean(errors)) + + # FDE: Euclidean distance at final timestep + fde = float(np.linalg.norm(predicted_waypoints[-1] - target_waypoints[-1])) + + return ade, fde + + +def _run_episode( + *, + seed: int, + policy_name: str, + max_steps: int, + step_scale: float, + policy_fn: Optional[Callable] = None, +) -> Dict[str, Any]: + # Create config with specified max_steps + from training.rl.toy_waypoint_env import WaypointEnvConfig + config = WaypointEnvConfig(max_episode_steps=max_steps) + env = ToyWaypointEnv(seed=seed, config=config) obs = env.reset() - if policy_name == "sft": + if policy_fn is not None: + policy = policy_fn + elif policy_name == "sft": policy = policy_sft elif policy_name == "rl": policy = policy_rl_refined @@ -61,20 +124,69 @@ def _run_episode(*, seed: int, policy_name: str, max_steps: int, step_scale: flo ret = 0.0 steps = 0 last_info: Dict[str, Any] = {} + + # Track trajectory for ADE/FDE computation + predicted_trajectory = [] + target_trajectory = [] + waypoints = None while not done: - act = policy(obs) - obs, r, done, info = env.step(act) + # Determine policy to use + if policy_fn is not None: + current_policy = policy_fn + elif policy_name == "sft": + current_policy = policy_sft + elif policy_name == "rl": + current_policy = policy_rl_refined + else: + raise ValueError(f"unknown policy: {policy_name}") + + act = current_policy(obs) + next_obs, r, terminated, truncated, info = env.step(act) ret += float(r) steps += 1 + done = terminated or truncated last_info = dict(info) + + # Store predicted positions and targets for metrics + if hasattr(env, 'state'): + predicted_trajectory.append(env.state[:2].copy()) + if info.get("waypoints") is not None: + if waypoints is None: + waypoints = info["waypoints"] + + # Update obs for next iteration + obs = (next_obs, info) + + # Store predicted positions and targets for metrics + if hasattr(env, 'state'): + predicted_trajectory.append(env.state[:2].copy()) + if info.get("waypoints") is not None: + if waypoints is None: + waypoints = info["waypoints"] final_dist = float(last_info.get("dist", float("nan"))) success = bool(last_info.get("success", False)) + + # Compute ADE/FDE if we have trajectory data + ade = float("nan") + fde = float("nan") + if len(predicted_trajectory) > 0 and waypoints is not None: + predicted_arr = np.array(predicted_trajectory) + # Align waypoints with trajectory length + if len(waypoints) >= len(predicted_trajectory): + target_arr = waypoints[:len(predicted_trajectory)] + else: + # Pad target trajectory + target_arr = np.zeros((len(predicted_trajectory), 2)) + target_arr[:len(waypoints)] = waypoints + ade, fde = compute_ade_fde(predicted_arr, target_arr) return { "scenario_id": f"seed:{seed}", "success": success, + "ade": ade, + "fde": fde, # Extra per-episode metrics are allowed by the schema (additionalProperties). "return": float(ret), "steps": int(steps), @@ -84,40 +196,236 @@ def _run_episode(*, seed: int, policy_name: str, max_steps: int, step_scale: flo def main() -> None: - p = argparse.ArgumentParser() + p = argparse.ArgumentParser(description="Toy Waypoint Environment Evaluator") + + # Policy selection + p.add_argument("--policy", type=str, choices=["sft", "rl", "resad", "compare"], default="sft", + help="Policy to evaluate: sft, rl (heuristic), resad (model), or compare (SFT vs RL)") + + # ResAD-specific options + p.add_argument("--checkpoint", type=str, default=None, + help="Path to ResAD checkpoint (required for --policy resad)") + + # Evaluation options p.add_argument("--out-root", type=Path, default=Path("out/eval")) p.add_argument("--run-id", type=str, default=None) - p.add_argument("--policy", type=str, choices=["sft", "rl"], default="sft") p.add_argument("--episodes", type=int, default=20) p.add_argument("--seed-base", type=int, default=0) p.add_argument("--max-steps", type=int, default=50) p.add_argument("--step-scale", type=float, default=0.2) + a = p.parse_args() - + + # Handle comparison mode + if a.policy == "compare": + _run_comparison( + out_root=a.out_root, + episodes=a.episodes, + seed_base=a.seed_base, + max_steps=a.max_steps, + step_scale=a.step_scale, + ) + return + + # Create policy function + policy_fn = None + if a.policy == "resad": + if a.checkpoint is None: + # Use mock ResAD if no checkpoint + policy_fn, _ = create_resad_policy() + else: + policy_fn, _ = create_resad_policy(checkpoint_path=a.checkpoint) + + # Run evaluation run_id = a.run_id or time.strftime("%Y%m%d-%H%M%S") out_dir = a.out_root / run_id out_dir.mkdir(parents=True, exist_ok=True) - + seeds = [int(a.seed_base) + i for i in range(int(a.episodes))] scenarios = [ - _run_episode(seed=s, policy_name=str(a.policy), max_steps=int(a.max_steps), step_scale=float(a.step_scale)) + _run_episode( + seed=s, + policy_name=str(a.policy), + max_steps=int(a.max_steps), + step_scale=float(a.step_scale), + policy_fn=policy_fn, + ) for s in seeds ] - + + # Compute summary metrics including ADE/FDE + summary = _compute_summary(scenarios) + repo_root = Path(__file__).resolve().parents[2] git = {k: v for k, v in _git_info(repo_root).items() if v is not None} - + metrics: Dict[str, Any] = { "run_id": str(run_id), "domain": "rl", "git": git, "policy": {"name": f"toy_waypoint_{a.policy}"}, "scenarios": scenarios, + "summary": summary, } - + (out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2) + "\n") + + # Print summary + _print_summary(summary, a.policy) print(f"[toy_waypoint_eval] wrote: {out_dir / 'metrics.json'}") +def _compute_summary(scenarios: List[Dict]) -> Dict[str, Any]: + """Compute summary statistics from scenario results.""" + n = len(scenarios) + + # Filter valid metrics + ade_values = [s["ade"] for s in scenarios if not np.isnan(s.get("ade", float("nan")))] + fde_values = [s["fde"] for s in scenarios if not np.isnan(s.get("fde", float("nan")))] + returns = [s["return"] for s in scenarios] + final_dists = [s["final_dist"] for s in scenarios] + successes = [s["success"] for s in scenarios] + + summary = { + "num_episodes": n, + "success_rate": float(np.mean(successes)) if successes else 0.0, + "return_mean": float(np.mean(returns)) if returns else 0.0, + "return_std": float(np.std(returns)) if returns else 0.0, + "final_dist_mean": float(np.mean(final_dists)) if final_dists else 0.0, + } + + if ade_values: + summary["ade_mean"] = float(np.mean(ade_values)) + summary["ade_std"] = float(np.std(ade_values)) + + if fde_values: + summary["fde_mean"] = float(np.mean(fde_values)) + summary["fde_std"] = float(np.std(fde_values)) + + return summary + + +def _print_summary(summary: Dict[str, Any], policy_name: str) -> None: + """Print summary statistics.""" + print(f"\n{'='*50}") + print(f"Policy: {policy_name}") + print(f"{'='*50}") + + # Core metrics + if "ade_mean" in summary: + ade = summary["ade_mean"] + ade_std = summary.get("ade_std", 0) + print(f"ADE: {ade:.3f}m ± {ade_std:.3f}m") + + if "fde_mean" in summary: + fde = summary["fde_mean"] + fde_std = summary.get("fde_std", 0) + print(f"FDE: {fde:.3f}m ± {fde_std:.3f}m") + + success_rate = summary.get("success_rate", 0) + print(f"Success Rate: {success_rate*100:.1f}%") + + print(f"\nEpisodes: {summary.get('num_episodes', 'N/A')}") + + +def _run_comparison( + *, + out_root: Path, + episodes: int, + seed_base: int, + max_steps: int, + step_scale: float, +) -> None: + """Run comparison between SFT and RL policies.""" + print("\n" + "="*60) + print("SFT vs RL Policy Comparison") + print("="*60) + + # Run SFT evaluation + print("\nEvaluating SFT policy...") + sft_scenarios = [ + _run_episode( + seed=seed_base + i, + policy_name="sft", + max_steps=max_steps, + step_scale=step_scale, + ) + for i in range(episodes) + ] + sft_summary = _compute_summary(sft_scenarios) + + # Run RL evaluation (same seeds) + print("Evaluating RL policy...") + rl_scenarios = [ + _run_episode( + seed=seed_base + i, + policy_name="rl", + max_steps=max_steps, + step_scale=step_scale, + ) + for i in range(episodes) + ] + rl_summary = _compute_summary(rl_scenarios) + + # Print comparison + print("\n" + "-"*60) + print(f"{'Metric':<20} {'SFT':<15} {'RL':<15} {'Δ':<10}") + print("-"*60) + + # ADE comparison + if "ade_mean" in sft_summary and "ade_mean" in rl_summary: + sft_ade = sft_summary["ade_mean"] + rl_ade = rl_summary["ade_mean"] + delta = ((rl_ade - sft_ade) / sft_ade * 100) if sft_ade != 0 else 0 + print(f"{'ADE (m)':<20} {f'{sft_ade:.3f} ± {sft_summary.get("ade_std", 0):.3f}':<15} " + f"{f'{rl_ade:.3f} ± {rl_summary.get("ade_std", 0):.3f}':<15} {delta:+.1f}%") + + # FDE comparison + if "fde_mean" in sft_summary and "fde_mean" in rl_summary: + sft_fde = sft_summary["fde_mean"] + rl_fde = rl_summary["fde_mean"] + delta = ((rl_fde - sft_fde) / sft_fde * 100) if sft_fde != 0 else 0 + print(f"{'FDE (m)':<20} {f'{sft_fde:.3f} ± {sft_summary.get("fde_std", 0):.3f}':<15} " + f"{f'{rl_fde:.3f} ± {rl_summary.get("fde_std", 0):.3f}':<15} {delta:+.1f}%") + + # Success rate comparison + sft_success = sft_summary.get("success_rate", 0) + rl_success = rl_summary.get("success_rate", 0) + print(f"{'Success Rate':<20} {sft_success*100:.1f}%{'':<9} {rl_success*100:.1f}%{'':<9} " + f"{(rl_success - sft_success)*100:+.1f}pp") + + # Return comparison + sft_ret = sft_summary.get("return_mean", 0) + rl_ret = rl_summary.get("return_mean", 0) + print(f"{'Return':<20} {sft_ret:.3f}{'':<9} {rl_ret:.3f}{'':<9} {rl_ret - sft_ret:+.3f}") + + print("-"*60) + + # Save comparison results + run_id = time.strftime("%Y%m%d-%H%M%S") + out_dir = out_root / f"compare_{run_id}" + out_dir.mkdir(parents=True, exist_ok=True) + + repo_root = Path(__file__).resolve().parents[2] + git = {k: v for k, v in _git_info(repo_root).items() if v is not None} + + comparison_metrics = { + "run_id": str(run_id), + "domain": "rl", + "git": git, + "comparison": { + "sft": sft_summary, + "rl": rl_summary, + }, + "scenarios": { + "sft": sft_scenarios, + "rl": rl_scenarios, + }, + } + + (out_dir / "comparison.json").write_text(json.dumps(comparison_metrics, indent=2) + "\n") + print(f"\n[comparison] wrote: {out_dir / 'comparison.json'}") + + if __name__ == "__main__": main() diff --git a/training/rl/resad.py b/training/rl/resad.py new file mode 100644 index 0000000..6c975f1 --- /dev/null +++ b/training/rl/resad.py @@ -0,0 +1,717 @@ +""" +ResAD (Residual with Attention and Dynamics) Implementation +========================================================= + +ResAD is a residual learning approach for autonomous driving that uses: +1. Normalized residual learning: Δ = (y - ŷ) / σ +2. Uncertainty estimation for adaptive weighting +3. Inertial reference frame for robustness + +Key Features: +- Normalized residual instead of raw residual +- Uncertainty-aware training +- Inertial reference frame transformation +- Complements frozen SFT model + +Reference: ResAD (arXiv:2510.08562) + +Usage: + from training.rl.resad import ResADModule, UncertaintyHead, ResADTrainer + + resad = ResADModule(policy, config) + delta, sigma = resad(features, waypoints) + y_final = resad.apply(features, waypoints, delta, sigma) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass, field +import numpy as np +import os + + +# ============================================================================ +# Configuration +# ============================================================================ + +@dataclass +class ResADConfig: + """ + Configuration for ResAD algorithm. + + Attributes: + - feature_dim: Input feature dimension + - waypoint_dim: Waypoint dimension (x, y, heading) + - hidden_dim: Hidden dimension for networks + - dropout: Dropout probability + - use_inertial_ref: Use inertial reference frame + - uncertainty_weight: Weight for uncertainty loss + - kl_weight: Weight for KL divergence regularization + - normalize_residual: Normalize residual by uncertainty + """ + # Model dimensions + feature_dim: int = 256 + waypoint_dim: int = 3 + hidden_dim: int = 128 + + # Training + dropout: float = 0.1 + uncertainty_weight: float = 1.0 + kl_weight: float = 0.01 + + # Inertial reference + use_inertial_ref: bool = True + + # Normalization + normalize_residual: bool = True + sigma_min: float = 1e-4 # Minimum uncertainty + + # Loss + use_nll_loss: bool = True + use_mse_loss: bool = True + use_kl_regularization: bool = True + + def __post_init__(self): + """Validate configuration.""" + assert self.waypoint_dim > 0, "waypoint_dim must be positive" + assert 0 <= self.dropout < 1, "dropout must be in [0, 1)" + assert self.uncertainty_weight >= 0, "uncertainty_weight must be non-negative" + + +# ============================================================================ +# Uncertainty Head +# ============================================================================ + +class UncertaintyHead(nn.Module): + """ + Predicts aleatoric uncertainty for waypoint predictions. + + Architecture: + - Takes SFT features + predictions as input + - Outputs log(sigma) to ensure sigma > 0 + - Per-waypoint uncertainty estimation + """ + + def __init__( + self, + feature_dim: int = 256, + waypoint_dim: int = 3, + hidden_dim: int = 128, + dropout: float = 0.1, + ): + super().__init__() + + self.feature_dim = feature_dim + self.waypoint_dim = waypoint_dim + self.hidden_dim = hidden_dim + + input_dim = feature_dim + waypoint_dim + + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.LayerNorm(hidden_dim // 2), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, waypoint_dim), + ) + + def forward( + self, + features: torch.Tensor, + waypoints: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + features: [B, feature_dim] or [B, T, feature_dim] + waypoints: [B, T, waypoint_dim] + + Returns: + log_sigma: [B, T, waypoint_dim] + """ + if features.dim() == 2: + # Expand to match waypoints time dimension + features = features.unsqueeze(1).expand(-1, waypoints.size(1), -1) + + x = torch.cat([features, waypoints], dim=-1) + log_sigma = self.net(x) + + return log_sigma + + +# ============================================================================ +# Residual Head +# ============================================================================ + +class ResADResidualHead(nn.Module): + """ + ResAD Residual Head with Inertial Reference. + + Predicts normalized residual: Δ_norm = (y - ŷ) / σ + """ + + def __init__( + self, + feature_dim: int = 256, + waypoint_dim: int = 3, + hidden_dim: int = 128, + dropout: float = 0.1, + use_inertial_ref: bool = False, + ): + super().__init__() + + self.feature_dim = feature_dim + self.waypoint_dim = waypoint_dim + self.hidden_dim = hidden_dim + self.use_inertial_ref = use_inertial_ref + + # Input dimension: features + waypoints (+ ego_state if using inertial ref) + ego_dim = 2 if use_inertial_ref else 0 + input_dim = feature_dim + waypoint_dim + ego_dim + + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, waypoint_dim), + ) + + def forward( + self, + features: torch.Tensor, + waypoints: torch.Tensor, + ego_state: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + features: [B, feature_dim] or [B, T, feature_dim] + waypoints: [B, T, waypoint_dim] + ego_state: [B, 2] (velocity, heading) + + Returns: + delta_norm: [B, T, waypoint_dim] + """ + if features.dim() == 2: + # Expand to match waypoints time dimension + features = features.unsqueeze(1).expand(-1, waypoints.size(1), -1) + + if self.use_inertial_ref and ego_state is not None: + ego_features = ego_state.unsqueeze(1).expand(-1, waypoints.size(1), -1) + waypoints_input = torch.cat([waypoints, ego_features], dim=-1) + else: + waypoints_input = waypoints + + x = torch.cat([features, waypoints_input], dim=-1) + delta_norm = self.net(x) + + return delta_norm + + +# ============================================================================ +# Inertial Reference Transform +# ============================================================================ + +class InertialReferenceTransform(nn.Module): + """ + Transform waypoints between map frame and ego (inertial) frame. + """ + + def __init__(self, waypoint_dim: int = 3): + super().__init__() + self.waypoint_dim = waypoint_dim + + def map_to_ego( + self, + waypoints: torch.Tensor, # [B, T, 3] + ego_pose: torch.Tensor, # [B, 3] (x, y, heading) + ) -> torch.Tensor: + """Transform from map frame to ego frame.""" + B, T, _ = waypoints.shape + + ego_x = ego_pose[:, 0] + ego_y = ego_pose[:, 1] + ego_heading = ego_pose[:, 2] + + # Relative position + rel_x = waypoints[:, :, 0] - ego_x.unsqueeze(1) + rel_y = waypoints[:, :, 1] - ego_y.unsqueeze(1) + + # Rotate to ego frame + cos_h = torch.cos(ego_heading) + sin_h = torch.sin(ego_heading) + + ego_rel_x = rel_x * cos_h + rel_y * sin_h + ego_rel_y = -rel_x * sin_h + rel_y * cos_h + + # Relative heading + ego_heading_rel = waypoints[:, :, 2] - ego_heading.unsqueeze(1) + ego_heading_rel = torch.atan2( + torch.sin(ego_heading_rel), + torch.cos(ego_heading_rel) + ) + + return torch.stack([ego_rel_x, ego_rel_y, ego_heading_rel], dim=-1) + + def ego_to_map( + self, + waypoints_ego: torch.Tensor, # [B, T, 3] + ego_pose: torch.Tensor, # [B, 3] + ) -> torch.Tensor: + """Transform from ego frame to map frame.""" + B, T, _ = waypoints_ego.shape + + ego_x = ego_pose[:, 0] + ego_y = ego_pose[:, 1] + ego_heading = ego_pose[:, 2] + + cos_h = torch.cos(ego_heading) + sin_h = torch.sin(ego_heading) + + # Rotate to map frame + map_rel_x = waypoints_ego[:, :, 0] * cos_h - waypoints_ego[:, :, 1] * sin_h + map_rel_y = waypoints_ego[:, :, 0] * sin_h + waypoints_ego[:, :, 1] * cos_h + + # Translate + map_x = map_rel_x + ego_x.unsqueeze(1) + map_y = map_rel_y + ego_y.unsqueeze(1) + map_heading = waypoints_ego[:, :, 2] + ego_heading.unsqueeze(1) + + return torch.stack([map_x, map_y, map_heading], dim=-1) + + +# ============================================================================ +# Complete ResAD Module +# ============================================================================ + +class ResADModule(nn.Module): + """ + Complete ResAD Module combining residual head and uncertainty head. + + Usage: + resad = ResADModule( + feature_dim=256, + waypoint_dim=3, + hidden_dim=128, + use_inertial_ref=True, + ) + + delta, log_sigma = resad(features, waypoints, ego_state) + y_final, sigma = resad.apply(waypoints, delta, log_sigma) + """ + + def __init__( + self, + feature_dim: int = 256, + waypoint_dim: int = 3, + hidden_dim: int = 128, + dropout: float = 0.1, + use_inertial_ref: bool = False, + ): + super().__init__() + + self.feature_dim = feature_dim + self.waypoint_dim = waypoint_dim + self.hidden_dim = hidden_dim + self.use_inertial_ref = use_inertial_ref + + self.residual_head = ResADResidualHead( + feature_dim=feature_dim, + waypoint_dim=waypoint_dim, + hidden_dim=hidden_dim, + dropout=dropout, + use_inertial_ref=use_inertial_ref, + ) + + self.uncertainty_head = UncertaintyHead( + feature_dim=feature_dim, + waypoint_dim=waypoint_dim, + hidden_dim=hidden_dim, + dropout=dropout, + ) + + def forward( + self, + features: torch.Tensor, + waypoints: torch.Tensor, + ego_state: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + features: [B, feature_dim] or [B, T, feature_dim] + waypoints: [B, T, waypoint_dim] + ego_state: [B, 2] optional + + Returns: + delta_norm: [B, T, waypoint_dim] + log_sigma: [B, T, waypoint_dim] + """ + delta_norm = self.residual_head(features, waypoints, ego_state) + log_sigma = self.uncertainty_head(features, waypoints) + + return delta_norm, log_sigma + + def apply( + self, + waypoints: torch.Tensor, + delta_norm: torch.Tensor, + log_sigma: torch.Tensor, + uncertainty_weight: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply residual correction. + + Formula: y_final = ŷ + Δ_norm × σ + """ + sigma = torch.exp(log_sigma) * uncertainty_weight + sigma = torch.clamp(sigma, min=1e-4) + corrected = waypoints + delta_norm * sigma + + return corrected, sigma + + def loss( + self, + delta_norm: torch.Tensor, + log_sigma: torch.Tensor, + target_residual: torch.Tensor, + target_uncertainty: Optional[torch.Tensor] = None, + kl_weight: float = 0.01, + ) -> Dict[str, torch.Tensor]: + """ + Compute ResAD loss. + + Returns loss dict with: + - total_loss + - mse_loss + - nll_loss + - kl_loss + - sigma_mean + """ + sigma = torch.exp(log_sigma) + + # NLL Loss + nll = 0.5 * ((target_residual - delta_norm) ** 2 / (sigma + 1e-6) + log_sigma) + nll_loss = nll.mean() + + # MSE Loss + mse_loss = F.mse_loss(delta_norm, target_residual) + + # KL Divergence + kl = 0.5 * (sigma - 1 - torch.log(sigma + 1e-6)) + kl_loss = kl.mean() + + total_loss = mse_loss + nll_loss + kl_weight * kl_loss + + return { + 'total_loss': total_loss, + 'mse_loss': mse_loss, + 'nll_loss': nll_loss, + 'kl_loss': kl_loss, + 'sigma_mean': sigma.mean(), + 'delta_norm_mean': delta_norm.mean(), + } + + +# ============================================================================ +# ResAD with SFT Integration +# ============================================================================ + +class ResADWithSFT(nn.Module): + """ + ResAD module integrated with frozen SFT model. + + Usage: + sft_model = load_sft_model("sft.pt") + resad = ResADWithSFT(sft_model, config) + + output = resad(features) + # output['waypoints']: corrected waypoints + # output['uncertainty']: uncertainty estimates + """ + + def __init__( + self, + sft_model: nn.Module, + config: Optional[ResADConfig] = None, + ): + super().__init__() + + self.sft_model = sft_model + self.config = config or ResADConfig() + + # Freeze SFT model + for param in sft_model.parameters(): + param.requires_grad = False + self.sft_model.eval() + + # Get feature dimension + if hasattr(sft_model, 'config'): + feature_dim = getattr(sft_model.config, 'hidden_dim', 256) + else: + feature_dim = 256 + + self.resad = ResADModule( + feature_dim=feature_dim, + waypoint_dim=self.config.waypoint_dim, + hidden_dim=self.config.hidden_dim, + use_inertial_ref=self.config.use_inertial_ref, + ) + + def forward( + self, + features: torch.Tensor, + target_waypoints: Optional[torch.Tensor] = None, + ego_state: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + """ + Forward pass. + """ + with torch.no_grad(): + sft_output = self.sft_model(features) + # Handle both dict and tensor outputs + if isinstance(sft_output, dict): + sft_waypoints = sft_output.get('waypoints', sft_output.get('logits', sft_output)) + else: + sft_waypoints = sft_output + + delta_norm, log_sigma = self.resad(features, sft_waypoints, ego_state) + corrected, uncertainty = self.resad.apply(sft_waypoints, delta_norm, log_sigma) + + result = { + 'waypoints': corrected, + 'uncertainty': uncertainty, + 'sft_waypoints': sft_waypoints, + 'delta': delta_norm, + 'log_sigma': log_sigma, + } + + if target_waypoints is not None: + with torch.no_grad(): + target_residual = target_waypoints - sft_waypoints + + loss_dict = self.resad.loss( + delta_norm, log_sigma, + target_residual, + target_uncertainty=torch.abs(target_residual), + kl_weight=self.config.kl_weight, + ) + result['loss'] = loss_dict + + return result + + +# ============================================================================ +# ResAD Trainer +# ============================================================================ + +class ResADTrainer: + """ + Trainer for ResAD algorithm. + """ + + def __init__( + self, + model: nn.Module, + config: ResADConfig, + lr: float = 1e-4, + device: str = 'cuda', + ): + self.model = model.to(device) + self.config = config + self.device = device + + self.optimizer = torch.optim.AdamW( + model.resad.parameters(), + lr=lr, + weight_decay=1e-4, + ) + + self.global_step = 0 + + def train_step( + self, + features: torch.Tensor, + waypoints: torch.Tensor, + target_waypoints: torch.Tensor, + ego_state: Optional[torch.Tensor] = None, + ) -> Dict[str, float]: + """Single training step.""" + output = self.model( + features, + target_waypoints=target_waypoints, + ego_state=ego_state, + ) + + loss_dict = output['loss'] + total_loss = loss_dict['total_loss'] + + self.optimizer.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.model.resad.parameters(), + max_norm=1.0 + ) + self.optimizer.step() + + self.global_step += 1 + + return {k: v.item() for k, v in loss_dict.items()} + + def train_epoch(self, dataloader) -> Dict[str, float]: + """Train for one epoch.""" + self.model.train() + total_losses = defaultdict(list) + + for batch in dataloader: + features = batch['features'].to(self.device) + waypoints = batch['waypoints'].to(self.device) + targets = batch['targets'].to(self.device) + ego_state = batch.get('ego_state') + if ego_state is not None: + ego_state = ego_state.to(self.device) + + losses = self.train_step(features, waypoints, targets, ego_state) + + for k, v in losses.items(): + total_losses[k].append(v) + + return {k: np.mean(v) for k, v in total_losses.items()} + + def save_checkpoint(self, path: str, epoch: int): + """Save checkpoint.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + + checkpoint = { + 'model_state_dict': self.model.resad.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'config': self.config.__dict__, + 'epoch': epoch, + 'global_step': self.global_step, + } + + torch.save(checkpoint, path) + print(f"Saved checkpoint to {path}") + + +# ============================================================================ +# ResAD Evaluation +# ============================================================================ + +class ResADEvaluator: + """Evaluator for ResAD model.""" + + def __init__(self, config: ResADConfig): + self.config = config + + def evaluate( + self, + model: nn.Module, + dataloader, + ) -> Dict[str, float]: + """Evaluate model on dataset.""" + model.eval() + metrics = defaultdict(list) + + with torch.no_grad(): + for batch in dataloader: + features = batch['features'] + targets = batch['targets'] + + output = model(features) + waypoints = output['waypoints'] + uncertainty = output['uncertainty'] + sft_waypoints = output['sft_waypoints'] + + # ADE + ade = torch.norm(waypoints - targets, dim=-1).mean().item() + metrics['ade'].append(ade) + + # FDE + fde = torch.norm(waypoints[:, -1] - targets[:, -1], dim=-1).mean().item() + metrics['fde'].append(fde) + + # SFT baseline + sft_ade = torch.norm(sft_waypoints - targets, dim=-1).mean().item() + metrics['sft_ade'].append(sft_ade) + + # Uncertainty + metrics['uncertainty'].append(uncertainty.mean().item()) + + return {k: np.mean(v) for k, v in metrics.items()} + + +# ============================================================================ +# Example Usage +# ============================================================================ + +def example_usage(): + """Example of using ResAD.""" + + # Mock SFT model + class MockSFT(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(256, 30) + + def forward(self, x): + return self.fc(x) + + # Configuration + config = ResADConfig( + feature_dim=256, + waypoint_dim=3, + hidden_dim=128, + use_inertial_ref=True, + ) + + # SFT model + sft_model = MockSFT() + + # ResAD with SFT + resad = ResADWithSFT(sft_model, config) + + # Forward pass + features = torch.randn(4, 256) + targets = torch.randn(4, 10, 3) + + output = resad(features, target_waypoints=targets) + + print("ResAD Output:") + print(f" SFT waypoints: {output['sft_waypoints'].shape}") + print(f" Corrected waypoints: {output['waypoints'].shape}") + print(f" Uncertainty: {output['uncertainty'].shape}") + print(f" Loss: {output['loss']['total_loss'].item():.4f}") + + # Training step + trainer = ResADTrainer(resad, config) + + # Mock dataloader + class MockDataset(torch.utils.data.Dataset): + def __getitem__(self, idx): + return { + 'features': torch.randn(256), + 'waypoints': torch.randn(10, 3), + 'targets': torch.randn(10, 3), + } + def __len__(self): + return 100 + + dataloader = torch.utils.data.DataLoader(MockDataset(), batch_size=32) + + print("\nTraining for 1 epoch...") + metrics = trainer.train_epoch(dataloader) + print(f"Losses: {metrics}") + + +if __name__ == "__main__": + example_usage() diff --git a/training/rl/toy_waypoint_env.py b/training/rl/toy_waypoint_env.py index 948d112..f707584 100644 --- a/training/rl/toy_waypoint_env.py +++ b/training/rl/toy_waypoint_env.py @@ -428,6 +428,177 @@ def policy_rl_refined(obs: tuple | np.ndarray) -> np.ndarray: return np.array([steer, throttle], dtype=np.float32) + return np.array([steer, throttle], dtype=np.float32) + + +# === ResAD Policy Integration === + +def policy_resad( + obs: tuple | np.ndarray, + resad_model: "ResADWithSFT" | None = None, + sft_waypoints: np.ndarray | None = None, +) -> np.ndarray: + """ + ResAD policy for toy waypoint environment. + + Uses ResAD residual correction to refine SFT waypoint predictions. + + Args: + obs: Observation (state, info tuple or observation array) + resad_model: Optional pre-loaded ResAD model + sft_waypoints: Optional pre-computed SFT waypoints + + Returns: + Action array (steer, throttle) + """ + # Handle both tuple format and observation array format + if isinstance(obs, tuple) and len(obs) == 2: + state, info = obs + x, y, heading, speed = float(state[0]), float(state[1]), float(state[2]), float(state[3]) + waypoints = info.get("waypoints") + current_waypoint_idx = info.get("current_waypoint_idx", 0) + elif isinstance(obs, np.ndarray): + x, y, heading, speed = float(obs[0]), float(obs[1]), float(obs[2]), float(obs[3]) + waypoints_start = 4 + horizon = 20 + target_idx = int(obs[-1] * horizon) if horizon > 0 else 0 + target_idx = max(0, min(target_idx, horizon - 1)) + waypoints = obs[waypoints_start:waypoints_start + horizon * 2].reshape(horizon, 2) + current_waypoint_idx = target_idx + else: + raise ValueError(f"Unknown observation format: {type(obs)}") + + # Use SFT waypoints if provided, otherwise use environment waypoints + if sft_waypoints is None and waypoints is not None: + # Use environment waypoints as "SFT" baseline + sft_waypoints = waypoints + + if resad_model is not None and sft_waypoints is not None: + # Run ResAD inference + import torch + import numpy as np + + # Create feature vector from state (expand 4D state to 256D feature) + # This is a mock - in real use, this would come from a perception backbone + features_np = np.zeros(256, dtype=np.float32) + features_np[0:4] = [x, y, heading, speed] + # Add some sinusoidal positional encoding for realism + for i in range(4): + features_np[i] = float(state[i]) + features_np[4 + i * 2] = np.sin(float(state[i]) * np.pi / 50) + features_np[5 + i * 2] = np.cos(float(state[i]) * np.pi / 50) + # Add waypoint features + for i, wp in enumerate(waypoints[:10]): # First 10 waypoints + if i * 2 + 24 < 256: + features_np[24 + i * 2] = wp[0] + features_np[25 + i * 2] = wp[1] + + features = torch.tensor(features_np, dtype=torch.float32).unsqueeze(0) # [1, 256] + + # Mock SFT output is [1, 30] - reshape to [1, 10, 3] for ResAD + mock_sft_output = torch.randn(1, 30) # 10 waypoints * 3 dims (x, y, heading) + sft_waypoints_tensor = mock_sft_output.view(1, 10, 3) # [1, 10, 3] + + ego_state = torch.tensor([[speed, heading]], dtype=torch.float32) # [1, 2] + + with torch.no_grad(): + # Create a temporary wrapper for this inference + class TempResAD: + def __init__(self, model, sft_output): + self.model = model + self.sft_output = sft_output + + def __call__(self, features, ego_state=None): + # Reshape sft_output to [1, 10, 3] + sft_wp = self.sft_output + if sft_wp.dim() == 2: + sft_wp = sft_wp.view(1, 10, 3) + + delta_norm, log_sigma = self.model.resad(features, sft_wp, ego_state) + corrected, uncertainty = self.model.resad.apply(sft_wp, delta_norm, log_sigma) + return { + 'waypoints': corrected, + 'uncertainty': uncertainty, + 'sft_waypoints': sft_wp, + } + + output = TempResAD(resad_model, mock_sft_output)(features, ego_state) + refined_waypoints = output['waypoints'].squeeze(0).cpu().numpy() # [10, 3] + uncertainty = output['uncertainty'].squeeze(0).cpu().numpy() + else: + # Fallback: use RL-refined heuristic + return policy_rl_refined(obs) + + # Get current target waypoint from refined predictions + if current_waypoint_idx < len(refined_waypoints): + target_wp = refined_waypoints[current_waypoint_idx] + else: + target_wp = refined_waypoints[-1] if len(refined_waypoints) > 0 else np.array([x, y]) + + # Compute angle to refined target + dx = target_wp[0] - x + dy = target_wp[1] - y + target_angle = np.arctan2(dy, dx) + + # Steering toward refined target + angle_diff = target_angle - heading + while angle_diff > np.pi: + angle_diff -= 2 * np.pi + while angle_diff < -np.pi: + angle_diff += 2 * np.pi + + steer = np.clip(angle_diff / (np.pi / 4), -1.0, 1.0) + + # Throttle: use uncertainty for adaptive speed + dist = np.sqrt(dx**2 + dy**2) + uncertainty_factor = 1.0 - np.clip(uncertainty.mean(), 0, 1) * 0.3 + throttle = np.clip(1.0 - dist / 20.0, 0.0, 1.0) * uncertainty_factor + + return np.array([steer, throttle], dtype=np.float32) + + +def create_resad_policy(checkpoint_path: str | None = None): + """ + Create a ResAD policy function from checkpoint. + + Args: + checkpoint_path: Path to ResAD checkpoint (optional, uses mock if None) + + Returns: + Policy function that takes (obs, info) and returns action + """ + import torch + from training.rl.resad import ResADConfig, ResADWithSFT, ResADModule + + # Mock SFT model (replace with actual SFT checkpoint loading) + class MockSFT(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(256, 30) # 10 waypoints * 3 dims + + def forward(self, x): + return self.fc(x) + + config = ResADConfig( + feature_dim=256, + waypoint_dim=3, + hidden_dim=128, + use_inertial_ref=True, + ) + + sft_model = MockSFT() + resad_model = ResADWithSFT(sft_model, config) + + if checkpoint_path is not None: + checkpoint = torch.load(checkpoint_path) + resad_model.resad.load_state_dict(checkpoint['model_state_dict']) + + def policy(obs): + return policy_resad(obs, resad_model=resad_model) + + return policy, resad_model + + def main(): """Quick test of the environment.""" import argparse