diff --git a/autorag/nodes/passagereranker/flag_embedding.py b/autorag/nodes/passagereranker/flag_embedding.py index 3546afce4..63254a5e6 100644 --- a/autorag/nodes/passagereranker/flag_embedding.py +++ b/autorag/nodes/passagereranker/flag_embedding.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Iterable import pandas as pd import torch @@ -58,7 +58,7 @@ def flag_embedding_run_model(input_texts, model, batch_size: int): for batch_texts in tqdm(batch_input_texts): with torch.no_grad(): pred_scores = model.compute_score(sentence_pairs=batch_texts) - if batch_size == 1: + if batch_size == 1 or not isinstance(pred_scores, Iterable): results.append(pred_scores) else: results.extend(pred_scores)