Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
zehuichen123 committed Feb 21, 2024
1 parent 93c5272 commit 35851bc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
1 change: 1 addition & 0 deletions test_all_en.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export CUDA_VISIBLE_DEVICES=1
echo "model_type: $1"

model_path=$2
Expand Down
56 changes: 30 additions & 26 deletions teval/evaluators/reason_retrieve_understand_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def _post_process(self, results_list):
if self.eval_type == 'understand':
metric_keys = ['args', 'parse_rate']
metrics_results = []
batch_data = []
batch_id = []
batch_data = []; batch_arg_data = []
batch_id = []; batch_arg_id = []
BATCH_LIMIT = 32
for id, data in enumerate(results_list):
metrics_results.append(
Expand Down Expand Up @@ -211,25 +211,29 @@ def _post_process(self, results_list):
metrics_results[id]['name'] = 0

if 'args' in data.pred and 'args' in data.gt:
# batch_arg_data.extend([str(data.pred['args']), str(data.gt['args'])])
# batch_arg_id.extend([id])
# if len(batch_arg_data) >= BATCH_LIMIT:
# pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
# for i in range(0, len(batch_arg_data), 2):
# cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
# metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
# batch_arg_data = []
# batch_arg_id = []
batch_arg_data.extend([str(data.pred['args']), str(data.gt['args'])])
batch_arg_id.extend([id])
if len(batch_arg_data) >= BATCH_LIMIT:
pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
for i in range(0, len(batch_arg_data), 2):
cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
batch_arg_data = []
batch_arg_id = []

# NOTE we adopt a more strict evaluation protocal in v2
for gt_arg_name in data.gt['args']:
if gt_arg_name in data.pred['args'] and str(data.pred['args'][gt_arg_name]) == str(data.gt['args'][gt_arg_name]):
metrics_results[id]['args'] += 1
metrics_results[id]['args'] /= (len(data.gt['args']) + 1e-5)
if len(data.gt['args']) == 0 and len(data.pred['args']) == 0:
metrics_results[id]['args'] = 1
if len(data.gt['args']) == 0 and len(data.pred['args']) != 0:
metrics_results[id]['args'] = 0
# if isinstance(data.gt['args'], dict):
# for gt_arg_name in data.gt['args']:
# if gt_arg_name in data.pred['args'] and str(data.pred['args'][gt_arg_name]) == str(data.gt['args'][gt_arg_name]):
# metrics_results[id]['args'] += 1
# metrics_results[id]['args'] /= (len(data.gt['args']) + 1e-5)
# if len(data.gt['args']) == 0 and len(data.pred['args']) == 0:
# metrics_results[id]['args'] = 1
# if len(data.gt['args']) == 0 and len(data.pred['args']) != 0:
# metrics_results[id]['args'] = 0
# else:
# data.pred['args'] = data.pred['args'].strip("'").strip('"')
# metrics_results[id]['args'] = float(data.gt['args'] == data.pred['args'])

if len(batch_data) > 0:
pred_emb = self.sentence_model.encode(batch_data, convert_to_tensor=True)
Expand All @@ -239,13 +243,13 @@ def _post_process(self, results_list):
batch_data = []
batch_id = []

# if len(batch_arg_data) > 0:
# pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
# for i in range(0, len(batch_arg_data), 2):
# cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
# metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
# batch_arg_data = []
# batch_arg_id = []
if len(batch_arg_data) > 0:
pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
for i in range(0, len(batch_arg_data), 2):
cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
batch_arg_data = []
batch_arg_id = []

results = dict()
for key in metric_keys:
Expand Down
1 change: 0 additions & 1 deletion teval/utils/meta_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
begin='\n<|im_start|>assistant\n',
end='<|im_end|>',
generate=True),

],
vicuna = [
dict(role='user', begin='user: ', end='\n'),
Expand Down

0 comments on commit 35851bc

Please sign in to comment.