diff --git a/ktrain/text/eda.py b/ktrain/text/eda.py index e4030f88..40e227cf 100755 --- a/ktrain/text/eda.py +++ b/ktrain/text/eda.py @@ -369,9 +369,10 @@ def filter(self, obj): "Length of obj is not consistent with the number of documents " + "supplied to get_topic_model" ) - #obj = np.array(obj) if isinstance(obj, list) else obj - #return obj[self.bool_array] + # obj = np.array(obj) if isinstance(obj, list) else obj + # return obj[self.bool_array] from itertools import compress + return list(compress(obj, self.bool_array)) def get_docs(self, topic_ids=[], doc_ids=[], rank=False):