Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Save/Resume backend to JLD2 #152

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ActionModels"
uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
authors = ["Peter Thestrup Waade [email protected]", "Anna Hedvig Møller [email protected]", "Jacopo Comoglio [email protected]", "Christoph Mathys [email protected]"]
authors = ["Peter Thestrup Waade [email protected]", "Anna Hedvig Møller [email protected]", "Jacopo Comoglio [email protected]", "Luke Ring [email protected]", "Christoph Mathys [email protected]"]
version = "0.6.6"

[deps]
Expand All @@ -9,9 +9,8 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MCMCChainsStorage = "51a256e2-afd8-4c38-88d8-a98ba8ad53ca"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -24,12 +23,11 @@ DataFrames = "1.6"
Distributed = "1"
Distributions = "0.25"
ForwardDiff = "0.10"
HDF5 = "0.17"
julia = "1.10"
JLD2 = "0.5"
Logging = "1"
MCMCChainsStorage = "0.1"
ProgressMeter = "1"
RecipesBase = "1.3"
Reexport = "1"
ReverseDiff = "1.15"
Turing = "0.34"
julia = "1.10"
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ After this introduction, you will be presented with a detailed step-by-step guid

Defning a premade agent

````@example Introduction
````julia @example Introduction
using ActionModels
````

Find premade agent, and define agent with default parameters

````@example Introduction
````julia @example Introduction
premade_agent("help")

agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

Set inputs and give inputs to agent

````@example Introduction
````julia @example Introduction
inputs = [1,0,0,0,1,1,1,1,0,1,0,1,0,1,1]
actions = give_inputs!(agent,inputs)

Expand All @@ -44,26 +44,25 @@ plot_trajectory(agent, "action_probability")

Fit learning rate. Start by setting prior

````@example Introduction
````julia @example Introduction
using Distributions
priors = Dict("learning_rate" => Normal(0.5, 0.5))
````

Run model

````@example Introduction
````julia @example Introduction
chains = fit_model(agent, priors, inputs, actions, n_chains = 1, n_iterations = 10)
````

Plot prior and posterior

````@example Introduction
````julia @example Introduction
plot_parameter_distribution(chains,priors)
````

Get posteriors from chains

````@example Introduction
````julia @example Introduction
get_posteriors(chains)
````

2 changes: 1 addition & 1 deletion src/ActionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module ActionModels
using Reexport
using Turing, ReverseDiff, DataFrames, AxisArrays, RecipesBase, Logging
using ProgressMeter, Distributed #TODO: get rid of this (only needed for parameter recovery)
using MCMCChainsStorage, HDF5
using JLD2
@reexport using Distributions
using Turing: DynamicPPL, ForwardDiff, AutoReverseDiff, AbstractMCMC
#Export functions
Expand Down
10 changes: 3 additions & 7 deletions src/fitting/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function validate_saved_sampling_state!(
last_seg = 0
n_segs = 0
for cur_seg in 1:n_segments
if isfile(joinpath(save_resume.path, "$(save_resume.chain_prefix)_c$(chain)_s$(cur_seg).h5"))
if isfile(joinpath(save_resume.path, "$(save_resume.chain_prefix)_c$(chain)_s$(cur_seg).jld2"))
last_seg = cur_seg
n_segs += 1
end
Expand All @@ -85,9 +85,7 @@ function load_segment(
segment::Int,
)
# load the chain
chain = h5open(joinpath(save_resume.path, "$(save_resume.chain_prefix)_c$(chain_n)_s$(segment).h5"), "r") do file
read(file, Chains)
end
chain = JLD2.load_object(joinpath(save_resume.path, "$(save_resume.chain_prefix)_c$(chain_n)_s$(segment).jld2"))
# extra validation?
return chain
end
Expand All @@ -99,9 +97,7 @@ function save_segment(
seg_n::Int,
)
# save the chain
h5open(joinpath(save_resume.path, "$(save_resume.chain_prefix)_c$(chain_n)_s$(seg_n).h5"), "w") do file
write(file, seg)
end
JLD2.save_object(joinpath(save_resume.path, "$(save_resume.chain_prefix)_c$(chain_n)_s$(seg_n).jld2"), seg)
end

function combine_segments(
Expand Down
4 changes: 2 additions & 2 deletions test/testsuite/create_model_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ using Turing: AutoReverseDiff

noise = agent.parameters["noise"]

if noise > 2.5
if noise > 3.0
#Throw an error that will reject samples when fitted
throw(
RejectParameters(
Expand Down Expand Up @@ -438,4 +438,4 @@ using Turing: AutoReverseDiff
#Fit model
fitted_model = sample(model, sampler, n_iterations; sampling_kwargs...)
end
end
end
30 changes: 28 additions & 2 deletions test/testsuite/fit_model_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using ActionModels, DataFrames
using Distributed
using Turing: AutoReverseDiff, NUTS

@testset "fit model" begin

Expand Down Expand Up @@ -31,6 +32,9 @@ using Distributed
n_chains = 2
sampling_kwargs = (; progress = false)

# this way we keep tempdir
save_resume = ChainSaveResume(path = mktempdir())

@testset "basic run" begin

#Create model
Expand All @@ -55,7 +59,6 @@ using Distributed
end

@testset "basic run - save_resume" begin

#Create model
model = create_model(
agent,
Expand All @@ -65,7 +68,7 @@ using Distributed
action_cols = :actions,
grouping_cols = :ID,
)
save_resume = ChainSaveResume(path = mktempdir())

results = fit_model(
model;
sampler = sampler,
Expand All @@ -78,6 +81,29 @@ using Distributed
@test results isa ActionModels.FitModelResults
end

@testset "Continuing from save_resume state" begin
#Create model
model = create_model(
agent,
prior,
data,
input_cols = :inputs,
action_cols = :actions,
grouping_cols = :ID,
)

results = fit_model(
model;
sampler = sampler,
n_iterations = n_iterations * 2, # bump up the iterations to continue
n_chains = n_chains,
save_resume=save_resume,
sampling_kwargs...,
)

@test results isa ActionModels.FitModelResults
end


@testset "parallelized" begin
addprocs(4)
Expand Down