Skip to content

Commit

Permalink
Merge pull request #150 from ilabcode/dev
Browse files Browse the repository at this point in the history
0.6.6
  • Loading branch information
PTWaade authored Oct 14, 2024
2 parents 1274236 + b2f197d commit 59c7dcd
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ActionModels"
uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
authors = ["Peter Thestrup Waade [email protected]", "Anna Hedvig Møller [email protected]", "Jacopo Comoglio [email protected]", "Christoph Mathys [email protected]"]
version = "0.6.5"
version = "0.6.6"

[deps]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
6 changes: 3 additions & 3 deletions src/fitting/agent_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
inputs_per_agent::Vector{I},
actions_per_agent::Vector{Vector{R}},
actions_flattened::Vector{R},
missing_actions::Nothing,
missing_actions::Val{false},
) where {D<:Dict,I<:Vector,R<:Real}

#TODO: Could use a list comprehension here to make it more efficient
Expand Down Expand Up @@ -58,7 +58,7 @@ end
inputs_per_agent::Vector{I},
actions_per_agent::Vector{Matrix{R}},
actions_flattened::Matrix{R},
missing_actions::Nothing,
missing_actions::Val{false},
) where {D<:Dict,I<:Vector,R<:Real}

#Initialize a vector for storing the action probability distributions
Expand Down Expand Up @@ -114,7 +114,7 @@ end
inputs_per_agent::Vector{I},
actions_per_agent::Vector{A},
actions_flattened::A2,
missing_actions::MissingActions,
missing_actions::Val{true},
) where {D<:Dict,I<:Vector,A<:Array,A2<:Array}

#For each agent
Expand Down
65 changes: 57 additions & 8 deletions src/fitting/create_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function create_model(
input_cols::Union{Vector{T1},T1},
action_cols::Union{Vector{T2},T3},
grouping_cols::Union{Vector{T3},T3},
check_parameter_rejections::Union{Nothing,CheckRejections} = nothing,
check_parameter_rejections::Bool = false,
verbose::Bool = true,
) where {T1<:Union{String,Symbol},T2<:Union{String,Symbol},T3<:Union{String,Symbol}}

Expand Down Expand Up @@ -93,10 +93,10 @@ function create_model(
## Determine whether any actions are missing ##
if actions isa Vector{A} where {R<:Real,A<:Array{Union{Missing,R}}}
#If there are missing actions
missing_actions = MissingActions()
missing_actions = true
elseif actions isa Vector{A} where {R<:Real,A<:Array{R}}
#If there are no missing actions
missing_actions = nothing
missing_actions = false
end

#Create a full model combining the agent model and the statistical model
Expand All @@ -106,8 +106,8 @@ function create_model(
inputs,
actions,
agent_ids,
missing_actions = missing_actions,
check_parameter_rejections = check_parameter_rejections,
Val(check_parameter_rejections),
Val(missing_actions),
)
end

Expand All @@ -119,9 +119,9 @@ end
population_model::DynamicPPL.Model,
inputs_per_agent::Vector{I},
actions_per_agent::Vector{A},
agent_ids::Vector{Symbol};
missing_actions::Union{Nothing,MissingActions} = MissingActions(),
check_parameter_rejections::Nothing = nothing,
agent_ids::Vector{Symbol},
check_parameter_rejections::Val{false},
missing_actions::Union{Val{false},Val{true}};
actions_flattened::A2 = vcat(actions_per_agent...),
) where {I<:Vector,R<:Real,A1<:Union{R,Union{Missing,R}},A<:Array{A1},A2<:Array}

Expand All @@ -142,3 +142,52 @@ end
#Return values fron the population model (agent parameters and oher values)
return population_values
end




###################################################
### VERSION WITH CHECK FOR PARAMETER REJECTIONS ###
###################################################
@model function full_model(
agent::Agent,
population_model::DynamicPPL.Model,
inputs_per_agent::Vector{I},
actions_per_agent::Vector{A},
agent_ids::Vector{Symbol},
check_parameter_rejections::Val{true},
missing_actions::Union{Val{false},Val{true}};
actions_flattened::A2 = vcat(actions_per_agent...),
) where {I<:Vector,R<:Real,A1<:Union{R,Union{Missing,R}},A<:Array{A1},A2<:Array}

#Check for errors
try
#Generate the agent parameters from the statistical model
@submodel population_values = population_model

#Generate the agent's behavior
@submodel agent_models(
agent,
agent_ids,
population_values.agent_parameters,
inputs_per_agent,
actions_per_agent,
actions_flattened,
missing_actions,
)

#Return values fron the population model (agent parameters and oher values)
return population_values

#If there are errors
catch err
#If the error is a rejection of parameters
if err isa RejectParameters
#Make Turing reject the sample
Turing.@addlogprob!(-Inf)
else
#Rethrow other errors
rethrow(err)
end
end
end
2 changes: 0 additions & 2 deletions src/structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ end
PopulationModelReturn(agent_parameters::Vector{D}) where {D<:Dict} =
PopulationModelReturn(agent_parameters, nothing)

struct CheckRejections end
struct MissingActions end
mutable struct FitModelResults
chains::Chains
model::DynamicPPL.Model
Expand Down
42 changes: 41 additions & 1 deletion test/testsuite/create_model_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ using Turing: AutoReverseDiff

#Extract agent parameters
agent_parameters = extract_quantities(model, fitted_model)
estimates_df = get_estimates(agent_parameters)
estimates_df = get_estimates(agent_parameters, DataFrame)
estimates_dict = get_estimates(agent_parameters, Dict)
#estimates_chains = get_estimates(agent_parameters, Chains)

#Extract state trajectories
state_trajectories = get_trajectories(model, fitted_model, ["value", "action"])
Expand Down Expand Up @@ -398,4 +399,43 @@ using Turing: AutoReverseDiff
#Rename chains
renamed_model = rename_chains(fitted_model, model)
end

@testset "Check for parameter rejections" begin
#Action model with multiple actions
function action_with_errors(agent, input::R) where {R<:Real}

noise = agent.parameters["noise"]

if noise > 2.5
#Throw an error that will reject samples when fitted
throw(
RejectParameters(
"Rejected noise",
),
)
end

actiondist = Normal(input, noise)

return actiondist
end
#Create agent
new_agent = init_agent(action_with_errors, parameters = Dict("noise" => 1.0))

new_prior = Dict("noise" => truncated(Normal(0.0, 1.0), lower = 0, upper = 3.1))

#Create model
model = create_model(
new_agent,
new_prior,
data,
input_cols = [:inputs],
action_cols = [:actions],
grouping_cols = :id,
check_parameter_rejections = true,
)

#Fit model
fitted_model = sample(model, sampler, n_iterations; sampling_kwargs...)
end
end

2 comments on commit 59c7dcd

@PTWaade
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

#Minor changes
Added support for rejecting parameters

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117276

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.6 -m "<description of version>" 59c7dcdee4f63069cfaa71187a6e4f4d906d9f34
git push origin v0.6.6

Please sign in to comment.