[Draft] Feature: inference embedding with C++ export #324
[Draft] Feature: inference embedding with C++ export #324geoffreyQiu wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds a C++-exportable inference embedding pipeline on top of Key issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as Python (InferenceEmbeddingTable.forward)
participant ExpandOp as INFERENCE_EMB::expand_table_ids (CUDA)
participant LookupOp as INFERENCE_EMB::table_lookup (CUDA)
participant NVE as nve::LinearUVMEmbedding (C++)
Py->>ExpandOp: offsets, num_tables, num_elements
ExpandOp-->>Py: table_ids (N,)
Py->>LookupOp: table_storage, table_bucket_offsets, keys, table_ids
LookupOp-->>Py: scores (N,), founds (N,), table_indices (N,)
Py->>Py: global_indices = where(founds, table_indices + table_offsets[table_ids], table_indices)
alt pooling_mode == -1
Py->>NVE: lookup(global_indices)
NVE-->>Py: embeddings (N, D)
else pooling_mode == 1 or 2
Py->>NVE: lookup_with_pooling(global_indices, pooling_offsets, None, mode)
NVE-->>Py: pooled_embeddings (B, D)
end
Last reviewed commit: "Add pooling mode sum..." |
|
|
||
| # Step 2: Expand table IDs from offsets | ||
| # expand_table_ids(offsets, table_offsets, num_tables, local_batch_size, num_elements) | ||
| # Returns (num_elements,) int64 table_ids indicating which table each element belongs to | ||
| num_features = offsets.shape[0] - 1 | ||
| num_elements = indices.shape[0] | ||
|
|
||
| # Prepare table_offsets_in_feature: where in feature space each table starts | ||
| table_offsets = torch.arange( | ||
| num_features + 1, dtype=torch.int64, device=self.device | ||
| ) |
There was a problem hiding this comment.
Invalid indices from unfound keys passed to torch.index_select
When the hash table lookup does not find a key, the CUDA kernel writes -1 into table_indices for that position (as confirmed by the kernel code in kernels.cuh which sets indices[i] = -1 on miss). The forward then calls:
embeddings = torch.index_select(self.linear_mem_table, 0, table_indices)torch.index_select does not support negative indices — passing -1 will raise a RuntimeError: index -1 is out of bounds for dimension 0 with size N at runtime whenever any key is missing from the table.
The torch.where block below this line was written exactly to handle this case (zeroing out embeddings for unfound items), but it is commented out. At minimum, the indices should be clamped before the gather, and then the where-mask applied:
safe_indices = table_indices.clamp(min=0)
embeddings = torch.index_select(self.linear_mem_table, 0, safe_indices)
embeddings = torch.where(
founds.unsqueeze(-1),
embeddings,
torch.zeros_like(embeddings),
)| ) | ||
| self.register_buffer("linear_mem_table", linear_mem_table) | ||
|
|
||
| def forward( | ||
| self, | ||
| indices: torch.Tensor, # (batch_size,) indices to lookup | ||
| offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Computed table_ranges result is never used
table_ranges is computed by calling torch.ops.INFERENCE_EMB.get_table_range(...), but is never referenced again in the forward() method — it's dead code. This is either an incomplete implementation (the range was intended to be used for per-feature routing) or the variable can be removed entirely. Leaving it in place will produce a graph node in the exported model that wastes computation.
| ) | |
| self.register_buffer("linear_mem_table", linear_mem_table) | |
| def forward( | |
| self, | |
| indices: torch.Tensor, # (batch_size,) indices to lookup | |
| offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling | |
| ) -> torch.Tensor: | |
| # Step 2: Expand table IDs from offsets |
| if os.path.exists(_path): | ||
| try: | ||
| torch.ops.load_library(_path) | ||
| print(f"[INFO] Loaded inference_emb_ops.so from {_path}") | ||
| _ops_loaded = True | ||
| except Exception as _e: | ||
| print(f"[WARN] Failed to load {_path}: {_e}") | ||
| break # stop after first found path, whether load succeeded or not | ||
|
|
||
| if not _ops_loaded: |
There was a problem hiding this comment.
Library loading stops at first found path even on failure
The break at line 55 exits the search loop after the first path where os.path.exists() returns True, regardless of whether torch.ops.load_library succeeds or fails. If the file exists but cannot be loaded (e.g., missing CUDA runtime symbols, wrong ABI), _ops_loaded remains False and none of the other fallback paths are tried.
Consider only break-ing when the load actually succeeds:
for _path in _SEARCH_PATHS:
if os.path.exists(_path):
try:
torch.ops.load_library(_path)
print(f"[INFO] Loaded inference_emb_ops.so from {_path}")
_ops_loaded = True
break # stop only on success
except Exception as _e:
print(f"[WARN] Failed to load {_path}: {_e}")| ) | ||
| self.register_buffer("linear_mem_table", linear_mem_table) | ||
|
|
||
| def forward( | ||
| self, | ||
| indices: torch.Tensor, # (batch_size,) indices to lookup |
There was a problem hiding this comment.
Misleading comment about get_table_range output shape
The inline comment says the output shape is (num_features, 2), but the actual C++ implementation (index_calculation.cu) returns at::empty_like(feature_offsets) — a 1D tensor of shape (num_tables + 1,), the same shape as feature_offsets. This is also confirmed by the fake kernel in index_range_meta.py which returns feature_offsets.new_empty(feature_offsets.shape).
| ) | |
| self.register_buffer("linear_mem_table", linear_mem_table) | |
| def forward( | |
| self, | |
| indices: torch.Tensor, # (batch_size,) indices to lookup | |
| table_ranges = torch.ops.INFERENCE_EMB.get_table_range( | |
| offsets, self.feature_offsets | |
| ) # (num_tables + 1,) – same shape as feature_offsets |
|
|
||
| def find_source_files(directory, extension_pattern, exclude_dirs=[]): | ||
| def find_source_files( | ||
| directory, | ||
| extension_pattern, | ||
| exclude_dirs=[], | ||
| exclude_files=[], |
There was a problem hiding this comment.
Mutable default argument introduced for exclude_files
exclude_files=[] is a mutable default argument, a well-known Python footgun: the same list object is shared across all calls that rely on the default. While harmless in this particular setup.py usage (called at install time, never mutated), adopting None with an explicit fallback is the conventional safe pattern:
| def find_source_files(directory, extension_pattern, exclude_dirs=[]): | |
| def find_source_files( | |
| directory, | |
| extension_pattern, | |
| exclude_dirs=[], | |
| exclude_files=[], | |
| def find_source_files( | |
| directory, | |
| extension_pattern, | |
| exclude_dirs=None, | |
| exclude_files=None, | |
| ): | |
| if exclude_dirs is None: | |
| exclude_dirs = [] | |
| if exclude_files is None: | |
| exclude_files = [] |
| offsets = [0] | ||
| prev = feature_table_map[0] | ||
| for i, tid in enumerate(feature_table_map[1:], start=1): | ||
| if tid != prev: | ||
| offsets.append(i) | ||
| prev = tid | ||
| offsets.append(len(feature_table_map)) | ||
| return offsets | ||
|
|
||
|
|
||
| class InferenceLinearBucketTable(torch.nn.Module): | ||
| """Simple exportable hash table wrapper for inference lookup using custom op. | ||
|
|
||
| This is a minimal demo version that focuses on lookup-only, non-pooled inference. |
There was a problem hiding this comment.
_derive_grouped_offsets silently produces wrong offsets for unsorted input
The function only records transitions between consecutive values. For a sorted input like [0, 0, 1, 2] this produces the correct [0, 2, 3, 4]. But if a caller ever passes an unsorted or interleaved map — e.g. [0, 1, 0] representing two features for table 0 and one for table 1 — the function produces [0, 1, 2, 3] (three single-item groups) instead of the semantically correct result, and no error is raised.
InferenceEmbeddingTable.__init__ only validates that each table-id is in [0, num_tables) but does not check that the map is sorted. Any downstream code that depends on contiguous grouping will silently produce incorrect embeddings.
Add a validation guard before computing offsets:
if feature_table_map != sorted(feature_table_map):
raise ValueError(
"feature_table_map must be sorted (features for the same table must be contiguous). "
f"Got: {feature_table_map}"
)| for _path in _NVE_TORCH_SEARCH_PATHS: | ||
| if os.path.exists(_path): | ||
| try: | ||
| torch.classes.load_library(_path) | ||
| print(f"[INFO] Loaded libnve_torch.so from {_path}") | ||
| _register_nve_fake_class() | ||
| _nve_torch_loaded = True | ||
| except Exception as _e: | ||
| print(f"[WARN] Failed to load {_path}: {_e}") | ||
| break |
There was a problem hiding this comment.
break fires even when library load fails in _load_nve_torch_bindings
The break at line 194 sits outside the try/except block. As a result, as soon as a path exists the loop exits—regardless of whether torch.classes.load_library succeeded or threw. If loading fails, _nve_torch_loaded stays False, the remaining fallback paths are never tried, and because _nve_torch_load_attempted is already True any future call returns False immediately.
| for _path in _NVE_TORCH_SEARCH_PATHS: | |
| if os.path.exists(_path): | |
| try: | |
| torch.classes.load_library(_path) | |
| print(f"[INFO] Loaded libnve_torch.so from {_path}") | |
| _register_nve_fake_class() | |
| _nve_torch_loaded = True | |
| except Exception as _e: | |
| print(f"[WARN] Failed to load {_path}: {_e}") | |
| break | |
| for _path in _NVE_TORCH_SEARCH_PATHS: | |
| if os.path.exists(_path): | |
| try: | |
| torch.classes.load_library(_path) | |
| print(f"[INFO] Loaded libnve_torch.so from {_path}") | |
| _register_nve_fake_class() | |
| _nve_torch_loaded = True | |
| break # only break on success | |
| except Exception as _e: | |
| print(f"[WARN] Failed to load {_path}: {_e}") |
| def forward( | ||
| self, | ||
| keys: torch.Tensor, | ||
| offsets: torch.Tensor, | ||
| pooling_offsets: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Run embedding lookup with optional pooling. | ||
|
|
||
| Args: | ||
| keys: (N,) int64 – flat lookup keys; may span multiple | ||
| tables and pooling bags. | ||
| offsets: (T*B+1,) int64 – CSR boundaries that map segments | ||
| of ``keys`` to per-table feature slots; used to | ||
| derive a per-key table id via | ||
| ``INFERENCE_EMB::expand_table_ids``. | ||
| pooling_offsets: (B+1,) int64 – CSR boundaries that map segments of | ||
| ``keys`` to pooling bags; required when | ||
| ``self.pooling_mode_ >= 0``, otherwise unused. | ||
|
|
||
| Returns: | ||
| - ``pooling_mode_ == -1``: ``(N, D)`` float tensor of per-key | ||
| embeddings. | ||
| - ``pooling_mode_ == 1 or 2``: ``(B, D)`` float tensor of pooled | ||
| embeddings, where ``B = pooling_offsets.size(0) - 1``. |
There was a problem hiding this comment.
pooling_offsets=None silently passed to lookup_with_pooling when pooling is enabled
pooling_offsets defaults to None in the signature. For pooling_mode_ == 1 or 2, the pooled branch unconditionally passes pooling_offsets straight into self.nve_embedding_.lookup_with_pooling(...). If a caller forgets this argument (or passes None explicitly), the C++ library will receive a null/undefined tensor, producing a crash or undefined behaviour with no helpful Python-level error message.
Add a guard at the start of the pooled branch:
if self.pooling_mode_ >= 0:
if pooling_offsets is None:
raise ValueError(
"pooling_offsets must be provided when pooling_mode is 1 (sum) or 2 (mean)"
)
return self.nve_embedding_.lookup_with_pooling(...)
|
| return score_out, founds, indices | ||
|
|
||
|
|
||
| class InferenceEmbeddingTable(torch.nn.Module): |
There was a problem hiding this comment.
do we want to move this into dynamicemb?
There was a problem hiding this comment.
this needs to be added in CI
There was a problem hiding this comment.
do we have a CPP aot inductor test? Or we want to add this demo in CI?
Add Cpp export for Inference embedding based on:
dynamicemb