From 98ea2e89305ac0bad8b3e96b0b76620b49a5883a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 23 Mar 2026 12:23:10 +0100 Subject: [PATCH] Try Enzyme tests with CUDA --- test/enzyme/eig.jl | 3 +++ test/enzyme/eigh.jl | 3 +++ test/enzyme/lq.jl | 3 +++ test/enzyme/orthnull.jl | 3 +++ test/enzyme/polar.jl | 3 +++ test/enzyme/projections.jl | 4 ++++ test/enzyme/qr.jl | 3 +++ test/enzyme/svd.jl | 3 +++ 8 files changed, 25 insertions(+) diff --git a/test/enzyme/eig.jl b/test/enzyme/eig.jl index 898d773a8..a0915629c 100644 --- a/test/enzyme/eig.jl +++ b/test/enzyme/eig.jl @@ -16,4 +16,7 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_eig(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/eigh.jl b/test/enzyme/eigh.jl index d32db3dd5..5c74f0fdf 100644 --- a/test/enzyme/eigh.jl +++ b/test/enzyme/eigh.jl @@ -16,4 +16,7 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_eigh(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/lq.jl b/test/enzyme/lq.jl index f7ae2ebf7..1f3f4823c 100644 --- a/test/enzyme/lq.jl +++ b/test/enzyme/lq.jl @@ -16,4 +16,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_lq(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/orthnull.jl b/test/enzyme/orthnull.jl index eaeae8400..005680708 100644 --- a/test/enzyme/orthnull.jl +++ b/test/enzyme/orthnull.jl @@ -16,4 +16,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_orthnull(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/polar.jl b/test/enzyme/polar.jl index 6ab965ac1..92282119e 100644 --- a/test/enzyme/polar.jl +++ b/test/enzyme/polar.jl @@ -16,4 +16,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_polar(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/projections.jl b/test/enzyme/projections.jl index 52b222a52..cf3d90709 100644 --- a/test/enzyme/projections.jl +++ b/test/enzyme/projections.jl @@ -18,4 +18,8 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.test_enzyme_projections(T, (m, m); atol, rtol) TestSuite.test_enzyme_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol) end + if CUDA.functional() + TestSuite.test_enzyme_projections(CuMatrix{T}, (m, n); atol, rtol) + TestSuite.test_enzyme_projections(Diagonal{T, CuVector{T}}, (m, m); atol, rtol) + end end diff --git a/test/enzyme/qr.jl b/test/enzyme/qr.jl index 728e267d3..2cef8195f 100644 --- a/test/enzyme/qr.jl +++ b/test/enzyme/qr.jl @@ -16,4 +16,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_qr(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl index 6143f61e4..a218dfd7e 100644 --- a/test/enzyme/svd.jl +++ b/test/enzyme/svd.jl @@ -16,4 +16,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_svd(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end