From 8e8144e8dca2b769b69b2e43cde4b293363e55d9 Mon Sep 17 00:00:00 2001 From: Mengyao Xu Date: Tue, 19 Jul 2022 08:10:08 -0700 Subject: [PATCH] Hashed cross (#587) * Add hashed cross. * Only hashed cross. Co-authored-by: mengyao Co-authored-by: Marc Romeyn Co-authored-by: Gabriel Moreira --- merlin/models/tf/__init__.py | 2 + merlin/models/tf/core/transformations.py | 151 ++++++++++++ tests/unit/tf/core/test_transformations.py | 263 +++++++++++++++++++++ 3 files changed, 416 insertions(+) diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index c71322a2a4..ce90ebf6d2 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -26,6 +26,7 @@ AsSparseFeatures, CategoricalOneHot, ExpandDims, + HashedCross, LabelToOneHot, ) @@ -151,6 +152,7 @@ "AsRaggedFeatures", "AsSparseFeatures", "CategoricalOneHot", + "HashedCross", "ElementwiseSum", "ElementwiseSumItemMulti", "AsTabular", diff --git a/merlin/models/tf/core/transformations.py b/merlin/models/tf/core/transformations.py index 327992d3f2..37834aaa90 100644 --- a/merlin/models/tf/core/transformations.py +++ b/merlin/models/tf/core/transformations.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import warnings from typing import Dict, Optional, Sequence, Union import tensorflow as tf +from keras.layers.preprocessing import preprocessing_utils from merlin.models.config.schema import requires_schema from merlin.models.tf.core.base import Block, PredictionOutput @@ -639,3 +641,152 @@ def _check_items_cardinality(self, item_freq_probs): f"(expected {cardinalities[item_id_feature_name]}" f", got {tf.shape(item_freq_probs)[0]})" ) + + +@Block.registry.register("hashed_cross") +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class HashedCross(TabularBlock): + """A transformation block which crosses categorical features using the "hasing trick". + Conceptually, the transformation can be thought of as: hash(concatenation of features) % + num_bins + Example usage:: + model_body = ParallelBlock( + TabularBlock.from_schema(schema=cross_schema, pre=ml.HashedCross(cross_schema, + num_bins = 1000)), + is_input=True).connect(ml.MLPBlock([64, 32])) + model = ml.Model(model_body, ml.BinaryClassificationTask("click")) + Parameters + ---------- + schema : Schema + The `Schema` with the input features + num_bins : int + Number of hash bins. + output_mode: string + Specification for the output of the layer. Defaults to + `"int"`. Values can be `"int"`, or `"one_hot"` configuring the layer as + follows: + - `"int"`: Return the integer bin indices directly. + - `"one_hot"`: Encodes each individual element in the input into an + array the same size as `num_bins`, containing a 1 at the input's bin + index. + sparse : bool + Boolean. Only applicable to `"one_hot"` mode. If True, returns a + `SparseTensor` instead of a dense `Tensor`. Defaults to False. + output_name : string + Name of output feature, if not specified, default would be + cross___<...> + """ + + def __init__( + self, + schema: Schema, + num_bins: int, + sparse: bool = False, + output_mode: str = "int", + output_name: str = None, + **kwargs, + ): + super().__init__(**kwargs) + + if not (output_mode in ["int", "one_hot"]): + raise ValueError("output_mode must be 'int' or 'one_hot'") + self.schema = schema + self.num_bins = num_bins + self.output_mode = output_mode + self.sparse = sparse + if not output_name: + self.output_name = "cross" + for name in self.schema.column_names: + self.output_name = self.output_name + "_" + name + else: + self.output_name = output_name + + def call(self, inputs): + self._check_at_least_two_inputs() + _inputs = {} + for name in self.schema.column_names: + _inputs[name] = inputs[name] + rank = _inputs[name].shape.rank + if rank < 2: + _inputs[name] = tf.expand_dims(_inputs[name], -1) + if rank < 1: + _inputs[name] = tf.expand_dims(_inputs[name], -1) + + # Perform the cross and convert to dense + output = tf.sparse.cross_hashed(list(_inputs.values()), self.num_bins) + output = tf.sparse.to_dense(output) + + # Fix output shape and downrank to match input rank. + if rank == 2: + # tf.sparse.cross_hashed output shape will always be None on the last + # dimension. Given our input shape restrictions, we want to force shape 1 + # instead. + output = tf.reshape(output, [-1, 1]) + elif rank == 1: + output = tf.reshape(output, [-1]) + elif rank == 0: + output = tf.reshape(output, []) + + # Encode outputs. + outputs = {} + outputs[self.output_name] = preprocessing_utils.encode_categorical_inputs( + output, + output_mode=self.output_mode, + depth=self.num_bins, + sparse=self.sparse, + ) + return outputs + + def compute_output_shape(self, input_shapes): + self._check_at_least_two_inputs() + self._check_input_shape_and_type(input_shapes) + output_shape = {} + one_input = list(input_shapes.values())[0] + output_shape[self.output_name] = preprocessing_utils.compute_shape_for_encode_categorical( + shape=one_input, output_mode=self.output_mode, depth=self.num_bins + ) + return output_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "num_bins": self.num_bins, + "output_mode": self.output_mode, + "sparse": self.sparse, + "output_name": self.output_name, + } + ) + if self.schema: + config["schema"] = schema_utils.schema_to_tensorflow_metadata_json(self.schema) + return config + + def _check_at_least_two_inputs(self): + if len(self.schema) < 2: + raise ValueError( + "`HashedCrossing` should be called on at least two features. " + f"Received: {len(self.schema)} schemas" + ) + for name, column_schema in self.schema.column_schemas.items(): + if Tags.CATEGORICAL not in column_schema.tags: + warnings.warn( + f"Please make sure input features to be categorical, detect {name} " + "has no categorical tag" + ) + + def _check_input_shape_and_type(self, inputs_shapes) -> TabularData: + _inputs_shapes = [] + for name in self.schema.column_names: + _inputs_shapes.append(inputs_shapes[name]) + first_shape = _inputs_shapes[0].as_list() + rank = len(first_shape) + if rank > 2 or (rank == 2 and first_shape[-1] != 1): + raise ValueError( + "All `HashedCrossing` inputs should have shape `[]`, `[batch_size]` " + f"or `[batch_size, 1]`. Received: input {name} with shape={first_shape}" + ) + if not all(x.as_list() == first_shape for x in _inputs_shapes): + raise ValueError( + "All `HashedCrossing` inputs should have equal shape. " + f"Received: inputs={_inputs_shapes}" + ) diff --git a/tests/unit/tf/core/test_transformations.py b/tests/unit/tf/core/test_transformations.py index 5e53d31e28..3ca91c55cd 100644 --- a/tests/unit/tf/core/test_transformations.py +++ b/tests/unit/tf/core/test_transformations.py @@ -16,9 +16,14 @@ import tempfile +import pytest import tensorflow as tf +from tensorflow.test import TestCase import merlin.models.tf as ml +from merlin.io import Dataset +from merlin.models.tf.core.combinators import ParallelBlock, TabularBlock +from merlin.models.tf.utils import testing_utils from merlin.models.utils.schema_utils import create_categorical_column, create_continuous_column from merlin.schema import Schema, Tags @@ -191,3 +196,261 @@ def test_items_weight_tying_with_different_domain_name(): _ = model(inputs) weight_tying_embeddings = model.blocks[2].context.get_embedding("joint_item_id") assert weight_tying_embeddings.shape == (101, 64) + + +def test_hashedcross_scalars(): + test_case = TestCase() + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=3), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=3), + ] + ) + inputs = {} + inputs["cat1"] = tf.constant("A") + inputs["cat2"] = tf.constant(101) + hashed_cross_op = ml.HashedCross(schema=schema, num_bins=10) + outputs = hashed_cross_op(inputs) + output_name, output_value = outputs.popitem() + + assert output_name == "cross_cat1_cat2" + assert output_value.shape.as_list() == [] + test_case.assertAllClose(output_value, 1) + + +def test_hashedcross_1d(): + test_case = TestCase() + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + inputs = {} + inputs["cat1"] = tf.constant(["A", "B", "A", "B", "A"]) + inputs["cat2"] = tf.constant([101, 101, 101, 102, 102]) + hashed_cross_op = ml.HashedCross(schema=schema, num_bins=10) + outputs = hashed_cross_op(inputs) + _, output_value = outputs.popitem() + + assert output_value.shape.as_list() == [5] + test_case.assertAllClose(output_value, [1, 4, 1, 6, 3]) + + +def test_hashedcross_2d(): + test_case = TestCase() + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + inputs = {} + inputs["cat1"] = tf.constant([["A"], ["B"], ["A"], ["B"], ["A"]]) + inputs["cat2"] = tf.constant([[101], [101], [101], [102], [102]]) + hashed_cross_op = ml.HashedCross(schema=schema, num_bins=10) + outputs = hashed_cross_op(inputs) + _, output_value = outputs.popitem() + + assert output_value.shape.as_list() == [5, 1] + test_case.assertAllClose(output_value, [[1], [4], [1], [6], [3]]) + + +def test_hashedcross_output_shape(): + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + inputs_shape = {} + inputs_shape["cat1"] = tf.constant([["A"], ["B"], ["A"], ["B"], ["A"]]).shape + inputs_shape["cat2"] = tf.constant([[101], [101], [101], [102], [102]]).shape + hashed_cross = ml.HashedCross(schema=schema, num_bins=10) + outputs = hashed_cross.compute_output_shape(inputs_shape) + _, output_shape = outputs.popitem() + + assert output_shape == [5, 1] + + +def test_hashedcross_output_shape_one_hot(): + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + inputs_shape = {} + inputs_shape["cat1"] = tf.constant([["A"], ["B"], ["A"], ["B"], ["A"]]).shape + inputs_shape["cat2"] = tf.constant([[101], [101], [101], [102], [102]]).shape + output_name = "cross_out" + hashed_cross = ml.HashedCross( + schema=schema, num_bins=10, output_mode="one_hot", output_name=output_name + ) + outputs = hashed_cross.compute_output_shape(inputs_shape) + _output_name, output_shape = outputs.popitem() + + assert output_shape == [5, 10] + assert _output_name == output_name + + +def test_hashedcross_less_bins(): + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + inputs = {} + inputs["cat1"] = tf.constant([["A"], ["B"], ["C"], ["D"], ["A"], ["B"], ["A"]]) + inputs["cat2"] = tf.constant([[101], [102], [101], [101], [101], [102], [103]]) + hashed_cross_op = ml.HashedCross(schema=schema, num_bins=4, output_mode="one_hot", sparse=True) + outputs = hashed_cross_op(inputs) + _, output_value = outputs.popitem() + output_value = tf.sparse.to_dense(output_value) + + assert output_value.shape.as_list() == [7, 4] + + +def test_hashedcross_onehot_output(): + test_case = TestCase() + + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + inputs = {} + inputs["cat1"] = tf.constant([["A"], ["B"], ["A"], ["B"], ["A"]]) + inputs["cat2"] = tf.constant([[101], [101], [101], [102], [102]]) + hashed_cross_op = ml.HashedCross(schema=schema, num_bins=5, output_mode="one_hot", sparse=True) + outputs = hashed_cross_op(inputs) + _, output_value = outputs.popitem() + output_value = tf.sparse.to_dense(output_value) + + assert output_value.shape.as_list() == [5, 5] + test_case.assertAllClose( + output_value, + [ + [0, 1, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0], + ], + ) + + +def test_hashed_cross_single_input_fails(): + test_case = TestCase() + schema = Schema([create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20)]) + with test_case.assertRaisesRegex(ValueError, "at least two features"): + ml.HashedCross(num_bins=10, schema=schema)([tf.constant(1)]) + + +def test_hashedcross_from_config(): + test_case = TestCase() + schema = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + inputs = {} + inputs["cat1"] = tf.constant([["A"], ["B"], ["A"], ["B"], ["A"]]) + inputs["cat2"] = tf.constant([[101], [101], [101], [102], [102]]) + hashed_cross_op = ml.HashedCross(schema=schema, num_bins=5, output_mode="one_hot", sparse=False) + cloned_hashed_cross_op = ml.HashedCross.from_config(hashed_cross_op.get_config()) + original_outputs = hashed_cross_op(inputs) + cloned_outputs = cloned_hashed_cross_op(inputs) + _, original_output_value = original_outputs.popitem() + _, cloned_output_value = cloned_outputs.popitem() + + test_case.assertAllEqual(cloned_output_value, original_output_value) + + +def test_hashedcrosses(): + test_case = TestCase() + + schema_0 = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=20), + create_categorical_column("cat2", tags=[Tags.CATEGORICAL], num_items=20), + ] + ) + schema_1 = Schema( + [ + create_categorical_column("cat1", tags=[Tags.CATEGORICAL], num_items=2), + create_categorical_column("cat3", tags=[Tags.CATEGORICAL], num_items=3), + ] + ) + inputs = {} + inputs["cat1"] = tf.constant([["A"], ["B"], ["A"], ["B"], ["A"]]) + inputs["cat2"] = tf.constant([[101], [101], [101], [102], [102]]) + inputs["cat3"] = tf.constant([[1], [0], [1], [2], [2]]) + hashed_cross_0 = ml.HashedCross( + schema=schema_0, num_bins=5, output_mode="one_hot", sparse=True, output_name="cross_0" + ) + hashed_cross_1 = ml.HashedCross( + schema=schema_1, num_bins=10, output_mode="one_hot", sparse=True, output_name="cross_1" + ) + hashed_crosses = ParallelBlock([hashed_cross_0, hashed_cross_1]) + outputs = hashed_crosses(inputs) + output_value_0 = outputs["cross_0"] + output_value_0 = tf.sparse.to_dense(output_value_0) + + assert output_value_0.shape.as_list() == [5, 5] + test_case.assertAllClose( + output_value_0, + [ + [0, 1, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0], + ], + ) + output_value_1 = outputs["cross_1"] + output_value_1 = tf.sparse.to_dense(output_value_1) + + assert output_value_1.shape.as_list() == [5, 10] + test_case.assertAllClose( + output_value_1, + [ + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + ], + ) + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_hashedcross_as_pre(ecommerce_data: Dataset, run_eagerly): + cross_schema = ecommerce_data.schema.select_by_name(names=["user_categories", "item_category"]) + body = ParallelBlock( + TabularBlock.from_schema( + schema=cross_schema, pre=ml.HashedCross(cross_schema, num_bins=1000) + ), + is_input=True, + ).connect(ml.MLPBlock([64])) + model = ml.Model(body, ml.BinaryClassificationTask("click")) + + model.compile(optimizer="adam", run_eagerly=run_eagerly) + testing_utils.model_test(model, ecommerce_data) + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_hashedcross_in_model(ecommerce_data: Dataset, run_eagerly): + cross_schema = ecommerce_data.schema.select_by_name(names=["user_categories", "item_category"]) + branches = { + "cross_product": ml.HashedCross(cross_schema, num_bins=1000, is_input=True), + "features": ml.InputBlock(ecommerce_data.schema), + } + body = ParallelBlock(branches, is_input=True).connect(ml.MLPBlock([64])) + model = ml.Model(body, ml.BinaryClassificationTask("click")) + + model.compile(optimizer="adam", run_eagerly=run_eagerly) + testing_utils.model_test(model, ecommerce_data)