Skip to content

Commit

Permalink
typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-B committed Jan 9, 2025
1 parent 51b2102 commit 83e6f46
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 22 deletions.
38 changes: 24 additions & 14 deletions spynnaker/pyNN/models/common/local_only_2d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from math import ceil, log2, floor
from typing import Final, Tuple, Union
from collections import namedtuple, defaultdict
from typing import Dict, Final, List, NamedTuple, Tuple, Union
from collections import defaultdict
from pacman.model.graphs.application import (
ApplicationVertex, ApplicationVirtualVertex)
from pacman.model.graphs.common.slice import Slice
from pacman.model.graphs.common.mdslice import MDSlice
from pacman.model.routing_info import AppVertexRoutingInfo
from spinn_front_end_common.utilities.constants import BYTES_PER_WORD
from spynnaker.pyNN.models.projection import Projection
from spynnaker.pyNN.data.spynnaker_data_view import SpynnakerDataView
from spynnaker.pyNN.models.neuron import AbstractPopulationVertex
from spynnaker.pyNN.utilities.utility_calls import get_n_bits
Expand All @@ -36,11 +38,11 @@
KEY_INFO_SIZE: Final[int] = 4 * BYTES_PER_WORD

#: A source
Source = namedtuple(
"Source", ["projection", "local_delay", "delay_stage"])
Source = NamedTuple("Source",
[("projection", Projection), ("local_delay", int),
("delay_stage", int)])


def get_div_const(value):
def get_div_const(value: int) -> int:
""" Get the values used to perform fast division by an integer constant
:param int value: The value to be divided by
Expand All @@ -56,7 +58,8 @@ def get_div_const(value):
+ (sh1 << BITS_PER_SHORT) + m)


def get_delay_for_source(incoming):
def get_delay_for_source(
incoming: Projection) -> Tuple[ApplicationVertex, int, int ,str]:
""" Get the vertex which will send data from a given source projection,
along with the delay stage and locally-handled delay value
Expand All @@ -67,19 +70,24 @@ def get_delay_for_source(incoming):
# pylint: disable=protected-access
app_edge = incoming._projection_edge
s_info = incoming._synapse_information
delay = s_info.synapse_dynamics.delay
# TODO do we need to support None float delay?
delay = float(s_info.synapse_dynamics.delay)
steps = delay * SpynnakerDataView.get_simulation_time_step_per_ms()
max_delay = app_edge.post_vertex.splitter.max_support_delay()
local_delay = steps % max_delay
local_delay = int(steps % max_delay)
delay_stage = 0
pre_vertex = app_edge.pre_vertex
pre_vertex: ApplicationVertex = app_edge.pre_vertex
if steps > max_delay:
delay_stage = (steps // max_delay) - 1
pre_vertex = app_edge.delay_edge.pre_vertex
delay_stage = int(steps // max_delay) - 1
edge = app_edge.delay_edge
assert edge is not None
pre_vertex = edge.pre_vertex
return pre_vertex, local_delay, delay_stage, s_info.partition_id


def get_rinfo_for_spike_source(pre_vertex, partition_id):
def get_rinfo_for_spike_source(
pre_vertex: ApplicationVertex,
partition_id: str) -> Tuple[AppVertexRoutingInfo, int, int]:
"""
Get the routing information for the source of a projection in the
given partition.
Expand All @@ -95,6 +103,7 @@ def get_rinfo_for_spike_source(pre_vertex, partition_id):
r_info = routing_info.get_info_from(
pre_vertex, partition_id)

assert isinstance(r_info, AppVertexRoutingInfo)
n_cores = len(r_info.vertex.splitter.get_out_going_vertices(partition_id))

# If there is 1 core, we don't use the core mask
Expand All @@ -107,7 +116,8 @@ def get_rinfo_for_spike_source(pre_vertex, partition_id):
return r_info, core_mask, mask_shift


def get_sources_for_target(app_vertex: AbstractPopulationVertex):
def get_sources_for_target(app_vertex: AbstractPopulationVertex) -> Dict[
Tuple[ApplicationVertex, str], List[Source]]:
"""
Get all the application vertex sources that will hit the given application
vertex.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, delay: Weight_Delay_In_Types = None):

# Store the sources to avoid recalculation
self.__cached_sources: Dict[ApplicationVertex, Dict[
Tuple[PopulationApplicationVertex, str],
Tuple[ApplicationVertex, str],
List[Source]]] = dict()

@property
Expand Down Expand Up @@ -158,6 +158,7 @@ def write_parameters(
weight_data = list()
for (pre_vertex, part_id), source_infos in sources.items():

assert isinstance(pre_vertex, PopulationApplicationVertex)
# Add connectors as needed
first_conn_index = len(connector_data)
for source in source_infos:
Expand Down Expand Up @@ -250,7 +251,7 @@ def write_parameters(

def __get_sources_for_target(
self, app_vertex: AbstractPopulationVertex) -> Dict[
Tuple[PopulationApplicationVertex, str], List[Source]]:
Tuple[ApplicationVertex, str], List[Source]]:
"""
Get all the application vertex sources that will hit the given
application vertex.
Expand All @@ -261,11 +262,13 @@ def __get_sources_for_target(
information
:rtype: dict(tuple(PopulationApplicationVertex, str), list(Source))
"""
sources = self.__cached_sources.get(app_vertex)
if sources is None:
_sources = self.__cached_sources.get(app_vertex)
if _sources is None:
sources = get_sources_for_target(app_vertex)
self.__cached_sources[app_vertex] = sources
return sources
return sources
else:
return _sources

@staticmethod
def __connector(projection: Projection) -> ConvolutionConnector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def A_plus(self) -> float:
return self.__a_plus

@A_plus.setter
def A_plus(self, new_value) -> None:
def A_plus(self, new_value: float) -> None:
self.__a_plus = new_value

@property
def A_minus(self):
def A_minus(self) -> float:
r"""
:math:`A^-`
Expand All @@ -106,7 +106,7 @@ def A_minus(self):
return self.__a_minus

@A_minus.setter
def A_minus(self, new_value) -> None:
def A_minus(self, new_value: float) -> None:
self.__a_minus = new_value

@overrides(AbstractTimingDependence.is_same_as)
Expand Down

0 comments on commit 83e6f46

Please sign in to comment.