diff --git a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu index 89f489d00..86bd2b624 100644 --- a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu +++ b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu @@ -115,6 +115,7 @@ void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMa const int batch_size = output_tensors->at("output_ids").shape[1]; const int beam_width = output_tensors->at("output_ids").shape[2]; + const int max_input_length = input_tensors->at("max_input_length").getVal(); const int step = input_tensors->at("step").getVal(); const int ite = input_tensors->at("ite").getVal(); const int local_batch_size = input_tensors->at("logits").shape[0]; @@ -125,6 +126,7 @@ void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMa input_tensors->isExist("len_penalty") ? input_tensors->at("len_penalty").getVal() : 0.0f; const int id_offset = step * batch_size * beam_width + local_batch_size * ite * beam_width; + const int gen_offset = (step - max_input_length) * batch_size * beam_width + local_batch_size * ite * beam_width; BeamHypotheses beam_hyps; if (output_tensors->isExist("beam_hyps")) { @@ -147,7 +149,7 @@ void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMa output_tensors->at("finished").getPtr(), output_tensors->at("sequence_length").getPtr(), output_tensors->at("cum_log_probs").getPtr(), - output_tensors->getPtrWithOffset("output_log_probs", id_offset, nullptr), + output_tensors->getPtrWithOffset("output_log_probs", gen_offset, nullptr), output_tensors->at("output_ids").getPtrWithOffset(id_offset), topk_softmax_workspace_, topk_softmax_workspace_size_,