Skip to content

Commit

Permalink
Reset rotary embeddings for chained inference
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Mar 20, 2024
1 parent 8479526 commit f0d012e
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,12 @@ def forward(
or query.dtype != torch.float16
):
if self.max_relative_positions == -1: # Rotary Embeddings
if step == 0:
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
base=self.rotary_theta,
device=self.rope.device,
)
if seqlen + start_pos > self.rope.size(0):
# Resize rotary embeddings.
self.rope, self.cos, self.sin = rotaryembeddings(
Expand Down

0 comments on commit f0d012e

Please sign in to comment.