Skip to content

Commit

Permalink
ack, forgotten how to build GANs, generator output is detached when t…
Browse files Browse the repository at this point in the history
…raining the discriminator
  • Loading branch information
lucidrains committed Feb 18, 2025
1 parent 61ed6fd commit f3a614e
Showing 1 changed file with 1 addition and 28 deletions.
29 changes: 1 addition & 28 deletions transformer_lm_gan/transformer_lm_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,34 +371,7 @@ def discriminate_forward(

real_embed = self.token_emb(real[:, :-1])

if self.strategy == 'rotate':
fake_embed = self.generator(fake, return_only_embed = True)
fake_next_embeds = self.token_emb(fake[:, 1:])
fake_embed = rotate_to(fake_embed[:, :-1], fake_next_embeds)

fake_embed = cat((prompt_embed, fake_embed[:, (prompt_len - 1):]), dim = -2)

elif self.strategy == 'gumbel_one_hot':
logits = self.generator(fake)
logits = logits[:, (prompt_len - 1):-1]

filter_fn, filter_thres, temperature, eps, gumbel_noises = sampling_hparams

filtered_logits = filter_fn(logits, thres = filter_thres)
filtered_logits = filtered_logits / max(temperature, eps)

noised_filtered_logits = filtered_logits + gumbel_noises

# do a classic gumble one-hot straight through

soft_prob = noised_filtered_logits.softmax(dim = -1)
soft_one_hot = soft_prob + soft_prob.detach() + F.one_hot(soft_prob.argmax(dim = -1), soft_prob.shape[-1])

fake_embed = einsum(soft_one_hot, self.token_emb.emb.weight, 'b n e, e d -> b n d')

fake_embed = cat((prompt_embed, fake_embed), dim = -2)
else:
raise ValueError(f'unknown strategy {self.strategy}')
fake_embed = self.token_emb(fake[:, :-1])

discr_input, packed_shape = pack((real_embed, fake_embed), '* n d')

Expand Down

0 comments on commit f3a614e

Please sign in to comment.