From 3a864dcba74352b336354ac9dbb2ec9bc25471db Mon Sep 17 00:00:00 2001 From: Vladimir Brkic Date: Tue, 18 Feb 2025 16:55:13 +0000 Subject: [PATCH] Operator repository for Forge and Torch --- forge/forge/op_repo/__init__.py | 4 + forge/forge/op_repo/datatypes.py | 5 + forge/forge/op_repo/forge_operators.py | 108 ++++++++++++++++++++++ forge/forge/op_repo/pytorch_operators.py | 111 +++++++++++++++++++++++ 4 files changed, 228 insertions(+) create mode 100644 forge/forge/op_repo/forge_operators.py create mode 100644 forge/forge/op_repo/pytorch_operators.py diff --git a/forge/forge/op_repo/__init__.py b/forge/forge/op_repo/__init__.py index 1f1596708..13a38b6e3 100644 --- a/forge/forge/op_repo/__init__.py +++ b/forge/forge/op_repo/__init__.py @@ -15,6 +15,8 @@ from .datatypes import OperandNumInt, OperandNumTuple, OperandNumRange from .datatypes import TensorShape, OperatorParam, OperatorParamNumber, OperatorDefinition, OperatorRepository from .datatypes import ShapeCalculationContext +from .forge_operators import forge_operator_repository +from .pytorch_operators import pytorch_operator_repository __ALL__ = [ "OperandNumInt", @@ -26,4 +28,6 @@ "OperatorDefinition", "OperatorRepository", "ShapeCalculationContext", + "forge_operator_repository", + "pytorch_operator_repository", ] diff --git a/forge/forge/op_repo/datatypes.py b/forge/forge/op_repo/datatypes.py index 86267e944..f5c5615bd 100644 --- a/forge/forge/op_repo/datatypes.py +++ b/forge/forge/op_repo/datatypes.py @@ -144,6 +144,11 @@ class OperatorRepository: def __init__(self, operators: List[OperatorDefinition]): self.operators = operators + # Check if operator names which are unique + names = [op.name for op in operators] + duplicates = set([name for name in names if names.count(name) > 1]) + if duplicates: + raise ValueError(f"Detected duplicate operator names: {duplicates}") def get_by_name(self, name: str): return [op for op in self.operators if op.name == name][0] diff --git a/forge/forge/op_repo/forge_operators.py b/forge/forge/op_repo/forge_operators.py new file mode 100644 index 000000000..37e23b907 --- /dev/null +++ b/forge/forge/op_repo/forge_operators.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +# Forge repostiory operators + + +from .datatypes import OperatorDefinition, OperatorRepository +from .datatypes import OperatorParamNumber + + +# TODO describe operand and shapes +_OPERATORS = [ + # Unary operators + OperatorDefinition("exp", "forge.op.Exp", 1), + OperatorDefinition("reciprocal", "forge.op.Reciprocal", 1), + OperatorDefinition("buffer", "forge.op.Buffer", 1), + OperatorDefinition("sqrt", "forge.op.Sqrt", 1), + OperatorDefinition("relu", "forge.op.Relu", 1), + OperatorDefinition( + "leaky_relu", + "forge.op.LeakyRelu", + 1, + forward_params=[ + OperatorParamNumber("alpha", float, 0, 100), + ], + ), + OperatorDefinition("nop", "forge.op.Identity", 1), + OperatorDefinition("gelu", "forge.op.Gelu", 1), + OperatorDefinition("log", "forge.op.Log", 1), + OperatorDefinition("sigmoid", "forge.op.Sigmoid", 1), + OperatorDefinition( + "clip", + "forge.op.Clip", + 1, + forward_params=[ + OperatorParamNumber("min", float, 0, 100), + OperatorParamNumber("max", float, 0, 100), + ], + ), + OperatorDefinition("sine", "forge.op.Sine", 1), + OperatorDefinition("cosine", "forge.op.Cosine", 1), + OperatorDefinition("abs", "forge.op.Abs", 1), + OperatorDefinition("tanh", "forge.op.Tanh", 1), + OperatorDefinition("cumsum", "forge.op.CumSum", 1), + OperatorDefinition("argmax", "forge.op.Argmax", 1), + OperatorDefinition("logical_not", "forge.op.LogicalNot", 1), + OperatorDefinition("dropout", "forge.op.Dropout", 1), + OperatorDefinition( + "pow", + "forge.op.Pow", + 1, + forward_params=[ + OperatorParamNumber("exponent", float, 0, 100), + ], + ), + OperatorDefinition("tilizer", "forge.op.Tilize", 1), + # Binary operators + OperatorDefinition("add", "forge.op.Add", 2), + OperatorDefinition("divide", "forge.op.Divide", 2), + OperatorDefinition("subtract", "forge.op.Subtract", 2), + OperatorDefinition("multiply", "forge.op.Multiply", 2), + OperatorDefinition("maximum", "forge.op.Max", 2), + OperatorDefinition("minimum", "forge.op.Min", 2), + OperatorDefinition("heaviside", "forge.op.Heaviside", 2), + OperatorDefinition("binary_stack", "forge.op.BinaryStack", 2), + OperatorDefinition("power", "forge.op.Power", 2), + OperatorDefinition("greater", "forge.op.Greater", 2), + OperatorDefinition("greater_equal", "forge.op.GreaterEqual", 2), + OperatorDefinition("less", "forge.op.Less", 2), + OperatorDefinition("less_equal", "forge.op.LessEqual", 2), + OperatorDefinition("equal", "forge.op.Equal", 2), + OperatorDefinition("not_equal", "forge.op.NotEqual", 2), + OperatorDefinition("logical_and", "forge.op.LogicalAnd", 2), + # Nary operators + OperatorDefinition("where", "forge.op.Where", 3), + # OperatorDefinition("index_copy", "forge.op.IndexCopy", 3), # Bug #2705 + OperatorDefinition( + "interleave", + "forge.op.Interleave", + (1, 10), + forward_params=[ + OperatorParamNumber("axis", int, -3, -3), + OperatorParamNumber("stride", int, 1, 1), + ], + ), + OperatorDefinition( + "concatenate", + "forge.op.Concatenate", + (1, 10), + forward_params=[ + OperatorParamNumber("axis", int, -10, 10), + ], + ), + OperatorDefinition( + "stack", + "forge.op.Stack", + (2, 4), + forward_params=[ + OperatorParamNumber("axis", int, 1, 10), + ], + ), + OperatorDefinition("matmul", "forge.op.Matmul", 2), + # OperatorDefinition("sparse_matmul", "forge.op.SparseMatmul", 2), +] + + +forge_operator_repository = OperatorRepository([op for op in _OPERATORS]) diff --git a/forge/forge/op_repo/pytorch_operators.py b/forge/forge/op_repo/pytorch_operators.py new file mode 100644 index 000000000..72b55d24d --- /dev/null +++ b/forge/forge/op_repo/pytorch_operators.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +# PyTorch repostiory operators + + +from .datatypes import OperatorDefinition, OperatorRepository +from .datatypes import OperatorParamNumber + + +# TODO describe operand and shapes +_OPERATORS = [ + OperatorDefinition( + "linear", + "torch.nn.Linear", + 1, + instantiate=True, + constructor_params=[ + OperatorParamNumber("in_features", int, 10, 50), + OperatorParamNumber("out_features", int, 10, 50), + ], + ), + OperatorDefinition( + "conv2d", + "torch.nn.Conv2d", + 1, + instantiate=True, + constructor_params=[ + OperatorParamNumber("in_channels", int, 10, 50), + OperatorParamNumber("out_channels", int, 10, 50), + OperatorParamNumber("kernel_size", int, 3, 3), + OperatorParamNumber("stride", int, 1, 1), + OperatorParamNumber("padding", int, 1, 1), + ], + ), + # Unary operators (implemented) + OperatorDefinition("relu", "torch.relu", 1), + OperatorDefinition("sqrt", "torch.sqrt", 1), + OperatorDefinition("reciprocal", "torch.reciprocal", 1), + OperatorDefinition("sigmoid", "torch.sigmoid", 1), + OperatorDefinition("abs", "torch.abs", 1), + OperatorDefinition("cos", "torch.cos", 1), + OperatorDefinition("exp", "torch.exp", 1), + OperatorDefinition("neg", "torch.neg", 1), + OperatorDefinition("rsqrt", "torch.rsqrt", 1), + OperatorDefinition("sin", "torch.sin", 1), + OperatorDefinition("square", "torch.square", 1), + OperatorDefinition("pow", "torch.pow", 1), + OperatorDefinition("clamp", "torch.clamp", 1), + OperatorDefinition("log", "torch.log", 1), + OperatorDefinition("log1p", "torch.log1p", 1), + OperatorDefinition("gelu", "torch.nn.functional.gelu", 1), + OperatorDefinition("leaky_relu", "torch.nn.functional.leaky_relu", 1), + # Unary operators (not implemented) + OperatorDefinition("acos", "torch.acos", 1), + OperatorDefinition("arccos", "torch.acos", 1), + OperatorDefinition("acosh", "torch.acosh", 1), + OperatorDefinition("arccosh", "torch.acosh", 1), + OperatorDefinition("angle", "torch.angle", 1), + OperatorDefinition("asin", "torch.asin", 1), + OperatorDefinition("arcsin", "torch.asin", 1), + OperatorDefinition("asinh", "torch.asinh", 1), + OperatorDefinition("arcsinh", "torch.asinh", 1), + OperatorDefinition("atan", "torch.atan", 1), + OperatorDefinition("arctan", "torch.atan", 1), + OperatorDefinition("atanh", "torch.atanh", 1), + OperatorDefinition("arctanh", "torch.atanh", 1), + OperatorDefinition("bitwise_not", "torch.bitwise_not", 1), + OperatorDefinition("ceil", "torch.ceil", 1), + OperatorDefinition("conj_physical", "torch.conj_physical", 1), + OperatorDefinition("cosh", "torch.cosh", 1), + OperatorDefinition("deg2rad", "torch.deg2rad", 1), + OperatorDefinition("digamma", "torch.digamma", 1), + OperatorDefinition("erf", "torch.erf", 1), + OperatorDefinition("erfc", "torch.erfc", 1), + OperatorDefinition("erfinv", "torch.erfinv", 1), + OperatorDefinition("exp2", "torch.exp2", 1), + OperatorDefinition("expm1", "torch.expm1", 1), + OperatorDefinition("fix", "torch.fix", 1), + OperatorDefinition("floor", "torch.floor", 1), + OperatorDefinition("frac", "torch.frac", 1), + OperatorDefinition("lgamma", "torch.lgamma", 1), + OperatorDefinition("log10", "torch.log10", 1), + OperatorDefinition("log2", "torch.log2", 1), + OperatorDefinition("logit", "torch.logit", 1), + OperatorDefinition("i0", "torch.i0", 1), + OperatorDefinition("isnan", "torch.isnan", 1), + OperatorDefinition("nan_to_num", "torch.nan_to_num", 1), + OperatorDefinition("positive", "torch.positive", 1), + OperatorDefinition("rad2deg", "torch.rad2deg", 1), + OperatorDefinition("round", "torch.round", 1), + OperatorDefinition("sign", "torch.sign", 1), + OperatorDefinition("sgn", "torch.sgn", 1), + OperatorDefinition("signbit", "torch.signbit", 1), + OperatorDefinition("sinc", "torch.sinc", 1), + OperatorDefinition("sinh", "torch.sinh", 1), + OperatorDefinition("tan", "torch.tan", 1), + OperatorDefinition("tanh", "torch.tanh", 1), + OperatorDefinition("trunc", "torch.trunc", 1), + # Binary operators + OperatorDefinition("add", "torch.add", 2), + OperatorDefinition("sub", "torch.sub", 2), + OperatorDefinition("mul", "torch.mul", 2), + OperatorDefinition("div", "torch.div", 2), + OperatorDefinition("ge", "torch.ge", 2), + OperatorDefinition("matmul", "torch.matmul", 2), +] + + +pytorch_operator_repository = OperatorRepository([op for op in _OPERATORS])