Skip to content

fix(datasets): use get_from_first_device to safely de-shard tfds payload arrays#366

Open
divyashreepathihalli wants to merge 1 commit intogoogle-deepmind:masterfrom
divyashreepathihalli:master
Open

fix(datasets): use get_from_first_device to safely de-shard tfds payload arrays#366
divyashreepathihalli wants to merge 1 commit intogoogle-deepmind:masterfrom
divyashreepathihalli:master

Conversation

@divyashreepathihalli
Copy link

Summary
This PR fixes a critical crash in acme/datasets/tfds.py introduced during the removal of the deprecated jax.config.pmap_shmap_merge flag.

Details
Currently, sample_and_postprocess in JaxInMemoryRandomSampleIterator unconditionally calls x.addressable_shards[0].data.squeeze(0) on all elements of the batched data. This causes an AttributeError to be thrown whenever the dataset contains 0-dimensional scalar arrays (which do not possess addressable_shards) or runs on un-sharded single-device configurations.

This PR replaces the hardcoded lambda function with acme.jax.utils.get_from_first_device(data, as_numpy=False), routing the data through the robust _unreplicate helper added recently. This ensures proper handling of SingleDeviceSharding, 0-dim scalars, and fully replicated parameters without crashing.

Changes:

  • Imported acme.jax.utils in tfds.py
  • Replaced the vulnerable jax.tree_util.tree_map lambda call with get_from_first_device to safely strip the pmap replica dimension while preserving non-sharded scalars.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant