Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion synchros2/synchros2/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def wait_for_outcome(
Returns:
whether the action finalized before the timeout expired.
"""
return wait_for_future(action.finalization, timeout_sec=timeout_sec)
return bool(wait_for_future(action.finalization, timeout_sec=timeout_sec))


def unwrap_outcome(
Expand Down
276 changes: 251 additions & 25 deletions synchros2/synchros2/futures.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
# Copyright (c) 2023 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.
from threading import Event
from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar, Union, runtime_checkable
import threading
from concurrent.futures import ALL_COMPLETED, FIRST_COMPLETED, FIRST_EXCEPTION
from typing import (
Any,
Awaitable,
Callable,
Iterable,
Iterator,
NamedTuple,
Optional,
Protocol,
Set,
TypeVar,
Union,
cast,
overload,
runtime_checkable,
)

from rclpy.clock import Clock
from rclpy.context import Context
from rclpy.duration import Duration
from rclpy.utilities import get_default_context

from synchros2.clock import wait_for

T = TypeVar("T", covariant=True)


@runtime_checkable
class FutureLike(Awaitable[T], Protocol[T]):
"""A future-like awaitable object.

Expand Down Expand Up @@ -61,58 +79,266 @@ def as_proper_future(instance: AnyFuture) -> FutureLike:
return instance


class WaitResult(NamedTuple):
"""Result of waiting for multiple futures.

A named tuple with 'done' and 'not_done' sets of futures.
"""

ok: bool
done: Set[FutureLike]
not_done: Set[FutureLike]

def __bool__(self) -> bool:
"""Equivalent to result.ok."""
return self.ok


@overload
def wait_for_future(
future: AnyFuture,
timeout_sec: Optional[float] = None,
*,
clock: Optional[Clock] = None,
context: Optional[Context] = None,
) -> bool:
"""Block while waiting for a future to become done
) -> WaitResult:
...


@overload
def wait_for_future(
future: Iterable[AnyFuture],
timeout_sec: Optional[float] = None,
*,
return_when: str = ALL_COMPLETED,
clock: Optional[Clock] = None,
context: Optional[Context] = None,
) -> WaitResult:
...


def wait_for_future(
future: Union[AnyFuture, Iterable[AnyFuture]],
timeout_sec: Optional[float] = None,
*,
clock: Optional[Clock] = None,
context: Optional[Context] = None,
return_when: str = ALL_COMPLETED,
) -> WaitResult:
"""Block while waiting for future(s) to become done.

Args:
future (Future): The future to be waited on
timeout_sec (Optional[float]): An optional timeout for how long to wait
clock (Optional[Clock]): An optional clock to use for timeout waits,
defaults to the clock of the current scope if any, otherwise the system clock
context (Optional[Context]): Current context (will use the default if none is given)
future: A single future or an iterable of futures to wait on
timeout_sec: An optional timeout for how long to wait
clock: An optional clock to use for timeout waits,
defaults to the clock of the current scope if any, otherwise the system clock
context: Current context (will use the default if none is given)
return_when: One of FIRST_COMPLETED, FIRST_EXCEPTION, or ALL_COMPLETED.
Only applies when waiting for multiple futures. Defaults to ALL_COMPLETED.

Returns:
bool: True if successful, False if the timeout was triggered
A result object indicating which futures are done and which are not,
and whether the wait was successful (i.e. not timed out).

Examples:
Single future:
>>> result = wait_for_future(my_future, timeout_sec=5.0)
>>> if result:
... value = my_future.result()

Multiple futures:
>>> result = wait_for_future([f1, f2, f3], return_when=FIRST_COMPLETED)
>>> for future in result.done:
... print(future.result())
"""
if return_when not in {FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED}:
raise ValueError(f"Invalid return_when value: {return_when}")

if context is None:
context = get_default_context()

if clock is None:
import synchros2.scope

clock = synchros2.scope.clock()
event = Event()

done_futures: Set[FutureLike] = set()
if not isinstance(future, (FutureConvertible, FutureLike)):
pending_futures = {as_proper_future(f) for f in future}
else:
pending_futures = {as_proper_future(future)}

if not pending_futures:
return WaitResult(ok=True, done=set(), not_done=set())

lock = threading.Lock()
event = threading.Event()

def _done_callback(future: FutureLike) -> None:
with lock:
if future in pending_futures:
pending_futures.remove(future)
done_futures.add(future)

should_return = False
if return_when == FIRST_COMPLETED:
should_return = True
elif return_when == FIRST_EXCEPTION:
exception_occurred = future.exception() is not None
should_return = exception_occurred or not pending_futures
elif return_when == ALL_COMPLETED:
should_return = not pending_futures

if should_return:
event.set()

context.on_shutdown(event.set)
proper_future = as_proper_future(future)
proper_future.add_done_callback(lambda _: event.set())
if proper_future.cancelled():
event.set()
wait_for(event, clock=clock, timeout_sec=timeout_sec)
return proper_future.done()
for future in list(pending_futures):
future.add_done_callback(_done_callback)
if future.cancelled():
_done_callback(future)

if not event.is_set():
wait_for(event, clock=clock, timeout_sec=timeout_sec)

with lock:
return WaitResult(ok=event.is_set(), done=done_futures.copy(), not_done=pending_futures.copy())


@overload
def unwrap_future(
future: AnyFuture,
timeout_sec: Optional[float] = None,
*,
clock: Optional[Clock] = None,
context: Optional[Context] = None,
) -> Any:
"""Fetch future result when it is done.
...


@overload
def unwrap_future(
future: Iterable[AnyFuture],
timeout_sec: Optional[float] = None,
*,
clock: Optional[Clock] = None,
context: Optional[Context] = None,
strict: bool = False,
) -> Iterator[Any]:
...


def unwrap_future(
future: Union[AnyFuture, Iterable[AnyFuture]],
timeout_sec: Optional[float] = None,
*,
clock: Optional[Clock] = None,
context: Optional[Context] = None,
strict: bool = False,
) -> Union[Any, Iterator[Any]]:
"""Fetch future result(s) when done.

For a single future, blocks until the future is done and returns its result.
For multiple futures, returns a generator that yields results as futures complete
(like concurrent.futures.as_completed).

Note: This function may block and may raise if a future raises or it times out
waiting. See wait_for_future() documentation for further reference on arguments.

Args:
future: A single future or an iterable of futures
timeout_sec: An optional timeout for how long to wait
clock: An optional clock to use for timeout waits
context: Current context (will use the default if none is given)
strict: If True, yield results in order regardless of completion order.
If False (default), yield results as they complete.
Irrelevant when a single future is provided.

Returns:
the result(s) of the future(s) when they are done.

Raises:
ValueError: If timeout occurs before future(s) complete

Examples:
Single future:
>>> result = unwrap_future(my_future, timeout_sec=5.0)

Note this function may block and may raise if the future does or it times out
waiting for it. See wait_for_future() documentation for further reference on
arguments taken.
Multiple futures (non-strict, as completed):
>>> for result in unwrap_future([f1, f2, f3], timeout_sec=10.0):
... process(result)

Multiple futures (strict, in order):
>>> for result in unwrap_future([f1, f2, f3], timeout_sec=10.0, strict=True):
... process(result)
"""
proper_future = as_proper_future(future)
if not wait_for_future(proper_future, timeout_sec, clock=clock, context=context):
raise ValueError("cannot unwrap future that is not done")
return proper_future.result()
if context is None:
context = get_default_context()

if clock is None:
import synchros2.scope

clock = synchros2.scope.clock()

if isinstance(future, (FutureConvertible, FutureLike)):
proper_future = as_proper_future(future)
if not wait_for_future(proper_future, timeout_sec, clock=clock, context=context):
raise ValueError("cannot unwrap future that is not done")
return proper_future.result()

def _result_generator() -> Any:
nonlocal future
future = cast(Iterable[AnyFuture], future)
pending_futures = [as_proper_future(f) for f in future]
if not pending_futures:
return

deadline = None
if timeout_sec is not None:
assert clock is not None
deadline = clock.now() + Duration(seconds=timeout_sec)

if strict:
for future in pending_futures:
remaining_timeout_sec = None
if deadline is not None:
assert clock is not None
remaining_duration = deadline - clock.now()
if remaining_duration.nanoseconds <= 0:
raise ValueError("timeout waiting for futures")
remaining_timeout_sec = remaining_duration.nanoseconds / 1e9

if not wait_for_future(future, timeout_sec=remaining_timeout_sec, clock=clock, context=context):
raise ValueError("timeout waiting for futures")
yield future.result()
return

while pending_futures:
remaining_timeout_sec = None
if deadline is not None:
assert clock is not None
remaining_duration = deadline - clock.now()
if remaining_duration.nanoseconds <= 0:
raise ValueError("timeout waiting for futures")
remaining_timeout_sec = remaining_duration.nanoseconds / 1e9

result = wait_for_future(
pending_futures,
timeout_sec=remaining_timeout_sec,
clock=clock,
context=context,
return_when=FIRST_COMPLETED,
)

if not result:
raise ValueError("timeout waiting for futures")

for future in result.done:
if future in pending_futures:
pending_futures.remove(future)
yield future.result()

return _result_generator()


wait_and_return_result = unwrap_future
Expand Down
Loading
Loading