Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions alf/experience_replayers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self,
mp_context=None,
keep_episodic_info=None,
record_episodic_return=False,
compute_episodic_return_on_last_step=False,
default_return=-1000.,
gamma=.99,
reward_clip=None,
Expand Down Expand Up @@ -113,6 +114,27 @@ def __init__(self,
``ReplayBuffer.reward_clip=(-1,1)``.
2) Discount ``gamma`` needs to be set consistent with ``TDLoss.gamma``.
3) Assumes ``keep_episodic_info`` to be True.
compute_episodic_return_on_last_step (bool): If True, compute episodic
return when a LAST step is encountered regardless of the discount factor.
Default behavior when False is to compute the return only when a discount factor
of 0 is encountered, otherwise, steps will be populated with a return of
'default_return'. Useful for infinite horizon RL formulations that need
the discounted return. If True, keep_episodic_return and record_episodic_return
must also be True.
NOTE:
If the discount factor is not 0, then the computed MC return can
be biased because future rewards beyond the last step are implicitly
assumed to be zero (i.e., the return is truncated at the episode boundary).

For infinite-horizon RL, this results in a missing term of the form
γ^k V(s_T), where s_T is the last state. Therefore:

- If the true value of the last state is positive, the MC return is a lower bound.
- If the true value is negative, the MC return is an upper bound.

This bias can be significant when episodes end due to time limits rather
than true terminal conditions, and the resulting discounted return should
be used cautiously.
default_return (float): The default values of ``discounted_return``
when the episode has not ended. For value target lower bounding,
default_return should not be bigger than the smallest possible
Expand Down Expand Up @@ -148,6 +170,9 @@ def __init__(self,
self._record_episodic_return = record_episodic_return
if record_episodic_return:
assert keep_episodic_info
self._compute_episodic_return_on_last_step = compute_episodic_return_on_last_step
if compute_episodic_return_on_last_step:
assert record_episodic_return
self._default_return = default_return
self._gamma = gamma
self._reward_clip = reward_clip
Expand Down Expand Up @@ -353,7 +378,10 @@ def add_batch(self, batch, env_ids=None, blocking=False):
# This has the advantage of start storing episodic return earlier,
# but the disadvantage of having to compute episodic return a few times
# per episode, repeatedly for some of the earlier steps in the episode.
disc_0, = torch.where(batch.discount == 0)
compute_mask = batch.discount == 0
if self._compute_episodic_return_on_last_step:
compute_mask |= step_types == ds.StepType.LAST
disc_0, = torch.where(compute_mask)
# Backfill episodic returns for episodes which ended
if disc_0.nelement() > 0:
self._compute_store_episodic_return(env_ids[disc_0])
Expand Down Expand Up @@ -546,7 +574,7 @@ def _set_default_return(self, env_ids):
self._episodic_discounted_return[ind] = self._default_return

def _compute_store_episodic_return(self, env_ids):
# Always pass in env_ids whose discount is 0, to save computation.
# Pass in env_ids marking the end of episodes to save computation.
current_pos = self._current_pos[env_ids]
current_pos -= 1

Expand Down
27 changes: 27 additions & 0 deletions alf/experience_replayers/replay_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,33 @@ def test_gather_all_with_num_earliest_frames_ignored(self):
self.assertEqual(torch.tensor([[8, 9, 10, 11, 12, 13, 14]] * 4),
experience.step_type)

def test_compute_episodic_return_on_last_step(self):
replay_buffer = ReplayBuffer(data_spec=self.data_spec,
num_environments=1,
max_length=10,
keep_episodic_info=True,
record_episodic_return=True,
compute_episodic_return_on_last_step=True)
steps = [
ds.StepType.FIRST,
ds.StepType.MID,
ds.StepType.MID,
ds.StepType.LAST,
]
for t in range(4):
batch = get_exp_batch([0], self.dim, t=steps[t], x=0.1 * t)
if steps[t] == ds.StepType.LAST:
batch.discount[:] = 1.0
replay_buffer.add_batch(batch, batch.env_id)
expected = torch.tensor([[
-2.9701, -1.99, -1., -1000., -1000., -1000., -1000., -1000.,
-1000., -1000.
]],
dtype=torch.float32)
self.assertTrue(
torch.allclose(replay_buffer._episodic_discounted_return,
expected))


def _write_to_buffer(buffer):
bat = torch.zeros((
Expand Down
Loading