Skip to content

Commit

Permalink
more update to training scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang533 committed Dec 12, 2024
1 parent 30aa240 commit 2129004
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion manga_ocr_dev/training/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None):

decoder_config.num_hidden_layers = num_decoder_layers

config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
config.tie_word_embeddings = False
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)

Expand Down
4 changes: 2 additions & 2 deletions manga_ocr_dev/training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def compute_metrics(self, pred):
pred_ids = pred.predictions
print(label_ids.shape, pred_ids.shape)

pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
pred_str = self.processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True)
label_str = self.processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

pred_str = np.array(["".join(text.split()) for text in pred_str])
label_str = np.array(["".join(text.split()) for text in label_str])
Expand Down

0 comments on commit 2129004

Please sign in to comment.