Skip to content

Commit

Permalink
support multi-label and arbitrary schedule length
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyu-guo committed Jan 23, 2025
1 parent ffeb0ed commit 5abf1e4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
26 changes: 13 additions & 13 deletions SmoothCache/calibration/calibration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
components_to_wrap: List[str],
calibration_lookahead: int = 3,
calibration_threshold: float = 0.0,
schedule_length: int = 50,
log_file: str = "calibration_schedule.json"
):
"""
Expand All @@ -51,6 +52,7 @@ def __init__(
self.components_to_wrap = components_to_wrap
self.calibration_lookahead = calibration_lookahead
self.calibration_threshold = calibration_threshold
self.schedule_length = schedule_length
self.log_file = log_file

# Tracking original forward methods
Expand Down Expand Up @@ -171,7 +173,6 @@ def wrapped_forward(*args, **kwargs):
# Compute error
error = rel_l1_loss(prev_output, current_output)
self.calibration_results[full_name][j].append(error)
print(len(self.calibration_results[full_name][j]))
\
# Update previous outputs
self.previous_layer_outputs[full_name].insert(0, current_output.detach().clone())
Expand All @@ -192,13 +193,13 @@ def generate_schedule(self):
and group all blocks that share that same component_full.
Each group yields 3 arrays: row0, row1, row2, averaged across all blocks,
then scanned to produce a 50-length schedule.
then scanned to produce the schedule.
Returns:
A dictionary like:
{
'attn1': [50-length schedule],
'mlp1': [50-length schedule],
'attn1': [schedule_length schedule],
'mlp1': [schedule_length schedule],
...
}
"""
Expand All @@ -225,7 +226,7 @@ def generate_schedule(self):

final_schedules = {}

# Step B: For each component_full, average row0, row1, row2, then produce 50-length schedule
# Step B: For each component_full, average row0, row1, row2, then produce schedule
for component_full, row_lists in component_to_rows.items():
row0_arrays = row_lists[0] # list of np arrays for row0
row1_arrays = row_lists[1] # row1
Expand Down Expand Up @@ -271,7 +272,7 @@ def _average_arrays(self, array_list):
def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold):
"""
Based on your scanning logic:
- We produce a schedule of length 50.
- We produce a schedule of length schedule_length.
- schedule[0] = 1
- For each i in [1..49], we check row2[i-1], row1[i-1], row0[i-1] (if in range)
* if row2[i-1] <= threshold => schedule i=1, i+1..i+3=0, skip i+4
Expand All @@ -281,11 +282,10 @@ def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold):
- Finally override schedule[49] = 1
"""

schedule = [None]*50
schedule = [None] * self.schedule_length
i = 0

while i < 50:
# breakpoint()
while i < self.schedule_length:
idx = i # to read from row2, row1, row0
used = False

Expand All @@ -294,22 +294,22 @@ def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold):
if row2_list[idx] <= threshold:
schedule[i] = 1
for skip_step in (i+1, i+2, i+3):
if skip_step < 50:
if skip_step < self.schedule_length:
schedule[skip_step] = 0
i += 4
used = True
if not used and idx < len(row1_list):
if row1_list[idx] <= threshold:
schedule[i] = 1
for skip_step in (i+1, i+2):
if skip_step < 50:
if skip_step < self.schedule_length:
schedule[skip_step] = 0
i += 3
used = True
if not used and idx < len(row0_list):
if row0_list[idx] <= threshold:
schedule[i] = 1
if i+1 < 50:
if i+1 < self.schedule_length:
schedule[i+1] = 0
i += 2
used = True
Expand All @@ -325,7 +325,7 @@ def _scan_3row_sublists(self, row0_list, row1_list, row2_list, threshold):
schedule[-1] = 1

# fill any None with 1
for x in range(50):
for x in range(self.schedule_length):
if schedule[x] is None:
schedule[x] = 1

Expand Down
3 changes: 3 additions & 0 deletions SmoothCache/calibration/diffuser_calibration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
model: nn.Module,
calibration_lookahead: int = 3,
calibration_threshold: float = 0.0,
schedule_length: int = 50,
log_file: str = "calibration_schedule.json"
):
"""
Expand All @@ -24,6 +25,7 @@ def __init__(
model (nn.Module): The model to wrap (e.g., pipe.transformer).
calibration_lookahead (int): Steps to look back for error calculation.
calibration_threshold (float): Cutoff L1 error value to enable caching.
schedule_length (int): Length of the generated schedule, 1:1 mapped to pipeline timesteps
log_file (str): Path to save the generated schedule JSON.
Raises:
Expand All @@ -41,5 +43,6 @@ def __init__(
components_to_wrap=components_to_wrap,
calibration_lookahead=calibration_lookahead,
calibration_threshold=calibration_threshold,
schedule_length=schedule_length,
log_file=log_file
)
13 changes: 9 additions & 4 deletions examples/run_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,29 @@ def main():
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")


num_inference_steps = 50
# Initialize calibration helper
calibration_helper = DiffuserCalibrationHelper(
model=pipe.transformer,
calibration_lookahead=3,
calibration_threshold=0.15,
schedule_length=num_inference_steps, # should be consistent with num_inference_steps below
log_file="smoothcache_schedules/diffuser_schedule.json"
)

# Enable calibration
calibration_helper.enable()

# Run pipeline normally
words = ["Labrador retriever"]
words = ["Labrador retriever", "combination lock", "cassette player"]

class_ids = pipe.get_label_ids(words)

generator = torch.manual_seed(33)
images = pipe(
class_labels=class_ids,
num_inference_steps=50,
num_inference_steps=num_inference_steps,
generator=generator
).images # Normal pipeline call

Expand All @@ -35,8 +40,8 @@ def main():

print("Calibration complete. Schedule saved to smoothcache_schedules/diffuser_schedule.json")

# breakpoint()
images[0].save('generated_image_cached.png')
for prompt, image in zip(words, images):
image.save(prompt + '.png')

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion smoothcache_schedules/diffuser_schedule.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"attn1": [1,0,1,0,0,1,0,1,0,1,0,1,0,1,1,1,1,1,1,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,1]
"attn1": [1,0,1,0,0,1,0,1,0,1,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,1,1]
}

0 comments on commit 5abf1e4

Please sign in to comment.