Skip to content

[WIP] [tx] Add deterministic mode#1125

Open
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:tx-determinism
Open

[WIP] [tx] Add deterministic mode#1125
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:tx-determinism

Conversation

@pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Feb 15, 2026

First cut at #1121 so we know what is required


Open with Devin

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a deterministic mode, primarily for the JAX backend, to ensure reproducible behavior in training and checkpointing. The changes span across the API, backend implementation, and utility functions to support this feature.

My review focuses on a couple of opportunities to reduce code duplication that has been introduced. Specifically, the test for deterministic checkpointing contains repeated logic, and the method for generating deterministic seeds is duplicated in two places in the API. Refactoring these into helper functions would improve the code's maintainability.

Overall, the implementation of the deterministic mode appears correct and well-tested.

Comment on lines +190 to +216
_ = training_client.forward_backward(processed_examples, "cross_entropy").result()
_ = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()
sampling_path = training_client.save_weights_for_sampler(name="final").result().path
parsed = urlparse(sampling_path)
training_run_id = parsed.netloc
checkpoint_id = parsed.path.lstrip("/")

# Re-run the same train step from the same resume point and verify checkpoint bytes match.
training_client.load_state(resume_path)
_ = training_client.forward_backward(processed_examples, "cross_entropy").result()
_ = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()
sampling_path_2 = training_client.save_weights_for_sampler(name="final_replayed").result().path
parsed_2 = urlparse(sampling_path_2)
training_run_id_2 = parsed_2.netloc
checkpoint_id_2 = parsed_2.path.lstrip("/")

rest_client = service_client.create_rest_client()
# Download the checkpoint
checkpoint_response = rest_client.get_checkpoint_archive_url(training_run_id, checkpoint_id).result()
with tempfile.NamedTemporaryFile() as tmp_archive:
urllib.request.urlretrieve(checkpoint_response.url, tmp_archive.name)
assert os.path.getsize(tmp_archive.name) > 0
with urllib.request.urlopen(checkpoint_response.url) as resp:
checkpoint_bytes = resp.read()
assert len(checkpoint_bytes) > 0

checkpoint_response_2 = rest_client.get_checkpoint_archive_url(training_run_id_2, checkpoint_id_2).result()
with urllib.request.urlopen(checkpoint_response_2.url) as resp:
checkpoint_bytes_2 = resp.read()
assert checkpoint_bytes == checkpoint_bytes_2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication here for running a training step and fetching the resulting checkpoint. The logic in lines 190-211 is repeated with minor differences in lines 198-215.

To improve readability and maintainability, consider extracting this logic into a helper function. This function could take the training_client, rest_client, and a name for the checkpoint, and return the checkpoint bytes. This would make the test's intent clearer and reduce redundancy.

Comment on lines +623 to +627
seed = (
request.lora_config.seed
if request.lora_config.seed is not None
else (0 if deterministic else random.randint(0, 2**31 - 1))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for determining the seed based on the deterministic flag is duplicated from the SamplingParams.to_types method (lines 381-383).

To avoid code duplication and improve maintainability, consider extracting this logic into a shared helper function, for example:

def _get_seed(provided_seed: int | None, deterministic: bool) -> int:
    """Return a seed, falling back to a fixed or random seed based on deterministic mode."""
    if provided_seed is not None:
        return provided_seed
    return 0 if deterministic else random.randint(0, 2**31 - 1)

You could then call this helper function here and in SamplingParams.to_types.

    seed = _get_seed(request.lora_config.seed, deterministic)

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 7 additional findings.

Open in Devin Review

@pcmoritz pcmoritz added the tx label Feb 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant