Skip to content

Givens orthogonal layer#57

Open
VolodyaCO wants to merge 1 commit into
dwavesystems:mainfrom
VolodyaCO:givens-rotation
Open

Givens orthogonal layer#57
VolodyaCO wants to merge 1 commit into
dwavesystems:mainfrom
VolodyaCO:givens-rotation

Conversation

@VolodyaCO
Copy link
Copy Markdown
Collaborator

This PR adds an orthogonal layer given by Givens rotations, using the parallel algorithm described by Firas in https://arxiv.org/abs/2106.00003, which gives a forward complexity of O(n) and backward complexity of O(n log(n)), even though there are O(n^2) rotations.

This PR still is in draft. I wrote it for even n. Probably some more unit tests are to be done, but I am quite lazy (will do it after all math is checked for odd n).

@VolodyaCO VolodyaCO requested a review from kevinchern December 19, 2025 00:01
@VolodyaCO VolodyaCO self-assigned this Dec 19, 2025
@VolodyaCO VolodyaCO added the enhancement New feature or request label Dec 19, 2025
@VolodyaCO
Copy link
Copy Markdown
Collaborator Author

I somehow broke @kevinchern's tests, what the hell...

Comment thread tests/helper_models.py Outdated
Comment thread tests/test_nn.py
def test_store_config(self):
with self.subTest("Simple case"):

class MyModel(torch.nn.Module):
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove formatting changes. Is this "black" formatting?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I have it by default on my vscode

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(bump)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@kevinchern
Copy link
Copy Markdown
Collaborator

kevinchern commented Dec 19, 2025

I somehow broke @kevinchern's tests, what the hell...

@VolodyaCO which tests? I'm seeing test_forward_agreement and test_backward_agreement failures on this CI test

@VolodyaCO
Copy link
Copy Markdown
Collaborator Author

I somehow broke @kevinchern's tests, what the hell...

@VolodyaCO which tests? I'm seeing test_forward_agreement and test_backward_agreement failures on this CI test

I forgot to update my tests to float64 precision. Now that I've done it, it's weird that all of the current failing tests are failing on

  File "/Users/distiller/project/tests/test_nn.py", line 144, in test_LinearBlock
    self.assertTrue(model_probably_good(model, (din,), (dout,)))

@kevinchern
Copy link
Copy Markdown
Collaborator

I somehow broke @kevinchern's tests, what the hell...

@VolodyaCO which tests? I'm seeing test_forward_agreement and test_backward_agreement failures on this CI test

I forgot to update my tests to float64 precision. Now that I've done it, it's weird that all of the current failing tests are failing on

  File "/Users/distiller/project/tests/test_nn.py", line 144, in test_LinearBlock
    self.assertTrue(model_probably_good(model, (din,), (dout,)))

Ahhhhhh. OK Theo also flagged this at #50 . It's a poorly-written test.. you can ignore it.

@VolodyaCO VolodyaCO marked this pull request as ready for review December 26, 2025 17:28
Returns:
list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations.

Note:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I put this? in the release notes? or in the docstring itself?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply change the Note: to

.. note::

    Lorem ipsum...

which would render a note box if we generate docs with Sphinx.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have done this now

Comment on lines +81 to +85
angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2
pairs of independent rotations.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code formatting?

Suggested change
angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2
pairs of independent rotations.
angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices that
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n // 2
pairs of independent rotations.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.

Comment thread dwave/plugins/torch/nn/modules/orthogonal.py
Comment on lines +122 to +127
angles, blocks, Ufwd_saved = ctx.saved_tensors
Ufwd = Ufwd_saved.clone()
M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n)
n = M.size(1)
block_size = n // 2
A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re lowercase for Ufwd, M, and A. Avoids incorrect colour highlighting in themes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I didn't read this about the incorrect colour highlighting before I made my previous comment. I still think that it is easier to read the algorithm alongside the code if the use of lower/upper case match. For example, lower case m is usually used for an integer variable, not a tensor.

return U

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the type hint as well as a longer explanation on what this return is.

U = self._create_rotation_matrix()
rotated_x = einsum(x, U, "... i, o i -> ... o")
if self.bias is not None:
rotated_x = rotated_x + self.bias
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rotated_x = rotated_x + self.bias
rotated_x += self.bias

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread tests/helper_models.py Outdated
from einops import einsum


class NaiveGivensRotationLayer(nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very keen on having a full on separate implementation here just to compare with/test the GivensRotationLayer. If this NaiveGivensRotationLayer is useful, should it be part of the package instead?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this in our one on one but, just for the record, there is no difference between the NaiveGivensRotationLayer and the GivensRotationLayer in the forward or backward passes. The naïve implementation is there to make sure that the forward and backward passes indeed match. The GivensRotationLayer should always be used because it has a substantially better runtime complexity. Thus, the naïve implementation is not useful—other than for a sanity check.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this class should go directly into the test file instead of creating a helper_models.py module. The naive module is only ever used in these tests.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this by movint the class to the test file

Comment thread tests/test_nn.py Outdated
Comment on lines +91 to +92
@parameterized.expand([(n, bias) for n in [4, 5, 6, 9, 10] for bias in [True, False]])
def test_forward_agreement(self, n, bias):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests do seem a bit too.. complex. Better to try and test more minimal aspects of the class, if possible. I'd much rather have separate integration-like tests that can assert that model behave as expected, while having these be strictly, small scale, isolated unit tests.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some tests to test invalid inputs too. These forward and backward tests are for testing that the correct input/output is given when compared to the naïve implementation. The model_probably_good test is done as unit test.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added other unit tests where I test incorrect inputs as well. In ML models, the forward and backward passes should be what one expects them to be, and this module gives the opportunity to test this correctly. I do agree that we should separate other tests that (at least) I wrote, which have to do with training a model to see if the intended final trained state is what is expected. However, the tests I present in this PR are not the result of training but explicit comparisons with the naïve approach; I don't know if we could regard those as integration tests.

@VolodyaCO
Copy link
Copy Markdown
Collaborator Author

After a bit of git wrangling, I was able to clean my whole mess of merge commits 😆.

Copy link
Copy Markdown
Collaborator

@anahitamansouri anahitamansouri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice PR Vlad. It took me a while to go over the paper and this PR :) The only thing is the tests that are failing. Thanks for the great work.

self.n = n
self.n_angles = n * (n - 1) // 2
self.angles = nn.Parameter(torch.randn(self.n_angles))
blocks_edges = _get_blocks_edges(n)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could directly return torch.LongTensor from get_blocks_edges to avoid the conversion.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set _get_blocks_edges to a private function, so it shouldn't make a difference if I convert the list to a tensor in the orthogonal module or in the function itself.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented your suggestion

Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a quick pass to provide some feedback before taking some time to take a deep dive into the paper.



def _get_blocks_edges(n: int) -> list[list[tuple[int, int]]]:
"""Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens
"""Uses the circle method for round-robin pairing to create blocks of edges for parallel Givens

(and other occurrences)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should _get_blocks_edges should be a method in GivensRotation instead? The orthogonal module is general while this function is a helper function bespoke to GivensRotation.
cc @thisac

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe... though what would be the attribute of GivensRotation used in _get_blocks_edges? n only?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only n or as a @staticmethod

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took your suggestion

return grad_theta, None, None


class GivensRotationLayer(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename to GivensRotation (parallel to nn.Linear)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Did that too

Comment on lines +44 to +48
if n % 2 != 0:
n += 1 # Add a dummy dimension for odd n
is_odd = True
else:
is_odd = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be cleaner like this 😛

Suggested change
if n % 2 != 0:
n += 1 # Add a dummy dimension for odd n
is_odd = True
else:
is_odd = False
odd = n % 2 != 0
if odd:
n += 1

or

odd = n % 2 ! = 0
n += odd

but this is less obvious.. (edit: not a big fan of n+=odd notation 😆)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is cleaner! (the first suggestion, not the n+=odd 😆 )

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ignored.
"""
if n % 2 != 0:
n += 1 # Add a dummy dimension for odd n
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rule-of-thumb for comments: explain the "why" or motivation as opposed to "what" (which is clear in this context)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to always adopt this rule of thumb!

for _ in range(n - 1):
pairs = circle_method(sequence)
if is_odd:
# Remove pairs involving the dummy dimension:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Remove pairs involving the dummy dimension:
# Remove pairs involving the dummy dimension

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was gonna ask why remove the colon?


@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
"""Computes the VJP needed for backward propagation.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Computes the VJP needed for backward propagation.
"""Computes the vector-Jacobian product needed for backward propagation.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

idx_block = torch.arange(block_size, device=angles.device)
for b, block in enumerate(blocks):
# angles is of shape (n_angles,) containing all angles for contiguous blocks.
angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,)
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)

If I understand correctly, blocks.size(1) will be block_size

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, while writing the algorithm I though blocks could have different sizes if n is odd, but that is not true. All blocks will have the same block size.

Comment on lines +99 to +104
c = torch.cos(angles_in_block)
s = torch.sin(angles_in_block)
i_idx = block[:, 0]
j_idx = block[:, 1]
r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx]
r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsqueeze once in the beginning

Suggested change
c = torch.cos(angles_in_block)
s = torch.sin(angles_in_block)
i_idx = block[:, 0]
j_idx = block[:, 1]
r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx]
r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx]
c = torch.cos(angles_in_block).unsqueeze(0)
s = torch.sin(angles_in_block).unsqueeze(0)
i_idx = block[:, 0]
j_idx = block[:, 1]
r_i = c * U[:, i_idx] + s * U[:, j_idx]
r_j = -s * U[:, i_idx] + c * U[:, j_idx]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

U = torch.eye(n, device=angles.device, dtype=angles.dtype)
block_size = n // 2
idx_block = torch.arange(block_size, device=angles.device)
for b, block in enumerate(blocks):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we commit to using paper variable names here, we should be consistent and use, e.g., B instead of blocks.
If that's the case, I'd prefer to be a little more wasteful and have B = blocks to keep the input argument blocks instead of B. This inconsistency makes me lean towards named variables more (with a look-up table in the docstring).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed internally to use B instead.

Comment on lines +103 to +104
r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx]
r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are r_i and r_j are backwards?
I think it should be:

  • $\cos - \sin$ for i, and
  • $\sin + \cos$ for j.
    Not sure if this has a significant impact on validity of method. If it does, then tests should be revised first to see why this error was not detected

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes... well... in the paper the rotation matrices were written the other way around, I think. I did the math separately and this way everything is consistent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding a # NOTE: here to highlight this distinction from paper, since we're also using variable names identical to paper

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a .. note:: at the top of the class to explain that the paper performs rotations using the rows of U, but it is more standard to use the columns of U. It does not matter in the end because U is orthogonal, and using the rows or the columns is completely equivalent.

@VolodyaCO VolodyaCO force-pushed the givens-rotation branch 3 times, most recently from 1146ebb to df400fb Compare January 14, 2026 14:41
@kevinchern
Copy link
Copy Markdown
Collaborator

@VolodyaCO can you rebase on main?

Returns:
torch.Tensor: The nxn rotation matrix.
"""
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you annotate with comments the corresponding equations from paper? This will make it easier to maintain and understand
e.g., ESS PR

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the equations do not have equation numbers. I improved the backward pass code with in-line comments because it is the one that's actually difficult to follow with mappings from code variables to the variables used in the paper for improved readability.

Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not have equation numbers

If it doesn't have number, then describe it unambiguously.

e.g., https://arxiv.org/pdf/2106.00003 (5) has 3 equations -> I'd just say # second equation of (5).

For another example
# the second unnumbered equation following (6)

A last example # the expression following the sentence "relevant storage of the gradient"

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done mainly by copilot with AI. I went through the comments and fixed what was wrong (mainly that some references to the paper were incorrect in the sense that the location of the reference was wrongly stated, e.g. "after" instead of "before", or equation numbering wrong, e.g. "after equation (11)" instead of "after equation (12)".

Comment thread tests/helper_models.py Outdated
from einops import einsum


class NaiveGivensRotationLayer(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this class should go directly into the test file instead of creating a helper_models.py module. The naive module is only ever used in these tests.

Comment thread tests/test_nn.py
def test_store_config(self):
with self.subTest("Simple case"):

class MyModel(torch.nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(bump)

@VolodyaCO VolodyaCO force-pushed the givens-rotation branch 3 times, most recently from 6587b89 to 97c2b8e Compare April 27, 2026 21:33
@kevinchern kevinchern closed this Apr 27, 2026
@kevinchern kevinchern reopened this Apr 27, 2026
Copy link
Copy Markdown
Collaborator Author

@VolodyaCO VolodyaCO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevinchern I have addressed all comments. Should be ready to merge now.

Comment thread tests/test_nn.py
def test_store_config(self):
with self.subTest("Simple case"):

class MyModel(torch.nn.Module):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Returns:
torch.Tensor: The nxn rotation matrix.
"""
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done mainly by copilot with AI. I went through the comments and fixed what was wrong (mainly that some references to the paper were incorrect in the sense that the location of the reference was wrongly stated, e.g. "after" instead of "before", or equation numbering wrong, e.g. "after equation (11)" instead of "after equation (12)".

Comment on lines +93 to +95
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input
angles. No calculation of gradients with respect to blocks or n is needed (cf.
forward method), so None is returned for these.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns shouldn't be indented. Same in other places in this file.

Suggested change
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input
angles. No calculation of gradients with respect to blocks or n is needed (cf.
forward method), so None is returned for these.
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input
angles. No calculation of gradients with respect to blocks or n is needed (cf.
forward method), so None is returned for these.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

__all__ = ["GivensRotation"]


class _RoundRobinGivens(torch.autograd.Function):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation is quite thorough for a hidden class. Would this perhaps make sense moving to a nn.functions namespace and removing the underscore? @kevinchern


@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
"""Computes the vector-Jacobian product needed for backward propagation.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Computes the vector-Jacobian product needed for backward propagation.
"""Computes the vector-Jacobian product needed for backward propagation.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +98 to +101
# Initialize U^fwd from forward pass output U. Mathematically, U^fwd represents U^{1:k-1}
# at block k, defined in equation (11). It is post-multiplied by G_bk^T at each block
# iteration to "remove the effect of the block's rotations" (Section 4.1, paragraph before
# equation (12)). This corresponds to the update from U^{1:k} back to U^{1:k-1}.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the reference link, but these comments seem a bit messy (for lack of a better term). If the algorithm in the link describes the process in a clear way, I'd perhaps shorten these comments a bit, although references to equations are always good.

Also, separation by empty lines always makes things look a bit tidier.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added separation by empty lines to make things look tidier. I added these inline comments at @kevinchern 's request. I do agree that it is enough for the interested reader to read the paper. I also think it's helpful to have the inline comments too.

Address orthogonal module PR feedback.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants