If an empty batch is fed to one rank, the entire group will hang. This issue was most likely introduced in commit 173bf25, as no hangs occurred before this change.
import math
import os
from typing import List, cast
import torch
import torch.distributed as dist
import torch.nn as nn
from dynamicemb.planner import (DynamicEmbeddingEnumerator,
DynamicEmbeddingShardingPlanner,
DynamicEmbParameterConstraints)
from dynamicemb.shard import (DynamicEmbeddingBagCollectionSharder,
DynamicEmbeddingCollectionSharder)
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
from fbgemm_gpu.split_embedding_configs import SparseType
from torchrec import EmbeddingCollection, EmbeddingConfig, KeyedJaggedTensor
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.fbgemm_qcomm_codec import (CommType, QCommsConfig,
get_qcomm_codecs_registry)
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import Topology
from torchrec.distributed.planner.storage_reservations import \
HeuristicalStorageReservation
from torchrec.distributed.planner.types import ShardingPlan
from torchrec.distributed.types import (BoundsCheckMode, ModuleSharder,
ShardingType)
from dynamicemb import (DynamicEmbCheckMode, DynamicEmbInitializerArgs,
DynamicEmbInitializerMode, DynamicEmbScoreStrategy,
DynamicEmbTableOptions, FrequencyAdmissionStrategy,
KVCounter)
def get_sharder(lr, prefetch_pipeline, optimizer_type, is_ebc):
# set optimizer args
learning_rate = lr
beta1 = 0.9
beta2 = 0.999
weight_decay = 0
eps = 1e-8
# Put args into a optimizer kwargs , which is same usage of torchrec
optimizer_kwargs = {
"optimizer": optimizer_type,
"learning_rate": learning_rate,
"beta1": beta1,
"beta2": beta2,
"weight_decay": weight_decay,
"eps": eps,
}
fused_params = {
"output_dtype": SparseType.FP32, # data type of the output after lookup, and can differ from the stored.
"prefetch_pipeline": prefetch_pipeline, # whether enable prefetch for embedding lookup module
}
fused_params.update(optimizer_kwargs)
# precision of all-to-all
qcomm_codecs_registry = get_qcomm_codecs_registry(
qcomms_config=QCommsConfig(
forward_precision=CommType.FP32, backward_precision=CommType.FP32
)
)
if not is_ebc:
return DynamicEmbeddingCollectionSharder(
qcomm_codecs_registry=qcomm_codecs_registry,
fused_params=fused_params,
use_index_dedup=True,
)
return DynamicEmbeddingBagCollectionSharder(
qcomm_codecs_registry=qcomm_codecs_registry, fused_params=fused_params
)
def create_dynamic_optitons(config, hbm_cap, caching, admission_threshold):
bucket_capacity = 1024 if caching else 128
cache_ratio = 0.5 # assume we will use 50% of the HBM for cache
emb_num_embeddings = config.num_embeddings
emb_num_embeddings_next_power_of_2 = 2 ** math.ceil(
math.log2(emb_num_embeddings)
) # hash table needs embedding vector num to be power of 2
world_size = dist.get_world_size()
threshold = (bucket_capacity * world_size) / cache_ratio
threshold_int = math.ceil(threshold)
if emb_num_embeddings_next_power_of_2 < threshold_int:
emb_num_embeddings_next_power_of_2 = 2 ** math.ceil(math.log2(threshold_int))
# Setup admission strategy if threshold > 0
admit_strategy = None
admission_counter = None
if admission_threshold > 0:
# Create counter config (actual table will be created during sharding)
admission_counter = KVCounter(
capacity=emb_num_embeddings_next_power_of_2,
bucket_capacity=bucket_capacity,
key_type=torch.int64,
)
# Create admission strategy with threshold
admit_strategy = FrequencyAdmissionStrategy(
threshold=admission_threshold,
initializer_args=DynamicEmbInitializerArgs(
mode=DynamicEmbInitializerMode.CONSTANT, value=0.0 # Initialize rejected keys to 0
),
)
init_capacity = 1 << 24
return DynamicEmbTableOptions(
safe_check_mode=DynamicEmbCheckMode.WARNING,
init_capacity=init_capacity,
global_hbm_for_values=hbm_cap * 0.5, # total_hbm_need * (cache_ratio if caching else 1.0),
initializer_args=DynamicEmbInitializerArgs(mode=DynamicEmbInitializerMode.UNIFORM),
score_strategy=DynamicEmbScoreStrategy.STEP,
caching=caching,
training=True,
admit_strategy=admit_strategy,
admission_counter=admission_counter,
)
def make_shard_constraints(configs, hbm_cap, caching, admission_threshold):
constraints = {}
for config in configs:
options = create_dynamic_optitons(
config, hbm_cap=hbm_cap, caching=caching, admission_threshold=admission_threshold
)
constraint = DynamicEmbParameterConstraints(
# dynamicemb embedding table only support to be sharded in row-wise.
sharding_types=[ShardingType.ROW_WISE.value],
bounds_check_mode=BoundsCheckMode.NONE, # dynamic embedding has no bounding!
use_dynamicemb=True,
dynamicemb_options=options,
)
constraints[config.name] = constraint
return constraints
# use a function warp all the Planner code
def get_planner(device, configs, caching, admission_threshold):
hbm_cap = torch.cuda.get_device_properties(0).total_memory
ddr_cap = 512 * 1024 * 1024 * 1024 # Assume a Node have 512GB memory
intra_host_bw = 450e9 # Nvlink bandwidth
inter_host_bw = 25e9 # NIC bandwidth
topology = Topology(
local_world_size=get_local_size(),
world_size=dist.get_world_size(),
compute_device=device.type,
hbm_cap=hbm_cap,
ddr_cap=ddr_cap,
intra_host_bw=intra_host_bw,
inter_host_bw=inter_host_bw,
)
constraints = make_shard_constraints(
configs,
hbm_cap=hbm_cap,
caching=caching,
admission_threshold=admission_threshold,
)
# same usage of torchrec's EmbeddingEnumerator
enumerator = DynamicEmbeddingEnumerator(topology=topology, constraints=constraints)
# Almost same usage of torchrec's EmbeddingShardingPlanner, except to input eb_configs,
# as dynamicemb need EmbeddingConfig info to help to plan.
return DynamicEmbeddingShardingPlanner(
eb_configs=configs,
topology=topology,
constraints=constraints,
enumerator=enumerator,
storage_reservation=HeuristicalStorageReservation(percentage=0.05),
)
def collect_configs(model):
out = []
ec = model.ec
assert isinstance(ec, EmbeddingCollection)
for config in ec.embedding_configs():
out.append(config)
return out
def apply_dmp_with_optimizer(model, pg, device: torch.device):
"""
The initialization of embedding lookup module in dynamicemb is almost consistent with torchrec.
1. Firstly, you should configure the global parameters of an embedding table using `EmbeddingCollection`.
2. Then, build a `DynamicEmbeddingCollectionSharder`, and generate `ShardingPlan` from `DynamicEmbeddingShardingPlanner`.
3. Finally, pass all parameters to the `DistributedModelParallel`, which then handles the embedding sharding and initialization.
"""
optimizer_type = OptimType.ADAM
lr = 1e-4
caching = True
prefetch_pipeline = True
admission_threshold = 0
"""
After configuring the `EmbeddingCollection`, you need to configure `DynamicEmbeddingCollectionSharder`.
It can create an instance of `ShardedDynamicEmbeddingCollection`.
`ShardedDynamicEmbeddingCollection` provides customized embedding lookup module based on
a GPU-optimized scored hash table which can utilize both device and host memory,
support automatic eviction based on score(per key) while providing better performance.
Besides, due to differences in deduplication between hash tables and array based static tables,
`ShardedDynamicEmbeddingCollection` also provide customized input distributor to support deduplication when `use_index_dedup=True`.
The actual sharding operation occurs during the initialization of the `ShardedDynamicEmbeddingCollection`,
but the parameters used to initialize `DynamicEmbeddingCollectionSharder` will play a key role in the sharding process.
By the way, `DynamicEmbeddingCollectionSharder` inherits `EmbeddingCollectionSharder`,
and its main job is return an instance of `ShardedDynamicEmbeddingCollection`.
"""
ebc_sharder = get_sharder(
lr=lr, optimizer_type=optimizer_type, prefetch_pipeline=prefetch_pipeline, is_ebc=True
)
eb_sharder = get_sharder(
lr=lr, optimizer_type=optimizer_type, prefetch_pipeline=prefetch_pipeline, is_ebc=False
)
sharders = [
cast(ModuleSharder[torch.nn.Module], ebc_sharder),
cast(ModuleSharder[torch.nn.Module], eb_sharder),
]
"""
The next step of preparation is to generate a `ParameterSharding` for each table, describe (configure) the sharding of a parameter.
For dynamic embedding table, `DynamicEmbParameterSharding` will be generated, which includes the parameters required from our embedding lookup module.
We will not expand `DynamicEmbParameterSharding` here.
The following steps demonstrate how to obtain `DynamicEmbParameterSharding` by `DynamicEmbeddingShardingPlanner`.
"""
configs = collect_configs(model)
planner = get_planner(
device, configs=configs, caching=caching, admission_threshold=admission_threshold
)
# get plan for all ranks.
# ShardingPlan is a dict, mapping table name to ParameterSharding/DynamicEmbParameterSharding.
plan: ShardingPlan = planner.collective_plan(model, sharders=sharders, pg=pg)
"""
The final step is to input the `sharder` and `ShardingPlan` to the `DistributedModelParallel`,
who will implement the sharded plan through `sharder` and hold the `ShardedDynamicEmbeddingCollection` after sharding.
Then you can use `dmp` for **training** and **evaluation**, just like using `EmbeddingCollection`.
"""
model = DistributedModelParallel(module=model, device=device, sharders=sharders, plan=plan)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
return model, optimizer
class ItemEmbeddingLookup(nn.Module):
def __init__(self, embedding_dim: int, vocab_size: int, dynamic: bool) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.vocab_size = vocab_size
self._feat_name = "item_id"
self.ec = EmbeddingCollection(
tables=[
EmbeddingConfig(
name="item_table",
embedding_dim=embedding_dim,
num_embeddings=vocab_size,
feature_names=[self._feat_name],
)
],
device=torch.device("meta") if dynamic else None,
)
@torch.cuda.nvtx.range("ItemEmbeddingLookup.forward")
def forward(self, batch_item_ids: List[List[int]]):
device = next(self.ec.parameters()).device
kjt = KeyedJaggedTensor.from_lengths_sync(
keys=[self._feat_name],
values=torch.tensor(
[item for sublist in batch_item_ids for item in sublist],
dtype=torch.int64,
device=device,
),
lengths=torch.tensor([len(sublist) for sublist in batch_item_ids], dtype=torch.int64, device=device),
)
print(f"Input KeyedJaggedTensor: {kjt['item_id'].lengths()}, {kjt['item_id'].values()}")
value = self.ec(kjt)[self._feat_name].to_dense()
return value
def init_dist():
from datetime import timedelta
torch.distributed.init_process_group(backend="nccl", timeout=timedelta(seconds=10))
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
torch.distributed.barrier()
device = f"cuda:{torch.cuda.current_device()}"
return device
def main():
model = ItemEmbeddingLookup(
embedding_dim=16,
vocab_size=10000,
dynamic=True,
)
device = init_dist()
pg = dist.group.WORLD
print(model)
model, optimizer = apply_dmp_with_optimizer(model, pg, torch.device(device))
print(model)
rank = torch.distributed.get_rank()
if rank != 0:
batch = [
[],
[],
[],
[],
]
else:
batch = [
[1, 2, 3],
[4, 5],
[6],
[7, 8, 9, 10],
]
out = model(batch)
print(f"Output embeddings: {out}")
if __name__ == "__main__":
main()
Describe the bug
If an empty batch is fed to one rank, the entire group will hang. This issue was most likely introduced in commit 173bf25, as no hangs occurred before this change.
Steps/Code to reproduce bug
minimal repro code:
Expected behavior
A clear and concise description of what you expected to happen.
Environment details (please complete the following information):
Click here to see environment details
Additional context
Add any other context about the problem here.
By submitting this issue, you agree to follow our code of conduct and our contributing guidelines.