Skip to content

Commit

Permalink
Fix index_select torch-mlir lowering issue
Browse files Browse the repository at this point in the history
  • Loading branch information
aviator19941 committed Mar 15, 2024
1 parent 8df7c2f commit f43f2c6
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,13 @@ def scale_model_input(
`torch.FloatTensor`:
A scaled input sample.
"""
dtype = sample.dtype
if self.step_index is None:
self._init_step_index(timestep)

sigma = self.sigmas.index_select(0, self.step_index)
sample = sample / ((sigma**2 + 1) ** 0.5)
sample = sample.to(dtype)

self.is_scale_input_called = True
return sample
Expand Down Expand Up @@ -351,7 +353,14 @@ def _init_step_index(self, timestep):

eq = torch.eq(self.timesteps, timestep)
index_candidates = torch.where(eq)[0]
index = torch.where(torch.scalar_tensor(torch.numel(index_candidates)) > 1, 1, 0)
# index_candidates = torch.where(self.timesteps == timestep)[0]
# index = torch.where(torch.scalar_tensor(torch.numel(index_candidates)) > 1, 1, 0)
a = torch.numel(index_candidates)
cond = torch.scalar_tensor(a)
one = torch.scalar_tensor(1, dtype=torch.int64)
zero = torch.scalar_tensor(0, dtype=torch.int64)
index = torch.where(cond > 1, one, zero)
index = index.unsqueeze(0)
step_index = index_candidates.index_select(0, index)
self._step_index = step_index

Expand Down

0 comments on commit f43f2c6

Please sign in to comment.