diff --git a/torchmoji/lstm.py b/torchmoji/lstm.py index 67ed0e1..7b6b745 100644 --- a/torchmoji/lstm.py +++ b/torchmoji/lstm.py @@ -75,7 +75,7 @@ def reset_parameters(self): def forward(self, input, hx=None): is_packed = isinstance(input, PackedSequence) if is_packed: - input, batch_sizes = input + input, batch_sizes, _, _ = input max_batch_size = batch_sizes[0] else: batch_sizes = None