Skip to content

Commit

Permalink
implement consistency flow matching
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlllc committed Oct 10, 2024
1 parent b93e925 commit 017ea03
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 13 deletions.
3 changes: 3 additions & 0 deletions configs/config_naivev2reflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ model:
loss_type: 'l2' # 'l1', 'l2' or 'l2_lognorm'
consistency: false
consistency_only: true
consistency_delta_t: 0.1
consistency_lambda_f: 1.0
consistency_lambda_v: 1.0
device: 'cuda'
ddp:
use_ddp: false # if true, ddp_device will cover device and gpu id
Expand Down
3 changes: 3 additions & 0 deletions configs/config_naivev2reflow_combo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ model:
naive_out_mel_cond_reflow: false # mel condition diffusion is a test function, maybe can make the model learn faster but less quality and pitch range.
consistency: false
consistency_only: true
consistency_delta_t: 0.1
consistency_lambda_f: 1.0
consistency_lambda_v: 1.0
device: 'cuda'
ddp:
use_ddp: false # if true, ddp_device will cover device and gpu id
Expand Down
3 changes: 3 additions & 0 deletions configs/config_naivev2reflow_shallow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ model:
loss_type: 'l2' # 'l1', 'l2' or 'l2_lognorm'
consistency: false
consistency_only: true
consistency_delta_t: 0.1
consistency_lambda_f: 1.0
consistency_lambda_v: 1.0
device: 'cuda'
ddp:
use_ddp: false # if true, ddp_device will cover device and gpu id
Expand Down
3 changes: 3 additions & 0 deletions configs/config_v2_reflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ model:
loss_type: 'l2_lognorm' # 'l1', 'l2' or 'l2_lognorm'
consistency: false
consistency_only: true
consistency_delta_t: 0.1
consistency_lambda_f: 1.0
consistency_lambda_v: 1.0
device: 'cuda'
ddp:
use_ddp: false # if true, ddp_device will cover device and gpu id
Expand Down
28 changes: 16 additions & 12 deletions diffusion/reflow/reflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def __init__(self,
loss_type='l2',
consistency=False,
consistency_only=True,
consistency_delta_t=0.1,
consistency_lambda_f=1.0,
consistency_lambda_v=1.0,
):
super().__init__()
self.velocity_fn = velocity_fn
Expand All @@ -26,6 +29,9 @@ def __init__(self,
self.loss_type = loss_type
self.consistency = consistency
self.consistency_only = consistency_only
self.consistency_delta_t = consistency_delta_t
self.consistency_lambda_f = consistency_lambda_f
self.consistency_lambda_v = consistency_lambda_v

def reflow_loss(self, x_1, t, cond, loss_type=None):
x_0 = torch.randn_like(x_1)
Expand Down Expand Up @@ -56,19 +62,17 @@ def reflow_consistency_loss(self, x_1, t_a, t_b, cond, loss_type=None):
x_t_b = x_0 + t_b[:, None, None, None] * (x_1 - x_0)
v_pred_a = self.velocity_fn(x_t_a, 1000 * t_a, cond)
v_pred_b = self.velocity_fn(x_t_b, 1000 * t_b, cond).detach()

f_pred_a = x_t_a + (1 - t_a[:, None, None, None]) * v_pred_a
f_pred_b = x_t_b + (1 - t_b[:, None, None, None]) * v_pred_b
if loss_type is None:
loss_type = self.loss_type
else:
loss_type = loss_type

if loss_type == 'l1':
loss = (v_pred_a - v_pred_b).abs().mean()
elif loss_type == 'l2':
loss = F.mse_loss(v_pred_a, v_pred_b)
elif loss_type == 'l2_lognorm':
weights = 0.398942 / t_a / (1 - t_a) * torch.exp(-0.5 * torch.log(t_a / ( 1 - t_a)) ** 2)
loss = torch.mean(weights[:, None, None, None] * F.mse_loss(v_pred_a, v_pred_b, reduction='none'))
if loss_type == 'l2':
loss_f = F.mse_loss(f_pred_a, f_pred_b) / self.consistency_delta_t ** 2
loss_v = F.mse_loss(v_pred_a, v_pred_b)
loss = self.consistency_lambda_f * loss_f + self.consistency_lambda_v * loss_v
else:
raise NotImplementedError()

Expand Down Expand Up @@ -196,10 +200,10 @@ def forward(self,
if self.consistency:
x_1 = self.norm_spec(gt_spec)
x_1 = x_1.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
t_a = t_start + (1.0 - t_start) * torch.rand(b, device=device)
t_b = t_start + (1.0 - t_start) * torch.rand(b, device=device)
t_a = torch.clip(t_a, 1e-7, 1 - 1e-7)
t_b = torch.clip(t_b, 1e-7, 1 - 1e-7)
t = t_start + (1.0 - t_start) * torch.rand(b, device=device)
dt = self.consistency_delta_t * torch.randn(b, device=device).abs()
t_a = torch.clip(t - 0.5 * dt, t_start, 1)
t_b = torch.clip(t + 0.5 * dt, t_start, 1)
consistency_loss = self.reflow_consistency_loss(x_1, t_a, t_b, cond=cond)
if self.consistency_only:
return consistency_loss
Expand Down
14 changes: 13 additions & 1 deletion diffusion/unit2mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def load_svc_model(args, vocoder_dimension):
loss_type=args.model.loss_type,
consistency=args.model.consistency,
consistency_only=args.model.consistency_only,
consistency_delta_t=arg.model.consistency_delta_t,
consistency_lambda_f=args.model.consistency_lambda_f,
consistency_lambda_v=args.model.consistency_lambda_v,
)

elif args.model.type == 'ReFlow1Step':
Expand Down Expand Up @@ -627,6 +630,9 @@ def spawn_decoder(self, velocity_fn, out_dims):
loss_type=self.loss_type,
consistency=self.consistency,
consistency_only=self.consistency_only,
consistency_delta_t=self.consistency_delta_t,
consistency_lambda_f=self.consistency_lambda_f,
consistency_lambda_v=self.consistency_lambda_v,
)
return decoder

Expand All @@ -651,11 +657,17 @@ def __init__(
naive_out_mel_cond_reflow=True,
loss_type='l2',
consistency=False,
consistency_only=True
consistency_only=True,
consistency_delta_t=0.1,
consistency_lambda_f=1.0,
consistency_lambda_v=1.0,
):
self.loss_type = loss_type if (loss_type is not None) else 'l2'
self.consistency = consistency if (consistency is not None) else False
self.consistency_only = consistency_only if (consistency_only is not None) else True
self.consistency_delta_t = consistency_delta_t if (consistency_delta_t is not None) else 0.1
self.consistency_lambda_f = consistency_lambda_f if (consistency_lambda_f is not None) else 1.0
self.consistency_lambda_v = consistency_lambda_v if (consistency_lambda_v is not None) else 1.0
super().__init__(
input_channel,
n_spk,
Expand Down

0 comments on commit 017ea03

Please sign in to comment.