[Draft] Fea dynamicemb table fusion#313
Conversation
Greptile SummaryThis PR introduces a table-fusion refactor of the DynamicEmbedding library, consolidating what was previously a per-table object model (
Two P1 bugs and one P2 issue were identified:
Confidence Score: 3/5Not safe to merge as-is: two P1 bugs can cause silent data corruption (stale prefetch states after flush, CUSTOMIZED score guard bypassed). Two confirmed P1 issues remain: (1) flush() does not clear _prefetch_states, meaning slot indices produced before a cache flush will be consumed after it — silently returning wrong embeddings. (2) _create_score now pre-populates self._scores[name] = 0 for CUSTOMIZED tables, so the RuntimeError guard that required users to call set_score() first is permanently bypassed. A P2 external-storage selection issue and an assertion-vs-ValueError nit round out the findings. The overall fusion architecture is sound and test coverage is good, but the two P1 correctness regressions need to be resolved before merging. batched_dynamicemb_tables.py (flush and _create_score); key_value_table.py (external_storage selection) Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant BDET as BatchedDynamicEmbTablesV2
participant PrefetchQ as _prefetch_states (deque)
participant PF as dynamicemb_prefetch
participant FWD as DynamicEmbeddingFunction
participant Cache as DynamicEmbCache
participant Storage as DynamicEmbStorage
User->>BDET: prefetch(indices, offsets)
BDET->>Cache: set_score(fused_score), training=True
BDET->>Storage: set_score(fused_score), training=True
BDET->>PF: dynamicemb_prefetch(...)
PF->>Storage: segmented_unique → unique_keys, table_ids
PF->>Cache: lookup(unique_keys, table_ids)
Cache-->>PF: founds, slot_indices
PF->>Storage: find(miss_keys) [cache miss path]
PF->>Cache: insert_and_evict(miss_keys)
Cache-->>PF: new slot_indices, evicted_keys
PF->>Storage: insert(evicted_keys, evicted_values)
PF-->>BDET: PrefetchState(unique_keys, slot_indices, ...)
BDET->>PrefetchQ: append(PrefetchState)
User->>BDET: forward(indices, offsets)
BDET->>PrefetchQ: popleft() → PrefetchState
BDET->>FWD: apply(prefetch_state, ...)
FWD->>Cache: load_from_flat(slot_indices) [CACHE mode]
FWD-->>BDET: output_embs
Note over FWD: outstanding_keys_ref -= num_prefetched_keys
User->>BDET: backward(grads)
FWD->>FWD: reduce_grads → unique_grads
FWD->>Storage: fused_update_for_flat_table(unique_grads, update_slot_indices)
FWD->>Cache: decrement_counter(update_slot_indices)
|
2c3371a to
a36d744
Compare
| training, | ||
| EvictStrategy(evict_strategy.value) if evict_strategy else None, | ||
| lfu_accumulated_frequency_per_table, | ||
| if prefetch_state is None: |
There was a problem hiding this comment.
What will happen when prefetch depth > 1, will. there any protect to the prefetch_state
| host_options: List[DynamicEmbTableOptions], | ||
| optimizer: BaseDynamicEmbeddingOptimizer, | ||
| ): | ||
| self._hbm = create_table_state(hbm_options, optimizer) |
There was a problem hiding this comment.
Will there two duplicated table states for hybrid state?
a36d744 to
c7590e7
Compare
Additional Comments (5)
Three consecutive statements are bare expressions whose values are immediately discarded: emb_dtype = storage.embedding_dtype()
storage.max_embedding_dim() # return value is thrown away
cache is not None # boolean is thrown awayThese look like forgotten assignments. Based on how the rest of the function uses Even if
When flagged_compact(admit_mask, [missing_keys, missing_indices, missing_table_ids, None])
This path is reachable any time:
The list passed to tensors_to_compact = [missing_keys, missing_indices, missing_table_ids]
if missing_scores is not None:
tensors_to_compact.append(missing_scores)
_, _, compacted = flagged_compact(admit_mask, tensors_to_compact)
if missing_scores is not None:
keys_to_insert, positions_in_unique, table_ids_to_insert, scores_to_insert = compacted
else:
keys_to_insert, positions_in_unique, table_ids_to_insert = compacted
scores_to_insert = None
Simply remove line 1714.
The concrete implementations ( The abstract signature declares only 7 elements – # Storage.find – correct 8-element annotation
) -> Tuple[
int,
torch.Tensor, # missing_keys
torch.Tensor, # missing_indices
torch.Tensor, # missing_table_ids ← add this
Optional[torch.Tensor], # missing_scores
torch.Tensor, # founds
torch.Tensor, # output_scores
torch.Tensor, # values
]:
|
a4097d9 to
4bc15aa
Compare
Additional Comments (5)
In flagged_compact(
admit_mask,
[missing_keys, missing_indices, missing_table_ids, missing_scores],
)passes The fix is to guard the compact call and only include scores when they are available: tensors_to_compact = [missing_keys, missing_indices, missing_table_ids]
if missing_scores is not None:
tensors_to_compact.append(missing_scores)
_, _, compacted = flagged_compact(admit_mask, tensors_to_compact)
if missing_scores is not None:
keys_to_insert, positions_in_unique, table_ids_to_insert, scores_to_insert = compacted
else:
keys_to_insert, positions_in_unique, table_ids_to_insert = compacted
scores_to_insert = None
The original new_score = cur_score + self.num_prefetch_ahead - 1The replacement
In int64_t table_id = NumRegions == 0 ? scalar_table_id : table_ids[emb_id];is not a int64_t table_id;
if constexpr (NumRegions == 0) {
table_id = scalar_table_id;
} else {
table_id = table_ids[emb_id];
}The same pattern appears in |
Additional Comments (7)
The early-return path on line 184–192 bypasses this only when scores_input = (
missing_scores
if missing_scores is not None
else torch.empty(0, dtype=torch.int64, device=missing_keys.device)
)
(
_,
_,
(
keys_to_insert,
positions_in_unique,
table_ids_to_insert,
scores_to_insert,
),
) = flagged_compact(
admit_mask,
[missing_keys, missing_indices, missing_table_ids, scores_input],
)
scores_to_insert = scores_to_insert if missing_scores is not None else None
The old per-table path used This issue also appears in
Consider adding validation: if len({kv.bucket_capacity for kv in kv_counters}) > 1:
raise ValueError(
"All KVCounter configs must share the same bucket_capacity in fused mode."
)
if len({kv.key_type for kv in kv_counters}) > 1:
raise ValueError(
"All KVCounter configs must share the same key_type in fused mode."
)
When creating external PS storage in caching mode, Add validation to ensure consistency: ext_storages = {opt.external_storage for opt in storage_options}
if len(ext_storages) > 1:
raise ValueError(
"All tables must share the same external_storage class in fused caching mode."
)
PS = storage_options[0].external_storage
This TODO should be resolved before the PR is merged. The comment is on a
|
Additional Comments (4)
|
|
* Avoid outstanding keys overflow: decrement in the end of fwd * Fix seq-emb'bw test;fix ref_counter in/decrement bug 1.Fix sequence embedding backward test: issue: there are two forward and one backward calls in one iteration, which will increment the ref_counter twice and decement it once. fix: switch to eval mode when only evaluate the model. other method: move ref_counter's decrement to the end of fwd, but may unlock the key early when there is an overlap of prefetch and backwad 2.Fix ref_counter increment/decrement bug issue: the arg slot_indices are begin from 0 for each table, but we need a flat index. besides, the flat_indices in one iteration are unique as two key can't share the same slot. fix: make increment/decrement the slot_indices for each table using table_ids. * Route to the correct ref_counter table in insert kernel * Update score in the end of prefetch;and only update it for STEP * Fix expected score in test as we update score in prefetch * Remove default value for table_ids in in/decrement_counter and make it not optional
2590bf7 to
8cc97ab
Compare
|
I encountered a severe perf regression with hybrid storage mode ( global_hbm_size = 0). See (new table_insert_and_evict_kernel is quite slow) timeline:
vs
The report: |
| def enable_prefetch(self, value: bool): | ||
| self._enable_prefetch = value | ||
| self.num_prefetch_ahead = 0 | ||
| self._prefetch_outstanding_keys.zero_() |
There was a problem hiding this comment.
It's not recommended to invoke self._prefetch_outstanding_keys.zero_() inside a setter.
See :blog
- Use public attributes whenever appropriate, even if you expect the attribute to require functional behavior in the future.
- Avoid defining setter and getter methods for your attributes. You can always turn them into properties if needed.
- Use properties when you need to attach behavior to attributes and keep using them as regular attributes in your code.
- Avoid side effects in properties because no one would expect operations like assignments to cause any side effects.


Description
Checklist
ci
CI after fixed