From 5abf1e4d7a1dcb9762b7f508725bb787ec029ab5 Mon Sep 17 00:00:00 2001 From: Ziyu Guo Date: Thu, 23 Jan 2025 23:58:31 +0000 Subject: [PATCH] support multi-label and arbitrary schedule length --- SmoothCache/calibration/calibration_helper.py | 26 +++++++++---------- .../diffuser_calibration_helper.py | 3 +++ examples/run_calibration.py | 13 +++++++--- smoothcache_schedules/diffuser_schedule.json | 2 +- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/SmoothCache/calibration/calibration_helper.py b/SmoothCache/calibration/calibration_helper.py index b500ec0..6246ba6 100644 --- a/SmoothCache/calibration/calibration_helper.py +++ b/SmoothCache/calibration/calibration_helper.py @@ -34,6 +34,7 @@ def __init__( components_to_wrap: List[str], calibration_lookahead: int = 3, calibration_threshold: float = 0.0, + schedule_length: int = 50, log_file: str = "calibration_schedule.json" ): """ @@ -51,6 +52,7 @@ def __init__( self.components_to_wrap = components_to_wrap self.calibration_lookahead = calibration_lookahead self.calibration_threshold = calibration_threshold + self.schedule_length = schedule_length self.log_file = log_file # Tracking original forward methods @@ -171,7 +173,6 @@ def wrapped_forward(*args, **kwargs): # Compute error error = rel_l1_loss(prev_output, current_output) self.calibration_results[full_name][j].append(error) - print(len(self.calibration_results[full_name][j])) \ # Update previous outputs self.previous_layer_outputs[full_name].insert(0, current_output.detach().clone()) @@ -192,13 +193,13 @@ def generate_schedule(self): and group all blocks that share that same component_full. Each group yields 3 arrays: row0, row1, row2, averaged across all blocks, - then scanned to produce a 50-length schedule. + then scanned to produce the schedule. Returns: A dictionary like: { - 'attn1': [50-length schedule], - 'mlp1': [50-length schedule], + 'attn1': [schedule_length schedule], + 'mlp1': [schedule_length schedule], ... } """ @@ -225,7 +226,7 @@ def generate_schedule(self): final_schedules = {} - # Step B: For each component_full, average row0, row1, row2, then produce 50-length schedule + # Step B: For each component_full, average row0, row1, row2, then produce schedule for component_full, row_lists in component_to_rows.items(): row0_arrays = row_lists[0] # list of np arrays for row0 row1_arrays = row_lists[1] # row1 @@ -271,7 +272,7 @@ def _average_arrays(self, array_list): def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold): """ Based on your scanning logic: - - We produce a schedule of length 50. + - We produce a schedule of length schedule_length. - schedule[0] = 1 - For each i in [1..49], we check row2[i-1], row1[i-1], row0[i-1] (if in range) * if row2[i-1] <= threshold => schedule i=1, i+1..i+3=0, skip i+4 @@ -281,11 +282,10 @@ def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold): - Finally override schedule[49] = 1 """ - schedule = [None]*50 + schedule = [None] * self.schedule_length i = 0 - while i < 50: - # breakpoint() + while i < self.schedule_length: idx = i # to read from row2, row1, row0 used = False @@ -294,7 +294,7 @@ def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold): if row2_list[idx] <= threshold: schedule[i] = 1 for skip_step in (i+1, i+2, i+3): - if skip_step < 50: + if skip_step < self.schedule_length: schedule[skip_step] = 0 i += 4 used = True @@ -302,14 +302,14 @@ def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold): if row1_list[idx] <= threshold: schedule[i] = 1 for skip_step in (i+1, i+2): - if skip_step < 50: + if skip_step < self.schedule_length: schedule[skip_step] = 0 i += 3 used = True if not used and idx < len(row0_list): if row0_list[idx] <= threshold: schedule[i] = 1 - if i+1 < 50: + if i+1 < self.schedule_length: schedule[i+1] = 0 i += 2 used = True @@ -325,7 +325,7 @@ def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold): schedule[-1] = 1 # fill any None with 1 - for x in range(50): + for x in range(self.schedule_length): if schedule[x] is None: schedule[x] = 1 diff --git a/SmoothCache/calibration/diffuser_calibration_helper.py b/SmoothCache/calibration/diffuser_calibration_helper.py index c4281af..68d4da5 100644 --- a/SmoothCache/calibration/diffuser_calibration_helper.py +++ b/SmoothCache/calibration/diffuser_calibration_helper.py @@ -15,6 +15,7 @@ def __init__( model: nn.Module, calibration_lookahead: int = 3, calibration_threshold: float = 0.0, + schedule_length: int = 50, log_file: str = "calibration_schedule.json" ): """ @@ -24,6 +25,7 @@ def __init__( model (nn.Module): The model to wrap (e.g., pipe.transformer). calibration_lookahead (int): Steps to look back for error calculation. calibration_threshold (float): Cutoff L1 error value to enable caching. + schedule_length (int): Length of the generated schedule, 1:1 mapped to pipeline timesteps log_file (str): Path to save the generated schedule JSON. Raises: @@ -41,5 +43,6 @@ def __init__( components_to_wrap=components_to_wrap, calibration_lookahead=calibration_lookahead, calibration_threshold=calibration_threshold, + schedule_length=schedule_length, log_file=log_file ) diff --git a/examples/run_calibration.py b/examples/run_calibration.py index e08a3e7..2611362 100644 --- a/examples/run_calibration.py +++ b/examples/run_calibration.py @@ -9,11 +9,14 @@ def main(): pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") + + num_inference_steps = 50 # Initialize calibration helper calibration_helper = DiffuserCalibrationHelper( model=pipe.transformer, calibration_lookahead=3, calibration_threshold=0.15, + schedule_length=num_inference_steps, # should be consistent with num_inference_steps below log_file="smoothcache_schedules/diffuser_schedule.json" ) @@ -21,12 +24,14 @@ def main(): calibration_helper.enable() # Run pipeline normally - words = ["Labrador retriever"] + words = ["Labrador retriever", "combination lock", "cassette player"] + class_ids = pipe.get_label_ids(words) + generator = torch.manual_seed(33) images = pipe( class_labels=class_ids, - num_inference_steps=50, + num_inference_steps=num_inference_steps, generator=generator ).images # Normal pipeline call @@ -35,8 +40,8 @@ def main(): print("Calibration complete. Schedule saved to smoothcache_schedules/diffuser_schedule.json") - # breakpoint() - images[0].save('generated_image_cached.png') + for prompt, image in zip(words, images): + image.save(prompt + '.png') if __name__ == "__main__": main() diff --git a/smoothcache_schedules/diffuser_schedule.json b/smoothcache_schedules/diffuser_schedule.json index ccdcf1b..a274a0e 100644 --- a/smoothcache_schedules/diffuser_schedule.json +++ b/smoothcache_schedules/diffuser_schedule.json @@ -1,3 +1,3 @@ { - "attn1": [1,0,1,0,0,1,0,1,0,1,0,1,0,1,1,1,1,1,1,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,1] + "attn1": [1,0,1,0,0,1,0,1,0,1,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,1,1] }