Skip to content

Commit

Permalink
Merge pull request #43 from boostcampaitech2/serving_test
Browse files Browse the repository at this point in the history
Serving test
  • Loading branch information
gistarrr authored Dec 22, 2021
2 parents 295fc22 + 956b870 commit a8a906b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 23 deletions.
13 changes: 7 additions & 6 deletions model/utils/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ def preprocess_function(examples:datasets,
# padding in the loss.
inputs_padding_bool = (padding == "max_length")
doc_type_ids = []

for i in range(len(model_inputs['input_ids'])) :
model_inputs["attention_mask"][i] = add_padding(sample_tokens=model_inputs["attention_mask"][i],
padding=inputs_padding_bool,
padding_num=0,
max_length=max_source_length,
bos_token_id=1,
eos_token_id=1)
bos_token_id=bos_token_id,
eos_token_id=eos_token_id)
model_inputs["input_ids"][i] = add_padding(sample_tokens=model_inputs["input_ids"][i],
padding=inputs_padding_bool,
padding_num= pad_token_id,
Expand Down Expand Up @@ -84,12 +85,12 @@ def add_padding(sample_tokens:List[int],
bos_token_id:int,
eos_token_id:int) -> List:
sample_tokens_len = len(sample_tokens)
if len(sample_tokens) > max_length - 2:
if len(sample_tokens) > max_length - 1:
if bos_token_id == 0: #bart tokenizer만 진행
sample_tokens = [bos_token_id] + sample_tokens[:max_length-2] + [eos_token_id]
sample_tokens = [bos_token_id] + sample_tokens[:max_length-1]# + [eos_token_id]
else:
if bos_token_id == 0: #bart tokenizer만 진행
sample_tokens = [bos_token_id] + sample_tokens + [eos_token_id] # + [padding_num]*(max_length-sample_tokens_len-2)
sample_tokens = [bos_token_id] + sample_tokens #+ [eos_token_id] # + [padding_num]*(max_length-sample_tokens_len-2)
if padding:
sample_tokens = sample_tokens + [padding_num]*(max_length-sample_tokens_len-2)
sample_tokens = sample_tokens + [padding_num]*(max_length-sample_tokens_len-1)
return sample_tokens
19 changes: 11 additions & 8 deletions serving/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ def main(args):
beams_input = st.sidebar.slider('Number of beams search', 1, 5, 3, key='beams')
layer = st.sidebar.slider('Layer', 0, 5, 5, key='layer')

model_name = args.model

with timer("load...") :
tokenizer, model = load(model_name)

tokenizer, model = load(args.model)

model_name = args.use_model
input_text = st.text_area('Prompt:', height=200)
if input_text :
with timer("generate...") :
with timer("generate...") :
generated_tokens = get_prediction(tokenizer, model, model_name, input_text, beams_input, generation_args)
title = tokenizer.decode(generated_tokens.squeeze().tolist(), skip_special_tokens=True)
title = re.sub('</s> |</s>|[CLS] | [SEP]', '', title)

pcs = TitlePostProcessor()
title = pcs.post_process(title)
st.write(f'Titles: {title}')
Expand All @@ -55,8 +56,10 @@ def main(args):
dec_split_indices = split_tensor_by_words(dec_tokens, model_name)
enc_split_indices = split_tensor_by_words(enc_tokens, model_name)


highlighted_text = text_highlight(st_cross_attn, enc_tokens)
# print(dec_split_indices)
# breakpoint()

highlighted_text = text_highlight(st_cross_attn, enc_tokens, model_name)
st.write(HTML(highlighted_text))

fig = attention_heatmap(st_cross_attn, enc_split, dec_split,
Expand All @@ -72,7 +75,7 @@ def main(args):

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='../model/checkpoint/baseV1.0_Kobart_ep3_0.7') #baseV1.0_Kobart
parser.add_argument('--use_model', type=str, default='kobart', help='kobigbirdbart or etc')
parser.add_argument('--use_model', type=str, default='bigbart', help='bigbart or etc')
args = parser.parse_args()

main(args)
4 changes: 2 additions & 2 deletions serving/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def load(model_name) :

if "longformerbart" in model_name:
model = LongformerBartWithDoctypeForConditionalGeneration.from_pretrained(model_name)
elif "kobigbirdbart" in model_name:
elif "bigbart" in model_name:
tokenizer = AutoTokenizer.from_pretrained('monologg/kobigbird-bert-base')
model = EncoderDecoderModel.from_pretrained(model_name, output_attentions=True)
model.encoder.encoder.layer = model.encoder.encoder.layer[:model.config.encoder.encoder_layers]
# model.encoder.encoder.layer = model.encoder.encoder.layer[:model.config.encoder.encoder_layers]
model.encoder.config.output_attentions = True
model.decoder.config.output_attentions = True
else:
Expand Down
13 changes: 8 additions & 5 deletions serving/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def split_tensor_by_words(
) -> List[int] :
i = 0
split_words_indices = []
if model_type != 'kobigbirdbart' :
if model_type != 'bigbart' :
for token in text_tokens :
if '▁' in token:
split_words_indices.append(i)
Expand All @@ -17,20 +17,23 @@ def split_tensor_by_words(
i += 1
split_words_indices.append(i)
else :
cnt = 0
for token in text_tokens :
cnt += 1
if '##' in token :
i += 1
else :
split_words_indices.append(i)
i = 1
split_words_indices.append(i)
split_words_indices = split_words_indices[1:]
return split_words_indices

def token_to_words(
text_tokens: List[str],
model_type: str
) -> List[str] :
if model_type != 'kobigbirdbart' :
if model_type != 'bigbart' :
join_text = ''.join(text_tokens).replace('▁', ' ')
space_text = join_text.split(' ')[1:]
else :
Expand Down Expand Up @@ -58,11 +61,11 @@ def format_attention(
return torch.stack(squeezed)

def model_forward(model, tokenizer, text, title) :
enc_input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=True).input_ids
dec_input_ids = tokenizer(title, return_tensors="pt", add_special_tokens=True).input_ids
enc_input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids
dec_input_ids = tokenizer(title, return_tensors="pt", add_special_tokens=False).input_ids

outputs = model(input_ids=enc_input_ids, decoder_input_ids=dec_input_ids)

st_cross_attn = format_attention(outputs.cross_attentions)
return st_cross_attn, enc_input_ids, dec_input_ids

Expand Down
5 changes: 3 additions & 2 deletions serving/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def highlighter(
word = '<span style="background-color:' +color+ '">' +word+ '</span>'
return word

def text_highlight(st_cross_attn, encoder_tokens) :
def text_highlight(st_cross_attn, encoder_tokens, model_type) :
layer_mat = st_cross_attn.detach()
last_h_layer_mat = torch.mean(layer_mat, 1)[-1] ## mean by head side, last layer
enc_mat = torch.mean(last_h_layer_mat, 0) ## mean by decoder id side
Expand All @@ -28,9 +28,10 @@ def text_highlight(st_cross_attn, encoder_tokens) :
enc_mat /= enc_mat.max()

colors = [rgb_to_hex(255, 255, 255*(1-attn_s)) for attn_s in enc_mat.numpy()]
if model_type == 'bigbart' :
encoder_tokens = ['▁'+word if '##' not in word else word.replace('##','') for word in encoder_tokens ]
higlighted_text = ''.join([highlighter(colors[i], word) for i, word in enumerate(encoder_tokens)])
higlighted_text = higlighted_text.replace('▁',' ')

return higlighted_text

def attention_heatmap(
Expand Down

0 comments on commit a8a906b

Please sign in to comment.