Background
Currently, DynamicEmb has custom input_dist implementation (RwSparseFeaturesDist in input_dist.py) but still relies on TorchRec's original output_dist implementation. This causes:
- Performance issue: The
unbucketize_permute operation in TorchRec's output distribution is slow, especially for non-contiguous distribution patterns (e.g., round-robin)
- Limited customization: Cannot optimize the output distribution without modifying TorchRec source code
Objective
Port TorchRec's output distribution classes to DynamicEmb library, enabling future performance optimizations.
Tasks
PR 1: Port output distribution classes to DynamicEmb
PR 2: Optimize unbucketize permute with custom kernel
Background
Currently, DynamicEmb has custom
input_distimplementation (RwSparseFeaturesDistininput_dist.py) but still relies on TorchRec's originaloutput_distimplementation. This causes:unbucketize_permuteoperation in TorchRec's output distribution is slow, especially for non-contiguous distribution patterns (e.g., round-robin)Objective
Port TorchRec's output distribution classes to DynamicEmb library, enabling future performance optimizations.
Tasks
PR 1: Port output distribution classes to DynamicEmb
dynamicemb/output_dist.pywith:RwSequenceEmbeddingDistRwPooledEmbeddingDistdynamicemb/planner/rw_sharding.pyto overridecreate_output_dist()methodstest_sequence_embedding_fw.py,test_pooled_embedding_fw.py)PR 2: Optimize unbucketize permute with custom kernel
output_dist.py