Skip to content

Commit

Permalink
the prompt portion of the generated sequences should be the straight …
Browse files Browse the repository at this point in the history
…token embeds
  • Loading branch information
lucidrains committed Feb 18, 2025
1 parent 0544b82 commit f375cb7
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions transformer_lm_gan/transformer_lm_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def generate(
if not return_with_prompt:
out = out[..., prompt_seq_len:]

return out, (filter_fn, filter_thres, temperature, stack(gumbel_noises))
return out, (filter_fn, filter_thres, temperature, eps, stack(gumbel_noises))

def forward(
self,
Expand Down Expand Up @@ -306,14 +306,27 @@ def generate_forward(

prompt = seq[:, :(seq_len // 2)]

generated, sampling_hparams = self.generator.generate(prompt, seq_len, **generate_kwargs)
prompt_len = prompt.shape[-1]
prompt_embed = self.token_emb(prompt[:, :-1])

embed = self.generator(generated, return_only_embed = True)
generated, sampling_hparams = self.generator.generate(prompt, seq_len, **generate_kwargs)

if self.strategy == 'rotate':
embed = self.generator(generated, return_only_embed = True)
next_embeds = self.token_emb(generated[:, 1:])
embed = rotate_to(embed[:, :-1], next_embeds)

# should not learn on the prompt portion

embed = cat((prompt_embed, embed[:, (prompt_len - 1):]), dim = -2)

elif self.strategy == 'gumbel_one_hot':
logits = self.generator(generated)
raise NotImplementedError

else:
raise ValueError(f'unknown strategy')

logits = self.discriminator(embed)

loss = generator_hinge_loss(logits)
Expand All @@ -333,15 +346,43 @@ def discriminate_forward(
real = seq

prompt = seq[:, :(seq_len // 2)]

prompt_len = prompt.shape[-1]
prompt_embed = self.token_emb(prompt[:, :-1])

fake, sampling_hparams = self.generator.generate(prompt, seq_len, **generate_kwargs)

real_embed = self.token_emb(real[:, :-1])
fake_embed = self.generator(fake, return_only_embed = True)

if self.strategy == 'rotate':
fake_embed = self.generator(fake, return_only_embed = True)
fake_next_embeds = self.token_emb(fake[:, 1:])
fake_embed = rotate_to(fake_embed[:, :-1], fake_next_embeds)

# should not learn on the prompt portion

fake_embed = cat((prompt_embed, fake_embed[:, (prompt_len - 1):]), dim = -2)

elif self.strategy == 'gumbel_one_hot':
logits = self.generator(fake)

filter_fn, filter_thres, temperature, eps, gumbel_noises = sampling_hparams

filtered_logits = filter_fn(logits, thres = filter_thres)
filtered_logits = filtered_logits / max(temperature, eps)

noised_filtered_logits = filtered_logits + gumbel_noises

# do a classic gumble one-hot straight through

soft_prob = noised_filtered_logits.softmax(dim = -1)
soft_one_hot = soft_prob + soft_prob.detach() + F.one_hot(soft_prob.argmax(dim = -1), soft_prob.shape[-1])

raise NotImplementedError

else:
raise ValueError(f'unknown strategy {self.strategy}')

discr_input, packed_shape = pack((real_embed, fake_embed), '* n d')

if apply_grad_penalty:
Expand Down

0 comments on commit f375cb7

Please sign in to comment.