diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json index 8412fe36..dddabe41 100644 --- a/configs/agents/rl/basic/cart_pole/train_config.json +++ b/configs/agents/rl/basic/cart_pole/train_config.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "eval_freq": 2, "save_freq": 200, "use_wandb": false, diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json index ae5026b2..b7639c06 100644 --- a/configs/agents/rl/push_cube/train_config.json +++ b/configs/agents/rl/push_cube/train_config.json @@ -9,14 +9,15 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "enable_eval": true, "num_eval_envs": 16, "num_eval_episodes": 3, - "eval_freq": 2, + "eval_freq": 200, "save_freq": 200, "use_wandb": false, "wandb_project_name": "embodychain-push_cube", + "model_type": "standard", "events": { "eval": { "record_camera": { @@ -38,6 +39,7 @@ }, "policy": { "name": "actor_critic", + "action_dim": 8, "actor": { "type": "mlp", "network_cfg": { diff --git a/configs/agents/rl/vla_example/train_config.json b/configs/agents/rl/vla_example/train_config.json new file mode 100644 index 00000000..bc48b9f5 --- /dev/null +++ b/configs/agents/rl/vla_example/train_config.json @@ -0,0 +1,70 @@ +{ + "trainer": { + "exp_name": "vla_fine_tuning_ppo", + "gym_config": "configs/agents/rl/push_cube/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 32, + "iterations": 500, + "buffer_size": 2048, + "buffer_type": "vla", + "enable_eval": true, + "num_eval_envs": 8, + "num_eval_episodes": 3, + "eval_freq": 100, + "save_freq": 100, + "use_wandb": true, + "wandb_project_name": "embodychain-vla-training", + "model_type": "vla", + "events": { + "eval": { + "record_camera": { + "func": "record_camera_data_async", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "main_cam", + "resolution": [640, 480], + "eye": [-1.4, 1.4, 2.0], + "target": [0, 0, 0], + "up": [0, 0, 1], + "intrinsics": [600, 600, 320, 240], + "save_path": "./outputs/videos/vla_eval" + } + } + } + } + }, + "policy": { + "name": "vla", + "action_dim": 7, + "vla_config": { + "model_path": "checkpoints/pretrained_vla_model.pth", + "model_class": "vla_models.GPTVLAModel", + "model_config": { + "vision_encoder": "resnet50", + "language_model": "gpt2-medium", + "action_head_hidden_size": 512, + "freeze_vision_encoder": false, + "freeze_language_model": false + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 1e-5, + "n_epochs": 4, + "batch_size": 2048, + "gamma": 0.99, + "gae_lambda": 0.95, + "clip_coef": 0.2, + "ent_coef": 0.001, + "vf_coef": 0.5, + "max_grad_norm": 1.0 + } + } +} diff --git a/docs/rl_training_guide.md b/docs/rl_training_guide.md new file mode 100644 index 00000000..3db3d072 --- /dev/null +++ b/docs/rl_training_guide.md @@ -0,0 +1,292 @@ +# RL Training Framework Guide + +TensorDict-based RL framework supporting standard PPO and asynchronous VLA training. + +--- + +## Quick Start + +### Configuration + +```json +{ + "trainer": { + "buffer_size": 2048, + "model_type": "standard" // or "vla" + }, + "policy": {"name": "actor_critic"}, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 3e-4, + "gamma": 0.99, + "n_epochs": 10, + "batch_size": 64 + } + } +} +``` + +### Run Training + +```bash +python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json +``` + +--- + +## Architecture + +``` +Trainer → Collector (sync/async) → Buffer (standard/vla) → Algorithm (PPO) +``` + +**Components**: +- **Collector**: Gather data from environment (SyncCollector / AsyncCollector) +- **Buffer**: Store transitions (RolloutBuffer / VLABuffer) +- **Algorithm**: Update policy (PPO) +- **Trainer**: Coordinate training loop + +--- + +## Training Modes + +### Standard Mode (Default) + +**For**: Normal models (<100ms inference/step) + +``` +SyncCollector → Collect 2048 steps → Train → Clear buffer → Repeat +``` + +**Config**: `{"trainer": {"model_type": "standard"}}` + +**Pros**: Simple, stable, low memory, no staleness + +### VLA Async Mode + +**For**: Large models (>1 sec inference/step) + +``` +Background: AsyncCollector → Continuously collect → VLABuffer +Main: Wait for buffer full → Train → Repeat +``` + +**Config**: `{"trainer": {"model_type": "vla"}}` + +**Pros**: 2-3x speedup via parallel collection +**Cons**: Data staleness, higher memory + +--- + +## Collectors + +### SyncCollector + +Collects complete rollout synchronously: + +```python +from embodichain.agents.rl.collector import SyncCollector + +collector = SyncCollector(env, policy, device, callback) +rollout = collector.collect(num_steps=2048) # [T, N, ...] +``` + +### AsyncCollector + +Runs in background thread: + +```python +from embodichain.agents.rl.collector import AsyncCollector + +collector = AsyncCollector(env, policy, buffer, device, callback) +collector.start() # Begin background collection +# ... buffer fills automatically ... +collector.stop() # Stop collection +``` + +--- + +## Buffers + +### RolloutBuffer (Standard) + +Single-use buffer: + +```python +from embodichain.agents.rl.buffer import RolloutBuffer + +buffer = RolloutBuffer(buffer_size=2048, device=device) +buffer.add(rollout) # [T, N, ...] +data = buffer.get(flatten=True) # [T*N, ...], auto-clears +``` + +### VLABuffer (Async) + +Circular FIFO buffer: + +```python +from embodichain.agents.rl.buffer import VLABuffer + +buffer = VLABuffer(buffer_size=4096, device=device) +buffer.add(transition) # Single step +data = buffer.get(flatten=True) # [buffer_size, ...] when full +``` + +**Circular behavior**: `[T0,T1,T2,T3]` → add T4 → `[T4,T1,T2,T3]` (T0 overwritten) + +--- + +## VLA Integration + +### 1. Implement Model + +```python +class MyVLAModel(nn.Module): + def forward(self, obs: TensorDict) -> TensorDict: + # Add 'action', 'sample_log_prob', 'value' + ... + def get_value(self, obs: TensorDict) -> TensorDict: + # Add 'value' + ... + def evaluate_actions(self, obs: TensorDict) -> TensorDict: + # Add 'sample_log_prob', 'entropy', 'value' + ... +``` + +### 2. Implement Loading + +Edit `embodichain/agents/rl/models/vla_policy.py`: + +```python +def load_vla_model(model_path, model_class, model_config, device): + model = MyVLAModel(**model_config) + model.load_state_dict(torch.load(model_path)) + return model.to(device) +``` + +### 3. Configure + +```json +{ + "trainer": {"model_type": "vla"}, + "policy": { + "name": "vla", + "vla_config": { + "model_path": "checkpoints/vla.pt", + "model_class": "MyVLAModel", + "model_config": {} + } + } +} +``` + +--- + +## Common APIs + +### Trainer + +```python +from embodichain.agents.rl.utils import Trainer + +trainer = Trainer( + policy, env, algorithm, + buffer_size=2048, + model_type="standard", # or "vla" + ... +) +trainer.train(total_timesteps=1000000) +``` + +### Buffer Methods + +```python +buffer.add(data) # Add data +data = buffer.get(flatten=True) # Retrieve data +buffer.is_full() # Check ready status +buffer.clear() # Clear buffer +buffer.get_stats() # Statistics +``` + +### Algorithm + +```python +from embodichain.agents.rl.algo import PPO, PPOCfg + +algorithm = PPO(PPOCfg(...), policy) +losses = algorithm.update(rollout) # Returns loss dict +``` + +--- + +## FAQ + +**Q: When use VLA mode?** +A: Inference >100ms/step AND GPU training fast + +**Q: Buffer size?** +A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity) + +**Q: Data staleness impact?** +A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty + +**Q: Debug data flow?** +A: `buffer.get_stats()` or `_print_tensordict_tree(rollout)` in ppo.py + +--- + +## Workflows + +### Standard + +```python +collector = SyncCollector(env, policy, device, callback) +while step < total: + rollout = collector.collect(num_steps=2048) + buffer.add(rollout) + data = buffer.get(flatten=True) + losses = algorithm.update(data) +``` + +### VLA + +```python +collector = AsyncCollector(env, policy, buffer, device, callback) +collector.start() +while step < total: + while not buffer.is_full(): + time.sleep(0.1) + data = buffer.get(flatten=True) + losses = algorithm.update(data) +collector.stop() +``` + +--- + +## File Structure + +``` +embodichain/agents/rl/ +├── train.py # Entry point +├── algo/ppo.py # PPO algorithm +├── buffer/ +│ ├── standard_buffer.py # RolloutBuffer +│ └── vla_buffer.py # VLABuffer +├── collector/ +│ ├── base.py # BaseCollector +│ ├── sync_collector.py # SyncCollector +│ └── async_collector.py # AsyncCollector +├── models/ +│ ├── actor_critic.py # Standard policy +│ └── vla_policy.py # VLA wrapper +└── utils/trainer.py # Training coordinator +``` + +--- + +## References + +- [TensorDict Docs](https://pytorch.org/tensordict/) +- [PPO Paper](https://arxiv.org/abs/1707.06347) +- Example configs: `configs/agents/rl/` diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index cbc011b2..b0123c96 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -64,7 +64,7 @@ The ``runtime`` section controls experiment setup: - **cuda**: Whether to use GPU (default: true) - **headless**: Whether to run simulation in headless mode - **iterations**: Number of training iterations -- **rollout_steps**: Steps per rollout (e.g., 1024) +- **buffer_size**: Steps per rollout (e.g., 1024) - **eval_freq**: Frequency of evaluation (in steps) - **save_freq**: Frequency of checkpoint saving (in steps) - **use_wandb**: Whether to enable Weights & Biases logging (set in JSON config) diff --git a/embodichain/agents/rl/ARCHITECTURE.md b/embodichain/agents/rl/ARCHITECTURE.md new file mode 100644 index 00000000..c83e2ff1 --- /dev/null +++ b/embodichain/agents/rl/ARCHITECTURE.md @@ -0,0 +1,216 @@ +# RL训练框架架构 + +## 总体流程 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Trainer │ +│ (训练总协调者) │ +│ │ +│ ┌─────────────────┐ ┌──────────────────┐ │ +│ │ 初始化阶段 │ │ 训练循环 │ │ +│ │ │ │ │ │ +│ │ 1. 创建Policy │───────▶│ while epoch: │ │ +│ │ 2. 创建Algo │ │ ├─ 收集数据 │ │ +│ │ 3. 创建Collector│ │ ├─ 更新策略 │ │ +│ │ 4. 创建Env │ │ └─ 评估性能 │ │ +│ └─────────────────┘ └──────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────┘ + │ + ┌──────────────┼──────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌──────────┐ ┌──────────┐ ┌──────────┐ + │ Collector│ │Algorithm │ │ Policy │ + └──────────┘ └──────────┘ └──────────┘ +``` + +## 核心组件 + +### 1. Trainer(训练器) +**职责**:总协调者,串联所有组件 +``` +训练循环: + for epoch in range(n_epochs): + ├─ rollout = collector.collect(n_steps) # 收集数据 + ├─ metrics = algorithm.update(rollout) # 更新策略 + └─ eval_reward = evaluate(policy) # 评估性能 +``` + +### 2. Collector(数据收集器) +**职责**:与环境交互,收集经验数据 + +``` +┌─────────────────────────────────────────────┐ +│ Collector 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────────┐ ┌─────────────────┐ │ +│ │ SyncCollector │ │ AsyncCollector │ │ +│ │ (同步收集) │ │ (异步收集) │ │ +│ │ │ │ │ │ +│ │ 用于标准RL算法 │ │ 用于VLA模型 │ │ +│ │ - PPO │ │ - 后台持续收集 │ │ +│ │ - SAC │ │ - 独立线程 │ │ +│ └──────────────────┘ └─────────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +工作流程: + obs = env.reset() + for step in range(n_steps): + ├─ policy.forward(obs, deterministic=False) # 采样动作 + ├─ next_obs, reward, done = env.step(action) + └─ 存储到 TensorDict: (obs, action, reward, done, value) + return rollout_tensordict # [T, N] 格式 +``` + +### 3. Algorithm(算法) +**职责**:策略更新逻辑 + +``` +┌─────────────────────────────────────────────┐ +│ Algorithm 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ PPO │ │ SAC │ ... │ +│ │ │ │ │ │ +│ │ - GAE计算 │ │ - Q学习 │ │ +│ │ - Clip损失 │ │ - Soft更新 │ │ +│ │ - 价值损失 │ │ - 熵正则化 │ │ +│ └──────────────┘ └──────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +工作流程: + def update(rollout: TensorDict) -> dict: + ├─ 计算优势函数 (GAE) + ├─ 多轮优化循环 + │ ├─ policy.evaluate_actions(batch) # 重新计算log_prob + │ ├─ 计算loss (clip + value + entropy) + │ └─ optimizer.step() + └─ return metrics +``` + +### 4. Policy(策略) +**职责**:神经网络,输出动作和价值 + +``` +┌─────────────────────────────────────────────┐ +│ Policy 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ ActorCritic │ │ VLAPolicy │ │ +│ │ │ │ │ │ +│ │ - MLP网络 │ │ - 视觉语言 │ │ +│ │ - 高斯策略 │ │ - 预训练模型 │ │ +│ └──────────────┘ └──────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +接口方法: + 1. forward(obs, deterministic=False) + ├─ 训练时:采样动作 (deterministic=False) + ├─ 评估时:确定性动作 (deterministic=True) + └─ 返回:action, log_prob, value + + 2. evaluate_actions(obs, action) + └─ 重新计算给定动作的log_prob和entropy + + 3. get_value(obs) + └─ 仅返回价值估计 +``` + +## 数据流动(TensorDict) + +``` +Environment ──▶ Collector ──▶ Algorithm ──▶ Policy + │ │ │ │ + │ TensorDict TensorDict Parameters + │ [T, N] [batch] Update + │ │ │ │ + └───────────────┴──────────────┴────────────┘ + +TensorDict 结构: +{ + "observation": Tensor or nested TensorDict, + "action": Tensor[T, N, action_dim], + "reward": Tensor[T, N, 1], + "done": Tensor[T, N, 1], + "value": Tensor[T, N, 1], + "sample_log_prob": Tensor[T, N, 1], + "advantage": Tensor[T, N, 1], # GAE计算后添加 + "return": Tensor[T, N, 1], # GAE计算后添加 +} +``` + +## 完整训练流程示例 + +```python +# 1. 初始化组件 +trainer = Trainer( + env=env, + policy=ActorCritic(...), + algorithm=PPO(...), +) + +# 2. 创建Collector +collector = SyncCollector( + env=env, + policy=policy, + device=device, +) + +# 3. 训练循环 +for epoch in range(n_epochs): + + # 3.1 收集数据 + rollout = collector.collect( + n_steps=2048, + reset=True, + ) + # rollout: TensorDict[T=2048, N=num_envs] + + # 3.2 更新策略 + metrics = algorithm.update(rollout) + # metrics: {"loss": ..., "clip_frac": ..., ...} + + # 3.3 评估性能 + eval_reward = trainer.evaluate( + n_episodes=10, + deterministic=True, # 评估时使用确定性动作 + ) + + # 3.4 日志记录 + print(f"Epoch {epoch}: reward={eval_reward}, loss={metrics['loss']}") +``` + +## 关键设计原则 + +### 1. 职责分离 +- **Trainer**: 协调者,不涉及具体实现 +- **Collector**: 只负责数据收集,不做策略更新 +- **Algorithm**: 只负责策略更新,不做数据收集 +- **Policy**: 只负责网络前向,不涉及训练逻辑 + +### 2. 统一接口 +- 所有组件使用 **TensorDict** 进行数据传递 +- Policy暴露统一接口:`forward()`, `evaluate_actions()`, `get_value()` +- 易于切换不同实现(ActorCritic ↔ VLAPolicy) + +### 3. 灵活扩展 +- 添加新算法:继承 `BaseAlgorithm`,实现 `update()` +- 添加新策略:继承 `Policy`,实现三个抽象方法 +- 添加新收集器:继承 `BaseCollector`,实现 `collect()` + +### 4. 确定性评估 +```python +# 训练时(随机采样,探索) +policy.forward(obs, deterministic=False) # 使用 dist.sample() + +# 评估时(确定性,稳定) +policy.forward(obs, deterministic=True) # 使用 dist.mean +``` diff --git a/embodichain/agents/rl/algo/base.py b/embodichain/agents/rl/algo/base.py index 8d74a918..06058f46 100644 --- a/embodichain/agents/rl/algo/base.py +++ b/embodichain/agents/rl/algo/base.py @@ -18,35 +18,27 @@ from typing import Dict, Any, Callable import torch +from tensordict import TensorDict class BaseAlgorithm: - """Base class for RL algorithms. + """Base class for RL algorithms following TorchRL conventions. - Algorithms must implement buffer initialization, rollout collection, and - policy update. Trainer depends only on this interface to remain - algorithm-agnostic. + Algorithms implement policy updates using TensorDict. + Data collection is handled separately by Collector classes (SyncCollector/AsyncCollector). """ device: torch.device - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ) -> None: - """Initialize internal buffer(s) required by the algorithm.""" - raise NotImplementedError + def update(self, rollout: TensorDict) -> Dict[str, float]: + """Update policy using collected rollout data. - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect trajectories and return logging info (e.g., reward components).""" - raise NotImplementedError + Args: + rollout: TensorDict containing collected rollout data from Collector + Expected batch_size format: [T, N] for on-policy algorithms + where T is trajectory length and N is number of environments - def update(self) -> Dict[str, float]: - """Update policy using collected data and return training losses.""" + Returns: + Dictionary of training metrics (losses, learning stats, etc.) + """ raise NotImplementedError diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index f11fbe37..99671fa6 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -15,14 +15,52 @@ # ---------------------------------------------------------------------------- import torch -from typing import Dict, Any, Tuple, Callable - -from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation -from embodichain.agents.rl.buffer import RolloutBuffer +from tensordict import TensorDict +from embodichain.agents.rl.utils import AlgorithmCfg, compute_gae from embodichain.utils import configclass from .base import BaseAlgorithm +def _print_tensordict_tree(td, prefix="", is_last=True, name="TensorDict"): + """Recursively print TensorDict structure in tree format.""" + connector = "└── " if is_last else "├── " + + # Print current node + batch_info = ( + f"batch_size={list(td.batch_size)}" if hasattr(td, "batch_size") else "" + ) + device_info = f"device={td.device}" if hasattr(td, "device") else "" + meta_info = ", ".join(filter(None, [batch_info, device_info])) + print(f"{prefix}{connector}{name}: TensorDict ({meta_info})") + + # Prepare prefix for children + extension = " " if is_last else "│ " + new_prefix = prefix + extension + + # Get all keys + keys = sorted(td.keys()) if hasattr(td, "keys") else [] + + for i, key in enumerate(keys): + is_last_child = i == len(keys) - 1 + value = td[key] + + if isinstance(value, TensorDict): + # Recursively print nested TensorDict + _print_tensordict_tree(value, new_prefix, is_last_child, name=key) + elif isinstance(value, torch.Tensor): + # Print tensor info + child_connector = "└── " if is_last_child else "├── " + shape_str = "x".join(map(str, value.shape)) + dtype_str = str(value.dtype).replace("torch.", "") + print( + f"{new_prefix}{child_connector}{key}: Tensor([{shape_str}], {dtype_str})" + ) + else: + # Print other types + child_connector = "└── " if is_last_child else "├── " + print(f"{new_prefix}{child_connector}{key}: {type(value).__name__}") + + @configclass class PPOCfg(AlgorithmCfg): """Configuration for the PPO algorithm.""" @@ -34,126 +72,95 @@ class PPOCfg(AlgorithmCfg): class PPO(BaseAlgorithm): - """PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design).""" + """PPO algorithm using TensorDict for all data flow. + Data collection is handled by Collector classes (SyncCollector/AsyncCollector). + """ def __init__(self, cfg: PPOCfg, policy): self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - self.buffer: RolloutBuffer | None = None - # no per-rollout aggregation for dense logging - - def _compute_gae( - self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Internal method to compute GAE. Only called by collect_rollout.""" - T, N = rewards.shape - advantages = torch.zeros_like(rewards, device=self.device) - last_adv = torch.zeros(N, device=self.device) - for t in reversed(range(T)): - next_value = values[t + 1] if t < T - 1 else torch.zeros_like(values[0]) - not_done = (~dones[t]).float() - delta = rewards[t] + self.cfg.gamma * next_value * not_done - values[t] - last_adv = ( - delta + self.cfg.gamma * self.cfg.gae_lambda * not_done * last_adv - ) - advantages[t] = last_adv - returns = advantages + values - return advantages, returns - - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ): - """Initialize the rollout buffer. Called by trainer before first rollout.""" - self.buffer = RolloutBuffer( - num_steps, num_envs, obs_dim, action_dim, self.device - ) - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect a rollout. Algorithm controls the data collection process.""" - if self.buffer is None: - raise RuntimeError( - "Buffer not initialized. Call initialize_buffer() first." - ) - policy.train() - self.buffer.step = 0 - current_obs = obs - - for t in range(num_steps): - # Get action from policy - actions, log_prob, value = policy.get_action( - current_obs, deterministic=False - ) + def update(self, rollout: TensorDict) -> dict: + """Update the policy using collected rollout TensorDict (TorchRL style). - # Wrap action as dict for env processing - action_type = getattr(env, "action_type", "delta_qpos") - action_dict = {action_type: actions} + Args: + rollout: TensorDict with batch_size=[T, N] from collect_rollout() + OR [size] from VLA buffer - # Step environment - result = env.step(action_dict) - next_obs, reward, terminated, truncated, env_info = result - done = terminated | truncated - # Light dtype normalization - reward = reward.float() - done = done.bool() + Returns: + Dictionary of training metrics + """ + # Ensure 2D format [T, N] for GAE computation + if len(rollout.batch_size) == 1: + rollout = rollout.unsqueeze(1) # [size] -> [size, 1] - # Flatten dict observation from ObservationManager if needed - if isinstance(next_obs, dict): - next_obs = flatten_dict_observation(next_obs) - - # Add to buffer - self.buffer.add(current_obs, actions, reward, done, value, log_prob) - - # Dense logging is handled in Trainer.on_step via info; no aggregation here - # Call callback for statistics and logging - if on_step_callback is not None: - on_step_callback(current_obs, actions, reward, done, env_info, next_obs) - - current_obs = next_obs - - # Compute advantages/returns and attach to buffer extras - adv, ret = self._compute_gae( - self.buffer.rewards, self.buffer.values, self.buffer.dones + # Compute GAE advantages and returns + rollout = compute_gae( + rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda ) - self.buffer.set_extras({"advantages": adv, "returns": ret}) - - # No aggregated logging results; Trainer performs dense per-step logging - return {} - def update(self) -> dict: - """Update the policy using the collected rollout buffer.""" - if self.buffer is None: - raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") + # Flatten to [T*N, ...] for training + flat_data = rollout.reshape(-1) + total_samples = flat_data.batch_size[0] - # Normalize advantages (optional, common default) - adv = self.buffer._extras.get("advantages") - adv = (adv - adv.mean()) / (adv.std() + 1e-8) + # Normalize advantages globally + advantages = flat_data["advantage"] + advantages_normalized = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) + flat_data["advantage"] = advantages_normalized total_actor_loss = 0.0 total_value_loss = 0.0 total_entropy = 0.0 total_steps = 0 - - for _ in range(self.cfg.n_epochs): - for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): - obs = batch["obs"] - actions = batch["actions"] - old_logprobs = batch["logprobs"] - returns = batch["returns"] - advantages = ( - (batch["advantages"] - adv.mean()) / (adv.std() + 1e-8) - ).detach() - - logprobs, entropy, values = self.policy.evaluate_actions(obs, actions) + total_clip_fraction = 0.0 + total_approx_kl = 0.0 + + for epoch in range(self.cfg.n_epochs): + # Shuffle data each epoch + indices = torch.randperm(total_samples, device=self.device) + shuffled_data = flat_data[indices] + + # Iterate over minibatches + num_minibatches = total_samples // self.cfg.batch_size + for i in range(num_minibatches): + start_idx = i * self.cfg.batch_size + end_idx = start_idx + self.cfg.batch_size + batch_td = shuffled_data[start_idx:end_idx] + + # Extract data from TensorDict batch + old_logprobs = batch_td["sample_log_prob"] + returns = batch_td["value_target"] + advantages = batch_td[ + "advantage" + ] # Note: advantages are already normalized globally before shuffling + + # Evaluate actions with current policy + self.policy.evaluate_actions(batch_td) + + # Get updated values + logprobs = batch_td["sample_log_prob"] + entropy = batch_td["entropy"] + values = batch_td["value"] + + # Ensure shapes match (squeeze if needed) + if old_logprobs.dim() > 1: + old_logprobs = old_logprobs.squeeze(-1) + if logprobs.dim() > 1: + logprobs = logprobs.squeeze(-1) + if values.dim() > 1: + values = values.squeeze(-1) + if returns.dim() > 1: + returns = returns.squeeze(-1) + if advantages.dim() > 1: + advantages = advantages.squeeze(-1) + if entropy.dim() > 1: + entropy = entropy.squeeze(-1) + + # PPO loss computation ratio = (logprobs - old_logprobs).exp() surr1 = ratio * advantages surr2 = ( @@ -166,6 +173,13 @@ def update(self) -> dict: value_loss = torch.nn.functional.mse_loss(values, returns) entropy_loss = -entropy.mean() + # Diagnostics + with torch.no_grad(): + clip_fraction = ( + ((ratio - 1.0).abs() > self.cfg.clip_coef).float().mean() + ) + approx_kl = ((ratio - 1.0) - (logprobs - old_logprobs)).mean() + loss = ( actor_loss + self.cfg.vf_coef * value_loss @@ -179,14 +193,18 @@ def update(self) -> dict: ) self.optimizer.step() - bs = obs.shape[0] + bs = batch_td.batch_size[0] total_actor_loss += actor_loss.item() * bs total_value_loss += value_loss.item() * bs total_entropy += (-entropy_loss.item()) * bs + total_clip_fraction += clip_fraction.item() * bs + total_approx_kl += approx_kl.item() * bs total_steps += bs return { "actor_loss": total_actor_loss / max(1, total_steps), "value_loss": total_value_loss / max(1, total_steps), "entropy": total_entropy / max(1, total_steps), + "clip_fraction": total_clip_fraction / max(1, total_steps), + "approx_kl": total_approx_kl / max(1, total_steps), } diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index 8e6f6392..b68a7f49 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -14,6 +14,15 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .rollout_buffer import RolloutBuffer +""" +Buffer module for RL training. -__all__ = ["RolloutBuffer"] +Provides two buffer implementations: +- RolloutBuffer: Standard PPO buffer (single rollout, use and discard) +- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference) +""" + +from .vla_buffer import VLABuffer +from .standard_buffer import RolloutBuffer + +__all__ = ["RolloutBuffer", "VLABuffer"] diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/rollout_buffer.py deleted file mode 100644 index d99a8966..00000000 --- a/embodichain/agents/rl/buffer/rollout_buffer.py +++ /dev/null @@ -1,106 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Dict, Iterator - -import torch - - -class RolloutBuffer: - """On-device rollout buffer for on-policy algorithms. - - Stores (obs, actions, rewards, dones, values, logprobs) over time. - After finalize(), exposes advantages/returns and minibatch iteration. - """ - - def __init__( - self, - num_steps: int, - num_envs: int, - obs_dim: int, - action_dim: int, - device: torch.device, - ): - self.num_steps = num_steps - self.num_envs = num_envs - self.obs_dim = obs_dim - self.action_dim = action_dim - self.device = device - - T, N = num_steps, num_envs - self.obs = torch.zeros(T, N, obs_dim, dtype=torch.float32, device=device) - self.actions = torch.zeros(T, N, action_dim, dtype=torch.float32, device=device) - self.rewards = torch.zeros(T, N, dtype=torch.float32, device=device) - self.dones = torch.zeros(T, N, dtype=torch.bool, device=device) - self.values = torch.zeros(T, N, dtype=torch.float32, device=device) - self.logprobs = torch.zeros(T, N, dtype=torch.float32, device=device) - - self.step = 0 - # Container for algorithm-specific extra fields (e.g., advantages, returns) - self._extras: dict[str, torch.Tensor] = {} - - def add( - self, - obs: torch.Tensor, - action: torch.Tensor, - reward: torch.Tensor, - done: torch.Tensor, - value: torch.Tensor, - logprob: torch.Tensor, - ) -> None: - t = self.step - self.obs[t].copy_(obs) - self.actions[t].copy_(action) - self.rewards[t].copy_(reward) - self.dones[t].copy_(done) - self.values[t].copy_(value) - self.logprobs[t].copy_(logprob) - self.step += 1 - - def set_extras(self, extras: dict[str, torch.Tensor]) -> None: - """Attach algorithm-specific tensors (shape [T, N, ...]) for batching. - - Examples: - {"advantages": adv, "returns": ret} - """ - self._extras = extras or {} - - def iterate_minibatches(self, batch_size: int) -> Iterator[Dict[str, torch.Tensor]]: - T, N = self.num_steps, self.num_envs - total = T * N - indices = torch.randperm(total, device=self.device) - for start in range(0, total, batch_size): - idx = indices[start : start + batch_size] - t_idx = idx // N - n_idx = idx % N - batch = { - "obs": self.obs[t_idx, n_idx], - "actions": self.actions[t_idx, n_idx], - "rewards": self.rewards[t_idx, n_idx], - "dones": self.dones[t_idx, n_idx], - "values": self.values[t_idx, n_idx], - "logprobs": self.logprobs[t_idx, n_idx], - } - # Slice extras if present and shape aligned to [T, N, ...] - for name, tensor in self._extras.items(): - try: - batch[name] = tensor[t_idx, n_idx] - except Exception: - # Skip misaligned extras silently - continue - yield batch diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py new file mode 100644 index 00000000..eea45bf9 --- /dev/null +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -0,0 +1,117 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict +from typing import Optional + + +class RolloutBuffer: + """Standard on-policy rollout buffer for PPO (matches mainstream implementations). + + Unlike VLA buffer which accumulates multiple rollouts with FIFO eviction, + this buffer follows standard PPO pattern: + - Stores exactly ONE rollout at a time + - After training, buffer is cleared (on-policy: use once and discard) + - Simple and efficient for normal-sized models + + Interface compatible with VLABuffer for easy switching. + """ + + def __init__(self, buffer_size: int, device: torch.device): + """Initialize standard rollout buffer. + + Args: + buffer_size: Buffer size from config (for interface compatibility with VLABuffer) + device: Device to store tensors on + """ + self.buffer_size = buffer_size + self.device = device + self._rollout: Optional[TensorDict] = None + + def add(self, rollout: TensorDict) -> None: + """Add a rollout to buffer, replacing any existing rollout. + + Args: + rollout: TensorDict with batch_size=[T, N, ...] + """ + # Standard PPO: replace existing rollout (not accumulate) + self._rollout = rollout.to(self.device) + + def get(self, flatten: bool = True) -> TensorDict: + """Get rollout from buffer and clear it (standard PPO behavior). + + Args: + flatten: If True, flatten to [batch_size, ...]. + If False, return as [T, N, ...]. + + Returns: + TensorDict with rollout data + """ + if self._rollout is None: + raise ValueError("Buffer is empty") + + rollout = self._rollout + + # Clear after retrieval (on-policy: use once) + self._rollout = None + + if flatten: + # Flatten [T, N, ...] -> [T*N, ...] + return rollout.reshape(-1) + else: + return rollout + + def clear(self) -> None: + """Clear buffer.""" + self._rollout = None + + def is_full(self) -> bool: + """Check if buffer has a rollout ready for training. + + Returns: + True if buffer contains a rollout + """ + return self._rollout is not None + + def __len__(self) -> int: + """Return 1 if buffer has data, 0 otherwise.""" + return 1 if self._rollout is not None else 0 + + def get_num_rollouts(self) -> int: + """Return current number of rollouts in buffer (0 or 1).""" + return 1 if self._rollout is not None else 0 + + def get_num_transitions(self) -> int: + """Return total number of transitions stored.""" + if self._rollout is None: + return 0 + return self._rollout.batch_size[0] * self._rollout.batch_size[1] + + def get_stats(self) -> dict: + """Get buffer statistics for logging. + + Returns: + Dict with buffer stats + """ + return { + "buffer_size": 1 if self._rollout is not None else 0, + "buffer_capacity": self.buffer_size, + "total_transitions": self.get_num_transitions(), + "buffer_usage": 1.0 if self._rollout is not None else 0.0, + } diff --git a/embodichain/agents/rl/buffer/vla_buffer.py b/embodichain/agents/rl/buffer/vla_buffer.py new file mode 100644 index 00000000..d08252ff --- /dev/null +++ b/embodichain/agents/rl/buffer/vla_buffer.py @@ -0,0 +1,175 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict +from typing import Optional + + +class VLABuffer: + """FIFO rollout buffer for VLA RL with pre-allocated TensorDict storage. + + Uses a single pre-allocated TensorDict with circular indexing for efficient + high-frequency transition writes. Designed for async VLA scenarios where + model inference is slow but training is fast. + + Key characteristics: + - Pre-allocated memory: Zero-copy writes via direct indexing + - FIFO eviction: Circular buffer automatically overwrites oldest data + - Transition-level storage: Each step is a separate entry + - High-frequency writes: Optimized for async collection (no TensorDict creation overhead) + + Storage layout: Single TensorDict with shape [buffer_size, ...] + """ + + def __init__(self, buffer_size: int, device: torch.device): + """Initialize VLA buffer with lazy allocation. + + Args: + buffer_size: Maximum number of transitions to store + device: Device to store tensors on + """ + self.buffer_size = buffer_size + self.device = device + self.buffer: Optional[TensorDict] = None # Lazy init on first add + self.write_pos = 0 # Current write position (circular) + self.size = 0 # Current valid data count + self._total_added = 0 + self._initialized = False + + def _initialize_buffer(self, template: TensorDict) -> None: + """Initialize buffer structure from first transition template. + + Args: + template: First transition TensorDict to infer structure from + """ + if self._initialized: + return + + # Pre-allocate buffer with buffer_size + # Template should be a single transition [key: shape] + self.buffer = template.expand(self.buffer_size).clone() + self._initialized = True + + def add(self, transition: TensorDict) -> None: + """Add a single transition to buffer (high-frequency async writes). + + Args: + transition: Single transition TensorDict (no batch dimension) + """ + # Lazy initialization on first add + if not self._initialized: + self._initialize_buffer(transition.to(self.device)) + + # Ensure transition is on correct device + transition = transition.to(self.device) + + # Direct index assignment (zero-copy write) + self.buffer[self.write_pos] = transition + + # Update circular index + self.write_pos = (self.write_pos + 1) % self.buffer_size + + # Update size (saturates at buffer_size) + self.size = min(self.size + 1, self.buffer_size) + self._total_added += 1 + + def add_batch(self, transitions: TensorDict) -> None: + """Add multiple transitions at once (batch write). + + Args: + transitions: Batch of transitions with shape [batch_size, ...] + """ + batch_size = transitions.batch_size[0] + + # Lazy initialization + if not self._initialized: + self._initialize_buffer(transitions[0].to(self.device)) + + transitions = transitions.to(self.device) + + # Handle circular write + for i in range(batch_size): + self.buffer[self.write_pos] = transitions[i] + self.write_pos = (self.write_pos + 1) % self.buffer_size + self.size = min(self.size + 1, self.buffer_size) + self._total_added += 1 + + def get(self, flatten: bool = True) -> TensorDict: + """Get valid data from buffer. + + Args: + flatten: If True, return flattened [size, ...]. Currently only supports True. + + Returns: + TensorDict with batch_size=[size, ...] containing valid data + """ + if not self._initialized or self.size == 0: + raise ValueError("Buffer is empty") + + if not flatten: + raise NotImplementedError("Only flatten=True is supported for VLABuffer") + + # Return first 'size' elements (valid data) + # Note: Data is in insertion order up to write_pos, then wraps + if self.size < self.buffer_size: + # Buffer not yet full, data is [0:size] + return self.buffer[: self.size] + else: + # Buffer full, need to rearrange to maintain temporal order + # Oldest data is at write_pos, newest at write_pos-1 + indices = ( + torch.arange( + self.write_pos, + self.write_pos + self.buffer_size, + device=self.device, + ) + % self.buffer_size + ) + return self.buffer[indices] + + def clear(self) -> None: + """Clear buffer (reset pointers, keep pre-allocated memory).""" + self.write_pos = 0 + self.size = 0 + # Keep buffer allocated for reuse + + def __len__(self) -> int: + """Return current number of valid transitions.""" + return self.size + + def is_full(self) -> bool: + """Check if buffer is at full buffer_size.""" + return self.size >= self.buffer_size + + def get_num_rollouts(self) -> int: + """Return 1 (buffer stores transitions, not rollouts).""" + return 1 if self.size > 0 else 0 + + def get_stats(self) -> dict: + """Get buffer statistics for logging.""" + return { + "buffer_size": self.size, + "buffer_capacity": self.buffer_size, + "total_transitions": self.size, + "total_added": self._total_added, + "buffer_usage": ( + self.size / self.buffer_size if self.buffer_size > 0 else 0.0 + ), + "write_pos": self.write_pos, + } diff --git a/embodichain/agents/rl/collector/__init__.py b/embodichain/agents/rl/collector/__init__.py new file mode 100644 index 00000000..eede4937 --- /dev/null +++ b/embodichain/agents/rl/collector/__init__.py @@ -0,0 +1,26 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base import BaseCollector +from .sync_collector import SyncCollector +from .async_collector import AsyncCollector, AsyncCollectorStats + +__all__ = [ + "BaseCollector", + "SyncCollector", + "AsyncCollector", + "AsyncCollectorStats", +] diff --git a/embodichain/agents/rl/collector/async_collector.py b/embodichain/agents/rl/collector/async_collector.py new file mode 100644 index 00000000..063c33a3 --- /dev/null +++ b/embodichain/agents/rl/collector/async_collector.py @@ -0,0 +1,296 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import threading +from typing import Callable, Optional +import torch +from tensordict import TensorDict +from collections import deque + +from ..utils.helper import dict_to_tensordict +from .base import BaseCollector + + +class AsyncCollector(BaseCollector): + """Asynchronous data collector for VLA RL scenarios. + + Runs in a background thread to continuously collect transitions while + the main thread performs model updates. Designed for scenarios where + model inference is slow (e.g., VLA) but training is fast. + + Key features: + - Background thread: Continuous data collection + - Thread-safe buffer: Lock-protected writes + - Step-level collection: Individual transitions added to buffer + - Episode statistics tracking: Rewards and lengths + + Usage: + collector = AsyncCollector(env, policy, buffer, device, ...) + collector.start() # Begin background collection + # ... main thread does training ... + collector.stop() # Stop collection + """ + + def __init__( + self, + env, + policy, + buffer, + device: torch.device, + on_step_callback: Optional[Callable] = None, + ): + """Initialize async collector. + + Args: + env: Environment to collect from + policy: Policy for action selection + buffer: VLABuffer instance (shared with Trainer) + device: Device for tensor operations + on_step_callback: Optional callback(transition, env_info) called after each step + """ + super().__init__(env, policy, device, on_step_callback) + self.buffer = buffer + + # Thread control + self._running = False + self._thread: Optional[threading.Thread] = None + self._lock = threading.Lock() + + # Episode statistics + self._episode_count = 0 + self._step_count = 0 + + def start(self): + """Start background collection thread.""" + if self._running: + raise RuntimeError("Collector is already running") + + self._running = True + self._thread = threading.Thread(target=self._collect_loop, daemon=True) + self._thread.start() + print("[AsyncCollector] Background collection started") + + def collect(self, **kwargs) -> TensorDict: + """For AsyncCollector, data is collected continuously in background. + + This method is just for interface compatibility with BaseCollector. + Actual data retrieval happens through buffer.get(). + + Returns: + Empty TensorDict (not used in async mode) + """ + raise NotImplementedError( + "AsyncCollector collects data in background thread. " + "Use buffer.get() to retrieve data instead." + ) + + def stop(self): + """Stop background collection thread.""" + if not self._running: + return + + self._running = False + if self._thread is not None: + self._thread.join(timeout=5.0) + if self._thread.is_alive(): + print("[AsyncCollector] Warning: Thread did not stop cleanly") + + print( + f"[AsyncCollector] Stopped (collected {self._step_count} steps, {self._episode_count} episodes)" + ) + + def is_running(self) -> bool: + """Check if collector is currently running.""" + return self._running + + def get_stats(self) -> dict: + """Get collection statistics.""" + with self._lock: + return { + "steps_collected": self._step_count, + "episodes_collected": self._episode_count, + } + + def _collect_loop(self): + """Background thread main loop: continuously collect transitions. + + This method runs in a separate thread and continuously: + 1. Gets action from policy + 2. Steps environment + 3. Constructs transition TensorDict + 4. Adds to buffer (thread-safe) + 5. Updates statistics + """ + current_td = self.obs_tensordict + + while self._running: + try: + # Policy forward (no_grad for inference) + with torch.no_grad(): + self.policy.train() # Use stochastic policy + self.policy.forward(current_td) + + # Extract action + action = current_td["action"] + action_type = getattr(self.env, "action_type", "delta_qpos") + action_dict = {action_type: action} + + # Environment step + next_obs_dict, reward, terminated, truncated, env_info = self.env.step( + action_dict + ) + + # Convert observation to TensorDict + next_obs_td = dict_to_tensordict(next_obs_dict, self.device) + done = terminated | truncated + next_obs_for_td = next_obs_td["observation"] + batch_size = next_obs_td.batch_size[0] + + # Build "next" TensorDict + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) + if done.dim() == 1 + else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), + device=self.device, + ) + + # Compute next value for bootstrapping (GAE computation) + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + self.policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] + + # Add "next" to current transition + current_td["next"] = next_td + + # Flatten transition for buffer (remove batch dimension for single-step storage) + # Current buffer expects transitions without batch dimension + # We need to add each parallel env's transition separately + for env_idx in range(batch_size): + transition = current_td[env_idx] # Extract single env's transition + + # Thread-safe buffer write + with self._lock: + self.buffer.add(transition) + self._step_count += 1 + + # Callback for statistics + if self.on_step_callback is not None: + self.on_step_callback(current_td, env_info) + + # Handle episode termination + if done.any(): + with self._lock: + self._episode_count += done.sum().item() + + # Prepare next observation + current_td = next_obs_td + + except Exception as e: + print(f"[AsyncCollector] Error in collection loop: {e}") + import traceback + + traceback.print_exc() + break + + print("[AsyncCollector] Collection loop exited") + + +class AsyncCollectorStats: + """Helper class to track async collection statistics safely.""" + + def __init__(self, num_envs: int, device: torch.device): + self.device = device + self.num_envs = num_envs + + # Episode tracking (on device) + self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=device) + self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=device) + + # Completed episodes (CPU) + self.ret_window = deque(maxlen=100) + self.len_window = deque(maxlen=100) + self._lock = threading.Lock() + + def update(self, reward: torch.Tensor, done: torch.Tensor): + """Update episode statistics (thread-safe). + + Args: + reward: Reward tensor [N] or [N, 1] + done: Done tensor [N] or [N, 1] + """ + # Ensure correct shape + if reward.dim() > 1: + reward = reward.squeeze(-1) + if done.dim() > 1: + done = done.squeeze(-1) + + with self._lock: + # Update cumulative stats + self.curr_ret += reward + self.curr_len += 1 + + # Handle completed episodes + done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) + if done_idx.numel() > 0: + finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() + finished_len = self.curr_len[done_idx].detach().cpu().tolist() + self.ret_window.extend(finished_ret) + self.len_window.extend(finished_len) + + # Reset for finished episodes + self.curr_ret[done_idx] = 0 + self.curr_len[done_idx] = 0 + + def get_avg_stats(self) -> tuple[float, float]: + """Get average episode return and length (thread-safe). + + Returns: + (avg_return, avg_length) or (nan, nan) if no episodes completed + """ + with self._lock: + if len(self.ret_window) == 0: + return float("nan"), float("nan") + return float(sum(self.ret_window) / len(self.ret_window)), float( + sum(self.len_window) / len(self.len_window) + ) diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py new file mode 100644 index 00000000..3f49d1e0 --- /dev/null +++ b/embodichain/agents/rl/collector/base.py @@ -0,0 +1,64 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable, Optional +import torch +from tensordict import TensorDict + +from ..utils.helper import dict_to_tensordict + + +class BaseCollector(ABC): + """Abstract base class for data collectors. + + Defines the interface that all collectors must implement. + """ + + def __init__( + self, + env, + policy, + device: torch.device, + on_step_callback: Optional[Callable] = None, + ): + """Initialize base collector. + + Args: + env: Environment to collect from + policy: Policy for action selection + device: Device for tensor operations + on_step_callback: Optional callback(tensordict, env_info) called after each step + """ + self.env = env + self.policy = policy + self.device = device + self.on_step_callback = on_step_callback + + # Initialize observation + obs_dict, _ = self.env.reset() + self.obs_tensordict = dict_to_tensordict(obs_dict, self.device) + + @abstractmethod + def collect(self, **kwargs) -> TensorDict: + """Collect data from environment. + + Returns: + TensorDict with collected data + """ + pass diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py new file mode 100644 index 00000000..4136096f --- /dev/null +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -0,0 +1,130 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +from ..utils.helper import dict_to_tensordict +from .base import BaseCollector + + +class SyncCollector(BaseCollector): + """Synchronous data collector for standard RL training. + + Collects a complete rollout of specified length, then returns it. + Used with RolloutBuffer for standard PPO training. + + Usage: + collector = SyncCollector(env, policy, device) + rollout = collector.collect(num_steps=2048) + buffer.add(rollout) + """ + + def collect(self, num_steps: int) -> TensorDict: + """Collect a synchronous rollout. + + Args: + num_steps: Number of steps to collect + + Returns: + TensorDict with batch_size=[T, N] containing full rollout + """ + self.policy.train() + current_td = self.obs_tensordict + rollout_list = [] + + for t in range(num_steps): + # Policy forward: adds "action", "sample_log_prob", "value" to tensordict + self.policy.forward(current_td) + + # Extract action for environment step + action = current_td["action"] + action_type = getattr(self.env, "action_type", "delta_qpos") + action_dict = {action_type: action} + + # Environment step - returns tuple (env returns dict, not TensorDict) + next_obs, reward, terminated, truncated, env_info = self.env.step( + action_dict + ) + + # Convert env dict observation to TensorDict at boundary + next_obs_td = dict_to_tensordict(next_obs, self.device) + + # Build "next" TensorDict + done = terminated | truncated + next_obs_for_td = next_obs_td["observation"] + + # Ensure batch_size consistency - use next_obs_td's batch_size + batch_size = next_obs_td.batch_size[0] + + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) if done.dim() == 1 else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), + device=self.device, + ) + + # Compute next value for GAE (bootstrap value) + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + self.policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] + + # Add "next" to current tensordict + current_td["next"] = next_td + + # Store complete transition + rollout_list.append(current_td.clone()) + + # Callback for statistics and logging + if self.on_step_callback is not None: + self.on_step_callback(current_td, env_info) + + # Prepare next iteration - use the converted TensorDict + current_td = next_obs_td + + # Update observation for next collection + self.obs_tensordict = current_td + + # Stack into [T, N, ...] TensorDict + rollout = torch.stack(rollout_list, dim=0) + + return rollout diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 669e2b33..e996d2d8 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -18,11 +18,11 @@ from typing import Dict, Type import torch -from gymnasium import spaces from .actor_critic import ActorCritic from .policy import Policy from .mlp import MLP +from .vla_policy import VLAPolicy, build_vla_policy, load_vla_model # In-module policy registry _POLICY_REGISTRY: Dict[str, Type[Policy]] = {} @@ -44,13 +44,26 @@ def get_policy_class(name: str) -> Type[Policy] | None: def build_policy( policy_block: dict, - obs_space: spaces.Space, - action_space: spaces.Space, + action_dim: int, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, ) -> Policy: - """Build policy strictly from json-like block: { name: ..., cfg: {...} }""" + """Build policy from json-like block. + + With TensorDict architecture, we only need action_dim. + Observations are handled via TensorDict structure. + + Args: + policy_block: Config dict with 'name' key + action_dim: Dimension of action space + device: Device to place policy on + actor: Actor network (required for actor_critic) + critic: Critic network (required for actor_critic) + + Returns: + Initialized Policy instance + """ name = policy_block["name"].lower() if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) @@ -63,9 +76,14 @@ def build_policy( raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) - return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) + return policy_cls( + action_dim=action_dim, device=device, actor=actor, critic=critic + ) + elif name == "vla": + return build_vla_policy(policy_block, action_dim, device) else: - return policy_cls(obs_space, action_space, device) + # Other policies should also use action_dim signature + return policy_cls(action_dim=action_dim, device=device) def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: @@ -88,12 +106,16 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: # default registrations register_policy("actor_critic", ActorCritic) +register_policy("vla", VLAPolicy) __all__ = [ "ActorCritic", + "VLAPolicy", "register_policy", "get_registered_policy_names", "build_policy", + "build_vla_policy", + "load_vla_model", "build_mlp_from_cfg", "get_policy_class", "Policy", diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 1c40043a..f404d41e 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -16,11 +16,12 @@ from __future__ import annotations -from typing import Dict, Any, Tuple +from typing import Dict, Any import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict from .mlp import MLP from .policy import Policy @@ -28,31 +29,31 @@ class ActorCritic(Policy): """Actor-Critic with learnable log_std for Gaussian policy. - This is a placeholder implementation of the Policy interface that: - - Encapsulates MLP networks (actor + critic) that need to be trained by RL algorithms + Uses TensorDict for all data I/O following TorchRL conventions. + This implementation: + - Encapsulates MLP networks (actor + critic) trained by RL algorithms - Handles internal computation: MLP output → mean + learnable log_std → Normal distribution - - Provides a uniform interface for RL algorithms (PPO, SAC, etc.) + - Provides a uniform TensorDict-based interface for RL algorithms (PPO, SAC, etc.) This allows seamless swapping with other policy implementations (e.g., VLAPolicy) without modifying RL algorithm code. Implements: - - get_action(obs, deterministic=False) -> (action, log_prob, value) - - get_value(obs) - - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + - forward(tensordict) -> tensordict (adds action, sample_log_prob, value) + - get_value(tensordict) -> tensordict (adds value) + - evaluate_actions(tensordict) -> tensordict (adds sample_log_prob, entropy, value) """ def __init__( self, - obs_space, - action_space, + action_dim: int, device: torch.device, actor: nn.Module, critic: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + # Observation handling done via TensorDict - no need for obs_space + self.action_dim = action_dim self.device = device # Require external injection of actor and critic @@ -66,31 +67,137 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 + def _extract_obs_tensor(self, tensordict: TensorDict) -> torch.Tensor: + """Extract observation as flat tensor from TensorDict. + + For nested TensorDict observations, flattens all leaf tensors. + For plain tensor observations, returns as is. + + Args: + tensordict: Input TensorDict with "observation" key + + Returns: + Flattened observation tensor + """ + obs = tensordict["observation"] + + # Handle nested TensorDict by collecting all leaf tensors + obs_list = [] + + def _collect(item): + # Duck typing: if it has keys(), treat as TensorDict + if hasattr(item, "keys"): + for key in sorted(item.keys()): + _collect(item[key]) + else: + # Leaf tensor + obs_list.append(item.flatten(start_dim=1)) + + _collect(obs) + + if len(obs_list) == 0: + raise ValueError("No tensors found in observation") + elif len(obs_list) == 1: + return obs_list[0] + else: + return torch.cat(obs_list, dim=-1) + @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + """Forward pass: sample action and compute value (in-place modification). + + Args: + tensordict: Must contain "observation" key + deterministic: If True, use mean instead of sampling + + Returns: + Same tensordict with added keys: + - "action": Sampled or deterministic action + - "sample_log_prob": Log probability of action + - "value": Value estimate + - "loc": Distribution mean + - "scale": Distribution std + """ + obs_tensor = self._extract_obs_tensor(tensordict) + + # Actor forward + mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) + dist = Normal(mean, std) - action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return action, log_prob, value + + # Sample action or use mean + if deterministic: + action = mean + else: + dist = Normal(mean, std) + action = dist.sample() + + # Compute log probability + log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) + + # Critic forward - keep shape [N, 1] for consistency with reward/done + value = self.critic(obs_tensor) + + # Add to tensordict (in-place) + tensordict["action"] = action + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = value + tensordict["loc"] = mean + tensordict["scale"] = std + + return tensordict @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return self.critic(obs).squeeze(-1) + def get_value(self, tensordict: TensorDict) -> TensorDict: + """Get value estimate for observations (in-place modification). + + Args: + tensordict: Must contain "observation" key + + Returns: + Same tensordict with added key: + - "value": Value estimate + """ + obs_tensor = self._extract_obs_tensor(tensordict) + value = self.critic(obs_tensor) # Keep shape [N, 1] + tensordict["value"] = value + return tensordict + + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + """Evaluate actions for policy gradient computation (in-place modification). - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + Args: + tensordict: Must contain "observation" and "action" keys + + Returns: + Same tensordict with added keys: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of action distribution + - "value": Value estimate + """ + obs_tensor = self._extract_obs_tensor(tensordict) + actions = tensordict["action"] + + # Actor forward + mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return log_prob, entropy, value + + # Evaluate given actions - keep shape [N, 1] for consistency + log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True) + entropy = dist.entropy().sum(dim=-1, keepdim=True) + + # Critic forward - keep shape [N, 1] + value = self.critic(obs_tensor) + + # Add to tensordict (in-place) + tensordict["sample_log_prob"] = log_prob + tensordict["entropy"] = entropy + tensordict["value"] = value + + return tensordict diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index cd21d0f7..68b909d1 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -19,13 +19,15 @@ This module defines an abstract Policy base class that all RL policies must inherit from. A Policy encapsulates the neural networks and exposes a uniform interface for RL algorithms (e.g., PPO, SAC) to interact with. + +All data I/O now uses TensorDict for structured, extensible data flow. """ from __future__ import annotations -from typing import Tuple from abc import ABC, abstractmethod import torch.nn as nn +from tensordict import TensorDict import torch @@ -37,6 +39,7 @@ class Policy(nn.Module, ABC): - Encapsulates neural networks that are trained by RL algorithms - Handles internal computations (e.g., network output → distribution) - Provides a uniform interface for algorithms (PPO, SAC, etc.) + - Uses TensorDict for all inputs and outputs (no tensor fallback) """ device: torch.device @@ -46,49 +49,54 @@ def __init__(self) -> None: super().__init__() @abstractmethod - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sample an action from the policy. + def forward(self, tensordict: TensorDict) -> TensorDict: + """Forward pass that adds action to the input tensordict (in-place). + + This is the main inference method following TorchRL conventions. Args: - obs: Observation tensor of shape (batch_size, obs_dim) - deterministic: If True, return the mean action; otherwise sample + tensordict: Input TensorDict containing at minimum: + - "observation": Observation tensor or nested TensorDict Returns: - Tuple of (action, log_prob, value): - - action: Sampled action tensor of shape (batch_size, action_dim) - - log_prob: Log probability of the action, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + The same TensorDict (modified in-place) with added fields: + - "action": Sampled action tensor + - "sample_log_prob": Log probability of the sampled action + - "value": Value estimate (optional, for actor-critic) + - "loc": Distribution mean (optional, for continuous actions) + - "scale": Distribution std (optional, for continuous actions) """ raise NotImplementedError @abstractmethod - def get_value(self, obs: torch.Tensor) -> torch.Tensor: + def get_value(self, tensordict: TensorDict) -> TensorDict: """Get value estimate for given observations. Args: - obs: Observation tensor of shape (batch_size, obs_dim) + tensordict: Input TensorDict containing: + - "observation": Observation data Returns: - Value estimate tensor of shape (batch_size,) + TensorDict with added field: + - "value": Value estimate tensor """ raise NotImplementedError @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: """Evaluate actions and compute log probabilities, entropy, and values. + Used during policy updates to recompute action probabilities. + Args: - obs: Observation tensor of shape (batch_size, obs_dim) - actions: Action tensor of shape (batch_size, action_dim) + tensordict: Input TensorDict containing: + - "observation": Observation data + - "action": Actions to evaluate Returns: - Tuple of (log_prob, entropy, value): - - log_prob: Log probability of actions, shape (batch_size,) - - entropy: Entropy of the action distribution, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + TensorDict with added fields: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of the action distribution + - "value": Value estimate """ raise NotImplementedError diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py new file mode 100644 index 00000000..63bbeeab --- /dev/null +++ b/embodichain/agents/rl/models/vla_policy.py @@ -0,0 +1,238 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""VLA Policy for RL training with pretrained models. + +This module provides VLAPolicy that inherits from Policy base class, +just like ActorCritic. VLAPolicy loads pretrained VLA model components +and exposes the same interface as other policies. +""" + +from __future__ import annotations + +from typing import Optional +import torch +import torch.nn as nn +from tensordict import TensorDict + +from .policy import Policy + + +class VLAPolicy(Policy): + """VLA Policy that loads pretrained vision-language-action models. + + Similar to ActorCritic, this class inherits from Policy and implements + the required methods. The difference is that VLAPolicy loads pretrained + model components instead of training from scratch. + + VLA model components are loaded by the VLA team's implementation and + should provide the necessary interfaces for action generation and value + estimation. + """ + + def __init__( + self, + action_dim: int, + device: torch.device, + vla_model: nn.Module, + ): + """Initialize VLA policy with pretrained model. + + Args: + action_dim: Dimension of action space + device: Device to place policy on + vla_model: Pretrained VLA model (vision encoder, language model, + action head, value head, etc.) + """ + super().__init__() + self.action_dim = action_dim + self.device = device + + # Store VLA model + self.vla_model = vla_model + self.vla_model.to(self.device) + + @torch.no_grad() + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + """Forward pass: generate action and value from VLA model. + + Args: + tensordict: Must contain "observation" key with observation data + deterministic: If True, use deterministic actions (passed to VLA model) + + Returns: + Same tensordict with added keys: + - "action": Sampled or deterministic action + - "sample_log_prob": Log probability of action + - "value": Value estimate + """ + # VLA team should implement forward logic here + # This is a template - actual implementation depends on VLA model structure + obs = tensordict["observation"] + + # Example: VLA model generates action and value + action, log_prob, value = self.vla_model(obs, deterministic=deterministic) + + tensordict["action"] = action + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = value.squeeze(-1) + + return tensordict + + @torch.no_grad() + def get_value(self, tensordict: TensorDict) -> TensorDict: + """Get value estimate from VLA model. + + Args: + tensordict: Must contain "observation" key + + Returns: + Same tensordict with added "value" key + """ + obs = tensordict["observation"] + + # VLA team implements value computation + value = self.vla_model.get_value(obs) + + tensordict["value"] = value.squeeze(-1) + return tensordict + + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + """Evaluate actions using VLA model. + + Args: + tensordict: Must contain: + - "observation": Observation data + - "action": Actions to evaluate + + Returns: + Same tensordict with added keys: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of action distribution + - "value": Value estimate + """ + obs = tensordict["observation"] + actions = tensordict["action"] + + # VLA team implements action evaluation + log_prob, entropy, value = self.vla_model.evaluate_actions(obs, actions) + + tensordict["sample_log_prob"] = log_prob + tensordict["entropy"] = entropy + tensordict["value"] = value.squeeze(-1) + + return tensordict + + +def load_vla_model( + model_path: str, + model_class: Optional[str] = None, + model_config: Optional[dict] = None, + device: torch.device = torch.device("cpu"), +) -> nn.Module: + """Load VLA model from checkpoint. + + This function should be implemented by the VLA team to load their + pretrained VLA model (vision encoder, language model, action head, etc.). + + The returned module should have methods: + - forward(obs) -> (action, log_prob, value) + - get_value(obs) -> value + - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + + Args: + model_path: Path to checkpoint file + model_class: Fully qualified class name for VLA model + model_config: Configuration dict for model initialization + device: Device to load model on + + Returns: + Initialized VLA model module + + Example implementation by VLA team: + ```python + def load_vla_model(model_path, model_class, model_config, device): + import importlib + + # Import VLA model class + module_name, class_name = model_class.rsplit(".", 1) + module = importlib.import_module(module_name) + ModelClass = getattr(module, class_name) + + # Initialize model + model = ModelClass(**model_config) + + # Load checkpoint + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + + model.to(device) + model.eval() + + return model + ``` + """ + raise NotImplementedError( + "load_vla_model() must be implemented. " + f"Model path: {model_path}, class: {model_class}, config: {model_config}" + ) + + +def build_vla_policy( + policy_block: dict, + action_dim: int, + device: torch.device, +) -> VLAPolicy: + """Build VLA policy from configuration. + + Args: + policy_block: Configuration dict + action_dim: Dimension of action space + device: Device to place policy on + + Returns: + Initialized VLAPolicy instance + """ + vla_config = policy_block.get("vla_config") + if vla_config is None: + raise ValueError("VLA policy requires 'vla_config' in policy block") + + model_path = vla_config.get("model_path") + if model_path is None: + raise ValueError("VLA config requires 'model_path'") + + model_class = vla_config.get("model_class") + model_config = vla_config.get("model_config", {}) + model_config["action_dim"] = action_dim + + # Load VLA model + vla_model = load_vla_model( + model_path=model_path, + model_class=model_class, + model_config=model_config, + device=device, + ) + + # Create VLAPolicy instance + policy = VLAPolicy( + action_dim=action_dim, + device=device, + vla_model=vla_model, + ) + + return policy diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index 0f766954..a634f359 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -64,7 +64,8 @@ def train_from_config(config_path: str): seed = int(trainer_cfg.get("seed", 1)) device_str = trainer_cfg.get("device", "cpu") iterations = int(trainer_cfg.get("iterations", 250)) - rollout_steps = int(trainer_cfg.get("rollout_steps", 2048)) + buffer_size = int(trainer_cfg.get("buffer_size", 2048)) + model_type = trainer_cfg.get("model_type", "standard") enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) @@ -175,13 +176,36 @@ def train_from_config(config_path: str): # Build Policy via registry policy_name = policy_block["name"] - # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic) - if policy_name.lower() == "actor_critic": - # Get observation dimension from flattened observation space - # flattened_observation_space returns Box space for RL training - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] + # Get action_dim from config (required) + action_dim = policy_block.get("action_dim") + if action_dim is None: + raise ValueError( + "Missing 'action_dim' in policy config. " + "With TensorDict architecture, action dimension must be explicitly specified in config. " + 'Example: {"policy": {"name": "actor_critic", "action_dim": 7, ...}}' + ) + + # Infer obs_dim from environment sampling (no gym space dependency) + # Env returns dict, we process it to infer dimensions + sample_obs, _ = env.reset() + + # Get obs_dim by flattening observation structure (env returns dict) + obs_list = [] + + def _collect(item): + """Recursively collect tensors from dict or direct tensor.""" + if hasattr(item, "keys"): # It's a dict + for key in sorted(item.keys()): + _collect(item[key]) + else: # It's a Tensor + obs_list.append(item.flatten(start_dim=1)) + + _collect(sample_obs) + obs_dim = sum(t.shape[-1] for t in obs_list) + + # Build policy based on type + if policy_name.lower() == "actor_critic": actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: @@ -194,16 +218,19 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, - env.action_space, - device, + action_dim=action_dim, + device=device, actor=actor, critic=critic, ) - else: - policy = build_policy( - policy_block, env.flattened_observation_space, env.action_space, device + elif policy_name.lower() == "vla": + # VLA policy loads pretrained model from checkpoint + logger.info( + f"Loading VLA model from config: {policy_block.get('vla_config', {})}" ) + policy = build_policy(policy_block, action_dim=action_dim, device=device) + else: + policy = build_policy(policy_block, action_dim=action_dim, device=device) # Build Algorithm via factory algo_name = algo_block["name"].lower() @@ -254,7 +281,7 @@ def train_from_config(config_path: str): policy=policy, env=env, algorithm=algo, - num_steps=rollout_steps, + buffer_size=buffer_size, batch_size=algo_cfg["batch_size"], writer=writer, eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled @@ -266,6 +293,7 @@ def train_from_config(config_path: str): event_cfg=train_event_cfg, eval_event_cfg=eval_event_cfg if enable_eval else {}, num_eval_episodes=num_eval_episodes, + model_type=model_type, ) logger.log_info("Generic training initialized") @@ -277,7 +305,7 @@ def train_from_config(config_path: str): f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" ) - total_steps = int(iterations * rollout_steps * env.num_envs) + total_steps = int(iterations * buffer_size * env.num_envs) logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})") try: diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py index e6f9e57a..852cdaf9 100644 --- a/embodichain/agents/rl/utils/__init__.py +++ b/embodichain/agents/rl/utils/__init__.py @@ -15,9 +15,12 @@ # ---------------------------------------------------------------------------- from .config import AlgorithmCfg -from .helper import flatten_dict_observation +from .helper import dict_to_tensordict, mean_scalar, pack_log_dict, compute_gae __all__ = [ "AlgorithmCfg", - "flatten_dict_observation", + "dict_to_tensordict", + "mean_scalar", + "pack_log_dict", + "compute_gae", ] diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index 3021a31f..17919144 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -14,39 +14,167 @@ # limitations under the License. # ---------------------------------------------------------------------------- -import torch +"""Helper utilities for RL training. +This module provides utility functions for RL algorithms. +""" -def flatten_dict_observation(input_dict: dict) -> torch.Tensor: - """ - Flatten hierarchical dict observations from ObservationManager. +import torch +import numpy as np +from tensordict import TensorDict - Recursively traverse nested dicts, collect all tensor values, - flatten each to (num_envs, -1), and concatenate in sorted key order. + +def dict_to_tensordict(obs_dict: dict, device: torch.device) -> TensorDict: + """Convert nested dict observation to TensorDict recursively. Args: - input_dict: Nested dict structure, e.g. {"robot": {"qpos": tensor, "ee_pos": tensor}, "object": {...}} + obs_dict: Nested observation dictionary + device: Device to place tensors on Returns: - Concatenated flat tensor of shape (num_envs, total_dim) + TensorDict with nested structure preserved and "observation" key """ - obs_list = [] - def _collect_tensors(d, prefix=""): - """Recursively collect tensors from nested dicts in sorted order.""" - for key in sorted(d.keys()): - full_key = f"{prefix}/{key}" if prefix else key - value = d[key] + def _recursive_convert(d): + """Recursively convert dict to TensorDict-compatible structure.""" + result = {} + for key, value in d.items(): if isinstance(value, dict): - _collect_tensors(value, full_key) + # Recursively convert nested dicts + result[key] = _recursive_convert(value) elif isinstance(value, torch.Tensor): - # Flatten tensor to (num_envs, -1) shape - obs_list.append(value.flatten(start_dim=1)) + result[key] = value.to(device) + else: + result[key] = torch.tensor(value, device=device) + return result + + # Convert the observation dict structure + converted = _recursive_convert(obs_dict) + + # Infer batch_size from first tensor we find + def _get_first_tensor_batch_size(d): + """Find first tensor and get its batch dimension.""" + for value in d.values(): + if isinstance(value, torch.Tensor): + return value.shape[0] + elif isinstance(value, dict): + bs = _get_first_tensor_batch_size(value) + if bs is not None: + return bs + return None + + batch_size = _get_first_tensor_batch_size(converted) + if batch_size is None: + batch_size = 1 # Default if no tensors found + + # Wrap in TensorDict with explicit batch_size + obs_td = TensorDict(converted, batch_size=[batch_size], device=device) + + # Wrap observation in outer TensorDict with "observation" key + return TensorDict({"observation": obs_td}, batch_size=[batch_size], device=device) + + +def mean_scalar(x) -> float: + """Convert tensor or array to scalar float (mean if needed). + + Args: + x: Tensor, array, or scalar value + + Returns: + Float scalar value + """ + if hasattr(x, "detach"): + x = x.detach().cpu().numpy() + else: + x = np.asarray(x) + return float(np.mean(x)) + + +def pack_log_dict(prefix: str, data: dict) -> dict: + """Pack data dict into logging dict with prefix. + + Args: + prefix: Prefix for keys (e.g., "train", "eval") + data: Dictionary of values to pack + + Returns: + Dictionary with prefixed keys and scalar values + """ + if not isinstance(data, dict): + return {} + out = {} + for k, v in data.items(): + try: + out[f"{prefix}/{k}"] = mean_scalar(v) + except Exception: + continue + return out + + +def compute_gae( + rollout: TensorDict, + gamma: float, + gae_lambda: float, +) -> TensorDict: + """Compute Generalized Advantage Estimation (GAE) on rollout TensorDict. + + This follows the TorchRL convention where rollout has shape [T, N, ...]. + Computes advantage and value_target in-place and returns the modified TensorDict. + + Args: + rollout: TensorDict with batch_size=[T, N] containing: + - "value": Tensor[T, N, 1] - state values + - "next": TensorDict with: + - "reward": Tensor[T, N, 1] + - "done": Tensor[T, N, 1] + - "value": Tensor[T, N, 1] - next state values (bootstrapped) + gamma: Discount factor + gae_lambda: GAE lambda parameter + + Returns: + TensorDict with added keys: + - "advantage": Tensor[T, N, 1] + - "value_target": Tensor[T, N, 1] + """ + T, N = rollout.batch_size[:2] + device = rollout.device + + # Extract tensors - shape [T, N, 1] + values = rollout["value"] + rewards = rollout["next"]["reward"] + dones = rollout["next"]["done"].float() + + # Bootstrap values: use next state value from rollout["next"]["value"] + # This is computed during collection by evaluating policy on next_obs + if "value" in rollout["next"]: + bootstrap_values = rollout["next"]["value"] + else: + # If not provided, assume 0 (terminal state) + bootstrap_values = torch.zeros_like(values) + + # Compute GAE advantages using backward iteration + # advantage[t] = delta[t] + (gamma * gae_lambda) * (1 - done[t]) * advantage[t+1] + # where delta[t] = reward[t] + gamma * (1 - done[t]) * V(s_{t+1}) - V(s_t) + # V(s_{t+1}) comes from bootstrap_values[t] which was computed on next_obs[t] + + advantages = torch.zeros_like(values) + gae = torch.zeros(N, 1, device=device) + + # Iterate backwards through time + for t in reversed(range(T)): + # Compute TD error (delta) + # bootstrap_values[t] is V(s_{t+1}), the value of the next state after action at t + delta = rewards[t] + gamma * bootstrap_values[t] * (1.0 - dones[t]) - values[t] + + # Compute GAE recursively + gae = delta + gamma * gae_lambda * (1.0 - dones[t]) * gae + advantages[t] = gae - _collect_tensors(input_dict) + # Compute value targets (for value function loss) + value_targets = advantages + values - if not obs_list: - raise ValueError("No tensors found in observation dict") + # Add to rollout TensorDict (in-place) + rollout["advantage"] = advantages + rollout["value_target"] = value_targets - result = torch.cat(obs_list, dim=-1) - return result + return rollout diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 40df6d74..1d8612d1 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -23,9 +23,11 @@ from torch.utils.tensorboard import SummaryWriter from collections import deque import wandb +from tensordict import TensorDict from embodichain.lab.gym.envs.managers.event_manager import EventManager -from .helper import flatten_dict_observation +from .helper import dict_to_tensordict, mean_scalar, pack_log_dict +from ..collector import SyncCollector, AsyncCollector class Trainer: @@ -36,7 +38,7 @@ def __init__( policy, env, algorithm, - num_steps: int, + buffer_size: int, batch_size: int, writer: SummaryWriter | None, eval_freq: int, @@ -48,12 +50,14 @@ def __init__( event_cfg=None, eval_event_cfg=None, num_eval_episodes: int = 5, + # Model type: "standard" (default PPO) or "vla" + model_type: str = "standard", ): self.policy = policy self.env = env self.eval_env = eval_env self.algorithm = algorithm - self.num_steps = num_steps + self.buffer_size = buffer_size self.batch_size = batch_size self.writer = writer self.eval_freq = eval_freq @@ -63,6 +67,29 @@ def __init__( self.use_wandb = use_wandb self.num_eval_episodes = num_eval_episodes + # Buffer setup (depends on model_type) + self.model_type = model_type + device = ( + algorithm.device + if hasattr(algorithm, "device") + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + if model_type == "vla": + # VLA model: accumulate multiple rollouts with FIFO buffer + from embodichain.agents.rl.buffer import VLABuffer + + self.buffer = VLABuffer(buffer_size=buffer_size, device=device) + elif model_type == "standard": + # Standard PPO model: single rollout, use and discard + from embodichain.agents.rl.buffer import RolloutBuffer + + self.buffer = RolloutBuffer(buffer_size=buffer_size, device=device) + else: + raise ValueError( + f"Unknown model_type: {model_type}. Use 'standard' or 'vla'." + ) + if event_cfg is not None: self.event_manager = EventManager(event_cfg, env=self.env) if eval_event_cfg is not None: @@ -75,85 +102,46 @@ def __init__( self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) - # initial obs (assume env returns torch tensors already on target device) + # Initialize observation - will be used by collectors obs, _ = self.env.reset() + self.obs_tensordict = dict_to_tensordict(obs, self.device) + num_envs = self.obs_tensordict.batch_size[0] - # Initialize algorithm's buffer - # Flatten dict observations from ObservationManager to tensor for RL algorithms - if isinstance(obs, dict): - obs_tensor = flatten_dict_observation(obs) - obs_dim = obs_tensor.shape[-1] - num_envs = obs_tensor.shape[0] - # Store flattened observation for RL training - self.obs = obs_tensor - - action_space = getattr(self.env, "action_space", None) - action_dim = action_space.shape[-1] if action_space else None - if action_dim is None: - raise RuntimeError( - "Env must expose action_space with shape for buffer initialization." - ) - - # Algorithm manages its own buffer - self.algorithm.initialize_buffer(num_steps, num_envs, obs_dim, action_dim) - - # episode stats tracked on device to avoid repeated CPU round-trips + # Episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) # ---- lightweight helpers for dense logging ---- - @staticmethod - def _mean_scalar(x) -> float: - if hasattr(x, "detach"): - x = x.detach().cpu().numpy() - else: - x = np.asarray(x) - return float(np.mean(x)) - def _log_scalar_dict(self, prefix: str, data: dict): if not self.writer or not isinstance(data, dict): return for k, v in data.items(): try: self.writer.add_scalar( - f"{prefix}/{k}", self._mean_scalar(v), self.global_step + f"{prefix}/{k}", mean_scalar(v), self.global_step ) except Exception: continue - def _pack_log_dict(self, prefix: str, data: dict) -> dict: - if not isinstance(data, dict): - return {} - out = {} - for k, v in data.items(): - try: - out[f"{prefix}/{k}"] = self._mean_scalar(v) - except Exception: - continue - return out - - def train(self, total_timesteps: int): - print(f"Start training, total steps: {total_timesteps}") - while self.global_step < total_timesteps: - self._collect_rollout() - losses = self.algorithm.update() - self._log_train(losses) - if ( - self.eval_freq > 0 - and self.eval_env is not None - and self.global_step % self.eval_freq == 0 - ): - self._eval_once(num_episodes=self.num_eval_episodes) - if self.global_step % self.save_freq == 0: - self.save_checkpoint() + def _create_step_callback(self) -> Callable: + """Create step callback for collectors. - @torch.no_grad() - def _collect_rollout(self): - """Collect a rollout. Algorithm controls the data collection process.""" + Returns: + Callback function compatible with both sync and async collectors + """ - # Callback function for statistics and logging - def on_step(obs, actions, reward, done, info, next_obs): + def on_step(tensordict: TensorDict, env_info: dict): """Callback called at each step during rollout collection.""" + # Extract reward and done from next subdictionary + reward = tensordict["next"]["reward"] + done = tensordict["next"]["done"] + + # Squeeze if needed + if reward.dim() > 1: + reward = reward.squeeze(-1) + if done.dim() > 1: + done = done.squeeze(-1) + # Episode stats (stay on device; convert only when episode ends) self.curr_ret += reward self.curr_len += 1 @@ -166,30 +154,113 @@ def on_step(obs, actions, reward, done, info, next_obs): self.curr_ret[done_idx] = 0 self.curr_len[done_idx] = 0 - # Update global step and observation - # next_obs is already flattened in algorithm's collect_rollout - self.obs = next_obs - self.global_step += next_obs.shape[0] - - if isinstance(info, dict): - rewards_dict = info.get("rewards") - metrics_dict = info.get("metrics") + # Log environment metrics + if isinstance(env_info, dict): + rewards_dict = env_info.get("rewards") + metrics_dict = env_info.get("metrics") self._log_scalar_dict("rewards", rewards_dict) self._log_scalar_dict("metrics", metrics_dict) log_dict = {} - log_dict.update(self._pack_log_dict("rewards", rewards_dict)) - log_dict.update(self._pack_log_dict("metrics", metrics_dict)) + log_dict.update(pack_log_dict("rewards", rewards_dict)) + log_dict.update(pack_log_dict("metrics", metrics_dict)) if log_dict and self.use_wandb: wandb.log(log_dict, step=self.global_step) - # Algorithm controls data collection - result = self.algorithm.collect_rollout( - env=self.env, - policy=self.policy, - obs=self.obs, - num_steps=self.num_steps, - on_step_callback=on_step, - ) + return on_step + + def train(self, total_timesteps: int): + print(f"Start training, total steps: {total_timesteps}") + print(f"Model type: {self.model_type}") + + if self.model_type == "vla": + collector = AsyncCollector( + env=self.env, + policy=self.policy, + buffer=self.buffer, + device=self.device, + on_step_callback=self._create_step_callback(), + ) + self._train_async(collector, total_timesteps) + else: + collector = SyncCollector( + env=self.env, + policy=self.policy, + device=self.device, + on_step_callback=self._create_step_callback(), + ) + self._train_sync(collector, total_timesteps) + + def _train_sync(self, collector: SyncCollector, total_timesteps: int): + """Synchronous training loop (standard PPO).""" + while self.global_step < total_timesteps: + # Collect rollout + rollout = collector.collect(num_steps=self.buffer_size) + + # Update global step (main thread only) + num_steps = rollout.batch_size[0] # T dimension + num_envs = rollout.batch_size[1] if len(rollout.batch_size) > 1 else 1 + self.global_step += num_steps * num_envs + + self.buffer.add(rollout) + + # Train when buffer is full + if self.buffer.is_full(): + data = self.buffer.get(flatten=True) + losses = self.algorithm.update(data) + self._log_train(losses) + + # Evaluation + if ( + self.eval_freq > 0 + and self.eval_env is not None + and self.global_step % self.eval_freq == 0 + ): + self._eval_once(num_episodes=self.num_eval_episodes) + + # Checkpoint + if self.global_step % self.save_freq == 0: + self.save_checkpoint() + + def _train_async(self, collector: AsyncCollector, total_timesteps: int): + """Asynchronous training loop (VLA mode).""" + collector.start() + print("[Trainer] Async collector started") + + try: + while self.global_step < total_timesteps: + # Wait for buffer to fill + while not self.buffer.is_full(): + time.sleep(0.1) + if not collector.is_running(): + raise RuntimeError("Async collector stopped unexpectedly") + + # Get data and train + data = self.buffer.get(flatten=True) + + # Update global step based on collected data (main thread only) + batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 + self.global_step += batch_size + + losses = self.algorithm.update(data) + self._log_train(losses) + + # Evaluation (pause collector during eval) + if ( + self.eval_freq > 0 + and self.eval_env is not None + and self.global_step % self.eval_freq == 0 + ): + collector.stop() + self._eval_once(num_episodes=self.num_eval_episodes) + collector.start() + + # Checkpoint + if self.global_step % self.save_freq == 0: + self.save_checkpoint() + + finally: + collector.stop() + print("[Trainer] Async collector stopped") def _log_train(self, losses: Dict[str, float]): if self.writer: @@ -243,10 +314,10 @@ def _eval_once(self, num_episodes: int = 5): episode_lengths = [] for _ in range(num_episodes): - # Reset and initialize episode tracking + # Reset and initialize episode tracking - env returns dict, convert at boundary obs, _ = self.eval_env.reset() - obs = flatten_dict_observation(obs) - num_envs = obs.shape[0] if obs.ndim == 2 else 1 + obs = dict_to_tensordict(obs, self.device) + num_envs = obs.batch_size[0] done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( @@ -256,16 +327,20 @@ def _eval_once(self, num_episodes: int = 5): # Run episode until all environments complete while not done_mask.all(): - # Get deterministic actions from policy - actions, _, _ = self.policy.get_action(obs, deterministic=True) + # Get deterministic actions for evaluation + obs_copy = obs.clone() + self.policy.forward(obs_copy, deterministic=True) + actions = obs_copy["action"] + action_type = getattr(self.eval_env, "action_type", "delta_qpos") action_dict = {action_type: actions} - # Environment step - obs, reward, terminated, truncated, info = self.eval_env.step( + # Environment step - env returns dict, convert to TensorDict at boundary + next_obs, reward, terminated, truncated, info = self.eval_env.step( action_dict ) - obs = flatten_dict_observation(obs) if isinstance(obs, dict) else obs + next_obs = dict_to_tensordict(next_obs, self.device) + obs = next_obs # Update statistics only for still-running environments done = terminated | truncated diff --git a/pyproject.toml b/pyproject.toml index 0b4624d7..84328fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", + "tensordict>=0.5.0", # For TensorDict-based RL data structures ] [project.optional-dependencies] diff --git a/tests/agents/test_rl.py b/tests/agents/test_rl.py index d12cc10f..a4951fae 100644 --- a/tests/agents/test_rl.py +++ b/tests/agents/test_rl.py @@ -70,7 +70,7 @@ def setup_method(self): test_train_config = train_config.copy() test_train_config["trainer"]["gym_config"] = self.temp_gym_config_path test_train_config["trainer"]["iterations"] = 2 - test_train_config["trainer"]["rollout_steps"] = 32 + test_train_config["trainer"]["buffer_size"] = 32 test_train_config["trainer"]["eval_freq"] = 1000000 # Disable eval test_train_config["trainer"]["save_freq"] = 1000000 # Disable save test_train_config["trainer"]["headless"] = True