Skip to content

Commit

Permalink
monkey-patch class name and class repr
Browse files Browse the repository at this point in the history
  • Loading branch information
arashbm committed Oct 15, 2024
1 parent 1805b72 commit f968c26
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 20 deletions.
2 changes: 2 additions & 0 deletions src/common_edge_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ nb::class_<EdgeT> define_basic_edge_concept(nb::module_& m) {
return fmt::format("{}", a);
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<EdgeT>{}());
}).def_static("__class_name__", []() {
return type_str<EdgeT>{}();
}).def_static("vertex_type", []() {
return types::handle_for<typename EdgeT::VertexType>();
});
Expand Down
9 changes: 9 additions & 0 deletions src/components.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ struct declare_component_types {
return types::handle_for<typename Component::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<Component>{}());
})
.def_static("__class_name__", []() {
return type_str<Component>{}();
});

nb::implicitly_convertible<
Expand All @@ -87,6 +90,9 @@ struct declare_component_types {
return types::handle_for<typename ComponentSize::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<ComponentSize>{}());
})
.def_static("__class_name__", []() {
return type_str<ComponentSize>{}();
});

using ComponentSizeEstimate =
Expand All @@ -103,6 +109,9 @@ struct declare_component_types {
return types::handle_for<typename ComponentSizeEstimate::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<ComponentSizeEstimate>{}());
})
.def_static("__class_name__", []() {
return type_str<ComponentSizeEstimate>{}();
});
}
};
Expand Down
29 changes: 29 additions & 0 deletions src/distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ struct declare_integral_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>",
type_str<std::geometric_distribution<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<std::geometric_distribution<ResultType>>{}();
});

nb::class_<reticula::delta_distribution<ResultType>>(m,
Expand All @@ -38,6 +41,9 @@ struct declare_integral_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>",
type_str<reticula::delta_distribution<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<reticula::delta_distribution<ResultType>>{}();
});

nb::class_<std::uniform_int_distribution<ResultType>>(m,
Expand All @@ -55,6 +61,9 @@ struct declare_integral_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>",
type_str<std::uniform_int_distribution<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<std::uniform_int_distribution<ResultType>>{}();
});
}
};
Expand All @@ -74,6 +83,9 @@ struct declare_floating_point_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>",
type_str<std::exponential_distribution<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<std::exponential_distribution<ResultType>>{}();
});

nb::class_<reticula::power_law_with_specified_mean<ResultType>>(m,
Expand All @@ -94,6 +106,10 @@ struct declare_floating_point_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<
reticula::power_law_with_specified_mean<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<
reticula::power_law_with_specified_mean<ResultType>>{}();
});

nb::class_<reticula::residual_power_law_with_specified_mean<ResultType>>(m,
Expand All @@ -117,6 +133,10 @@ struct declare_floating_point_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<
reticula::residual_power_law_with_specified_mean<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<
reticula::residual_power_law_with_specified_mean<ResultType>>{}();
});

nb::class_<reticula::hawkes_univariate_exponential<ResultType>>(m,
Expand All @@ -143,6 +163,9 @@ struct declare_floating_point_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<
reticula::hawkes_univariate_exponential<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<reticula::hawkes_univariate_exponential<ResultType>>{}();
});

nb::class_<reticula::delta_distribution<ResultType>>(m,
Expand All @@ -157,6 +180,9 @@ struct declare_floating_point_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>",
type_str<reticula::delta_distribution<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<reticula::delta_distribution<ResultType>>{}();
});

nb::class_<std::uniform_real_distribution<ResultType>>(m,
Expand All @@ -174,6 +200,9 @@ struct declare_floating_point_distributions {
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>",
type_str<std::uniform_real_distribution<ResultType>>{}());
})
.def_static("__class_name__", []() {
return type_str<std::uniform_real_distribution<ResultType>>{}();
});
}
};
Expand Down
2 changes: 2 additions & 0 deletions src/implicit_event_graphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ struct declare_implicit_event_graph_class {
return types::handle_for<typename Net::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<Net>{}());
}).def_static("__class_name__", []() {
return type_str<Net>{}();
});
}
};
Expand Down
3 changes: 3 additions & 0 deletions src/interval_sets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ struct declare_interval_set_types {
return types::handle_for<typename IntSet::ValueType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<IntSet>{}());
})
.def_static("__class_name__", []() {
return type_str<IntSet>{}();
});
}
};
Expand Down
2 changes: 2 additions & 0 deletions src/networks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ struct declare_network_class {
return fmt::format("{}", a);
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<Net>{}());
}).def_static("__class_name__", []() {
return type_str<Net>{}();
}).def_static("edge_type", []() {
return types::handle_for<typename Net::EdgeType>();
}).def_static("vertex_type", []() {
Expand Down
2 changes: 2 additions & 0 deletions src/random_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ void declare_random_states(nb::module_& m) {
"seed"_a, nb::call_guard<nb::gil_scoped_release>())
.def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<std::mt19937_64>{}());
}).def_static("__class_name__", []() {
return type_str<std::mt19937_64>{}();
});
}
30 changes: 16 additions & 14 deletions src/reticula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@
from . import temporal_adjacency
from . import microcanonical_reference_models

# import scalar type "tags"
from ._reticula_ext import (
int64, double, string)

# patch the nb_type_0 metaclass to show the "correct" name for classes
setattr(type(int64), "__repr__", lambda kls: kls.__class_repr__())

_simple_vertex_types = set(_reticula_ext.types.simple_vertex_types)
pair = _generic_attribute(
attr_prefix="pair",
arg_names=("vertex_type", "vertex_type"),
options={(i, j)
for i in _simple_vertex_types for j in _simple_vertex_types},
function_module=_reticula_ext,
api_module_name=__name__)

_static_edge_prefixes = [
"directed_edge", "undirected_edge",
"directed_hyperedge", "undirected_hyperedge"]
Expand All @@ -29,10 +45,6 @@
function_module=_reticula_ext,
api_module_name=__name__))

# import scalar type "tags"
from ._reticula_ext import (
int64, double, string)

from ._reticula_ext import (
mersenne_twister)

Expand Down Expand Up @@ -166,16 +178,6 @@
api_module_name=__name__)


_simple_vertex_types = set(_reticula_ext.types.simple_vertex_types)
pair = _generic_attribute(
attr_prefix="pair",
arg_names=("vertex_type", "vertex_type"),
options={(i, j)
for i in _simple_vertex_types for j in _simple_vertex_types},
function_module=_reticula_ext,
api_module_name=__name__)


_random_network_generic_attrs = [
"random_gnp_graph",
"random_directed_gnp_graph",
Expand Down
19 changes: 13 additions & 6 deletions src/reticula/generic_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,28 @@ def __getitem__(self, keys):
else:
raise AttributeError(
"Provided template type is not a valid "
"option. Valid options are:\n" + "\n".join(
["\t[" + ", ".join([t.__name__ for t in type_list]) + "]"
for type_list in self.options]))
"option. Valid options are:\n\n" +
self.options_list())
return self.function_module.__getattribute__(attr_name)

def options_list(self):
opts = []
for type_list in self.options:
opts.append(
"[" + ", ".join([t.__class_name__() for t in type_list]) + "]")
opts = sorted(opts, key=lambda s: (s.count("["), s))

return "\n".join(opts)

def __call__(self, *args, **kwargs):
raise TypeError(
"No type information was paased to a generic function or type.\n"
"This usually means that you forgot to add square brackets\n"
"and type information before parentheses, e.g.:\n\n"
f" {self.api_module_name}.{self.attr_prefix}"
f"[{', '.join(self.arg_names)}]"
"\n\nValid options are:\n\n" + "\n".join(
[" [" + ", ".join([t.__name__ for t in type_list]) + "]"
for type_list in self.options]))
"\n\nValid options are:\n\n" +
self.options_list())
def __repr__(self) -> str:
return f"{self.api_module_name}.{self.attr_prefix}"\
f"[{", ".join(self.arg_names)}]"
3 changes: 3 additions & 0 deletions src/scalar_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ struct declare_scalar_types {
python_type_str<T>().c_str())
.def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<T>{}());
})
.def_static("__class_name__", []() {
return type_str<T>{}();
});
}
};
Expand Down
8 changes: 8 additions & 0 deletions src/temporal_adjacency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct declare_temporal_adjacency_class {
return types::handle_for<typename Simple::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<Simple>{}());
}).def_static("__class_name__", []() {
return type_str<Simple>{}();
});

using LWT = reticula::temporal_adjacency::limited_waiting_time<EdgeT>;
Expand All @@ -52,6 +54,8 @@ struct declare_temporal_adjacency_class {
return types::handle_for<typename LWT::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<LWT>{}());
}).def_static("__class_name__", []() {
return type_str<LWT>{}();
});

if constexpr (std::is_floating_point_v<typename EdgeT::TimeType>) {
Expand All @@ -76,6 +80,8 @@ struct declare_temporal_adjacency_class {
return types::handle_for<typename Exp::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<Exp>{}());
}).def_static("__class_name__", []() {
return type_str<Exp>{}();
});
}

Expand All @@ -101,6 +107,8 @@ struct declare_temporal_adjacency_class {
return types::handle_for<typename Geom::VertexType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<Geom>{}());
}).def_static("__class_name__", []() {
return type_str<Geom>{}();
});
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/temporal_clusters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ struct declare_temporal_cluster_types {
return types::handle_for<typename Cluster::AdjacencyType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<Cluster>{}());
}).def_static("__class_name__", []() {
return type_str<Cluster>{}();
});

using ClusterSize = reticula::temporal_cluster_size<EdgeT, AdjT>;
Expand All @@ -104,6 +106,8 @@ struct declare_temporal_cluster_types {
return types::handle_for<typename ClusterSize::AdjacencyType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<ClusterSize>{}());
}).def_static("__class_name__", []() {
return type_str<ClusterSize>{}();
});


Expand All @@ -130,6 +134,8 @@ struct declare_temporal_cluster_types {
return types::handle_for<typename ClusterSizeEstimate::AdjacencyType>();
}).def_static("__class_repr__", []() {
return fmt::format("<class '{}'>", type_str<ClusterSizeEstimate>{}());
}).def_static("__class_name__", []() {
return type_str<ClusterSizeEstimate>{}();
});
}
};
Expand Down

0 comments on commit f968c26

Please sign in to comment.