Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image captioner refactorization #70

Merged
merged 9 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import os
import glob
import logging
import csv
import json
from datetime import datetime
from dotenv import load_dotenv
import asyncio
import torch
from PIL import Image, UnidentifiedImageError
from transformers import BlipProcessor, BlipForConditionalGeneration, PreTrainedModel
from transformers import BlipProcessor, BlipForConditionalGeneration
import hashlib

# Initialize logging at the beginning of the script
# Initialize module-specific logger
logger = logging.getLogger(__name__)
logging_level = os.getenv('LOGGING_LEVEL', 'INFO').upper()
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO))
logger.setLevel(getattr(logging, logging_level, logging.INFO))
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

class ImageCaptioner:
"""
Expand All @@ -31,16 +35,19 @@ def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"):

Args:
model_name (str): The name of the model to be loaded.

This initializer sets the device to 'cuda:0' if a CUDA-capable GPU is available, otherwise defaults to 'cpu'.
"""
self.is_initialized = True
self.caption_cache = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = os.getenv('MODEL_NAME', "Salesforce/blip-image-captioning-base")
try:
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device)
logging.info("Successfully loaded model and processor.")
logger.info("Successfully loaded model and processor.")
except Exception as e:
logging.error(f"Failed to load model and processor: {e}")
logger.error(f"Failed to load model and processor: {e}")
self.is_initialized = False
raise

Expand All @@ -57,13 +64,19 @@ def load_image(self, image_path: str) -> Image.Image:
try:
return Image.open(image_path).convert('RGB')
except UnidentifiedImageError as e:
logging.error(f"Failed to load image: {e}")
return None
logger.error(f"Failed to load image: {e}")
except FileNotFoundError:
logger.error(f"Image file not found: {image_path}")
except Exception as e:
logger.error(f"Unknown error occurred while loading image: {e}")
return None

async def generate_caption(self, raw_image: Image.Image, text: str = None) -> str:
"""
Generates a caption for the given image asynchronously with added features like caching and device selection.

This method uses a hash of the image contents for caching to efficiently store and retrieve previously generated captions.

Args:
raw_image (Image.Image): The image for which to generate a caption.
text (str, optional): Optional text to condition the captioning.
Expand All @@ -72,118 +85,49 @@ async def generate_caption(self, raw_image: Image.Image, text: str = None) -> st
str or None: The generated caption or None if captioning failed.
"""
try:
# Check if this image has been processed before
cache_key = f"{id(raw_image)}_{text}"
def image_hash(image: Image.Image) -> str:
image_bytes = image.tobytes()
return hashlib.md5(image_bytes).hexdigest()

cache_key = f"{image_hash(raw_image)}_{text}"
if cache_key in self.caption_cache:
return self.caption_cache[cache_key]

inputs = self.processor(raw_image, text, return_tensors="pt").to(self.device) if text else self.processor(raw_image, return_tensors="pt").to(self.device)
out = self.model.generate(**inputs)
caption = self.processor.batch_decode(out, skip_special_tokens=True)[0]

# Store the generated caption in cache

self.caption_cache[cache_key] = caption

return caption
except Exception as e:
logging.error(f"Failed to generate caption: {e}")
logger.error(f"Failed to generate caption: {e}")
return None

def save_to_csv(self, image_name: str, caption: str, file_name: str = None, csvfile=None):
def save_to_csv(self, image_name: str, caption: str, file_name: str = None, csvfile=None, mode='a'):
"""
Saves the image name and the generated caption to a CSV file, supporting both file name and file object inputs.

Saves the image name and the generated caption to a CSV file. This method supports writing to a CSV file using either a file name or a file object. If a file object is provided, it takes precedence over the file name.

Enhanced error handling is included to manage potential issues when writing to the CSV file. The method also allows for specifying the file write mode, adding flexibility in how the CSV file is handled (e.g., append or write).

Args:
image_name (str): The name of the image file.
caption (str): The generated caption.
file_name (str, optional): The name of the CSV file. Defaults to a timestamp-based name.
csvfile (file object, optional): The CSV file to write to. Takes precedence over file_name if provided.
"""
if csvfile is None:
if file_name is None:
file_name = f"captions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
with open(file_name, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([image_name, caption])
if csvfile is not None and file_name is not None:
csvfile.close()

class ConfigurationManager:
"""
A class for managing configuration settings for the ImageCaptioner.

Attributes:
config (dict): The configuration settings.
"""

def __init__(self):
"""
Initializes the ConfigurationManager and loads settings from a JSON file and environment variables.
file_name (str, optional): The name of the CSV file. If not provided, a timestamp-based name is used.
csvfile (file object, optional): The CSV file to write to. If provided, this takes precedence over file_name.
mode (str, optional): The file mode for writing to the CSV (e.g., 'a' for append, 'w' for write). Defaults to 'a'.
"""
self.config = self.load_config()

def load_config(self) -> dict:
"""
Loads and validates configuration settings from a JSON file and environment variables.

Returns:
dict: The loaded and validated configuration settings.
"""
# Initialize with default values
config_updated = False
config = {
'IMAGE_FOLDER': 'images',
'BASE_NAME': 'your_image_name_here.jpg',
'ENDING_CAPTION': "AI generated Artwork by Daethyra using DallE"
}
if csvfile is None and file_name is not None:
csvfile = open(file_name, mode, newline='')

# Try to load settings from configuration file
try:
with open('config.json', 'r') as f:
file_config = json.load(f)
config.update(file_config)
except FileNotFoundError:
logging.error("Configuration file config.json not found.")
except json.JSONDecodeError as e:
logging.error(f"Failed to parse configuration file: {e}")
writer = csv.writer(csvfile)
writer.writerow([image_name, caption])
except Exception as e:
logging.error(f"An unknown error occurred while loading the configuration file: {e}")

# Validate the loaded settings
self.validate_config(config)

# Fallback to environment variables and offer to update the JSON configuration
for key in config.keys():
env_value = os.getenv(key, None)
if env_value:
logging.info(f"Falling back to environment variable for {key}: {env_value}")
config[key] = env_value

# Offering to update the JSON configuration file with new settings
if config_updated:
try:
with open('config.json', 'w') as f:
json.dump(config, f, indent=4)
except Exception as e:
logging.error(f"Failed to update configuration file: {e}")

return config

def validate_config(self, config: dict):
"""
Validates the loaded configuration settings.

Args:
config (dict): The loaded configuration settings.
"""
if not config.get('IMAGE_FOLDER'):
logging.error("The IMAGE_FOLDER is missing or invalid.")

if not config.get('BASE_NAME'):
logging.error("The BASE_NAME is missing or invalid.")

if not config.get('ENDING_CAPTION'):
logging.error("The ENDING_CAPTION is missing or invalid.")
logger.error(f"Failed to write to CSV file: {e}")
finally:
if csvfile and file_name is not None:
csvfile.close()

async def main() -> None:
"""
Expand All @@ -196,40 +140,33 @@ async def main() -> None:
4. List all image files in the configured directory.
5. Loop through each image file to generate and save both unconditional and conditional captions.
"""
# Load environment variables from a .env file
load_dotenv()

# Initialize the configuration manager to load and manage settings
config_manager = ConfigurationManager()
config = config_manager.config
# Get configuration from environment variables
image_folder = os.getenv('IMAGE_FOLDER', 'images')
ending_caption = os.getenv('ENDING_CAPTION', 'AI generated Artwork by Daethyra using Stable Diffusion XL.').strip('"')
model_name = os.getenv('MODEL_NAME', 'Salesforce/blip-image-captioning-base')
use_conditional_caption = os.getenv('USE_CONDITIONAL_CAPTION', 'true').lower() == 'true'

# Initialize the ImageCaptioner with the specified model
captioner = ImageCaptioner()
captioner = ImageCaptioner(model_name=model_name)

# Get a list of all image files in the specified directory
image_files = list_image_files(config['IMAGE_FOLDER'])
image_files = glob.glob(os.path.join(image_folder, "*.jpg"), os.path.join(image_folder, "*.jpeg"), os.path.join(image_folder, "*.png"))

# Default to using the conditional captioning logic
use_conditional_caption = config.get('USE_CONDITIONAL_CAPTION', True)

# Loop through each image file in the directory
# Process each image file in the directory
for image_file in image_files:
# Load the image from file
raw_image = captioner.load_image(image_file)

try:
# Check if the image was successfully loaded
if raw_image:
# If the user has opted for conditional captions, generate and save them.
if use_conditional_caption:
caption = await captioner.generate_caption(raw_image, config['ENDING_CAPTION'])
else:
# Fallback to unconditional caption if the conditional caption is not selected.
caption = await captioner.generate_caption(raw_image)

# Save the chosen caption to a CSV file.
# Generate a caption, conditionally or unconditionally, based on the configuration
caption = await captioner.generate_caption(raw_image, ending_caption) if use_conditional_caption else await captioner.generate_caption(raw_image)

# Save the image file name and its generated caption to a CSV file
captioner.save_to_csv(os.path.basename(image_file), caption)

except Exception as e:
logging.error(f"An unexpected error occurred: {e}")
# Log any errors that occur during the caption generation or CSV writing process
logger.error(f"An unexpected error occurred with {image_file}: {e}")

# Entry point for the script execution
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

This file was deleted.

12 changes: 9 additions & 3 deletions src/llm_utilikit/HuggingFace/image_captioner/template.env
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Logging level for the application
# Logging level for the application (e.g., DEBUG, INFO, ERROR, WARNING, CRITICAL)
LOGGING_LEVEL=INFO

# Pretrained model name
# Pretrained model name for the image captioning (e.g., Salesforce/blip-image-captioning-base)
MODEL_NAME=Salesforce/blip-image-captioning-base

# Whether to use conditional captioning logic
# Whether to use conditional captioning logic (true/false)
USE_CONDITIONAL_CAPTION=true

# Set directory to search for images
IMAGE_FOLDER=images

# Append caption content
ENDING_CAPTION="AI generated Artwork by Daethyra using Stable Diffusion XL."
Loading