Skip to content

Replace one-hot+linear with embedding lookup in RelativePositionEncoding#17

Open
longleo17 wants to merge 1 commit intobytedance:mainfrom
longleo17:pr/embedding-lookup-relposenc
Open

Replace one-hot+linear with embedding lookup in RelativePositionEncoding#17
longleo17 wants to merge 1 commit intobytedance:mainfrom
longleo17:pr/embedding-lookup-relposenc

Conversation

@longleo17
Copy link

Summary

Replaces F.one_hot(idx, K) @ W with direct weight indexing W.T[idx] in RelativePositionEncoding. The original materialized three N_token x N_token one-hot tensors (66/66/6 classes each) then multiplied by a linear layer.

Measured impact (synthetic benchmark)

  • 52% memory reduction on the RelPosEnc path
  • 1.2x faster on small targets (0.9ms → 0.7ms)
  • Prevents OOM on large targets (3000+ tokens): saves ~5GB peak intermediate memory
  • Eliminates the chunked inference path entirely (was needed to handle huge one-hot tensors)

Changes

  • pxdesign/model/embedders.pyRelativePositionEncoding.forward()
  • Mathematically equivalent: one_hot(idx, K) @ W = W.T[idx]
  • Fully backward-compatible with existing checkpoints (same weights, different access pattern)

Test plan

  • pytest tests/test_embedding_relposenc.py -v
  • Verify numerical equivalence on varying target sizes (500, 2000, 4000 tokens)

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes RelativePositionEncoding.forward() by replacing F.one_hot(... ) @ W with direct weight indexing (embedding-style lookups), reducing intermediate memory pressure from large one-hot materializations and removing the prior chunked inference workaround.

Changes:

  • Replace one-hot + linear projection with direct indexing into LinearNoBias.weight slices in pxdesign/model/embedders.py.
  • Add a dedicated test module validating numerical equivalence, expected memory reduction (theoretical), and the presence of the new indexing pattern.

Reviewed changes

Copilot reviewed 2 out of 3 changed files in this pull request and generated 1 comment.

File Description
pxdesign/model/embedders.py Switches RelPosEnc computation from one-hot+linear to weight indexing and summation.
tests/test_embedding_relposenc.py Adds equivalence test and guardrails around the new embedding-lookup implementation.
Comments suppressed due to low confidence (4)

pxdesign/model/embedders.py:298

  • self.linear_no_bias.weight.t().float() forces a float32 cast every forward, which can defeat AMP/mixed precision and adds an extra allocation/copy (even if small). If float32 is required, consider using to(dtype=...) based on the module/input dtype or relying on autocast; otherwise prefer keeping the weight’s existing dtype.
        W = self.linear_no_bias.weight.t().float()  # [input_dim, c_z]
        n_pos = 2 * (self.r_max + 1)

tests/test_embedding_relposenc.py:197

  • The GPU speed test asserts time_optimized <= time_original, which is likely to be flaky on shared/variable GPU CI due to timing noise, clocking, and CUDA scheduling. Consider relaxing the assertion (e.g., allow a small slowdown tolerance), increasing iterations and using median, or marking this as a benchmark/optional test rather than a unit test gate.
    assert time_optimized <= time_original, (
        f"Embedding lookup ({time_optimized*1000:.1f}ms) should not be slower "
        f"than one-hot+linear ({time_original*1000:.1f}ms)"
    )

tests/test_embedding_relposenc.py:218

  • test_no_onehot_in_inference_path is brittle because it searches for hard-coded substrings in the source file. Small refactors (renaming variables, formatting, different indexing API) could break the test without changing behavior. Consider inspecting RelativePositionEncoding.forward via inspect.getsource and checking for one_hot usage within that method specifically (the file still contains F.one_hot elsewhere).
def test_no_onehot_in_inference_path():
    """Verify the inference path no longer uses F.one_hot."""
    from pathlib import Path

    source = (
        Path(__file__).parent.parent / "pxdesign" / "model" / "embedders.py"
    ).read_text()

    # Find the RelativePositionEncoding class and check its forward method
    # The forward should use W_pos[d_residue] pattern, not F.one_hot
    assert "W_pos[d_residue]" in source, (
        "RelativePositionEncoding should use embedding lookup (W_pos[d_residue])"
    )
    assert "W_token[d_token]" in source, (
        "RelativePositionEncoding should use embedding lookup (W_token[d_token])"
    )
    assert "W_chain[d_chain]" in source, (
        "RelativePositionEncoding should use embedding lookup (W_chain[d_chain])"
    )

tests/test_embedding_relposenc.py:137

  • The memory “peak bytes” calculation assumes the one-hot tensors are float32, but F.one_hot produces int64 and the .float() cast happens after concatenation. If the intent is to document/validate memory impact, consider adjusting the calculation/comments to reflect actual dtypes (or measure with torch.cuda.max_memory_allocated() in a CUDA-only test).
    # One-hot approach memory (intermediate tensors):
    # 3 one-hot tensors: N^2 x (66 + 66 + 6) = N^2 x 138 float32
    # Concatenated: N^2 x 139 float32
    # Total peak: N^2 x (138 + 139) = N^2 x 277 float32
    onehot_peak_bytes = n_tokens**2 * 277 * 4

    # Embedding approach memory (intermediate tensors):
    # 3 index tensors: N^2 x 3 int64
    # Output accumulator: N^2 x c_z float32
    emb_peak_bytes = n_tokens**2 * c_z * 4 + n_tokens**2 * 3 * 8

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +308 to +310
p = p + W_token[d_token] # [..., N, N, c_z]
p = p + b_same_entity.unsqueeze(-1).float() * W_entity # [..., N, N, c_z]
p = p + W_chain[d_chain] # [..., N, N, c_z]
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accumulation uses out-of-place additions (p = p + ...) and a full broadcasted mask multiply for the entity term. For large N, each p = p + ... allocates a new [..., N, N, c_z] tensor (and the entity multiply allocates another), which can spike peak memory and reintroduce OOM risk despite removing one-hot. Consider using in-place accumulation (e.g., p.add_(...)) and implementing the entity contribution via masked add/indexing to avoid materializing a full [..., N, N, c_z] intermediate.

Suggested change
p = p + W_token[d_token] # [..., N, N, c_z]
p = p + b_same_entity.unsqueeze(-1).float() * W_entity # [..., N, N, c_z]
p = p + W_chain[d_chain] # [..., N, N, c_z]
# In-place accumulation to avoid allocating new [..., N, N, c_z] tensors
p.add_(W_token[d_token]) # [..., N, N, c_z]
# Add entity contribution only at masked positions to avoid full broadcasted intermediate
b_same_entity_bool = b_same_entity.bool() # [..., N, N]
p[b_same_entity_bool] += W_entity # masked in-place add
p.add_(W_chain[d_chain]) # [..., N, N, c_z]

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants