Skip to content

Commit

Permalink
Florence2 distributed inference example (#3123)
Browse files Browse the repository at this point in the history
* Florence2 distributed inference example

* optimized

* Documentation
  • Loading branch information
hlky authored Oct 9, 2024
1 parent 55136b8 commit f4ee5a2
Showing 1 changed file with 202 additions and 0 deletions.
202 changes: 202 additions & 0 deletions examples/inference/distributed/florence2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import pathlib
import queue
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Union

import fire
import torch
import webdataset as wds
from huggingface_hub.utils import insecure_hashlib
from PIL import Image
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor

from accelerate import PartialState


"""
Additional requirements: flash_attn einops timm webdataset fire tqdm huggingface_hub
pip install flash_attn einops timm webdataset fire tqdm huggingface_hub
Example:
accelerate launch --num_processes=2 florence2.py --data_path "https://huggingface.co/datasets/pixparse/cc3m-wds/resolve/main/cc3m-train-0000.tar" --output_path outputs --batch_size 12 --num_workers 1 --prompt "<CAPTION>"
"""


def main(
data_path: str,
output_path: str,
batch_size: int,
num_workers: int,
prompt: str = "<MORE_DETAILED_CAPTION>",
model_name: str = "microsoft/Florence-2-large",
max_new_tokens: int = 1024,
num_beams: int = 3,
):
output_dir = pathlib.Path(output_path)

distributed_state = PartialState()

if distributed_state.is_main_process:
output_dir.mkdir(exist_ok=True)

model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=distributed_state.device,
torch_dtype=torch.float16,
trust_remote_code=True,
)

processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, clean_up_tokenization_spaces=True)

class ExistsFilter:
def __init__(self, output_dir: Union[pathlib.Path, str]):
current_training_img_hashes = [f.split(".jpg")[0] for f in os.listdir(output_dir) if f.endswith(".jpg")]
self.current_training_img_hashes = set(current_training_img_hashes)
if distributed_state.is_main_process:
print(f"Existing images found: {len(self.current_training_img_hashes)}.")

def __call__(self, x):
if len(self.current_training_img_hashes) > 0:
if x["img_hash"] in self.current_training_img_hashes:
return False
else:
return True
else:
return True

def preprocess_fn(sample, processor):
image: Image.Image = sample["jpg"].convert("RGB")
img_hash = insecure_hashlib.sha1(image.tobytes()).hexdigest()
inputs = processor(
text=prompt,
images=image,
return_tensors="pt",
)
return {
"input_ids": inputs["input_ids"],
"pixel_values": inputs["pixel_values"],
"image": image,
"img_hash": img_hash,
"original_caption": sample["txt"],
}

def collate_fn(examples):
input_ids = torch.cat([example["input_ids"] for example in examples])
pixel_values = torch.cat([example["pixel_values"] for example in examples])
images = [example["image"] for example in examples]
img_hashes = [example["img_hash"] for example in examples]
captions = [example["original_caption"] for example in examples]
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"images": images,
"img_hashes": img_hashes,
"original_captions": captions,
}

exist_filter = ExistsFilter(output_dir)
dataset = (
wds.WebDataset(
data_path,
handler=wds.warn_and_continue,
nodesplitter=None,
shardshuffle=False,
empty_check=False,
)
.decode("pil", handler=wds.warn_and_continue)
.map(partial(preprocess_fn, processor=processor), handler=wds.warn_and_continue)
)
if len(exist_filter.current_training_img_hashes) > 0:
dataset = dataset.select(exist_filter)
dataset = dataset.batched(
batch_size,
partial=False,
collation_fn=collate_fn,
)
dataloader = wds.WebLoader(
dataset,
batch_size=None,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
)

def save_results(output_queue: queue.Queue, output_dir: pathlib.Path, processor):
while True:
try:
item = output_queue.get(timeout=5)
if item is None:
break
original_captions, predictions, images, img_hashes = item
predicted_captions = processor.batch_decode(
predictions,
skip_special_tokens=False,
)
for caption, pred_caption, image, img_hash in zip(
original_captions, predicted_captions, images, img_hashes
):
processed_caption = processor.post_process_generation(
pred_caption, task=prompt, image_size=(image.width, image.height)
)[prompt]
img_path = output_dir.joinpath(f"{img_hash}.jpg")
image.save(img_path)

caption_dict = {"original": caption, "predicted": processed_caption}
with output_dir.joinpath(f"{img_hash}_caption.json").open("w") as f:
json.dump(caption_dict, f, indent=4)

except queue.Empty:
continue

output_queue = queue.Queue()
save_thread = ThreadPoolExecutor(max_workers=num_workers)
save_future = save_thread.submit(save_results, output_queue, output_dir, processor)

try:
for _, batch_raw in tqdm(
enumerate(dataloader),
disable=not distributed_state.is_main_process,
):
with distributed_state.split_between_processes(batch_raw) as batch:
outputs = model.generate(
input_ids=batch["input_ids"].to(distributed_state.device),
pixel_values=batch["pixel_values"].to(distributed_state.device, model.dtype),
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
output_queue.put(
(
batch["original_captions"],
outputs,
batch["images"],
batch["img_hashes"],
)
)
finally:
output_queue.put(None)
save_thread.shutdown(wait=True)

save_future.result()


if __name__ == "__main__":
fire.Fire(main)

0 comments on commit f4ee5a2

Please sign in to comment.