Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 74 additions & 8 deletions egomimic/algo/pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
_to_minus1_1,
)
from egomimic.rldb.embodiment.embodiment import get_embodiment, get_embodiment_id
from egomimic.utils.action_utils import ConverterRegistry
from egomimic.utils.action_utils import (
PI05_CARTESIAN_ACTION_ENCODING_LEGACY,
PI05_CARTESIAN_ACTION_ENCODING_NORM_ROT_6D,
PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D,
ConverterRegistry,
)

logger = logging.getLogger(__name__)
# Ensure logger propagates to root logger and has appropriate level
Expand Down Expand Up @@ -70,6 +75,7 @@ def __init__(
state_num_bins: int = 256,
control_mode: dict[str, str] | None = None,
proprio_keys_for_prompt: list[str] | None = None,
action_encoding: str = PI05_CARTESIAN_ACTION_ENCODING_LEGACY,
**kwargs,
):
self.nets = nn.ModuleDict()
Expand Down Expand Up @@ -103,6 +109,7 @@ def __init__(
"pi_cam_keys", ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"]
)
self.config = config
self.action_encoding = action_encoding

self.ac_keys = ac_keys

Expand Down Expand Up @@ -291,6 +298,23 @@ def _tokenize_prompts(self, prompts: list[str]) -> dict:
"token_ar_mask": attention_mask.clone().requires_grad_(False),
}

def _action_stats(self, embodiment_id: int, ac_key: str) -> dict:
try:
return self.norm_stats.norm_stats[embodiment_id][ac_key]
except KeyError as exc:
raise KeyError(
f"Missing norm stats for action key {ac_key!r} "
f"and embodiment id {embodiment_id}"
) from exc

def _unnormalize_action(
self, action: torch.Tensor, embodiment_id: int, ac_key: str
):
return self.norm_stats.unnormalize(
{ac_key: action.clone(), "embodiment": embodiment_id},
embodiment_id,
)[ac_key].to(action.device)

@override
def process_batch_for_training(self, batch):
"""
Expand Down Expand Up @@ -446,17 +470,41 @@ def forward_eval(self, batch):
num_steps=self.num_steps,
)

pred_actions = pred_actions.clone()

predictions = OrderedDict()
ref = _batch[ac_key]
B, T, D = ref.shape

converter = self.action_registry.get(embodiment_id, ac_key)
pred_actions_orig = converter.from32(pred_actions)

pred = pred_actions_orig[:, :T, :D]
predictions[ac_key] = pred

unnorm_actions = self.norm_stats.unnormalize(predictions, embodiment_id)
if self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D:
pred_actions_orig = converter.from32_raw_rotation(
pred_actions,
stats=self._action_stats(embodiment_id, ac_key),
norm_mode=self.norm_stats.norm_mode,
unnormalize_non_rotation=True,
)
unnorm_actions = {ac_key: pred_actions_orig[:, :T, :D]}
elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_NORM_ROT_6D:
# Extract the normalized xyz+6D(+gripper) action, then
# unnormalize via the standard pipeline (stats were computed
# over the 6D representation) to get raw 6D actions.
pred_6d = converter.from32_norm_6d(pred_actions)
predictions[ac_key] = pred_6d[:, :T, :D]
unnorm_actions = self.norm_stats.unnormalize(
predictions, embodiment_id
)
elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY:
pred_actions_orig = converter.from32(pred_actions)
pred = pred_actions_orig[:, :T, :D]
predictions[ac_key] = pred
unnorm_actions = self.norm_stats.unnormalize(
predictions, embodiment_id
)
else:
raise ValueError(
f"Unsupported PI0.5 action_encoding: {self.action_encoding!r}"
)
for key in unnorm_actions:
unnorm_preds[f"{embodiment_name}_{key}"] = unnorm_actions[key]

Expand Down Expand Up @@ -531,7 +579,25 @@ def _robomimic_to_pi_data(

emb_id = get_embodiment_id(embodiment) # embodiment is a name string
converter = self.action_registry.get(emb_id, ac_key)
action32 = converter.to32(action)
if self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D:
raw_action = self._unnormalize_action(action, emb_id, ac_key)
action32 = converter.to32_raw_rotation(
raw_action,
normalized_actions=action,
stats=self._action_stats(emb_id, ac_key),
norm_mode=self.norm_stats.norm_mode,
)
elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_NORM_ROT_6D:
# Action is already a normalized xyz+6D(+gripper) chunk (the
# ypr->6D conversion happened in the CartesianYPRToRot6D data
# transform). Just pack it into the 32D vector.
action32 = converter.to32_norm_6d(action)
elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY:
action32 = converter.to32(action)
else:
raise ValueError(
f"Unsupported PI0.5 action_encoding: {self.action_encoding!r}"
)

# OpenPI expects a fixed camera tuple. Human datasets only provide
# `base_0_rgb`, so duplicate that view into the missing wrist slots and
Expand Down
55 changes: 54 additions & 1 deletion egomimic/rldb/embodiment/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from egomimic.rldb.zarr.action_chunk_transforms import (
ActionChunkCoordinateFrameTransform,
BatchQuaternionPoseToYPR,
CartesianRot6DToYPR,
CartesianYPRToRot6D,
ConcatKeys,
DeleteKeys,
InterpolateLinear,
Expand All @@ -29,13 +31,31 @@ class Eva(Embodiment):
@staticmethod
def get_transform_list(
mode: Literal[
"cartesian", "cartesian_wristframe_ypr", "cartesian_wristframe_quat"
"cartesian",
"cartesian_6d",
"cartesian_wristframe_ypr",
"cartesian_wristframe_6d",
"cartesian_wristframe_quat",
],
) -> list[Transform]:
if mode == "cartesian":
return _build_eva_bimanual_transform_list(is_quat=True)
elif mode == "cartesian_6d":
# Camera-frame cartesian (14D xyz+ypr+gripper per arm) with the
# rotation re-expressed as the continuous 6D representation
# (20D xyz+6d+gripper per arm) for pi0.5 normalized-rot6d encoding.
return _build_eva_bimanual_transform_list(is_quat=True) + [
CartesianYPRToRot6D(action_key="actions_cartesian")
]
elif mode == "cartesian_wristframe_ypr":
return _build_eva_bimanual_eef_frame_transform_list(is_quat=False)
elif mode == "cartesian_wristframe_6d":
# Wrist-frame cartesian (14D xyz+ypr+gripper per arm) with the
# rotation re-expressed as the continuous 6D representation
# (20D) for pi0.5 normalized-rot6d encoding.
return _build_eva_bimanual_eef_frame_transform_list(is_quat=False) + [
CartesianYPRToRot6D(action_key="actions_cartesian")
]
elif mode == "cartesian_wristframe_quat":
return _build_eva_bimanual_eef_frame_transform_list(is_quat=True)

Expand Down Expand Up @@ -131,6 +151,39 @@ def dinov3_keymap(cls):
}


def _build_eva_cartesian_revert_6d_transform_list(
*,
action_key: str = "actions_cartesian",
) -> list[Transform]:
"""Revert camera-frame 6D-rotation EVA cartesian actions back to ypr.

Used by the cam-frame 6D evaluator: the action chunk is already in camera
frame (produced by the ``cartesian_6d`` transform mode), so only the
rotation representation is converted from xyz+6D (+gripper, 10/arm) back to
xyz+ypr (+gripper, 7/arm) so cam-frame MSE and the viz video see the same
ypr layout as the plain ``cartesian`` mode.
"""
return [CartesianRot6DToYPR(action_key=action_key)]


def _build_eva_cartesian_revert_6d_wristframe_transform_list(
*,
action_key: str = "actions_cartesian",
) -> list[Transform]:
"""Revert wrist-frame 6D-rotation EVA actions back to camera-frame ypr.

Two stages for the cam-frame 6D wristframe evaluator: (1) convert the action
rotation from xyz+6D (+gripper) back to xyz+ypr (+gripper) via
``CartesianRot6DToYPR``; (2) project the wrist-frame ypr actions back into
camera frame using the standard eef-frame revert (which reads the proprio
``observations.state.ee_pose``, left untouched as ypr by the 6D transform).
"""
return [
CartesianRot6DToYPR(action_key=action_key),
*_build_eva_bimanual_revert_eef_frame_transform_list(is_quat=False),
]


def _build_eva_bimanual_revert_eef_frame_transform_list(
*,
action_key: str = "actions_cartesian",
Expand Down
105 changes: 99 additions & 6 deletions egomimic/rldb/zarr/action_chunk_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
_matrix_to_xyz,
_matrix_to_xyzwxyz,
_matrix_to_xyzypr,
_rot6d_to_ypr,
_xyz_to_matrix,
_xyzwxyz_to_matrix,
_xyzypr_to_matrix,
_ypr_to_rot6d,
wxyz_to_xyzw,
xyzw_to_wxyz,
)
Expand Down Expand Up @@ -387,6 +389,101 @@ def transform(self, batch: dict) -> dict:
return batch


class CartesianYPRToRot6D(Transform):
"""Convert a bimanual cartesian action chunk from per-arm xyz+ypr(+gripper)
to per-arm xyz+rot6d(+gripper).

``rot6d`` is the continuous 6D rotation representation = the first two
columns of the rotation matrix, packed as [col0(3), col1(3)] (see
:func:`egomimic.utils.pose_utils._ypr_to_rot6d`). This matches the column
convention of the ``to32``/``from32`` packers in
``egomimic.utils.action_utils``, so the resulting per-arm layout maps
directly into the pi0.5 32D action blocks.

Input layouts (last dim):
12 -> [L xyz ypr, R xyz ypr] -> 18 [L xyz 6d, R xyz 6d]
14 -> [L xyz ypr g, R xyz ypr g] -> 20 [L xyz 6d g, R xyz 6d g]

Preserves the numpy/tensor type of the input (like ``PadGripperZeros``).
"""

def __init__(
self, action_key: str = "actions_cartesian", output_key: str | None = None
):
self.action_key = action_key
self.output_key = output_key or action_key

def transform(self, batch: dict) -> dict:
actions = batch[self.action_key]
is_tensor = isinstance(actions, torch.Tensor)
arr = actions.cpu().numpy() if is_tensor else np.asarray(actions)
D = arr.shape[-1]
if D == 14:
l_xyz, l_ypr, l_g = arr[..., 0:3], arr[..., 3:6], arr[..., 6:7]
r_xyz, r_ypr, r_g = arr[..., 7:10], arr[..., 10:13], arr[..., 13:14]
out = np.concatenate(
[l_xyz, _ypr_to_rot6d(l_ypr), l_g, r_xyz, _ypr_to_rot6d(r_ypr), r_g],
axis=-1,
)
elif D == 12:
l_xyz, l_ypr = arr[..., 0:3], arr[..., 3:6]
r_xyz, r_ypr = arr[..., 6:9], arr[..., 9:12]
out = np.concatenate(
[l_xyz, _ypr_to_rot6d(l_ypr), r_xyz, _ypr_to_rot6d(r_ypr)],
axis=-1,
)
else:
raise ValueError(
f"CartesianYPRToRot6D expects last-dim 12 or 14, got {arr.shape} "
f"for '{self.action_key}'"
)
batch[self.output_key] = torch.from_numpy(out) if is_tensor else out
return batch


class CartesianRot6DToYPR(Transform):
"""Inverse of :class:`CartesianYPRToRot6D`: per-arm xyz+rot6d(+gripper) ->
xyz+ypr(+gripper).

Input layouts (last dim):
18 -> [L xyz 6d, R xyz 6d] -> 12 [L xyz ypr, R xyz ypr]
20 -> [L xyz 6d g, R xyz 6d g] -> 14 [L xyz ypr g, R xyz ypr g]
"""

def __init__(
self, action_key: str = "actions_cartesian", output_key: str | None = None
):
self.action_key = action_key
self.output_key = output_key or action_key

def transform(self, batch: dict) -> dict:
actions = batch[self.action_key]
is_tensor = isinstance(actions, torch.Tensor)
arr = actions.cpu().numpy() if is_tensor else np.asarray(actions)
D = arr.shape[-1]
if D == 20:
l_xyz, l_6d, l_g = arr[..., 0:3], arr[..., 3:9], arr[..., 9:10]
r_xyz, r_6d, r_g = arr[..., 10:13], arr[..., 13:19], arr[..., 19:20]
out = np.concatenate(
[l_xyz, _rot6d_to_ypr(l_6d), l_g, r_xyz, _rot6d_to_ypr(r_6d), r_g],
axis=-1,
)
elif D == 18:
l_xyz, l_6d = arr[..., 0:3], arr[..., 3:9]
r_xyz, r_6d = arr[..., 9:12], arr[..., 12:18]
out = np.concatenate(
[l_xyz, _rot6d_to_ypr(l_6d), r_xyz, _rot6d_to_ypr(r_6d)],
axis=-1,
)
else:
raise ValueError(
f"CartesianRot6DToYPR expects last-dim 18 or 20, got {arr.shape} "
f"for '{self.action_key}'"
)
batch[self.output_key] = torch.from_numpy(out) if is_tensor else out
return batch


class CartesianWithGripperCoordinateTransform(Transform):
def __init__(
self,
Expand Down Expand Up @@ -535,12 +632,8 @@ def transform(self, batch: dict) -> dict:
)
pad_shape = (*arr.shape[:-1], 1)
pad = np.zeros(pad_shape, dtype=arr.dtype)
padded = np.concatenate(
(arr[..., :6], pad, arr[..., 6:], pad), axis=-1
)
batch[self.action_key] = (
torch.from_numpy(padded) if is_tensor else padded
)
padded = np.concatenate((arr[..., :6], pad, arr[..., 6:], pad), axis=-1)
batch[self.action_key] = torch.from_numpy(padded) if is_tensor else padded
return batch


Expand Down
Loading
Loading