Skip to content

Commit

Permalink
Merge pull request #125 from mariuzka/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
mariuzka authored Nov 6, 2024
2 parents 9182263 + e0e0186 commit 78758cf
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
18 changes: 6 additions & 12 deletions src/pop2net/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _get_mother_group_id(self, agent, dummy_location) -> str:
n_mother_locations_found += 1

# Check if the number of mother locations is not 1
if n_mother_locations_found > 1:
if n_mother_locations_found > 1 and self.model.enable_p2n_warnings:
warnings.warn(
f"""For agent {agent},
{n_mother_locations_found} locations of class
Expand Down Expand Up @@ -521,25 +521,19 @@ def create_locations(

# If nxgraph is used do some checks
if dummy_location.nxgraph is not None:
if dummy_location.n_agents is not None:
if dummy_location.n_agents is not None and self.model.enable_p2n_warnings:
msg = """You cannot define location.n_agents if location.nxgraph is used.
It will be set to the number of nodes in location.nxgraph automatically."""
warnings.warn(msg)
location_cls.n_agents = len(list(dummy_location.nxgraph.nodes))
dummy_location.n_agents = len(list(dummy_location.nxgraph.nodes))

if dummy_location.overcrowding is True:
if dummy_location.overcrowding is True and self.model.enable_p2n_warnings:
msg = """You cannot define location.overcrowding if location.nxgraph is used.
It will be set to `False` automatically."""
warnings.warn(msg)
location_cls.overcrowding = False
dummy_location.n_agents = False

# if dummy_location.n_agents is not None and dummy_location.n_agents < 1:
# msg = (
# f"""{str_location_cls}.n_agents must be `None` or an integer greater than 0."""
# )
# raise Exception(msg)
dummy_location.overcrowding = False

# bridge
if not dummy_location.melt():
Expand All @@ -554,15 +548,15 @@ def create_locations(
if len(bridge_values) == 0:
pass

elif len(bridge_values) == 1:
elif len(bridge_values) == 1 and self.model.enable_p2n_warnings:
msg = f"""{str_location_cls}.bridge() returned only one unique value.
{str_location_cls}.bridge() must return at least two unique values in order
to create locations that bring together agents with different values on the
same attribute.
"""
warnings.warn(msg)

elif len(bridge_values) > 1:
elif len(bridge_values) > 1 and self.model.enable_p2n_warnings:
if dummy_location.n_agents is not None:
msg = f"""You cannot use {str_location_cls}.n_agents and
{str_location_cls}.bridge() at the same time. {str_location_cls}.n_agents
Expand Down
13 changes: 10 additions & 3 deletions src/pop2net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ class Model(ap.Model):
:class:`agentpy.Model` for more information.
"""

def __init__(self, parameters=None, _run_id=None, **kwargs):
def __init__(
self,
parameters=None,
_run_id=None,
enable_p2n_warnings=True,
**kwargs,
):
"""Initiate a simulation.
Args:
Expand All @@ -37,6 +43,7 @@ def __init__(self, parameters=None, _run_id=None, **kwargs):
"""
super().__init__(parameters, _run_id, **kwargs)
self.g = nx.Graph()
self.enable_p2n_warnings = enable_p2n_warnings

@property
def agents(self) -> AgentList:
Expand Down Expand Up @@ -443,7 +450,7 @@ def disconnect_agents(
warn = True
break

if warn:
if warn and self.enable_p2n_warnings:
msg = "There are other agents at the location from which you have removed agents."
warnings.warn(msg)

Expand All @@ -452,7 +459,7 @@ def disconnect_agents(
if remove_locations:
self.remove_location(location=location)

if warn:
if warn and self.enable_p2n_warnings:
msg = "You have removed a location to which other agents were still connected."
warnings.warn(msg)

Expand Down
43 changes: 43 additions & 0 deletions tests/test_model/test_model_enable_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import warnings

import networkx as nx
import pytest

import pop2net as p2n


def test_enable_warnings_true():
model = p2n.Model(enable_p2n_warnings=True)
creator = p2n.Creator(model)

class LineLocation(p2n.MagicLocation):
nxgraph = nx.path_graph(10)
n_agents = 10

creator.create_agents(n=10)

with pytest.warns(UserWarning) as record:
creator.create_locations(
location_classes=[LineLocation],
delete_magic_agent_attributes=False,
)
assert len(record) > 0, "Expected a warning but none were raised."


def test_enable_warnings_false():
model = p2n.Model(enable_p2n_warnings=False)
creator = p2n.Creator(model)

class LineLocation(p2n.MagicLocation):
nxgraph = nx.path_graph(10)
n_agents = 10

creator.create_agents(n=10)

with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always")
creator.create_locations(
location_classes=[LineLocation],
delete_magic_agent_attributes=False,
)
assert len(record) == 0, "Expected no warnings but warnings were raised."

0 comments on commit 78758cf

Please sign in to comment.