Replace one-hot+linear with embedding lookup in RelativePositionEncoding#17
Replace one-hot+linear with embedding lookup in RelativePositionEncoding#17longleo17 wants to merge 1 commit intobytedance:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR optimizes RelativePositionEncoding.forward() by replacing F.one_hot(... ) @ W with direct weight indexing (embedding-style lookups), reducing intermediate memory pressure from large one-hot materializations and removing the prior chunked inference workaround.
Changes:
- Replace one-hot + linear projection with direct indexing into
LinearNoBias.weightslices inpxdesign/model/embedders.py. - Add a dedicated test module validating numerical equivalence, expected memory reduction (theoretical), and the presence of the new indexing pattern.
Reviewed changes
Copilot reviewed 2 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
pxdesign/model/embedders.py |
Switches RelPosEnc computation from one-hot+linear to weight indexing and summation. |
tests/test_embedding_relposenc.py |
Adds equivalence test and guardrails around the new embedding-lookup implementation. |
Comments suppressed due to low confidence (4)
pxdesign/model/embedders.py:298
self.linear_no_bias.weight.t().float()forces a float32 cast every forward, which can defeat AMP/mixed precision and adds an extra allocation/copy (even if small). If float32 is required, consider usingto(dtype=...)based on the module/input dtype or relying on autocast; otherwise prefer keeping the weight’s existing dtype.
W = self.linear_no_bias.weight.t().float() # [input_dim, c_z]
n_pos = 2 * (self.r_max + 1)
tests/test_embedding_relposenc.py:197
- The GPU speed test asserts
time_optimized <= time_original, which is likely to be flaky on shared/variable GPU CI due to timing noise, clocking, and CUDA scheduling. Consider relaxing the assertion (e.g., allow a small slowdown tolerance), increasing iterations and using median, or marking this as a benchmark/optional test rather than a unit test gate.
assert time_optimized <= time_original, (
f"Embedding lookup ({time_optimized*1000:.1f}ms) should not be slower "
f"than one-hot+linear ({time_original*1000:.1f}ms)"
)
tests/test_embedding_relposenc.py:218
test_no_onehot_in_inference_pathis brittle because it searches for hard-coded substrings in the source file. Small refactors (renaming variables, formatting, different indexing API) could break the test without changing behavior. Consider inspectingRelativePositionEncoding.forwardviainspect.getsourceand checking forone_hotusage within that method specifically (the file still containsF.one_hotelsewhere).
def test_no_onehot_in_inference_path():
"""Verify the inference path no longer uses F.one_hot."""
from pathlib import Path
source = (
Path(__file__).parent.parent / "pxdesign" / "model" / "embedders.py"
).read_text()
# Find the RelativePositionEncoding class and check its forward method
# The forward should use W_pos[d_residue] pattern, not F.one_hot
assert "W_pos[d_residue]" in source, (
"RelativePositionEncoding should use embedding lookup (W_pos[d_residue])"
)
assert "W_token[d_token]" in source, (
"RelativePositionEncoding should use embedding lookup (W_token[d_token])"
)
assert "W_chain[d_chain]" in source, (
"RelativePositionEncoding should use embedding lookup (W_chain[d_chain])"
)
tests/test_embedding_relposenc.py:137
- The memory “peak bytes” calculation assumes the one-hot tensors are float32, but
F.one_hotproduces int64 and the.float()cast happens after concatenation. If the intent is to document/validate memory impact, consider adjusting the calculation/comments to reflect actual dtypes (or measure withtorch.cuda.max_memory_allocated()in a CUDA-only test).
# One-hot approach memory (intermediate tensors):
# 3 one-hot tensors: N^2 x (66 + 66 + 6) = N^2 x 138 float32
# Concatenated: N^2 x 139 float32
# Total peak: N^2 x (138 + 139) = N^2 x 277 float32
onehot_peak_bytes = n_tokens**2 * 277 * 4
# Embedding approach memory (intermediate tensors):
# 3 index tensors: N^2 x 3 int64
# Output accumulator: N^2 x c_z float32
emb_peak_bytes = n_tokens**2 * c_z * 4 + n_tokens**2 * 3 * 8
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| p = p + W_token[d_token] # [..., N, N, c_z] | ||
| p = p + b_same_entity.unsqueeze(-1).float() * W_entity # [..., N, N, c_z] | ||
| p = p + W_chain[d_chain] # [..., N, N, c_z] |
There was a problem hiding this comment.
The accumulation uses out-of-place additions (p = p + ...) and a full broadcasted mask multiply for the entity term. For large N, each p = p + ... allocates a new [..., N, N, c_z] tensor (and the entity multiply allocates another), which can spike peak memory and reintroduce OOM risk despite removing one-hot. Consider using in-place accumulation (e.g., p.add_(...)) and implementing the entity contribution via masked add/indexing to avoid materializing a full [..., N, N, c_z] intermediate.
| p = p + W_token[d_token] # [..., N, N, c_z] | |
| p = p + b_same_entity.unsqueeze(-1).float() * W_entity # [..., N, N, c_z] | |
| p = p + W_chain[d_chain] # [..., N, N, c_z] | |
| # In-place accumulation to avoid allocating new [..., N, N, c_z] tensors | |
| p.add_(W_token[d_token]) # [..., N, N, c_z] | |
| # Add entity contribution only at masked positions to avoid full broadcasted intermediate | |
| b_same_entity_bool = b_same_entity.bool() # [..., N, N] | |
| p[b_same_entity_bool] += W_entity # masked in-place add | |
| p.add_(W_chain[d_chain]) # [..., N, N, c_z] |
Summary
Replaces
F.one_hot(idx, K) @ Wwith direct weight indexingW.T[idx]in RelativePositionEncoding. The original materialized three N_token x N_token one-hot tensors (66/66/6 classes each) then multiplied by a linear layer.Measured impact (synthetic benchmark)
Changes
pxdesign/model/embedders.py—RelativePositionEncoding.forward()one_hot(idx, K) @ W = W.T[idx]Test plan
pytest tests/test_embedding_relposenc.py -v