Skip to content

[Draft] Feature: inference embedding with C++ export #324

Open
geoffreyQiu wants to merge 11 commits intoNVIDIA:mainfrom
geoffreyQiu:fea-inference_emb_export
Open

[Draft] Feature: inference embedding with C++ export #324
geoffreyQiu wants to merge 11 commits intoNVIDIA:mainfrom
geoffreyQiu:fea-inference_emb_export

Conversation

@geoffreyQiu
Copy link
Copy Markdown
Collaborator

Add Cpp export for Inference embedding based on:

  • ScoredHashTable ops from dynamicemb
  • Exportable module based on LinearBucketTable.
  • LinearUVMEmbedding from NVEmbedding.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 12, 2026

Greptile Summary

This PR adds a C++-exportable inference embedding pipeline on top of dynamicemb, enabling the BatchedDynamicEmbeddingTablesV2 training path to be serialised and loaded as an AOTI artifact in C++. The main additions are: three new INFERENCE_EMB custom CUDA ops (table_lookup, expand_table_ids, get_table_range) with PyTorch dispatch bindings and fake/meta kernels for torch.export, a new InferenceLinearBucketTable and InferenceEmbeddingTable nn.Module wrapper that integrates the custom ops with an NVE LinearUVMEmbedding backend, an end-to-end Python export demo, and two C++ runner binaries that load and exercise the packaged AOTI artifact. The broader dynamicemb layer is also significantly refactored (flat-table storage layout, new LinearBucketTable, scored hash-table rewrite).

Key issues found:

  • Debug print in fake class: FakeLinearUVMEmbedding.__init__ contains a loud debug message ("FAKE LINEAR UVM EMBEDDING CTOR MOOOOOOOOOOOO") that will print on every torch.export trace pass.
  • break-outside-try in _load_nve_torch_bindings: The loop exits after the first existing path regardless of whether torch.classes.load_library succeeded, preventing fallback to subsequent search paths on load failure.
  • Unsafe -1 indices to NVE lookup: global_indices can hold -1 for hash-table misses, and those are passed verbatim to nve_embedding_.lookup() / lookup_with_pooling(). Whether the NVE C++ library handles -1 gracefully is undocumented and untested by contract.
  • Missing pooling_offsets guard: When pooling_mode_ ∈ {1, 2}, pooling_offsets=None is silently forwarded to the NVE C++ call, giving an opaque crash instead of a clear Python error.

Confidence Score: 2/5

  • Not safe to merge as-is; multiple logic bugs in the new inference embedding module can cause silent incorrect results or hard crashes at runtime.
  • The core new file (inference_embedding_impl.py) has four bugs: a debug print that will fire on every export trace, a library-loader that stops searching after the first existing-but-unloadable path, unguarded -1 indices passed to an opaque C++ embedding lookup, and no validation that pooling_offsets is provided when a pooling mode is selected. Any of the last two can produce incorrect embeddings or a process crash in production inference.
  • Pay close attention to examples/hstu/inference/inference_embedding_impl.py — it is the single most changed new file and contains all four flagged issues.

Important Files Changed

Filename Overview
examples/hstu/inference/inference_embedding_impl.py New core module for exportable inference embedding: contains a debug print in the fake class constructor, a break-outside-try bug in _load_nve_torch_bindings, unsafe -1 index passing to NVE lookup for unfound keys, and a missing guard for pooling_offsets=None when pooling is enabled.
examples/hstu/inference/test_export_demo.py New end-to-end test for non-pooled and sum-pooled export; logic is straightforward. Previous issues (dead table_ranges, torch.index_select with -1) have been removed/replaced in the reworked implementation.
corelib/dynamicemb/src/table_operation/lookup_torch_binding.cu New CUDA/CPU dispatch binding for INFERENCE_EMB::table_lookup; CUDA precondition checks are thorough, CPU stub correctly rejects with TORCH_CHECK(false). No major issues found.
corelib/dynamicemb/dynamicemb/index_range_meta.py New file registering fake/meta kernels for get_table_range and expand_table_ids ops; shapes and dtypes are consistent with the real CUDA kernels.
corelib/dynamicemb/dynamicemb/lookup_meta.py New file registering the fake/meta kernel for INFERENCE_EMB::table_lookup; output tensor shapes and dtypes correctly match the real kernel's contract.
examples/hstu/inference/aoti_demo/inference_e2e.cpp New C++ demo loading the AOTI artifact; library loading, tensor loading, and comparison logic are clean. Gracefully handles missing CUDA and argument errors.
corelib/dynamicemb/dynamicemb/scored_hashtable.py Major refactor of the scored hash-table Python layer; adds LinearBucketTable, ScoredHashTable, and related utilities supporting the new inference path. Extensive changes but structurally consistent with the rest of the module.
corelib/dynamicemb/dynamicemb/key_value_table.py Large refactor switching from combined flat table storage to separate contiguous/emb/value flat-table layouts; changes are self-consistent and imports updated accordingly.

Sequence Diagram

sequenceDiagram
    participant Py as Python (InferenceEmbeddingTable.forward)
    participant ExpandOp as INFERENCE_EMB::expand_table_ids (CUDA)
    participant LookupOp as INFERENCE_EMB::table_lookup (CUDA)
    participant NVE as nve::LinearUVMEmbedding (C++)

    Py->>ExpandOp: offsets, num_tables, num_elements
    ExpandOp-->>Py: table_ids (N,)

    Py->>LookupOp: table_storage, table_bucket_offsets, keys, table_ids
    LookupOp-->>Py: scores (N,), founds (N,), table_indices (N,)

    Py->>Py: global_indices = where(founds, table_indices + table_offsets[table_ids], table_indices)

    alt pooling_mode == -1
        Py->>NVE: lookup(global_indices)
        NVE-->>Py: embeddings (N, D)
    else pooling_mode == 1 or 2
        Py->>NVE: lookup_with_pooling(global_indices, pooling_offsets, None, mode)
        NVE-->>Py: pooled_embeddings (B, D)
    end
Loading

Last reviewed commit: "Add pooling mode sum..."

Comment on lines +280 to +290

# Step 2: Expand table IDs from offsets
# expand_table_ids(offsets, table_offsets, num_tables, local_batch_size, num_elements)
# Returns (num_elements,) int64 table_ids indicating which table each element belongs to
num_features = offsets.shape[0] - 1
num_elements = indices.shape[0]

# Prepare table_offsets_in_feature: where in feature space each table starts
table_offsets = torch.arange(
num_features + 1, dtype=torch.int64, device=self.device
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid indices from unfound keys passed to torch.index_select

When the hash table lookup does not find a key, the CUDA kernel writes -1 into table_indices for that position (as confirmed by the kernel code in kernels.cuh which sets indices[i] = -1 on miss). The forward then calls:

embeddings = torch.index_select(self.linear_mem_table, 0, table_indices)

torch.index_select does not support negative indices — passing -1 will raise a RuntimeError: index -1 is out of bounds for dimension 0 with size N at runtime whenever any key is missing from the table.

The torch.where block below this line was written exactly to handle this case (zeroing out embeddings for unfound items), but it is commented out. At minimum, the indices should be clamped before the gather, and then the where-mask applied:

safe_indices = table_indices.clamp(min=0)
embeddings = torch.index_select(self.linear_mem_table, 0, safe_indices)
embeddings = torch.where(
    founds.unsqueeze(-1),
    embeddings,
    torch.zeros_like(embeddings),
)

Comment on lines +255 to +262
)
self.register_buffer("linear_mem_table", linear_mem_table)

def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling
) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computed table_ranges result is never used

table_ranges is computed by calling torch.ops.INFERENCE_EMB.get_table_range(...), but is never referenced again in the forward() method — it's dead code. This is either an incomplete implementation (the range was intended to be used for per-feature routing) or the variable can be removed entirely. Leaving it in place will produce a graph node in the exported model that wastes computation.

Suggested change
)
self.register_buffer("linear_mem_table", linear_mem_table)
def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling
) -> torch.Tensor:
# Step 2: Expand table IDs from offsets

Comment on lines +46 to +55
if os.path.exists(_path):
try:
torch.ops.load_library(_path)
print(f"[INFO] Loaded inference_emb_ops.so from {_path}")
_ops_loaded = True
except Exception as _e:
print(f"[WARN] Failed to load {_path}: {_e}")
break # stop after first found path, whether load succeeded or not

if not _ops_loaded:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Library loading stops at first found path even on failure

The break at line 55 exits the search loop after the first path where os.path.exists() returns True, regardless of whether torch.ops.load_library succeeds or fails. If the file exists but cannot be loaded (e.g., missing CUDA runtime symbols, wrong ABI), _ops_loaded remains False and none of the other fallback paths are tried.

Consider only break-ing when the load actually succeeds:

for _path in _SEARCH_PATHS:
    if os.path.exists(_path):
        try:
            torch.ops.load_library(_path)
            print(f"[INFO] Loaded inference_emb_ops.so from {_path}")
            _ops_loaded = True
            break  # stop only on success
        except Exception as _e:
            print(f"[WARN] Failed to load {_path}: {_e}")

Comment on lines +255 to +260
)
self.register_buffer("linear_mem_table", linear_mem_table)

def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading comment about get_table_range output shape

The inline comment says the output shape is (num_features, 2), but the actual C++ implementation (index_calculation.cu) returns at::empty_like(feature_offsets) — a 1D tensor of shape (num_tables + 1,), the same shape as feature_offsets. This is also confirmed by the fake kernel in index_range_meta.py which returns feature_offsets.new_empty(feature_offsets.shape).

Suggested change
)
self.register_buffer("linear_mem_table", linear_mem_table)
def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
table_ranges = torch.ops.INFERENCE_EMB.get_table_range(
offsets, self.feature_offsets
) # (num_tables + 1,) – same shape as feature_offsets

Comment on lines 57 to +62

def find_source_files(directory, extension_pattern, exclude_dirs=[]):
def find_source_files(
directory,
extension_pattern,
exclude_dirs=[],
exclude_files=[],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutable default argument introduced for exclude_files

exclude_files=[] is a mutable default argument, a well-known Python footgun: the same list object is shared across all calls that rely on the default. While harmless in this particular setup.py usage (called at install time, never mutated), adopting None with an explicit fallback is the conventional safe pattern:

Suggested change
def find_source_files(directory, extension_pattern, exclude_dirs=[]):
def find_source_files(
directory,
extension_pattern,
exclude_dirs=[],
exclude_files=[],
def find_source_files(
directory,
extension_pattern,
exclude_dirs=None,
exclude_files=None,
):
if exclude_dirs is None:
exclude_dirs = []
if exclude_files is None:
exclude_files = []

Comment on lines +240 to +253
offsets = [0]
prev = feature_table_map[0]
for i, tid in enumerate(feature_table_map[1:], start=1):
if tid != prev:
offsets.append(i)
prev = tid
offsets.append(len(feature_table_map))
return offsets


class InferenceLinearBucketTable(torch.nn.Module):
"""Simple exportable hash table wrapper for inference lookup using custom op.

This is a minimal demo version that focuses on lookup-only, non-pooled inference.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _derive_grouped_offsets silently produces wrong offsets for unsorted input

The function only records transitions between consecutive values. For a sorted input like [0, 0, 1, 2] this produces the correct [0, 2, 3, 4]. But if a caller ever passes an unsorted or interleaved map — e.g. [0, 1, 0] representing two features for table 0 and one for table 1 — the function produces [0, 1, 2, 3] (three single-item groups) instead of the semantically correct result, and no error is raised.

InferenceEmbeddingTable.__init__ only validates that each table-id is in [0, num_tables) but does not check that the map is sorted. Any downstream code that depends on contiguous grouping will silently produce incorrect embeddings.

Add a validation guard before computing offsets:

if feature_table_map != sorted(feature_table_map):
    raise ValueError(
        "feature_table_map must be sorted (features for the same table must be contiguous). "
        f"Got: {feature_table_map}"
    )

Comment on lines +185 to +194
for _path in _NVE_TORCH_SEARCH_PATHS:
if os.path.exists(_path):
try:
torch.classes.load_library(_path)
print(f"[INFO] Loaded libnve_torch.so from {_path}")
_register_nve_fake_class()
_nve_torch_loaded = True
except Exception as _e:
print(f"[WARN] Failed to load {_path}: {_e}")
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 break fires even when library load fails in _load_nve_torch_bindings

The break at line 194 sits outside the try/except block. As a result, as soon as a path exists the loop exits—regardless of whether torch.classes.load_library succeeded or threw. If loading fails, _nve_torch_loaded stays False, the remaining fallback paths are never tried, and because _nve_torch_load_attempted is already True any future call returns False immediately.

Suggested change
for _path in _NVE_TORCH_SEARCH_PATHS:
if os.path.exists(_path):
try:
torch.classes.load_library(_path)
print(f"[INFO] Loaded libnve_torch.so from {_path}")
_register_nve_fake_class()
_nve_torch_loaded = True
except Exception as _e:
print(f"[WARN] Failed to load {_path}: {_e}")
break
for _path in _NVE_TORCH_SEARCH_PATHS:
if os.path.exists(_path):
try:
torch.classes.load_library(_path)
print(f"[INFO] Loaded libnve_torch.so from {_path}")
_register_nve_fake_class()
_nve_torch_loaded = True
break # only break on success
except Exception as _e:
print(f"[WARN] Failed to load {_path}: {_e}")

Comment on lines +642 to +665
def forward(
self,
keys: torch.Tensor,
offsets: torch.Tensor,
pooling_offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run embedding lookup with optional pooling.

Args:
keys: (N,) int64 – flat lookup keys; may span multiple
tables and pooling bags.
offsets: (T*B+1,) int64 – CSR boundaries that map segments
of ``keys`` to per-table feature slots; used to
derive a per-key table id via
``INFERENCE_EMB::expand_table_ids``.
pooling_offsets: (B+1,) int64 – CSR boundaries that map segments of
``keys`` to pooling bags; required when
``self.pooling_mode_ >= 0``, otherwise unused.

Returns:
- ``pooling_mode_ == -1``: ``(N, D)`` float tensor of per-key
embeddings.
- ``pooling_mode_ == 1 or 2``: ``(B, D)`` float tensor of pooled
embeddings, where ``B = pooling_offsets.size(0) - 1``.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 pooling_offsets=None silently passed to lookup_with_pooling when pooling is enabled

pooling_offsets defaults to None in the signature. For pooling_mode_ == 1 or 2, the pooled branch unconditionally passes pooling_offsets straight into self.nve_embedding_.lookup_with_pooling(...). If a caller forgets this argument (or passes None explicitly), the C++ library will receive a null/undefined tensor, producing a crash or undefined behaviour with no helpful Python-level error message.

Add a guard at the start of the pooled branch:

if self.pooling_mode_ >= 0:
    if pooling_offsets is None:
        raise ValueError(
            "pooling_offsets must be provided when pooling_mode is 1 (sum) or 2 (mean)"
        )
    return self.nve_embedding_.lookup_with_pooling(...)

@shijieliu shijieliu mentioned this pull request Mar 31, 2026
3 tasks
@shijieliu
Copy link
Copy Markdown
Collaborator

  1. add doc: guidance and benchmark

return score_out, founds, indices


class InferenceEmbeddingTable(torch.nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to move this into dynamicemb?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be added in CI

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a CPP aot inductor test? Or we want to add this demo in CI?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants