Skip to content

Commit

Permalink
Operator repository for Forge and Torch
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT committed Feb 19, 2025
1 parent 513203c commit 3a864dc
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
4 changes: 4 additions & 0 deletions forge/forge/op_repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .datatypes import OperandNumInt, OperandNumTuple, OperandNumRange
from .datatypes import TensorShape, OperatorParam, OperatorParamNumber, OperatorDefinition, OperatorRepository
from .datatypes import ShapeCalculationContext
from .forge_operators import forge_operator_repository
from .pytorch_operators import pytorch_operator_repository

__ALL__ = [
"OperandNumInt",
Expand All @@ -26,4 +28,6 @@
"OperatorDefinition",
"OperatorRepository",
"ShapeCalculationContext",
"forge_operator_repository",
"pytorch_operator_repository",
]
5 changes: 5 additions & 0 deletions forge/forge/op_repo/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class OperatorRepository:

def __init__(self, operators: List[OperatorDefinition]):
self.operators = operators
# Check if operator names which are unique
names = [op.name for op in operators]
duplicates = set([name for name in names if names.count(name) > 1])
if duplicates:
raise ValueError(f"Detected duplicate operator names: {duplicates}")

def get_by_name(self, name: str):
return [op for op in self.operators if op.name == name][0]
108 changes: 108 additions & 0 deletions forge/forge/op_repo/forge_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

# Forge repostiory operators


from .datatypes import OperatorDefinition, OperatorRepository
from .datatypes import OperatorParamNumber


# TODO describe operand and shapes
_OPERATORS = [
# Unary operators
OperatorDefinition("exp", "forge.op.Exp", 1),
OperatorDefinition("reciprocal", "forge.op.Reciprocal", 1),
OperatorDefinition("buffer", "forge.op.Buffer", 1),
OperatorDefinition("sqrt", "forge.op.Sqrt", 1),
OperatorDefinition("relu", "forge.op.Relu", 1),
OperatorDefinition(
"leaky_relu",
"forge.op.LeakyRelu",
1,
forward_params=[
OperatorParamNumber("alpha", float, 0, 100),
],
),
OperatorDefinition("nop", "forge.op.Identity", 1),
OperatorDefinition("gelu", "forge.op.Gelu", 1),
OperatorDefinition("log", "forge.op.Log", 1),
OperatorDefinition("sigmoid", "forge.op.Sigmoid", 1),
OperatorDefinition(
"clip",
"forge.op.Clip",
1,
forward_params=[
OperatorParamNumber("min", float, 0, 100),
OperatorParamNumber("max", float, 0, 100),
],
),
OperatorDefinition("sine", "forge.op.Sine", 1),
OperatorDefinition("cosine", "forge.op.Cosine", 1),
OperatorDefinition("abs", "forge.op.Abs", 1),
OperatorDefinition("tanh", "forge.op.Tanh", 1),
OperatorDefinition("cumsum", "forge.op.CumSum", 1),
OperatorDefinition("argmax", "forge.op.Argmax", 1),
OperatorDefinition("logical_not", "forge.op.LogicalNot", 1),
OperatorDefinition("dropout", "forge.op.Dropout", 1),
OperatorDefinition(
"pow",
"forge.op.Pow",
1,
forward_params=[
OperatorParamNumber("exponent", float, 0, 100),
],
),
OperatorDefinition("tilizer", "forge.op.Tilize", 1),
# Binary operators
OperatorDefinition("add", "forge.op.Add", 2),
OperatorDefinition("divide", "forge.op.Divide", 2),
OperatorDefinition("subtract", "forge.op.Subtract", 2),
OperatorDefinition("multiply", "forge.op.Multiply", 2),
OperatorDefinition("maximum", "forge.op.Max", 2),
OperatorDefinition("minimum", "forge.op.Min", 2),
OperatorDefinition("heaviside", "forge.op.Heaviside", 2),
OperatorDefinition("binary_stack", "forge.op.BinaryStack", 2),
OperatorDefinition("power", "forge.op.Power", 2),
OperatorDefinition("greater", "forge.op.Greater", 2),
OperatorDefinition("greater_equal", "forge.op.GreaterEqual", 2),
OperatorDefinition("less", "forge.op.Less", 2),
OperatorDefinition("less_equal", "forge.op.LessEqual", 2),
OperatorDefinition("equal", "forge.op.Equal", 2),
OperatorDefinition("not_equal", "forge.op.NotEqual", 2),
OperatorDefinition("logical_and", "forge.op.LogicalAnd", 2),
# Nary operators
OperatorDefinition("where", "forge.op.Where", 3),
# OperatorDefinition("index_copy", "forge.op.IndexCopy", 3), # Bug #2705
OperatorDefinition(
"interleave",
"forge.op.Interleave",
(1, 10),
forward_params=[
OperatorParamNumber("axis", int, -3, -3),
OperatorParamNumber("stride", int, 1, 1),
],
),
OperatorDefinition(
"concatenate",
"forge.op.Concatenate",
(1, 10),
forward_params=[
OperatorParamNumber("axis", int, -10, 10),
],
),
OperatorDefinition(
"stack",
"forge.op.Stack",
(2, 4),
forward_params=[
OperatorParamNumber("axis", int, 1, 10),
],
),
OperatorDefinition("matmul", "forge.op.Matmul", 2),
# OperatorDefinition("sparse_matmul", "forge.op.SparseMatmul", 2),
]


forge_operator_repository = OperatorRepository([op for op in _OPERATORS])
111 changes: 111 additions & 0 deletions forge/forge/op_repo/pytorch_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

# PyTorch repostiory operators


from .datatypes import OperatorDefinition, OperatorRepository
from .datatypes import OperatorParamNumber


# TODO describe operand and shapes
_OPERATORS = [
OperatorDefinition(
"linear",
"torch.nn.Linear",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_features", int, 10, 50),
OperatorParamNumber("out_features", int, 10, 50),
],
),
OperatorDefinition(
"conv2d",
"torch.nn.Conv2d",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_channels", int, 10, 50),
OperatorParamNumber("out_channels", int, 10, 50),
OperatorParamNumber("kernel_size", int, 3, 3),
OperatorParamNumber("stride", int, 1, 1),
OperatorParamNumber("padding", int, 1, 1),
],
),
# Unary operators (implemented)
OperatorDefinition("relu", "torch.relu", 1),
OperatorDefinition("sqrt", "torch.sqrt", 1),
OperatorDefinition("reciprocal", "torch.reciprocal", 1),
OperatorDefinition("sigmoid", "torch.sigmoid", 1),
OperatorDefinition("abs", "torch.abs", 1),
OperatorDefinition("cos", "torch.cos", 1),
OperatorDefinition("exp", "torch.exp", 1),
OperatorDefinition("neg", "torch.neg", 1),
OperatorDefinition("rsqrt", "torch.rsqrt", 1),
OperatorDefinition("sin", "torch.sin", 1),
OperatorDefinition("square", "torch.square", 1),
OperatorDefinition("pow", "torch.pow", 1),
OperatorDefinition("clamp", "torch.clamp", 1),
OperatorDefinition("log", "torch.log", 1),
OperatorDefinition("log1p", "torch.log1p", 1),
OperatorDefinition("gelu", "torch.nn.functional.gelu", 1),
OperatorDefinition("leaky_relu", "torch.nn.functional.leaky_relu", 1),
# Unary operators (not implemented)
OperatorDefinition("acos", "torch.acos", 1),
OperatorDefinition("arccos", "torch.acos", 1),
OperatorDefinition("acosh", "torch.acosh", 1),
OperatorDefinition("arccosh", "torch.acosh", 1),
OperatorDefinition("angle", "torch.angle", 1),
OperatorDefinition("asin", "torch.asin", 1),
OperatorDefinition("arcsin", "torch.asin", 1),
OperatorDefinition("asinh", "torch.asinh", 1),
OperatorDefinition("arcsinh", "torch.asinh", 1),
OperatorDefinition("atan", "torch.atan", 1),
OperatorDefinition("arctan", "torch.atan", 1),
OperatorDefinition("atanh", "torch.atanh", 1),
OperatorDefinition("arctanh", "torch.atanh", 1),
OperatorDefinition("bitwise_not", "torch.bitwise_not", 1),
OperatorDefinition("ceil", "torch.ceil", 1),
OperatorDefinition("conj_physical", "torch.conj_physical", 1),
OperatorDefinition("cosh", "torch.cosh", 1),
OperatorDefinition("deg2rad", "torch.deg2rad", 1),
OperatorDefinition("digamma", "torch.digamma", 1),
OperatorDefinition("erf", "torch.erf", 1),
OperatorDefinition("erfc", "torch.erfc", 1),
OperatorDefinition("erfinv", "torch.erfinv", 1),
OperatorDefinition("exp2", "torch.exp2", 1),
OperatorDefinition("expm1", "torch.expm1", 1),
OperatorDefinition("fix", "torch.fix", 1),
OperatorDefinition("floor", "torch.floor", 1),
OperatorDefinition("frac", "torch.frac", 1),
OperatorDefinition("lgamma", "torch.lgamma", 1),
OperatorDefinition("log10", "torch.log10", 1),
OperatorDefinition("log2", "torch.log2", 1),
OperatorDefinition("logit", "torch.logit", 1),
OperatorDefinition("i0", "torch.i0", 1),
OperatorDefinition("isnan", "torch.isnan", 1),
OperatorDefinition("nan_to_num", "torch.nan_to_num", 1),
OperatorDefinition("positive", "torch.positive", 1),
OperatorDefinition("rad2deg", "torch.rad2deg", 1),
OperatorDefinition("round", "torch.round", 1),
OperatorDefinition("sign", "torch.sign", 1),
OperatorDefinition("sgn", "torch.sgn", 1),
OperatorDefinition("signbit", "torch.signbit", 1),
OperatorDefinition("sinc", "torch.sinc", 1),
OperatorDefinition("sinh", "torch.sinh", 1),
OperatorDefinition("tan", "torch.tan", 1),
OperatorDefinition("tanh", "torch.tanh", 1),
OperatorDefinition("trunc", "torch.trunc", 1),
# Binary operators
OperatorDefinition("add", "torch.add", 2),
OperatorDefinition("sub", "torch.sub", 2),
OperatorDefinition("mul", "torch.mul", 2),
OperatorDefinition("div", "torch.div", 2),
OperatorDefinition("ge", "torch.ge", 2),
OperatorDefinition("matmul", "torch.matmul", 2),
]


pytorch_operator_repository = OperatorRepository([op for op in _OPERATORS])

0 comments on commit 3a864dc

Please sign in to comment.