Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions src/pypcd4/pypcd4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from enum import Enum
from io import BufferedReader
from pathlib import Path
from typing import TYPE_CHECKING, BinaryIO, List, Literal, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, BinaryIO, List, Literal, Optional, Sequence, Tuple, Union, cast

import lzf
import lzf # type: ignore
import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, NonNegativeInt, PositiveInt
Expand All @@ -21,8 +21,8 @@
)

if TYPE_CHECKING:
from sensor_msgs.msg import PointCloud2
from std_msgs.msg import Header
from sensor_msgs.msg import PointCloud2 # type: ignore
from std_msgs.msg import Header # type: ignore

PathLike = Union[str, Path]

Expand Down Expand Up @@ -207,7 +207,7 @@ def _parse_pc_data(fp: BufferedReader, metadata: MetaData) -> npt.NDArray:
def _compose_pc_data(
points: Union[npt.NDArray, Sequence[npt.NDArray]], metadata: MetaData
) -> npt.NDArray:
arrays: Sequence[npt.NDArray] = tuple(points.T) if isinstance(points, np.ndarray) else points
arrays = tuple(points.T) if isinstance(points, np.ndarray) else points

return np.rec.fromarrays(arrays, dtype=metadata.build_dtype())

Expand Down Expand Up @@ -616,30 +616,30 @@ def to_msg(self, header: Optional["Header"] = None) -> "PointCloud2":
"""
ROS_MSG_AVAILABLE = False
try:
from sensor_msgs.msg import PointCloud2, PointField
from std_msgs.msg import Header
from sensor_msgs.msg import PointCloud2, PointField # type: ignore
from std_msgs.msg import Header # type: ignore

ROS_MSG_AVAILABLE = True
try:
from builtin_interfaces.msg import Time # ROS2
from builtin_interfaces.msg import Time # ROS2 # type: ignore
except ImportError:
from std_msgs.msg import Time # ROS1
from std_msgs.msg import Time # ROS1 # type: ignore
except ImportError:
pass

if not ROS_MSG_AVAILABLE:
try:
# Fallback to rosbags
from rosbags.typesys.stores.latest import (
from rosbags.typesys.stores.latest import ( # type: ignore
builtin_interfaces__msg__Time as Time,
)
from rosbags.typesys.stores.latest import (
from rosbags.typesys.stores.latest import ( # type: ignore
sensor_msgs__msg__PointCloud2 as PointCloud2,
)
from rosbags.typesys.stores.latest import (
from rosbags.typesys.stores.latest import ( # type: ignore
sensor_msgs__msg__PointField as PointField,
)
from rosbags.typesys.stores.latest import (
from rosbags.typesys.stores.latest import ( # type: ignore
std_msgs__msg__Header as Header,
)

Expand Down Expand Up @@ -1037,12 +1037,11 @@ def __getitem__(
)
"""

points_list: list[npt.NDArray]
fields = self.fields
types = self.types
if isinstance(subscript, slice):
points_list = tuple(
self.pc_data[field][subscript]
cast(npt.NDArray, self.pc_data[field][subscript])
for field in self.pc_data.dtype.names # type: ignore[assignment,union-attr]
)
elif isinstance(subscript, np.ndarray):
Expand All @@ -1051,19 +1050,23 @@ def __getitem__(
raise ValueError(f"Mask array must be 1-dimensional but got {mask.ndim}")

points_list = tuple(
self.pc_data[field][mask]
cast(npt.NDArray, self.pc_data[field][mask])
for field in self.pc_data.dtype.names # type: ignore[assignment,union-attr]
)
elif isinstance(subscript, str) or all(isinstance(s, str) for s in subscript):
elif isinstance(subscript, str) or all(isinstance(s, str) for s in cast(tuple, subscript)):
if isinstance(subscript, str):
subscript = (subscript,)

subscript = cast(tuple, subscript)

if not np.isin(subscript, self.fields).all():
raise ValueError(f"Invalid field name(s): {subscript}")

points_list = [self.pc_data[field] for field in subscript]
points_list = tuple(cast(npt.NDArray, self.pc_data[field]) for field in subscript)
fields = tuple(subscript)
types = tuple(self.pc_data[field].dtype for field in subscript)
else:
raise ValueError(f"Invalid subscript type: {type(subscript).__name__}")

return PointCloud.from_points(points_list, fields, types)

Expand Down
Loading