diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index 0f766954..e1f8ff29 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -32,7 +32,7 @@ from embodichain.agents.rl.utils.trainer import Trainer from embodichain.utils import logger from embodichain.lab.gym.envs.tasks.rl import build_env -from embodichain.lab.gym.utils.gym_utils import config_to_cfg +from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES from embodichain.utils.utility import load_json from embodichain.utils.module_utils import find_function_from_modules from embodichain.lab.sim import SimulationManagerCfg @@ -133,11 +133,9 @@ def train_from_config(config_path: str): logger.log_info(f"Current working directory: {Path.cwd()}") gym_config_data = load_json(str(gym_config_path)) - gym_env_cfg = config_to_cfg(gym_config_data) - - # Override num_envs from train config if provided - if num_envs is not None: - gym_env_cfg.num_envs = num_envs + gym_env_cfg = config_to_cfg( + gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES + ) # Ensure sim configuration mirrors runtime overrides if gym_env_cfg.sim_cfg is None: diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index fadf26d3..38f2c2d9 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -133,6 +133,9 @@ def __init__( self._num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device ) + self._task_success = torch.zeros( + self._num_envs, dtype=torch.bool, device=self.device + ) # The UIDs of objects that are detached from automatic reset. self._detached_uids_for_reset: List[str] = [] @@ -485,6 +488,20 @@ def get_reward( return rewards + def is_task_success(self, **kwargs) -> torch.Tensor: + """ + Determine if the task is successfully completed. This is mainly used in the data generation process + of the imitation learning. + + Args: + **kwargs: Additional arguments for task-specific success criteria. + + Returns: + torch.Tensor: A boolean tensor indicating success for each environment in the batch. + """ + + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + def _preprocess_action(self, action: EnvAction) -> EnvAction: """Preprocess action before sending to robot. @@ -534,6 +551,10 @@ def reset( "reset_ids", torch.arange(self.num_envs, dtype=torch.int32, device=self.device), ) + + # Save task success status before resetting objects + self._task_success = self.is_task_success() + self.sim.reset_objects_state( env_ids=reset_ids, excluded_uids=self._detached_uids_for_reset ) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 98355f62..ef339123 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -379,15 +379,13 @@ def _initialize_episode( if save_data and self.cfg.dataset: if "save" in self.dataset_manager.available_modes: - current_task_success = self.is_task_success() - # Filter to only save successful episodes successful_env_ids = [ env_id for env_id in env_ids_to_process if ( self.episode_success_status.get(env_id, False) - or current_task_success[env_id].item() + or self._task_success[env_id].item() ) ] @@ -589,20 +587,6 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None "The method 'create_demo_action_list' must be implemented in subclasses." ) - def is_task_success(self, **kwargs) -> torch.Tensor: - """ - Determine if the task is successfully completed. This is mainly used in the data generation process - of the imitation learning. - - Args: - **kwargs: Additional arguments for task-specific success criteria. - - Returns: - torch.Tensor: A boolean tensor indicating success for each environment in the batch. - """ - - return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) - def close(self) -> None: """Close the environment and release resources.""" # Finalize dataset if present diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 1cc1ba1e..3555322c 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -28,6 +28,16 @@ from embodichain.utils.utility import get_class_instance from dexsim.utility import log_debug, log_error +# Default manager modules for config parsing +DEFAULT_MANAGER_MODULES = [ + "embodichain.lab.gym.envs.managers.datasets", + "embodichain.lab.gym.envs.managers.randomization", + "embodichain.lab.gym.envs.managers.record", + "embodichain.lab.gym.envs.managers.events", + "embodichain.lab.gym.envs.managers.observations", + "embodichain.lab.gym.envs.managers.rewards", +] + def get_dtype_bounds(dtype: np.dtype): """Gets the min and max values of a given numpy type""" @@ -323,11 +333,13 @@ def cat_tensor_with_ids( return out -def config_to_cfg(config: dict) -> "EmbodiedEnvCfg": +def config_to_cfg(config: dict, manager_modules: list = None) -> "EmbodiedEnvCfg": """Parser configuration file into cfgs for env initialization. Args: config (dict): The configuration dictionary containing robot, sensor, light, background, and interactive objects. + manager_modules (list): List of module paths for dataset, event, observation, and reward managers. + If not provided, uses default module paths. Returns: EmbodiedEnvCfg: A configuration object for initializing the environment. @@ -437,13 +449,19 @@ class ComponentCfg: env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) env_cfg.extensions = deepcopy(config.get("env", {}).get("extensions", {})) + # Initialize manager_modules with defaults + default_manager_modules = DEFAULT_MANAGER_MODULES.copy() + + # Extend with user-provided modules, skipping duplicates + if manager_modules is not None: + for module in manager_modules: + if module not in default_manager_modules: + default_manager_modules.append(module) + + manager_modules = default_manager_modules + env_cfg.dataset = ComponentCfg() if "dataset" in config["env"]: - # Define modules to search for dataset functions - dataset_modules = [ - "embodichain.lab.gym.envs.managers.datasets", - ] - for dataset_name, dataset_params in config["env"]["dataset"].items(): dataset_params_modified = deepcopy(dataset_params) @@ -457,7 +475,7 @@ class ComponentCfg: # Find the function from multiple modules using the utility function dataset_func = find_function_from_modules( func_name, - dataset_modules, + manager_modules, raise_if_not_found=True, ) @@ -476,13 +494,6 @@ class ComponentCfg: env_cfg.events = ComponentCfg() if "events" in config["env"]: - # Define modules to search for event functions - event_modules = [ - "embodichain.lab.gym.envs.managers.randomization", - "embodichain.lab.gym.envs.managers.record", - "embodichain.lab.gym.envs.managers.events", - ] - # parser env events config for event_name, event_params in config["env"]["events"].items(): event_params_modified = deepcopy(event_params) @@ -500,7 +511,7 @@ class ComponentCfg: # Find the function from multiple modules using the utility function event_func = find_function_from_modules( - event_params["func"], event_modules, raise_if_not_found=True + event_params["func"], manager_modules, raise_if_not_found=True ) interval_step = event_params_modified.get("interval_step", 10) @@ -514,11 +525,6 @@ class ComponentCfg: env_cfg.observations = ComponentCfg() if "observations" in config["env"]: - # Define modules to search for observation functions - observation_modules = [ - "embodichain.lab.gym.envs.managers.observations", - ] - for obs_name, obs_params in config["env"]["observations"].items(): obs_params_modified = deepcopy(obs_params) @@ -531,7 +537,7 @@ class ComponentCfg: # Find the function from multiple modules using the utility function obs_func = find_function_from_modules( obs_params["func"], - observation_modules, + manager_modules, raise_if_not_found=True, ) @@ -546,11 +552,6 @@ class ComponentCfg: env_cfg.rewards = ComponentCfg() if "rewards" in config["env"]: - # Define modules to search for reward functions - reward_modules = [ - "embodichain.lab.gym.envs.managers.rewards", - ] - for reward_name, reward_params in config["env"]["rewards"].items(): reward_params_modified = deepcopy(reward_params) @@ -573,7 +574,7 @@ class ComponentCfg: # Find the function from multiple modules using the utility function reward_func = find_function_from_modules( reward_params["func"], - reward_modules, + manager_modules, raise_if_not_found=True, ) diff --git a/embodichain/lab/scripts/preview_env.py b/embodichain/lab/scripts/preview_env.py index fde071bd..09f7cda3 100644 --- a/embodichain/lab/scripts/preview_env.py +++ b/embodichain/lab/scripts/preview_env.py @@ -23,6 +23,7 @@ from embodichain.lab.gym.envs import EmbodiedEnvCfg from embodichain.lab.gym.utils.gym_utils import ( config_to_cfg, + DEFAULT_MANAGER_MODULES, ) from embodichain.utils.utility import load_json from embodichain.utils import logger @@ -86,7 +87,9 @@ ############################################################################################## # load gym config gym_config = load_json(args.gym_config) - cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg( + gym_config, manager_modules=DEFAULT_MANAGER_MODULES + ) cfg.filter_visual_rand = args.filter_visual_rand action_config = {} diff --git a/embodichain/lab/scripts/run_agent.py b/embodichain/lab/scripts/run_agent.py index e6546a59..f8e248ec 100644 --- a/embodichain/lab/scripts/run_agent.py +++ b/embodichain/lab/scripts/run_agent.py @@ -24,6 +24,7 @@ from embodichain.lab.gym.envs import EmbodiedEnvCfg from embodichain.lab.gym.utils.gym_utils import ( config_to_cfg, + DEFAULT_MANAGER_MODULES, ) from embodichain.utils.logger import log_warning, log_info, log_error from .run_env import main @@ -120,7 +121,9 @@ agent_config = load_json(args.agent_config) # Build environment configuration - cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg( + gym_config, manager_modules=DEFAULT_MANAGER_MODULES + ) cfg.filter_visual_rand = args.filter_visual_rand cfg.num_envs = args.num_envs cfg.sim_cfg = SimulationManagerCfg( diff --git a/embodichain/lab/scripts/run_env.py b/embodichain/lab/scripts/run_env.py index 74398233..6b3844b1 100644 --- a/embodichain/lab/scripts/run_env.py +++ b/embodichain/lab/scripts/run_env.py @@ -28,6 +28,7 @@ from embodichain.lab.gym.envs import EmbodiedEnvCfg from embodichain.lab.gym.utils.gym_utils import ( config_to_cfg, + DEFAULT_MANAGER_MODULES, ) from embodichain.utils.logger import log_warning, log_info, log_error @@ -191,7 +192,9 @@ def main(args, env, gym_config): # log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") gym_config = load_json(args.gym_config) - cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg( + gym_config, manager_modules=DEFAULT_MANAGER_MODULES + ) cfg.filter_visual_rand = args.filter_visual_rand action_config = {} diff --git a/pyproject.toml b/pyproject.toml index 0b4624d7..29a2b35d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "pin", "pin-pink", "casadi", - "qpsolvers==4.8.1", + "qpsolvers[osqp]==4.8.1", "py_opw_kinematics==0.1.6", "pytorch_kinematics==0.7.6", "polars==1.31.0", diff --git a/tests/gym/envs/test_embodied_env.py b/tests/gym/envs/test_embodied_env.py index 574fd60c..baae2700 100644 --- a/tests/gym/envs/test_embodied_env.py +++ b/tests/gym/envs/test_embodied_env.py @@ -22,7 +22,7 @@ from embodichain.lab.gym.envs import EmbodiedEnvCfg from embodichain.lab.sim.objects import RigidObject, Robot -from embodichain.lab.gym.utils.gym_utils import config_to_cfg +from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES from embodichain.lab.gym.utils.registration import register_env from embodichain.lab.sim import SimulationManager, SimulationManagerCfg from embodichain.data import get_data_path @@ -120,7 +120,9 @@ class EmbodiedEnvTest: """Shared test logic for CPU and CUDA.""" def setup_simulation(self, sim_device): - cfg: EmbodiedEnvCfg = config_to_cfg(METADATA) + cfg: EmbodiedEnvCfg = config_to_cfg( + METADATA, manager_modules=DEFAULT_MANAGER_MODULES + ) cfg.num_envs = NUM_ENVS cfg.sim_cfg = SimulationManagerCfg(headless=True, sim_device=sim_device)