diff --git a/README.md b/README.md index b3562ac..a87f3c1 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ from transformer_lm_gan import ( ) gan = GAN( + strategy = 'gumbel_one_hot', # or 'rotate' for rotation trick, may try combination of two if both fails in experiments generator = dict( num_tokens = 256, dim = 512, diff --git a/transformer_lm_gan/transformer_lm_gan.py b/transformer_lm_gan/transformer_lm_gan.py index a665c03..322c0ea 100644 --- a/transformer_lm_gan/transformer_lm_gan.py +++ b/transformer_lm_gan/transformer_lm_gan.py @@ -216,7 +216,7 @@ def generate( if not return_with_prompt: out = out[..., prompt_seq_len:] - return out, (filter_fn, filter_thres, temperature, eps, stack(gumbel_noises)) + return out, (filter_fn, filter_thres, temperature, eps, stack(gumbel_noises, dim = -2)) def forward( self, @@ -321,8 +321,25 @@ def generate_forward( embed = cat((prompt_embed, embed[:, (prompt_len - 1):]), dim = -2) elif self.strategy == 'gumbel_one_hot': + logits = self.generator(generated) - raise NotImplementedError + logits = logits[:, (prompt_len - 1):-1] + + filter_fn, filter_thres, temperature, eps, gumbel_noises = sampling_hparams + + filtered_logits = filter_fn(logits, thres = filter_thres) + filtered_logits = filtered_logits / max(temperature, eps) + + noised_filtered_logits = filtered_logits + gumbel_noises + + # do a classic gumble one-hot straight through + + soft_prob = noised_filtered_logits.softmax(dim = -1) + soft_one_hot = soft_prob + soft_prob.detach() + F.one_hot(soft_prob.argmax(dim = -1), soft_prob.shape[-1]) + + embed = einsum(soft_one_hot, self.token_emb.emb.weight, 'b n e, e d -> b n d') + + embed = cat((prompt_embed, embed), dim = -2) else: raise ValueError(f'unknown strategy') @@ -359,12 +376,11 @@ def discriminate_forward( fake_next_embeds = self.token_emb(fake[:, 1:]) fake_embed = rotate_to(fake_embed[:, :-1], fake_next_embeds) - # should not learn on the prompt portion - fake_embed = cat((prompt_embed, fake_embed[:, (prompt_len - 1):]), dim = -2) elif self.strategy == 'gumbel_one_hot': logits = self.generator(fake) + logits = logits[:, (prompt_len - 1):-1] filter_fn, filter_thres, temperature, eps, gumbel_noises = sampling_hparams @@ -378,8 +394,9 @@ def discriminate_forward( soft_prob = noised_filtered_logits.softmax(dim = -1) soft_one_hot = soft_prob + soft_prob.detach() + F.one_hot(soft_prob.argmax(dim = -1), soft_prob.shape[-1]) - raise NotImplementedError + fake_embed = einsum(soft_one_hot, self.token_emb.emb.weight, 'b n e, e d -> b n d') + fake_embed = cat((prompt_embed, fake_embed), dim = -2) else: raise ValueError(f'unknown strategy {self.strategy}')