Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions pygpcca/_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def _gram_schmidt_mod(X: ArrayLike, eta: ArrayLike) -> ArrayLike:
r"""
:math:`\eta`-orthonormalize Schur vectors.

This uses a modified, numerically stable version of Gram-Schmidt
Orthonormalization.
This uses Householder QR decomposition (LAPACK DGEQRF via
:func:`numpy.linalg.qr`) for backward-stable orthonormalization.

Parameters
----------
Expand All @@ -111,18 +111,18 @@ def _gram_schmidt_mod(X: ArrayLike, eta: ArrayLike) -> ArrayLike:
# Keep copy of the original (Schur) vectors for later sanity check.
Xc = np.copy(X)

# Initialize matrices.
n, m = X.shape
Q = np.zeros((n, m))
R = np.zeros((m, m))

# Search for the constant (Schur) vector, if explicitly present.
max_i = 0
for i in range(m):
vsum = np.sum(X[:, i])
dummy = np.ones(X[:, i].shape) * (vsum / n)
if np.allclose(X[:, i], dummy, rtol=1e-6, atol=1e-5):
max_i = i # TODO: check, if more than one vec fulfills this
# Find the column most aligned with sqrt(eta) (i.e., the stationary vector)
# by cosine similarity. This is more robust than checking element-wise
# constancy, which can fail when Schur vectors are only approximately
# constant due to numerical noise.
sqrt_eta = np.sqrt(eta)
sqrt_eta_normed = sqrt_eta / np.linalg.norm(sqrt_eta)
col_norms = np.linalg.norm(X, axis=0)
col_norms = np.where(col_norms > 0, col_norms, 1.0)
cosines = np.abs((X / col_norms).T @ sqrt_eta_normed)
max_i = int(np.argmax(cosines))

# Shift non-constant first (Schur) vector to the right.
X[:, max_i] = X[:, 0]
Expand All @@ -142,14 +142,11 @@ def _gram_schmidt_mod(X: ArrayLike, eta: ArrayLike) -> ArrayLike:
f"Number of clusters: {m}."
)

# eta-orthonormalization
for j in range(m):
v = X[:, j]
for i in range(j):
R[i, j] = np.dot(Q[:, i].conj(), v)
v = v - np.dot(R[i, j], Q[:, i])
R[j, j] = np.linalg.norm(v)
Q[:, j] = np.true_divide(v, R[j, j])
# Orthonormalize via Householder QR (backward stable, LAPACK DGEQRF).
Q, _ = np.linalg.qr(X)
# QR may flip the sign of columns; ensure column 0 aligns with sqrt(eta).
if Q[:, 0] @ np.sqrt(eta) < 0:
Q[:, 0] = -Q[:, 0]

# Raise, if the subspace changed!
dummy = subspace_angles(Q, Xc)
Expand Down Expand Up @@ -258,18 +255,21 @@ def _do_schur(
if not np.allclose(Q.T.dot(Q * eta[:, None]), np.eye(Q.shape[1]), rtol=1e6 * EPS, atol=1e6 * EPS):
logging.debug("The Schur vectors aren't D-orthogonal so they are D-orthogonalized.")
Q = _gram_schmidt_mod(Q, eta)
# Recompute R in the new orthonormal basis to maintain P_bar @ Q ≈ Q @ R.
P_bar_dense = P_bar.toarray() if issparse(P_bar) else P_bar
R = Q.T @ P_bar_dense @ Q
# Transform the orthonormalized Schur vectors of P_bar back
# to orthonormalized Schur vectors X of P.
X = np.true_divide(Q, np.sqrt(eta)[:, None])
else:
# Search for the constant (Schur) vector, if explicitly present.
# Find the column most aligned with sqrt(eta) (i.e., the stationary vector).
n, m = Q.shape
max_i = 0
for i in range(m):
vsum = np.sum(Q[:, i])
dummy = np.ones(Q[:, i].shape) * (vsum / n)
if np.allclose(Q[:, i], dummy, rtol=1e-6, atol=1e-5):
max_i = i # TODO: check, if more than one vec fulfills this
sqrt_eta = np.sqrt(eta)
sqrt_eta_normed = sqrt_eta / np.linalg.norm(sqrt_eta)
col_norms = np.linalg.norm(Q, axis=0)
col_norms = np.where(col_norms > 0, col_norms, 1.0)
cosines = np.abs((Q / col_norms).T @ sqrt_eta_normed)
max_i = int(np.argmax(cosines))

# Shift non-constant first (Schur) vector to the right.
Q[:, max_i] = Q[:, 0]
Expand Down
Loading