Skip to content

Commit

Permalink
ShiftTo and PositiveDefinite no longer ignore state from one of t…
Browse files Browse the repository at this point in the history
…he two calls to the underlying `model`
  • Loading branch information
nicholaskl97 committed Jan 24, 2025
1 parent 00d75b9 commit 9f6629e
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ inputs.
## Returns
- The output of the positive definite model
- The state of the positive definite model. If the underlying model changes it state, the
state will be updated according to the call with the input `x`, not with the call using
`x0`.
state will be updated first according to the call with the input `x0`, then according to
the call with the input `x`.
## States
- `st`: a `NamedTuple` containing the state of the underlying `model` and the `x0` value
Expand All @@ -47,11 +47,11 @@ inputs.
r <: Function

function PositiveDefinite(model, x0::AbstractVector; ψ=Base.Fix1(sum, abs2),
r=Base.Fix1(sum, abs2) -)
r=Base.Fix1(sum, abs2) -)
return PositiveDefinite(model, (rng, in_dims) -> copy(x0), length(x0), ψ, r)
end
function PositiveDefinite(model; in_dims::Integer, ψ=Base.Fix1(sum, abs2),
r=Base.Fix1(sum, abs2) -)
r=Base.Fix1(sum, abs2) -)
return PositiveDefinite(model, zeros32, in_dims, ψ, r)
end
end
Expand All @@ -66,16 +66,16 @@ function (pd::PositiveDefinite)(x::AbstractVector, ps, st)
end

function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st)
ϕ0, _ = pd.model(st.x0, ps, st.model)
ϕx, new_model_st = pd.model(x, ps, st.model)
ϕ0, new_model_st = pd.model(st.x0, ps, st.model)
ϕx, final_model_st = pd.model(x, ps, new_model_st)
ϕx_cols = eachcol(ϕx)
return (
permutedims(
mapreduce(vcat, zip(eachcol(x), ϕx_cols); init=empty(first(ϕx_cols))) do (x, ϕx)
pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0)
end
),
merge(st, (; model=new_model_st))
merge(st, (; model=final_model_st))
)
end

Expand All @@ -99,8 +99,8 @@ where `Δϕ = out_val - ϕ(in_val, ps, st)`.
## Returns
- The output of the shifted model
- The state of the shifted model. If the underlying model changes it state, the
state will be updated according to the call with the input `x`, not the call using
`in_val`.
state will be updated first according to the call with the input `in_val`, then
according to the call with the input `x`.
## States
- `st`: a `NamedTuple` containing the state of the underlying `model` and the `in_val` and
Expand Down Expand Up @@ -134,11 +134,11 @@ function (s::ShiftTo)(x::AbstractVector, ps, st)
end

function (s::ShiftTo)(x::AbstractMatrix, ps, st)
ϕ0, _ = s.model(st.in_val, ps, st.model)
ϕ0, new_model_st = s.model(st.in_val, ps, st.model)
Δϕ = st.out_val .- ϕ0
ϕx, new_model_st = s.model(x, ps, st.model)
ϕx, final_model_st = s.model(x, ps, new_model_st)
return (
ϕx .+ Δϕ,
merge(st, (; model=new_model_st))
merge(st, (; model=final_model_st))
)
end

0 comments on commit 9f6629e

Please sign in to comment.