Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions exercises/solved_notebooks/P5_mcmc/MCMC_1-intro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,11 @@ We can sample some values from this prior distribution and use them to plot the

# ╔═╡ 1080294d-8da9-4326-a53b-afe158dc2ab9
begin
scatter(times, observed_mutations, xlabel = "Time (My)",
ylabel = "Number of mutations", label = false, xlims = (0, 500),
ylims = (0, 400), title = "A priori relationship between t and mean N"
);
for α in rand(prior_alpha, 1000)
plot!(x -> α*x, color = :purple, alpha = 0.05, label = false);
plot(xlabel = "Time (My)", ylabel = "Number of mutations", xlims = (0, 500),
ylims = (0, 400), title = "A priori relationship between t and mean N");
scatter!(times, observed_mutations, label = "Cyt C", color = :deeppink);
for α in rand(prior_alpha, 200)
plot!(x -> α*x, color = :dodgerblue, alpha = 0.3, label = false);
end
plot!()
end
Expand Down Expand Up @@ -199,9 +198,6 @@ mutation_model = mutations(times);
# ╔═╡ d4dbc185-6087-4862-b6e5-b5e4f51c2866
md"And can then generate random samples of the output as we are used to:"

# ╔═╡ 8b0bf05f-92a0-4ce7-8042-08c790088688
mutation_model() # random sample of N

# ╔═╡ 3e4998e7-4981-4945-8bf2-ddb5afcb43b1
chain = sample(mutation_model, Prior(), 2000);

Expand All @@ -211,6 +207,9 @@ chain = sample(mutation_model, Prior(), 2000);
# ╔═╡ a9d54dfc-5337-412f-86b0-deeb5a0b6928
histogram(α_sp, title = "Sample of prior of α")

# ╔═╡ fdec279a-a7e3-4a57-a3f1-183e22558725
generated_quantities(mutation_model, chain) # random samples of N

# ╔═╡ 425b4b6c-76b8-4676-8d0a-fd26711400d6
md"### Inference"

Expand Down Expand Up @@ -240,8 +239,12 @@ md"""
We can verify that for our conditioned model, the value of `N` has been set as constant:
"""

# ╔═╡ 31d48207-4432-4d8f-a18a-40f01d4a0a6c
# ╠═╡ show_logs = false
conditioned_chain = sample(conditioned_model, Prior(), 5);

# ╔═╡ d03cef36-3e82-4de4-89e7-af9f772edd8d
conditioned_model() # always returns `observed_mutations`
generated_quantities(conditioned_model, conditioned_chain) # always returns `observed_mutations`

# ╔═╡ 371a48d5-daea-4d0b-968b-7e3056a65494
md"""
Expand Down Expand Up @@ -291,11 +294,13 @@ md"Plotting some sampled mutation rates from this distribution onto our data sho

# ╔═╡ 87e70d5a-7a45-4a3e-b6c4-a894cc78621b
begin
scatter(times, observed_mutations, xlabel = "Time (My)",
ylabel = "Number of mutations", label = false, xlims = (0, 500),
ylims = (0, 400), title = "A posteriori relationship between t and mean N"
);
plot!([x -> αᵢ*x for αᵢ in alpha_samples[1:10:end]], color = :purple, opacity = 0.1, label = false)
plot(xlabel = "Time (My)", ylabel = "Number of mutations", xlims = (0, 500),
ylims = (0, 400), title = "A posteriori relationship between t and mean N");
for α in alpha_samples[1:10:end]
plot!(x -> α*x, color = :dodgerblue, alpha = 0.1, label = false);
end
scatter!(times, observed_mutations, label = "Cyt C", color = :deeppink);
plot!()
end

# ╔═╡ 87059440-2919-4f6d-9d32-0df3ce75e2a2
Expand All @@ -320,7 +325,7 @@ md"""
To answer how old the ancestral seahorse fossil is, we need to update the model a little.
So far the fossil ages were considered to be known exactly and given as input to the model `ts`. Since the seahorse fossil's age is unknown, we add a parameter for it called `fossil_age`.

As prior knowledge we can use the fact that it must have evolved _after_ the ray-finned fish fossil (30 Ma after weird old fish), but _before_ modern seahorses (450 Ma after the bony fish fossil).
As prior knowledge we can use the fact that it must have evolved _after_ the ray-finned fish fossil (30 Ma after the bony fish fossil), but _before_ modern seahorses (450 Ma after the bony fish fossil).
"""

# ╔═╡ fd15afe1-72d7-4663-b2b6-afa0dd219db8
Expand Down Expand Up @@ -428,17 +433,18 @@ end
# ╟─31378eb3-51a5-4ad6-a713-7f77c7ceafcc
# ╠═48f6b7dc-13aa-4057-8468-97db047773ba
# ╟─d4dbc185-6087-4862-b6e5-b5e4f51c2866
# ╠═8b0bf05f-92a0-4ce7-8042-08c790088688
# ╠═3e4998e7-4981-4945-8bf2-ddb5afcb43b1
# ╠═3421987d-1aab-4cab-bf83-dd3653715bce
# ╟─a9d54dfc-5337-412f-86b0-deeb5a0b6928
# ╠═fdec279a-a7e3-4a57-a3f1-183e22558725
# ╟─425b4b6c-76b8-4676-8d0a-fd26711400d6
# ╟─b8adbdd4-2642-4375-9979-0cb8f52c5bc8
# ╠═70f9e94d-a4e6-47d6-8d19-b60f7011d572
# ╟─00c24cbb-88ba-49f2-9bf5-f44538fb2413
# ╟─2d0c969d-03a2-4e4c-ace4-e439f81c771b
# ╠═a35a43e2-e6b0-47ce-80b2-48148336274c
# ╟─c5f0dbb3-fba1-41f2-b7d2-740012603555
# ╠═31d48207-4432-4d8f-a18a-40f01d4a0a6c
# ╠═d03cef36-3e82-4de4-89e7-af9f772edd8d
# ╟─371a48d5-daea-4d0b-968b-7e3056a65494
# ╠═0d2c1359-434f-4f3d-8c04-c452c46d7ae8
Expand Down
56 changes: 33 additions & 23 deletions exercises/solved_notebooks/P5_mcmc/MCMC_2-basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Markdown
using InteractiveUtils

# ╔═╡ 94c6f31d-1a43-4221-b60c-1fa0ef8738b8
# ╠═╡ show_logs = false
using Pkg; Pkg.activate("..")

# ╔═╡ 45bc5b66-c81b-4afb-8a7e-51aff9609c62
Expand Down Expand Up @@ -70,8 +71,11 @@ end
# ╔═╡ 76dd814d-0d9b-4f7e-aff8-990da57d052b
molemodel = mole()

# ╔═╡ 611575ef-20ee-4731-9b5d-3fec90d8b038
molechain = sample(molemodel, Prior(), 2000)

# ╔═╡ b5bd5114-0a0c-4aa3-af00-4f3d05edc5e3
Y_samples = [molemodel() for i in 1:2000];
Y_samples = molechain[:Y];

# ╔═╡ 4e8b9000-cb61-4f01-9ba9-17276ad0335e
E_Y = mean(Y_samples)
Expand All @@ -83,7 +87,7 @@ md"### 3: Conditional expected value of X"
cond_mole = molemodel | (Y = 3,);

# ╔═╡ 014538da-b5ef-41c8-b799-2c000b4c9134
molechain = sample(cond_mole, NUTS(), 2000);
molechain_cond = sample(cond_mole, NUTS(), 2000);

# ╔═╡ b02dc714-e2bb-4ae2-acf9-c37a4389f953
plot(molechain)
Expand All @@ -95,7 +99,7 @@ X_samplescondY = molechain[:X];
E_XcondY = mean(molechain[:X])

# ╔═╡ 5521933a-a42e-4a67-94b9-84eab52ddf07
E_X = mean(X_prior)
E_X = mean(Exponential(100))

# ╔═╡ ca3e730c-a940-4c76-93eb-70ae4aa0e008
md"### 4: Conditional distribution of X"
Expand All @@ -107,9 +111,9 @@ histogram(molechain[:X])
md"### 5: Conditional distribution of X (with more data)"

# ╔═╡ b2b16c34-e12a-4bea-8098-313d01913bbf
@model function mole2()
@model function mole2(n)
X ~ Exponential(100)
Ys = zeros(4)
Ys = zeros(n)
for i in 1:length(Ys)
Ys[i] ~ Uniform(0, X)
end
Expand All @@ -119,7 +123,7 @@ end
Y_obs = [3.0, 1.5, 0.9, 5.7]

# ╔═╡ 2f5c14e9-709a-40ec-a6cb-ddd870cc1a60
mole_cond2 = mole2() | (Ys = Y_obs,)
mole_cond2 = mole2(4) | (Ys = Y_obs,)

# ╔═╡ 7f90e18d-9af8-42fa-984b-aaa7c8c458b5
molechain2 = sample(mole_cond2, NUTS(), 2000);
Expand Down Expand Up @@ -174,7 +178,7 @@ histogram(potato_chain[:N])
md"### 2: Probability"

# ╔═╡ 9dd6fa35-37d5-4ab4-ac2c-f4eefdabadd2
p_potato1 = mean(potato_chain[:N] .> 6 .&& potato_chain[:W] .> 175)
p_potato1 = mean((potato_chain[:N] .> 6) .&& (potato_chain[:W] .> 175))

# ╔═╡ 989b755b-2275-4aae-a3e2-3b107a72fc0b
md"### 3: Probability (with more data)"
Expand Down Expand Up @@ -283,23 +287,27 @@ md"""
"""

# ╔═╡ 8fc58fa4-b005-4f32-9eae-a8143582a1ae
@model function lights_censored(time_observed, n_working)
@model function lights_censored(time_observed)
μ ~ LogNormal(log(40), 0.5)


# we still have lifespan information for 2 lights
lifespans = zeros(2)
for i in 1:length(lifespans)
lifespans[i] ~ Exponential(μ)
end
p_stillworking = 1 - cdf(Exponential(μ), time_observed)
n_working ~ Binomial(n_working + length(lifespans), p_stillworking)
# or simply `2 ~ Binomial(4, p_stillworking)`

# use the information that 2 lights are still working
p_stillworking = 1 - cdf(Exponential(μ), time_observed)
# probability that μ > 30_000 => this is the success rate for a lamp to still work after 30_000 hours
n_working ~ Binomial(4, p_stillworking)
# the amount of lights still working (`n_workin`) follows a Binomial distribution: they can be seen as the successes of attempting to "not break" 4 times, each with the success rate for a lamp to still work after 30_000 hours
end

# ╔═╡ fe2958a7-e9dd-4eca-979d-a80df12f8735
lightmodel_cens = lights_censored(30, 2) | (lifespans = [16, 20],)
lightmodel_cens = lights_censored(30) | (lifespans = [16, 20], n_working = 2)

# ╔═╡ 8abafb6a-dc83-422c-82d1-a721a0e1eca0
lightschain_cens = sample(lightmodel_cens, MH(), 10_000)
lightschain_cens = sample(lightmodel_cens, PG(10), 1_000)

# ╔═╡ 5ba2886c-b2e1-49f4-90c6-549acc808f77
plot(lightschain_cens)
Expand Down Expand Up @@ -381,12 +389,12 @@ md"### 3🌟: Conditional expected value (spicy)"
fs1 ~ Uniform(0, 1) # fraction of species 1

fishlens = zeros(10)
isspecies1 = zeros(10)
isspecies1 = zeros(10) # for every fish: are you species 1?
for i in 1:length(fishlens)
isspecies1[i] ~ Bernoulli(fs1)
if isspecies1[i] == 1.0
isspecies1[i] ~ Bernoulli(fs1) # "are you species 1" is a binomial distr!
if isspecies1[i] == 1.0 # if species 1, length follows their distribution
fishlens[i] ~ Normal(90, 15)
else
else # otherwise length follows distribution of species 2
fishlens[i] ~ Normal(60, 10)
end
end
Expand All @@ -396,10 +404,10 @@ end
fishmodel🌟 = fishmixture🌟() | (fishlens = len_obs,)

# ╔═╡ bdb0d902-3056-43fb-abb0-15f2494fcf9d
fishchain🌟 = sample(fishmodel🌟, PG(20), 2000)
fishchain🌟 = sample(fishmodel🌟, MH(), 20_000) # PG convergences poorly without a lot of particles => takes a long time => to win time you can instead try the very fast (but inconsistent) `MH` with a criminally large amount of samples, just make sure to check convergence for this sampler!

# ╔═╡ 7372c616-05a8-428d-ab04-a7be9f65653d
plot(fishchain🌟)
plot(fishchain🌟) # MH happened to work well this time! 🥳

# ╔═╡ 40cc67c1-8627-4f1f-b135-a9770f916b53
p_fish4_is_species1 = mean(fishchain🌟["isspecies1[4]"])
Expand Down Expand Up @@ -475,14 +483,15 @@ md"### 1: All points"
R ~ Uniform(0, 50)

# three random points in polar coordinates

# define angles (radius has already been chosen)
θ1 ~ Flat()
# `Uniform(0, 2*pi)` is also possible but can get the sampler stuck
# at 0 or 2π!
# `Uniform(0, 2*pi)` is also possible but can get the sampler stuck at 0 or 2π! Another option here is a Uniform distribution with a very large domain, such as `Uniform(-1000, 1000)`
θ2 ~ Flat()
θ3 ~ Flat()

# P1
x1 ~ Normal(xC + R * cos(θ1), σ)
x1 ~ Normal(xC + R * cos(θ1), σ) # wait, it's all trigonometry?
y1 ~ Normal(yC + R * sin(θ1), σ)
# P2
x2 ~ Normal(xC + R * cos(θ2), σ)
Expand Down Expand Up @@ -567,6 +576,7 @@ end
# ╟─78eb7779-f182-4419-b5d8-79a2f5c5d6da
# ╠═3decb2ec-210a-4b2f-842d-6fd40dd3f77b
# ╠═76dd814d-0d9b-4f7e-aff8-990da57d052b
# ╠═611575ef-20ee-4731-9b5d-3fec90d8b038
# ╠═b5bd5114-0a0c-4aa3-af00-4f3d05edc5e3
# ╠═4e8b9000-cb61-4f01-9ba9-17276ad0335e
# ╟─8777b133-7d7c-4a85-b89c-2f00093e9984
Expand Down
22 changes: 13 additions & 9 deletions exercises/solved_notebooks/P5_mcmc/MCMC_3-advanced.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ macro bind(def, element)
end

# ╔═╡ 75581580-2fb2-4112-b397-2b775eb64630
# ╠═╡ show_logs = false
using Pkg; Pkg.activate("..")

# ╔═╡ e07a1ae5-43b7-4c12-831d-43e1738eeac0
Expand Down Expand Up @@ -158,11 +159,14 @@ plot(dropletdist)
r ~ LogNormal(0.0, 0.3)
K ~ Normal(1e5, 1e4)

logfun = t -> logistic(t, P0, r, K)
Pt = max(logfun(t_obs), 0)
# model number of bacteria at time `t_obs`
num_bacteria = logistic(t_obs, P0, r, K)
num_bacteria = max(num_bacteria, 0)
# set to 0 if negative to prevent possible errors (not required)
P_obs ~ Poisson(Pt)

P_obs ~ Poisson(num_bacteria)

# to answer the questions, it's useful to return the growth through time as a function
logfun = t -> logistic(t, P0, r, K)
return logfun
end

Expand All @@ -173,13 +177,13 @@ petrimodel = petrigrowth(5) | (P_obs = 21_000,)
petrichain = sample(petrimodel, PG(40), 2_000)

# ╔═╡ 9633f07a-d583-4680-b3cc-f5701540968f
plot(petrichain)
plot(petrichain) # convergence is not great unless you use a large number of particles (and even then it's not ideal)

# ╔═╡ 7d0e5f0b-2cdf-4946-a186-f70774e363bb
logfuns = generated_quantities(petrimodel, petrichain);
logfuns = generated_quantities(petrimodel, petrichain); # it's easy here to get a function describing growth through time because we returned it in our model!

# ╔═╡ c642574c-c01b-4398-aac9-43e514d7fa25
petri_samples = [logfun(8.0) for logfun in logfuns]
petri_samples = [logfun(8.0) for logfun in logfuns] # get population sizes at t = 8

# ╔═╡ d15fa091-839c-4239-96e2-eaf7335ce620
prob_splittable = mean((petri_samples .>= 1e4) .&& (petri_samples .<= 1e5))
Expand All @@ -188,7 +192,7 @@ prob_splittable = mean((petri_samples .>= 1e4) .&& (petri_samples .<= 1e5))
md"### 2"

# ╔═╡ 3d9b73ef-2aae-4c19-bddc-4081927ec92d
plot(logfuns[1:10:1000], xlims = (0, 12), legend = false, color = :skyblue, alpha = 0.5)
plot(logfuns, xlims = (0, 12), legend = false, color = :skyblue, alpha = 0.1)

# ╔═╡ 9ab88be4-4cf8-4747-ac36-3f1b82899be0
md"### 3🌟"
Expand All @@ -200,7 +204,7 @@ dropletdist🌟 = MixtureModel(
truncated(Normal(30, sqrt(30)), lower = 0.0)
],
[0.75, 0.25]
);
); # Normal distributions are the Poisson distributions of the continuous world (make sure to match mean and variance of both distributions) - ideally also use `truncated` to ensure the normal distributions are positive

# ╔═╡ 4e730df9-f619-464a-b8a3-57448132404b
begin
Expand Down
16 changes: 13 additions & 3 deletions exercises/solved_notebooks/P5_mcmc/MCMC_4-review.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.20.4
# v0.20.21

using Markdown
using InteractiveUtils
Expand Down Expand Up @@ -57,15 +57,23 @@ md"""
Where is the hornet nest located? You may assume the nest is somewhere within the plot's boundaries, and the flight speed of a wasp is around 5 to 10 m/s.
"""

# ╔═╡ 0ac905af-2798-41c2-8295-60e40f384f64
v_prior = LogNormal(log(8), 0.5)
# or something else positive with a peak around 5-10

# ╔═╡ 9bc4efa9-5e4c-41d5-86cb-997fb0788ea0
plot(v_prior, xlims = (0, 20)) # seems about right!

# ╔═╡ ac0496d3-aa62-4c07-8233-16fe67c52b24
@model function horenaars(xs, ys, ts)
σ ~ Exponential(10)
x_nest ~ Uniform(0, 1000)
y_nest ~ Uniform(0, 1000)
v_wasp ~ Gamma(8) # or something similar
v_wasp ~ LogNormal(log(8), 0.5)

for i in 1:length(ts)
dist = sqrt((xs[i] - x_nest)^2 + (ys[i] - y_nest)^2)
ts[i] ~ Normal(2*dist / v_wasp, 10)
ts[i] ~ Normal(2*dist / v_wasp, σ)
end
end

Expand Down Expand Up @@ -103,6 +111,8 @@ end
# ╠═c1850e64-9e2c-46fb-b7cd-22e8af81d3aa
# ╟─28cb3363-6856-4164-b60e-36ec2e88ed56
# ╟─484b56ec-57b2-496d-a5cf-1b1c0da97c58
# ╠═0ac905af-2798-41c2-8295-60e40f384f64
# ╠═9bc4efa9-5e4c-41d5-86cb-997fb0788ea0
# ╠═ac0496d3-aa62-4c07-8233-16fe67c52b24
# ╠═88e95234-4e43-4ead-bd0f-f32f9fc109c2
# ╠═c5a94b3e-a4b8-42e6-a8c7-79e942212852
Expand Down
Loading
Loading