Skip to content

Commit

Permalink
Merge pull request #40 from boostcampaitech2/preprocessing
Browse files Browse the repository at this point in the history
Update preprocessing
  • Loading branch information
gistarrr authored Dec 17, 2021
2 parents b060561 + 2333385 commit 39d22f9
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
5 changes: 4 additions & 1 deletion model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from utils.processor import preprocess_function
from utils.data_preprocessor import Preprocessor
from models.modeling_kobigbird_bart import EncoderDecoderModel

@contextmanager
Expand Down Expand Up @@ -60,8 +61,10 @@ def main() :

dataset_name = "metamong1/summarization"
datasets = load_dataset(dataset_name + "_part" if data_args.is_part else dataset_name, use_auth_token=USE_AUTH_TOKEN)
data_preprocessor = Preprocessor()
datasets = datasets.map(data_preprocessor.for_test)
valid_dataset = datasets['validation']

idx = 1600 ## 바꾸면서 test 해보세요!
text = valid_dataset[idx]['text']
title = valid_dataset[idx]['title']
Expand Down
8 changes: 8 additions & 0 deletions model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from utils.trainer import Seq2SeqTrainerWithConditionalDocType
from utils.data_preprocessor import Preprocessor, Filter
from utils.data_collator import DataCollatorForSeq2SeqWithDocType
from utils.processor import preprocess_function
from utils.rouge import compute_metrics
Expand Down Expand Up @@ -78,6 +79,13 @@ def main():

dataset_name = "metamong1/summarization"
datasets = load_dataset(dataset_name + "_part" if data_args.is_part else dataset_name, use_auth_token=USE_AUTH_TOKEN)
data_preprocessor = Preprocessor()
data_filter = Filter(min_size=5, max_size=100)

## data preprocessing
datasets = datasets.map(data_preprocessor.for_train)
datasets = datasets.filter(data_filter)

train_dataset = datasets['train']
valid_dataset = datasets['validation']

Expand Down
57 changes: 57 additions & 0 deletions model/utils/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

import re

class Preprocessor :
def __init__(self ) :
# 일본어, 한국어, 한자, 기본 문자, 구두점, 문장 기호
self.private_comp = re.compile('[\ue000-\uf8ff]')
self.outrange_comp = re.compile('[^\u3040-\u30ff\
\uac00-\ud7af\
\u4e00-\u9fff\
\u0000-\u007f\
\u2000-\u206f\
\u25a0-\u25ff]')

self.bracket_comp = re.compile(r"\([^)]+\)")

def for_train(self, data) :
title = data['title']
title = self.bracket_comp.sub(' ', title)
title = self.doc_preprocess(title)
title = self.strip(title)

text = data['text']
text = self.bracket_comp.sub(' ', text)
text = self.doc_preprocess(text)
text = self.strip(text)

data['text'] = text
data['title'] = title
return data

def for_test(self, data) :
text = data['text']
text = self.bracket_comp.sub(' ', text)
text = self.doc_preprocess(text)
text = self.strip(text)
data['text'] = text
return data

def strip(self, txt) :
txt = re.sub('\s+' , ' ', txt)
return txt.strip()

def doc_preprocess(self, txt) :
txt = self.private_comp.sub(' ', txt)
txt = self.outrange_comp.sub(' ', txt)
return txt

class Filter :
def __init__(self, min_size, max_size) :
self.min_size = min_size
self.max_size = max_size

def __call__(self, data) :
if len(data['title']) < self.min_size or len(data['title']) > self.max_size:
return False
return True

0 comments on commit 39d22f9

Please sign in to comment.