Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Builder groupwise quant #1078

Draft
wants to merge 14 commits into
base: dev
Choose a base branch
from
Next Next commit
Fix (groupwise): correct log and groupdim
Giuseppe5 committed Oct 28, 2024
commit 76455504bfcad1f5e8011995cd0cb64890c84305
4 changes: 4 additions & 0 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
@@ -187,6 +187,8 @@ def __init__(
self.scaling_min_val = scaling_min_val
self.input_view_impl = input_view_impl
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_module = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module(
)

@brevitas.jit.script_method
def forward(
@@ -197,6 +199,8 @@ def forward(
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
# Apply log scaling
out = self.restrict_module(out)
# Scaling min val
out = self.restrict_clamp_scaling(out)
return out
3 changes: 2 additions & 1 deletion src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
@@ -178,7 +178,8 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None):
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output == ScalingPerOutputType.GROUP:
return group_dim + 1
reduce_dim = group_dim + 1 if group_dim != -1 else -1
return reduce_dim

@value
def keepdim(scaling_per_output):