From e3df1792343da2d11f72cde366257dfae0cf7011 Mon Sep 17 00:00:00 2001 From: Vladimir Brkic Date: Mon, 16 Sep 2024 09:57:56 +0000 Subject: [PATCH] Enable RGG Issue #313 --- forge/forge/op_repo/__init__.py | 4 + forge/forge/op_repo/pybuda_operators.py | 81 ++++ forge/forge/op_repo/pytorch_operators.py | 73 +++ forge/test/operators/utils/compat.py | 1 + forge/test/random/conftest.py | 14 + forge/test/random/rgg/__init__.py | 48 ++ forge/test/random/rgg/algorithms.py | 425 ++++++++++++++++++ forge/test/random/rgg/base.py | 328 ++++++++++++++ forge/test/random/rgg/config.py | 41 ++ forge/test/random/rgg/datatypes.py | 201 +++++++++ forge/test/random/rgg/frameworks.py | 106 +++++ forge/test/random/rgg/pybuda/__init__.py | 3 + .../random/rgg/pybuda/generated_model.jinja2 | 52 +++ forge/test/random/rgg/pybuda/model.py | 20 + forge/test/random/rgg/pytorch/__init__.py | 3 + .../random/rgg/pytorch/generated_model.jinja2 | 50 +++ forge/test/random/rgg/pytorch/model.py | 23 + forge/test/random/rgg/shapes.py | 283 ++++++++++++ forge/test/random/rgg/utils.py | 300 +++++++++++++ forge/test/random/test_graphs.py | 230 ++++++++++ 20 files changed, 2286 insertions(+) create mode 100644 forge/forge/op_repo/pybuda_operators.py create mode 100644 forge/forge/op_repo/pytorch_operators.py create mode 100644 forge/test/random/rgg/__init__.py create mode 100644 forge/test/random/rgg/algorithms.py create mode 100644 forge/test/random/rgg/base.py create mode 100644 forge/test/random/rgg/config.py create mode 100644 forge/test/random/rgg/datatypes.py create mode 100644 forge/test/random/rgg/frameworks.py create mode 100644 forge/test/random/rgg/pybuda/__init__.py create mode 100644 forge/test/random/rgg/pybuda/generated_model.jinja2 create mode 100644 forge/test/random/rgg/pybuda/model.py create mode 100644 forge/test/random/rgg/pytorch/__init__.py create mode 100644 forge/test/random/rgg/pytorch/generated_model.jinja2 create mode 100644 forge/test/random/rgg/pytorch/model.py create mode 100644 forge/test/random/rgg/shapes.py create mode 100644 forge/test/random/rgg/utils.py create mode 100644 forge/test/random/test_graphs.py diff --git a/forge/forge/op_repo/__init__.py b/forge/forge/op_repo/__init__.py index 1f1596708..9f5096930 100644 --- a/forge/forge/op_repo/__init__.py +++ b/forge/forge/op_repo/__init__.py @@ -15,6 +15,8 @@ from .datatypes import OperandNumInt, OperandNumTuple, OperandNumRange from .datatypes import TensorShape, OperatorParam, OperatorParamNumber, OperatorDefinition, OperatorRepository from .datatypes import ShapeCalculationContext +from .pybuda_operators import pybuda_operator_repository +from .pytorch_operators import pytorch_operator_repository __ALL__ = [ "OperandNumInt", @@ -26,4 +28,6 @@ "OperatorDefinition", "OperatorRepository", "ShapeCalculationContext", + "pybuda_operator_repository", + "pytorch_operator_repository", ] diff --git a/forge/forge/op_repo/pybuda_operators.py b/forge/forge/op_repo/pybuda_operators.py new file mode 100644 index 000000000..29d3e0bf3 --- /dev/null +++ b/forge/forge/op_repo/pybuda_operators.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# PyBuda repostiory operators + + +from .datatypes import OperatorDefinition, OperatorRepository +from .datatypes import OperatorParamNumber + + +# TODO describe operand and shapes +_OPERATORS = [ + + # Unary operators + OperatorDefinition("exp", "forge.op.Exp", 1), + OperatorDefinition("reciprocal", "forge.op.Reciprocal", 1), + OperatorDefinition("buffer", "forge.op.Buffer", 1), + OperatorDefinition("sqrt", "forge.op.Sqrt", 1), + OperatorDefinition("relu", "forge.op.Relu", 1), + OperatorDefinition("leaky_relu", "forge.op.LeakyRelu", 1, forward_params=[ + OperatorParamNumber("alpha", float, 0, 100), + ]), + OperatorDefinition("nop", "forge.op.Identity", 1), + OperatorDefinition("gelu", "forge.op.Gelu", 1), + OperatorDefinition("log", "forge.op.Log", 1), + OperatorDefinition("sigmoid", "forge.op.Sigmoid", 1), + OperatorDefinition("clip", "forge.op.Clip", 1, forward_params=[ + OperatorParamNumber("min", float, 0, 100), + OperatorParamNumber("max", float, 0, 100), + ]), + OperatorDefinition("sine", "forge.op.Sine", 1), + OperatorDefinition("cosine", "forge.op.Cosine", 1), + OperatorDefinition("abs", "forge.op.Abs", 1), + OperatorDefinition("tanh", "forge.op.Tanh", 1), + OperatorDefinition("cumsum", "forge.op.CumSum", 1), + OperatorDefinition("argmax", "forge.op.Argmax", 1), + OperatorDefinition("logical_not", "forge.op.LogicalNot", 1), + OperatorDefinition("dropout", "forge.op.Dropout", 1), + OperatorDefinition("pow", "forge.op.Pow", 1, forward_params=[ + OperatorParamNumber("exponent", float, 0, 100), + ]), + OperatorDefinition("tilizer", "forge.op.Tilize", 1), + + # Binary operators + OperatorDefinition("add", "forge.op.Add", 2), + OperatorDefinition("divide", "forge.op.Divide", 2), + OperatorDefinition("subtract", "forge.op.Subtract", 2), + OperatorDefinition("multiply", "forge.op.Multiply", 2), + OperatorDefinition("maximum", "forge.op.Max", 2), + OperatorDefinition("minimum", "forge.op.Min", 2), + OperatorDefinition("heaviside", "forge.op.Heaviside", 2), + OperatorDefinition("binary_stack", "forge.op.BinaryStack", 2), + OperatorDefinition("power", "forge.op.Power", 2), + OperatorDefinition("greater", "forge.op.Greater", 2), + OperatorDefinition("greater_equal", "forge.op.GreaterEqual", 2), + OperatorDefinition("less", "forge.op.Less", 2), + OperatorDefinition("less_equal", "forge.op.LessEqual", 2), + OperatorDefinition("equal", "forge.op.Equal", 2), + OperatorDefinition("not_equal", "forge.op.NotEqual", 2), + OperatorDefinition("logical_and", "forge.op.LogicalAnd", 2), + + # Nary operators + OperatorDefinition("where", "forge.op.Where", 3), + # OperatorDefinition("index_copy", "forge.op.IndexCopy", 3), # Bug #2705 + OperatorDefinition("interleave", "forge.op.Interleave", (1,10), forward_params=[ + OperatorParamNumber("axis", int, -3, -3), + OperatorParamNumber("stride", int, 1, 1), + ]), + OperatorDefinition("concatenate", "forge.op.Concatenate", (1, 10), forward_params=[ + OperatorParamNumber("axis", int, -10, 10), + ]), + OperatorDefinition("stack", "forge.op.Stack", (2,4), forward_params=[ + OperatorParamNumber("axis", int, 1, 10), + ]), + + OperatorDefinition("matmul", "forge.op.Matmul", 2), + # OperatorDefinition("sparse_matmul", "forge.op.SparseMatmul", 2), +] + + +pybuda_operator_repository = OperatorRepository([op for op in _OPERATORS]) diff --git a/forge/forge/op_repo/pytorch_operators.py b/forge/forge/op_repo/pytorch_operators.py new file mode 100644 index 000000000..ade1b10ee --- /dev/null +++ b/forge/forge/op_repo/pytorch_operators.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# PyTorch repostiory operators + + +from .datatypes import OperatorDefinition, OperatorRepository +from .datatypes import OperatorParamNumber + + +# TODO describe operand and shapes +_OPERATORS = [ + OperatorDefinition("linear", "torch.nn.Linear", 1, instantiate=True, constructor_params=[ + OperatorParamNumber("in_features", int, 10, 50), + OperatorParamNumber("out_features", int, 10, 50), + ]), + OperatorDefinition("conv2d", "torch.nn.Conv2d", 1, instantiate=True, constructor_params=[ + OperatorParamNumber("in_channels", int, 10, 50), + OperatorParamNumber("out_channels", int, 10, 50), + OperatorParamNumber("kernel_size", int, 3, 3), + OperatorParamNumber("stride", int, 1, 1), + OperatorParamNumber("padding", int, 1, 1), + ]), + OperatorDefinition("relu", "torch.relu", 1), + OperatorDefinition("sqrt", "torch.sqrt", 1), + OperatorDefinition("tanh", "torch.tanh", 1), + # OperatorDefinition("add", "torch.add", 1), + OperatorDefinition("add", "torch.add", 2), + OperatorDefinition("sub", "torch.sub", 2), + OperatorDefinition("mul", "torch.mul", 2), + OperatorDefinition("div", "torch.div", 2), + OperatorDefinition("ge", "torch.ge", 2), + + # Non-linear activation functions + # HARDTANH = OperatorDefinition("hardtanh", 1) + # HARDWISH = OperatorDefinition("hardwish", 1) + # RELU6 = OperatorDefinition("relu6", 1) + # ELU = OperatorDefinition("elu", 1) + # SELU = OperatorDefinition("selu", 1) + # CELU = OperatorDefinition("celu", 1) + # LEACKY_RELU = OperatorDefinition("leaky_relu", 1) + # PRELU = OperatorDefinition("prelu", 1) + # RRELU = OperatorDefinition("rrelu", 1) + # GLU = OperatorDefinition("glu", 1) + # GELU = OperatorDefinition("gelu", 1) + # LOGSIGMOID = OperatorDefinition("logsigmoid", 1) + # HARDSHRINK = OperatorDefinition("hardshrink", 1) + # TANHSHRINK = OperatorDefinition("tanhshrink", 1) + # SOFTSIGN = OperatorDefinition("softsign", 1) + # SOFTPLUS = OperatorDefinition("softplus", 1) + # SOFTMIN = OperatorDefinition("softmin", 1) + # SOFTMAX = OperatorDefinition("softmax", 1) + # SOFTSHRINK = OperatorDefinition("softshrink", 1) + # GUMBEL_SOFTMAX = OperatorDefinition("gumbel_softmax", 1) + # LOG_SOFTMAX = OperatorDefinition("log_softmax", 1) + # TANH = OperatorDefinition("tanh", 1) + # SIGMOID = OperatorDefinition("sigmoid", 1) + # HARDSIGMOID = OperatorDefinition("hardsigmoid", 1) + # SILU = OperatorDefinition("silu", 1) + # MISH = OperatorDefinition("mish", 1) + # BATCH_NORM = OperatorDefinition("batch_norm", 1) + # GROUP_NORM = OperatorDefinition("group_norm", 1) + # INSTANCE_NORM = OperatorDefinition("instance_norm", 1) + # LAYER_NORM = OperatorDefinition("layer_norm", 1) + # LOCAL_RESPONSE_NORM = OperatorDefinition("local_response_norm", 1) + # NORMALIZE = OperatorDefinition("normalize", 1) + + OperatorDefinition("matmul", "torch.matmul", 2), + OperatorDefinition("eltwise", "torch.add", 2), +] + + +pytorch_operator_repository = OperatorRepository([op for op in _OPERATORS]) diff --git a/forge/test/operators/utils/compat.py b/forge/test/operators/utils/compat.py index 6192d73fb..72863d42e 100644 --- a/forge/test/operators/utils/compat.py +++ b/forge/test/operators/utils/compat.py @@ -7,6 +7,7 @@ import forge import torch +from loguru import logger from loguru import logger from typing import Optional, List diff --git a/forge/test/random/conftest.py b/forge/test/random/conftest.py index 4cc94ff11..742b19172 100644 --- a/forge/test/random/conftest.py +++ b/forge/test/random/conftest.py @@ -6,6 +6,8 @@ import os import forge +from .rgg import get_randomizer_config_default + test_rg = random.Random() seeds = [] @@ -27,6 +29,18 @@ def run_test(test_index, random_seeds): yield def pytest_generate_tests(metafunc): + if "randomizer_config" in metafunc.fixturenames: + configs = [] + for (build_model_from_code,) in [ + (True,), + # (False,), + ]: + config = get_randomizer_config_default() + # config.build_model_from_code = build_model_from_code + # config.debug_forward = not build_model_from_code + # config.print_code = not build_model_from_code + configs.append(config) + metafunc.parametrize("randomizer_config", configs) if "test_index" in metafunc.fixturenames: if "RANDOM_TEST_COUNT" in os.environ: test_count = int(os.environ["RANDOM_TEST_COUNT"]) diff --git a/forge/test/random/rgg/__init__.py b/forge/test/random/rgg/__init__.py new file mode 100644 index 000000000..ca356a809 --- /dev/null +++ b/forge/test/random/rgg/__init__.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + + +from .datatypes import TensorShape +from .datatypes import RandomizerConstantNode +from .datatypes import RandomizerInputNode, RandomizerNode, ExecutionContext, RandomizerParameters, RandomizerGraph, RandomizerConfig +from .datatypes import NodeShapeCalculationContext +from .datatypes import RandomizerTestContext +from .datatypes import ModelBuilder, Framework +from .config import get_randomizer_config_default +from .utils import StrUtils, GraphUtils +from .utils import DebugUtils +from .base import GraphBuilder +from .base import RandomizerRunner, RandomizerCodeGenerator, process_test +from .frameworks import Frameworks +from .frameworks import FrameworkTestUtils +from .algorithms import GraphNodeSetup +from .algorithms import RandomGraphAlgorithm + +__all__ = [ + "TensorShape", + "RandomizerConstantNode", + "RandomizerInputNode", + "RandomizerNode", + "ExecutionContext", + "RandomizerParameters", + "RandomizerGraph", + "RandomizerConfig", + "NodeShapeCalculationContext", + "RandomizerTestContext", + "ModelBuilder", + "Framework", + "get_randomizer_config_default", + "StrUtils", + "GraphUtils", + "DebugUtils", + "Framework", + "GraphBuilder", + "RandomizerRunner", + "RandomizerCodeGenerator", + "process_test", + "Frameworks", + "FrameworkTestUtils" + "GraphNodeSetup", + "RandomGraphAlgorithm", +] diff --git a/forge/test/random/rgg/algorithms.py b/forge/test/random/rgg/algorithms.py new file mode 100644 index 000000000..5b64dd22c --- /dev/null +++ b/forge/test/random/rgg/algorithms.py @@ -0,0 +1,425 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Implementation of randomization algrorithms + + +import math + +from typing import List +from loguru import logger + +from forge.op_repo import OperatorDefinition + +from .utils import Timer + +from .datatypes import RandomizerGraph, RandomizerTestContext +from .datatypes import NodeShapeCalculationContext +from .datatypes import RandomizerInputNode +from .datatypes import RandomizerConstantNode +from .datatypes import InvalidShape +from .datatypes import Framework +from .base import RandomizerNode, GraphBuilder +from .utils import RandomUtils, StrUtils, NodeUtils +from .shapes import AdjustParameters + +from test.operators.utils import RateLimiter + + +class GraphNodeSetup: + '''Common step for completion of setting up and validating of the graph + after it's built by a graph builder algorithm''' + + # Whether to always generate unique variables for each node + always_unique_variables = False + + @classmethod + def init_nodes_names(cls, test_context: RandomizerTestContext): + """ + Initializes the nodes names of a graph. + + This method does following things: + 1. Sets the index for each node. + 2. Stores output values if they are needed as explicit input for a later operator. + """ + + nodes = test_context.graph.nodes + + # Setting node.index + op_index_cnt = 0 + for node in nodes: + op_index_cnt += 1 + node.index = op_index_cnt + + # Storing output values if needed as explicit input for later operator + logger.trace("Setting out_value for nodes") + for node in nodes: + # setting default output variable name + node.out_value = "v" + for input_node in node.inputs: + if input_node is not None and not input_node.constant and (not NodeUtils.is_previous_node(node, input_node) or cls.always_unique_variables): + # overriding default output variable name + input_node.out_value = input_node.operator_name + logger.trace(f"Set out_value = {input_node.out_value}") + + @classmethod + def init_nodes_inputs(cls, test_context: RandomizerTestContext): + """ + Setting input and contant nodes for open nodes. + + Args: + test_context (RandomizerTestContext): The test context. + + Returns: + None + """ + graph = test_context.graph + nodes = test_context.graph.nodes + + rng_shape = test_context.rng_shape + + constant_input_rate_limiter = RateLimiter(rng_shape, 100, test_context.randomizer_config.constant_input_rate) + same_inputs_rate_limiter = RateLimiter(rng_shape, 100, test_context.randomizer_config.same_inputs_percent_limit) + + logger.trace("Setting input nodes for open nodes") + open_nodes = NodeUtils.get_open_nodes(nodes) + logger.trace(f"Open nodes {StrUtils.nodes_to_str(open_nodes)}") + + # Setting input nodes for open nodes + for node in open_nodes: + input_shapes = node.input_shapes + # list of input nodes that are already connected to the node + used_input_nodes: List[RandomizerInputNode] = [] + for open_input_index in NodeUtils.get_open_input_indices(node): + input_shape = input_shapes[open_input_index] + + # There must be at least one input node for forward method + if len(graph.input_nodes) > 0 and constant_input_rate_limiter.is_allowed(): + # Creates a new constant node with the same shape + constant_node = RandomizerConstantNode(out_value=None, input_shape=input_shape) + logger.trace(f"Allowed constant input {constant_node.out_value} -> {node.name}[{open_input_index}] due to rate limit not exceeded: {constant_input_rate_limiter.limit_info()}") + # Stores the new constant node in the graph constant nodes + graph.constant_nodes.append(constant_node) + input_node = constant_node + else: + # list of all graph input nodes with the same shape as the input shape + input_nodes_with_same_shape = [input_node for input_node in graph.input_nodes if input_node.input_shape == input_shape] + # list of input nodes with the same shape that are not already connected to the node + input_nodes_with_same_shape_unused = [input_node for input_node in input_nodes_with_same_shape if input_node not in used_input_nodes] + if len(input_nodes_with_same_shape_unused) > 0: + # reuse existing input node with the same shape that is not already connected to the node + input_node = input_nodes_with_same_shape_unused[0] + used_input_nodes.append(input_node) + else: + # there are no input nodes with the same shape that are not already connected to the node + # check if same input value is allowed + # there must be at least one input node with the same shape to allow repeat + allow_repeat = len(input_nodes_with_same_shape) > 0 + + if allow_repeat: + if not same_inputs_rate_limiter.is_allowed(): + logger.trace(f"Not allowed same input value {input_node.out_value} -> {node.name}[{open_input_index}] due to rate limit exceeded: {same_inputs_rate_limiter.limit_info()}") + allow_repeat = False + + if allow_repeat: + input_node = rng_shape.choice(input_nodes_with_same_shape) + logger.trace(f"Allowed same input value {input_node.out_value} -> {node.name}[{open_input_index}] due to rate limit not exceeded: {same_inputs_rate_limiter.limit_info()}") + + else: + # create a new input node with the same shape since there are no unused input nodes with the same shape or repeat is not allowed + input_node = RandomizerInputNode(out_value=f"in_value{len(graph.input_nodes)+1}", input_shape=input_shape) + used_input_nodes.append(input_node) + # store the new input node in the graph input nodes + graph.input_nodes.append(input_node) + + # connect the input node to the open node input + node.inputs[open_input_index] = input_node + + # Assign constant node values after connecting inputs + iconst_index = 0 + for i, constant_node in enumerate(graph.constant_nodes): + if constant_node.out_value is None: + iconst_index += 1 + constant_node.out_value = f"iconst{iconst_index}" + + @classmethod + def init_node_params(cls, node: RandomizerNode, test_context: RandomizerTestContext): + """ + Generates random parameters for specified node. + + Args: + node (RandomizerNode): The node. + test_context (RandomizerTestContext): The test context. + + Returns: + None + """ + rng_params = test_context.rng_params + + node.constructor_kwargs = RandomUtils.constructor_kwargs(node.operator, node.constructor_kwargs, rng_params) + node.forward_kwargs = RandomUtils.forward_kwargs(node.operator, node.forward_kwargs, rng_params) + + @classmethod + def validate_graph(cls, graph: RandomizerGraph): + '''Validates the graph + 1. Validates the number of inputs for each node + 2. Validates operator class type + + Args: + graph (RandomizerGraph): The graph to validate + + Raises: + Exception: If the number of inputs for a node does not match the configured input number. + Exception: If the node operator is not of type RandomizerOperator. + ''' + nodes = graph.nodes + + # Validation of input configuration + for node in nodes: + if node.input_num and node.input_num > 1: + if NodeUtils.num_of_open_inputs(node) > 0: + raise Exception(f"Closed {NodeUtils.num_of_closed_inputs(node)}/{node.input_num} inputs, missing {NodeUtils.num_of_open_inputs(node)} inputs for node {node.node_info}") + + # Validation of operator and layer types + for node in nodes: + if node.operator and not isinstance(node.operator, OperatorDefinition): + raise Exception(f"Step operator is wrong type {node.node_info} expected RandomizerOperator got {type(node.operator)}") + + @classmethod + def prepare_graph(cls, test_context: RandomizerTestContext): + + graph = test_context.graph + + logger.trace("Initializing nodes") + cls.init_nodes_names(test_context) + cls.init_nodes_inputs(test_context) + logger.trace("Nodes initialized") + + logger.trace("Validating graph") + cls.validate_graph(graph) + logger.trace("Graph validated") + + logger.trace("Serializing nodes") + nodes_str = StrUtils.nodes_to_str(graph.nodes) + logger.trace("Nodes serialized") + logger.trace(f"Nodes: \n{nodes_str}") + + +class RandomGraphAlgorithm(GraphBuilder): + '''Implementation of the random graph building algorithm''' + + # Log building progress every n seconds + buliding_progress_rate_in_sec = 2 + + def __init__(self, framework: Framework, randomizer_config): + super(RandomGraphAlgorithm, self).__init__(randomizer_config) + self.framework = framework + self.operators = framework.operator_repository.operators + + def _get_random_operator(self, rng): + return rng.choice(self.operators) + + @classmethod + def _init_default_constructor_params(cls, node: RandomizerNode): + '''Initializing default constructor parameters based on input and output shapes''' + # Operator specific settings + # TODO abstract this + if len([param for param in node.operator.constructor_params if param.name == "in_features"]) == 1: + node.constructor_kwargs["in_features"] = node.input_shapes[0][-1] + if len([param for param in node.operator.constructor_params if param.name == "out_features"]) == 1: + node.constructor_kwargs["out_features"] = node.output_shape[-1] + if len([param for param in node.operator.constructor_params if param.name == "in_channels"]) == 1: + node.constructor_kwargs["in_channels"] = node.input_shapes[0][1] + if len([param for param in node.operator.constructor_params if param.name == "out_channels"]) == 1: + node.constructor_kwargs["out_channels"] = node.output_shape[1] + + @classmethod + def _adjust_params(cls, node: RandomizerNode, test_context: RandomizerTestContext): + + function_name = f"{node.operator.name}_adjust" + if function_name in AdjustParameters.__dict__: + logger.trace(f"Found method {function_name}") + adjust_params_method = AdjustParameters.__dict__[function_name] + adjust_params_method(node, test_context) + else: + pass + + # Build graph of random operators via random graph building algorithm + # Graph contains between num_of_nodes_min and num_of_nodes_max nodes + # Graph is constructed backwards starting from end node + # In each step a random operator is selected and a new node is created + # Output of new node is connected as input to the multiple open nodes randomly selected which has the same input shape + # When new node is connected to more than one node, graph constructs a fork join + # Output shape of first node is random + # Output shape of other nodes is based on next input shape of a randomly picked open node + # Input shapes for each node are calculated based on output shape of the node + def build_graph(self, test_context: RandomizerTestContext): + '''Implementation of the random graph building algorithm''' + + graph = test_context.graph + nodes = graph.nodes + + rng_graph = test_context.rng_graph + rng_shape = test_context.rng_shape + + fork_join_counter = 0 + fork_join_max = test_context.randomizer_config.num_fork_joins_max + + constant_input_rate_limiter = RateLimiter(rng_shape, 100, test_context.randomizer_config.constant_input_rate) + same_inputs_rate_limiter = RateLimiter(rng_shape, 100, test_context.randomizer_config.same_inputs_percent_limit) + + # Context object for shape calculation, node will be set later in the loop + shape_calculation_context = NodeShapeCalculationContext(node=None, test_context=test_context) + + # Building the graph with number of nodes between num_of_nodes_min and num_of_nodes_max + num_of_nodes = rng_graph.randint(self.randomizer_config.num_of_nodes_min, self.randomizer_config.num_of_nodes_max) + + build_duration = Timer() + last_build_duration = 0 + + for node_index in range(num_of_nodes, 0, -1): + + # Logging graph building progress + duration = math.floor(build_duration.get_duration()) + + # Log building progress every n seconds + if duration - last_build_duration >= self.buliding_progress_rate_in_sec: + last_build_duration = duration + logger.debug(f"Building node {num_of_nodes-node_index}/{num_of_nodes} in {duration} sec") + + first_node = len(nodes) == 0 + + # Choose operator randomly based on rng + op1 = self._get_random_operator(rng_graph) + + node_name = f"op{node_index}[{op1.name}]" + + # Find all open nodes + open_nodes = NodeUtils.get_open_nodes(nodes) + + # Select output shape for the new node + if first_node: + # For the first node set output shape as random shape + output_shape = RandomUtils.random_shape_from_config(self.randomizer_config, rng_shape) + else: + # For other nodes, output shape is based on input shapes of a random open node + # Select one of open nodes randomly + random_open_node: RandomizerNode = rng_graph.choice(open_nodes) + # Setting output shape based on input shapes of the random open node + input_shapes = random_open_node.input_shapes + open_input_indices = [i for i in NodeUtils.get_open_input_indices(random_open_node)] + open_input_index = open_input_indices[rng_graph.randint(0, len(open_input_indices) - 1)] + output_shape = input_shapes[open_input_index] + + # Find all other open nodes with input shape mathing the output shape of new node + open_nodes = NodeUtils.get_open_nodes_with_input_shape(nodes, output_shape) + + # Random nodes are selected by matching the same input shape as new node + # Closing multiple nodes will construct fork joins + random_nodes: List[RandomizerNode] + + if not first_node: + # There must be at least one node to close + subset_count_min = max(1, len(open_nodes) // 2) + subset_count_max = len(open_nodes) + # Choose a random number of nodes to close + subset_count = rng_graph.randint(subset_count_min, subset_count_max) + + # Limit number of fork joins + subset_count = min(subset_count, fork_join_max - fork_join_counter + 1) + + # Increase fork join counter + new_fork_join = subset_count - 1 + if new_fork_join > 0: + logger.trace(f"Constructing {new_fork_join} new fork join(s) from operator {node_name}") + fork_join_counter += new_fork_join + + # Select random subset of open nodes to close + random_nodes = rng_graph.sample(open_nodes, subset_count) + + if len(random_nodes) > 1: + for random_node in random_nodes[1:]: + logger.trace(f"Constructing new fork join from operator {node_name} -> {random_node.name}") + + else: + random_nodes = [] + + # Closing nodes are all random open nodes + closing_nodes = random_nodes + + # Creating new node + node = RandomizerNode(operator=op1, output_shape=output_shape) + + # Initializing node parameters + # Calculating input shapes may require input parameters for its calculation + GraphNodeSetup.init_node_params(node, test_context) + + # Initializing random inputs based on operand num range + NodeUtils.init_random_inputs(node, test_context) + + try: + # Try to adjust parameters to avoid invalid shapes + self._adjust_params(node, test_context) + except InvalidShape as e: + # Skip node if shape doesn't support fixing + logger.warning(f"Invalid shape -> Skip node {node_name} because params adjustment failed: {e}") + # TODO repeat node generation with different operator + continue + + # Saving input shapes for the new node + shape_calculation_context.node = node + try: + node.input_shapes = NodeUtils.calc_input_shapes(node, shape_calculation_context) + except InvalidShape as e: + # Skip node if shape is invalid + logger.warning(f"Invalid shape calculation -> Skip node {node_name}: {e}") + # TODO repeat node generation with different operator + continue + + # Initializing default constructor parameters based on input and output shapes + self._init_default_constructor_params(node) + + for closing_node in closing_nodes: + node_connected = False + for open_input_index in NodeUtils.get_open_input_indices(closing_node): + # check input shape of a closing node open input + if closing_node.input_shapes[open_input_index] == node.output_shape: + + # Limit number of same inputs on same node + if node_connected: + if not same_inputs_rate_limiter.is_allowed(): + logger.trace(f"Skipping same input node connection op{node_index} {node.name} -> {closing_node.name}[{open_input_index}] due to rate limit exceeded: {same_inputs_rate_limiter.limit_info()}") + continue + else: + logger.trace(f"Allowed same input node connection op{node_index} {node.name} -> {closing_node.name}[{open_input_index}] due to rate limit not exceeded: {same_inputs_rate_limiter.limit_info()}") + closing_node.inputs[open_input_index] = node + node_connected = True + + nodes.insert(0, node) + + # Connecting constants randomly to current node inputs + open_nodes = NodeUtils.get_open_nodes(nodes) + open_nodes_count = len(open_nodes) + input_shapes = node.input_shapes + for open_input_index in NodeUtils.get_open_input_indices(node): + input_shape = input_shapes[open_input_index] + # Skip connecting constant input for last open input to avoid disconnected graph + if open_nodes_count > 1 or NodeUtils.num_of_open_inputs(node) > 1: + if constant_input_rate_limiter.is_allowed(): + # Creates a new constant node with the same shape + constant_node = RandomizerConstantNode(out_value=None, input_shape=input_shape) + logger.trace(f"Allowed constant input {constant_node.out_value} -> {node.name}[{open_input_index}] due to rate limit not exceeded: {constant_input_rate_limiter.limit_info()}") + # Stores the new constant node in the graph constant nodes + graph.constant_nodes.insert(0, constant_node) + # Connects the input node to the open node input + node.inputs[open_input_index] = constant_node + + # Assign constant node values + for i, constant_node in enumerate(graph.constant_nodes): + constant_node.out_value = f"nconst{i+1}" + + logger.trace(f"Graph built with {len(nodes)} nodes") + + logger.trace("Preparing graph") + GraphNodeSetup.prepare_graph(test_context) + logger.trace("Graph prepared") diff --git a/forge/test/random/rgg/base.py b/forge/test/random/rgg/base.py new file mode 100644 index 000000000..d008b9fe4 --- /dev/null +++ b/forge/test/random/rgg/base.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Base classes for random randomizer generator + + +from typing import Type +from loguru import logger +from jinja2 import Environment, FileSystemLoader +import os +import random + +import forge + +from forge import ForgeModule +from forge.verify import VerifyConfig #, verify_module +from test.operators.utils.compat import verify_module +from forge.op_repo import OperatorRepository +from test.operators.utils import ShapeUtils +from test.operators.utils.compat import TestDevice +from .utils import Timer +from .datatypes import ModelBuilder, Framework +from .datatypes import RandomizerNode, RandomizerGraph, RandomizerParameters, RandomizerConfig, ExecutionContext +from .datatypes import RandomizerTestContext +from .utils import StrUtils, GraphUtils +from .utils import timeout, TimeoutException + + +class GraphBuilder: + """ + GraphBuilder is an interface that each graph building algorithm should implement. + GraphBuilder encapsulates the logic for generating random graphs. + """ + def __init__(self, randomizer_config: RandomizerConfig): + self.randomizer_config = randomizer_config + + def build_graph(self, test_context: RandomizerTestContext) -> None: + """ + Generate test graph with input shape needed for validation. + + Args: + test_context (RandomizerTestContext): The context for the randomizer test. + + Raises: + Exception: This method is not implemented. + """ + raise Exception("Method build_graph() not implemented") + + def get_name(self): + return self.__class__.__name__ + + +# Translates randomized graph into framework NN model code +class RandomizerCodeGenerator: + + def __init__(self, template_dir: str): + self.template = Environment(loader=FileSystemLoader(template_dir)).get_template('generated_model.jinja2') + + def constructor_kwargs(self, node: RandomizerNode): + return StrUtils.kwargs_str(**node.constructor_kwargs) + + def forward_args(self, node: RandomizerNode) -> str: + args_str = ", ".join([f"inputs[{i}]" for i in range(node.input_num)]) + return args_str + + def forward_kwargs(self, node: RandomizerNode) -> str: + return StrUtils.kwargs_str(**node.forward_kwargs) + + def generate_code(self, test_context: RandomizerTestContext, test_format: bool = True) -> str: + # TODO setup random seed in generated test function + + parameters = test_context.parameters + template = self.template + + code_str = template.render( + randomizer_config = test_context.randomizer_config, + graph_builder_name = parameters.graph_builder_name, + test_id = StrUtils.test_id(test_context), + test_format = test_format, + test_index = parameters.test_index, + random_seed = parameters.random_seed, + graph=test_context.graph, + constructor_kwargs=self.constructor_kwargs, + forward_args=self.forward_args, + forward_kwargs=self.forward_kwargs, + reduce_microbatch_size=ShapeUtils.reduce_microbatch_size, + ExecutionContext=ExecutionContext, + ) + + return code_str + + +class RandomizerModelProviderFromSourceCode: + + def __init__(self, code_generator: RandomizerCodeGenerator, model_builder: ModelBuilder): + self.code_generator = code_generator + self.model_builder = model_builder + + def build_model(self, test_context: RandomizerTestContext) -> ForgeModule: + ''' + Build model from generated test model class. + + Args: + test_context (RandomizerTestContext): The context for the randomizer test. + + Returns: + ForgeModule: The Forge model. + ''' + GeneratedTestModel = self._get_model_class(test_context) + model = self.model_builder.build_model(test_context, GeneratedTestModel) + return model + + def _get_model_class(self, test_context: RandomizerTestContext) -> Type: + class_name = self._get_model_class_name(test_context) + test_code_str = self.code_generator.generate_code(test_context, test_format=False) + + GeneratedTestModel = self._get_model_class_from_code(class_name, test_code_str) + return GeneratedTestModel + + def _get_model_class_name(self, test_context: RandomizerTestContext): + parameters = test_context.parameters + class_name = f"GeneratedTestModel_{parameters.test_index}_{parameters.random_seed}" + return class_name + + def _get_model_class_from_code(self, class_name: str, class_string: str) -> Type: + # python magic, create class from class code string + namespace = {} + exec(class_string, namespace) + + GeneratedTestModel = namespace[class_name] + + return GeneratedTestModel + + +# TODO move RandomizerRunner and process_test to runner.py +class RandomizerRunner: + """ + The RandomizerRunner class is used for processing randomized tests. + + Attributes: + test_context (RandomizerTestContext): The context for the randomizer test. + model_provider (RandomizerModelProviderFromSourceCode): The model provider for generating tests. + + Methods: + init_nodes(): Initializes the nodes for generating tests. Sets the index for each node and + stores output values if they are needed as explicit input for a later operator. + """ + def __init__(self, test_context: RandomizerTestContext, modelBuilder: ModelBuilder): + self.test_context = test_context + self.code_generator = RandomizerCodeGenerator(f"forge/test/random/rgg/{StrUtils.text_to_snake_case(test_context.parameters.framework.template_name)}") + self.model_provider = RandomizerModelProviderFromSourceCode(self.code_generator, modelBuilder) + + def generate_code(self) -> str: + """ + Generates a test source code with test function for the randomized graph. + + Returns: + str: The generated code. + """ + return self.code_generator.generate_code(self.test_context, test_format=True) + + def build_graph(self, graph_builder: GraphBuilder) -> None: + self.test_context.graph = RandomizerGraph() + + # Initialize random number generators for graph building + self.test_context.rng_graph = random.Random(self.test_context.parameters.random_seed) + # Initialize random number generators for shape generation + self.test_context.rng_shape = random.Random(self.test_context.parameters.random_seed) + # Initialize random number generators for parameters + self.test_context.rng_params = random.Random(self.test_context.parameters.random_seed) + + graph_builder.build_graph(self.test_context) + + def build_model(self) -> ForgeModule: + model = self.model_provider.build_model(self.test_context) + return model + + def verify(self, model: ForgeModule) -> None: + + verification_timeout = self.test_context.randomizer_config.verification_timeout + + try: + @timeout(verification_timeout) + def verify_model_timeout() -> None: + self.verify_model(model) + + verify_model_timeout() + except TimeoutException as e: + logger.error(f"Module verification takes too long {e}.") + raise e + + def verify_model(self, model: ForgeModule) -> None: + """ + Verify the model by building it and performing validation via Forge. + The method is usually implemented once per framework. + + Args: + test_context (RandomizerTestContext): The context for the randomizer test. + model (ForgeModule): The Forge model to verify. + + Raises: + Exception: This method is not implemented. + """ + + parameters = self.test_context.parameters + input_shapes = GraphUtils.get_input_shapes(self.test_context.graph) + + # Reset default data format to None set by conftest.py + forge.config.g_compiler_config.default_df_override = None + + # verify Forge model + verify_module(model, input_shapes) + # verify_module(model, input_shapes, + # VerifyConfig(devtype=parameters.test_device.devtype, arch=parameters.test_device.arch)) + + def save_test(self, test_code_str: str, failing_test: bool = False): + test_dir = self.test_context.randomizer_config.test_dir + if failing_test: + test_dir = f"{test_dir}/failing_tests" + test_code_str = test_code_str.replace("# @pytest.mark.xfail", "@pytest.mark.xfail") + test_code_file_name = f"{test_dir}/test_gen_model_{StrUtils.test_id(self.test_context)}.py" + + if not os.path.exists(test_dir): + logger.info(f"Creating test directory {test_dir}") + os.makedirs(test_dir) + + logger.info(f"Saving test to {test_code_file_name}") + with open(test_code_file_name, "w") as f: + f.write(test_code_str) + + def run(self, graph_builder: GraphBuilder): + """ + Process the randomizer generator. + Usually the only method from this class that is called from the test. + + This method generates randomizer model config, initializes nodes, and performs verification via Forge. + + Args: + test_context (RandomizerTestContext): The context for the randomizer test. + graph_builder (GraphBuilder): The graph builder for generating tests. + """ + logger.debug("-------------- Process Randomizer Generator -------------------") + randomizer_config = self.test_context.randomizer_config + parameters = self.test_context.parameters + logger.debug(f"Parameters test_index: {parameters.test_index} random_seed: {parameters.random_seed} test_device: {parameters.test_device}") + + # build random graph for the specified parameters + logger.trace("Building graph started") + graph_duration = Timer() + try: + self.build_graph(graph_builder) + except Exception as e1: + # Try to save test source code to file for debugging purposes if an error occurs + try: + test_code_str = self.generate_code() + if randomizer_config.save_tests: + # Saving test source code to file for debugging purposes + self.save_test(test_code_str, failing_test=True) + except Exception as e2: + logger.error(f"Error while saving test: {e2}") + # Re-raise the original exception from graph building + raise e1 + logger.trace("Building graph completed") + graph = self.test_context.graph + logger.debug(f"Generating graph model {GraphUtils.short_description(graph)}") + if randomizer_config.print_graph: + # printing generated graph to console for debugging purposes + logger.debug(f"Graph config:\n{StrUtils.to_str(graph)}") + + # generate test source code with test function + test_code_str = self.generate_code() + + if randomizer_config.print_code: + # printing generated test source code to console for debugging purposes + logger.debug(f"Generated code: \n{test_code_str}") + + if randomizer_config.save_tests: + # saving test source code to file for debugging purposes + self.save_test(test_code_str, failing_test=False) + + logger.info(f"Graph built in: {graph_duration.get_duration():.4f} seconds") + + if randomizer_config.run_test: + # instantiate Forge model + model = self.build_model() + # perform model validation + try: + verify_duration = Timer() + verify_successful = False + self.verify(model) + verify_successful = True + finally: + if not verify_successful: + if randomizer_config.save_failing_tests: + # saving error test source code to file for debugging purposes + self.save_test(test_code_str, failing_test=True) + logger.debug(f"Test verified in: {verify_duration.get_duration():.4f} seconds") + else: + logger.info("Skipping test run") + + +def process_test(test_name: str, test_index: int, random_seed: int, test_device: TestDevice, randomizer_config: RandomizerConfig, graph_builder_type: Type[GraphBuilder], framework: Framework): + ''' + Process a single randomizer test. + + Args: + test_name (str): The name of the test used for generating test code, test file name, etc. + test_index (int): The index of the test. + random_seed (int): The random seed for the test. + test_device (TestDevice): The device for the test. + randomizer_config (RandomizerConfig): The configuration for the randomizer. + graph_builder_type (Type[GraphBuilder]): The graph builder type (algorithm) for the test. + framework (Framework): The test framework for the test. + ''' + # TODO read framwework from randomizer_config + + # instantiate graph_builder + graph_builder = graph_builder_type(framework, randomizer_config) + # instantiate parameters + parameters = RandomizerParameters(test_index, random_seed, test_device, framework=framework, graph_builder_name=graph_builder.get_name()) + # instantiate test_context + test_context = RandomizerTestContext(randomizer_config=randomizer_config, parameters=parameters, graph=None, test_name=test_name) + # instantiate graph_builder + model_builder = framework.ModelBuilderType() + # instantiate runner + runner = RandomizerRunner(test_context, model_builder) + # process test + runner.run(graph_builder) diff --git a/forge/test/random/rgg/config.py b/forge/test/random/rgg/config.py new file mode 100644 index 000000000..21dc8efb5 --- /dev/null +++ b/forge/test/random/rgg/config.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Configuration of the randomizer + + +import os + +from .datatypes import RandomizerConfig + + +# TODO introduce environment variables to set these values +# TODO read config from file +def get_randomizer_config_default(): + randomizer_config = RandomizerConfig ( + print_graph = False, + print_code = True, + run_test = True, + save_tests = True, + save_failing_tests = True, + # build_model_from_code = False, + debug_shapes = False, + verify_shapes = False, + verification_timeout = int(os.environ.get("VERIFICATION_TIMEOUT", 60)), + # TODO ranges + # dim_min=int(os.environ.get("MIN_DIM", 3)), + dim_min=int(os.environ.get("MIN_DIM", 4)), # Until #2722 is resolved + dim_max=int(os.environ.get("MAX_DIM", 4)), + op_size_per_dim_min=int(os.environ.get("MIN_OP_SIZE_PER_DIM", 16)), + op_size_per_dim_max=int(os.environ.get("MAX_OP_SIZE_PER_DIM", 64)), # by default run with smaller sizes + # op_size_per_dim_max=int(os.environ.get("MAX_OP_SIZE_PER_DIM", 512)), + op_size_quantization=int(os.environ.get("OP_SIZE_QUANTIZATION", 1)), + microbatch_size_min=int(os.environ.get("MIN_MICROBATCH_SIZE", 1)), + microbatch_size_max=int(os.environ.get("MAX_MICROBATCH_SIZE", 8)), + num_of_nodes_min=int(os.environ.get("NUM_OF_NODES_MIN", 5)), + num_of_nodes_max=int(os.environ.get("NUM_OF_NODES_MAX", 10)), + num_fork_joins_max=int(os.environ.get("NUM_OF_FORK_JOINS_MAX", 50)), + constant_input_rate=int(os.environ.get("CONSTANT_INPUT_RATE", 20)), + same_inputs_percent_limit=int(os.environ.get("SAME_INPUTS_PERCENT_LIMIT", 10)), + ) + return randomizer_config diff --git a/forge/test/random/rgg/datatypes.py b/forge/test/random/rgg/datatypes.py new file mode 100644 index 000000000..91d64ccba --- /dev/null +++ b/forge/test/random/rgg/datatypes.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Generic test model randomizer + + +from typing import Dict, List, Type, Optional, Final +from dataclasses import dataclass, field +import random +import torch + +from forge import ForgeModule + +from forge.op_repo import TensorShape +from forge.op_repo import OperatorDefinition +from forge.op_repo import OperatorRepository +from forge.op_repo import ShapeCalculationContext + +from test.operators.utils.compat import TestDevice + + +@dataclass +class RandomizerInputNode: + constant: Final[bool] = field(default=False, init=False) + out_value: str + input_shape: TensorShape + + +@dataclass +class RandomizerConstantNode: + constant: Final[bool] = field(default=True, init=False) + out_value: str + input_shape: TensorShape + + +@dataclass +class RandomizerNode: + constant: Final[bool] = field(default=False, init=False) + index: Optional[int] = None + out_value: Optional[str] = None + operator: Optional[OperatorDefinition] = None + input_num: int = field(init=False) + inputs: List['RandomizerNode'] = field(init=False) + constructor_kwargs: Dict[str, object] = field(default_factory=dict) + forward_kwargs: Dict[str, object] = field(default_factory=dict) + input_shapes: List[TensorShape] = field(default_factory=list) + output_shape: TensorShape = None + + def __post_init__(self): + # List of input nodes is initialized with None values for each input + # Inputs will be set later during graph construction + self.input_num = self.operator.input_num_range.operands_min + self.init_inputs() + + def init_inputs(self): + self.inputs = [None for _ in range(self.input_num)] + + @property + def operator_name(self): + return f"op{self.index}" + + @property + def layer_name(self): + return f"l{self.index}" + + @property + def node_name(self): + return self.operator_name if self.operator.is_operator else self.layer_name + + @property + def name(self): + return self.operator.name + + @property + def node_info(self): + return f"{self.node_name} {self.name}" + + +class NodeShapeCalculationContext(ShapeCalculationContext): + + def __init__(self, node: RandomizerNode, test_context: 'RandomizerTestContext'): + self.node = node + self.test_context = test_context + + @property + def operator(self) -> OperatorDefinition: + return self.node.operator + + @property + def input_num(self) -> int: + return self.node.input_num + + @property + def constructor_kwargs(self) -> Dict[str, object]: + return self.node.constructor_kwargs + + @property + def forward_kwargs(self) -> Dict[str, object]: + return self.node.forward_kwargs + + @property + def output_shape(self) -> TensorShape: + return self.node.output_shape + + @property + def rng_shape(self) -> random.Random: + return self.test_context.rng_shape + + +@dataclass +class ExecutionContext: + values: Dict + last_value: torch.Tensor + node: Optional[RandomizerNode] = None + inputs: Optional[List[torch.Tensor]] = None + + +@dataclass +class Framework: + + template_name: str + framework_name: str + ModelBuilderType: Type["ModelBuilder"] + operator_repository: OperatorRepository + + +@dataclass +class RandomizerParameters: + test_index: int + random_seed: int + test_device: TestDevice + framework: Framework + graph_builder_name: str + + +# TODO load from file +@dataclass +class RandomizerGraph: + # parameters: RandomizerParameters + nodes: List[RandomizerNode] = field(default_factory=list) + input_nodes: List[RandomizerInputNode] = field(default_factory=list) + constant_nodes: List[RandomizerConstantNode] = field(default_factory=list) + # graph_builder: Optional[str] = None + + +@dataclass +class RandomizerConfig: + print_graph: bool = True + print_code: bool = False + run_test: bool = True + test_dir:str = "forge/test/random_tests" + save_tests: bool = False + save_failing_tests: bool = False + # build_model_from_code: bool = False # TODO remove obsoleted + debug_shapes: bool = False, + verify_shapes: bool = False, + verification_timeout: int = 60 + dim_min: int = 3 + dim_max: int = 4 + op_size_per_dim_min: int = 16 + op_size_per_dim_max: int = 512 + op_size_quantization: int = 1 + microbatch_size_min: int = 1 + microbatch_size_max: int = 8 + num_of_nodes_min: int = 5 + num_of_nodes_max: int = 10 + num_fork_joins_max: int = 50 + constant_input_rate: int = 20 + same_inputs_percent_limit: int = 10 + + +@dataclass +class RandomizerTestContext: + randomizer_config: RandomizerConfig + parameters: RandomizerParameters + # framework: Framework + # graph_builder: GraphBuilder + graph: Optional[RandomizerGraph] # graph will be constructed later during test processing + test_name: str = "Default" + + # random number generators for graph building + rng_graph: Optional[random.Random] = None + # random number generators for shape generation + rng_shape: Optional[random.Random] = None + # random number generators for parameters + rng_params: Optional[random.Random] = None + + +class ModelBuilder: + ''' + ModelBuilder is an interface that each framework should implement for instantiated model instances from a previously generated test model class. + ''' + + def build_model(self, graph: RandomizerGraph, GeneratedTestModel: Type) -> ForgeModule: + raise Exception("Method build_model() not implemented") + + +class InvalidShape(Exception): + + def __init__(self, message): + super().__init__(message) diff --git a/forge/test/random/rgg/frameworks.py b/forge/test/random/rgg/frameworks.py new file mode 100644 index 000000000..a4265ad9d --- /dev/null +++ b/forge/test/random/rgg/frameworks.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# In depth testing of PyBuda models with one randomly selected operation + + +from enum import Enum + +from loguru import logger +from typing import Tuple, Type +from copy import copy + +from .datatypes import Framework, ModelBuilder +from .shapes import OperatorShapes + +from .pybuda.model import PyBudaModelBuilder +from .pytorch.model import PyTorchModelBuilder + +from forge.op_repo import pybuda_operator_repository +from forge.op_repo import pytorch_operator_repository +from forge.op_repo import OperatorDefinition +from forge.op_repo import OperatorRepository + + +class FrameworkTestUtils: + + @classmethod + def copy_framework(cls, framework: Framework, framework_name: str = None, skip_operators: Tuple[str] = []) -> Framework: + framework0 = framework + framework = copy(framework) + if framework_name is not None: + framework.framework_name = framework_name + framework.operator_repository = copy(framework.operator_repository) + cls.skip_operators(framework, skip_operators) + assert len(framework.operator_repository.operators) + len(skip_operators) == len(framework0.operator_repository.operators), "Operators count should match after skipping operators" + return framework + + @classmethod + def skip_operators(cls, framework: Framework, skip_operators: Tuple[str] = []) -> None: + initial_operator_count = len(framework.operator_repository.operators) + framework.operator_repository.operators = [op for op in framework.operator_repository.operators if op.name not in skip_operators] + logger.debug(f"Skipped num of operators for framework {framework.framework_name}: {initial_operator_count} -> {len(framework.operator_repository.operators)}") + assert len(framework.operator_repository.operators) + len(skip_operators) == initial_operator_count, "Operators count should match after skipping operators" + + @classmethod + def allow_operators(cls, framework: Framework, allow_operators: Tuple[str] = []) -> None: + initial_operator_count = len(framework.operator_repository.operators) + framework.operator_repository.operators = [op for op in framework.operator_repository.operators if op.name in allow_operators] + logger.debug(f"Allowed num of operators for framework {framework.framework_name}: {initial_operator_count} -> {len(framework.operator_repository.operators)}") + assert len(allow_operators) == len(framework.operator_repository.operators), "Operators count should match allowing skipping operators" + + @classmethod + def copy_operator(cls, framework: Framework, operator_name: str) -> OperatorDefinition: + operators = framework.operator_repository.operators + + i, operator = next(((i, operator) for i, operator in enumerate(operators) if operator.name == operator_name), (None, None)) + if not operator: + return None + + operator = copy(operator) + operators[i] = operator + return operator + + @classmethod + def set_calc_input_shapes(cls, framework: Framework, allow_operators: Tuple[str] = []) -> None: + ''' Implicitly set calc_input_shapes for all operators in the framework ''' + logger.debug(f"Setting calc_input_shapes for framework {framework.framework_name}") + for operator in framework.operator_repository.operators: + function_name = f"{operator.name}_inputs" + if function_name in OperatorShapes.__dict__: + logger.debug(f"Found method {function_name} for {operator.name}") + operator.calc_input_shapes = OperatorShapes.__dict__[function_name] + else: + operator.calc_input_shapes = OperatorShapes.same_input_shapes + + +class Frameworks(Enum): + ''' Register of all frameworks ''' + + @staticmethod + def build_framework(template_name: str, framework_name: str, ModelBuilderType: Type[ModelBuilder], operator_repository: OperatorRepository): + framework = Framework( + template_name=template_name, + framework_name=framework_name, + ModelBuilderType=ModelBuilderType, + operator_repository=operator_repository, + ) + + framework = FrameworkTestUtils.copy_framework(framework=framework) + + FrameworkTestUtils.set_calc_input_shapes(framework) + + return framework + + PYBUDA = build_framework( + template_name="PyBuda", + framework_name="PyBuda", + ModelBuilderType=PyBudaModelBuilder, + operator_repository=pybuda_operator_repository, + ) + PYTORCH = build_framework( + template_name="PyTorch", + framework_name="PyTorch", + ModelBuilderType=PyTorchModelBuilder, + operator_repository=pytorch_operator_repository, + ) diff --git a/forge/test/random/rgg/pybuda/__init__.py b/forge/test/random/rgg/pybuda/__init__.py new file mode 100644 index 000000000..2332467ef --- /dev/null +++ b/forge/test/random/rgg/pybuda/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 diff --git a/forge/test/random/rgg/pybuda/generated_model.jinja2 b/forge/test/random/rgg/pybuda/generated_model.jinja2 new file mode 100644 index 000000000..4147e502d --- /dev/null +++ b/forge/test/random/rgg/pybuda/generated_model.jinja2 @@ -0,0 +1,52 @@ +import torch +import forge +{% if test_format %} +import pytest +from forge.verify import VerifyConfig +from test.operators.utils.compat import verify_module +{% endif %}{% if randomizer_config.debug_shapes %} +from test.random.rgg import DebugUtils{% endif %} +from forge import ForgeModule, Tensor + +{# TODO replace empty new lines with spaces to keep formatting in pipeline #} +class GeneratedTestModel_{{ test_index }}_{{ random_seed }}(ForgeModule): + # graph_builder: {{ graph_builder_name }} + # id: {{ test_id }} + # params.test_index: {{ test_index }} + # params.random_seed: {{ random_seed}} + + def __init__(self, module_name: str = "Buda Test GeneratedTestModel_{{ test_id }}"): + super(GeneratedTestModel_{{ test_index }}_{{ random_seed }}, self).__init__(module_name) + self.testname = "Operator Test GeneratedTestModel_{{ test_id }}" +{% for node in graph.nodes %}{% if node.operator.is_layer %} + self.{{ node.layer_name }} = {{ node.operator.full_name }}({{ constructor_kwargs(node=node) }}){% endif %}{% endfor %} + {% for constant_node in graph.constant_nodes %} + self.add_constant("{{ constant_node.out_value }}") + self.set_constant("{{ constant_node.out_value }}", torch.randn({{ reduce_microbatch_size(constant_node.input_shape) }})){% endfor %} + + def forward(self{% for node in graph.input_nodes %}, + {{ node.out_value }}: forge.Tensor{% endfor %} + ) -> forge.Tensor: + {% for node in graph.nodes %} + + # shapes: {{ node.input_shapes }} -> {{ node.output_shape }} + inputs = [{% for input_node in node.inputs %}{% if input_node.constant %}self.get_constant("{{ input_node.out_value }}"){% else %}{{ input_node.out_value }}{% endif %}{% if not loop.last %}, {% endif %}{% endfor %}]{% if randomizer_config.debug_shapes %} + print(f"{{ node.layer_name }} inputs: {DebugUtils.format_tensors(inputs)}"){% endif %}{% if node.operator.is_layer %} + {{ node.out_value }} = self.{{ node.layer_name }}(inputs[0]){% else %} + {{ node.out_value }} = {% if node.operator.forward_code %}{{node.operator.forward_code()}}{% else %}{{ node.operator.full_name }}('{{ node.node_name }}', {{ forward_args(node=node) }}, {{ forward_kwargs(node=node) }}){% endif %}{% endif %}{% if randomizer_config.verify_shapes %} + assert {{ node.out_value }}.shape.dims == {{ reduce_microbatch_size(node.output_shape) }}, f"Unexpected output shape of {{ node.out_value }} { {{ node.out_value }}.shape } <> {{ reduce_microbatch_size(node.output_shape) }}"{% endif %}{% endfor %} + + return v +{% if test_format %} + +# @pytest.mark.xfail(reason="The model triggers a bug.") +def test_gen_model_{{ test_index }}_{{ random_seed }}(test_device): + + input_shapes = [ + {% for input_node in graph.input_nodes %}{{ input_node.input_shape }}, + {% endfor %}] + model = GeneratedTestModel_{{ test_index }}_{{ random_seed }}("pytest_gen_model_{{ test_id }}") + + verify_module(model, input_shapes) + +{% endif %} diff --git a/forge/test/random/rgg/pybuda/model.py b/forge/test/random/rgg/pybuda/model.py new file mode 100644 index 000000000..021d1717c --- /dev/null +++ b/forge/test/random/rgg/pybuda/model.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Building PyBuda models + + +from typing import Type +from loguru import logger + +from forge import ForgeModule + +from .. import RandomizerGraph, ModelBuilder, StrUtils + + +class PyBudaModelBuilder(ModelBuilder): + + def build_model(self, graph: RandomizerGraph, GeneratedTestModel: Type[ForgeModule]) -> ForgeModule: + module_name = f"gen_model_pytest_{StrUtils.test_id(graph)}" + pybuda_model = GeneratedTestModel(module_name) + return pybuda_model diff --git a/forge/test/random/rgg/pytorch/__init__.py b/forge/test/random/rgg/pytorch/__init__.py new file mode 100644 index 000000000..2332467ef --- /dev/null +++ b/forge/test/random/rgg/pytorch/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 diff --git a/forge/test/random/rgg/pytorch/generated_model.jinja2 b/forge/test/random/rgg/pytorch/generated_model.jinja2 new file mode 100644 index 000000000..36222160f --- /dev/null +++ b/forge/test/random/rgg/pytorch/generated_model.jinja2 @@ -0,0 +1,50 @@ +import torch +import forge +{% if test_format %} +import pytest +from forge.verify import VerifyConfig +from test.operators.utils.compat import verify_module +{% endif %}{% if randomizer_config.debug_shapes %} +from test.random.rgg import DebugUtils{% endif %} + +{# TODO replace empty new lines with spaces to keep formatting in pipeline #} +class GeneratedTestModel_{{ test_index }}_{{ random_seed }}(torch.nn.Module): + # graph_builder: {{ graph_builder_name }} + # id: {{ test_id }} + # params.test_index: {{ test_index }} + # params.random_seed: {{ random_seed}} + + def __init__(self): + super(GeneratedTestModel_{{ test_index }}_{{ random_seed }}, self).__init__() +{% for node in graph.nodes %}{% if node.operator.is_layer %} + self.{{ node.layer_name }} = {{ node.operator.full_name }}({{ constructor_kwargs(node=node) }}){% endif %}{% endfor %} + {% for constant_node in graph.constant_nodes %} + self.{{ constant_node.out_value }} = torch.randn({{ reduce_microbatch_size(constant_node.input_shape) }}){% endfor %} + + def forward(self{% for node in graph.input_nodes %}, + {{ node.out_value }}: torch.Tensor{% endfor %} + ) -> torch.Tensor: + {% for node in graph.nodes %} + + # shapes: {{ node.input_shapes }} -> {{ node.output_shape }} + inputs = [{% for input_node in node.inputs %}{% if input_node.constant %}self.{{ input_node.out_value }}{% else %}{{ input_node.out_value }}{% endif %}{% if not loop.last %}, {% endif %}{% endfor %}]{% if randomizer_config.debug_shapes %} + print(f"{{ node.layer_name }} inputs: {DebugUtils.format_tensors(inputs)}"){% endif %}{% if node.operator.is_layer %} + {{ node.out_value }} = self.{{ node.layer_name }}(inputs[0]){% else %} + {{ node.out_value }} = {% if node.operator.forward_code %}{{node.operator.forward_code()}}{% else %}{{ node.operator.full_name }}({{ forward_args(node=node) }}, {{ forward_kwargs(node=node) }}){% endif %}{% endif %}{% if randomizer_config.verify_shapes %} + assert {{ node.out_value }}.shape == {{ reduce_microbatch_size(node.output_shape) }}, f"Unexpected output shape of {{ node.out_value }} { {{ node.out_value }}.shape } <> {{ reduce_microbatch_size(node.output_shape) }}"{% endif %}{% endfor %} + + return v +{% if test_format %} + +# @pytest.mark.xfail(reason="The model triggers a bug.") +def test_gen_model_{{ test_index }}_{{ random_seed }}(test_device): + + input_shapes = [ + {% for input_node in graph.input_nodes %}{{ input_node.input_shape }}, + {% endfor %}] + pytorch_model = GeneratedTestModel_{{ test_index }}_{{ random_seed }}() + # model = forge.PyTorchModule("pytest_gen_model_{{ test_id }}", pytorch_model) + + verify_module(pytorch_model, input_shapes) + +{% endif %} diff --git a/forge/test/random/rgg/pytorch/model.py b/forge/test/random/rgg/pytorch/model.py new file mode 100644 index 000000000..adf889ae3 --- /dev/null +++ b/forge/test/random/rgg/pytorch/model.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Building PyTorch models + + +import torch + +from typing import Type +from loguru import logger + +from forge import PyTorchModule, ForgeModule + +from .. import RandomizerGraph, ModelBuilder, StrUtils + + +class PyTorchModelBuilder(ModelBuilder): + + def build_model(self, graph: RandomizerGraph, GeneratedTestModel: Type[torch.nn.Module]) -> ForgeModule: + pytorch_model = GeneratedTestModel() + # module_name = f"gen_model_pytest_{StrUtils.test_id(graph)}" + # pybuda_model = PyTorchModule(module_name, pytorch_model) + return pytorch_model diff --git a/forge/test/random/rgg/shapes.py b/forge/test/random/rgg/shapes.py new file mode 100644 index 000000000..92c348383 --- /dev/null +++ b/forge/test/random/rgg/shapes.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Calculation of input shapes from output shapes for the specified operator + + +import random + +from loguru import logger +from typing import List + +from .datatypes import TensorShape +from .datatypes import RandomizerNode +from .datatypes import InvalidShape +from .datatypes import RandomizerTestContext +from .datatypes import ShapeCalculationContext + +from .utils import RandomUtils + + +class OperatorShapes: + + @staticmethod + def same_input_shapes(calculation_context: ShapeCalculationContext) -> List[TensorShape]: + input_num, output_shape = calculation_context.input_num, calculation_context.output_shape + # each input operand has the same shape as the output + return [output_shape for _ in range(input_num)] + + @staticmethod + def linear_inputs(calculation_context: ShapeCalculationContext) -> List[TensorShape]: + output_shape = calculation_context.output_shape + test_context: RandomizerTestContext = calculation_context.test_context + # linear layer changes the last dimension of the input shape + batch_shape = output_shape[:-1] + n = output_shape[-1] + n = randomize_size(len(batch_shape), test_context) + input_shapes = [batch_shape + (n,)] + return input_shapes + + # FIXME: conv2d in PyTorch not working properly in all cases + @staticmethod + def conv2d_inputs(calculation_context: ShapeCalculationContext) -> List[TensorShape]: + output_shape = calculation_context.output_shape + test_context: RandomizerTestContext = calculation_context.test_context + shape1 = output_shape[:1] + shape2 = output_shape[2:] + n = output_shape[1] + n = randomize_size(len(shape1), test_context) + input_shapes = [shape1 + (n,) + shape2] + return input_shapes + + @staticmethod + def matmul_inputs(calculation_context: ShapeCalculationContext) -> List[TensorShape]: + output_shape = calculation_context.output_shape + test_context: RandomizerTestContext = calculation_context.test_context + batch_shape = output_shape[:-2] + m = output_shape[-2] + n = output_shape[-1] + # calculates inner dimension based on one of output shape dimensions + # dim is wrong for second operand + q = randomize_size(len(batch_shape) + 1, test_context) + input_shapes = [batch_shape + (m,q), batch_shape + (q,n)] + return input_shapes + + @staticmethod + def interleave_inputs(calculation_context: ShapeCalculationContext) -> List[TensorShape]: + # Interleave joins the input shapes along the specified axis + # It requires that axis dimension is divisible by the number of inputs + input_num, output_shape = calculation_context.input_num, calculation_context.output_shape + forward_kwargs = calculation_context.forward_kwargs + axis = forward_kwargs["axis"] + + if axis >= len(output_shape) or axis < 0: + axis %= len(output_shape) + + logger.trace(f"Interleave axis = {axis} output_shape = {output_shape}") + + shape1 = output_shape[:axis] + mid_size = output_shape[axis] + shape2 = output_shape[axis+1:] + + if mid_size < input_num: + raise InvalidShape(f"Output shape {output_shape} is too small mid_size={mid_size} < input_num={input_num}") + + if mid_size % input_num != 0: + raise InvalidShape(f"Output shape {output_shape} axis[{axis}]={mid_size} is not divisible by input_num={input_num}") + + dim = mid_size // input_num + + input_shapes = [shape1 + (dim,) + shape2 for _ in range(input_num)] + return input_shapes + + @staticmethod + def concatenate_inputs(calculation_context: ShapeCalculationContext) -> List[TensorShape]: + # Concatenate joins the input shapes along the specified axis + # It requires that axis dimension can be split into input_num parts + input_num, output_shape = calculation_context.input_num, calculation_context.output_shape + test_context: RandomizerTestContext = calculation_context.test_context + rng_shape = test_context.rng_shape + forward_kwargs = calculation_context.forward_kwargs + axis = forward_kwargs["axis"] + + if axis >= len(output_shape) or axis < 0: + axis %= len(output_shape) + + logger.trace(f"Concatenate axis = {axis} output_shape = {output_shape}") + + shape1 = output_shape[:axis] + mid_size = output_shape[axis] + shape2 = output_shape[axis+1:] + + if mid_size < input_num: + raise InvalidShape(f"Output shape {output_shape} is too small mid_size={mid_size} < input_num={input_num}") + + dims = [] + for input_pos in range(input_num): + reserved_size = input_num - input_pos - 1 + mid_range = mid_size - reserved_size + logger.trace(f"input_num = {input_num} mid_size = {mid_size} reserved_size = {reserved_size} mid_range = {mid_range}") + if mid_range <= 0: + raise InvalidShape(f"Output shape {output_shape} is too small mid_range={mid_range} <= 0") + if reserved_size == 0: + dim = mid_size + else: + # TODO quantize size + dim = rng_shape.randint(1, mid_range) + logger.trace(f"dim = {dim}") + mid_size -= dim + dims.append(dim) + + input_shapes = [shape1 + (dim,) + shape2 for dim in dims] + return input_shapes + + @staticmethod + def stack_inputs(calculation_context: ShapeCalculationContext) -> List[TensorShape]: + # Stack adds a new dimension at the specified axis + input_num, output_shape = calculation_context.input_num, calculation_context.output_shape + test_context: RandomizerTestContext = calculation_context.test_context + forward_kwargs = calculation_context.forward_kwargs + axis = forward_kwargs["axis"] + + if len(output_shape) <= test_context.randomizer_config.dim_min: + raise InvalidShape(f"Output shape {output_shape} is too small len(output_shape)={len(output_shape)} <= dim_min={test_context.randomizer_config.dim_min}") + dim = output_shape[axis] + if dim != input_num: + raise InvalidShape(f"Mismatch of dim and input_num in output shape {output_shape}. dim={dim} != input_num={input_num}") + shape1 = output_shape[:axis] + shape2 = output_shape[axis+1:] + + input_shapes = [shape1 + shape2 for _ in range(input_num)] + return input_shapes + + +def randomize_size(dim: int, test_context: RandomizerTestContext) -> int: + '''Randomize size of a new dimension based operand size range + + Args: + dim (int): new dimension + test_context: RandomizerTestContext + + Returns: + int: random size of an dimension + ''' + rng_shape = test_context.rng_shape + randomizer_config = test_context.randomizer_config + op_size_min = randomizer_config.op_size_per_dim_min + op_size_max = randomizer_config.op_size_per_dim_max + quantization = randomizer_config.op_size_quantization + + n = rng_shape.randint(op_size_min, op_size_max) + n = RandomUtils.quantize(n, quantization) + # logger.trace(f"Randomize size: dim = {dim}, quant = {quantization} -> {n}") + + return n + + +class AdjustParameters: + '''Adjust parameters for operators based on output shape''' + # TODO Introduce adjustment method in operator definition similar to calc_input_shapes + + @staticmethod + def interleave_adjust(node: RandomizerNode, test_context: RandomizerTestContext) -> None: + '''Adjust parameters and input number for interleave based on output shape''' + rng_shape = test_context.rng_shape + input_num_range = node.operator.input_num_range + + input_num = node.input_num + output_shape = node.output_shape + axis = node.forward_kwargs["axis"] + + if len(output_shape) < 4: + raise InvalidShape(f"Output shape {node.output_shape} has len(output_shape)={len(output_shape)} < 4") + + if axis != -3: + raise InvalidShape(f"Invalid axis={axis} for output shape {node.output_shape}") + + mid_size = output_shape[axis] + + logger.trace(f"Interleave axis = {axis} output_shape = {output_shape} mid_size = {mid_size} input_num = {input_num}") + + if mid_size % input_num == 0: + # If axis is divisible by input number, no need to recalculate + return + + # Currently axis is required to be -3 so no need to change axis + supported_axises = [(axis, node.output_shape[axis])] + # supported_axises = list(enumerate(node.output_shape)) + + for axis, mid_size in rng_shape.sample(supported_axises, len(supported_axises)): + for input_num in rng_shape.sample(range(input_num_range.operands_min, input_num_range.operands_max+1), input_num_range.operands_max - input_num_range.operands_min + 1): + if mid_size % input_num == 0: + node.forward_kwargs["axis"] = axis + node.input_num = input_num + node.init_inputs() + return + + raise InvalidShape(f"Not found possible params for output shape {node.output_shape}") + + @staticmethod + def concatenate_adjust(node: RandomizerNode, test_context: RandomizerTestContext) -> None: + '''Adjust parameters and input number for concatenate based on output shape''' + rng_shape = test_context.rng_shape + input_num_range = node.operator.input_num_range + + input_num = node.input_num + output_shape = node.output_shape + axis = node.forward_kwargs["axis"] + + if not -len(output_shape) <= axis < len(output_shape): + axis = None # must be recalculated + + if axis is not None and axis % len(output_shape) == 0: + # Axis 0 is not supported + axis = None # must be recalculated + + if axis is not None: + # Maybe it's possible axis + axis %= len(output_shape) + + mid_size = output_shape[axis] + + if mid_size >= input_num: + # It is possible axis, no need to recalculate + return + + # TODO global limit for number of operands + if input_num_range.operands_min <= mid_size <= input_num_range.operands_max: + # Axis is possible but number of inputs is too big + # Lower number of inputs to fit axis dimension + node.input_num = rng_shape.randint(input_num_range.operands_min, mid_size) + node.init_inputs() + return + + # Try another axis + for axis, mid_size in rng_shape.sample(list(enumerate(node.output_shape)), len(node.output_shape)): + if axis % len(output_shape) == 0: + # Axis 0 is not supported + continue + if input_num_range.operands_min <= mid_size: + node.forward_kwargs["axis"] = axis + node.input_num = rng_shape.randint(input_num_range.operands_min, min(mid_size, input_num_range.operands_max)) + node.init_inputs() + return + + raise InvalidShape(f"Not found possible params for output shape {node.output_shape}") + + @staticmethod + def stack_adjust(node: RandomizerNode, test_context: RandomizerTestContext) -> None: + '''Adjust parameters and input number for stack based on output shape''' + input_num_range = node.operator.input_num_range + output_shape = node.output_shape + if len(output_shape) <= test_context.randomizer_config.dim_min: + raise InvalidShape(f"Output shape {output_shape} is too small len(output_shape)={len(output_shape)} <= dim_min={test_context.randomizer_config.dim_min}") + for axis, dim in enumerate(node.output_shape): + if axis == 0: + # Axis 0 is not supported + continue + if input_num_range.operands_min <= dim <= input_num_range.operands_max: + node.forward_kwargs["axis"] = axis + node.input_num = dim + node.init_inputs() + return + raise InvalidShape(f"Not found possible params for output shape {node.output_shape}") diff --git a/forge/test/random/rgg/utils.py b/forge/test/random/rgg/utils.py new file mode 100644 index 000000000..f9198203a --- /dev/null +++ b/forge/test/random/rgg/utils.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Utility functions + + +import random +import signal +import time +from typing import Callable, Generator, List, Dict +from dataclasses import asdict +from loguru import logger +import re +import yaml + +import torch +import forge + +from forge.op_repo import OperatorParam, OperatorDefinition, OperatorParamNumber + +from .datatypes import TensorShape +from .datatypes import RandomizerConfig, RandomizerTestContext, RandomizerNode, RandomizerGraph +from .datatypes import NodeShapeCalculationContext + + +class StrUtils: + + @staticmethod + def kwargs_str(**kwargs): + s = ', '.join([f"{key}= {value}" for key, value in kwargs.items()]) + return s + + @staticmethod + def args_str(*args): + s = ', '.join([f"{value}" for value in args]) + if s: + s = ", " + s + return s + + @staticmethod + def camel_case_to_snake_case(camel_case: str) -> str: + pattern = re.compile(r'(? str: + text = text.lower() + pattern = re.compile(r'\ +') + snake_case = re.sub(pattern, '_', text).lower() + return snake_case + + @classmethod + def test_id(cls, test_context: RandomizerTestContext) -> str: + parameters = test_context.parameters + framework_name = cls.text_to_snake_case(parameters.framework.framework_name) + graph_builder_snake_case = cls.camel_case_to_snake_case(parameters.graph_builder_name) + test_name = cls.text_to_snake_case(test_context.test_name) + test_id = f"{framework_name}_{graph_builder_snake_case}_{test_name}_{parameters.test_index}_{parameters.random_seed}" + return test_id + + @staticmethod + def nodes_to_str(nodes: List[RandomizerNode]) -> str: + '''Converts list of nodes to string representation + Used for debugging purposes + + Args: + nodes (List[RandomizerNode]): list of nodes + + Returns: + str: string representation of nodes + ''' + # TODO Very slow -> implement in a faster way + # nodes_str = "\n".join([f" {node}" for node in nodes]) + nodes_str = "" + return nodes_str + + +class RandomUtils: + + @classmethod + def random_value_for_param(cls, param: OperatorParam, rng_params: random.Random): + if isinstance(param, OperatorParamNumber): + return cls.random_value_for_number_param(param, rng_params) + else: + raise ValueError(f"Unsupported param type {type(param)}") + + @classmethod + def random_value_for_number_param(cls, param: OperatorParamNumber, rng_params: random.Random) -> int: + # TODO: support open intervals + # TODO: store rng_params in test_context + if param.type == float: + return rng_params.uniform(param.min_value, param.max_value) + elif param.type == int: + return rng_params.randint(param.min_value, param.max_value) + else: + raise ValueError(f"Unsupported type {param.type}") + + @classmethod + def constructor_kwargs(cls, operator: OperatorDefinition, constructor_kwargs: Dict[str, object], rng_params: random.Random) -> Dict: + return {param.name: cls.random_value_for_param(param, rng_params) if param.name not in constructor_kwargs else constructor_kwargs[param.name] for param in operator.constructor_params} + + @classmethod + def forward_kwargs(cls, operator: OperatorDefinition, forward_kwargs: Dict[str, object], rng_params: random.Random) -> Dict: + return {param.name: cls.random_value_for_param(param, rng_params) if param.name not in forward_kwargs else forward_kwargs[param.name] for param in operator.forward_params} + + @classmethod + def quantize(cls, value: int, quantization: int = 2) -> int: + '''Quantize the value to the nearest multiple of quantization + + Args: + value (int): value to quantize + quantization (int, optional): quantization factor. Defaults to 2. + + Returns: + int: quantized value + ''' + # Using max to avoid quantizing to 0 + return max(round(value / quantization) * quantization, quantization) + + @classmethod + def random_shape(cls, + rng_shape: random.Random, + dim_min: int, + dim_max: int, + op_size_min: int, + op_size_max: int, + quantization: int, + microbatch_size_min: int, + microbatch_size_max: int, + ) -> TensorShape: + shape = [cls.quantize(rng_shape.randint(op_size_min, op_size_max), quantization) for _ in range(rng_shape.randint(dim_min - 1, dim_max - 1))] + microbatch_size = rng_shape.randint(microbatch_size_min, microbatch_size_max) + shape.insert(0, microbatch_size) + shape = tuple(shape) + + return shape + + @classmethod + def random_shape_from_config(cls, randomizer_config: RandomizerConfig, rng_shape: random.Random) -> TensorShape: + op_size_min = randomizer_config.op_size_per_dim_min + op_size_max = randomizer_config.op_size_per_dim_max + op_size_quantization = randomizer_config.op_size_quantization + + dim_min = randomizer_config.dim_min + dim_max = randomizer_config.dim_max + + microbatch_size_min = randomizer_config.microbatch_size_min + microbatch_size_max = randomizer_config.microbatch_size_max + + return cls.random_shape( + rng_shape, + dim_min=dim_min, + dim_max=dim_max, + op_size_min=op_size_min, + op_size_max=op_size_max, + quantization=op_size_quantization, + microbatch_size_min=microbatch_size_min, + microbatch_size_max=microbatch_size_max, + ) + + +class GraphUtils: + + @classmethod + def get_input_shapes(cls, graph: RandomizerGraph) -> List[TensorShape]: + input_shapes = [input_node.input_shape for input_node in graph.input_nodes] + return input_shapes + + @classmethod + def to_ops_str(cls, graph: RandomizerGraph) -> str: + ops = [node.name for node in graph.nodes] + ops_str = " -> ".join(ops) + return ops_str + + @classmethod + def short_description(cls, graph: RandomizerGraph): + return f"ops: ({cls.to_ops_str(graph)}) input_shapes: {cls.get_input_shapes(graph)}" + + # TODO support serialization/deserialization of RandomizerGraph + @classmethod + def to_str(cls, graph: RandomizerGraph): + graph_dict = asdict(graph) + # Serialize dictionary to YAML string + yaml_str = yaml.dump(graph_dict) + # yaml_str = json.dumps(graph.__dict__) + return yaml_str + + +class NodeUtils: + + @staticmethod + def is_previous_node(node: RandomizerNode, previous_node: RandomizerNode) -> bool: + return node.index == previous_node.index + 1 + + @classmethod + def num_of_open_inputs(cls, node: RandomizerNode) -> int: + return node.inputs.count(None) + + @classmethod + def num_of_closed_inputs(cls, node: RandomizerNode) -> int: + return node.input_num - cls.num_of_open_inputs(node) + + @classmethod + def is_open(cls, node: RandomizerNode) -> bool: + return cls.num_of_open_inputs(node) > 0 + + # TODO replace list with generator + @classmethod + def get_open_nodes(cls, nodes: List[RandomizerNode]) -> List[RandomizerNode]: + return [node for node in nodes if cls.is_open(node)] + + @classmethod + def has_open_input_with_input_shape(cls, node: RandomizerNode, input_shape: TensorShape) -> bool: + for i, open_input in enumerate(node.inputs): + if open_input is None: + if input_shape == node.input_shapes[i]: + return True + return False + + @classmethod + def get_open_input_indices(cls, node: RandomizerNode) -> Generator[int, None, None]: + for i, open_input in enumerate(node.inputs): + if open_input is None: + yield i + + # TODO replace list with generator + @classmethod + def get_open_nodes_with_input_shape(cls, nodes: List[RandomizerNode], input_shape: TensorShape) -> List[RandomizerNode]: + return [node for node in nodes if cls.is_open(node) and cls.has_open_input_with_input_shape(node, input_shape)] + + @classmethod + def calc_input_shapes(cls, node: RandomizerNode, shape_calculation_context: NodeShapeCalculationContext) -> List[TensorShape]: + return node.operator.calc_input_shapes(shape_calculation_context) + + @classmethod + def get_random_input_num(cls, node: RandomizerNode, test_context: RandomizerTestContext) -> int: + input_num_range = node.operator.input_num_range + return test_context.rng_graph.randint(input_num_range.operands_min, input_num_range.operands_max) + + @classmethod + def init_random_inputs(cls, node: RandomizerNode, test_context: RandomizerTestContext) -> None: + node.input_num = cls.get_random_input_num(node, test_context) + node.init_inputs() + + +class DebugUtils: + + @classmethod + def format_tensors(cls, tensors: List[forge.Tensor]): + if isinstance(tensors[0], forge.Tensor): + format_tensor: Callable[[forge.Tensor], str] = lambda t: f'{t.data_format}:{t.shape}' + elif isinstance(tensors[0], torch.Tensor): + format_tensor: Callable[[forge.Tensor], str] = lambda t: f'{t.type()}:{t.shape}' + return [format_tensor(t) for t in tensors] + + @classmethod + def debug_inputs(cls, inputs: List[forge.Tensor]): + logger.info(f"inputs: {cls.format_tensors(inputs)}") + + +class Timer: + '''Timer class to measure the duration of a code block''' + + def __init__(self): + self.start_time = time.perf_counter() + + def get_duration(self): + '''Calculate the duration of the code block in seconds''' + end_time = time.perf_counter() + duration = end_time - self.start_time + return duration + + +class TimeoutException(Exception): + pass + + +# Handler for timeout signal +def timeout_handler(signum, frame): + raise TimeoutException + + +# Decorator for time limiting +def timeout(seconds): + def decorator(func): + def wrapper(*args, **kwargs): + # Set signal handler + signal.signal(signal.SIGALRM, timeout_handler) + # Set alarm + signal.alarm(seconds) + try: + result = func(*args, **kwargs) + finally: + # Shutdown alarm + signal.alarm(0) + return result + return wrapper + return decorator + diff --git a/forge/test/random/test_graphs.py b/forge/test/random/test_graphs.py new file mode 100644 index 000000000..0cde6553f --- /dev/null +++ b/forge/test/random/test_graphs.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +# Test random graph configurations by utilizing Random Graph Generator Algorithm and targeting PyBuda and PyTorch frameworks + +from enum import Enum +import pytest + +from copy import copy + +from forge.op_repo import OperatorParamNumber + +from test.random.rgg import Frameworks +from test.random.rgg import FrameworkTestUtils +from test.random.rgg import RandomGraphAlgorithm +from test.random.rgg import RandomizerConfig +from test.random.rgg import process_test + + +class FrameworksHealthy(Enum): + ''' Adjust repositories to test healthy operators ''' + + @staticmethod + def healty_pybuda(): + SKIP_OPERATORS = ( + # Unary operators + "exp", # pcc? + "sqrt", # skip because it's failing for negative values + "cumsum", # bug + "argmax", # shape calc is wrong + "logical_not", # bug + "dropout", # pcc? + "tilizer", # bug + + # Binary operators + "divide", # bug + "binary_stack", # bug + "power", # occasionally fails + "logical_and", # bug + + # Nary operators + "where", # pcc? + ) + + framework = FrameworkTestUtils.copy_framework(Frameworks.PYBUDA.value, "Healthy PyBuda", SKIP_OPERATORS) + + pow_operator = FrameworkTestUtils.copy_operator(framework, "pow") + if pow_operator: + pow_operator.forward_params = [ + # float exponent is currently not supported due to issue #2592 + # OperatorParamNumber("exponent", float, 0, 100), + # OperatorParamNumber("exponent", int, 0, 100), + OperatorParamNumber("exponent", int, 0, 4), # pcc for higher numbers fails + ] + + return framework + + @staticmethod + def healty_pytorch(): + SKIP_OPERATORS = ( + "sqrt", # skip because it's failing for negative values + # "linear", + "conv2d", # skip until calc_input_shapes is properly implemented + ) + + framework = FrameworkTestUtils.copy_framework(Frameworks.PYTORCH.value, "Healthy PyTorch", SKIP_OPERATORS) + + return framework + + PYBUDA = healty_pybuda() + PYTORCH = healty_pytorch() + + +class FrameworksCustom(Enum): + ''' Adjust repositories to prepare custom framework configurations ''' + + + @staticmethod + def pybuda_fork_joins(): + SKIP_OPERATORS = ( + ) + + framework = FrameworkTestUtils.copy_framework(Frameworks.PYBUDA.value, "PyBuda fork joins", SKIP_OPERATORS) + + ALLOW_OPERATORS = ( + "relu", + "tanh", + "add", + "matmul", + ) + + FrameworkTestUtils.allow_operators(framework, ALLOW_OPERATORS) + + return framework + + @staticmethod + def pybuda_nary(): + SKIP_OPERATORS = ( + ) + + framework = FrameworkTestUtils.copy_framework(Frameworks.PYBUDA.value, "PyBuda nary", SKIP_OPERATORS) + + ALLOW_OPERATORS = ( + # "relu", + "tanh", + "add", + "matmul", # Skip matmul to increase chance for stack operator + "interleave", + # "where", # pcc? + "concatenate", + "stack", + ) + + FrameworkTestUtils.allow_operators(framework, ALLOW_OPERATORS) + + return framework + + PYBUDA_FORK_JOINS = pybuda_fork_joins() + PYBUDA_NARY = pybuda_nary() + + +@pytest.mark.parametrize("framework", [ + FrameworksHealthy.PYBUDA.value, +]) +def test_random_graph_algorithm_pybuda(test_index, random_seeds, test_device, randomizer_config: RandomizerConfig, framework): + # adjust randomizer_config + randomizer_config = copy(randomizer_config) + # randomizer_config.debug_shapes = True + # randomizer_config.verify_shapes = True + + # Uncomment the following randomizer_config values to override the default values + # randomizer_config.dim_min = 3 + # randomizer_config.dim_max = 4 + # randomizer_config.op_size_per_dim_min = 4 + # # randomizer_config.op_size_per_dim_min = 16 + # randomizer_config.op_size_per_dim_max = 8 + # # randomizer_config.op_size_per_dim_max = 64 + # # randomizer_config.op_size_per_dim_max = 256 + # randomizer_config.microbatch_size_min = 1 + # randomizer_config.microbatch_size_max = 8 + # randomizer_config.num_of_nodes_min = 5 + # randomizer_config.num_of_nodes_max = 10 + # randomizer_config.num_fork_joins_max = 5 + + # TODO random_seed instead of random_seeds + random_seed = random_seeds[test_index] + process_test("Default", test_index, random_seed, test_device, randomizer_config, graph_builder_type=RandomGraphAlgorithm, framework=framework) + + +@pytest.mark.parametrize("framework", [ + FrameworksHealthy.PYTORCH.value, +]) +def test_random_graph_algorithm_pytorch(test_index, random_seeds, test_device, randomizer_config: RandomizerConfig, framework): + # adjust randomizer_config + randomizer_config = copy(randomizer_config) + # randomizer_config.debug_shapes = True + # randomizer_config.verify_shapes = True + + # Uncomment the following randomizer_config values to override the default values + # randomizer_config.dim_min = 4 + # randomizer_config.dim_max = 4 + # randomizer_config.op_size_per_dim_min = 4 + # # randomizer_config.op_size_per_dim_min = 16 + # randomizer_config.op_size_per_dim_max = 8 + # # randomizer_config.op_size_per_dim_max = 64 + # # randomizer_config.op_size_per_dim_max = 256 + # randomizer_config.microbatch_size_min = 1 + # randomizer_config.microbatch_size_max = 8 + # randomizer_config.num_of_nodes_min = 3 + # randomizer_config.num_of_nodes_max = 5 + # randomizer_config.num_fork_joins_max = 5 + + # TODO random_seed instead of random_seeds + random_seed = random_seeds[test_index] + process_test("Default", test_index, random_seed, test_device, randomizer_config, graph_builder_type=RandomGraphAlgorithm, framework=framework) + + +@pytest.mark.parametrize("framework", [ + FrameworksCustom.PYBUDA_FORK_JOINS.value, +]) +def ttest_random_graph_algorithm_pybuda_fork_joins(test_index, random_seeds, test_device, randomizer_config: RandomizerConfig, framework): + # adjust randomizer_config + randomizer_config = copy(randomizer_config) + # randomizer_config.debug_shapes = True + # randomizer_config.verify_shapes = True + randomizer_config.dim_min = 3 + randomizer_config.dim_max = 4 + randomizer_config.op_size_per_dim_min = 4 + # randomizer_config.op_size_per_dim_min = 16 + randomizer_config.op_size_per_dim_max = 8 + # randomizer_config.op_size_per_dim_max = 64 + # randomizer_config.op_size_per_dim_max = 256 + randomizer_config.microbatch_size_min = 1 + randomizer_config.microbatch_size_max = 8 + randomizer_config.num_of_nodes_min = 10 + randomizer_config.num_of_nodes_max = 15 + randomizer_config.num_fork_joins_max = 10 + + # TODO random_seed instead of random_seeds + random_seed = random_seeds[test_index] + process_test("Fork Joins", test_index, random_seed, test_device, randomizer_config, graph_builder_type=RandomGraphAlgorithm, framework=framework) + + +# @pytest.mark.xfail(reason="Nary operators are buggy") +@pytest.mark.parametrize("framework", [ + FrameworksCustom.PYBUDA_NARY.value, +]) +def ttest_random_graph_algorithm_pybuda_nary(test_index, random_seeds, test_device, randomizer_config: RandomizerConfig, framework): + # adjust randomizer_config + randomizer_config = copy(randomizer_config) + # randomizer_config.debug_shapes = True + # randomizer_config.verify_shapes = True + randomizer_config.dim_min = 3 + randomizer_config.dim_max = 4 + randomizer_config.op_size_per_dim_min = 2 # avoid failing tests with smaller dimensions? + # randomizer_config.op_size_per_dim_min = 4 + # randomizer_config.op_size_per_dim_min = 16 + randomizer_config.op_size_per_dim_max = 8 + # randomizer_config.op_size_per_dim_max = 64 + # randomizer_config.op_size_per_dim_max = 256 + randomizer_config.op_size_quantization = 2 + randomizer_config.microbatch_size_min = 1 + randomizer_config.microbatch_size_max = 8 + randomizer_config.num_of_nodes_min = 10 + randomizer_config.num_of_nodes_max = 15 + randomizer_config.num_fork_joins_max = 10 + + # TODO random_seed instead of random_seeds + random_seed = random_seeds[test_index] + process_test("Nary", test_index, random_seed, test_device, randomizer_config, graph_builder_type=RandomGraphAlgorithm, framework=framework)