Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand All @@ -17,6 +16,5 @@ DifferentiationInterface = "0.7"
Distributions = "0.25"
LuxCore = "1"
PkgBenchmark = "0.2"
StableRNGs = "1"
Zygote = "0.7"
julia = "1.10"
16 changes: 7 additions & 9 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@ import ADTypes,
Distributions,
LuxCore,
PkgBenchmark,
StableRNGs,
Zygote,
ContinuousNormalizingFlows

rng = StableRNGs.StableRNG(1)
ndata = 2^10
ndimension = 1
ndimensions = 1
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
r = rand(rng, data_dist, ndimension, ndata)
r = rand(data_dist, ndimensions, ndata)
r = convert.(Float32, r)

nvars = size(r, 1)
icnf = ContinuousNormalizingFlows.ICNF(; nvars, rng)
icnf2 = ContinuousNormalizingFlows.ICNF(; nvars, rng, inplace = true)
nvariables = size(r, 1)
icnf = ContinuousNormalizingFlows.ICNF(; nvariables)
icnf2 = ContinuousNormalizingFlows.ICNF(; nvariables, inplace = true)

ps, st = LuxCore.setup(icnf.rng, icnf)
ps = ComponentArrays.ComponentArray(ps)
Expand All @@ -35,7 +33,7 @@ end
function diff_loss_tt(x::Any)
return ContinuousNormalizingFlows.loss(
icnf,
ContinuousNormalizingFlows.TestMode{true}(),
ContinuousNormalizingFlows.TestMode(),
r,
x,
st,
Expand All @@ -54,7 +52,7 @@ end
function diff_loss_tt2(x::Any)
return ContinuousNormalizingFlows.loss(
icnf2,
ContinuousNormalizingFlows.TestMode{true}(),
ContinuousNormalizingFlows.TestMode(),
r,
x,
st,
Expand Down
1 change: 0 additions & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Expand Down
76 changes: 45 additions & 31 deletions examples/usage.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
# Switch To MKL For Faster Computation
using MKL

## Enable Logging
using Logging, TerminalLoggers
global_logger(TerminalLogger())

## Data
using Distributions
ndata = 1024
ndimension = 1
ndimensions = 1
data_dist = Beta{Float32}(2.0f0, 4.0f0)
r = rand(data_dist, ndimension, ndata)
r = rand(data_dist, ndimensions, ndata)
r = convert.(Float32, r)

## Parameters
nvars = size(r, 1)
naugs = nvars + 1
n_in = nvars + naugs
nvariables = size(r, 1)
naugments = nvariables + 1
n_in = nvariables + naugments + 1 # add time concatenation
n_out = nvariables + naugments
n_hidden = n_in * 4

## Model
using ContinuousNormalizingFlows,
Expand All @@ -26,32 +25,35 @@ using ContinuousNormalizingFlows,
SciMLSensitivity,
ADTypes,
Zygote,
# ForwardDiff, # to use JVP
# LuxCUDA, # To use gpu
MLDataDevices

# To use gpu, add related packages
# using LuxCUDA

nn = Chain(Dense(n_in + 1 => n_in, tanh))
icnf = ICNF(;
nn = nn,
nvars = nvars, # number of variables
naugmented = naugs, # number of augmented dimensions
nn = Chain(
Dense(n_in => n_hidden, softplus),
Dense(n_hidden => n_hidden, softplus),
Dense(n_hidden => n_out),
),
nvariables = nvariables, # number of variables
naugments = naugments, # number of augmented dimensions
nconditions = 0, # number of conditioning inputs
λ₁ = 1.0f-2, # regulate flow
λ₂ = 1.0f-2, # regulate volume change
λ₃ = 1.0f-2, # regulate augmented dimensions
steer_rate = 1.0f-1, # add random noise to end of the time span
tspan = (0.0f0, 1.0f0), # time span
device = cpu_device(), # process data by CPU
# device = gpu_device(), # process data by GPU
cond = false, # not conditioning on auxiliary input
inplace = false, # not using the inplace version of functions
autonomous = false, # using non-autonomous flow
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
inplace = false, # not using the inplace version of functions
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use VJP via Zygote
# compute_mode = LuxJacVecMatrixMode(AutoForwardDiff()), # process data in batches and use JVP via ForwardDiff
sol_kwargs = (;
save_everystep = false,
maxiters = typemax(Int),
reltol = 1.0f-4,
abstol = 1.0f-8,
reltol = sqrt(eps(Float32)),
abstol = sqrt(eps(Float32)),
alg = VCABM(; thread = True()),
sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true),
), # pass to the solver
Expand All @@ -60,24 +62,36 @@ icnf = ICNF(;
## Fit It
using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers

icnf_mach_fn = "icnf_mach.jls"
if ispath(icnf_mach_fn)
mach = machine(icnf_mach_fn) # load it
else
df = DataFrame(transpose(r), :auto)
function opt_callback(state::Any, l::Any)
if isone(state.iter % 64) # log the loss at each 64 iterations
println("Iteration: $(state.iter) | Loss: $l")
end
return false
end

icnf_mach_fn = "icnf-machine.jls"
if !isfile(icnf_mach_fn)
df = DataFrame(permutedims(r), :auto)
model = ICNFModel(;
icnf,
optimizers = (OptimiserChain(WeightDecay(), Adam()),),
optimizers = (
OptimiserChain(
ClipNorm(1.0f0, 2.0f0; throw = true),
WeightDecay(; lambda = 1.0f-2),
Adam(; eta = 1.0f-3, beta = (9.0f-1, 9.99f-1), epsilon = eps(Float32)),
),
),
batchsize = 1024,
adtype = AutoZygote(),
sol_kwargs = (; epochs = 300, progress = true), # pass to the solver
sol_kwargs = (; epochs = 300, progress = true, callback = opt_callback), # pass to the solver
)
mach = machine(model, df)
fit!(mach)
# CUDA.@allowscalar fit!(mach) # needed for gpu

MLJBase.save(icnf_mach_fn, mach) # save it
end
mach = machine(icnf_mach_fn) # load it

## Use It
d = ICNFDist(mach, TestMode())
Expand All @@ -97,8 +111,8 @@ display(res_df)
using CairoMakie
f = Figure()
ax = Axis(f[1, 1]; title = "Result")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "actual")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "estimated")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "Actual")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "Estimated")
axislegend(ax)
save("result-fig.svg", f)
save("result-fig.png", f)
save("result-figure.svg", f)
save("result-figure.png", f)
Loading
Loading