From 74874ac8636e3aabc59c448e52980cd0ab3fb734 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 11 Aug 2022 13:11:23 -0800 Subject: [PATCH] Move predicate filtering out of datamodel Part of the quest to remove the implementation details of predicates out of DataModel and into the things that actually care about them. This slightly changes the behavior in the test because we don't do any filtering either way, so we use ALL predicates from the variable definitions --- dedupe/datamodel.py | 11 +++-------- dedupe/labeler.py | 20 ++++++++++++++++++-- tests/test_training.py | 2 +- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/dedupe/datamodel.py b/dedupe/datamodel.py index f815413af..bb4335658 100644 --- a/dedupe/datamodel.py +++ b/dedupe/datamodel.py @@ -69,17 +69,12 @@ def _field_comparators( yield (var.field, comparator, start, stop) start = stop - def predicates(self, canopies: bool = True) -> set[Predicate]: + @property + def predicates(self) -> set[Predicate]: predicates = set() for var in self.primary_variables: for predicate in var.predicates: - if hasattr(predicate, "index"): - is_canopy = hasattr(predicate, "canopy") - if is_canopy == canopies: - predicates.add(predicate) - else: - predicates.add(predicate) - + predicates.add(predicate) return predicates def distances( diff --git a/dedupe/labeler.py b/dedupe/labeler.py index 734dfd801..0dbaaf7c8 100644 --- a/dedupe/labeler.py +++ b/dedupe/labeler.py @@ -225,6 +225,20 @@ def _sample_indices(self, sample_size: int) -> Iterable[RecordIDPair]: return sample_ids +def _filter_canopy_predicates( + predicates: Iterable[Predicate], canopies: bool +) -> set[Predicate]: + result = set() + for predicate in predicates: + if hasattr(predicate, "index"): + is_canopy = hasattr(predicate, "canopy") + if is_canopy == canopies: + result.add(predicate) + else: + result.add(predicate) + return result + + class DedupeBlockLearner(BlockLearner): def __init__( self, @@ -239,7 +253,8 @@ def __init__( index_data = sample_records(data, 50000) sampled_records = sample_records(index_data, N_SAMPLED_RECORDS) - preds = self.data_model.predicates() + preds = self.data_model.predicates + preds = _filter_canopy_predicates(preds, canopies=True) self.block_learner = training.DedupeBlockLearner( preds, sampled_records, index_data @@ -293,7 +308,8 @@ def __init__( index_data = sample_records(data_2, 50000) sampled_records_2 = sample_records(index_data, N_SAMPLED_RECORDS) - preds = self.data_model.predicates(canopies=False) + preds = self.data_model.predicates + preds = _filter_canopy_predicates(preds, canopies=False) self.block_learner = training.RecordLinkBlockLearner( preds, sampled_records_1, sampled_records_2, index_data diff --git a/tests/test_training.py b/tests/test_training.py index 545e7d6e6..1485a107f 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -34,7 +34,7 @@ def setUp(self): self.block_learner = training.BlockLearner self.block_learner.blocker = dedupe.blocking.Fingerprinter( - self.data_model.predicates() + self.data_model.predicates ) self.block_learner.blocker.index_all( {i: x for i, x in enumerate(self.training_records)}