Skip to content

Merge eric/dev#175

Draft
ealt wants to merge 43 commits intomainfrom
eric/dev
Draft

Merge eric/dev#175
ealt wants to merge 43 commits intomainfrom
eric/dev

Conversation

@ealt
Copy link
Collaborator

@ealt ealt commented Mar 4, 2026

No description provided.

casperlchristensen and others added 30 commits December 9, 2025 10:16
* pylint (#98)

* add compatibility for factored states

* concrete examples and alternating process

* tweaks to vocab sizes

* update naming

* lock

* full merge, renaming

* test factored representation

* finalise gen-process PR

* update after merge

* static analysis

* static analysis tweaks

* arg name

* better test coverage

* factor input args

* ruff

* better linting

* bind i

* elipsis to protocol

* simplify protocol

* format

* Minor fixes

* Minor fixes

* jnp.ndarray -> jax.Array

* Fix JIT compilation issue

Previous code extracted values from JAX arrays and convert to Python ints at runtime. This will fail when the function is JIT-compiled because JAX arrays become tracers during compilation, and int() on a tracer raises an error. The vocab_sizes parameter must be provided to __init__ for this method to work with JIT.

* Refactor generative process config tests to use a helper method for creating factored process configurations. Added parameterized tests for valid and invalid configurations, improving test coverage and maintainability.

* Add docstrings

* Add match strings to value errors in tests

* add better factor handling and allow regression to individual factors

* pass device

* static analysis

* better output format

* to_factor in validation

* update returns and concatenations

* tuple handling

* fix typehint

* improve test coverage

---------

Co-authored-by: ealt <ealt@users.noreply.github.com>
Co-authored-by: Eric Alt <ericallenalt@gmail.com>
* Enhance PyTorch training with metric tracking and update configuration

- Introduced `TrainingMetricTracker` for stateful metric tracking during PyTorch training, allowing for detailed monitoring of loss, learning rates, and parameter updates.
- Updated `train_pytorch_model` to integrate the metric tracker, enabling automatic logging of training metrics.
- Added new metrics to track cumulative and instantaneous values, including loss averages and parameter norms.
- Modified `pyproject.toml` to include `reportUnnecessaryEllipsis` setting and added `diff-cover` as a development dependency.
- Expanded the README with documentation on the new `TrainingMetricTracker` and its usage.
- Added tests for the metric tracker to ensure accurate reporting of metrics during training.

* pylint (#98)

* add compatibility for factored states

* concrete examples and alternating process

* tweaks to vocab sizes

* Create design doc

* Implement solution

* Add plotly support

* Disable too many instance attributes in configs

* Replace dict with {}

* Import altair in a normal way

* Remove init

* Reorganize altair dependency in pyproject.toml

* Fix demo imports

* Refactor metric tracker

* Update metrics

* Add current loss metric enhancements

- Introduced additional metrics for tracking loss: minimum loss, moving average (MA), and exponential moving average (EMA).
- Updated the `compute` method to return these new metrics alongside the current loss.
- Enhanced the distance from initialization metric to track the maximum distance encountered during training.

* Fix bugs with metrics tracker

* Fix loss metrics

* update naming

* Rename metric tracker

* Refactor MetricTracker and metrics initialization

- Removed initial_loss and optimal_loss parameters from MetricTracker constructor.
- Introduced metric_kwargs to pass additional parameters for metrics initialization.
- Updated the _initialize_context and _initialize_metrics methods to accommodate changes.
- Enhanced CurrentLossMetric and LossProgressMetric to use kwargs for initialization, improving flexibility.

* Refactor MetricTracker and MetricContext to unify named parameters handling

- Renamed and consolidated handling of named parameters in MetricTracker and MetricContext.
- Updated methods to use a single `named_parameters` attribute instead of separate current and previous parameters.
- Adjusted metrics computations to reflect the new structure, ensuring consistency across metrics that rely on named parameters.

* Refactor MetricTracker and MetricContext to use unified token count

- Renamed `batch_tokens` and `total_tokens` to `num_tokens` in MetricContext and MetricTracker.
- Updated metrics calculations in TokensMetric, LearningRateWeightedTokensMetric, and GradientWeightedTokensMetric to reflect the new naming convention.
- Enhanced cumulative token tracking for improved clarity and consistency.

* Refactor metrics to use update method and improve computation

- Updated the `compute` method in various metrics to remove context dependency and introduced an `update` method for state management.
- Enhanced metrics such as TokensMetric, LearningRateMetric, and GradientWeightedTokensMetric to maintain internal state for more efficient calculations.
- Added new utility functions for L2 norm calculations across collections of tensors, improving performance and clarity in metric computations.

* Refactor LossProgressMetric to separate update and compute methods

- Introduced an `update` method to manage the current loss state, enhancing clarity and separation of concerns.
- Updated the `compute` method to calculate progress based on the current loss, improving the metric's functionality.

* Update TokensMetric to rename token metrics for clarity

- Changed metric keys from "tokens/batch" and "tokens/total" to "tokens/raw" and "tokens/raw/cumulative" to better reflect their purpose and improve consistency in naming conventions.

* Clear gradients and learning rates after metric computation in GradientWeightedTokensMetric and FisherInformationMetric for improved state management.

* Refactor MetricTracker to enhance metric group handling and requirements management

- Updated MetricTracker to initialize metric groups and requirement flags more efficiently.
- Modified the update method to support group-specific requirements for learning rates, gradients, and named parameters.
- Simplified the initialization of metrics by consolidating logic and improving clarity in the code structure.
- Added `update_every_step` attribute to several metrics for better state management during updates.

* Add logging for missing update keys in MetricTracker

- Introduced logging to warn when required update keys are missing for metric groups.
- Enhanced metric group handling by adding a method to identify missing update keys based on the `update_every_step` attribute.
- Improved clarity in the metric initialization process by consolidating logic for required metrics.

* Refactor L2 norm computation in metrics.py

- Simplified the docstring for the _tensor_collection_l2_norms function to focus on its core functionality.
- Removed unnecessary casting to CPU in the _named_tensor_distance function to streamline tensor operations.

* Refactor metric computations to utilize new utility functions

- Replaced internal L2 norm and distance calculations in metrics.py with calls to the newly defined tensor_collection_l2_norm and named_tensor_distance functions from pytorch_utils.py.
- Updated docstrings for clarity and removed redundant comments to streamline the codebase.

* Refactor MetricTracker and metrics protocol for improved clarity

- Renamed the TrainingMetric protocol to Metric for better alignment with its purpose.
- Updated the MetricTracker's _initialize_metrics method to utilize the new Metric protocol, enhancing type consistency and clarity in metric initialization.

* Refactor metrics to utilize tensor_stack_l2_norm for improved efficiency

- Replaced instances of tensor_collection_l2_norm with tensor_stack_l2_norm in various metrics for optimized L2 norm calculations.
- Simplified the update and compute methods in GradientWeightedTokensMetric, CumulativeParameterUpdateMetric, and FisherInformationMetric to enhance state management and clarity.
- Removed redundant internal functions for L2 norm and distance calculations, streamlining the codebase.

* Remove metric tracker

* add activation analysis work

* reformat tests

* move protocol and inherit

* add example

* less jax conversions

* jax first

* claude feedback

* better types

* fix tests

* fix initialisation from config

* use protocol only for duck-typing

* error handling

* simplified docstrings, unused variables

* pyright protocol

* analyses tweaks

* typing

* refactor: Split `MetricTracker.update` into `step` and `update_metrics`, and optimize tensor operations in `named_tensor_distance`, gradient extraction, and parameter snapshots by removing CPU transfers and vectorizing calculations.

* Add configs and metric tracker in run management

* update pr

* fix uv

* pin for transformer_lens compatibility

* add layerwise analysis classes

* use correct return class

* dataclass access notation

* fix lock

* ruff format

* pylint

* remove unused arg

* fix tests after refactor

* linter happiness

* separate responsibilities of generative processes

* revert

* add activation tracker test

* add activation tracker test

* final feedback

* proper instantiate

* no aliasing

* remove unneeded sklearn

* simplify last token

* fix tests after refactor

* remove unusd dict and handle div by 0

* add tests to analysis functions

* better coverage

* make pyright happy

* add config coverage

* pull out normalization functions

* be more explicit about missing data/features

* PR feedback

* missing docstrings

* make methods public and document

* prepare options

* unnecessary conversion

* missing docstring

* formatting

* use explicit typehints

* use prepare options in tests

* change tests to JNP

* unused import

* update test coverage

* wip activation visualization

* merge

* add lock

* run with scalars

* temporary commit for merge

* static analysis checks

* update after facet-plots

* mute final pylint warnings

* fix final static analyses

* get rid of type alias

* small e2e test only

* use activation tracker in e2e

* fix yaml structure

* remove unused config

* fix end to end tests

* add more coverage

* add more tests

* refactor to more modularity

* training config

* add schedulers, proper bos handling in loop

* remove unnecessary comment

* add schedulers to e2e configs

* handle bos-token behaviour in tests

* get rid of large docs file

* add LR scheduler tests

* add LR schedulers for exact recreations of Adam's plots

* fix pyright

* only little test

* fix colour test

* make altair optional

* make altair obligatory

* simplify conversions

* consolidations

* Delete tests/end_to_end/configs/demo_config_with_visuals.yaml

* Delete tests/end_to_end/configs/demo_config_with_visuals.py

* consolidation (again)

* address PR feedback

* better typing

---------

Co-authored-by: Eric Alt <ericallenalt@gmail.com>
Co-authored-by: ealt <ealt@users.noreply.github.com>
Co-authored-by: Casper Lutzhoft Christensen <casper@g488.voltagepark.net>
* Refactor regression code to incorporate optional computation of pairwise subspace orthogonality metrics

* Refine regression API and add comprehensive orthogonality tests

- Separate coeffs/intercept in return structure (omit intercept key when
  fit_intercept=False)
- Rename to_factors → concat_belief_states for clarity
- Add 9 orthogonality tests with principled numerical thresholds
  (safety_factor=10)
- Test orthogonal, aligned, contained subspaces; multi-factor scenarios;
  edge cases
- Update validators and existing tests for new parameter structure
- Add informative assertion messages for debugging numerical precision

* Organize imports

* Fix lint issues

* Fix slices

* Simplify lr kwarg validation

* Add return type

* Add pylint ignore

* Fix potential division by zero

* Fix potential log(0) issue

* Enhance subspace orthogonality computation by adding a check for multiple belief states. Log a warning if only one belief state is present, preventing unnecessary calculations.

* Fix docstring inconsistency

* Update docstring

* Fix lint issues

* Refactor linear regression kwargs validation and improve logging. Temporarily disable pylint checks during AST traversal to avoid crashes related to package imports.

* Fix merge conflict

* Ammended unseen merge conflict in linear_regression tests

* Rename to_factors parameter to concat_belief_states in activation analyses

* Update activation analysis tests for concat_belief_states semantics

* Fix validator error message and fix linting issues

* Add check requiring 2+ factors in _handle_factored_regression and remove redundant orthogonality compuations warning

* Add proper spacing to warning messages

* Fix dictionary equivalence check in test_linear_regression and add blank line after docstring in test_layerwise_analysis

* Refactor subspace orthogonality computation for JIT compatibility

* Fix conditional callback execution using jax.lax.cond

* Fix linting and formatting issues

* Fix formatting issues

* Disable too-many-locals linting issue in test_linear_regression.py

* Change name of return dict from singular_values -> arrays for clarity

* Add docstring describing return values for _compute_all_pairwise_orthogonality function

* Add docstring describing relevance of the do_nothing_branch function

* Refactor key removal method in kwarg validator and fix docstring format

* Temporarily disable pylint checks during AST traversal in linear_regression.py to prevent crashes. Remove deprecated layer_linear_regression_svd function for cleaner code and encourage use of layer_linear_regression with use_svd=True.

* Refactor linear regression analysis registration to use partial application of layer_linear_regression with use_svd=True, removing the deprecated layer_linear_regression_svd function for improved clarity and consistency.

* Fix tests

* Add detailed docstring to _compute_subspace_orthogonality function, specifying return values and their meanings for improved clarity and documentation.

* Add todo

* Fix kwarg validation

* Fix tests

* Add validator decorator for linear_regression_svd to enforce use_svd=True and exclude it from output. Enhance tests to validate behavior.

* Fix test

* Add get_robust_basis for robust orthonormal basis extraction

* Pass pair of bases instead of coefficient matrices to _compute_subspace_orthogonality

* Compute full rank and orthonormal basis of coeff matrices before passing bases to subspace analysis

* Fix formatting and docstring

* Update comment

* Fix issues due to API changes in activation and dataframe tests

* Fix formatting issues

---------

Co-authored-by: Eric Alt <ericallenalt@gmail.com>
…nalysis and LinearRegressionSVDAnalysis (#140)

* Enhance PyTorch training with metric tracking and update configuration

- Introduced `TrainingMetricTracker` for stateful metric tracking during PyTorch training, allowing for detailed monitoring of loss, learning rates, and parameter updates.
- Updated `train_pytorch_model` to integrate the metric tracker, enabling automatic logging of training metrics.
- Added new metrics to track cumulative and instantaneous values, including loss averages and parameter norms.
- Modified `pyproject.toml` to include `reportUnnecessaryEllipsis` setting and added `diff-cover` as a development dependency.
- Expanded the README with documentation on the new `TrainingMetricTracker` and its usage.
- Added tests for the metric tracker to ensure accurate reporting of metrics during training.

* pylint (#98)

* add compatibility for factored states

* concrete examples and alternating process

* tweaks to vocab sizes

* Refactor metric tracker

* Update metrics

* Add current loss metric enhancements

- Introduced additional metrics for tracking loss: minimum loss, moving average (MA), and exponential moving average (EMA).
- Updated the `compute` method to return these new metrics alongside the current loss.
- Enhanced the distance from initialization metric to track the maximum distance encountered during training.

* Fix bugs with metrics tracker

* Fix loss metrics

* update naming

* Rename metric tracker

* Refactor MetricTracker and metrics initialization

- Removed initial_loss and optimal_loss parameters from MetricTracker constructor.
- Introduced metric_kwargs to pass additional parameters for metrics initialization.
- Updated the _initialize_context and _initialize_metrics methods to accommodate changes.
- Enhanced CurrentLossMetric and LossProgressMetric to use kwargs for initialization, improving flexibility.

* Refactor MetricTracker and MetricContext to unify named parameters handling

- Renamed and consolidated handling of named parameters in MetricTracker and MetricContext.
- Updated methods to use a single `named_parameters` attribute instead of separate current and previous parameters.
- Adjusted metrics computations to reflect the new structure, ensuring consistency across metrics that rely on named parameters.

* Refactor MetricTracker and MetricContext to use unified token count

- Renamed `batch_tokens` and `total_tokens` to `num_tokens` in MetricContext and MetricTracker.
- Updated metrics calculations in TokensMetric, LearningRateWeightedTokensMetric, and GradientWeightedTokensMetric to reflect the new naming convention.
- Enhanced cumulative token tracking for improved clarity and consistency.

* Refactor metrics to use update method and improve computation

- Updated the `compute` method in various metrics to remove context dependency and introduced an `update` method for state management.
- Enhanced metrics such as TokensMetric, LearningRateMetric, and GradientWeightedTokensMetric to maintain internal state for more efficient calculations.
- Added new utility functions for L2 norm calculations across collections of tensors, improving performance and clarity in metric computations.

* Refactor LossProgressMetric to separate update and compute methods

- Introduced an `update` method to manage the current loss state, enhancing clarity and separation of concerns.
- Updated the `compute` method to calculate progress based on the current loss, improving the metric's functionality.

* Update TokensMetric to rename token metrics for clarity

- Changed metric keys from "tokens/batch" and "tokens/total" to "tokens/raw" and "tokens/raw/cumulative" to better reflect their purpose and improve consistency in naming conventions.

* Clear gradients and learning rates after metric computation in GradientWeightedTokensMetric and FisherInformationMetric for improved state management.

* Refactor MetricTracker to enhance metric group handling and requirements management

- Updated MetricTracker to initialize metric groups and requirement flags more efficiently.
- Modified the update method to support group-specific requirements for learning rates, gradients, and named parameters.
- Simplified the initialization of metrics by consolidating logic and improving clarity in the code structure.
- Added `update_every_step` attribute to several metrics for better state management during updates.

* Add logging for missing update keys in MetricTracker

- Introduced logging to warn when required update keys are missing for metric groups.
- Enhanced metric group handling by adding a method to identify missing update keys based on the `update_every_step` attribute.
- Improved clarity in the metric initialization process by consolidating logic for required metrics.

* Refactor L2 norm computation in metrics.py

- Simplified the docstring for the _tensor_collection_l2_norms function to focus on its core functionality.
- Removed unnecessary casting to CPU in the _named_tensor_distance function to streamline tensor operations.

* Refactor metric computations to utilize new utility functions

- Replaced internal L2 norm and distance calculations in metrics.py with calls to the newly defined tensor_collection_l2_norm and named_tensor_distance functions from pytorch_utils.py.
- Updated docstrings for clarity and removed redundant comments to streamline the codebase.

* Refactor MetricTracker and metrics protocol for improved clarity

- Renamed the TrainingMetric protocol to Metric for better alignment with its purpose.
- Updated the MetricTracker's _initialize_metrics method to utilize the new Metric protocol, enhancing type consistency and clarity in metric initialization.

* Refactor metrics to utilize tensor_stack_l2_norm for improved efficiency

- Replaced instances of tensor_collection_l2_norm with tensor_stack_l2_norm in various metrics for optimized L2 norm calculations.
- Simplified the update and compute methods in GradientWeightedTokensMetric, CumulativeParameterUpdateMetric, and FisherInformationMetric to enhance state management and clarity.
- Removed redundant internal functions for L2 norm and distance calculations, streamlining the codebase.

* Remove metric tracker

* refactor: Split `MetricTracker.update` into `step` and `update_metrics`, and optimize tensor operations in `named_tensor_distance`, gradient extraction, and parameter snapshots by removing CPU transfers and vectorizing calculations.

* Add configs and metric tracker in run management

* Simplify

* Refactor metrics and tracker

* Rename step group

* Renames

* Update metric tracker config validation

* Make metric tracker context non-private

* Get initial loss from context

* Add metric tracker to e2e test

* Remove example

* Fix config name

* Cahange dict to mapping to handle DictConfig

* Fix bug in updating lr

* Remove unused return value, simplify method call

* Refactor metric naming conventions for consistency and clarity. Update metric keys to include context and step information, and rename CurrentLossMetric to LossMetric for better understanding.

* Add loss progress to LossMetric

* Refactor requirements formatting in metrics for improved readability and consistency

* Enhance ParameterNormMetric to compute both parameter and weight norms, consolidating metrics into a single return statement. Remove WeightNormMetric class as its functionality is now integrated.

* Rename keys, merge fisher proxy into grad weighted tokens

* Update names

* Enhance MetricTracker and LossMetric to support custom step values, improving flexibility in metric tracking and loss computation.

* Remove step from context

* Add eval metric tracker to training

* Remove weights norm

* Check if metric names is a list config

* add instance to metric tracker keys

* Disable databricks.sdk info logs

* Configure devices to be the same

* Reanme experiment/run names

* Add tokens per second metrics

* Detatch loss before converting to float

* Create full training configs

* Update uv.lock

* ruff format

* Avoid div by zero

* lock

* full merge, renaming

* Fix training test

* test factored representation

* Fix device mismatch

* Device mismatch pt 2

* finalise gen-process PR

* update after merge

* static analysis

* static analysis tweaks

* arg name

* better test coverage

* factor input args

* ruff

* better linting

* bind i

* elipsis to protocol

* simplify protocol

* format

* hack to get training working again

* Simplify components key

* Change metrics returns

* Update optimizer handling to log warnings for multiple optimizers and return None instead of the first optimizer.

* Create tests for requirements

* learning rates metric test

* Tokens metric test

* lr weighted tokens test

* gradient weighted tokens test

* parameter update test

* Have loss progess approach zero instead of one

* loss metric test

* param norm test

* parameter distance test

* uv sync

* Test pytorch utils

* Create metric groups property

* Create metric tracker tests

* add xavier's leaky RRXOR (#130)

* Update workflows to support dev branch ruleset standards

* Update GitHub workflows to correctly reference pull request base branches in conditions

* feat: Add `compute_subspace_orthogonality` option to `LinearRegressionAnalysis` and `LinearRegressionSVDAnalysis` to expose subspace metrics, along with corresponding tests.

---------

Co-authored-by: Casper Lutzhoft Christensen <clu@corti.ai>
Co-authored-by: Casper Lützhøft Christensen <61698286+casperlchristensen@users.noreply.github.com>
* fix slider rendering

* fix reference

* update tests post bug-fix

* static analysis
* Add simplexity-multirun CLI for parallel experiment execution

Add a new CLI tool for running multiple Hydra experiments in parallel
across GPUs or CPU workers with proper device isolation.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix pylint and ruff linting issues

- Add pylint disable comments for too-many-arguments, too-many-locals, etc.
- Initialize variables before conditional to fix possibly-used-before-assignment
- Use raw docstring (r""") for backslash escapes
- Add strict=True to zip() call

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Refactor run_parallel to separate job generation from dispatch

- Add Job dataclass with to_cmd() method for rendering commands
- Extract generate_jobs() as a pure function for testability
- Extract dispatch_jobs() to encapsulate ProcessPoolExecutor logic
- Simplify main() to two-phase structure: generate then dispatch
- Dry-run now exits before dispatch instead of passing through executor
- Add tests for Job and generate_jobs() (GPU round-robin, sweep expansion)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Add missing docstrings to test methods

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

---------

Co-authored-by: adamimos <adam@g093.voltagepark.net>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
* save more path-specific visualizations

* update test path
* generic d_vocab resolution

* rename and test

* static analysis
* Abbreviate linear regression scalar metric names

* Add method to make layer names more compact and to construct layer-specific metric key names

* Integrate layer formatting methods into LayerwiseAnalysis

* Update visualization key lookup to use new {analysis}/{layer} format

Change projection and scalar key resolution to use the new naming convention
where keys follow {analysis}/{layer_spec} format (e.g., "pca/L0.resid.pre")
instead of the old {layer}_{analysis} format (e.g., "layer_0_pca").

Key changes:
- Update _lookup_projection_array and _lookup_scalar_value to match keys
  by prefix (analysis/) rather than suffix (_analysis)
- Add _key_matches_layer helper to handle factor-suffixed keys like
  "projected/layer_0-F0" when given pattern "projected/F0"
- Update _expand_projection_key_pattern to extract factor suffixes from
  new format and reconstruct pattern-matchable keys
- Update _expand_scalar_pattern_keys to properly handle analysis prefix
  for patterns with internal slashes

Update all test files to use new key format in mock data.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Format layer names in visualization lookups and update test key assertions

- Add format_layer_spec to field_resolution.py for converting layer names
  (e.g., blocks.0.hook_resid_pre → L0.resid.pre) before key matching
- Update dataframe_builders.py to format layer names in scalar series
  inference and DataFrame construction
- Update test_linear_regression.py assertions to new key format:
  - factor_X/metric → metric/FX
  - orthogonality_X_Y/metric → orth/metric_short/FX,Y
  - concat/metric → metric/Fcat
- Update test_layerwise_analysis.py assertions to new key format:
  - layer_metric → metric/layer
- Update with_visuals.yaml config templates to match new format
- Update test_activation_tracker_config.py key assertion

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Abbreviate PCA metric names for consistency

- variance_explained → var_exp
- n_components_{pct}pct → nc_{pct}
- cumvar_{idx} unchanged

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix formatting and linting

* Format layer names in pattern expansion for projection key matching

The pattern expansion logic was using unformatted layer names (e.g.,
'blocks.0.hook_resid_pre') to match against projection keys that have
formatted layer names (e.g., 'projected/L0.resid.pre'). This caused
pattern matching to fail when expanding projection key patterns like
'projected/F*' for non-concatenated layers.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Eric/improve-metric-naming-for-length-and-readability (#156)

* simplify metric_keys.py

* Update field resolution

* Remove test

* Simplify pattern expansion

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: ealt <ealt@users.noreply.github.com>
* add xavier's leaky RRXOR (#130)

* reduce number of metrics returned from variance analysis

* rename

* Update simplexity/activations/visualization/pattern_expansion.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* abbreviate

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…#167)

Add support for:
- Top-level hooks (hook_embed → embed)
- Block component hooks (blocks.N.{comp}.hook_X → LN.{comp}.X)
- ln_final hooks (ln_final.hook_X → ln_final.X)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
)

* Add IndependentFactoredGenerativeProcess for frozen factor support

Introduces a new generative process subclass that samples emissions from
each factor independently and supports "frozen" factors whose sequences
are identical across batch samples. This enables generating datasets where
k factors share realizations while (n-k) factors vary independently.

Key features:
- Per-factor independent emission sampling (not from joint distribution)
- Frozen factors specified via frozen_factor_indices and frozen_key
- Dual key stream approach: frozen factors use shared key, unfrozen use per-sample keys

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Simplify IndependentFactoredGenerativeProcess implementation

- Remove _generate_with_frozen method, merge logic into single generate method
- Move None handling to edges (emit_observation and generate) so
  _emit_observation_per_factor always receives valid arrays
- Remove unnecessary super().generate() delegation
- Cleaner code structure with fewer methods and one code path

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix formatting

* Fix pylint warnings in IndependentFactoredGenerativeProcess tests

Replace unnecessary lambdas with direct method references and add
pylint disable for too-few-public-methods on TestStateTransitions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
* return targets (#163)

* Apply Eric's review suggestions

- Use tuple(map(int, vocab_sizes)) in factored_generative_process.py
- Use math.prod(vocab_sizes) in noisy_channel.py for cleaner code

Co-authored-by: Casper Lützhøft Christensen <casperlchristensen@users.noreply.github.com>

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Casper Lützhøft Christensen <casperlchristensen@users.noreply.github.com>
casperlchristensen and others added 13 commits February 25, 2026 01:16
  - clarify FullyConditional semantics as a product-of-conditionals approximation
  - add constructor validation for control-map count/shape, vocab sizes, and fallback params
  - compute PoC normalization in log space for better numerical stability
  - add configurable zero-mass fallback (uniform or epsilon_smooth)
  - extract shared other-factor indexing helpers and reuse in ConditionalTransitions
  - expand fully-conditional tests for zero-mass fallback, control-map validation, and distortion behavior
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants