Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions embodichain/agents/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from embodichain.agents.rl.utils.trainer import Trainer
from embodichain.utils import logger
from embodichain.lab.gym.envs.tasks.rl import build_env
from embodichain.lab.gym.utils.gym_utils import config_to_cfg
from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
from embodichain.utils.utility import load_json
from embodichain.utils.module_utils import find_function_from_modules
from embodichain.lab.sim import SimulationManagerCfg
Expand Down Expand Up @@ -133,11 +133,9 @@ def train_from_config(config_path: str):
logger.log_info(f"Current working directory: {Path.cwd()}")

gym_config_data = load_json(str(gym_config_path))
gym_env_cfg = config_to_cfg(gym_config_data)

# Override num_envs from train config if provided
if num_envs is not None:
gym_env_cfg.num_envs = num_envs
gym_env_cfg = config_to_cfg(
gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
)

Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_envs override logic has been removed but not replaced. Previously, if num_envs was specified in the trainer config, it would override the value from the gym config. Now this override is lost, which means the trainer config's num_envs setting is ignored.

The removed code was:

if num_envs is not None:
    gym_env_cfg.num_envs = num_envs

This should be re-added after the config_to_cfg call to maintain the ability to override num_envs from the training configuration.

Suggested change
# Allow trainer config to override number of environments
num_envs = trainer_cfg.get("num_envs")
if num_envs is not None:
gym_env_cfg.num_envs = num_envs

Copilot uses AI. Check for mistakes.
# Ensure sim configuration mirrors runtime overrides
if gym_env_cfg.sim_cfg is None:
Expand Down
21 changes: 21 additions & 0 deletions embodichain/lab/gym/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def __init__(
self._num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device
)

self._task_success = torch.zeros(
self._num_envs, dtype=torch.bool, device=self.device
)
# The UIDs of objects that are detached from automatic reset.
self._detached_uids_for_reset: List[str] = []

Expand Down Expand Up @@ -485,6 +488,20 @@ def get_reward(

return rewards

def is_task_success(self, **kwargs) -> torch.Tensor:
"""
Determine if the task is successfully completed. This is mainly used in the data generation process
of the imitation learning.

Args:
**kwargs: Additional arguments for task-specific success criteria.

Returns:
torch.Tensor: A boolean tensor indicating success for each environment in the batch.
"""

return torch.ones(self.num_envs, dtype=torch.bool, device=self.device)

def _preprocess_action(self, action: EnvAction) -> EnvAction:
"""Preprocess action before sending to robot.

Expand Down Expand Up @@ -534,6 +551,10 @@ def reset(
"reset_ids",
torch.arange(self.num_envs, dtype=torch.int32, device=self.device),
)

# Save task success status before resetting objects
self._task_success = self.is_task_success()

self.sim.reset_objects_state(
env_ids=reset_ids, excluded_uids=self._detached_uids_for_reset
)
Expand Down
18 changes: 1 addition & 17 deletions embodichain/lab/gym/envs/embodied_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,15 +379,13 @@ def _initialize_episode(
if save_data and self.cfg.dataset:
if "save" in self.dataset_manager.available_modes:

current_task_success = self.is_task_success()

# Filter to only save successful episodes
successful_env_ids = [
env_id
for env_id in env_ids_to_process
if (
self.episode_success_status.get(env_id, False)
or current_task_success[env_id].item()
or self._task_success[env_id].item()
)
]

Expand Down Expand Up @@ -589,20 +587,6 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None
"The method 'create_demo_action_list' must be implemented in subclasses."
)

def is_task_success(self, **kwargs) -> torch.Tensor:
"""
Determine if the task is successfully completed. This is mainly used in the data generation process
of the imitation learning.

Args:
**kwargs: Additional arguments for task-specific success criteria.

Returns:
torch.Tensor: A boolean tensor indicating success for each environment in the batch.
"""

return torch.ones(self.num_envs, dtype=torch.bool, device=self.device)

def close(self) -> None:
"""Close the environment and release resources."""
# Finalize dataset if present
Expand Down
55 changes: 28 additions & 27 deletions embodichain/lab/gym/utils/gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
from embodichain.utils.utility import get_class_instance
from dexsim.utility import log_debug, log_error

# Default manager modules for config parsing
DEFAULT_MANAGER_MODULES = [
"embodichain.lab.gym.envs.managers.datasets",
"embodichain.lab.gym.envs.managers.randomization",
"embodichain.lab.gym.envs.managers.record",
"embodichain.lab.gym.envs.managers.events",
"embodichain.lab.gym.envs.managers.observations",
"embodichain.lab.gym.envs.managers.rewards",
]


def get_dtype_bounds(dtype: np.dtype):
"""Gets the min and max values of a given numpy type"""
Expand Down Expand Up @@ -323,11 +333,13 @@ def cat_tensor_with_ids(
return out


def config_to_cfg(config: dict) -> "EmbodiedEnvCfg":
def config_to_cfg(config: dict, manager_modules: list = None) -> "EmbodiedEnvCfg":
"""Parser configuration file into cfgs for env initialization.

Args:
config (dict): The configuration dictionary containing robot, sensor, light, background, and interactive objects.
manager_modules (list): List of module paths for dataset, event, observation, and reward managers.
If not provided, uses default module paths.

Returns:
EmbodiedEnvCfg: A configuration object for initializing the environment.
Expand Down Expand Up @@ -437,13 +449,19 @@ class ComponentCfg:
env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4)
env_cfg.extensions = deepcopy(config.get("env", {}).get("extensions", {}))

# Initialize manager_modules with defaults
default_manager_modules = DEFAULT_MANAGER_MODULES.copy()

# Extend with user-provided modules, skipping duplicates
if manager_modules is not None:
for module in manager_modules:
if module not in default_manager_modules:
default_manager_modules.append(module)

manager_modules = default_manager_modules

env_cfg.dataset = ComponentCfg()
if "dataset" in config["env"]:
# Define modules to search for dataset functions
dataset_modules = [
"embodichain.lab.gym.envs.managers.datasets",
]

for dataset_name, dataset_params in config["env"]["dataset"].items():
dataset_params_modified = deepcopy(dataset_params)

Expand All @@ -457,7 +475,7 @@ class ComponentCfg:
# Find the function from multiple modules using the utility function
dataset_func = find_function_from_modules(
func_name,
dataset_modules,
manager_modules,
raise_if_not_found=True,
)

Expand All @@ -476,13 +494,6 @@ class ComponentCfg:

env_cfg.events = ComponentCfg()
if "events" in config["env"]:
# Define modules to search for event functions
event_modules = [
"embodichain.lab.gym.envs.managers.randomization",
"embodichain.lab.gym.envs.managers.record",
"embodichain.lab.gym.envs.managers.events",
]

# parser env events config
for event_name, event_params in config["env"]["events"].items():
event_params_modified = deepcopy(event_params)
Expand All @@ -500,7 +511,7 @@ class ComponentCfg:

# Find the function from multiple modules using the utility function
event_func = find_function_from_modules(
event_params["func"], event_modules, raise_if_not_found=True
event_params["func"], manager_modules, raise_if_not_found=True
)
interval_step = event_params_modified.get("interval_step", 10)

Expand All @@ -514,11 +525,6 @@ class ComponentCfg:

env_cfg.observations = ComponentCfg()
if "observations" in config["env"]:
# Define modules to search for observation functions
observation_modules = [
"embodichain.lab.gym.envs.managers.observations",
]

for obs_name, obs_params in config["env"]["observations"].items():
obs_params_modified = deepcopy(obs_params)

Expand All @@ -531,7 +537,7 @@ class ComponentCfg:
# Find the function from multiple modules using the utility function
obs_func = find_function_from_modules(
obs_params["func"],
observation_modules,
manager_modules,
raise_if_not_found=True,
)

Expand All @@ -546,11 +552,6 @@ class ComponentCfg:

env_cfg.rewards = ComponentCfg()
if "rewards" in config["env"]:
# Define modules to search for reward functions
reward_modules = [
"embodichain.lab.gym.envs.managers.rewards",
]

for reward_name, reward_params in config["env"]["rewards"].items():
reward_params_modified = deepcopy(reward_params)

Expand All @@ -573,7 +574,7 @@ class ComponentCfg:
# Find the function from multiple modules using the utility function
reward_func = find_function_from_modules(
reward_params["func"],
reward_modules,
manager_modules,
raise_if_not_found=True,
)

Expand Down
5 changes: 4 additions & 1 deletion embodichain/lab/scripts/preview_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from embodichain.lab.gym.envs import EmbodiedEnvCfg
from embodichain.lab.gym.utils.gym_utils import (
config_to_cfg,
DEFAULT_MANAGER_MODULES,
)
from embodichain.utils.utility import load_json
from embodichain.utils import logger
Expand Down Expand Up @@ -86,7 +87,9 @@
##############################################################################################
# load gym config
gym_config = load_json(args.gym_config)
cfg: EmbodiedEnvCfg = config_to_cfg(gym_config)
cfg: EmbodiedEnvCfg = config_to_cfg(
gym_config, manager_modules=DEFAULT_MANAGER_MODULES
)
cfg.filter_visual_rand = args.filter_visual_rand

action_config = {}
Expand Down
5 changes: 4 additions & 1 deletion embodichain/lab/scripts/run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from embodichain.lab.gym.envs import EmbodiedEnvCfg
from embodichain.lab.gym.utils.gym_utils import (
config_to_cfg,
DEFAULT_MANAGER_MODULES,
)
from embodichain.utils.logger import log_warning, log_info, log_error
from .run_env import main
Expand Down Expand Up @@ -120,7 +121,9 @@
agent_config = load_json(args.agent_config)

# Build environment configuration
cfg: EmbodiedEnvCfg = config_to_cfg(gym_config)
cfg: EmbodiedEnvCfg = config_to_cfg(
gym_config, manager_modules=DEFAULT_MANAGER_MODULES
)
cfg.filter_visual_rand = args.filter_visual_rand
cfg.num_envs = args.num_envs
cfg.sim_cfg = SimulationManagerCfg(
Expand Down
5 changes: 4 additions & 1 deletion embodichain/lab/scripts/run_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from embodichain.lab.gym.envs import EmbodiedEnvCfg
from embodichain.lab.gym.utils.gym_utils import (
config_to_cfg,
DEFAULT_MANAGER_MODULES,
)
from embodichain.utils.logger import log_warning, log_info, log_error

Expand Down Expand Up @@ -191,7 +192,9 @@ def main(args, env, gym_config):
# log_error(f"Currently only support num_envs=1, but got {args.num_envs}.")

gym_config = load_json(args.gym_config)
cfg: EmbodiedEnvCfg = config_to_cfg(gym_config)
cfg: EmbodiedEnvCfg = config_to_cfg(
gym_config, manager_modules=DEFAULT_MANAGER_MODULES
)
cfg.filter_visual_rand = args.filter_visual_rand

action_config = {}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"pin",
"pin-pink",
"casadi",
"qpsolvers==4.8.1",
"qpsolvers[osqp]==4.8.1",
"py_opw_kinematics==0.1.6",
"pytorch_kinematics==0.7.6",
"polars==1.31.0",
Expand Down
6 changes: 4 additions & 2 deletions tests/gym/envs/test_embodied_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from embodichain.lab.gym.envs import EmbodiedEnvCfg
from embodichain.lab.sim.objects import RigidObject, Robot
from embodichain.lab.gym.utils.gym_utils import config_to_cfg
from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
from embodichain.lab.gym.utils.registration import register_env
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.data import get_data_path
Expand Down Expand Up @@ -120,7 +120,9 @@ class EmbodiedEnvTest:
"""Shared test logic for CPU and CUDA."""

def setup_simulation(self, sim_device):
cfg: EmbodiedEnvCfg = config_to_cfg(METADATA)
cfg: EmbodiedEnvCfg = config_to_cfg(
METADATA, manager_modules=DEFAULT_MANAGER_MODULES
)
cfg.num_envs = NUM_ENVS
cfg.sim_cfg = SimulationManagerCfg(headless=True, sim_device=sim_device)

Expand Down