Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"yourdfpy",
"trimesh",
"viser",
"pyliblzfse", # Need for viser.extras import in viser==0.2.23
"pyliblzfse", # Need for viser.extras import in viser==0.2.23
"flax>=0.10.7",
]

[project.optional-dependencies]
Expand Down
100 changes: 93 additions & 7 deletions src/pyronot/motion_generators/_trajopt.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""TrajoptMotionGenerator: Cartesian spline seeding + IK + SCO trajopt.
"""TrajoptMotionGenerator: configurable seeding + SCO trajopt.

Seeding modes:
- ``"cartesian_ik"``: Cartesian spline interpolation between control SE(3)
poses, then batched IK via MPPI to convert waypoint poses to joint configs.
- ``"linear_js"``: Solve IK only at start/goal poses, then linearly
interpolate in joint space (same strategy as cuRobo).

Full pipeline:
1. Cartesian spline interpolation between control SE(3) poses (configurable
mode: linear, cubic, bspline).
2. Batched IK via MPPI to convert waypoint poses to joint configs.
3. SCO trajectory optimization on the resulting batch.
1. Seed trajectories via the chosen ``seed_mode``.
2. Tile to [B, T, DOF] and add Gaussian noise.
3. Run SCO (or other) trajectory optimization on the batch.
"""

from __future__ import annotations

import time
from dataclasses import dataclass, field
from typing import Literal

import jax
import jax.numpy as jnp
Expand All @@ -19,6 +25,8 @@
from jaxtyping import Float

from .._robot import Robot

SeedMode = Literal["cartesian_ik", "linear_js"]
from .._splines import (
SplineMode,
bspline_interpolate,
Expand Down Expand Up @@ -85,6 +93,7 @@ class TrajoptMotionGenerator:
n_batch: int = 25
noise_scale: float = 0.05
cartesian_spline_mode: SplineMode = "linear"
seed_mode: SeedMode = "cartesian_ik"

trajopt_cfg: ScoTrajOptConfig = field(default_factory=ScoTrajOptConfig)
ik_cfg: IKSeedConfig = field(default_factory=IKSeedConfig)
Expand Down Expand Up @@ -167,13 +176,13 @@ def _batch_ik(
continuity_weight=ik_cfg.continuity_weight,
)

def _seed_trajectories(
def _seed_cartesian_ik(
self,
control_poses: jaxlie.SE3,
key: Array,
prev_cfgs: Float[Array, "T DOF"] | None = None,
) -> tuple[Float[Array, "B T DOF"], Float[Array, "DOF"], Float[Array, "DOF"]]:
"""Build [B, T, DOF] batch via Cartesian spline + IK seeding + noise.
"""Seed via Cartesian spline interpolation + per-timestep IK.

Returns:
init_trajs: Batch of seeded trajectories. Shape [B, T, DOF].
Expand Down Expand Up @@ -203,6 +212,83 @@ def _seed_trajectories(
noise = jax.random.normal(key, trajs.shape) * self.noise_scale
return trajs + noise, start_cfg, goal_cfg

def _seed_linear_js(
self,
control_poses: jaxlie.SE3,
key: Array,
prev_cfgs: Float[Array, "T DOF"] | None = None,
start_cfg: Float[Array, "DOF"] | None = None,
goal_cfg: Float[Array, "DOF"] | None = None,
) -> tuple[Float[Array, "B T DOF"], Float[Array, "DOF"], Float[Array, "DOF"]]:
"""Seed via IK at start/goal only, then linear joint-space interpolation.

This mirrors cuRobo's seeding strategy: solve IK only at the endpoints,
then linearly interpolate in joint space. The resulting seeds are smooth
by construction but may pass through obstacles.

If ``start_cfg`` and/or ``goal_cfg`` are provided (e.g. from a problem
file), IK is skipped for that endpoint. The goal IK is warm-started
from the start config so both endpoints land on the same joint-space
branch.

Returns:
init_trajs: Batch of seeded trajectories. Shape [B, T, DOF].
start_cfg: Joint config at the start pose. Shape [DOF].
goal_cfg: Joint config at the goal pose. Shape [DOF].
"""
key, ik_key = jax.random.split(key)
ik_key_start, ik_key_goal = jax.random.split(ik_key)

# --- Start config ---
if start_cfg is None:
start_pose = jaxlie.SE3(control_poses.wxyz_xyz[0:1])
if prev_cfgs is not None:
start_prev = prev_cfgs[0:1]
else:
mid_cfg = (self.robot.joints.lower_limits + self.robot.joints.upper_limits) / 2.0
start_prev = mid_cfg[None]
start_cfg = self._batch_ik(start_pose, start_prev, ik_key_start)[0]

# --- Goal config (warm-started from start) ---
if goal_cfg is None:
goal_pose = jaxlie.SE3(control_poses.wxyz_xyz[-1:])
goal_prev = start_cfg[None]
goal_cfg = self._batch_ik(goal_pose, goal_prev, ik_key_goal)[0]

# Linear interpolation in joint space: q(t) = (1-α)*start + α*goal
alphas = jnp.linspace(0.0, 1.0, self.n_timesteps).reshape(-1, 1) # [T, 1]
base_traj = start_cfg * (1.0 - alphas) + goal_cfg * alphas # [T, DOF]

trajs = jnp.broadcast_to(
base_traj[None], (self.n_batch, self.n_timesteps, base_traj.shape[-1])
)
noise = jax.random.normal(key, trajs.shape) * self.noise_scale
return trajs + noise, start_cfg, goal_cfg

def _seed_trajectories(
self,
control_poses: jaxlie.SE3,
key: Array,
prev_cfgs: Float[Array, "T DOF"] | None = None,
start_cfg: Float[Array, "DOF"] | None = None,
goal_cfg: Float[Array, "DOF"] | None = None,
) -> tuple[Float[Array, "B T DOF"], Float[Array, "DOF"], Float[Array, "DOF"]]:
"""Build [B, T, DOF] seeded trajectory batch.

Dispatches to the seeding method specified by ``self.seed_mode``:
- ``"cartesian_ik"``: Cartesian spline + per-timestep IK.
- ``"linear_js"``: IK at endpoints + linear joint-space interpolation.

When ``start_cfg`` / ``goal_cfg`` are provided, ``linear_js`` skips IK
for those endpoints (ignored by ``cartesian_ik``).
"""
if self.seed_mode == "cartesian_ik":
return self._seed_cartesian_ik(control_poses, key, prev_cfgs)
elif self.seed_mode == "linear_js":
return self._seed_linear_js(control_poses, key, prev_cfgs, start_cfg, goal_cfg)
else:
raise ValueError(f"Unknown seed_mode: {self.seed_mode!r}")

def generate(
self,
start_pose: jaxlie.SE3,
Expand Down
2 changes: 2 additions & 0 deletions src/pyronot/optimization_engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@
from ._stomp_optimization import stomp_trajopt as stomp_trajopt
from ._ls_trajopt_optimization import LsTrajOptConfig as LsTrajOptConfig
from ._ls_trajopt_optimization import ls_trajopt as ls_trajopt
from ._lbfgs_trajopt_optimization import LbfgsTrajOptConfig as LbfgsTrajOptConfig
from ._lbfgs_trajopt_optimization import lbfgs_trajopt as lbfgs_trajopt
Loading
Loading