Skip to content

Commit

Permalink
fix pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
mariuzka committed Jan 17, 2024
1 parent 39ccfb6 commit 74164a2
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 227 deletions.
2 changes: 1 addition & 1 deletion src/popy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, model, *args, **kwargs) -> None:
self.model = model
self.model.env.add_agent(self)
self.setup()
self._initial_locations = []


def setup(self) -> None:
"""Instantiate an Agent.
Expand Down
122 changes: 98 additions & 24 deletions src/popy/location.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base class to create Location objects."""
from __future__ import annotations

import math
from typing import Any

from agentpy.objects import Object
Expand All @@ -9,8 +10,6 @@
from . import agent as _agent
from . import model as _model

import math

class Location(Object):
"""Base class for location objects."""

Expand All @@ -26,14 +25,14 @@ def __init__(self, model: _model.Model) -> None:
self.subgroup_id: int | None = None
self.group_value: int | str | None = None
self.subgroup_value: int | str | None = None
self.size: int = 2

self.size: int | None = None
self.allow_overcrowding: bool = True
self.n_locations: int | None = None
self.static_weight: bool = False
self.round_function = round
self.multi_melt: bool = True
self.n_branches: int > 0 = 2
self.n_branches: int = 2

self.model.env.add_location(self)

Expand Down Expand Up @@ -105,8 +104,16 @@ def filter(self, agent: _agent.Agent) -> bool: # noqa: ARG002
True if the agent is allowed to join the location, False otherwise.
"""
return True

def find(self, agent: _agent.Agent) -> bool:

def find(self, agent: _agent.Agent) -> bool: # noqa: ARG002
"""_summary_.
Args:
agent (_agent.Agent): _description_
Returns:
bool: _description_
"""
return True

def split(self, agent: _agent.Agent) -> float | str | list | None: # noqa: ARG002
Expand All @@ -126,7 +133,15 @@ def split(self, agent: _agent.Agent) -> float | str | list | None: # noqa: ARG0
"""
return None

def subsplit(self, agent: _agent.Agent) -> Any:
def subsplit(self, agent: _agent.Agent) -> Any: # noqa: ARG002
"""_summary_.
Args:
agent (_agent.Agent): _description_
Returns:
Any: _description_
"""
return None

def is_affiliated(self, agent: _agent.Agent) -> bool:
Expand Down Expand Up @@ -197,49 +212,94 @@ def project_weights(self, agent1: _agent.Agent, agent2: _agent.Agent) -> float:
Combined edge weight.
"""
return min([self.get_weight(agent1), self.get_weight(agent2)])


def stick_together(self, agent: _agent.Agent) -> Any:
"""Sticks agents together by attribute.
Args:
agent (_agent.Agent): _description_
agent (_agent.Agent): An agent instance.
Returns:
Any: _description_
"""
return agent.id

# TODO: rename this method
def do_this_after_creation(self):
"""_summary_."""
pass

def nest(self):
return None

def melt(self) -> None | []:
def nest(self) -> Location | None:
"""_summary_.
Returns:
_type_: _description_
"""
return None

def melt(self) -> list | tuple:
"""_summary_.
Returns:
None | list | tuple: _description_
"""
return []


class LineLocation(Location):
"""A location that connects agents via a line network."""

def subsplit(self, agent):
"""_summary_.
Args:
agent (_type_): _description_
Returns:
_type_: _description_
"""
pos = self.group_agents.index(agent)
right = (pos + 1)
return [pos, right]


class RingLocation(Location):
"""A location that connects agents via a ring network."""

def subsplit(self, agent):
"""_summary_.
Args:
agent (_type_): _description_
Returns:
_type_: _description_
"""
pos = self.group_agents.index(agent)
right = (pos + 1) % len(self.group_agents)
return [pos, right]


class GridLocation(Location):
"""A location that connects agents via a grid network."""

def subsplit(self, agent):
row_len = math.ceil(math.sqrt(self.size))
"""_summary_.
Args:
agent (_type_): _description_
Returns:
_type_: _description_
"""
row_len = math.ceil(
math.sqrt(
self.size if self.size is not None else len(self.group_agents),
),
)
right_edge_positions = [row_len * i - 1 for i in range(row_len)]

pos = self.group_agents.index(agent)
right = pos + 1
left = pos - 1
Expand All @@ -250,15 +310,25 @@ def subsplit(self, agent):
return_list.append("-".join(sorted([str(pos), str(left)])))
return_list.append("-".join(sorted([str(pos), str(top)])))
return_list.append("-".join(sorted([str(pos), str(bottom)])))

if pos not in right_edge_positions:
return_list.append("-".join(sorted([str(pos), str(right)])))

return return_list


class TreeLocation(Location):
"""A location that connects agents via a tree network."""

def subsplit(self, agent):
"""_summary_.
Args:
agent (_type_): _description_
Returns:
_type_: _description_
"""
if isinstance(self, StarLocation):
self.n_branches = self.size - 1

Expand All @@ -267,7 +337,10 @@ def subsplit(self, agent):
agent._TEMP_contacts = [agent]

for a in self.group_agents:
if not hasattr(a, "_TEMP_infected") or hasattr(a, "_TEMP_infected") and not a._TEMP_infected:
if (
not hasattr(a, "_TEMP_infected")
or hasattr(a, "_TEMP_infected") and not a._TEMP_infected
):
if agent is not a:
agent._TEMP_contacts.append(a)
if hasattr(a, "_TEMP_contacts"):
Expand All @@ -277,11 +350,12 @@ def subsplit(self, agent):
a._TEMP_infected = True
if len(agent._TEMP_contacts) >= self.n_branches + 1:
break

location_ids = ["-".join(sorted([str(agent.id), str(a.id)])) for a in agent._TEMP_contacts]

return location_ids


class StarLocation(TreeLocation):
pass
"""A location that connects agents via a star network."""
pass
Loading

0 comments on commit 74164a2

Please sign in to comment.