Skip to content

Conversation

@ayulockedin
Copy link
Contributor

@ayulockedin ayulockedin commented Jan 6, 2026

What does this PR do?

Description

This PR standardizes the documentation for the rngs argument across the flax.nnx.nn layers.

In the transition from Linen to NNX, rngs has changed from being a single JAX PRNGKey (array) to an nnx.Rngs container object. Previous docstrings incorrectly referred to it as an "rng key," which could confuse users trying to pass a raw JAX key.

Changes

  • Updated docstrings to refer to rngs as an "rngs object" instead of an "rng key".
  • Fixed inconsistency in the following modules:
    • flax.nnx.nn.linear
    • flax.nnx.nn.attention
    • flax.nnx.nn.normalization
    • flax.nnx.nn.lora
    • flax.nnx.nn.recurrent
    • flax.nnx.nn.stochastic

Related Issues

Fixes # (issue)

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@ayulockedin
Copy link
Contributor Author

Hi @cgarciae , thanks for merging my other PR today!

Since this one is just a documentation cleanup (standardizing the rngs docstrings for the Linen-to-NNX transition), I was wondering if you might have a quick moment to slip this one in as well? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant