Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Aug 14, 2024
1 parent 5267c91 commit ba386d6
Showing 1 changed file with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import unittest
import torch
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
from mct_quantizers.pytorch.quantizers import ActivationPOTInferableQuantizer
from torch.nn import Conv2d
import numpy as np

Expand All @@ -12,6 +11,8 @@
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
from model_compression_toolkit.gptq.pytorch.quantizer.activation.ste_activation import \
STEActivationSymmetricGPTQTrainableQuantizer
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc
from torch.fx import symbolic_trace
from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_quantization_parameters
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_adding_holder_instead_quantize_wrapper(self):
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
self.assertTrue(isinstance(a.activation_holder_quantizer, STEActivationSymmetricGPTQTrainableQuantizer))
for name, module in gptq_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
self.assertTrue(len(module.weights_quantizers) > 0)
Expand All @@ -87,7 +88,7 @@ def test_adding_holder_after_relu(self):
# check that 3 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
self.assertTrue(isinstance(a.activation_holder_quantizer, STEActivationSymmetricGPTQTrainableQuantizer))
for name, module in gptq_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
self.assertTrue(len(module.weights_quantizers) > 0)
Expand All @@ -102,7 +103,7 @@ def test_adding_holders_after_reuse(self):
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
self.assertTrue(isinstance(a.activation_holder_quantizer, STEActivationSymmetricGPTQTrainableQuantizer))
for name, module in gptq_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
self.assertTrue(len(module.weights_quantizers) > 0)
Expand Down

0 comments on commit ba386d6

Please sign in to comment.