merge#1
Open
masterkni6 wants to merge 578 commits into
Open
Conversation
Replace _eval_jit_cache global dict and _make_eval_jit factory with a single @functools.partial(jax.jit, static_argnums=(0, 1)) function, letting JAX cache per (graphdef, loss_fn) pair natively.
…l cache" This reverts commit 76e3d9c.
This reverts commit ac23c8b.
…al cache" This reverts commit b8da2dd.
Replace static_argnums=(0,1) approach (which requires graphdef to be hashable) with an id-keyed cache of JIT closures that capture graphdef and loss_fn, matching the working ea35598 approach.
The per-instance _eval_jit field was a latent bug — it ignored graphdef on subsequent calls if the instance was reused with a different graphdef. Since _make_eval_jit already deduplicates by (graphdef, loss_fn) id at module level, the field is redundant. Also tighten Any to a typed alias.
Allows freezing selected weights during training by wrapping the gradient transformation with optax.masked, so frozen weights receive no gradient updates.
- describe-training: --weight_paths lists all model weight paths in slash-separated format, sorted numerically - migrate-checkpoint: --dump_source_paths / --dump_destination_paths print paths as proto rule stubs and exit early
…ozen params optax.masked(tx, mask) does not zero updates for False-masked params — it passes them through unchanged. This meant every "frozen" param got param += raw_gradient each step, causing the activation explosions seen during block-14 reset training. Fix: chain the masked optimizer with optax.set_to_zero() so frozen params receive exactly zero updates.
is_leaf stopped at any NamedTuple (MaskedState, etc.), preventing traversal to inner ScaleByAdamState/ScaleByScheduleState nodes. Now is_leaf only stops at known count-bearing types. Also adds a post-update assertion that no state with a 'count' field was missed, so new wrapper types fail loudly instead of silently leaving step counters at 0.
tuple.count() is a builtin method on all tuples, so hasattr(x, "count") matched MaskedState and other NamedTuples. Check _fields instead to only match NamedTuples that have count as an actual data field.
Stop at any NamedTuple with a 'count' field via is_leaf, assert it's a known type, and replace. Regular array leaves pass through unchanged. Replaces the broken second-traversal approach.
Both the train command and daemon pipeline were missing the training_steps field when constructing LeelaExportOptions.
This reverts commit beb94e2.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.