Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions embodichain/lab/gym/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from embodichain.lab.sim.types import EnvObs, EnvAction
from embodichain.lab.sim import SimulationManagerCfg, SimulationManager
from embodichain.lab.sim.objects import Robot
from embodichain.lab.sim.sensors import BaseSensor
from embodichain.lab.sim.sensors import BaseSensor, Camera
from embodichain.lab.gym.utils import gym_utils
from embodichain.utils import configclass
from embodichain.utils import logger, set_seed
Expand Down Expand Up @@ -219,6 +219,16 @@ def get_sensor(self, name: str, **kwargs) -> BaseSensor:

return self.sensors[name]

def add_camera_group_id(self, group_id: int) -> None:
"""Add a camera group ID for rendering.

Args:
group_id: The camera group ID to be added.
"""
if not hasattr(self, "_camera_group_ids"):
self._camera_group_ids: List[int] = []
self._camera_group_ids.append(group_id)

def _setup_scene(self, **kwargs):
# Init sim manager.
# we want to open gui window when the scene is setup, so init sim manager in headless mode first.
Expand All @@ -245,6 +255,12 @@ def _setup_scene(self, **kwargs):

self.sensors = self._setup_sensors(**kwargs)

# Setup camera groups for rendering.
self._camera_group_ids: List[int] = []
for sensor in self.sensors.values():
if isinstance(sensor, Camera):
self._camera_group_ids.append(sensor.group_id)

def _setup_robot(self, **kwargs) -> Robot:
"""Load the robot agent, setup the controller and action space.

Expand Down Expand Up @@ -337,7 +353,7 @@ def _get_sensor_obs(self, **kwargs) -> Dict[str, any]:
fetch_only = False
if self.sim.is_rt_enabled:
fetch_only = True
self.sim.render_camera_group()
self.sim.render_camera_group(self._camera_group_ids)

for sensor_name, sensor in self.sensors.items():
sensor.update(fetch_only=fetch_only)
Expand Down
7 changes: 5 additions & 2 deletions embodichain/lab/gym/envs/managers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from dexsim.utility import images_to_video
from embodichain.lab.gym.envs.managers import Functor, FunctorCfg
from embodichain.lab.sim.sensors.camera import CameraCfg
from embodichain.lab.sim.sensors.camera import CameraCfg, Camera

if TYPE_CHECKING:
from embodichain.lab.gym.envs import EmbodiedEnv
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv):
"intrinsics", (600, 600, int(resolution[0] / 2), int(resolution[1] / 2))
)

self.camera = env.sim.add_sensor(
self.camera: Camera = env.sim.add_sensor(
sensor_cfg=CameraCfg(
uid=self._name,
width=resolution[0],
Expand All @@ -79,6 +79,9 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv):
)
)

# Add this camera's group ID to the environment for batch rendering.
env.add_camera_group_id(self.camera.group_id)

self._current_episode = 0
self._frames: List[np.ndarray] = []

Expand Down
9 changes: 9 additions & 0 deletions embodichain/lab/sim/sensors/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ def is_rt_enabled(self) -> bool:
"""
return is_rt_enabled()

@cached_property
def group_id(self) -> int:
"""Get the camera group ID in the dexsim world.

Returns:
int: The camera group ID.
"""
return self._frame_buffer.get_group_id()

def update(self, **kwargs) -> None:
"""Update the sensor data.

Expand Down
7 changes: 5 additions & 2 deletions embodichain/lab/sim/sim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,17 @@ def init_gpu_physics(self) -> None:

self._is_initialized_gpu_physics = True

def render_camera_group(self) -> None:
def render_camera_group(self, group_ids: list[int]) -> None:
"""Render all camera group in the simulation.

Args:
group_ids (list[int]): The list of camera group ids to render.

Note: This interface is only valid when Ray Tracing rendering backend is enabled.
"""

if self.is_rt_enabled:
self._world.render_camera_group()
self._world.render_camera_group(group_ids)
else:
logger.log_warning(
"This interface is only valid when Ray Tracing rendering backend is enabled."
Expand Down
Loading