Skip to content

Commit

Permalink
Merge pull request #34 from boostcampaitech6/exp/#28
Browse files Browse the repository at this point in the history
[feat] backtranslation 기능 추가 및 데이터 분석 추가 #28
  • Loading branch information
BJH9 authored Jan 9, 2024
2 parents e047965 + 879320b commit d2c118f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
24 changes: 22 additions & 2 deletions code/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
#%%
import pandas as pd
import numpy as np
from googletrans import Translator
from tqdm.auto import tqdm
import time

train_path = './dataset/train/train.csv'
#%%
train_path = '../dataset/train/train.csv'
train_df = pd.read_csv(train_path)
train_df_len = len(train_df)

trans_train_path = '../dataset/train/trans_train.csv'
trans_train = pd.read_csv(trans_train_path)

#%%
# back translation
translator = Translator()

def google_ko2en2ko(ko_text, translator):
Expand Down Expand Up @@ -41,4 +48,17 @@ def google_ko2en2ko(ko_text, translator):

# for idx, sentence in enumerate(tqdm(sen2_list)):
# result = google_ko2en2ko(sentence, translator)
# test.loc[idx,'sentence_2'] = result
# test.loc[idx,'sentence_2'] = result

#%%
smooth_trans = trans_train[(trans_train['label'] != "no_relation") & (trans_train['label'] != "org:top_members/employees") & (trans_train['label'] != "per:employee_of")]
#augmented_df = pd.merge(train_df, smooth_trans)
smooth_trans['id'] = range(1, len(smooth_trans)+1)
smooth_trans.reset_index(drop=True, inplace=True)
smooth_trans.to_csv('../dataset/train/smooth_trans.csv', index=False)

augmented_df = pd.merge(train_df, smooth_trans, how='outer')
augmented_df['id'] = range(1, len(augmented_df)+1)
augmented_df.reset_index(drop=True, inplace=True)
augmented_df.to_csv('../dataset/train/augmented_train.csv', index=False)
# %%
13 changes: 8 additions & 5 deletions code/data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
test_path = '../dataset/test/test_data.csv'
test_df = pd.read_csv(test_path)

augmented_path = '../dataset/train/augmented_train.csv'
augmented_df = pd.read_csv(augmented_path)

# %%
# subj_entity의 라벨 종류 및 분포 / obj_entity의 라벨 종류 및 분포

Expand Down Expand Up @@ -48,8 +51,8 @@ def visualize_obj_label(train_df):
plt.show()


visualize_subj_label(train_df)
visualize_obj_label(train_df)
visualize_subj_label(augmented_df)
visualize_obj_label(augmented_df)

#%%

Expand Down Expand Up @@ -106,7 +109,7 @@ def visualize_relation(train_df):
plt.xlabel('Counts')
plt.show()

visualize_relation(train_df)
visualize_relation(augmented_df)

# %%
# sentence 길이 시각화
Expand All @@ -132,7 +135,7 @@ def visualize_sentence_len(train_df):
plt.ylabel('Counts', fontsize=40)
#plt.show()

visualize_sentence_len(train_df)
visualize_sentence_len(augmented_df)

# %%
# outlier 구간 찾기
Expand Down Expand Up @@ -185,6 +188,6 @@ def detect_duplicated(train_df):
print(filtered_df)


detect_duplicated(train_df)
detect_duplicated(augmented_df)

# %%

0 comments on commit d2c118f

Please sign in to comment.