Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baseline 구축 #24

Open
presto105 opened this issue Dec 4, 2021 · 0 comments
Open

Baseline 구축 #24

presto105 opened this issue Dec 4, 2021 · 0 comments
Assignees
Labels
report Sharing information or results of analysis

Comments

@presto105
Copy link
Contributor

presto105 commented Dec 4, 2021

Content

final_project_baseline/
├── 1.EDA
│   ├── EDA_books_data_bj.ipynb
│   ├── EDA_document_summarization_ys.ipynb
│   ├── EDA_paper_data_sh.ipynb
│   └── EDA_patent_data_cy.ipynb
├── args
│   ├── DataTrainingArguments.py
│   ├── GenerationArguments.py
│   ├── __init__.py
│   ├── LoggingArguments.py
│   ├── ModelArguments.py
│   └── Seq2SeqTrainingArguments.py
├── dataloader.py
├── predict.py
├── processor.py
├── requirements
│   ├── requirements.sh
│   ├── rouge_requirement.sh
│   └── viz_requirement.sh
├── rouge.py
├── train.py
├── running.sh
└──use_auth_token.env

Running code

  • running.sh를 참고하여 running 해보세요! :)
python train.py \ 
--do_train \                                 => train mode
--dataset_name paper,news,magazine,law \     => 사용할 dataset 종류
--output_dir model/kobart \                  => 모델 저장 경로
--num_train_epochs 2 \                       => epoch 수
--learning_rate 3e-05 \                      => LR
--max_source_length 1024 \                   => Input max length
--max_target_length 128 \                    => Output max length
--metric_for_best_model rougeLsum \          => 평가지표
--es_patience 5 \                            => Early stopping patience
--relative_eval_steps 20 \                   => Eval step 전체 iteration에서 몇번 할지
--wandb_unique_tag kobart_ep2_lr3e05         => wandb tag

Requirements

requirements.sh
rouge_requirement.sh

위 두 파일은 모두 install 해주셔야 합니다!

EDA

  • 논문, 문서, 특허, 도서 data에 대해서 개인별 EDA입니다.
  • 최종적으로 사용하기로한 논문, 도서 등에 대해서 참고해보세요!

Train

데이터 로드 부분!

  • dataset_name 인자를 통해 여러 데이터 셋을 통합한 dataset을 불러옵니다.
types = data_args.dataset_name.split(',')
data_args.dataset_name = ['metamong1/summarization_' + dt for dt in types]

load_dotenv(dotenv_path=data_args.use_auth_token_path)
USE_AUTH_TOKEN = os.getenv("USE_AUTH_TOKEN")

train_dataset = SumDataset(data_args.dataset_name, 'train', USE_AUTH_TOKEN=USE_AUTH_TOKEN).load_data()
valid_dataset = SumDataset(data_args.dataset_name, 'validation', USE_AUTH_TOKEN=USE_AUTH_TOKEN).load_data()

위와 같은 input을 넣어주면 , split으로 string을 처리하여 dataset merge!
dataset_name argument => paper,news,magazine,law

상대적 eval step 수 정하기

  • 상대적 eval 수를 구하기 위해 iteration 구하기
    • iteration = epoch * 올림(dataset수 / batchsize)
    • 몇번 eval할지 = trainingargs.relativeevalsteps
    • 전체 iteration에서 / eval 횟수 = eval steps
if training_args.relative_eval_steps :
    iterations =  training_args.num_train_epochs*math.ceil(len(train_dataset)/training_args.per_device_train_batch_size)
    training_args.eval_steps = int(iterations // training_args.relative_eval_steps) ## dataset 크기에 상대적 eval step 적용
    training_args.save_steps = training_args.eval_steps

Seq2SeqTrainer

  • 특이한 점은 earlystopping!

callbacks = [EarlyStoppingCallback(early_stopping_patience=training_args.es_patience)]

  • training_args.es_patience로 3중 연산자를 통해 ES시행하지 않을 수도 있고 patience도 조절 가능
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset, # if training_args.do_train else None,
    eval_dataset=valid_dataset, # if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=comp_met_fn if training_args.predict_with_generate else None,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=training_args.es_patience)] if training_args.es_patience else None
)

Dataloder

  • 원하는 huggingFace에 저장된 데이터셋들을 concat하여 불러옵니다!

  • mode(train, validation)를 결정하고 지정한 데이터셋을 append 합니다.

def __init__(self,
        data_types: List[str],
        mode: str,
        USE_AUTH_TOKEN: str
    ) :
        self.dataset = []
        self.mode=mode
        for data_type in data_types :
            self.dataset.append(load_dataset(data_type, use_auth_token=USE_AUTH_TOKEN))
  • load_data함수를 이용하여 각 데이터셋을 concat하여 return합니다.
def load_data(self):
        dataset = concatenate_datasets([ds[self.mode] for ds in self.dataset])
        return dataset

Argument

  • 목적에 따라 다른 arguments file 분리
    • 아직은 목적에 따라 깔끔히 분류하지 못함 :(

DataTrainingArguments

  • use_auth_token_path: huggingface의 dataset을 사용하기 위해 필요한 API key를 저장한 경로입니다.

GenerationArguments

링크 참조!

LoggingArguments

  • dotenv_path: wandb를 사용하기 위해 필요한 API key를 저장한 경로입니다.
  • wandb log에서 tag 설정해줄때 참고하세요!

Seq2SeqTrainingArguments

  • save_total_limit: checkpoint 개수 조절 하여 여러분의 용량을 지켜줍니다 👍
  • metric_for_best_model: early stopping 할때 기준으로 삼을 metric을 정해줍니다.
  • es_patience: 몇 step 이상 metric 점수를 도달하지 못하면 early stopping 진행합니다

Predict

  • 입력받은 문단에 대해, 원하는 모델로 Title을 생성해줍니다!

  • tokenizer를 통해, text를 'input_ids'로 만들고 앞과 뒤에 시작과 끝을 알리는 token을 붙입니다.

raw_input_ids =  tokenizer(text, max_length=data_args.max_source_length, truncation=True)
input_ids = [tokenizer.bos_token_id] + raw_input_ids['input_ids'][:-2] + [tokenizer.eos_token_id]
  • text에 대한 input_ids를 generate의 입력으로 넣어 만든 id는 decode를 통해 title을 생성합니다. (num_return_sequences args를 통해 title 여러 개를 생성할 수 있습니다.)
with timer('** Generate title **') :
    summary_ids = model.generate(torch.tensor([input_ids]), num_beams=num_beams, **generation_args.__dict__)
    if len(summary_ids.shape) == 1 :
        title = tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)
        print(title)
    else :
        titles = tokenizer.batch_decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)
        print(titles)

Written by 범진, 유석

@presto105 presto105 added the report Sharing information or results of analysis label Dec 4, 2021
@presto105 presto105 changed the title Baseline Baseline 구축 Dec 4, 2021
@presto105 presto105 assigned presto105 and j961224 and unassigned presto105 and j961224 Dec 4, 2021
@j961224 j961224 assigned j961224 and presto105 and unassigned j961224 Dec 4, 2021
@presto105 presto105 assigned j961224 and unassigned presto105 Dec 4, 2021
@j961224 j961224 assigned presto105 and unassigned j961224 Dec 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
report Sharing information or results of analysis
Projects
None yet
Development

No branches or pull requests

2 participants