Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Apr 17, 2024
1 parent 1128d43 commit 447f168
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
1 change: 1 addition & 0 deletions avalanche/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Most of these protocols are checked dynamically at runtime, so it is often not
necessary to inherit explicit from them or implement all the methods.
"""

from abc import ABC
from typing import Any, TypeVar, Generic, Protocol, runtime_checkable
from typing import TYPE_CHECKING
Expand Down
12 changes: 9 additions & 3 deletions avalanche/evaluation/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def update(self, res, *, stream=None):
else:
self.metrics_res[k] = [v]

def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None):
def get(
self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None
):
"""Returns a metric value given its name and aggregation method.
:param name: name of the metric.
Expand All @@ -102,9 +104,13 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=N
assert time_reduce in {None, "last", "mean"}
assert exp_reduce in {None, "sample_mean", "experience_mean", "weighted_sum"}
if exp_reduce == "weighted_sum":
assert weights is not None, "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`."
assert (
weights is not None
), "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`."
else:
assert weights is None, "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`"
assert (
weights is None
), "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`"

if stream is not None:
name = f"{stream.name}/{name}"
Expand Down
7 changes: 5 additions & 2 deletions avalanche/models/dynamic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
colors[None] = colors["END"]


@experimental("New dynamic optimizers. The API may slightly change in the next versions.")
@experimental(
"New dynamic optimizers. The API may slightly change in the next versions."
)
class DynamicOptimizer(Adaptable):
"""Avalanche dynamic optimizer.
Expand Down Expand Up @@ -64,6 +66,7 @@ class DynamicOptimizer(Adaptable):
# first model.pre_adapt, then optimizer.pre_adapt
agent.pre_adapt(experience)
"""

def __init__(self, optim):
self.optim = optim

Expand All @@ -78,7 +81,7 @@ def pre_adapt(self, agent: Agent, exp: CLExperience):
update_optimizer(
self.optim,
new_params=dict(agent.model.named_parameters()),
optimized_params=dict(agent.model.named_parameters())
optimized_params=dict(agent.model.named_parameters()),
)


Expand Down
12 changes: 6 additions & 6 deletions tests/evaluation/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __len__(self):
time_reduce=None,
exp_reduce="weighted_sum",
weights=[1, 2],
stream=fake_stream
stream=fake_stream,
)
np.testing.assert_array_almost_equal(
v, [(1 + 3*2), (5 + 7*2), (11 + 13*2)]
v, [(1 + 3 * 2), (5 + 7 * 2), (11 + 13 * 2)]
)

# time = "last"
Expand All @@ -87,9 +87,9 @@ def __len__(self):
time_reduce="last",
exp_reduce="weighted_sum",
stream=fake_stream,
weights=[1, 2]
weights=[1, 2],
)
self.assertAlmostEqual(v, 11 + 13*2)
self.assertAlmostEqual(v, 11 + 13 * 2)

# time_reduce = "mean"
v = mc.get(
Expand All @@ -115,9 +115,9 @@ def __len__(self):
time_reduce="mean",
exp_reduce="weighted_sum",
stream=fake_stream,
weights=[1, 2]
weights=[1, 2],
)
self.assertAlmostEqual(v, ((1 + 3*2) + (5 + 7*2) + (11 + 13*2)) / 3)
self.assertAlmostEqual(v, ((1 + 3 * 2) + (5 + 7 * 2) + (11 + 13 * 2)) / 3)


if __name__ == "__main__":
Expand Down

0 comments on commit 447f168

Please sign in to comment.