Skip to content

Commit

Permalink
make it all work end to end, save experiments for another day
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 18, 2025
1 parent f375cb7 commit 61ed6fd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from transformer_lm_gan import (
)

gan = GAN(
strategy = 'gumbel_one_hot', # or 'rotate' for rotation trick, may try combination of two if both fails in experiments
generator = dict(
num_tokens = 256,
dim = 512,
Expand Down
27 changes: 22 additions & 5 deletions transformer_lm_gan/transformer_lm_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def generate(
if not return_with_prompt:
out = out[..., prompt_seq_len:]

return out, (filter_fn, filter_thres, temperature, eps, stack(gumbel_noises))
return out, (filter_fn, filter_thres, temperature, eps, stack(gumbel_noises, dim = -2))

def forward(
self,
Expand Down Expand Up @@ -321,8 +321,25 @@ def generate_forward(
embed = cat((prompt_embed, embed[:, (prompt_len - 1):]), dim = -2)

elif self.strategy == 'gumbel_one_hot':

logits = self.generator(generated)
raise NotImplementedError
logits = logits[:, (prompt_len - 1):-1]

filter_fn, filter_thres, temperature, eps, gumbel_noises = sampling_hparams

filtered_logits = filter_fn(logits, thres = filter_thres)
filtered_logits = filtered_logits / max(temperature, eps)

noised_filtered_logits = filtered_logits + gumbel_noises

# do a classic gumble one-hot straight through

soft_prob = noised_filtered_logits.softmax(dim = -1)
soft_one_hot = soft_prob + soft_prob.detach() + F.one_hot(soft_prob.argmax(dim = -1), soft_prob.shape[-1])

embed = einsum(soft_one_hot, self.token_emb.emb.weight, 'b n e, e d -> b n d')

embed = cat((prompt_embed, embed), dim = -2)

else:
raise ValueError(f'unknown strategy')
Expand Down Expand Up @@ -359,12 +376,11 @@ def discriminate_forward(
fake_next_embeds = self.token_emb(fake[:, 1:])
fake_embed = rotate_to(fake_embed[:, :-1], fake_next_embeds)

# should not learn on the prompt portion

fake_embed = cat((prompt_embed, fake_embed[:, (prompt_len - 1):]), dim = -2)

elif self.strategy == 'gumbel_one_hot':
logits = self.generator(fake)
logits = logits[:, (prompt_len - 1):-1]

filter_fn, filter_thres, temperature, eps, gumbel_noises = sampling_hparams

Expand All @@ -378,8 +394,9 @@ def discriminate_forward(
soft_prob = noised_filtered_logits.softmax(dim = -1)
soft_one_hot = soft_prob + soft_prob.detach() + F.one_hot(soft_prob.argmax(dim = -1), soft_prob.shape[-1])

raise NotImplementedError
fake_embed = einsum(soft_one_hot, self.token_emb.emb.weight, 'b n e, e d -> b n d')

fake_embed = cat((prompt_embed, fake_embed), dim = -2)
else:
raise ValueError(f'unknown strategy {self.strategy}')

Expand Down

0 comments on commit 61ed6fd

Please sign in to comment.