Skip to content

Commit

Permalink
support arbitrary k
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyu-guo committed Jan 29, 2025
1 parent 5abf1e4 commit 4520e38
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 67 deletions.
140 changes: 73 additions & 67 deletions SmoothCache/calibration/calibration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,61 +184,70 @@ def wrapped_forward(*args, **kwargs):

def generate_schedule(self):
"""
Generate schedules for each exact component name (e.g. 'attn1', 'mlp1', etc.)
using the 3-row scanning logic.
Generate schedules for each exact component name (e.g., 'attn1', 'mlp1', etc.)
using n-row scanning logic, where n is arbitrary based on calibration_lookahead.
For example, if self.calibration_results has keys:
'transformer_blocks.0.attn1', 'transformer_blocks.1.attn1', 'transformer_blocks.0.mlp1', ...
we parse out the last part (e.g. 'attn1', 'mlp1') as `component_full`,
we parse out the last part (e.g., 'attn1', 'mlp1') as `component_full`,
and group all blocks that share that same component_full.
Each group yields 3 arrays: row0, row1, row2, averaged across all blocks,
Each group yields n arrays: row0, row1, ..., row(n-1)_list, averaged across all blocks,
then scanned to produce the schedule.
Returns:
A dictionary like:
{
'attn1': [schedule_length schedule],
'mlp1': [schedule_length schedule],
...
'attn1': [schedule_length schedule],
'mlp1': [schedule_length schedule],
...
}
"""
import numpy as np
from collections import defaultdict

# Dictionary: component_name -> [ list_of_arrays_row0, list_of_arrays_row1, list_of_arrays_row2 ]
component_to_rows = defaultdict(lambda: [[], [], []])
# Dictionary: component_name -> list of lists for each row
component_to_rows = defaultdict(list)

# Step A: Collect row0, row1, row2 arrays by exact component name
# Step A: Collect row arrays by exact component name
for full_name, sublists in self.calibration_results.items():
if len(sublists) < self.calibration_lookahead:
# skip if incomplete
continue

# e.g. 'transformer_blocks.0.attn1' => component_full='attn1'
component_full = full_name.split('.')[-1] # e.g. 'attn1'
# e.g., 'transformer_blocks.0.attn1' => component_full='attn1'
component_full = full_name.split('.')[-1] # e.g., 'attn1'

# sublists is e.g. [row0_list, row1_list, row2_list]
# convert each to numpy array
for row_idx in range(len(sublists)):
arr = np.array(sublists[row_idx], dtype=float)
component_to_rows[component_full][row_idx].append(arr)
# sublists is a list of row arrays for this component
component_to_rows[component_full].append(sublists)

final_schedules = {}

# Step B: For each component_full, average row0, row1, row2, then produce schedule
for component_full, row_lists in component_to_rows.items():
row0_arrays = row_lists[0] # list of np arrays for row0
row1_arrays = row_lists[1] # row1
row2_arrays = row_lists[2] # row2
breakpoint()

# Step B: For each component_full, average rows and produce schedule
for component_full, sublist_groups in component_to_rows.items():
# Assuming each sublist_group has the same number of rows (calibration_lookahead)
num_rows = len(sublist_groups[0]) if sublist_groups else 0

avg0_list = self._average_arrays(row0_arrays) # length ~ 49
avg1_list = self._average_arrays(row1_arrays) # length ~ 48
avg2_list = self._average_arrays(row2_arrays) # length ~ 47
# Average each row across all blocks
averaged_rows = []
for row_idx in range(num_rows):
row_arrays = [sublist[row_idx] for sublist in sublist_groups]
avg_row_list = self._average_arrays(row_arrays)
averaged_rows.append(avg_row_list)

schedule = self._scan_3row_sublists(avg0_list, avg1_list, avg2_list, self.calibration_threshold)

breakpoint()
print(averaged_rows[0])
print(averaged_rows[1])
print(averaged_rows[2])

schedule = self._scan_nrows_sublists(averaged_rows, self.calibration_threshold)
final_schedules[component_full] = schedule

breakpoint()

print(final_schedules)
return final_schedules

Expand Down Expand Up @@ -269,69 +278,66 @@ def _average_arrays(self, array_list):
avg_arr[i] = sum_vals[i] / count_vals[i]
return avg_arr.tolist()

def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold):
"""
Based on your scanning logic:
- We produce a schedule of length schedule_length.
- schedule[0] = 1
- For each i in [1..49], we check row2[i-1], row1[i-1], row0[i-1] (if in range)
* if row2[i-1] <= threshold => schedule i=1, i+1..i+3=0, skip i+4
* else if row1[i-1] <= threshold => schedule i=1, i+1..i+2=0, skip i+3
* else if row0[i-1] <= threshold => schedule i=1, i+1=0, skip i+2
* else => schedule i=1, skip i+1
- Finally override schedule[49] = 1
def _scan_nrows_sublists(self, row_lists, threshold):
"""
Scan through multiple rows (arbitrary number) in reverse order to produce a schedule.
Parameters:
row_lists (list of lists): A list where each element is a row's list of values
ordered from highest priority to lowest.
threshold (float): The threshold value to check against.
Returns:
schedule (list): The generated schedule based on the scanning logic.
"""
schedule = [None] * self.schedule_length
i = 0

while i < self.schedule_length:
idx = i # to read from row2, row1, row0
idx = i
used = False

# check row2 if idx < len(row2_list)
if idx < len(row2_list):
if row2_list[idx] <= threshold:
schedule[i] = 1
for skip_step in (i+1, i+2, i+3):
if skip_step < self.schedule_length:
schedule[skip_step] = 0
i += 4
used = True
if not used and idx < len(row1_list):
if row1_list[idx] <= threshold:
# Iterate through each row in reverse order (highest priority first)
for row_idx in range(len(row_lists)-1, -1, -1):
current_row_list = row_lists[row_idx]
if idx >= len(current_row_list):
continue # Skip if index is out of bounds for this row

if current_row_list[idx] <= threshold:
# Activate the current step
schedule[i] = 1
for skip_step in (i+1, i+2):

# Determine how many steps to skip based on the row priority
num_skips = row_idx + 1 # More skips for higher priority rows
skip_steps = []
for s in range(1, num_skips + 1):
skip_step = i + s
if skip_step < self.schedule_length:
schedule[skip_step] = 0
i += 3
used = True
if not used and idx < len(row0_list):
if row0_list[idx] <= threshold:
schedule[i] = 1
if i+1 < self.schedule_length:
schedule[i+1] = 0
i += 2
skip_steps.append(skip_step)

# Move the index past the skipped steps
i += (num_skips + 1) # Move to the step after the last skip
used = True
break

if not used:
# fallback => schedule[i]=1
# Fallback: Activate current step without skipping
schedule[i] = 1
i += 1
# print(schedule)
# breakpoint()

# override schedule[49] = 1
schedule[0] = 1
schedule[-1] = 1
# Override the first and last steps to be active
if self.schedule_length > 0:
schedule[0] = 1
schedule[-1] = 1

# fill any None with 1
# Fill any remaining None values with 1
for x in range(self.schedule_length):
if schedule[x] is None:
schedule[x] = 1

return schedule


def get_module_by_name(self, model, full_name):
"""
Utility to retrieve a module by full name.
Expand Down
1 change: 1 addition & 0 deletions examples/run_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main():

# Run pipeline normally
words = ["Labrador retriever", "combination lock", "cassette player"]


class_ids = pipe.get_label_ids(words)

Expand Down

0 comments on commit 4520e38

Please sign in to comment.