diff --git a/Project.toml b/Project.toml index 97374eac..5367ad5e 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" @@ -42,7 +41,6 @@ DifferentiationInterface = "0.7" Distributions = "0.25" DistributionsAD = "0.6" FillArrays = "1" -ForwardDiff = "1" LinearAlgebra = "1" Lux = "1" LuxCore = "1" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index f35c328f..94cce68e 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -4,8 +4,6 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -17,8 +15,6 @@ BenchmarkTools = "1" ComponentArrays = "0.15" DifferentiationInterface = "0.7" Distributions = "0.25" -ForwardDiff = "1" -Lux = "1" LuxCore = "1" PkgBenchmark = "0.2" StableRNGs = "1" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 37312736..05a38d05 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -3,8 +3,6 @@ import ADTypes, ComponentArrays, DifferentiationInterface, Distributions, - ForwardDiff, - Lux, LuxCore, PkgBenchmark, StableRNGs, @@ -19,18 +17,8 @@ r = rand(rng, data_dist, ndimension, ndata) r = convert.(Float32, r) nvars = size(r, 1) -naugs = nvars + 1 -n_in = nvars + naugs - -nn = Lux.Chain( - Lux.Dense(n_in => (2 * n_in + 1), tanh), - Lux.Dense((2 * n_in + 1) => n_in, tanh), -) - -icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng) - -icnf2 = - ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng, inplace = true) +icnf = ContinuousNormalizingFlows.ICNF(; nvars, rng) +icnf2 = ContinuousNormalizingFlows.ICNF(; nvars, rng, inplace = true) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/examples/usage.jl b/examples/usage.jl index 27312bfb..6e4508ee 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -31,7 +31,7 @@ using ContinuousNormalizingFlows, # To use gpu, add related packages # using LuxCUDA -nn = Chain(Dense(n_in => (2 * n_in + 1), tanh), Dense((2 * n_in + 1) => n_in, tanh)) +nn = Chain(Dense(n_in + 1 => n_in, tanh)) icnf = ICNF(; nn = nn, nvars = nvars, # number of variables @@ -45,6 +45,7 @@ icnf = ICNF(; # device = gpu_device(), # process data by GPU cond = false, # not conditioning on auxiliary input inplace = false, # not using the inplace version of functions + autonomous = false, # using non-autonomous flow compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote sol_kwargs = (; save_everystep = false, diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 24a2082c..dca2b40b 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -8,7 +8,6 @@ import ADTypes, Distributions, DistributionsAD, FillArrays, - ForwardDiff, LinearAlgebra, Lux, LuxCore, diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 7391b984..6ece873d 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -18,6 +18,7 @@ struct ICNF{ CM <: ComputeMode, INPLACE, COND, + AUTONOMOUS, AUGMENTED, STEER, NORM_Z, @@ -54,13 +55,14 @@ function ICNF(; compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), inplace::Bool = false, cond::Bool = false, + autonomous::Bool = false, device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(), rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), tspan::NTuple{2} = (zero(data_type), one(data_type)), nvars::Int = 1, naugmented::Int = nvars + 1, nn::LuxCore.AbstractLuxLayer = Lux.Chain( - Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), + Lux.Dense(nvars + naugmented + !autonomous => nvars + naugmented, tanh), ), steer_rate::AbstractFloat = convert(data_type, 1.0e-1), λ₁::AbstractFloat = convert(data_type, 1.0e-2), @@ -92,6 +94,7 @@ function ICNF(; typeof(compute_mode), inplace, cond, + autonomous, !iszero(naugmented), !iszero(steer_rate), !iszero(λ₁), @@ -131,14 +134,15 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIVectorMode, false, COND, AUGMENTED, STEER, NORM_Z}, + t::Any, + icnf::ICNF{T, <:DIVectorMode, false, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z}, mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, J = icnf_jacobian(icnf, mode, snn, z) @@ -156,14 +160,15 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIVectorMode, true, COND, AUGMENTED, STEER, NORM_Z}, + t::Any, + icnf::ICNF{T, <:DIVectorMode, true, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z}, mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, J = icnf_jacobian(icnf, mode, snn, z) @@ -181,14 +186,15 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:MatrixMode, false, COND, AUGMENTED, STEER, NORM_Z}, + t::Any, + icnf::ICNF{T, <:MatrixMode, false, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z}, mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, J = icnf_jacobian(icnf, mode, snn, z) @@ -212,14 +218,15 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:MatrixMode, true, COND, AUGMENTED, STEER, NORM_Z}, + t::Any, + icnf::ICNF{T, <:MatrixMode, true, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z}, mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, J = icnf_jacobian(icnf, mode, snn, z) @@ -237,14 +244,25 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIVecJacVectorMode, + false, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -266,14 +284,25 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIVecJacVectorMode, + true, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -295,14 +324,25 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIJacVecVectorMode, + false, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -324,14 +364,25 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIJacVecVectorMode, + true, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -353,14 +404,25 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIVecJacMatrixMode, + false, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -386,14 +448,25 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIVecJacMatrixMode, + true, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -415,14 +488,25 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIJacVecMatrixMode, + false, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -448,14 +532,25 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:DIJacVecMatrixMode, + true, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -477,14 +572,25 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:LuxVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:LuxVecJacMatrixMode, + false, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -510,14 +616,25 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:LuxVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:LuxVecJacMatrixMode, + true, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -539,14 +656,25 @@ end function augmented_f( u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:LuxJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:LuxJacVecMatrixMode, + false, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) @@ -572,14 +700,25 @@ function augmented_f( du::Any, u::Any, p::Any, - ::Any, - icnf::ICNF{T, <:LuxJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + t::Any, + icnf::ICNF{ + T, + <:LuxJacVecMatrixMode, + true, + COND, + AUTONOMOUS, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + }, mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} +) where {T <: AbstractFloat, COND, AUTONOMOUS, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) + nn = ifelse(AUTONOMOUS, nn, CondLayer(nn, t)) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) diff --git a/src/layers/cond_layer.jl b/src/layers/cond_layer.jl index e97bf9fa..7bca5194 100644 --- a/src/layers/cond_layer.jl +++ b/src/layers/cond_layer.jl @@ -4,6 +4,28 @@ struct CondLayer{NN <: LuxCore.AbstractLuxLayer, AT <: Any} <: ys::AT end -function (m::CondLayer)(z::AbstractVecOrMat, ps::Any, st::NamedTuple) +function (m::CondLayer{<:LuxCore.AbstractLuxLayer, <:AbstractArray})( + z::AbstractVecOrMat, + ps::Any, + st::NamedTuple, +) return LuxCore.apply(m.nn, vcat(z, m.ys), ps, st) end + +function (m::CondLayer{<:LuxCore.AbstractLuxLayer, <:Number})( + z::AbstractVector, + ps::Any, + st::NamedTuple, +) + return LuxCore.apply(m.nn, vcat(z, m.ys), ps, st) +end + +function (m::CondLayer{<:LuxCore.AbstractLuxLayer, <:Number})( + z::AbstractMatrix, + ps::Any, + st::NamedTuple, +) + ts = similar(z, 1, size(z, 2)) + ChainRulesCore.@ignore_derivatives fill!(ts, m.ys) + return LuxCore.apply(m.nn, vcat(z, ts), ps, st) +end diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index 5682ac00..893abb9c 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -3,68 +3,60 @@ Implementation of Planar Layer from [Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366) """ -struct PlanarLayer{use_bias, cond, F1, F2, F3, NVARS <: Int} <: LuxCore.AbstractLuxLayer +struct PlanarLayer{USE_BIAS, F1, F2, F3, NVARS <: Int} <: LuxCore.AbstractLuxLayer + in_dims::NVARS + out_dims::NVARS activation::F1 - nvars::NVARS init_weight::F2 init_bias::F3 - n_cond::NVARS end function PlanarLayer( - nvars::Int, + mapping::Pair{<:Int, <:Int}, activation::Any = identity; init_weight::Any = WeightInitializers.glorot_uniform, init_bias::Any = WeightInitializers.zeros32, use_bias::Bool = true, - n_cond::Int = 0, ) return PlanarLayer{ use_bias, - !iszero(n_cond), typeof(activation), typeof(init_weight), typeof(init_bias), - typeof(nvars), + typeof(first(mapping)), }( + first(mapping), + last(mapping), activation, - nvars, init_weight, init_bias, - n_cond, ) end function LuxCore.initialparameters( rng::Random.AbstractRNG, - layer::PlanarLayer{use_bias, cond}, -) where {use_bias, cond} + layer::PlanarLayer{USE_BIAS}, +) where {USE_BIAS} return ifelse( - use_bias, + USE_BIAS, ( - u = layer.init_weight(rng, layer.nvars), - w = layer.init_weight( - rng, - ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), - ), + u = layer.init_weight(rng, layer.out_dims), + w = layer.init_weight(rng, layer.in_dims), b = layer.init_bias(rng, 1), ), ( - u = layer.init_weight(rng, layer.nvars), - w = layer.init_weight( - rng, - ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), - ), + u = layer.init_weight(rng, layer.out_dims), + w = layer.init_weight(rng, layer.in_dims), ), ) end -function LuxCore.parameterlength(m::PlanarLayer{use_bias, cond}) where {use_bias, cond} - return m.nvars + ifelse(cond, (m.nvars + m.n_cond), m.nvars) + ifelse(use_bias, 1, 0) +function LuxCore.parameterlength(m::PlanarLayer{USE_BIAS}) where {USE_BIAS} + return m.out_dims + m.in_dims + USE_BIAS end function LuxCore.outputsize(m::PlanarLayer, ::Any, ::Random.AbstractRNG) - return (m.nvars,) + return (m.out_dims,) end function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index f4ea9a24..36e6de5c 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -7,15 +7,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test r = convert.(Float32, r) nvars = size(r, 1) - naugs = nvars + 1 - n_in = nvars + naugs - - nn = Lux.Chain( - Lux.Dense(n_in => (2 * n_in + 1), tanh), - Lux.Dense((2 * n_in + 1) => n_in, tanh), - ) - - icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng) + icnf = ContinuousNormalizingFlows.ICNF(; nvars, rng) df = DataFrames.DataFrame(transpose(r), :auto) model = ContinuousNormalizingFlows.ICNFModel(; diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index 4d143a38..f0a021e5 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -106,17 +106,21 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be planar, Lux.Chain( ContinuousNormalizingFlows.PlanarLayer( - nvars * 2 + 1, - tanh; - n_cond = nvars, + nvars * 3 + 1 + 1 => nvars * 2 + 1, + tanh, ), ), - Lux.Chain(Lux.Dense(nvars * 3 + 1 => nvars * 2 + 1, tanh)), + Lux.Chain(Lux.Dense(nvars * 3 + 1 + 1 => nvars * 2 + 1, tanh)), ), ifelse( planar, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2 + 1, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), + Lux.Chain( + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2 + 1 + 1 => nvars * 2 + 1, + tanh, + ), + ), + Lux.Chain(Lux.Dense(nvars * 2 + 1 + 1 => nvars * 2 + 1, tanh)), ), ) icnf = ContinuousNormalizingFlows.ICNF(; @@ -199,7 +203,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) diff --git a/test/ci_tests/speed_tests.jl b/test/ci_tests/speed_tests.jl index 93355f01..46616028 100644 --- a/test/ci_tests/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -32,7 +32,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be Test.@testset verbose = true showtiming = true failfast = false "$compute_mode" for compute_mode in compute_modes - @show compute_mode rng = StableRNGs.StableRNG(1) @@ -43,21 +42,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be r = convert.(Float32, r) nvars = size(r, 1) - naugs = nvars + 1 - n_in = nvars + naugs - - nn = Lux.Chain( - Lux.Dense(n_in => (2 * n_in + 1), tanh), - Lux.Dense((2 * n_in + 1) => n_in, tanh), - ) - - icnf = ContinuousNormalizingFlows.ICNF(; - nn, - nvars, - naugmented = naugs, - rng, - compute_mode, - ) + icnf = ContinuousNormalizingFlows.ICNF(; nvars, rng, compute_mode) df = DataFrames.DataFrame(transpose(r), :auto) model = ContinuousNormalizingFlows.ICNFModel(; diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index d5cca152..4132bc66 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -89,17 +89,21 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg planar, Lux.Chain( ContinuousNormalizingFlows.PlanarLayer( - nvars * 2 + 1, - tanh; - n_cond = nvars, + nvars * 3 + 1 + 1 => nvars * 2 + 1, + tanh, ), ), - Lux.Chain(Lux.Dense(nvars * 3 + 1 => nvars * 2 + 1, tanh)), + Lux.Chain(Lux.Dense(nvars * 3 + 1 + 1 => nvars * 2 + 1, tanh)), ), ifelse( planar, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2 + 1, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), + Lux.Chain( + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2 + 1 + 1 => nvars * 2 + 1, + tanh, + ), + ), + Lux.Chain(Lux.Dense(nvars * 2 + 1 + 1 => nvars * 2 + 1, tanh)), ), ) icnf = ContinuousNormalizingFlows.ICNF(;