diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index b313f9b2f..a87319082 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_eig(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl index 800dbaa05..e39f68316 100644 --- a/test/mooncake/eigh.jl +++ b/test/mooncake/eigh.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_eigh(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index 0f05f85ab..9ffdc730d 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_lq(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl index 09e3a28cc..370454b55 100644 --- a/test/mooncake/orthnull.jl +++ b/test/mooncake/orthnull.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_orthnull(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index 1faf3c104..d6c089098 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -17,5 +17,10 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_left_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake_right_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index 17415e8df..ed4db1925 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_qr(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl index d2d40df42..f096fdb8e 100644 --- a/test/mooncake/svd.jl +++ b/test/mooncake/svd.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end