Skip to content

[BUG] Fix hanging issue when a rank receives empty input #341

@gameofdimension

Description

@gameofdimension

Describe the bug

If an empty batch is fed to one rank, the entire group will hang. This issue was most likely introduced in commit 173bf25, as no hangs occurred before this change.

Steps/Code to reproduce bug

torchrun --nproc-per-node=4 --standalone repro.py

minimal repro code:

import math
import os
from typing import List, cast

import torch
import torch.distributed as dist
import torch.nn as nn
from dynamicemb.planner import (DynamicEmbeddingEnumerator,
                                DynamicEmbeddingShardingPlanner,
                                DynamicEmbParameterConstraints)
from dynamicemb.shard import (DynamicEmbeddingBagCollectionSharder,
                              DynamicEmbeddingCollectionSharder)
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
from fbgemm_gpu.split_embedding_configs import SparseType
from torchrec import EmbeddingCollection, EmbeddingConfig, KeyedJaggedTensor
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.fbgemm_qcomm_codec import (CommType, QCommsConfig,
                                                     get_qcomm_codecs_registry)
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import Topology
from torchrec.distributed.planner.storage_reservations import \
    HeuristicalStorageReservation
from torchrec.distributed.planner.types import ShardingPlan
from torchrec.distributed.types import (BoundsCheckMode, ModuleSharder,
                                        ShardingType)

from dynamicemb import (DynamicEmbCheckMode, DynamicEmbInitializerArgs,
                        DynamicEmbInitializerMode, DynamicEmbScoreStrategy,
                        DynamicEmbTableOptions, FrequencyAdmissionStrategy,
                        KVCounter)


def get_sharder(lr, prefetch_pipeline, optimizer_type, is_ebc):
    # set optimizer args
    learning_rate = lr
    beta1 = 0.9
    beta2 = 0.999
    weight_decay = 0
    eps = 1e-8

    # Put args into a optimizer kwargs , which is same usage of torchrec
    optimizer_kwargs = {
        "optimizer": optimizer_type,
        "learning_rate": learning_rate,
        "beta1": beta1,
        "beta2": beta2,
        "weight_decay": weight_decay,
        "eps": eps,
    }

    fused_params = {
        "output_dtype": SparseType.FP32,  # data type of the output after lookup, and can differ from the stored.
        "prefetch_pipeline": prefetch_pipeline,  # whether enable prefetch for embedding lookup module
    }
    fused_params.update(optimizer_kwargs)

    # precision of all-to-all
    qcomm_codecs_registry = get_qcomm_codecs_registry(
        qcomms_config=QCommsConfig(
            forward_precision=CommType.FP32, backward_precision=CommType.FP32
        )
    )

    if not is_ebc:
        return DynamicEmbeddingCollectionSharder(
            qcomm_codecs_registry=qcomm_codecs_registry,
            fused_params=fused_params,
            use_index_dedup=True,
        )
    return DynamicEmbeddingBagCollectionSharder(
        qcomm_codecs_registry=qcomm_codecs_registry, fused_params=fused_params
    )


def create_dynamic_optitons(config, hbm_cap, caching, admission_threshold):
    bucket_capacity = 1024 if caching else 128
    cache_ratio = 0.5  # assume we will use 50% of the HBM for cache

    emb_num_embeddings = config.num_embeddings
    emb_num_embeddings_next_power_of_2 = 2 ** math.ceil(
        math.log2(emb_num_embeddings)
    )  # hash table needs embedding vector num to be power of 2
    world_size = dist.get_world_size()
    threshold = (bucket_capacity * world_size) / cache_ratio
    threshold_int = math.ceil(threshold)
    if emb_num_embeddings_next_power_of_2 < threshold_int:
        emb_num_embeddings_next_power_of_2 = 2 ** math.ceil(math.log2(threshold_int))

    # Setup admission strategy if threshold > 0
    admit_strategy = None
    admission_counter = None
    if admission_threshold > 0:
        # Create counter config (actual table will be created during sharding)
        admission_counter = KVCounter(
            capacity=emb_num_embeddings_next_power_of_2,
            bucket_capacity=bucket_capacity,
            key_type=torch.int64,
        )

        # Create admission strategy with threshold
        admit_strategy = FrequencyAdmissionStrategy(
            threshold=admission_threshold,
            initializer_args=DynamicEmbInitializerArgs(
                mode=DynamicEmbInitializerMode.CONSTANT, value=0.0  # Initialize rejected keys to 0
            ),
        )
    init_capacity = 1 << 24
    return DynamicEmbTableOptions(
        safe_check_mode=DynamicEmbCheckMode.WARNING,
        init_capacity=init_capacity,
        global_hbm_for_values=hbm_cap * 0.5,  # total_hbm_need * (cache_ratio if caching else 1.0),
        initializer_args=DynamicEmbInitializerArgs(mode=DynamicEmbInitializerMode.UNIFORM),
        score_strategy=DynamicEmbScoreStrategy.STEP,
        caching=caching,
        training=True,
        admit_strategy=admit_strategy,
        admission_counter=admission_counter,
    )


def make_shard_constraints(configs, hbm_cap, caching, admission_threshold):
    constraints = {}
    for config in configs:
        options = create_dynamic_optitons(
            config, hbm_cap=hbm_cap, caching=caching, admission_threshold=admission_threshold
        )
        constraint = DynamicEmbParameterConstraints(
            # dynamicemb embedding table only support to be sharded in row-wise.
            sharding_types=[ShardingType.ROW_WISE.value],
            bounds_check_mode=BoundsCheckMode.NONE,  # dynamic embedding has no bounding!
            use_dynamicemb=True,
            dynamicemb_options=options,
        )
        constraints[config.name] = constraint
    return constraints


# use a function warp all the Planner code
def get_planner(device, configs, caching, admission_threshold):
    hbm_cap = torch.cuda.get_device_properties(0).total_memory
    ddr_cap = 512 * 1024 * 1024 * 1024  # Assume a Node have 512GB memory
    intra_host_bw = 450e9  # Nvlink bandwidth
    inter_host_bw = 25e9  # NIC bandwidth

    topology = Topology(
        local_world_size=get_local_size(),
        world_size=dist.get_world_size(),
        compute_device=device.type,
        hbm_cap=hbm_cap,
        ddr_cap=ddr_cap,
        intra_host_bw=intra_host_bw,
        inter_host_bw=inter_host_bw,
    )

    constraints = make_shard_constraints(
        configs,
        hbm_cap=hbm_cap,
        caching=caching,
        admission_threshold=admission_threshold,
    )
    # same usage of  torchrec's EmbeddingEnumerator
    enumerator = DynamicEmbeddingEnumerator(topology=topology, constraints=constraints)

    # Almost same usage of  torchrec's EmbeddingShardingPlanner, except to input eb_configs,
    #   as dynamicemb need EmbeddingConfig info to help to plan.
    return DynamicEmbeddingShardingPlanner(
        eb_configs=configs,
        topology=topology,
        constraints=constraints,
        enumerator=enumerator,
        storage_reservation=HeuristicalStorageReservation(percentage=0.05),
    )


def collect_configs(model):
    out = []
    ec = model.ec
    assert isinstance(ec, EmbeddingCollection)
    for config in ec.embedding_configs():
        out.append(config)
    return out


def apply_dmp_with_optimizer(model, pg, device: torch.device):
    """
    The initialization of embedding lookup module in dynamicemb is almost consistent with torchrec.
        1. Firstly, you should configure the global parameters of an embedding table using `EmbeddingCollection`.
        2. Then, build a `DynamicEmbeddingCollectionSharder`, and generate `ShardingPlan` from `DynamicEmbeddingShardingPlanner`.
        3. Finally, pass all parameters to the `DistributedModelParallel`, which then handles the embedding sharding and initialization.
    """
    optimizer_type = OptimType.ADAM
    lr = 1e-4

    caching = True
    prefetch_pipeline = True
    admission_threshold = 0

    """
    After configuring the `EmbeddingCollection`, you need to configure `DynamicEmbeddingCollectionSharder`.
    It can create an instance of `ShardedDynamicEmbeddingCollection`.
    `ShardedDynamicEmbeddingCollection` provides customized embedding lookup module based on
        a GPU-optimized scored hash table which can utilize both device and host memory,
        support automatic eviction based on score(per key) while providing better performance.
    Besides, due to differences in deduplication between hash tables and array based static tables,
        `ShardedDynamicEmbeddingCollection` also provide customized input distributor to support deduplication when `use_index_dedup=True`.
    The actual sharding operation occurs during the initialization of the `ShardedDynamicEmbeddingCollection`,
        but the parameters used to initialize `DynamicEmbeddingCollectionSharder`  will play a key role in the sharding process.
    By the way, `DynamicEmbeddingCollectionSharder` inherits `EmbeddingCollectionSharder`,
        and its main job is return an instance of `ShardedDynamicEmbeddingCollection`.
    """
    ebc_sharder = get_sharder(
        lr=lr, optimizer_type=optimizer_type, prefetch_pipeline=prefetch_pipeline, is_ebc=True
    )
    eb_sharder = get_sharder(
        lr=lr, optimizer_type=optimizer_type, prefetch_pipeline=prefetch_pipeline, is_ebc=False
    )

    sharders = [
        cast(ModuleSharder[torch.nn.Module], ebc_sharder),
        cast(ModuleSharder[torch.nn.Module], eb_sharder),
    ]

    """
    The next step of preparation is to generate a `ParameterSharding` for each table, describe (configure) the sharding of a parameter.
    For dynamic embedding table, `DynamicEmbParameterSharding` will be generated, which includes the parameters required from our embedding lookup module.
    We will not expand `DynamicEmbParameterSharding` here.
    The following steps demonstrate how to obtain `DynamicEmbParameterSharding` by `DynamicEmbeddingShardingPlanner`.
    """
    configs = collect_configs(model)
    planner = get_planner(
        device, configs=configs, caching=caching, admission_threshold=admission_threshold
    )
    # get plan for all ranks.
    # ShardingPlan is a dict, mapping table name to ParameterSharding/DynamicEmbParameterSharding.
    plan: ShardingPlan = planner.collective_plan(model, sharders=sharders, pg=pg)

    """
    The final step is to input the `sharder` and `ShardingPlan` to the `DistributedModelParallel`,
        who will implement the sharded plan through `sharder` and hold the `ShardedDynamicEmbeddingCollection` after sharding.
    Then you can use `dmp` for **training** and **evaluation**, just like using `EmbeddingCollection`.
    """
    model = DistributedModelParallel(module=model, device=device, sharders=sharders, plan=plan)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    return model, optimizer


class ItemEmbeddingLookup(nn.Module):
    def __init__(self, embedding_dim: int, vocab_size: int, dynamic: bool) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self._feat_name = "item_id"

        self.ec = EmbeddingCollection(
            tables=[
                EmbeddingConfig(
                    name="item_table",
                    embedding_dim=embedding_dim,
                    num_embeddings=vocab_size,
                    feature_names=[self._feat_name],
                )
            ],
            device=torch.device("meta") if dynamic else None,
        )

    @torch.cuda.nvtx.range("ItemEmbeddingLookup.forward")
    def forward(self, batch_item_ids: List[List[int]]):
        device = next(self.ec.parameters()).device
        kjt = KeyedJaggedTensor.from_lengths_sync(
            keys=[self._feat_name],
            values=torch.tensor(
                [item for sublist in batch_item_ids for item in sublist],
                dtype=torch.int64,
                device=device,
            ),
            lengths=torch.tensor([len(sublist) for sublist in batch_item_ids], dtype=torch.int64, device=device),
        )
        print(f"Input KeyedJaggedTensor: {kjt['item_id'].lengths()}, {kjt['item_id'].values()}")
        value = self.ec(kjt)[self._feat_name].to_dense()
        return value


def init_dist():
    from datetime import timedelta
    torch.distributed.init_process_group(backend="nccl", timeout=timedelta(seconds=10))

    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    torch.distributed.barrier()

    device = f"cuda:{torch.cuda.current_device()}"
    return device


def main():
    model = ItemEmbeddingLookup(
        embedding_dim=16,
        vocab_size=10000,
        dynamic=True,
    )

    device = init_dist()
    pg = dist.group.WORLD

    print(model)
    model, optimizer = apply_dmp_with_optimizer(model, pg, torch.device(device))
    print(model)

    rank = torch.distributed.get_rank()
    if rank != 0:
        batch = [
            [],
            [],
            [],
            [],
        ]
    else:
        batch = [
            [1, 2, 3],
            [4, 5],
            [6],
            [7, 8, 9, 10],
        ]
    out = model(batch)
    print(f"Output embeddings: {out}")


if __name__ == "__main__":
    main()

Expected behavior
A clear and concise description of what you expected to happen.

Environment details (please complete the following information):

  • I build the image using the docker/Dockerfile at each commit.
Click here to see environment details
 **git***
 commit 173bf25fb5d5cb58fd691cd4db3b6d1bfdaf3a0d (HEAD)
 Author: shijie liu <aleliu@nvidia.com>
 Date:   Fri Feb 6 19:38:27 2026 +0800

 Optimize dedup indices and segmented unique (#293)

 * add segmented_unique

 * replace dedup_indices with segmented_unique_cuda

 * use segmented_unique_cuda in lookup forward

 * fix some perf issue

 * fix overflow issue

 * clean code
 **git submodules***
 -9c197a9c558d1e8285c2e50c1974f0f102826f11 third_party/HierarchicalKV
 7d49e6c7e2f8896c47f586706e67e1fb215529dc third_party/cutlass (v3.5.0)

 ***OS Information***
 DISTRIB_ID=Ubuntu
 DISTRIB_RELEASE=24.04
 DISTRIB_CODENAME=noble
 DISTRIB_DESCRIPTION="Ubuntu 24.04.2 LTS"
 PRETTY_NAME="Ubuntu 24.04.2 LTS"
 NAME="Ubuntu"
 VERSION_ID="24.04"
 VERSION="24.04.2 LTS (Noble Numbat)"
 VERSION_CODENAME=noble
 ID=ubuntu
 ID_LIKE=debian
 HOME_URL="https://www.ubuntu.com/"
 SUPPORT_URL="https://help.ubuntu.com/"
 BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
 PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
 UBUNTU_CODENAME=noble
 LOGO=ubuntu-logo
 Linux 8a33676f-ac22-47e8-8605-4f0ce8a65049 4.18.0-240.el8.x86_64 #1 SMP Fri Sep 25 19:48:47 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux

 ***GPU Information***
 Fri Apr  3 06:13:31 2026
 +-----------------------------------------------------------------------------------------+
 | NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.9     |
 |-----------------------------------------+------------------------+----------------------+
 | GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
 | Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
 |                                         |                        |               MIG M. |
 |=========================================+========================+======================|
 |   0  NVIDIA L20                     On  |   00000000:0E:00.0 Off |                    0 |
 | N/A   35C    P0             79W /  350W |    3393MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+
 |   1  NVIDIA L20                     On  |   00000000:0F:00.0 Off |                    0 |
 | N/A   38C    P0             76W /  350W |    2235MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+
 |   2  NVIDIA L20                     On  |   00000000:10:00.0 Off |                    0 |
 | N/A   40C    P0             77W /  350W |    2235MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+
 |   3  NVIDIA L20                     On  |   00000000:12:00.0 Off |                    0 |
 | N/A   36C    P0             77W /  350W |    2235MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+
 |   4  NVIDIA L20                     On  |   00000000:87:00.0 Off |                    0 |
 | N/A   27C    P8             34W /  350W |       0MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+
 |   5  NVIDIA L20                     On  |   00000000:88:00.0 Off |                    0 |
 | N/A   28C    P8             34W /  350W |       0MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+
 |   6  NVIDIA L20                     On  |   00000000:89:00.0 Off |                    0 |
 | N/A   31C    P8             34W /  350W |       0MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+
 |   7  NVIDIA L20                     On  |   00000000:8A:00.0 Off |                    0 |
 | N/A   30C    P8             37W /  350W |       0MiB /  46068MiB |      0%      Default |
 |                                         |                        |                  N/A |
 +-----------------------------------------+------------------------+----------------------+

 +-----------------------------------------------------------------------------------------+
 | Processes:                                                                              |
 |  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
 |        ID   ID                                                               Usage      |
 |=========================================================================================|
 +-----------------------------------------------------------------------------------------+

 ***CPU***
 Architecture:                    x86_64
 CPU op-mode(s):                  32-bit, 64-bit
 Address sizes:                   52 bits physical, 57 bits virtual
 Byte Order:                      Little Endian
 CPU(s):                          128
 On-line CPU(s) list:             0-127
 Vendor ID:                       GenuineIntel
 BIOS Vendor ID:                  Intel(R) Corporation
 Model name:                      Intel(R) Xeon(R) Gold 6430
 BIOS Model name:                 Intel(R) Xeon(R) Gold 6430  CPU @ 2.1GHz
 BIOS CPU family:                 179
 CPU family:                      6
 Model:                           143
 Thread(s) per core:              2
 Core(s) per socket:              32
 Socket(s):                       2
 Stepping:                        8
 CPU(s) scaling MHz:              100%
 CPU max MHz:                     2100.0000
 CPU min MHz:                     800.0000
 BogoMIPS:                        4200.00
 Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 wbnoinvd dtherm arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid cldemote movdiri movdir64b md_clear pconfig flush_l1d arch_capabilities
 Virtualization:                  VT-x
 L1d cache:                       3 MiB (64 instances)
 L1i cache:                       2 MiB (64 instances)
 L2 cache:                        128 MiB (64 instances)
 L3 cache:                        120 MiB (2 instances)
 NUMA node(s):                    2
 NUMA node0 CPU(s):               0-31,64-95
 NUMA node1 CPU(s):               32-63,96-127
 Vulnerability Itlb multihit:     Not affected
 Vulnerability L1tf:              Not affected
 Vulnerability Mds:               Not affected
 Vulnerability Meltdown:          Not affected
 Vulnerability Spec store bypass: Vulnerable
 Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
 Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled
 Vulnerability Srbds:             Not affected
 Vulnerability Tsx async abort:   Not affected

 ***CMake***
 /usr/local/bin/cmake
 cmake version 3.31.6

 CMake suite maintained and supported by Kitware (kitware.com/cmake).

 ***g++***
 /usr/bin/g++
 g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
 Copyright (C) 2023 Free Software Foundation, Inc.
 This is free software; see the source for copying conditions.  There is NO
 warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.


 ***nvcc***
 /usr/local/cuda/bin/nvcc
 nvcc: NVIDIA (R) Cuda compiler driver
 Copyright (c) 2005-2025 NVIDIA Corporation
 Built on Tue_May_27_02:21:03_PDT_2025
 Cuda compilation tools, release 12.9, V12.9.86
 Build cuda_12.9.r12.9/compiler.36037853_0

 ***Python***
 /usr/bin/python
 Python 3.12.3

 ***Environment Variables***
 PATH                            : /root/.nvm/versions/node/v25.8.1/bin:/root/.local/bin:/root/.local/bin:/usr/local/python3.7.4/bin:/usr/local/lib/python3.12/dist-packages/torch_tensorrt/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/mpi/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/ucx/bin:/opt/amazon/efa/bin:/opt/tensorrt/bin
 LD_LIBRARY_PATH                 : /usr/local/python3.7.4/lib:/usr/local/lib/python3.12/dist-packages/torch/lib:/usr/local/lib/python3.12/dist-packages/torch_tensorrt/lib:/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
 NUMBAPRO_NVVM                   :
 NUMBAPRO_LIBDEVICE              :
 CONDA_PREFIX                    :
 PYTHON_PATH                     :

 conda not found
 ***pip packages***
 /usr/local/bin/pip
 Package                    Version                       Editable project location
 -------------------------- ----------------------------- ---------------------------
 absl-py                    2.3.0
 aiohappyeyeballs           2.6.1
 aiohttp                    3.12.7
 aiosignal                  1.3.2
 annotated-types            0.7.0
 anyio                      4.9.0
 apex                       0.1
 argon2-cffi                25.1.0
 argon2-cffi-bindings       21.2.0
 arrow                      1.3.0
 asciitree                  0.3.3
 asttokens                  3.0.0
 astunparse                 1.6.3
 async-lru                  2.0.5
 attrs                      25.3.0
 audioread                  3.0.1
 babel                      2.17.0
 beautifulsoup4             4.13.4
 black                      25.1.0
 bleach                     6.2.0
 blis                       0.7.11
 cachetools                 6.0.0
 catalogue                  2.0.10
 certifi                    2025.4.26
 cffi                       1.17.1
 cfgv                       3.5.0
 charset-normalizer         3.4.2
 click                      8.2.1
 cloudpathlib               0.21.1
 cloudpickle                3.1.1
 cmake                      3.31.6
 comm                       0.2.2
 confection                 0.1.5
 contourpy                  1.3.2
 cuda-bindings              12.9.0
 cuda-python                12.9.0
 cudf                       25.4.0
 cudf-polars                25.4.0
 cugraph                    25.4.0
 cugraph-service-client     25.4.0
 cugraph-service-server     25.4.0
 cuml                       25.4.0
 cupy-cuda12x               13.3.0
 cuvs                       25.4.0
 cycler                     0.12.1
 cymem                      2.0.11
 Cython                     3.1.1
 dask                       2025.2.0
 dask-cuda                  25.4.0
 dask-cudf                  25.4.0
 debugpy                    1.8.14
 decorator                  5.2.1
 defusedxml                 0.7.1
 dill                       0.4.0
 distlib                    0.4.0
 distributed                2025.2.0
 distributed-ucxx           0.43.0
 distro                     1.9.0
 dm-tree                    0.1.9
 docker                     7.1.0
 docstring_parser           0.17.0
 dynamicemb                 0.0.1+173bf25
 einops                     0.8.1
 execnet                    2.1.1
 executing                  2.2.0
 expecttest                 0.3.0
 fasteners                  0.19
 fastjsonschema             2.21.1
 fastrlock                  0.8.3
 fbgemm_gpu_nightly         2026.4.3
 filelock                   3.25.2
 flash_attn                 2.7.4.post1
 fonttools                  4.58.1
 fqdn                       1.5.1
 frozenlist                 1.6.0
 fsspec                     2025.5.1
 gast                       0.6.0
 gin-config                 0.5.0
 grpcio                     1.62.1
 h11                        0.16.0
 hstu_attn                  0.1.0+173bf25.cu12.9
 hstu_cuda_ops              0.0.0
 hstu-hopper                0.1.1+173bf25.cu12.9
 httpcore                   1.0.9
 httpx                      0.28.1
 hypothesis                 6.130.8
 identify                   2.6.18
 idna                       3.10
 importlib_metadata         8.7.0
 iniconfig                  2.1.0
 intel-openmp               2021.4.0
 iopath                     0.1.10
 ipykernel                  6.29.5
 ipython                    9.3.0
 ipython_pygments_lexers    1.1.1
 isoduration                20.11.0
 isort                      6.0.1
 jedi                       0.19.2
 Jinja2                     3.1.6
 joblib                     1.5.1
 json5                      0.12.0
 jsonpointer                3.0.0
 jsonschema                 4.24.0
 jsonschema-specifications  2025.4.1
 jupyter_client             8.6.3
 jupyter_core               5.8.1
 jupyter-events             0.12.0
 jupyter-lsp                2.2.5
 jupyter_server             2.16.0
 jupyter_server_terminals   0.5.3
 jupyterlab                 4.4.3
 jupyterlab_code_formatter  3.0.2
 jupyterlab_pygments        0.3.0
 jupyterlab_server          2.27.3
 jupyterlab_tensorboard_pro 4.0.0
 jupytext                   1.17.2
 kiwisolver                 1.4.8
 kvikio                     25.4.0
 langcodes                  3.5.0
 language_data              1.3.0
 lazy_loader                0.4
 libcudf                    25.4.0
 libcugraph                 25.4.0
 libcuml                    25.4.0
 libcuvs                    25.4.0
 libkvikio                  25.4.0
 libraft                    25.4.0
 librmm                     25.4.0
 librmm-cu12                25.4.0
 librosa                    0.11.0
 libucxx                    0.43.0
 lightning-thunder          0.2.3.dev0
 lightning-utilities        0.14.3
 lintrunner                 0.12.7
 llvmlite                   0.42.0
 locket                     1.0.0
 looseversion               1.3.0
 marisa-trie                1.2.1
 Markdown                   3.8
 markdown-it-py             3.0.0
 MarkupSafe                 3.0.2
 matplotlib                 3.10.3
 matplotlib-inline          0.1.7
 mdit-py-plugins            0.4.2
 mdurl                      0.1.2
 megatron-core              0.12.1                        /workspace/deps/megatron-lm
 mistune                    3.1.3
 mkl                        2021.1.1
 mkl-devel                  2021.1.1
 mkl-include                2021.1.1
 mock                       5.2.0
 mpmath                     1.3.0
 msgpack                    1.1.0
 multidict                  6.4.4
 murmurhash                 1.0.13
 mypy_extensions            1.1.0
 nbclient                   0.10.2
 nbconvert                  7.16.6
 nbformat                   5.10.4
 nest-asyncio               1.6.0
 networkx                   3.5
 ninja                      1.11.1.4
 nodeenv                    1.10.0
 notebook                   7.4.3
 notebook_shim              0.2.4
 numba                      0.59.1
 numba-cuda                 0.4.0
 numcodecs                  0.13.1
 numpy                      1.26.4
 nvdlfw_inspect             0.1.0
 nvfuser                    0.2.27a0+9bf5aca
 nvidia-cudnn-frontend      1.12.0
 nvidia-cutlass-dsl         4.3.0
 nvidia-dali-cuda120        1.50.0
 nvidia-ml-py               12.575.51
 nvidia-modelopt            0.29.0
 nvidia-modelopt-core       0.29.0
 nvidia-nvcomp-cu12         4.2.0.14
 nvidia-nvimgcodec-cu12     0.5.0.13
 nvidia-nvjpeg-cu12         12.4.0.16
 nvidia-nvjpeg2k-cu12       0.8.1.40
 nvidia-nvtiff-cu12         0.5.0.67
 nvidia-resiliency-ext      0.4.0
 nvtx                       0.2.11
 nx-cugraph                 25.4.0
 onnx                       1.17.0
 opt_einsum                 3.4.0
 optree                     0.16.0
 ordered-set                4.1.0
 orjson                     3.11.8
 overrides                  7.7.0
 packaging                  23.2
 pandas                     2.2.3
 pandocfilters              1.5.1
 parso                      0.8.4
 partd                      1.4.2
 pathspec                   0.12.1
 pexpect                    4.9.0
 pillow                     11.2.1
 pip                        25.1.1
 platformdirs               4.3.8
 pluggy                     1.6.0
 ply                        3.11
 polars                     1.25.2
 polygraphy                 0.49.20
 pooch                      1.8.2
 portalocker                3.2.0
 pre_commit                 4.5.1
 preshed                    3.0.10
 prometheus_client          0.22.1
 prompt_toolkit             3.0.51
 propcache                  0.3.1
 protobuf                   4.24.4
 psutil                     7.0.0
 ptyprocess                 0.7.0
 PuLP                       3.2.1
 pure_eval                  0.2.3
 pyarrow                    19.0.1
 pybind11                   2.13.6
 pybind11_global            2.13.6
 pycocotools                2.0+nv0.8.1
 pycparser                  2.22
 pydantic                   2.11.5
 pydantic_core              2.33.2
 Pygments                   2.19.1
 pylibcudf                  25.4.0
 pylibcugraph               25.4.0
 pylibcugraphops            25.4.0
 pylibraft                  25.4.0
 pylibwholegraph            25.4.0
 pynvjitlink                0.3.0
 pynvml                     12.0.0
 pyparsing                  3.2.3
 pyre-extensions            0.0.32
 pytest                     8.1.1
 pytest-flakefinder         1.1.0
 pytest-rerunfailures       15.1
 pytest-shard               0.1.2
 pytest-xdist               3.7.0
 python-dateutil            2.9.0.post0
 python-discovery           1.2.1
 python-hostlist            2.2.1
 python-json-logger         3.3.0
 pytorch-triton             3.3.0+git96316ce52.nvinternal
 pytz                       2023.4
 pyvers                     0.2.2
 PyYAML                     6.0.2
 pyzmq                      26.4.0
 raft-dask                  25.4.0
 rapids-dask-dependency     25.4.0a0
 rapids-logger              0.1.18
 referencing                0.36.2
 regex                      2024.11.6
 requests                   2.32.3
 rfc3339-validator          0.1.4
 rfc3986-validator          0.1.1
 rich                       14.0.0
 rmm                        25.4.0
 rpds-py                    0.25.1
 safetensors                0.5.3
 scikit-build               0.19.0
 scikit-learn               1.6.1
 scipy                      1.15.3
 Send2Trash                 1.8.3
 setuptools                 78.1.1
 setuptools-git-versioning  3.0.1
 shellingham                1.5.4
 six                        1.16.0
 smart-open                 7.1.0
 sniffio                    1.3.1
 sortedcontainers           2.4.0
 soundfile                  0.13.1
 soupsieve                  2.7
 soxr                       0.5.0.post1
 spacy                      3.7.5
 spacy-legacy               3.0.12
 spacy-loggers              1.0.5
 srsly                      2.5.1
 stack-data                 0.6.3
 sympy                      1.14.0
 tabulate                   0.9.0
 tbb                        2021.13.1
 tblib                      3.1.0
 tensorboard                2.16.2
 tensorboard-data-server    0.7.2
 tensordict                 0.11.0
 tensorrt                   10.11.0.33
 terminado                  0.18.1
 thinc                      8.2.5
 threadpoolctl              3.6.0
 thriftpy2                  0.5.2
 tinycss2                   1.4.0
 toolz                      1.0.0
 torch                      2.8.0a0+5228986c39.nv25.6
 torch_tensorrt             2.8.0a0
 torchao                    0.11.0+git
 torchmetrics               1.0.3
 torchprofile               0.0.4
 torchrec                   1.2.0+440b1c6
 torchvision                0.22.0a0+95f10a4e
 torchx                     0.7.0
 tornado                    6.5.1
 tqdm                       4.67.1
 traitlets                  5.14.3
 transformer_engine         2.4.0+3cd6870
 treelite                   4.4.1
 triton                     3.6.0
 typer                      0.16.0
 types-dataclasses          0.6.6
 types-python-dateutil      2.9.0.20250516
 typing_extensions          4.14.0
 typing-inspect             0.9.0
 typing-inspection          0.4.1
 tzdata                     2025.2
 ucx-py                     0.43.0
 ucxx                       0.43.0
 uri-template               1.3.0
 urllib3                    1.26.20
 virtualenv                 21.2.0
 wasabi                     1.1.3
 wcwidth                    0.2.13
 weasel                     0.4.1
 webcolors                  24.11.1
 webencodings               0.5.1
 websocket-client           1.8.0
 Werkzeug                   3.1.3
 wheel                      0.45.1
 wrapt                      1.17.2
 xdoctest                   1.0.2
 xgboost                    2.1.4
 yarl                       1.20.0
 zarr                       2.18.7
 zict                       3.0.0
 zipp                       3.22.0

Additional context
Add any other context about the problem here.


By submitting this issue, you agree to follow our code of conduct and our contributing guidelines.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions