Skip to content

Commit

Permalink
Merge pull request #113 from mariuzka/mk/improve_weighting
Browse files Browse the repository at this point in the history
Mk/improve weighting
  • Loading branch information
mariuzka authored Oct 21, 2024
2 parents 9fdff44 + 42d9057 commit 7a9310b
Show file tree
Hide file tree
Showing 7 changed files with 563 additions and 96 deletions.
1 change: 0 additions & 1 deletion docs/Introduction/introduction_creator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,6 @@
" overcrowding: bool | None = None\n",
" only_exact_n_agents: bool = False\n",
" n_locations: int | None = None\n",
" static_weight: bool = False\n",
" recycle: bool = True\n",
"\n",
" def filter(self, agent: p2n.Agent) -> bool:\n",
Expand Down
63 changes: 32 additions & 31 deletions docs/Introduction/introduction_simulations.ipynb

Large diffs are not rendered by default.

44 changes: 35 additions & 9 deletions src/pop2net/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ def shared_locations(self, agent, location_classes: list | None = None):
location_classes=location_classes,
)

def add_location(self, location: _location.Location) -> None:
def add_location(self, location: _location.Location, weight: float | None = None) -> None:
"""Add this Agent to a given location.
Args:
location: Add agent to this location.
weight (float | None): The edge weight between the agent and the location.
Defaults to None.
"""
self.model.add_agent_to_location(self, location)
self.model.add_agent_to_location(agent=self, location=location, weight=weight)

def add_locations(self, locations: list) -> None:
def add_locations(self, locations: list, weight: float | None = None) -> None:
"""Add this agent to multiple locations.
Args:
locations (list): Add the agent to these locations.
weight (float | None): The edge weight between the agent and the location.
Defaults to None.
"""
for location in locations:
self.add_location(location)
self.add_location(location=location, weight=weight)

def remove_location(self, location: _location.Location) -> None:
"""Remove this Agent from a given location.
Expand All @@ -86,6 +95,11 @@ def remove_location(self, location: _location.Location) -> None:
self.model.remove_agent_from_location(self, location)

def remove_locations(self, locations: list) -> None:
"""Remove this Agent from the given locations.
Args:
locations (list): A list of location instances.
"""
for location in locations:
self.remove_location(location)

Expand All @@ -99,12 +113,14 @@ def locations(self) -> _sequences.LocationList:
return self.model.locations_of_agent(self)

def get_agent_weight(self, agent: Agent, location_classes: list | None = None) -> float:
"""Return the contact weight between this agent and a given other agent.
"""Return the edge weight between this agent and a given other agent.
This is summed over all shared locations.
Args:
agent_v: The other agent.
agent: The other agent.
location_classes (list): A list of location classes to specify the type of locations
which are considered.
Returns:
A weight of the contact between the two agents.
Expand All @@ -115,16 +131,26 @@ def get_agent_weight(self, agent: Agent, location_classes: list | None = None) -
return weight

def get_location_weight(self, location) -> float:
"""Return the edge weight between this agent and a given location.
Args:
location (_type_): A location.
Returns:
float: The edge weight.
"""
return self.model.get_weight(agent=self, location=location)

def connect(self, agent: Agent, location_cls: type):
def connect(self, agent: Agent, location_cls: type, weight: float | None = None):
"""Connects this agent with a given other agent via an instance of a given location class.
Args:
agents (list): An agent to connect with.
agent (list): An agent to connect with.
location_cls (type): The location class that is used to create a location instance.
weight(float | None): The edge weight between the agents and the location.
Defaults to None.
"""
self.model.connect_agents(agents=[self, agent], location_cls=location_cls)
self.model.connect_agents(agents=[self, agent], location_cls=location_cls, weight=weight)

def disconnect(
self,
Expand Down
60 changes: 26 additions & 34 deletions src/pop2net/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,25 @@ def agents(self) -> AgentList:
"""
return self.model.agents_of_location(self)

def add_agent(self, agent: _agent.Agent) -> None:
def add_agent(self, agent: _agent.Agent, weight: float | None = None) -> None:
"""Assigns the given agent to this location.
Args:
agent: The agent that should be added to the location.
weight: The edge weight between the agent and the location. Defaults to 1.
"""
self.model.add_agent_to_location(self, agent)
self.model.add_agent_to_location(self, agent=agent, weight=weight)

def add_agents(self, agents: list) -> None:
def add_agents(self, agents: list, weight: float | None = None) -> None:
"""Add multiple agents at once.
Args:
agents (list): An iterable over agents.
weight(float | None): The edge weight between the agents and the location.
Defaults to None.
"""
for agent in agents:
self.add_agent(agent)
self.add_agent(agent=agent, weight=weight)

def remove_agent(self, agent: _agent.Agent) -> None:
"""Removes the given agent from this location.
Expand Down Expand Up @@ -84,14 +87,20 @@ def neighbors(self, agent: _agent.Agent) -> AgentList:
agents.remove(agent)
return agents

def set_weight(self, agent, weight) -> None:
def set_weight(self, agent, weight: float | None = None) -> None:
"""Set the weight of an agent at the current location.
If weight is None the method location.weight() will be used to generate a weight.
Args:
agent (Agent): The agent.
weight (float): The weight.
"""
self.model.set_weight(agent=agent, location=self, weight=weight)
self.model.set_weight(
agent=agent,
location=self,
weight=weight,
)

def get_weight(self, agent: _agent.Agent) -> float:
"""Return the edge weight between an agent and the location.
Expand All @@ -104,6 +113,17 @@ def get_weight(self, agent: _agent.Agent) -> float:
"""
return self.model.get_weight(agent=agent, location=self)

def weight(self, agent) -> float: # noqa: ARG002
"""Generates the edge weight between a given agent and the location instance.
Args:
agent (_type_): An agent.
Returns:
float: The weight between the given agent and the location.
"""
return 1

def project_weights(self, agent1: _agent.Agent, agent2: _agent.Agent) -> float:
"""Calculates the edge weight between two agents assigned to the same location instance.
Expand Down Expand Up @@ -194,21 +214,6 @@ def split(self, agent: _agent.Agent) -> float | str | list | None: # noqa: ARG0
"""
return None

def weight(self, agent: _agent.Agent) -> float | None: # noqa: ARG002
"""Defines the edge weight between the agent and the location instance.
Defines how the edge weight between an agent and the location is determined.
This is a boilerplate implementation of this method which always returns 1; i.e. all
edge weights will be 1. Override this method in your own implementations as you seem fit.
Args:
agent: The agent that is currently processed by the Creator.
Returns:
The edge weight.
"""
return None

def stick_together(self, agent: _agent.Agent) -> float | str:
"""Assigns agents with a shared value on an attribute to the same location instance.
Expand Down Expand Up @@ -246,19 +251,6 @@ def refine(self):
"""An action that is performed after all location instances have been created."""
pass

def _update_weight(self, agent: _agent.Agent) -> None:
"""Create or update the agent-speific weight.
Args:
agent: The agent to be updated.
"""
self.set_weight(agent, self.weight(agent))

def _update_weights(self) -> None:
"""Update the weight of every agent on this location."""
for agent_ in self.agents:
self._update_weight(agent_)

def _subsplit(self, agent: _agent.Agent) -> str | float | list | None: # noqa: ARG002
"""Splits a location instance into sub-instances to create a certain network structure.
Expand Down
51 changes: 30 additions & 21 deletions src/pop2net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,6 @@ def __init__(self, parameters=None, _run_id=None, **kwargs):
super().__init__(parameters, _run_id, **kwargs)
self.g = nx.Graph()

def sim_step(self) -> None:
"""Do 1 step in the simulation."""
self.t += 1

# TODO: Rethink the following:
for location in self.locations:
if hasattr(location, "static_weight") and hasattr(location, "_update_weights"):
if not location.static_weight:
location._update_weights()

self.step()
self.update()

if self.t >= self._steps: # type: ignore
self.running = False

@property
def agents(self) -> AgentList:
"""Show a iterable view of all agents in the environment.
Expand Down Expand Up @@ -142,7 +126,7 @@ def add_agent_to_location(
self,
location: _location.Location,
agent: _agent.Agent,
weight: float = 1,
weight: float | None = None,
**kwargs,
) -> None:
"""Add an agent to a specific location.
Expand Down Expand Up @@ -376,15 +360,19 @@ def agents_between_locations(self, location1, location2, agent_classes: list | N
objs=self._objects_between_objects(location1, location2, agent_classes),
)

def set_weight(self, agent, location, weight) -> None:
def set_weight(self, agent, location, weight: float | None = None) -> None:
"""Set the weight of an agent at a location.
If weight is None the method location.weight() will be used to generate a weight.
Args:
agent (Agent): The agent.
location (Location): The location.
weight (int): The weight
"""
self.g[agent.id][location.id]["weight"] = 1 if weight is None else weight
self.g[agent.id][location.id]["weight"] = (
location.weight(agent) if weight is None else weight
)

def get_weight(self, agent, location) -> int:
"""Get the weight of an agent at a location.
Expand All @@ -398,15 +386,17 @@ def get_weight(self, agent, location) -> int:
"""
return self.g[agent.id][location.id]["weight"]

def connect_agents(self, agents: list, location_cls: type):
def connect_agents(self, agents: list, location_cls: type, weight: float | None = None):
"""Connects multiple agents via an instance of a given location class.
Args:
agents (list): A list of agents.
location_cls (type): The location class that is used to create a location instance.
weight (float | None): The edge weight between the agents and the location.
Defaults to None.
"""
location = location_cls(model=self)
location.add_agents(agents)
location.add_agents(agents=agents, weight=weight)

def disconnect_agents(
self,
Expand Down Expand Up @@ -526,3 +516,22 @@ def export_agent_network(
graph.add_edge(agent.id, agent_v.id, weight=weight)

return graph

def update_weights(self, location_classes: list | None = None) -> None:
"""Updates the edge weights between agents and locations.
If you only want to update the weights of specific types of locations
specify those types in location_classes.
Args:
location_classes (list | None, optional): A list of location classes that specifiy for
which location types the weights should be updated.
If location_classes is None all locations are considered. Defaults to None.
"""
for location in (
self.locations
if location_classes is None
else [location for location in self.locations if type(location) in location_classes]
):
for agent in location.agents:
location.set_weight(agent=agent, weight=location.weight(agent=agent))
Loading

0 comments on commit 7a9310b

Please sign in to comment.