Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a deterministic mode, primarily for the JAX backend, to ensure reproducible behavior in training and checkpointing. The changes span across the API, backend implementation, and utility functions to support this feature.
My review focuses on a couple of opportunities to reduce code duplication that has been introduced. Specifically, the test for deterministic checkpointing contains repeated logic, and the method for generating deterministic seeds is duplicated in two places in the API. Refactoring these into helper functions would improve the code's maintainability.
Overall, the implementation of the deterministic mode appears correct and well-tested.
| _ = training_client.forward_backward(processed_examples, "cross_entropy").result() | ||
| _ = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result() | ||
| sampling_path = training_client.save_weights_for_sampler(name="final").result().path | ||
| parsed = urlparse(sampling_path) | ||
| training_run_id = parsed.netloc | ||
| checkpoint_id = parsed.path.lstrip("/") | ||
|
|
||
| # Re-run the same train step from the same resume point and verify checkpoint bytes match. | ||
| training_client.load_state(resume_path) | ||
| _ = training_client.forward_backward(processed_examples, "cross_entropy").result() | ||
| _ = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result() | ||
| sampling_path_2 = training_client.save_weights_for_sampler(name="final_replayed").result().path | ||
| parsed_2 = urlparse(sampling_path_2) | ||
| training_run_id_2 = parsed_2.netloc | ||
| checkpoint_id_2 = parsed_2.path.lstrip("/") | ||
|
|
||
| rest_client = service_client.create_rest_client() | ||
| # Download the checkpoint | ||
| checkpoint_response = rest_client.get_checkpoint_archive_url(training_run_id, checkpoint_id).result() | ||
| with tempfile.NamedTemporaryFile() as tmp_archive: | ||
| urllib.request.urlretrieve(checkpoint_response.url, tmp_archive.name) | ||
| assert os.path.getsize(tmp_archive.name) > 0 | ||
| with urllib.request.urlopen(checkpoint_response.url) as resp: | ||
| checkpoint_bytes = resp.read() | ||
| assert len(checkpoint_bytes) > 0 | ||
|
|
||
| checkpoint_response_2 = rest_client.get_checkpoint_archive_url(training_run_id_2, checkpoint_id_2).result() | ||
| with urllib.request.urlopen(checkpoint_response_2.url) as resp: | ||
| checkpoint_bytes_2 = resp.read() | ||
| assert checkpoint_bytes == checkpoint_bytes_2 |
There was a problem hiding this comment.
There's significant code duplication here for running a training step and fetching the resulting checkpoint. The logic in lines 190-211 is repeated with minor differences in lines 198-215.
To improve readability and maintainability, consider extracting this logic into a helper function. This function could take the training_client, rest_client, and a name for the checkpoint, and return the checkpoint bytes. This would make the test's intent clearer and reduce redundancy.
| seed = ( | ||
| request.lora_config.seed | ||
| if request.lora_config.seed is not None | ||
| else (0 if deterministic else random.randint(0, 2**31 - 1)) | ||
| ) |
There was a problem hiding this comment.
This logic for determining the seed based on the deterministic flag is duplicated from the SamplingParams.to_types method (lines 381-383).
To avoid code duplication and improve maintainability, consider extracting this logic into a shared helper function, for example:
def _get_seed(provided_seed: int | None, deterministic: bool) -> int:
"""Return a seed, falling back to a fixed or random seed based on deterministic mode."""
if provided_seed is not None:
return provided_seed
return 0 if deterministic else random.randint(0, 2**31 - 1)You could then call this helper function here and in SamplingParams.to_types.
seed = _get_seed(request.lora_config.seed, deterministic)
First cut at #1121 so we know what is required