diff --git a/Project.toml b/Project.toml index d578dcb..0e5ece1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "PoissonRandom" uuid = "e409e4f3-bfea-5376-8464-e040bb5c01ab" -version = "0.4.7" +version = "0.4.8" [deps] LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index f8245c0..4a867d0 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -13,6 +13,15 @@ Random.rand(rng::PassthroughRNG) = rand() Random.randexp(rng::PassthroughRNG) = randexp() Random.randn(rng::PassthroughRNG) = randn() +# When an overlay method table (e.g. CUDA.jl's `@device_override +# Random.randexp(::AbstractRNG)`) shadows the methods above, the overlay body +# runs with rng::PassthroughRNG and may call `Random.rand(rng, UInt52Raw())` +# or `Random.rand(rng, T)`. The stdlib Sampler chain bottoms out at +# `_rand52(r, rng_native_52(r))` → `rand(r, UInt64)`; provide those so the +# chain still reaches bare rand(T) and the device-side default_rng path. +Random.rng_native_52(::PassthroughRNG) = UInt64 +Random.rand(rng::PassthroughRNG, ::Type{T}) where {T} = rand(T) + count_rand(λ::Real) = count_rand(Random.default_rng(), λ) function count_rand(rng::AbstractRNG, λ::Real) n = 0 diff --git a/test/runtests.jl b/test/runtests.jl index 4f16a85..3b9cd0d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,22 @@ end end end +@testset "PassthroughRNG dispatch" begin + using Random: Random, UInt52Raw + prng = PassthroughRNG() + # The CUDA.jl @device_override Random.randexp(::AbstractRNG) shadows our + # specific Random.randexp(::PassthroughRNG) on the GPU because Julia's + # OverlayMethodTable returns overlay matches without consulting the base + # table when the overlay fully covers the signature. The override body + # then calls these against PassthroughRNG; if they MethodError, kernel + # compilation fails with InvalidIRError on jl_f_throw_methoderror. + @test Random.rng_native_52(prng) === UInt64 + @test Random.rand(prng, UInt52Raw()) isa UInt64 + @test Random.rand(prng, UInt64) isa UInt64 + @test Random.rand(prng, Float32) isa Float32 + @test Random.rand(prng, Float64) isa Float64 +end + if get(ENV, "GROUP", "all") == "all" || get(ENV, "GROUP", "all") == "nopre" @testset "Allocation Tests" begin include("alloc_tests.jl")