Skip to content

Commit

Permalink
Run code format
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiahpslewis committed Jan 13, 2025
1 parent 2d25c54 commit b5347d6
Show file tree
Hide file tree
Showing 24 changed files with 928 additions and 578 deletions.
21 changes: 11 additions & 10 deletions src/common/CircularArraySARTSATraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,28 @@ const CircularArraySARTSATraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
},
}

function CircularArraySARTSATraces(;
capacity::Int,
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces))
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) =
minimum(map(capacity, t.traces))
19 changes: 10 additions & 9 deletions src/common/CircularArraySARTSTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,28 @@ const CircularArraySARTSTraces = Traces{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
},
}

function CircularArraySARTSTraces(;
capacity::Int,
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
Traces(
action = CircularArrayBuffer{action_eltype}(action_size..., capacity),
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces))
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) =
minimum(map(capacity, t.traces))
26 changes: 16 additions & 10 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ const CircularArraySLARTTraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
},
}

function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
legal_actions_mask = Bool => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
Expand All @@ -26,12 +26,18 @@ function CircularArraySLARTTraces(;
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{LL′}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{LL′}(
CircularArrayBuffer{legal_actions_mask_eltype}(
legal_actions_mask_size...,
capacity + 1,
),
) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces))
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) =
minimum(map(capacity, t.traces))
13 changes: 9 additions & 4 deletions src/common/CircularPrioritizedTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@ struct CircularPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts}
default_priority::Float32
end

function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
function CircularPrioritizedTraces(
traces::AbstractTraces{names,Ts};
default_priority,
) where {names,Ts}
new_names = (:key, :priority, names...)
new_Ts = Tuple{Int,Float32,Ts.parameters...}
c = capacity(traces)
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
CircularVectorBuffer{Int}(c),
SumTree(c),
traces,
default_priority
default_priority,
)
end

Expand Down Expand Up @@ -60,6 +63,8 @@ function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
end
end

Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} =
NamedTuple{names}(map(k -> t[k][i], names))

capacity(t::CircularPrioritizedTraces) = ReinforcementLearningTrajectories.capacity(t.traces)
capacity(t::CircularPrioritizedTraces) =
ReinforcementLearningTrajectories.capacity(t.traces)
15 changes: 7 additions & 8 deletions src/common/ElasticArraySARTSATraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ const ElasticArraySARTSATraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
},
}

function ElasticArraySARTSATraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
Expand All @@ -24,8 +24,7 @@ function ElasticArraySARTSATraces(;
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
Traces(
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end

20 changes: 10 additions & 10 deletions src/common/ElasticArraySARTSTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ const ElasticArraySARTSTraces = Traces{
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
},
}

function ElasticArraySARTSTraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ())

state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)

state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
Traces(
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + Traces(
action = ElasticArray{action_eltype}(undef, action_size..., 0),
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
20 changes: 11 additions & 9 deletions src/common/ElasticArraySLARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ const ElasticArraySLARTTraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
},
}

function ElasticArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
legal_actions_mask = Bool => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
Expand All @@ -26,10 +26,12 @@ function ElasticArraySLARTTraces(;
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) +
MultiplexTraces{LL′}(
ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0),
) +
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
Traces(
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
6 changes: 3 additions & 3 deletions src/common/sum_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function correct_sample(t::SumTree, leaf_ind)
p = t.tree[leaf_ind]
# walk backwards until p != 0 or until leftmost leaf reached
tmp_ind = leaf_ind
while iszero(p) && (tmp_ind-1)*2 > length(t.tree)
while iszero(p) && (tmp_ind - 1) * 2 > length(t.tree)
tmp_ind -= 1
p = t.tree[tmp_ind]
end
Expand All @@ -151,7 +151,7 @@ function correct_sample(t::SumTree, leaf_ind)
end
return p, tmp_ind
end


function Base.get(t::SumTree, v)
parent_ind = 1
Expand Down Expand Up @@ -185,7 +185,7 @@ Random.rand(t::SumTree) = rand(Random.GLOBAL_RNG, t)

function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
inds, priorities = Vector{Int}(undef, n), Vector{T}(undef, n)
for i in 1:n
for i = 1:n
v = (i - 1 + rand(rng, T)) / n
ind, p = get(t, v * t.tree[1])
inds[i] = ind
Expand Down
20 changes: 11 additions & 9 deletions src/controllers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export InsertSampleRatioController, AsyncInsertSampleRatioController, EpisodeSampleRatioController
export InsertSampleRatioController,
AsyncInsertSampleRatioController, EpisodeSampleRatioController

"""
InsertSampleRatioController(;ratio=1., threshold=1)
Expand Down Expand Up @@ -43,18 +44,19 @@ end
function AsyncInsertSampleRatioController(
ratio,
threshold,
; ch_in_sz=1,
ch_out_sz=1,
n_inserted=0,
n_sampled=0
;
ch_in_sz = 1,
ch_out_sz = 1,
n_inserted = 0,
n_sampled = 0,
)
AsyncInsertSampleRatioController(
ratio,
threshold,
n_inserted,
n_sampled,
Channel(ch_in_sz),
Channel(ch_out_sz)
Channel(ch_out_sz),
)
end

Expand All @@ -75,14 +77,14 @@ end

function on_insert!(c::EpisodeSampleRatioController, n::Int, x::NamedTuple)
if n > 0
c.n_episodes += sum(x.terminal)
c.n_episodes += sum(x.terminal)
end
end

function on_sample!(c::EpisodeSampleRatioController)
if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio
if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio
c.n_sampled += 1
return true
end
return false
end
end
Loading

0 comments on commit b5347d6

Please sign in to comment.