Skip to content

Commit

Permalink
Merge pull request #57 from Open-Model-Initiative/50-53-add-new-hdr-i…
Browse files Browse the repository at this point in the history
…mage-endpoints

Add new hdr image endpoints and hugging face endpoint
  • Loading branch information
CheesyLaZanya authored Sep 26, 2024
2 parents 860c484 + 66936fb commit 63037d5
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ DISCORD_CLIENT_ID=discord_client_id
DISCORD_CLIENT_SECRET=discord_client_secret

OAUTH2_REDIRECT_PATH=docs

## Hugging Face
HF_TOKEN=your_access_token
HF_HDR_DATASET_NAME=openmodelinitiative/hdr-submissions
2 changes: 1 addition & 1 deletion modules/odr_api/docker/Dockerfile.api
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ WORKDIR /app
COPY modules/odr_api/requirements.txt .

# Install dependencies
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu124

COPY modules/odr_core /app/modules/odr_core
RUN pip install --no-cache-dir -e /app/modules/odr_core
Expand Down
2 changes: 2 additions & 0 deletions modules/odr_api/odr_api/api/endpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
from odr_api.api.endpoints.user import router as user_router
from odr_api.api.endpoints.auth import router as auth_router
from odr_api.api.endpoints.health import router as health_router
from odr_api.api.endpoints.image import router as image_router
from odr_api.api.endpoints.hugging_face import router as hugging_face_router
45 changes: 45 additions & 0 deletions modules/odr_api/odr_api/api/endpoints/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
from fastapi import APIRouter, File, UploadFile, HTTPException
from huggingface_hub import HfApi
from pathlib import Path
import tempfile
from io import BytesIO
from odr_core.config import settings

router = APIRouter(tags=["hugging_face"])


@router.post("/hugging-face/upload-image")
async def upload_to_huggingface(file: UploadFile = File(...)):
HF_TOKEN = settings.HF_TOKEN

if not HF_TOKEN:
raise HTTPException(status_code=500, detail="Hugging Face token not configured")

HF_HDR_DATASET_NAME = settings.HF_HDR_DATASET_NAME

if not HF_HDR_DATASET_NAME:
raise HTTPException(status_code=500, detail="Hugging Face hdr dataset name not configured")

try:
# Read the file contents
contents = await file.read()
safe_filename = Path(file.filename).name

# Initialize Hugging Face API
api = HfApi()

# Upload the file to Hugging Face
file_object = BytesIO(contents)
commit_info = api.upload_file(
path_or_fileobj=file_object,
path_in_repo=f"images/{safe_filename}",
repo_id=HF_HDR_DATASET_NAME,
repo_type="dataset",
token=HF_TOKEN
)

return {"message": "Image uploaded successfully", "commit_url": commit_info.commit_url}

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
184 changes: 184 additions & 0 deletions modules/odr_api/odr_api/api/endpoints/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import torch
import numpy as np
import torchvision.transforms as transforms
import rawpy
from fastapi import APIRouter, File, UploadFile, HTTPException
from typing import Dict
from io import BytesIO
from PIL import Image, UnidentifiedImageError
from PIL.ExifTags import TAGS
from PIL.TiffImagePlugin import IFDRational
import imageio
from typing import Any
import base64

router = APIRouter(tags=["image"])


# Helper functions for HDR stats
def calculate_kurtosis(tensor: torch.Tensor):
# Calculate the mean and standard deviation
mean = torch.mean(tensor)
std = torch.std(tensor, unbiased=True)

# Calculate the fourth central moment
fourth_moment = torch.mean((tensor - mean) ** 4)

# Calculate kurtosis (excess kurtosis)
kurtosis_val = fourth_moment / (std ** 4)

return kurtosis_val - 3 # Subtract 3 for excess kurtosis (to make normal distribution kurtosis = 0)


def calculate_msd(tensor: torch.Tensor):
mean = torch.mean(tensor)
msd = torch.mean((tensor - mean) ** 2)
return msd


def calculate_dynamic_range(tensor: torch.Tensor, epsilon=1e-10):
I_min = torch.min(tensor)
I_max = torch.max(tensor)

# Prevent division by zero in both min and max
I_min = torch.clamp(I_min, min=epsilon)
I_max = torch.clamp(I_max, min=epsilon)

# Using decibels for DR
dynamic_range_db = 20 * torch.log10(I_max / I_min)
return dynamic_range_db


def calculate_entropy(tensor: torch.Tensor):
# Flatten the tensor to 1D
tensor = tensor.flatten()

# Get unique values and their counts
unique_values, counts = torch.unique(tensor, return_counts=True)

# Calculate the probabilities of each unique value
probabilities = counts.float() / counts.sum()

# Calculate entropy using the Shannon entropy formula
entropy = -torch.sum(probabilities * torch.log2(probabilities))

return entropy


# Helper functions for HDR metadata and preview conversion
def extract_metadata(image_bytes: bytes, is_dng: bool) -> Dict:
metadata = {}
try:
with Image.open(BytesIO(image_bytes)) as img:
exif_data = img.getexif()
if exif_data:
for tag_id, value in exif_data.items():
tag = TAGS.get(tag_id, tag_id)
metadata[tag] = value
print(f"Metadata extracted from {'DNG' if is_dng else 'JPG'} file")
except UnidentifiedImageError:
print(f"Could not extract metadata from {'DNG' if is_dng else 'JPG'}")
return metadata


def convert_ifd_rational(value):
if isinstance(value, IFDRational):
return float(value)
elif isinstance(value, tuple) and all(isinstance(v, IFDRational) for v in value):
return tuple(float(v) for v in value)
return value


def check_metadata(metadata: Dict) -> Dict[str, Any]:
important_keys = ['Make', 'Model', 'BitsPerSample', 'BaselineExposure', 'LinearResponseLimit', 'ImageWidth', 'ImageLength', 'DateTime']
result = {key: convert_ifd_rational(metadata.get(key)) for key in important_keys if key in metadata}

gps_keys = [key for key in metadata.keys() if isinstance(key, str) and 'GPS' in key.upper()]
gps_keys += [key for key in metadata.keys() if isinstance(key, int) and key == 34853] # GPSInfo tag number

if gps_keys:
raise ValueError(f"GPS data found in metadata: {gps_keys}")

if 'DNGVersion' in metadata:
dng_version = metadata['DNGVersion']
version_string = '.'.join(str(b) for b in dng_version)
result['DNGVersion'] = version_string

return result


def convert_dng_to_jpg(dng_bytes: bytes) -> bytes:
with rawpy.imread(BytesIO(dng_bytes)) as raw:
rgb = raw.postprocess()
jpg_bytes = BytesIO()
imageio.imwrite(jpg_bytes, rgb, format='JPEG')
return jpg_bytes.getvalue()


# Endpoint for calculating HDR stats
@router.post("/image/hdr-stats", response_model=Dict[str, float])
async def calculate_hdr_stats(file: UploadFile = File(...)):
try:
contents = await file.read()
with rawpy.imread(BytesIO(contents)) as raw:
rgb_16 = raw.postprocess(
use_camera_wb=True,
output_color=rawpy.ColorSpace.ACES,
output_bps=16
)
rgb_16 = rgb_16.astype(np.float32) / 65535.0
rgb_16_tensor = torch.tensor(rgb_16, dtype=torch.float32).permute(2, 0, 1)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
image = rgb_16_tensor.to(device)

kurtosis = calculate_kurtosis(image).item()
msd = calculate_msd(image).item()
dynamic_range = calculate_dynamic_range(image).item()
entropy = calculate_entropy(image).item()

return {
"kurtosis": kurtosis,
"msd": msd,
"dynamic_range": dynamic_range,
"entropy": entropy
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))


@router.post("/image/jpg-preview")
async def create_jpg_preview(file: UploadFile = File(...)):
try:
contents = await file.read()
jpg_bytes = convert_dng_to_jpg(contents)
encoded_jpg = base64.b64encode(jpg_bytes).decode('utf-8')
return {
"jpg_preview": encoded_jpg,
"content_type": "image/jpeg",
"filename": f"{file.filename.rsplit('.', 1)[0]}.jpg"
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))


# Endpoint for metadata retrieval
@router.post("/image/metadata")
async def get_image_metadata(file: UploadFile = File(...)):
try:
contents = await file.read()
is_dng = file.filename.lower().endswith('.dng')
metadata = extract_metadata(contents, is_dng)

if is_dng:
jpg_bytes = convert_dng_to_jpg(contents)
jpg_metadata = extract_metadata(jpg_bytes, False)
metadata.update(jpg_metadata)

important_metadata = check_metadata(metadata)

return important_metadata
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
14 changes: 13 additions & 1 deletion modules/odr_api/odr_api/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from odr_api.api.endpoints import user_router, team_router, content_router, annotation_router, auth_router, embedding_router, health_router
from odr_api.api.endpoints import (
user_router,
team_router,
content_router,
annotation_router,
auth_router,
embedding_router,
health_router,
image_router,
hugging_face_router
)
from odr_core.config import settings
import uvicorn

Expand Down Expand Up @@ -32,6 +42,8 @@ def test_communication():
app.include_router(auth_router, prefix=settings.API_V1_STR)
app.include_router(embedding_router, prefix=settings.API_V1_STR)
app.include_router(health_router, prefix=settings.API_V1_STR)
app.include_router(image_router, prefix=settings.API_V1_STR)
app.include_router(hugging_face_router, prefix=settings.API_V1_STR)

if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=31100, reload=True)
9 changes: 9 additions & 0 deletions modules/odr_api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@ psycopg2-binary==2.9.9
python-multipart==0.0.9
httpx==0.27.0
fastapi-sso==0.15.0

# hdr statistics
torch==2.4.0
torchvision

# dng conversions and metadata
exif==1.6.0
imageio==2.35.1
rawpy==0.22.0
6 changes: 5 additions & 1 deletion modules/odr_core/odr_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Settings(BaseSettings):
# Models
MODEL_CACHE_DIR: str

# OUATH
# OAUTH
GOOGLE_CLIENT_ID: str
GOOGLE_CLIENT_SECRET: str

Expand All @@ -68,6 +68,10 @@ class Settings(BaseSettings):
CONTENT_EMBEDDING_DIMENSION: int = 512
ANNOTATION_EMBEDDING_DIMENSION: int = 384

# Hugging Face
HF_TOKEN: str
HF_HDR_DATASET_NAME: str

class Config:
env_file = ".env"
env_file_encoding = "utf-8"
Expand Down

0 comments on commit 63037d5

Please sign in to comment.