diff --git a/src/pop2net/creator.py b/src/pop2net/creator.py index e5e7a65..3ee8d9e 100644 --- a/src/pop2net/creator.py +++ b/src/pop2net/creator.py @@ -198,7 +198,7 @@ def _get_mother_group_id(self, agent, dummy_location) -> str: n_mother_locations_found += 1 # Check if the number of mother locations is not 1 - if n_mother_locations_found > 1: + if n_mother_locations_found > 1 and self.model.enable_p2n_warnings: warnings.warn( f"""For agent {agent}, {n_mother_locations_found} locations of class @@ -521,25 +521,19 @@ def create_locations( # If nxgraph is used do some checks if dummy_location.nxgraph is not None: - if dummy_location.n_agents is not None: + if dummy_location.n_agents is not None and self.model.enable_p2n_warnings: msg = """You cannot define location.n_agents if location.nxgraph is used. It will be set to the number of nodes in location.nxgraph automatically.""" warnings.warn(msg) location_cls.n_agents = len(list(dummy_location.nxgraph.nodes)) dummy_location.n_agents = len(list(dummy_location.nxgraph.nodes)) - if dummy_location.overcrowding is True: + if dummy_location.overcrowding is True and self.model.enable_p2n_warnings: msg = """You cannot define location.overcrowding if location.nxgraph is used. It will be set to `False` automatically.""" warnings.warn(msg) location_cls.overcrowding = False - dummy_location.n_agents = False - - # if dummy_location.n_agents is not None and dummy_location.n_agents < 1: - # msg = ( - # f"""{str_location_cls}.n_agents must be `None` or an integer greater than 0.""" - # ) - # raise Exception(msg) + dummy_location.overcrowding = False # bridge if not dummy_location.melt(): @@ -554,7 +548,7 @@ def create_locations( if len(bridge_values) == 0: pass - elif len(bridge_values) == 1: + elif len(bridge_values) == 1 and self.model.enable_p2n_warnings: msg = f"""{str_location_cls}.bridge() returned only one unique value. {str_location_cls}.bridge() must return at least two unique values in order to create locations that bring together agents with different values on the @@ -562,7 +556,7 @@ def create_locations( """ warnings.warn(msg) - elif len(bridge_values) > 1: + elif len(bridge_values) > 1 and self.model.enable_p2n_warnings: if dummy_location.n_agents is not None: msg = f"""You cannot use {str_location_cls}.n_agents and {str_location_cls}.bridge() at the same time. {str_location_cls}.n_agents diff --git a/src/pop2net/model.py b/src/pop2net/model.py index 6572b5f..6e15192 100644 --- a/src/pop2net/model.py +++ b/src/pop2net/model.py @@ -25,7 +25,13 @@ class Model(ap.Model): :class:`agentpy.Model` for more information. """ - def __init__(self, parameters=None, _run_id=None, **kwargs): + def __init__( + self, + parameters=None, + _run_id=None, + enable_p2n_warnings=True, + **kwargs, + ): """Initiate a simulation. Args: @@ -37,6 +43,7 @@ def __init__(self, parameters=None, _run_id=None, **kwargs): """ super().__init__(parameters, _run_id, **kwargs) self.g = nx.Graph() + self.enable_p2n_warnings = enable_p2n_warnings @property def agents(self) -> AgentList: @@ -443,7 +450,7 @@ def disconnect_agents( warn = True break - if warn: + if warn and self.enable_p2n_warnings: msg = "There are other agents at the location from which you have removed agents." warnings.warn(msg) @@ -452,7 +459,7 @@ def disconnect_agents( if remove_locations: self.remove_location(location=location) - if warn: + if warn and self.enable_p2n_warnings: msg = "You have removed a location to which other agents were still connected." warnings.warn(msg) diff --git a/tests/test_model/test_model_enable_warnings.py b/tests/test_model/test_model_enable_warnings.py new file mode 100644 index 0000000..ba37f08 --- /dev/null +++ b/tests/test_model/test_model_enable_warnings.py @@ -0,0 +1,43 @@ +import warnings + +import networkx as nx +import pytest + +import pop2net as p2n + + +def test_enable_warnings_true(): + model = p2n.Model(enable_p2n_warnings=True) + creator = p2n.Creator(model) + + class LineLocation(p2n.MagicLocation): + nxgraph = nx.path_graph(10) + n_agents = 10 + + creator.create_agents(n=10) + + with pytest.warns(UserWarning) as record: + creator.create_locations( + location_classes=[LineLocation], + delete_magic_agent_attributes=False, + ) + assert len(record) > 0, "Expected a warning but none were raised." + + +def test_enable_warnings_false(): + model = p2n.Model(enable_p2n_warnings=False) + creator = p2n.Creator(model) + + class LineLocation(p2n.MagicLocation): + nxgraph = nx.path_graph(10) + n_agents = 10 + + creator.create_agents(n=10) + + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + creator.create_locations( + location_classes=[LineLocation], + delete_magic_agent_attributes=False, + ) + assert len(record) == 0, "Expected no warnings but warnings were raised."