forked from boostcampaitech2/mrc-level2-nlp-02
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer_qa.py
132 lines (116 loc) · 5.69 KB
/
trainer_qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Question-Answering task와 관련된 'Trainer'의 subclass 코드 입니다.
"""
from transformers import (
Trainer,
is_datasets_available,
is_torch_tpu_available,
AdamW,
get_cosine_with_hard_restarts_schedule_with_warmup
)
if is_datasets_available():
import datasets
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
# Huggingface의 Trainer를 상속받아 QuestionAnswering을 위한 Trainer를 생성합니다.
class QuestionAnsweringTrainer(Trainer):
def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.post_process_function = post_process_function
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples
# 일시적으로 metric computation를 불가능하게 한 상태이며, 해당 코드에서는 loop 내에서 metric 계산을 수행합니다.
compute_metrics = self.compute_metrics
self.compute_metrics = None
try:
output = self.prediction_loop(
eval_dataloader,
description="Evaluation",
# metric이 없으면 예측값을 모으는 이유가 없으므로 아래의 코드를 따르게 됩니다.
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset.set_format(
type=eval_dataset.format["type"],
columns=list(eval_dataset.features.keys()),
)
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(
eval_examples, eval_dataset, output.predictions, self.args
)
metrics = self.compute_metrics(eval_preds)
self.log(metrics)
else:
metrics = {}
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: PyTorch/XLA에 대한 Logging debug metrics (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(
self.args, self.state, self.control, metrics
)
return metrics
def predict(self, test_dataset, test_examples, ignore_keys=None):
test_dataloader = self.get_test_dataloader(test_dataset)
# 일시적으로 metric computation를 불가능하게 한 상태이며, 해당 코드에서는 loop 내에서 metric 계산을 수행합니다.
# evaluate 함수와 동일하게 구성되어있습니다
compute_metrics = self.compute_metrics
self.compute_metrics = None
try:
output = self.prediction_loop(
test_dataloader,
description="Evaluation",
# metric이 없으면 예측값을 모으는 이유가 없으므로 아래의 코드를 따르게 됩니다.
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is None or self.compute_metrics is None:
return output
if isinstance(test_dataset, datasets.Dataset):
test_dataset.set_format(
type=test_dataset.format["type"],
columns=list(test_dataset.features.keys()),
)
predictions = self.post_process_function(
test_examples, test_dataset, output.predictions, self.args
)
return predictions
def create_optimizer_and_scheduler(self, num_training_steps: int, num_cycles:int = 1, another_scheduler_flag=False):
if not another_scheduler_flag:
self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
else:
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
"lr" : self.args.learning_rate,
}
self.optimizer = AdamW(self.model.parameters(), **optimizer_kwargs)
self.lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps,
num_training_steps= num_training_steps,
num_cycles = num_cycles)