Skip to content

async collect buffer for VLA RL#122

Open
yhnsu wants to merge 3 commits intomainfrom
yhn/rl_vla
Open

async collect buffer for VLA RL#122
yhnsu wants to merge 3 commits intomainfrom
yhn/rl_vla

Conversation

@yhnsu
Copy link
Collaborator

@yhnsu yhnsu commented Feb 6, 2026

RL Training Framework Guide

TensorDict-based RL framework supporting standard PPO and asynchronous VLA training.


Quick Start

Configuration

{
  "trainer": {
    "buffer_size": 2048,
    "model_type": "standard"  // or "vla"
  },
  "policy": {"name": "actor_critic"},
  "algorithm": {
    "name": "ppo",
    "cfg": {
      "learning_rate": 3e-4,
      "gamma": 0.99,
      "n_epochs": 10,
      "batch_size": 64
    }
  }
}

Run Training

python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json

Architecture

Trainer → Collector (sync/async) → Buffer (standard/vla) → Algorithm (PPO)

Components:

  • Collector: Gather data from environment (SyncCollector / AsyncCollector)
  • Buffer: Store transitions (RolloutBuffer / VLABuffer)
  • Algorithm: Update policy (PPO)
  • Trainer: Coordinate training loop

Training Modes

Standard Mode (Default)

For: Normal models (<100ms inference/step)

SyncCollector → Collect 2048 steps → Train → Clear buffer → Repeat

Config: {"trainer": {"model_type": "standard"}}

Pros: Simple, stable, low memory, no staleness

VLA Async Mode

For: Large models (>1 sec inference/step)

Background: AsyncCollector → Continuously collect → VLABuffer
Main:       Wait for buffer full → Train → Repeat

Config: {"trainer": {"model_type": "vla"}}

Pros: 2-3x speedup via parallel collection
Cons: Data staleness, higher memory


Collectors

SyncCollector

Collects complete rollout synchronously:

from embodichain.agents.rl.collector import SyncCollector

collector = SyncCollector(env, policy, device, callback)
rollout = collector.collect(num_steps=2048)  # [T, N, ...]

AsyncCollector

Runs in background thread:

from embodichain.agents.rl.collector import AsyncCollector

collector = AsyncCollector(env, policy, buffer, device, callback)
collector.start()   # Begin background collection
# ... buffer fills automatically ...
collector.stop()    # Stop collection

Buffers

RolloutBuffer (Standard)

Single-use buffer:

from embodichain.agents.rl.buffer import RolloutBuffer

buffer = RolloutBuffer(buffer_size=2048, device=device)
buffer.add(rollout)  # [T, N, ...]
data = buffer.get(flatten=True)  # [T*N, ...], auto-clears

VLABuffer (Async)

Circular FIFO buffer:

from embodichain.agents.rl.buffer import VLABuffer

buffer = VLABuffer(buffer_size=4096, device=device)
buffer.add(transition)  # Single step
data = buffer.get(flatten=True)  # [buffer_size, ...] when full

Circular behavior: [T0,T1,T2,T3] → add T4 → [T4,T1,T2,T3] (T0 overwritten)


VLA Integration

1. Implement Model

class MyVLAModel(nn.Module):
    def forward(self, obs: TensorDict) -> TensorDict:
        # Add 'action', 'sample_log_prob', 'value'
        ...
    def get_value(self, obs: TensorDict) -> TensorDict:
        # Add 'value'
        ...
    def evaluate_actions(self, obs: TensorDict) -> TensorDict:
        # Add 'sample_log_prob', 'entropy', 'value'
        ...

2. Implement Loading

Edit embodichain/agents/rl/models/vla_policy.py:

def load_vla_model(model_path, model_class, model_config, device):
    model = MyVLAModel(**model_config)
    model.load_state_dict(torch.load(model_path))
    return model.to(device)

3. Configure

{
  "trainer": {"model_type": "vla"},
  "policy": {
    "name": "vla",
    "vla_config": {
      "model_path": "checkpoints/vla.pt",
      "model_class": "MyVLAModel",
      "model_config": {}
    }
  }
}

Common APIs

Trainer

from embodichain.agents.rl.utils import Trainer

trainer = Trainer(
    policy, env, algorithm,
    buffer_size=2048,
    model_type="standard",  # or "vla"
    ...
)
trainer.train(total_timesteps=1000000)

Buffer Methods

buffer.add(data)            # Add data
data = buffer.get(flatten=True)  # Retrieve data
buffer.is_full()            # Check ready status
buffer.clear()              # Clear buffer
buffer.get_stats()          # Statistics

Algorithm

from embodichain.agents.rl.algo import PPO, PPOCfg

algorithm = PPO(PPOCfg(...), policy)
losses = algorithm.update(rollout)  # Returns loss dict

FAQ

Q: When use VLA mode?
A: Inference >100ms/step AND GPU training fast

Q: Buffer size?
A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity)

Q: Data staleness impact?
A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty

Q: Debug data flow?
A: buffer.get_stats() or _print_tensordict_tree(rollout) in ppo.py


Workflows

Standard

collector = SyncCollector(env, policy, device, callback)
while step < total:
    rollout = collector.collect(num_steps=2048)
    buffer.add(rollout)
    data = buffer.get(flatten=True)
    losses = algorithm.update(data)

VLA

collector = AsyncCollector(env, policy, buffer, device, callback)
collector.start()
while step < total:
    while not buffer.is_full():
        time.sleep(0.1)
    data = buffer.get(flatten=True)
    losses = algorithm.update(data)
collector.stop()

File Structure

embodichain/agents/rl/
├── train.py              # Entry point
├── algo/ppo.py          # PPO algorithm
├── buffer/
│   ├── standard_buffer.py  # RolloutBuffer
│   └── vla_buffer.py       # VLABuffer
├── collector/
│   ├── base.py             # BaseCollector
│   ├── sync_collector.py   # SyncCollector
│   └── async_collector.py  # AsyncCollector
├── models/
│   ├── actor_critic.py     # Standard policy
│   └── vla_policy.py       # VLA wrapper
└── utils/trainer.py     # Training coordinator

References

Copilot AI review requested due to automatic review settings February 6, 2026 04:22
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request introduces a comprehensive refactoring of the RL training framework to use TensorDict-based data flow, replacing the previous tensor-based approach. The PR adds support for two training modes: standard synchronous PPO and asynchronous VLA training designed for scenarios with slow model inference.

Changes:

  • Migrated entire RL pipeline to TensorDict-based architecture for structured, extensible data flow
  • Introduced dual buffer system: RolloutBuffer (standard) and VLABuffer (async with FIFO)
  • Added AsyncCollector for background data collection in VLA mode with thread-based parallelism
  • Refactored Policy interface to use TensorDict inputs/outputs with in-place modifications
  • Updated PPO algorithm to work with TensorDict rollouts and removed dependency on gym spaces
  • Modified configuration to use buffer_size instead of rollout_steps and added action_dim requirement

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
embodichain/agents/rl/utils/trainer.py Refactored to support dual training modes (sync/async) with TensorDict
embodichain/agents/rl/utils/helper.py Added dict_to_tensordict, compute_gae, and logging utilities
embodichain/agents/rl/utils/async_collector.py New async data collector for VLA mode with background thread
embodichain/agents/rl/buffer/rollout_buffer.py Renamed/refactored to VLABuffer with circular indexing
embodichain/agents/rl/buffer/standard_buffer.py New RolloutBuffer for standard PPO mode
embodichain/agents/rl/buffer/init.py Updated exports for dual buffer system
embodichain/agents/rl/algo/ppo.py Refactored to use TensorDict data flow throughout
embodichain/agents/rl/algo/base.py Updated base algorithm interface for TensorDict
embodichain/agents/rl/models/policy.py Changed interface to TensorDict-based methods
embodichain/agents/rl/models/actor_critic.py Implemented TensorDict-based policy with in-place modifications
embodichain/agents/rl/models/init.py Removed gymnasium dependency, added action_dim parameter
embodichain/agents/rl/train.py Added action_dim requirement, removed gym space dependency
tests/agents/test_rl.py Updated test to use buffer_size parameter
configs/agents/rl/push_cube/train_config.json Updated config with buffer_size, action_dim, and eval_freq
configs/agents/rl/basic/cart_pole/train_config.json Updated config with buffer_size
docs/source/tutorial/rl.rst Updated documentation to reference buffer_size
pyproject.toml Added tensordict>=0.5.0 dependency
Comments suppressed due to low confidence (1)

embodichain/agents/rl/train.py:289

  • The buffer_type parameter is not read from the trainer config and not passed to the Trainer constructor (line 273-289). This means the VLA async mode introduced in this PR cannot be used, as it will always default to "standard" mode. Add buffer_type = trainer_cfg.get("buffer_type", "standard") before the Trainer initialization and pass it as buffer_type=buffer_type to the Trainer constructor.
    trainer = Trainer(
        policy=policy,
        env=env,
        algorithm=algo,
        buffer_size=buffer_size,
        batch_size=algo_cfg["batch_size"],
        writer=writer,
        eval_freq=eval_freq if enable_eval else 0,  # Disable eval if not enabled
        save_freq=save_freq,
        checkpoint_dir=checkpoint_dir,
        exp_name=exp_name,
        use_wandb=use_wandb,
        eval_env=eval_env,  # None if enable_eval=False
        event_cfg=train_event_cfg,
        eval_event_cfg=eval_event_cfg if enable_eval else {},
        num_eval_episodes=num_eval_episodes,
    )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


# Update global step
num_envs = tensordict.batch_size[0]
self.global_step += num_envs
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.global_step variable is updated from the async collector thread (line 182 via callback) and potentially read from the main thread (lines 214, 244, 255). This creates a race condition. Consider using a thread-safe counter (e.g., threading.Lock protection or multiprocessing.Value) or tracking steps only in one thread.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 6, 2026 07:51
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +113 to +158
def get(self, flatten: bool = True) -> TensorDict:
"""Get valid data from buffer.

Args:
flatten: If True, return flattened [size, ...]. Currently only supports True.

Returns:
TensorDict with batch_size=[size, ...] containing valid data
"""
if not self._initialized or self.size == 0:
raise ValueError("Buffer is empty")

if not flatten:
raise NotImplementedError("Only flatten=True is supported for VLABuffer")

# Return first 'size' elements (valid data)
# Note: Data is in insertion order up to write_pos, then wraps
if self.size < self.buffer_size:
# Buffer not yet full, data is [0:size]
return self.buffer[: self.size]
else:
# Buffer full, need to rearrange to maintain temporal order
# Oldest data is at write_pos, newest at write_pos-1
indices = (
torch.arange(
self.write_pos,
self.write_pos + self.buffer_size,
device=self.device,
)
% self.buffer_size
)
return self.buffer[indices]

def clear(self) -> None:
"""Clear buffer (reset pointers, keep pre-allocated memory)."""
self.write_pos = 0
self.size = 0
# Keep buffer allocated for reuse

def __len__(self) -> int:
"""Return current number of valid transitions."""
return self.size

def is_full(self) -> bool:
"""Check if buffer is at full buffer_size."""
return self.size >= self.buffer_size
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VLABuffer.get() and is_full() methods are called from the main thread while AsyncCollector writes to the buffer from a background thread, but these methods lack thread safety. The read of self.size and self.write_pos could return inconsistent values if a write is in progress. Additionally, buffer.get() performs complex operations (checking size, slicing buffer) that should be atomic with respect to concurrent writes. Consider adding thread synchronization or document that external locking is required.

Copilot uses AI. Check for mistakes.
if deterministic:
action = mean
else:
dist = Normal(mean, std)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distribution is created twice when deterministic=False. Line 130 creates dist = Normal(mean, std), then lines 136-137 create it again. This is wasteful. Consider refactoring to create the distribution once and use either dist.mean or dist.sample() based on the deterministic flag.

Suggested change
dist = Normal(mean, std)

Copilot uses AI. Check for mistakes.
Comment on lines +195 to +201
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The policy is accessed from both the background collector thread (lines 145-146, 200) and potentially from the main training thread during algorithm.update(). PyTorch tensors and models are not thread-safe by default. Concurrent access to the policy parameters during forward passes and gradient updates can lead to race conditions and corrupted gradients. Consider using locks to synchronize policy access, or ensure the policy is not being updated while the collector is running (e.g., by stopping collection during training).

Suggested change
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]
# Protect policy access with lock to avoid races with training thread
with self._lock:
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]

Copilot uses AI. Check for mistakes.

losses = self.algorithm.update(data)
self._log_train(losses)

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After buffer.get() is called on line 238, the VLABuffer is not cleared (unlike RolloutBuffer which auto-clears). Since the buffer is full (size == buffer_size), the is_full() check on line 232 will immediately return True in the next iteration, causing the training loop to repeatedly train on the same data without waiting for new transitions. The buffer should be cleared after get(), or the is_full() logic should be modified to track whether data has been consumed.

Suggested change
# Clear async buffer after consumption to avoid retraining on stale data
if hasattr(self.buffer, "clear"):
self.buffer.clear()

Copilot uses AI. Check for mistakes.
Comment on lines +121 to +122
# Prepare next iteration - use the converted TensorDict
current_td = next_obs_td
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collector does not handle episode resets when done=True. After an episode terminates (done flag is set), the environment should be reset to get a fresh initial observation for the next episode. Currently, the collector continues using next_obs even after termination, which could contain stale data. Most RL environments auto-reset on episode end, but this should be made explicit or documented as a requirement.

Suggested change
# Prepare next iteration - use the converted TensorDict
current_td = next_obs_td
# Prepare next iteration:
# - if episode is done, reset env to get a fresh initial observation
# - otherwise, continue from next_obs_td
if done.any():
reset_result = self.env.reset()
# Support both Gym/Gymnasium-style (obs, info) and plain-obs resets
if isinstance(reset_result, tuple):
reset_obs = reset_result[0]
else:
reset_obs = reset_result
current_td = dict_to_tensordict(reset_obs, self.device)
else:
current_td = next_obs_td

Copilot uses AI. Check for mistakes.
Comment on lines +114 to +115
# Store complete transition
rollout_list.append(current_td.clone())
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .clone() on every transition creates a full copy of the TensorDict including all nested tensors, which can be memory-intensive for large rollouts. Since current_td is reassigned to next_obs_td on line 122 (which is a fresh TensorDict), the clone may be unnecessary. Consider whether a shallow copy or reference would suffice, or document why deep cloning is required here.

Suggested change
# Store complete transition
rollout_list.append(current_td.clone())
# Store complete transition (no clone needed: current_td is not mutated afterwards)
rollout_list.append(current_td)

Copilot uses AI. Check for mistakes.
Comment on lines +240 to +242
# Update global step based on collected data (main thread only)
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
self.global_step += batch_size
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global_step update in async mode only counts batch_size from the returned data (line 242), not the actual number of environment steps taken. Since VLABuffer is continuously being written to by AsyncCollector (which tracks steps in _step_count), the global_step will not accurately reflect the total number of environment interactions. Consider synchronizing global_step with the collector's _step_count, or documenting this discrepancy.

Suggested change
# Update global step based on collected data (main thread only)
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
self.global_step += batch_size
# Update global step.
# Prefer the collector's step count (actual env interactions) if available,
# otherwise fall back to counting processed batch size.
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
steps_from_collector = getattr(collector, "_step_count", None)
if isinstance(steps_from_collector, int) and steps_from_collector > self.global_step:
self.global_step = steps_from_collector
else:
self.global_step += batch_size

Copilot uses AI. Check for mistakes.
Comment on lines 146 to 167
@@ -166,30 +154,113 @@ def on_step(obs, actions, reward, done, info, next_obs):
self.curr_ret[done_idx] = 0
self.curr_len[done_idx] = 0

# Update global step and observation
# next_obs is already flattened in algorithm's collect_rollout
self.obs = next_obs
self.global_step += next_obs.shape[0]

if isinstance(info, dict):
rewards_dict = info.get("rewards")
metrics_dict = info.get("metrics")
# Log environment metrics
if isinstance(env_info, dict):
rewards_dict = env_info.get("rewards")
metrics_dict = env_info.get("metrics")
self._log_scalar_dict("rewards", rewards_dict)
self._log_scalar_dict("metrics", metrics_dict)
log_dict = {}
log_dict.update(self._pack_log_dict("rewards", rewards_dict))
log_dict.update(self._pack_log_dict("metrics", metrics_dict))
log_dict.update(pack_log_dict("rewards", rewards_dict))
log_dict.update(pack_log_dict("metrics", metrics_dict))
if log_dict and self.use_wandb:
wandb.log(log_dict, step=self.global_step)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The on_step callback modifies shared state (self.curr_ret, self.curr_len, self.ret_window, self.len_window, self.global_step) without thread synchronization. In async mode, this callback runs in the AsyncCollector background thread while the main thread could be accessing these same variables (e.g., in _log_train). This can cause race conditions and data corruption. Use threading.Lock to protect access to these shared variables, or ensure they're only accessed from one thread.

Copilot uses AI. Check for mistakes.
Comment on lines +58 to +60
def collect(self, **kwargs) -> TensorDict:
"""Collect data from environment.

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overridden method signature does not match call, where it is passed too many arguments. Overriding method method SyncCollector.collect matches the call.
Overridden method signature does not match call, where it is passed an argument named 'num_steps'. Overriding method method SyncCollector.collect matches the call.

Suggested change
def collect(self, **kwargs) -> TensorDict:
"""Collect data from environment.
def collect(self, num_steps: int, **kwargs) -> TensorDict:
"""Collect data from environment.
Args:
num_steps: Number of steps to collect.

Copilot uses AI. Check for mistakes.
Comment on lines +38 to +46
def collect(self, num_steps: int) -> TensorDict:
"""Collect a synchronous rollout.

Args:
num_steps: Number of steps to collect

Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method requires 2 positional arguments, whereas overridden BaseCollector.collect requires 1.

Suggested change
def collect(self, num_steps: int) -> TensorDict:
"""Collect a synchronous rollout.
Args:
num_steps: Number of steps to collect
Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
def collect(self, num_steps: int | None = None) -> TensorDict:
"""Collect a synchronous rollout.
Args:
num_steps: Number of steps to collect.
Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
if num_steps is None:
raise TypeError("num_steps must be provided for SyncCollector.collect()")

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant