Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Dec 23, 2024
1 parent fb2f44f commit 725a75c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class OperatorSetNames(Enum):
OPSET_CONV = "Conv"
OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
OPSET_CONV_TRANSPOSE = "ConvTraspose"
OPSET_CONV_TRANSPOSE = "ConvTranspose"
OPSET_FULLY_CONNECTED = "FullyConnected"
OPSET_CONCATENATE = "Concatenate"
OPSET_STACK = "Stack"
Expand All @@ -41,7 +41,8 @@ class OperatorSetNames(Enum):
OPSET_SUB = "Sub"
OPSET_MUL = "Mul"
OPSET_DIV = "Div"
OPSET_MIN_MAX = "MinMax"
OPSET_MIN = "Min"
OPSET_MAX = "Max"
OPSET_PRELU = "PReLU"
OPSET_SWISH = "Swish"
OPSET_SIGMOID = "Sigmoid"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __init__(self):
OperatorSetNames.OPSET_SUB.value: [tf.subtract, Subtract],
OperatorSetNames.OPSET_MUL.value: [tf.math.multiply, Multiply],
OperatorSetNames.OPSET_DIV.value: [tf.math.divide, tf.math.truediv],
OperatorSetNames.OPSET_MIN_MAX.value: [tf.math.minimum, tf.math.maximum, Minimum, Maximum],
OperatorSetNames.OPSET_MIN.value: [tf.math.minimum, Minimum],
OperatorSetNames.OPSET_MAX.value: [tf.math.maximum, Maximum],
OperatorSetNames.OPSET_PRELU.value: [PReLU],
OperatorSetNames.OPSET_SWISH.value: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
OperatorSetNames.OPSET_SIGMOID.value: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self):
OperatorSetNames.OPSET_SUB.value: [operator.sub, sub, subtract],
OperatorSetNames.OPSET_MUL.value: [operator.mul, mul, multiply],
OperatorSetNames.OPSET_DIV.value: [operator.truediv, div, divide],
OperatorSetNames.OPSET_MIN_MAX.value: [minimum, maximum],
OperatorSetNames.OPSET_MIN.value: [minimum],
OperatorSetNames.OPSET_MAX.value: [maximum],
OperatorSetNames.OPSET_PRELU.value: [PReLU, prelu],
OperatorSetNames.OPSET_SWISH.value: [SiLU, silu],
OperatorSetNames.OPSET_SIGMOID.value: [Sigmoid, sigmoid, F.sigmoid],
Expand Down Expand Up @@ -86,4 +87,5 @@ def __init__(self):
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
BIAS_ATTR: DefaultDict(default_value=BIAS)}
self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: pytorch_linear_attr_mapping,
OperatorSetNames.OPSET_CONV_TRANSPOSE.value: pytorch_linear_attr_mapping,
OperatorSetNames.OPSET_FULLY_CONNECTED.value: pytorch_linear_attr_mapping}

0 comments on commit 725a75c

Please sign in to comment.