diff --git a/configs/agents/rl/push_ball/gym_config.json b/configs/agents/rl/push_ball/gym_config.json new file mode 100644 index 00000000..4123889d --- /dev/null +++ b/configs/agents/rl/push_ball/gym_config.json @@ -0,0 +1,181 @@ +{ + "id": "PushBallRL", + "max_episodes": 5, + "env": { + "events": { + "randomize_ball": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": { + "uid": "soccer_ball" + }, + "position_range": [ + [-0.2, -0.2, 0.0], + [0.2, 0.2, 0.0] + ], + "relative_position": true + } + }, + "randomize_goal": { + "func": "randomize_target_pose", + "mode": "reset", + "params": { + "position_range": [ + [0.65, -0.2, 0.05], + [0.95, 0.2, 0.05] + ], + "relative_position": false, + "store_key": "goal_pose" + } + } + }, + "observations": { + "robot_qpos": { + "func": "normalize_robot_joint_data", + "mode": "modify", + "name": "robot/qpos", + "params": { + "joint_ids": [0, 1, 2, 3, 4, 5] + } + }, + "robot_ee_pos": { + "func": "get_robot_eef_pose", + "mode": "add", + "name": "robot/ee_pos", + "params": { + "part_name": "arm" + } + }, + "ball_pose": { + "func": "get_rigid_object_pose", + "mode": "add", + "name": "object/ball_pose", + "params": { + "entity_cfg": {"uid": "soccer_ball"} + } + }, + "goal_pos": { + "func": "target_position", + "mode": "add", + "name": "object/goal_pos", + "params": { + "target_pose_key": "goal_pose" + } + } + }, + "rewards": { + "reaching_reward": { + "func": "reaching_behind_object", + "mode": "add", + "weight": 0.1, + "params": { + "object_cfg": { + "uid": "soccer_ball" + }, + "target_pose_key": "goal_pose", + "behind_offset": 0.02, + "height_offset": 0.02, + "distance_scale": 5.0, + "part_name": "arm" + } + }, + "push_reward": { + "func": "incremental_distance_to_target", + "mode": "add", + "weight": 1.0, + "params": { + "source_entity_cfg": { + "uid": "soccer_ball" + }, + "target_pose_key": "goal_pose", + "tanh_scale": 10.0, + "positive_weight": 2.0, + "negative_weight": 0.5, + "use_xy_only": true + } + }, + "action_penalty": { + "func": "action_smoothness_penalty", + "mode": "add", + "weight": 0.01, + "params": {} + }, + "success_bonus": { + "func": "success_reward", + "mode": "add", + "weight": 10.0, + "params": {} + } + }, + "extensions": { + "action_type": "delta_qpos", + "episode_length": 100, + "action_scale": 0.1, + "success_threshold": 0.1 + } + }, + "robot": { + "uid": "ur10", + "urdf_cfg": { + "components": [ + { + "component_type": "arm", + "urdf_path": "UniversalRobots/UR10/UR10.urdf" + } + ] + }, + "init_pos": [0.0, 0.0, 0.0], + "init_rot": [0.0, 0.0, 0.0], + "init_qpos": [0.0, -1.57079, 1.57079, -1.57079, -1.57079, 0.0], + "drive_pros": { + "drive_type": "force", + "stiffness": 100000.0, + "damping": 1000.0, + "max_velocity": 2.0, + "max_effort": 500.0 + }, + "solver_cfg": { + "arm": { + "class_type": "PytorchSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + "tcp": [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.16], + [0.0, 0.0, 0.0, 1.0] + ] + } + }, + "control_parts": { + "arm": ["JOINT[1-6]"] + } + }, + "rigid_object": [ + { + "uid": "soccer_ball", + "shape": { + "shape_type": "Sphere", + "radius": 0.05 + }, + "body_type": "dynamic", + "init_pos": [0.35, 0.0, 0.05], + "attrs": { + "mass": 1.0, + "static_friction": 3.0, + "dynamic_friction": 2.5, + "linear_damping": 1.0, + "angular_damping": 1.0, + "restitution": 0.3, + "max_linear_velocity": 2.0, + "max_angular_velocity": 2.0 + } + } + ], + "sensor": [], + "light": {}, + "background": [], + "rigid_object_group": [], + "articulation": [] +} diff --git a/configs/agents/rl/push_ball/train_config.json b/configs/agents/rl/push_ball/train_config.json new file mode 100644 index 00000000..9c5dadc7 --- /dev/null +++ b/configs/agents/rl/push_ball/train_config.json @@ -0,0 +1,67 @@ +{ + "trainer": { + "exp_name": "push_ball_ppo", + "gym_config": "configs/agents/rl/push_ball/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 64, + "iterations": 1000, + "rollout_steps": 1024, + "eval_freq": 200, + "save_freq": 200, + "use_wandb": false, + "wandb_project_name": "embodychain-push_ball", + "events": { + "eval": { + "record_camera": { + "func": "record_camera_data_async", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "main_cam", + "resolution": [640, 480], + "eye": [-1.4, 1.4, 2.0], + "target": [0, 0, 0], + "up": [0, 0, 1], + "intrinsics": [600, 600, 320, 240], + "save_path": "./outputs/videos/eval" + } + } + } + } + }, + "policy": { + "name": "actor_critic", + "actor": { + "type": "mlp", + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } + }, + "critic": { + "type": "mlp", + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 0.0001, + "n_epochs": 10, + "batch_size": 8192, + "gamma": 0.99, + "gae_lambda": 0.95, + "clip_coef": 0.2, + "ent_coef": 0.01, + "vf_coef": 0.5, + "max_grad_norm": 0.5 + } + } +} diff --git a/embodichain/lab/gym/envs/tasks/rl/__init__.py b/embodichain/lab/gym/envs/tasks/rl/__init__.py index be52afc3..cc668926 100644 --- a/embodichain/lab/gym/envs/tasks/rl/__init__.py +++ b/embodichain/lab/gym/envs/tasks/rl/__init__.py @@ -19,7 +19,7 @@ from copy import deepcopy from embodichain.lab.gym.utils import registration as env_registry from embodichain.lab.gym.envs.embodied_env import EmbodiedEnvCfg - +from embodichain.lab.gym.envs.tasks.rl import push_ball def build_env(env_id: str, base_env_cfg: EmbodiedEnvCfg): """Create env from registry id, auto-inferring cfg class (EnvName -> EnvNameCfg).""" diff --git a/embodichain/lab/gym/envs/tasks/rl/push_ball.py b/embodichain/lab/gym/envs/tasks/rl/push_ball.py new file mode 100644 index 00000000..569e9fd5 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/rl/push_ball.py @@ -0,0 +1,69 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +from typing import Dict, Any, Tuple + +from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.gym.envs.rl_env import RLEnv +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.sim.types import EnvObs + + +@register_env("PushBallRL", max_episode_steps=100, override=True) +class PushBallEnv(RLEnv): + """Push Ball Gate Task Environment. + + The robot must push a soccer ball into a goal area. + Success is defined by the ball being within a distance threshold of the goal. + """ + + def __init__(self, cfg=None, **kwargs): + if cfg is None: + cfg = EmbodiedEnvCfg() + super().__init__(cfg, **kwargs) + + def compute_task_state( + self, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + """Compute task-specific state: success, fail, and metrics.""" + ball = self.sim.get_rigid_object("soccer_ball") + ball_pos = ball.body_data.pose[:, :3] + + if self.goal_pose is not None: + goal_pos = self.goal_pose[:, :3, 3] + xy_distance = torch.norm(ball_pos[:, :2] - goal_pos[:, :2], dim=1) + is_success = xy_distance < self.success_threshold + else: + xy_distance = torch.zeros(self.num_envs, device=self.device) + is_success = torch.zeros( + self.num_envs, device=self.device, dtype=torch.bool + ) + + is_fail = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) + metrics = { + "distance_to_goal": xy_distance, + "ball_height": ball_pos[:, 2], + } + + return is_success, is_fail, metrics + + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: + is_timeout = self._elapsed_steps >= self.episode_length + ball = self.sim.get_rigid_object("soccer_ball") + ball_pos = ball.body_data.pose[:, :3] + is_fallen = ball_pos[:, 2] < -0.1 + return is_timeout | is_fallen