Conversation
There was a problem hiding this comment.
Pull request overview
This pull request introduces a comprehensive refactoring of the RL training framework to use TensorDict-based data flow, replacing the previous tensor-based approach. The PR adds support for two training modes: standard synchronous PPO and asynchronous VLA training designed for scenarios with slow model inference.
Changes:
- Migrated entire RL pipeline to TensorDict-based architecture for structured, extensible data flow
- Introduced dual buffer system: RolloutBuffer (standard) and VLABuffer (async with FIFO)
- Added AsyncCollector for background data collection in VLA mode with thread-based parallelism
- Refactored Policy interface to use TensorDict inputs/outputs with in-place modifications
- Updated PPO algorithm to work with TensorDict rollouts and removed dependency on gym spaces
- Modified configuration to use
buffer_sizeinstead ofrollout_stepsand addedaction_dimrequirement
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| embodichain/agents/rl/utils/trainer.py | Refactored to support dual training modes (sync/async) with TensorDict |
| embodichain/agents/rl/utils/helper.py | Added dict_to_tensordict, compute_gae, and logging utilities |
| embodichain/agents/rl/utils/async_collector.py | New async data collector for VLA mode with background thread |
| embodichain/agents/rl/buffer/rollout_buffer.py | Renamed/refactored to VLABuffer with circular indexing |
| embodichain/agents/rl/buffer/standard_buffer.py | New RolloutBuffer for standard PPO mode |
| embodichain/agents/rl/buffer/init.py | Updated exports for dual buffer system |
| embodichain/agents/rl/algo/ppo.py | Refactored to use TensorDict data flow throughout |
| embodichain/agents/rl/algo/base.py | Updated base algorithm interface for TensorDict |
| embodichain/agents/rl/models/policy.py | Changed interface to TensorDict-based methods |
| embodichain/agents/rl/models/actor_critic.py | Implemented TensorDict-based policy with in-place modifications |
| embodichain/agents/rl/models/init.py | Removed gymnasium dependency, added action_dim parameter |
| embodichain/agents/rl/train.py | Added action_dim requirement, removed gym space dependency |
| tests/agents/test_rl.py | Updated test to use buffer_size parameter |
| configs/agents/rl/push_cube/train_config.json | Updated config with buffer_size, action_dim, and eval_freq |
| configs/agents/rl/basic/cart_pole/train_config.json | Updated config with buffer_size |
| docs/source/tutorial/rl.rst | Updated documentation to reference buffer_size |
| pyproject.toml | Added tensordict>=0.5.0 dependency |
Comments suppressed due to low confidence (1)
embodichain/agents/rl/train.py:289
- The
buffer_typeparameter is not read from the trainer config and not passed to the Trainer constructor (line 273-289). This means the VLA async mode introduced in this PR cannot be used, as it will always default to "standard" mode. Addbuffer_type = trainer_cfg.get("buffer_type", "standard")before the Trainer initialization and pass it asbuffer_type=buffer_typeto the Trainer constructor.
trainer = Trainer(
policy=policy,
env=env,
algorithm=algo,
buffer_size=buffer_size,
batch_size=algo_cfg["batch_size"],
writer=writer,
eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled
save_freq=save_freq,
checkpoint_dir=checkpoint_dir,
exp_name=exp_name,
use_wandb=use_wandb,
eval_env=eval_env, # None if enable_eval=False
event_cfg=train_event_cfg,
eval_event_cfg=eval_event_cfg if enable_eval else {},
num_eval_episodes=num_eval_episodes,
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # Update global step | ||
| num_envs = tensordict.batch_size[0] | ||
| self.global_step += num_envs |
There was a problem hiding this comment.
The self.global_step variable is updated from the async collector thread (line 182 via callback) and potentially read from the main thread (lines 214, 244, 255). This creates a race condition. Consider using a thread-safe counter (e.g., threading.Lock protection or multiprocessing.Value) or tracking steps only in one thread.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get(self, flatten: bool = True) -> TensorDict: | ||
| """Get valid data from buffer. | ||
|
|
||
| Args: | ||
| flatten: If True, return flattened [size, ...]. Currently only supports True. | ||
|
|
||
| Returns: | ||
| TensorDict with batch_size=[size, ...] containing valid data | ||
| """ | ||
| if not self._initialized or self.size == 0: | ||
| raise ValueError("Buffer is empty") | ||
|
|
||
| if not flatten: | ||
| raise NotImplementedError("Only flatten=True is supported for VLABuffer") | ||
|
|
||
| # Return first 'size' elements (valid data) | ||
| # Note: Data is in insertion order up to write_pos, then wraps | ||
| if self.size < self.buffer_size: | ||
| # Buffer not yet full, data is [0:size] | ||
| return self.buffer[: self.size] | ||
| else: | ||
| # Buffer full, need to rearrange to maintain temporal order | ||
| # Oldest data is at write_pos, newest at write_pos-1 | ||
| indices = ( | ||
| torch.arange( | ||
| self.write_pos, | ||
| self.write_pos + self.buffer_size, | ||
| device=self.device, | ||
| ) | ||
| % self.buffer_size | ||
| ) | ||
| return self.buffer[indices] | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear buffer (reset pointers, keep pre-allocated memory).""" | ||
| self.write_pos = 0 | ||
| self.size = 0 | ||
| # Keep buffer allocated for reuse | ||
|
|
||
| def __len__(self) -> int: | ||
| """Return current number of valid transitions.""" | ||
| return self.size | ||
|
|
||
| def is_full(self) -> bool: | ||
| """Check if buffer is at full buffer_size.""" | ||
| return self.size >= self.buffer_size |
There was a problem hiding this comment.
The VLABuffer.get() and is_full() methods are called from the main thread while AsyncCollector writes to the buffer from a background thread, but these methods lack thread safety. The read of self.size and self.write_pos could return inconsistent values if a write is in progress. Additionally, buffer.get() performs complex operations (checking size, slicing buffer) that should be atomic with respect to concurrent writes. Consider adding thread synchronization or document that external locking is required.
| if deterministic: | ||
| action = mean | ||
| else: | ||
| dist = Normal(mean, std) |
There was a problem hiding this comment.
The distribution is created twice when deterministic=False. Line 130 creates dist = Normal(mean, std), then lines 136-137 create it again. This is wasteful. Consider refactoring to create the distribution once and use either dist.mean or dist.sample() based on the deterministic flag.
| dist = Normal(mean, std) |
| next_value_td = TensorDict( | ||
| {"observation": next_obs_for_td}, | ||
| batch_size=next_td.batch_size, | ||
| device=self.device, | ||
| ) | ||
| self.policy.get_value(next_value_td) | ||
| next_td["value"] = next_value_td["value"] |
There was a problem hiding this comment.
The policy is accessed from both the background collector thread (lines 145-146, 200) and potentially from the main training thread during algorithm.update(). PyTorch tensors and models are not thread-safe by default. Concurrent access to the policy parameters during forward passes and gradient updates can lead to race conditions and corrupted gradients. Consider using locks to synchronize policy access, or ensure the policy is not being updated while the collector is running (e.g., by stopping collection during training).
| next_value_td = TensorDict( | |
| {"observation": next_obs_for_td}, | |
| batch_size=next_td.batch_size, | |
| device=self.device, | |
| ) | |
| self.policy.get_value(next_value_td) | |
| next_td["value"] = next_value_td["value"] | |
| # Protect policy access with lock to avoid races with training thread | |
| with self._lock: | |
| next_value_td = TensorDict( | |
| {"observation": next_obs_for_td}, | |
| batch_size=next_td.batch_size, | |
| device=self.device, | |
| ) | |
| self.policy.get_value(next_value_td) | |
| next_td["value"] = next_value_td["value"] |
|
|
||
| losses = self.algorithm.update(data) | ||
| self._log_train(losses) | ||
|
|
There was a problem hiding this comment.
After buffer.get() is called on line 238, the VLABuffer is not cleared (unlike RolloutBuffer which auto-clears). Since the buffer is full (size == buffer_size), the is_full() check on line 232 will immediately return True in the next iteration, causing the training loop to repeatedly train on the same data without waiting for new transitions. The buffer should be cleared after get(), or the is_full() logic should be modified to track whether data has been consumed.
| # Clear async buffer after consumption to avoid retraining on stale data | |
| if hasattr(self.buffer, "clear"): | |
| self.buffer.clear() |
| # Prepare next iteration - use the converted TensorDict | ||
| current_td = next_obs_td |
There was a problem hiding this comment.
The collector does not handle episode resets when done=True. After an episode terminates (done flag is set), the environment should be reset to get a fresh initial observation for the next episode. Currently, the collector continues using next_obs even after termination, which could contain stale data. Most RL environments auto-reset on episode end, but this should be made explicit or documented as a requirement.
| # Prepare next iteration - use the converted TensorDict | |
| current_td = next_obs_td | |
| # Prepare next iteration: | |
| # - if episode is done, reset env to get a fresh initial observation | |
| # - otherwise, continue from next_obs_td | |
| if done.any(): | |
| reset_result = self.env.reset() | |
| # Support both Gym/Gymnasium-style (obs, info) and plain-obs resets | |
| if isinstance(reset_result, tuple): | |
| reset_obs = reset_result[0] | |
| else: | |
| reset_obs = reset_result | |
| current_td = dict_to_tensordict(reset_obs, self.device) | |
| else: | |
| current_td = next_obs_td |
| # Store complete transition | ||
| rollout_list.append(current_td.clone()) |
There was a problem hiding this comment.
Calling .clone() on every transition creates a full copy of the TensorDict including all nested tensors, which can be memory-intensive for large rollouts. Since current_td is reassigned to next_obs_td on line 122 (which is a fresh TensorDict), the clone may be unnecessary. Consider whether a shallow copy or reference would suffice, or document why deep cloning is required here.
| # Store complete transition | |
| rollout_list.append(current_td.clone()) | |
| # Store complete transition (no clone needed: current_td is not mutated afterwards) | |
| rollout_list.append(current_td) |
| # Update global step based on collected data (main thread only) | ||
| batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 | ||
| self.global_step += batch_size |
There was a problem hiding this comment.
The global_step update in async mode only counts batch_size from the returned data (line 242), not the actual number of environment steps taken. Since VLABuffer is continuously being written to by AsyncCollector (which tracks steps in _step_count), the global_step will not accurately reflect the total number of environment interactions. Consider synchronizing global_step with the collector's _step_count, or documenting this discrepancy.
| # Update global step based on collected data (main thread only) | |
| batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 | |
| self.global_step += batch_size | |
| # Update global step. | |
| # Prefer the collector's step count (actual env interactions) if available, | |
| # otherwise fall back to counting processed batch size. | |
| batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 | |
| steps_from_collector = getattr(collector, "_step_count", None) | |
| if isinstance(steps_from_collector, int) and steps_from_collector > self.global_step: | |
| self.global_step = steps_from_collector | |
| else: | |
| self.global_step += batch_size |
| @@ -166,30 +154,113 @@ def on_step(obs, actions, reward, done, info, next_obs): | |||
| self.curr_ret[done_idx] = 0 | |||
| self.curr_len[done_idx] = 0 | |||
|
|
|||
| # Update global step and observation | |||
| # next_obs is already flattened in algorithm's collect_rollout | |||
| self.obs = next_obs | |||
| self.global_step += next_obs.shape[0] | |||
|
|
|||
| if isinstance(info, dict): | |||
| rewards_dict = info.get("rewards") | |||
| metrics_dict = info.get("metrics") | |||
| # Log environment metrics | |||
| if isinstance(env_info, dict): | |||
| rewards_dict = env_info.get("rewards") | |||
| metrics_dict = env_info.get("metrics") | |||
| self._log_scalar_dict("rewards", rewards_dict) | |||
| self._log_scalar_dict("metrics", metrics_dict) | |||
| log_dict = {} | |||
| log_dict.update(self._pack_log_dict("rewards", rewards_dict)) | |||
| log_dict.update(self._pack_log_dict("metrics", metrics_dict)) | |||
| log_dict.update(pack_log_dict("rewards", rewards_dict)) | |||
| log_dict.update(pack_log_dict("metrics", metrics_dict)) | |||
| if log_dict and self.use_wandb: | |||
| wandb.log(log_dict, step=self.global_step) | |||
There was a problem hiding this comment.
The on_step callback modifies shared state (self.curr_ret, self.curr_len, self.ret_window, self.len_window, self.global_step) without thread synchronization. In async mode, this callback runs in the AsyncCollector background thread while the main thread could be accessing these same variables (e.g., in _log_train). This can cause race conditions and data corruption. Use threading.Lock to protect access to these shared variables, or ensure they're only accessed from one thread.
| def collect(self, **kwargs) -> TensorDict: | ||
| """Collect data from environment. | ||
|
|
There was a problem hiding this comment.
Overridden method signature does not match call, where it is passed too many arguments. Overriding method method SyncCollector.collect matches the call.
Overridden method signature does not match call, where it is passed an argument named 'num_steps'. Overriding method method SyncCollector.collect matches the call.
| def collect(self, **kwargs) -> TensorDict: | |
| """Collect data from environment. | |
| def collect(self, num_steps: int, **kwargs) -> TensorDict: | |
| """Collect data from environment. | |
| Args: | |
| num_steps: Number of steps to collect. |
| def collect(self, num_steps: int) -> TensorDict: | ||
| """Collect a synchronous rollout. | ||
|
|
||
| Args: | ||
| num_steps: Number of steps to collect | ||
|
|
||
| Returns: | ||
| TensorDict with batch_size=[T, N] containing full rollout | ||
| """ |
There was a problem hiding this comment.
This method requires 2 positional arguments, whereas overridden BaseCollector.collect requires 1.
| def collect(self, num_steps: int) -> TensorDict: | |
| """Collect a synchronous rollout. | |
| Args: | |
| num_steps: Number of steps to collect | |
| Returns: | |
| TensorDict with batch_size=[T, N] containing full rollout | |
| """ | |
| def collect(self, num_steps: int | None = None) -> TensorDict: | |
| """Collect a synchronous rollout. | |
| Args: | |
| num_steps: Number of steps to collect. | |
| Returns: | |
| TensorDict with batch_size=[T, N] containing full rollout | |
| """ | |
| if num_steps is None: | |
| raise TypeError("num_steps must be provided for SyncCollector.collect()") |
RL Training Framework Guide
TensorDict-based RL framework supporting standard PPO and asynchronous VLA training.
Quick Start
Configuration
{ "trainer": { "buffer_size": 2048, "model_type": "standard" // or "vla" }, "policy": {"name": "actor_critic"}, "algorithm": { "name": "ppo", "cfg": { "learning_rate": 3e-4, "gamma": 0.99, "n_epochs": 10, "batch_size": 64 } } }Run Training
Architecture
Components:
Training Modes
Standard Mode (Default)
For: Normal models (<100ms inference/step)
Config:
{"trainer": {"model_type": "standard"}}Pros: Simple, stable, low memory, no staleness
VLA Async Mode
For: Large models (>1 sec inference/step)
Config:
{"trainer": {"model_type": "vla"}}Pros: 2-3x speedup via parallel collection
Cons: Data staleness, higher memory
Collectors
SyncCollector
Collects complete rollout synchronously:
AsyncCollector
Runs in background thread:
Buffers
RolloutBuffer (Standard)
Single-use buffer:
VLABuffer (Async)
Circular FIFO buffer:
Circular behavior:
[T0,T1,T2,T3]→ add T4 →[T4,T1,T2,T3](T0 overwritten)VLA Integration
1. Implement Model
2. Implement Loading
Edit
embodichain/agents/rl/models/vla_policy.py:3. Configure
{ "trainer": {"model_type": "vla"}, "policy": { "name": "vla", "vla_config": { "model_path": "checkpoints/vla.pt", "model_class": "MyVLAModel", "model_config": {} } } }Common APIs
Trainer
Buffer Methods
Algorithm
FAQ
Q: When use VLA mode?
A: Inference >100ms/step AND GPU training fast
Q: Buffer size?
A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity)
Q: Data staleness impact?
A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty
Q: Debug data flow?
A:
buffer.get_stats()or_print_tensordict_tree(rollout)in ppo.pyWorkflows
Standard
VLA
File Structure
References
configs/agents/rl/