Skip to content

Commit

Permalink
fixed bucket refilling in translator._score
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Jan 29, 2024
1 parent cd0e08a commit cf768a1
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,14 @@ def _process_bucket(bucket_translations):

def _score(self, infer_iter):
self.with_scores = True
score_results = []
score_res = []
processed_bucket = {}
prev_bucket_idx = 0
for batch, bucket_idx in infer_iter:
if bucket_idx != prev_bucket_idx:
prev_bucket_idx += 1
score_res += [item for _, item in sorted(processed_bucket.items())]
processed_bucket = {}
batch_data = self.translate_batch(batch, attn_debug=False, scoring=True)
batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist()
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
Expand All @@ -590,16 +596,15 @@ def _score(self, infer_iter):
batch_gold_log_probs = [
None for i, _ in enumerate(batch_inds_in_bucket)
]

for i, _ in enumerate(batch_inds_in_bucket):
score_results.append(
(
batch_gold_scores[i],
batch_gold_log_probs[i],
batch_tgt_lengths[i],
),
)
return score_results
for i, ind in enumerate(batch_inds_in_bucket):
processed_bucket[ind] = [
batch_gold_scores[i],
batch_gold_log_probs[i],
batch_tgt_lengths[i],
]
if processed_bucket:
score_res += [item for _, item in sorted(processed_bucket.items())]
return score_res

def _align_pad_prediction(self, predictions, bos, pad):
"""
Expand Down

0 comments on commit cf768a1

Please sign in to comment.