From 5df9e402bd964b101eb809d6d3b438f3029a7d79 Mon Sep 17 00:00:00 2001 From: BJH9 Date: Wed, 10 Jan 2024 22:50:24 +0000 Subject: [PATCH] =?UTF-8?q?[exp]:=20=EB=8D=B0=EC=9D=B4=ED=84=B0=20?= =?UTF-8?q?=EA=B0=95=ED=99=94=20=EC=8B=A4=ED=97=98=20#28?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- code/inference.py | 4 ++-- code/train.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/code/inference.py b/code/inference.py index 1070a36..4db7264 100644 --- a/code/inference.py +++ b/code/inference.py @@ -96,7 +96,7 @@ def main(args): # 아래 directory와 columns의 형태는 지켜주시기 바랍니다. output = pd.DataFrame({'id' : test_dataset['id'],'pred_label':pred_answer,'probs':output_prob,}) - output.to_csv('./prediction/submission.csv', index=False) # 최종적으로 완성된 예측한 라벨 csv 파일 형태로 저장. + output.to_csv('./prediction/augment_2.csv', index=False) # 최종적으로 완성된 예측한 라벨 csv 파일 형태로 저장. #### 필수!! ############################################## print('---- Finish! ----') @@ -105,7 +105,7 @@ def main(args): parser = argparse.ArgumentParser() # model dir - parser.add_argument('--model_path', type=str, default="./best_model/bestmodel.pth") + parser.add_argument('--model_path', type=str, default="./best_model/augment2_bestmodel.pth") args = parser.parse_args() print(args) main(args) diff --git a/code/train.py b/code/train.py index 52e3b80..6a9e4ea 100644 --- a/code/train.py +++ b/code/train.py @@ -20,7 +20,7 @@ def train(): # MODEL_NAME = "bert-base-uncased" MODEL_NAME = "klue/roberta-large" - TRAIN_PATH = "../dataset/train/augmented_train1.csv" + TRAIN_PATH = "../dataset/train/augmented_train2.csv" LABEL_CNT = 30 P_CONFIG = {'prompt_kind' : 's_and_o', # ['s_sep_o', 's_and_o', 'quiz'] 'preprocess_method' : 'typed_entity_marker_punct', # ['baseline_preprocessor', 'entity_mask', 'entity_marker', 'entity_marker_punct', 'typed_entity_marker', 'typed_entity_marker_punct'] @@ -104,7 +104,7 @@ def train(): trainer.train() # git에 올린 코드 model_state_dict = model.state_dict() - torch.save({'model_state_dict' : model_state_dict}, './best_model/augment1_bestmodel.pth') + torch.save({'model_state_dict' : model_state_dict}, './best_model/augment2_bestmodel.pth') def main(): train()