Skip to content

Commit

Permalink
chore: Refactor codebase to limit the number of mypy errors
Browse files Browse the repository at this point in the history
Signed-off-by: Nikos Livathinos <[email protected]>
  • Loading branch information
nikos-livathinos committed Jan 27, 2025
1 parent 71ff4bc commit 14e71a7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 234 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,10 @@ def predict(
else:
raise TypeError("Not supported input image format")
images_tmp.append(image)
images = images_tmp

images_tensor = torch.stack([self._image_processor(img) for img in images]).to(
self._device
)
images_tensor = torch.stack(
[self._image_processor(img) for img in images_tmp]
).to(self._device)

prompts = [self._get_prompt(label) for label in labels]

Expand Down
9 changes: 5 additions & 4 deletions docling_ibm_models/code_formula_model/models/sam_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def embed_tokens(self, x):

def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: torch.FloatTensor = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:

Expand All @@ -86,6 +86,7 @@ def forward(

if input_ids.shape[1] != 1 or self.training:
with torch.set_grad_enabled(self.training):
assert vision_tower is not None
image_features = vision_tower(images)
image_features = image_features.flatten(2).permute(0, 2, 1)
image_features = self.mm_projector(image_features)
Expand Down Expand Up @@ -115,13 +116,13 @@ def forward(

new_input_embeds.append(cur_input_embeds)

inputs_embeds = torch.stack(new_input_embeds, dim=0)
next_inputs_embeds = torch.stack(new_input_embeds, dim=0)

return super(SamOPTModel, self).forward(
input_ids=None,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
inputs_embeds=next_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,23 @@ def predict(
The predictions for each image are sorted in descending order of confidence.
"""
processed_images = []
rgb_images = []
for image in images:
if isinstance(image, Image.Image):
processed_images.append(image.convert("RGB"))
rgb_images.append(image.convert("RGB"))
elif isinstance(image, np.ndarray):
processed_images.append(Image.fromarray(image).convert("RGB"))
rgb_images.append(Image.fromarray(image).convert("RGB"))
else:
raise TypeError(
"Supported input formats are PIL.Image.Image or numpy.ndarray."
)
images = processed_images

# (batch_size, 3, 224, 224)
images = [self._image_processor(image) for image in images]
images = torch.stack(images).to(self._device)
processed_images = [self._image_processor(image) for image in rgb_images]
torch_images = torch.stack(processed_images).to(self._device)

with torch.no_grad():
logits = self._model(images).logits # (batch_size, num_classes)
logits = self._model(torch_images).logits # (batch_size, num_classes)
probs_batch = logits.softmax(dim=1) # (batch_size, num_classes)
probs_batch = probs_batch.cpu().numpy().tolist()

Expand Down
2 changes: 1 addition & 1 deletion docling_ibm_models/tableformer/otsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
logger = s.get_custom_logger("consolidate", LOG_LEVEL)
png_files = {} # Evaluation files
# png_files = {} # Evaluation files
total_pics = 0


Expand Down
5 changes: 3 additions & 2 deletions docling_ibm_models/tableformer/utils/mem_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import platform
import re
from typing import Dict, Union


class MemMonitor:
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(self, enable=True):
regex_str = r"({}:)(\s+)(\d*)(.*)".format(mem_field)
self._status_regex[mem_field] = re.compile(regex_str)

def get_memory_full(self) -> dict:
def get_memory_full(self) -> Union[Dict, int]:
r"""
- Parse /proc/<pid>status to get all memory info.
- The method returns a dict with the fields self._status_fields
Expand Down Expand Up @@ -140,7 +141,7 @@ def get_memory_full(self) -> dict:

return memory

def get_memory(self) -> dict:
def get_memory(self) -> Union[Dict, int]:
r"""
- Parse /proc/<pid>statm to get the most important memory fields
- This is a fast implementation.
Expand Down
216 changes: 0 additions & 216 deletions docling_ibm_models/tableformer/utils/torch_utils.py

This file was deleted.

0 comments on commit 14e71a7

Please sign in to comment.