diff --git a/CHANGELOG.md b/CHANGELOG.md index 4423e3bb..71b65ad3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.24 + +* Revised repetitions for Chipper + ## 0.7.23 * fix: added handling in `UnstructuredTableTransformerModel` for if `recognize` returns an empty diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index 95f9a206..9297625d 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -214,17 +214,14 @@ def points(self): class MockPageLayout(layout.PageLayout): def __init__( - self, - number=1, - image=None, - model=None, - detection_model=None, + self, number=1, image=None, model=None, detection_model=None, element_extraction_model=None ): self.image = image self.layout = layout self.model = model self.number = number self.detection_model = detection_model + self.element_extraction_model = element_extraction_model @pytest.mark.parametrize( diff --git a/test_unstructured_inference/models/test_chippermodel.py b/test_unstructured_inference/models/test_chippermodel.py index 065e24bc..cafc6c47 100644 --- a/test_unstructured_inference/models/test_chippermodel.py +++ b/test_unstructured_inference/models/test_chippermodel.py @@ -139,7 +139,7 @@ def test_no_repeat_ngram_logits(): no_repeat_ngram_size = 2 - logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2) + logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2, context_length=10) output = logitsProcessor(input_ids=input_ids, scores=logits) assert ( @@ -194,7 +194,7 @@ def test_ngram_repetiton_stopping_criteria(): logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]]) stoppingCriteria = chipper.NGramRepetitonStoppingCriteria( - repetition_window=2, skip_tokens={0, 1, 2, 3, 4} + ngram_size=2, skip_tokens={0, 1, 2, 3, 4} ) output = stoppingCriteria(input_ids=input_ids, scores=logits) @@ -202,12 +202,56 @@ def test_ngram_repetiton_stopping_criteria(): assert output is False stoppingCriteria = chipper.NGramRepetitonStoppingCriteria( - repetition_window=2, skip_tokens={1, 2, 3, 4} + ngram_size=2, skip_tokens={1, 2, 3, 4} ) output = stoppingCriteria(input_ids=input_ids, scores=logits) assert output is True +def test_no_repeat_group_ngram_logits_processor(): + input_ids = torch.tensor([[1, 2, 3, 4, 0, 1, 2, 3, 4]]) + logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]]) + + logitsProcessor = chipper.NoRepeatGroupNGramLogitsProcessor(ngram_size=3, token_group=[1, 2]) + + output = logitsProcessor(input_ids=input_ids, scores=logits) + + assert ( + int( + torch.sum( + output == torch.tensor([[0.1000, -0.3000, -0.5000, 0.0000, 1.0000, -0.9000]]), + ), + ) + == 6 + ) + + input_ids = torch.tensor([[1, 1, 2, 1, 2, 1, 2, 1, 2]]) + logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]]) + + output = logitsProcessor(input_ids=input_ids, scores=logits) + + assert ( + int( + torch.sum( + output + == torch.tensor([[0.1000, -float("inf"), -float("inf"), 0.0000, 1.0000, -0.9000]]), + ), + ) + == 6 + ) + + +def test_target_token_id_stopping_criterion(): + input_ids = torch.tensor([1, 2, 3]) + logits = torch.tensor([0.1, 0.2, 0.3]) + + stoppingCriterion = chipper.TargetTokenIdStoppingCriterion(1) + + output = stoppingCriterion(input_ids=input_ids, scores=logits) + + assert output is True + + @pytest.mark.parametrize( ("decoded_str", "expected_classes"), [ @@ -259,7 +303,7 @@ def test_predict_tokens_beam_indices(): model = get_model("chipper") model.stopping_criteria = [ chipper.NGramRepetitonStoppingCriteria( - repetition_window=1, + ngram_size=1, skip_tokens={}, ), ] @@ -294,9 +338,12 @@ def test_deduplicate_detected_elements(): assert len(output) == 2 -def test_norepeatnGramlogitsprocessor_exception(): +def test_logitsprocessor_exception(): + with pytest.raises(ValueError): + chipper.NoRepeatNGramLogitsProcessor(ngram_size="", context_length=10) + with pytest.raises(ValueError): - chipper.NoRepeatNGramLogitsProcessor(ngram_size="") + chipper.NoRepeatGroupNGramLogitsProcessor(ngram_size="", token_group={}) def test_run_chipper_v3(): diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index b16a8c5b..688c38bb 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.23" # pragma: no cover +__version__ = "0.7.24" # pragma: no cover diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index 27c3eefe..3294f737 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -209,7 +209,9 @@ def get_elements_from_layout( # If the model is a chipper model, we don't want to order the # elements, as they are already ordered - order_elements = not isinstance(self.detection_model, UnstructuredChipperModel) + order_elements = not isinstance( + self.element_extraction_model, UnstructuredChipperModel + ) and not isinstance(self.detection_model, UnstructuredChipperModel) if order_elements: layout = order_layout(layout) diff --git a/unstructured_inference/models/chipper.py b/unstructured_inference/models/chipper.py index 2df09c21..de87bdaa 100644 --- a/unstructured_inference/models/chipper.py +++ b/unstructured_inference/models/chipper.py @@ -1,7 +1,9 @@ import copy import os import platform +from collections import Counter from contextlib import nullcontext +from difflib import SequenceMatcher as SM from typing import ContextManager, Dict, List, Optional, Sequence, TextIO, Tuple, Union import cv2 @@ -97,18 +99,40 @@ def initialize( self.tokenizer = self.processor.tokenizer self.logits_processor = [ NoRepeatNGramLogitsProcessor( - no_repeat_ngram_size, - get_table_token_ids(self.processor), + ngram_size=no_repeat_ngram_size, + context_length=(no_repeat_ngram_size * 5) + 1, + skip_tokens=get_table_token_ids(self.processor), ), ] self.stopping_criteria = [ NGramRepetitonStoppingCriteria( - repetition_window=30, + ngram_size=30, skip_tokens=get_table_token_ids(self.processor), ), ] + # This check is needed to since Chipperv1 does not processes tables + if self.source in (Source.CHIPPERV2, Source.CHIPPERV3, Source.CHIPPER): + self.stopping_criteria.append( + TargetTokenIdStoppingCriterion( + target_token_id=self.processor.tokenizer.encode( + "", + add_special_tokens=False, + )[0], + ), + ) + + self.logits_processor.append( + NoRepeatGroupNGramLogitsProcessor( + ngram_size=5, + token_group=self.processor.tokenizer.encode( + "", + add_special_tokens=False, + ), + ), + ) + self.model = VisionEncoderDecoderModel.from_pretrained( pre_trained_model_repo, ignore_mismatched_sizes=True, @@ -168,8 +192,84 @@ def predict(self, image) -> List[LayoutElement]: elements = self.format_table_elements( self.postprocess(image, tokens, decoder_cross_attentions), ) + + elements = self.clean_elements(elements) + + return elements + + def clean_elements(cls, elements): + """ + Perform cleaning of empty tables and repeater elements + """ + elements = cls.remove_empty_table_elements(elements) + elements = cls.remove_repeated_elements_with_negative_coordinates(elements) + elements = cls.remove_repeated_elements(elements) return elements + def remove_repeated_elements(cls, elements): + """ + Removal of repeated elements. The iou and text similarity has to be large enough. + The first occurrence is kept since the following ones might be affected by the + logit processors + """ + repeated_elements = [] + + # Get list of parents, parents should not be removed + parents = [e.parent for e in elements if e.parent] + + n_elements = len(elements) + for i, e1 in enumerate(elements): + if e1 in repeated_elements: + continue + for j in range(i + 1, n_elements): + e2 = elements[j] + if e1 == e2 or e2 in parents or e1.type == e2.type: + continue + elements_iou = e1.bbox.intersection_over_union(e2.bbox) + ratio = SM(None, e1.text, e2.text).ratio() + if elements_iou > 0.75 and ratio > 0.99: + repeated_elements.append(e2) + + elements = [element for element in elements if element not in repeated_elements] + + return elements + + @staticmethod + def remove_repeated_elements_with_negative_coordinates(elements): + """ + remove repeated elements with negative coordinates + it does not evaluate invalid bboxes + """ + location_to_remove = [ + k + for k, v in Counter( + [ + str(element.bbox) + for element in elements + if has_bbox_negative_coordinates(element.bbox) + ], + ).items() + if v > 3 + ] + + return [ + element + for element in elements + if not (not is_not_valid_bbox(element.bbox) and str(element.bbox) in location_to_remove) + ] + + @staticmethod + def remove_empty_table_elements(elements): + """ + remove empty tables might be created by Chipper. It relies on running + format_table_elements first + """ + return [ + element + for element in elements + if not (element.type == "Table" and len(element.text) == 0) + ] + @staticmethod def format_table_elements(elements): """makes chipper table element return the same as other layout models @@ -227,6 +327,7 @@ def predict_tokens( do_sample=True, no_repeat_ngram_size=0, num_beams=3, + num_return_sequences=1, return_dict_in_generate=True, output_attentions=True, output_scores=True, @@ -527,7 +628,7 @@ def reduce_element_bbox( Given a LayoutElement element, reduce the size of the bounding box, depending on existing elements """ - if element.bbox: + if not is_not_valid_bbox(element.bbox): bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2] if not self.element_overlap(elements, element): @@ -564,7 +665,7 @@ def element_overlap( ] for check_element in elements: - if check_element == element: + if check_element == element or is_not_valid_bbox(check_element.bbox): continue if self.bbox_overlap( @@ -880,6 +981,8 @@ def resolve_bbox_overlaps( continue ebbox1 = element.bbox + if is_not_valid_bbox(ebbox1): + continue bbox1 = [ebbox1.x1, ebbox1.y1, ebbox1.x2, max(ebbox1.y1, ebbox1.y2)] for celement in elements: @@ -887,6 +990,8 @@ def resolve_bbox_overlaps( continue ebbox2 = celement.bbox + if is_not_valid_bbox(ebbox2): + continue bbox2 = [ebbox2.x1, ebbox2.y1, ebbox2.x2, max(ebbox2.y1, ebbox2.y2)] if self.bbox_overlap(bbox1, bbox2): @@ -982,17 +1087,23 @@ def image_padding(self, input_size, target_size): # Inspired on # https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706 class NoRepeatNGramLogitsProcessor(LogitsProcessor): - def __init__(self, ngram_size: int, skip_tokens: Optional[Sequence[int]] = None): + def __init__( + self, + ngram_size: int, + context_length: int, + skip_tokens: Optional[Sequence[int]] = None, + ): if not isinstance(ngram_size, int) or ngram_size <= 0: raise ValueError( f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}", ) self.ngram_size = ngram_size + self.context_length = context_length self.skip_tokens = skip_tokens def __call__( self, - input_ids: torch.LongTensor, + input_ids: torch.Tensor, scores: torch.FloatTensor, ) -> torch.FloatTensor: """ @@ -1012,9 +1123,12 @@ def __call__( """ num_batch_hypotheses = scores.shape[0] cur_len = input_ids.shape[-1] + new_input_ids = input_ids[:, slice(-self.context_length, cur_len)] + new_cur_len = new_input_ids.shape[-1] + return _no_repeat_ngram_logits( - input_ids, - cur_len, + new_input_ids, + new_cur_len, scores, batch_size=num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, @@ -1022,9 +1136,55 @@ def __call__( ) +class NoRepeatGroupNGramLogitsProcessor(LogitsProcessor): + def __init__(self, ngram_size: int, token_group: List[int]): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError( + f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}", + ) + self.ngram_size = ngram_size + self.token_group_list = torch.tensor(token_group * ngram_size) + self.token_group = token_group + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits + for each vocabulary when not using beam search or log softmax for + each vocabulary token when using beam search + + Return: + `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The + processed prediction scores. + + """ + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + if cur_len < len(self.token_group_list): + return scores + + for num_batch in range(num_batch_hypotheses): + if all( + input_ids[num_batch, slice(-len(self.token_group_list), cur_len)].to("cpu") + == self.token_group_list, + ): + for token_id in self.token_group: + scores[num_batch][token_id] = -float("inf") + + return scores + + class NGramRepetitonStoppingCriteria(StoppingCriteria): - def __init__(self, repetition_window: int, skip_tokens: set = set()): - self.repetition_window = repetition_window + def __init__(self, ngram_size: int, skip_tokens: set = set()): + self.ngram_size = ngram_size self.skip_tokens = skip_tokens def __call__( @@ -1058,7 +1218,7 @@ def __call__( for banned_tokens in _calc_banned_tokens( input_ids, num_batch_hypotheses, - self.repetition_window, + self.ngram_size, cur_len, ): for token in banned_tokens: @@ -1068,8 +1228,25 @@ def __call__( return False +class TargetTokenIdStoppingCriterion(StoppingCriteria): + def __init__(self, target_token_id): + super().__init__() + self.target_token_id = target_token_id + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + """ + Check if the already generated tokens contain the target token_id + """ + return self.target_token_id in input_ids + + def _no_repeat_ngram_logits( - input_ids: torch.LongTensor, + input_ids: torch.Tensor, cur_len: int, logits: torch.FloatTensor, batch_size: int = 1, @@ -1099,7 +1276,7 @@ def _no_repeat_ngram_logits( def _calc_banned_tokens( - prev_input_ids: torch.LongTensor, + prev_input_ids: torch.Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int, @@ -1145,3 +1322,22 @@ def get_table_token_ids(processor): if token.startswith("