Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion alf/algorithms/td_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self,
td_lambda: float = 0.95,
normalize_target: bool = False,
default_return: Optional[float] = None,
bootstrap_only: bool = False,
debug_summaries: bool = False,
name: str = "TDLoss"):
r"""
Expand Down Expand Up @@ -89,6 +90,9 @@ def __init__(self,
default_return: The default values of ``discounted_return`` used in
``ReplayBuffer`` when the episode has not ended. It is used to summarizing
the actual Monte Carlo return (MC-return) values.
bootstrap_only: If True, will ignore the MC-returns if present and instead
rely solely on bootstrapping. Note that if MC-returns are not present,
this flag has no effect.
debug_summaries: True if debug summaries should be created.
name: The name of this loss.
"""
Expand All @@ -102,6 +106,7 @@ def __init__(self,
self._normalize_target = normalize_target
self._target_normalizer = None
self._default_return = default_return
self._bootstrap_only = bootstrap_only

@property
def gamma(self):
Expand Down Expand Up @@ -158,7 +163,8 @@ def compute_td_target(self, info: namedtuple, target_value: torch.Tensor):

if hasattr(info, "discounted_return") and info.discounted_return != ():
discounted_return = info.discounted_return[:-1]
returns = torch.max(returns, discounted_return)
if not self._bootstrap_only:
returns = torch.max(returns, discounted_return)
with alf.summary.scope(self._name):
mask = info.step_type[:-1] != StepType.LAST
episode_ended = discounted_return != self._default_return
Expand Down
Loading