diff --git a/SmoothCache/calibration/calibration_helper.py b/SmoothCache/calibration/calibration_helper.py index 6246ba6..5b0d304 100644 --- a/SmoothCache/calibration/calibration_helper.py +++ b/SmoothCache/calibration/calibration_helper.py @@ -184,61 +184,70 @@ def wrapped_forward(*args, **kwargs): def generate_schedule(self): """ - Generate schedules for each exact component name (e.g. 'attn1', 'mlp1', etc.) - using the 3-row scanning logic. - + Generate schedules for each exact component name (e.g., 'attn1', 'mlp1', etc.) + using n-row scanning logic, where n is arbitrary based on calibration_lookahead. + For example, if self.calibration_results has keys: 'transformer_blocks.0.attn1', 'transformer_blocks.1.attn1', 'transformer_blocks.0.mlp1', ... - we parse out the last part (e.g. 'attn1', 'mlp1') as `component_full`, + we parse out the last part (e.g., 'attn1', 'mlp1') as `component_full`, and group all blocks that share that same component_full. - Each group yields 3 arrays: row0, row1, row2, averaged across all blocks, + Each group yields n arrays: row0, row1, ..., row(n-1)_list, averaged across all blocks, then scanned to produce the schedule. Returns: A dictionary like: { - 'attn1': [schedule_length schedule], - 'mlp1': [schedule_length schedule], - ... + 'attn1': [schedule_length schedule], + 'mlp1': [schedule_length schedule], + ... } """ import numpy as np from collections import defaultdict - # Dictionary: component_name -> [ list_of_arrays_row0, list_of_arrays_row1, list_of_arrays_row2 ] - component_to_rows = defaultdict(lambda: [[], [], []]) + # Dictionary: component_name -> list of lists for each row + component_to_rows = defaultdict(list) - # Step A: Collect row0, row1, row2 arrays by exact component name + # Step A: Collect row arrays by exact component name for full_name, sublists in self.calibration_results.items(): if len(sublists) < self.calibration_lookahead: # skip if incomplete continue - # e.g. 'transformer_blocks.0.attn1' => component_full='attn1' - component_full = full_name.split('.')[-1] # e.g. 'attn1' + # e.g., 'transformer_blocks.0.attn1' => component_full='attn1' + component_full = full_name.split('.')[-1] # e.g., 'attn1' - # sublists is e.g. [row0_list, row1_list, row2_list] - # convert each to numpy array - for row_idx in range(len(sublists)): - arr = np.array(sublists[row_idx], dtype=float) - component_to_rows[component_full][row_idx].append(arr) + # sublists is a list of row arrays for this component + component_to_rows[component_full].append(sublists) final_schedules = {} - # Step B: For each component_full, average row0, row1, row2, then produce schedule - for component_full, row_lists in component_to_rows.items(): - row0_arrays = row_lists[0] # list of np arrays for row0 - row1_arrays = row_lists[1] # row1 - row2_arrays = row_lists[2] # row2 + breakpoint() + + # Step B: For each component_full, average rows and produce schedule + for component_full, sublist_groups in component_to_rows.items(): + # Assuming each sublist_group has the same number of rows (calibration_lookahead) + num_rows = len(sublist_groups[0]) if sublist_groups else 0 - avg0_list = self._average_arrays(row0_arrays) # length ~ 49 - avg1_list = self._average_arrays(row1_arrays) # length ~ 48 - avg2_list = self._average_arrays(row2_arrays) # length ~ 47 + # Average each row across all blocks + averaged_rows = [] + for row_idx in range(num_rows): + row_arrays = [sublist[row_idx] for sublist in sublist_groups] + avg_row_list = self._average_arrays(row_arrays) + averaged_rows.append(avg_row_list) - schedule = self._scan_3row_sublists(avg0_list, avg1_list, avg2_list, self.calibration_threshold) + + breakpoint() + print(averaged_rows[0]) + print(averaged_rows[1]) + print(averaged_rows[2]) + + schedule = self._scan_nrows_sublists(averaged_rows, self.calibration_threshold) final_schedules[component_full] = schedule + breakpoint() + print(final_schedules) return final_schedules @@ -269,69 +278,66 @@ def _average_arrays(self, array_list): avg_arr[i] = sum_vals[i] / count_vals[i] return avg_arr.tolist() - def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold): - """ - Based on your scanning logic: - - We produce a schedule of length schedule_length. - - schedule[0] = 1 - - For each i in [1..49], we check row2[i-1], row1[i-1], row0[i-1] (if in range) - * if row2[i-1] <= threshold => schedule i=1, i+1..i+3=0, skip i+4 - * else if row1[i-1] <= threshold => schedule i=1, i+1..i+2=0, skip i+3 - * else if row0[i-1] <= threshold => schedule i=1, i+1=0, skip i+2 - * else => schedule i=1, skip i+1 - - Finally override schedule[49] = 1 + def _scan_nrows_sublists(self, row_lists, threshold): """ + Scan through multiple rows (arbitrary number) in reverse order to produce a schedule. + + Parameters: + row_lists (list of lists): A list where each element is a row's list of values + ordered from highest priority to lowest. + threshold (float): The threshold value to check against. + Returns: + schedule (list): The generated schedule based on the scanning logic. + """ schedule = [None] * self.schedule_length i = 0 while i < self.schedule_length: - idx = i # to read from row2, row1, row0 + idx = i used = False - # check row2 if idx < len(row2_list) - if idx < len(row2_list): - if row2_list[idx] <= threshold: - schedule[i] = 1 - for skip_step in (i+1, i+2, i+3): - if skip_step < self.schedule_length: - schedule[skip_step] = 0 - i += 4 - used = True - if not used and idx < len(row1_list): - if row1_list[idx] <= threshold: + # Iterate through each row in reverse order (highest priority first) + for row_idx in range(len(row_lists)-1, -1, -1): + current_row_list = row_lists[row_idx] + if idx >= len(current_row_list): + continue # Skip if index is out of bounds for this row + + if current_row_list[idx] <= threshold: + # Activate the current step schedule[i] = 1 - for skip_step in (i+1, i+2): + + # Determine how many steps to skip based on the row priority + num_skips = row_idx + 1 # More skips for higher priority rows + skip_steps = [] + for s in range(1, num_skips + 1): + skip_step = i + s if skip_step < self.schedule_length: schedule[skip_step] = 0 - i += 3 - used = True - if not used and idx < len(row0_list): - if row0_list[idx] <= threshold: - schedule[i] = 1 - if i+1 < self.schedule_length: - schedule[i+1] = 0 - i += 2 + skip_steps.append(skip_step) + + # Move the index past the skipped steps + i += (num_skips + 1) # Move to the step after the last skip used = True + break + if not used: - # fallback => schedule[i]=1 + # Fallback: Activate current step without skipping schedule[i] = 1 i += 1 - # print(schedule) - # breakpoint() - # override schedule[49] = 1 - schedule[0] = 1 - schedule[-1] = 1 + # Override the first and last steps to be active + if self.schedule_length > 0: + schedule[0] = 1 + schedule[-1] = 1 - # fill any None with 1 + # Fill any remaining None values with 1 for x in range(self.schedule_length): if schedule[x] is None: schedule[x] = 1 return schedule - def get_module_by_name(self, model, full_name): """ Utility to retrieve a module by full name. diff --git a/examples/run_calibration.py b/examples/run_calibration.py index 2611362..d9cf833 100644 --- a/examples/run_calibration.py +++ b/examples/run_calibration.py @@ -25,6 +25,7 @@ def main(): # Run pipeline normally words = ["Labrador retriever", "combination lock", "cassette player"] + class_ids = pipe.get_label_ids(words)