diff --git a/manga_ocr_dev/training/get_model.py b/manga_ocr_dev/training/get_model.py index a2bcadd..96d68c5 100644 --- a/manga_ocr_dev/training/get_model.py +++ b/manga_ocr_dev/training/get_model.py @@ -47,7 +47,7 @@ def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None): decoder_config.num_hidden_layers = num_decoder_layers - config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) config.tie_word_embeddings = False model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config) diff --git a/manga_ocr_dev/training/metrics.py b/manga_ocr_dev/training/metrics.py index 9c6b9ec..bfcf934 100644 --- a/manga_ocr_dev/training/metrics.py +++ b/manga_ocr_dev/training/metrics.py @@ -12,9 +12,9 @@ def compute_metrics(self, pred): pred_ids = pred.predictions print(label_ids.shape, pred_ids.shape) - pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True) + pred_str = self.processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id - label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True) + label_str = self.processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) pred_str = np.array(["".join(text.split()) for text in pred_str]) label_str = np.array(["".join(text.split()) for text in label_str])