Skip to content

Commit

Permalink
[exp]: 데이터 강화 실험 #28
Browse files Browse the repository at this point in the history
  • Loading branch information
BJH9 committed Jan 10, 2024
1 parent faf15a7 commit 5df9e40
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions code/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main(args):
# 아래 directory와 columns의 형태는 지켜주시기 바랍니다.
output = pd.DataFrame({'id' : test_dataset['id'],'pred_label':pred_answer,'probs':output_prob,})

output.to_csv('./prediction/submission.csv', index=False) # 최종적으로 완성된 예측한 라벨 csv 파일 형태로 저장.
output.to_csv('./prediction/augment_2.csv', index=False) # 최종적으로 완성된 예측한 라벨 csv 파일 형태로 저장.
#### 필수!! ##############################################
print('---- Finish! ----')

Expand All @@ -105,7 +105,7 @@ def main(args):
parser = argparse.ArgumentParser()

# model dir
parser.add_argument('--model_path', type=str, default="./best_model/bestmodel.pth")
parser.add_argument('--model_path', type=str, default="./best_model/augment2_bestmodel.pth")
args = parser.parse_args()
print(args)
main(args)
Expand Down
4 changes: 2 additions & 2 deletions code/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def train():

# MODEL_NAME = "bert-base-uncased"
MODEL_NAME = "klue/roberta-large"
TRAIN_PATH = "../dataset/train/augmented_train1.csv"
TRAIN_PATH = "../dataset/train/augmented_train2.csv"
LABEL_CNT = 30
P_CONFIG = {'prompt_kind' : 's_and_o', # ['s_sep_o', 's_and_o', 'quiz']
'preprocess_method' : 'typed_entity_marker_punct', # ['baseline_preprocessor', 'entity_mask', 'entity_marker', 'entity_marker_punct', 'typed_entity_marker', 'typed_entity_marker_punct']
Expand Down Expand Up @@ -104,7 +104,7 @@ def train():
trainer.train()
# git에 올린 코드
model_state_dict = model.state_dict()
torch.save({'model_state_dict' : model_state_dict}, './best_model/augment1_bestmodel.pth')
torch.save({'model_state_dict' : model_state_dict}, './best_model/augment2_bestmodel.pth')

def main():
train()
Expand Down

0 comments on commit 5df9e40

Please sign in to comment.