Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update preprocessing #40

Merged
merged 1 commit into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from utils.processor import preprocess_function
from utils.data_preprocessor import Preprocessor
from models.modeling_kobigbird_bart import EncoderDecoderModel

@contextmanager
Expand Down Expand Up @@ -60,8 +61,10 @@ def main() :

dataset_name = "metamong1/summarization"
datasets = load_dataset(dataset_name + "_part" if data_args.is_part else dataset_name, use_auth_token=USE_AUTH_TOKEN)
data_preprocessor = Preprocessor()
datasets = datasets.map(data_preprocessor.for_test)
valid_dataset = datasets['validation']

idx = 1600 ## 바꾸면서 test 해보세요!
text = valid_dataset[idx]['text']
title = valid_dataset[idx]['title']
Expand Down
8 changes: 8 additions & 0 deletions model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from utils.trainer import Seq2SeqTrainerWithConditionalDocType
from utils.data_preprocessor import Preprocessor, Filter
from utils.data_collator import DataCollatorForSeq2SeqWithDocType
from utils.processor import preprocess_function
from utils.rouge import compute_metrics
Expand Down Expand Up @@ -78,6 +79,13 @@ def main():

dataset_name = "metamong1/summarization"
datasets = load_dataset(dataset_name + "_part" if data_args.is_part else dataset_name, use_auth_token=USE_AUTH_TOKEN)
data_preprocessor = Preprocessor()
data_filter = Filter(min_size=5, max_size=100)

## data preprocessing
datasets = datasets.map(data_preprocessor.for_train)
datasets = datasets.filter(data_filter)

train_dataset = datasets['train']
valid_dataset = datasets['validation']

Expand Down
57 changes: 57 additions & 0 deletions model/utils/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

import re

class Preprocessor :
def __init__(self ) :
# 일본어, 한국어, 한자, 기본 문자, 구두점, 문장 기호
self.private_comp = re.compile('[\ue000-\uf8ff]')
self.outrange_comp = re.compile('[^\u3040-\u30ff\
\uac00-\ud7af\
\u4e00-\u9fff\
\u0000-\u007f\
\u2000-\u206f\
\u25a0-\u25ff]')

self.bracket_comp = re.compile(r"\([^)]+\)")

def for_train(self, data) :
title = data['title']
title = self.bracket_comp.sub(' ', title)
title = self.doc_preprocess(title)
title = self.strip(title)

text = data['text']
text = self.bracket_comp.sub(' ', text)
text = self.doc_preprocess(text)
text = self.strip(text)

data['text'] = text
data['title'] = title
return data

def for_test(self, data) :
text = data['text']
text = self.bracket_comp.sub(' ', text)
text = self.doc_preprocess(text)
text = self.strip(text)
data['text'] = text
return data

def strip(self, txt) :
txt = re.sub('\s+' , ' ', txt)
return txt.strip()

def doc_preprocess(self, txt) :
txt = self.private_comp.sub(' ', txt)
txt = self.outrange_comp.sub(' ', txt)
return txt

class Filter :
def __init__(self, min_size, max_size) :
self.min_size = min_size
self.max_size = max_size

def __call__(self, data) :
if len(data['title']) < self.min_size or len(data['title']) > self.max_size:
return False
return True