Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix;
return Dd, V
end

function MatrixAlgebraKit.householder_qr!(
function MatrixAlgebraKit.qr_householder!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand Down Expand Up @@ -97,7 +97,7 @@ function MatrixAlgebraKit.householder_qr!(
return Q, R
end

function MatrixAlgebraKit.householder_qr_null!(
function MatrixAlgebraKit.qr_null_householder!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand Down
37 changes: 19 additions & 18 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,44 +103,45 @@ for f! in (:geev!, :geevx!)
end

# driver dispatch
@inline qr_iteration_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) =
qr_iteration_eig_full!(driver, A, Dd, V; kwargs...)
@inline qr_iteration_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) =
qr_iteration_eig_vals!(driver, A, D, V; kwargs...)
@inline eig_full_qr_iteration!(A, DV; driver::Driver = DefaultDriver(), kwargs...) =
eig_full_qr_iteration!(driver, A, DV; kwargs...)
@inline eig_vals_qr_iteration!(A, D; driver::Driver = DefaultDriver(), kwargs...) =
eig_vals_qr_iteration!(driver, A, D; kwargs...)

@inline qr_iteration_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) =
qr_iteration_eig_full!(default_driver(QRIteration, A), A, Dd, V; kwargs...)
@inline qr_iteration_eig_vals!(::DefaultDriver, A, D, V; kwargs...) =
qr_iteration_eig_vals!(default_driver(QRIteration, A), A, D, V; kwargs...)
@inline eig_full_qr_iteration!(::DefaultDriver, A, DV; kwargs...) =
eig_full_qr_iteration!(default_driver(QRIteration, A), A, DV; kwargs...)
@inline eig_vals_qr_iteration!(::DefaultDriver, A, D; kwargs...) =
eig_vals_qr_iteration!(default_driver(QRIteration, A), A, D; kwargs...)

# Implementation
function qr_iteration_eig_full!(
driver::Driver, A, Dd, V;
function eig_full_qr_iteration!(
driver::Driver, A, DV;
fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true
)
D, V = DV
Dd = diagview(D)
(scale & permute) ? geev!(driver, A, Dd, V) : geevx!(driver, A, Dd, V; scale, permute)
fixgauge && gaugefix!(eig_full!, V)
return Dd, V
return DV
end
function qr_iteration_eig_vals!(
driver::Driver, A, D, V;
function eig_vals_qr_iteration!(
driver::Driver, A, D;
fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true
)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
(scale & permute) ? geev!(driver, A, D, V) : geevx!(driver, A, D, V; scale, permute)
return D
end

# Top-level QRIteration dispatch
function eig_full!(A::AbstractMatrix, DV, alg::QRIteration)
check_input(eig_full!, A, DV, alg)
D, V = DV
qr_iteration_eig_full!(A, diagview(D), V; alg.kwargs...)
return D, V
eig_full_qr_iteration!(A, DV; alg.kwargs...)
return DV
end
function eig_vals!(A::AbstractMatrix, D, alg::QRIteration)
check_input(eig_vals!, A, D, alg)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
qr_iteration_eig_vals!(A, D, V; alg.kwargs...)
eig_vals_qr_iteration!(A, D; alg.kwargs...)
return D
end

Expand Down
37 changes: 19 additions & 18 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,46 +115,47 @@ for (f, f_lapack!, Alg) in (
(:bisection, :heevx!, :Bisection),
(:jacobi, :heevj!, :Jacobi),
)
f_eigh_full! = Symbol(f, :_eigh_full!)
f_eigh_vals! = Symbol(f, :_eigh_vals!)
eigh_full_f! = Symbol(:eigh_full_, f, :!)
eigh_vals_f! = Symbol(:eigh_vals_, f, :!)

# MatrixAlgebraKit wrappers
@eval begin
function eigh_full!(A::AbstractMatrix, DV, alg::$Alg)
check_input(eigh_full!, A, DV, alg)
D, V = DV
Dd, V = $f_eigh_full!(A, D.diag, V; alg.kwargs...)
return D, V
$eigh_full_f!(A, DV; alg.kwargs...)
return DV
end
function eigh_vals!(A::AbstractMatrix, D, alg::$Alg)
check_input(eigh_vals!, A, D, alg)
V = similar(A, (size(A, 1), 0))
$f_eigh_vals!(A, D, V; alg.kwargs...)
$eigh_vals_f!(A, D; alg.kwargs...)
return D
end
end

# driver dispatch
@eval begin
@inline $f_eigh_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) =
$f_eigh_full!(driver, A, Dd, V; kwargs...)
@inline $f_eigh_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) =
$f_eigh_vals!(driver, A, D, V; kwargs...)
@inline $eigh_full_f!(A, DV; driver::Driver = DefaultDriver(), kwargs...) =
$eigh_full_f!(driver, A, DV; kwargs...)
@inline $eigh_vals_f!(A, D; driver::Driver = DefaultDriver(), kwargs...) =
$eigh_vals_f!(driver, A, D; kwargs...)

@inline $f_eigh_full!(::DefaultDriver, A, Dd, V; kwargs...) =
$f_eigh_full!(default_driver($Alg, A), A, Dd, V; kwargs...)
@inline $f_eigh_vals!(::DefaultDriver, A, D, V; kwargs...) =
$f_eigh_vals!(default_driver($Alg, A), A, D, V; kwargs...)
@inline $eigh_full_f!(::DefaultDriver, A, DV; kwargs...) =
$eigh_full_f!(default_driver($Alg, A), A, DV; kwargs...)
@inline $eigh_vals_f!(::DefaultDriver, A, D; kwargs...) =
$eigh_vals_f!(default_driver($Alg, A), A, D; kwargs...)
end

# Implementation
@eval begin
function $f_eigh_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...)
function $eigh_full_f!(driver::Driver, A, DV; fixgauge::Bool = default_fixgauge(), kwargs...)
D, V = DV
Dd = diagview(D)
$f_lapack!(driver, A, Dd, V; kwargs...)
fixgauge && gaugefix!(eigh_full!, V)
return Dd, V
return DV
end
function $f_eigh_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...)
function $eigh_vals_f!(driver::Driver, A, D; fixgauge::Bool = default_fixgauge(), kwargs...)
V = similar(A, (size(A, 1), 0))
$f_lapack!(driver, A, D, V; kwargs...)
return D
end
Expand Down
44 changes: 22 additions & 22 deletions src/implementations/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ end
# -----------
function lq_full!(A, LQ, alg::Householder)
check_input(lq_full!, A, LQ, alg)
return householder_lq!(A, LQ...; alg.kwargs...)
return lq_householder!(A, LQ...; alg.kwargs...)
end
function lq_compact!(A, LQ, alg::Householder)
check_input(lq_compact!, A, LQ, alg)
return householder_lq!(A, LQ...; alg.kwargs...)
return lq_householder!(A, LQ...; alg.kwargs...)
end
function lq_null!(A, Nᴴ, alg::Householder)
check_input(lq_null!, A, Nᴴ, alg)
return householder_lq_null!(A, Nᴴ; alg.kwargs...)
return lq_null_householder!(A, Nᴴ; alg.kwargs...)
end

# dispatch helpers
Expand All @@ -123,13 +123,13 @@ for f in (:gelqt!, :gemlqt!, :gelqf!, :unglq!, :unmlq!)
end
end

@inline householder_lq!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) =
householder_lq!(driver, A, L, Q; kwargs...)
householder_lq!(::DefaultDriver, A, L, Q; kwargs...) =
householder_lq!(default_driver(Householder, A), A, L, Q; kwargs...)
householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
@inline lq_householder!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) =
lq_householder!(driver, A, L, Q; kwargs...)
lq_householder!(::DefaultDriver, A, L, Q; kwargs...) =
lq_householder!(default_driver(Householder, A), A, L, Q; kwargs...)
lq_householder!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
lq_via_qr!(A, L, Q, Householder(; driver, kwargs...))
function householder_lq!(
function lq_householder!(
driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive = true, pivoted = false, blocksize::Int = 0
)
Expand Down Expand Up @@ -186,7 +186,7 @@ function householder_lq!(
end
return L, Q
end
function householder_lq!(
function lq_householder!(
driver::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand Down Expand Up @@ -229,13 +229,13 @@ function householder_lq!(
return L, Q
end

@inline householder_lq_null!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) =
householder_lq_null!(driver, A, Nᴴ; kwargs...)
householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) =
householder_lq_null!(default_driver(Householder, A), A, Nᴴ; kwargs...)
householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) =
@inline lq_null_householder!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) =
lq_null_householder!(driver, A, Nᴴ; kwargs...)
lq_null_householder!(::DefaultDriver, A, Nᴴ; kwargs...) =
lq_null_householder!(default_driver(Householder, A), A, Nᴴ; kwargs...)
lq_null_householder!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) =
lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...))
function householder_lq_null!(
function lq_null_householder!(
driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand All @@ -260,7 +260,7 @@ function householder_lq_null!(
end
return Nᴴ
end
function householder_lq_null!(
function lq_null_householder!(
driver::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand Down Expand Up @@ -343,21 +343,21 @@ end
function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
_diagonal_lq!(A, L, Q; alg.kwargs...)
lq_diagonal!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
_diagonal_lq!(A, L, Q; alg.kwargs...)
lq_diagonal!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm)
check_input(lq_null!, A, N, alg)
return _diagonal_lq_null!(A, N; alg.kwargs...)
return lq_null_diagonal!(A, N; alg.kwargs...)
end

function _diagonal_lq!(
function lq_diagonal!(
A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; positive::Bool = true
)
# note: Ad and Qd might share memory here so order of operations is important
Expand All @@ -374,7 +374,7 @@ function _diagonal_lq!(
return L, Q
end

_diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool = true) = N
lq_null_diagonal!(A::AbstractMatrix, N; positive::Bool = true) = N

# Deprecations
# ------------
Expand Down
12 changes: 6 additions & 6 deletions src/implementations/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton)
check_input(left_polar!, A, WP, alg)
W, P = WP
if isempty(P)
W = _left_polarnewton!(A, W, P; alg.kwargs...)
W = left_polar_newton!(A, W, P; alg.kwargs...)
return W, P
else
W = _left_polarnewton!(copy(A), W, P; alg.kwargs...)
W = left_polar_newton!(copy(A), W, P; alg.kwargs...)
# we still need `A` to compute `P`
P = project_hermitian!(mul!(P, W', A))
return W, P
Expand All @@ -114,18 +114,18 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarNewton)
check_input(right_polar!, A, PWᴴ, alg)
P, Wᴴ = PWᴴ
if isempty(P)
Wᴴ = _right_polarnewton!(A, Wᴴ, P; alg.kwargs...)
Wᴴ = right_polar_newton!(A, Wᴴ, P; alg.kwargs...)
return P, Wᴴ
else
Wᴴ = _right_polarnewton!(copy(A), Wᴴ, P; alg.kwargs...)
Wᴴ = right_polar_newton!(copy(A), Wᴴ, P; alg.kwargs...)
# we still need `A` to compute `P`
P = project_hermitian!(mul!(P, A, Wᴴ'))
return P, Wᴴ
end
end

# these methods only compute W and destroy A in the process
function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
function left_polar_newton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
m, n = size(A) # we must have m >= n
Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
if m > n # initial QR
Expand Down Expand Up @@ -165,7 +165,7 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
return W
end

function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
function right_polar_newton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
m, n = size(A) # we must have m <= n
Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
if m < n # initial QR
Expand Down
30 changes: 15 additions & 15 deletions src/implementations/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ end
# -----------
function qr_full!(A, QR, alg::Householder)
check_input(qr_full!, A, QR, alg)
return householder_qr!(A, QR...; alg.kwargs...)
return qr_householder!(A, QR...; alg.kwargs...)
end
function qr_compact!(A, QR, alg::Householder)
check_input(qr_compact!, A, QR, alg)
return householder_qr!(A, QR...; alg.kwargs...)
return qr_householder!(A, QR...; alg.kwargs...)
end
function qr_null!(A, N, alg::Householder)
check_input(qr_null!, A, N, alg)
return householder_qr_null!(A, N; alg.kwargs...)
return qr_null_householder!(A, N; alg.kwargs...)
end


Expand All @@ -125,11 +125,11 @@ for f in (:geqrt!, :gemqrt!, :geqp3!, :geqrf!, :ungqr!, :unmqr!)
end
end

@inline householder_qr!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) =
householder_qr!(driver, A, Q, R; kwargs...)
householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
householder_qr!(default_driver(Householder, A), A, Q, R; kwargs...)
function householder_qr!(
@inline qr_householder!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) =
qr_householder!(driver, A, Q, R; kwargs...)
qr_householder!(::DefaultDriver, A, Q, R; kwargs...) =
qr_householder!(default_driver(Householder, A), A, Q, R; kwargs...)
function qr_householder!(
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false,
blocksize::Int = 0
Expand Down Expand Up @@ -213,7 +213,7 @@ function householder_qr!(
end
return Q, R
end
function householder_qr!(
function qr_householder!(
driver::Native, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand Down Expand Up @@ -256,11 +256,11 @@ function householder_qr!(
return Q, R
end

@inline householder_qr_null!(A, N; driver::Driver = DefaultDriver(), kwargs...) =
householder_qr_null!(driver, A, N; kwargs...)
householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
householder_qr_null!(default_driver(Householder, A), A, N; kwargs...)
function householder_qr_null!(
@inline qr_null_householder!(A, N; driver::Driver = DefaultDriver(), kwargs...) =
qr_null_householder!(driver, A, N; kwargs...)
qr_null_householder!(::DefaultDriver, A, N; kwargs...) =
qr_null_householder!(default_driver(Householder, A), A, N; kwargs...)
function qr_null_householder!(
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand Down Expand Up @@ -288,7 +288,7 @@ function householder_qr_null!(
end
return N
end
function householder_qr_null!(
function qr_null_householder!(
driver::Native, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
Expand Down
Loading
Loading