Skip to content

merge#1

Open
masterkni6 wants to merge 578 commits into
masterkni6:masterfrom
LeelaChessZero:master
Open

merge#1
masterkni6 wants to merge 578 commits into
masterkni6:masterfrom
LeelaChessZero:master

Conversation

@masterkni6

Copy link
Copy Markdown
Owner

No description provided.

mooskagh and others added 30 commits March 2, 2026 22:08
Replace _eval_jit_cache global dict and _make_eval_jit factory with a
single @functools.partial(jax.jit, static_argnums=(0, 1)) function,
letting JAX cache per (graphdef, loss_fn) pair natively.
Replace static_argnums=(0,1) approach (which requires graphdef to be
hashable) with an id-keyed cache of JIT closures that capture graphdef
and loss_fn, matching the working ea35598 approach.
The per-instance _eval_jit field was a latent bug — it ignored graphdef
on subsequent calls if the instance was reused with a different graphdef.
Since _make_eval_jit already deduplicates by (graphdef, loss_fn) id at
module level, the field is redundant. Also tighten Any to a typed alias.
Allows freezing selected weights during training by wrapping the
gradient transformation with optax.masked, so frozen weights receive
no gradient updates.
- describe-training: --weight_paths lists all model weight paths in
  slash-separated format, sorted numerically
- migrate-checkpoint: --dump_source_paths / --dump_destination_paths
  print paths as proto rule stubs and exit early
…ozen params

optax.masked(tx, mask) does not zero updates for False-masked params —
it passes them through unchanged. This meant every "frozen" param got
param += raw_gradient each step, causing the activation explosions seen
during block-14 reset training.

Fix: chain the masked optimizer with optax.set_to_zero() so frozen
params receive exactly zero updates.
is_leaf stopped at any NamedTuple (MaskedState, etc.), preventing
traversal to inner ScaleByAdamState/ScaleByScheduleState nodes.
Now is_leaf only stops at known count-bearing types.

Also adds a post-update assertion that no state with a 'count' field
was missed, so new wrapper types fail loudly instead of silently
leaving step counters at 0.
tuple.count() is a builtin method on all tuples, so hasattr(x, "count")
matched MaskedState and other NamedTuples. Check _fields instead to
only match NamedTuples that have count as an actual data field.
Stop at any NamedTuple with a 'count' field via is_leaf, assert it's a
known type, and replace. Regular array leaves pass through unchanged.
Replaces the broken second-traversal approach.
Both the train command and daemon pipeline were missing the
training_steps field when constructing LeelaExportOptions.
This reverts commit beb94e2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants