"""Vision processing for images and vision models.
This module provides image processing, vision model analysis, and multimodal embeddings.
"""
import base64
import io
import math
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
from ..types import (EmbeddingModelMultimodal, ImageFormat, ImageInfo,
VisionAnalysis, VisionModel)
from ..utilities import get_mime_type
if TYPE_CHECKING:
from kerb.core.enums import Device
# ============================================================================
# Image Processing Functions
# ============================================================================
[docs]
def load_image(file_path: str) -> Any:
"""Load an image from file.
Args:
file_path: Path to the image file
Returns:
PIL.Image: Loaded image object
Raises:
ImportError: If PIL is not installed
FileNotFoundError: If file doesn't exist
Examples:
>>> img = load_image("photo.jpg")
>>> img.size
(1920, 1080)
"""
try:
from PIL import Image
except ImportError:
raise ImportError(
"PIL (Pillow) is required for image processing. Install with: pip install Pillow"
)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Image file not found: {file_path}")
return Image.open(file_path)
[docs]
def get_image_info(file_path: str) -> ImageInfo:
"""Get detailed information about an image.
Args:
file_path: Path to the image file
Returns:
ImageInfo: Image information object
Examples:
>>> info = get_image_info("photo.jpg")
>>> print(f"{info.width}x{info.height}")
1920x1080
"""
img = load_image(file_path)
size_bytes = os.path.getsize(file_path)
format_map = {
"JPEG": ImageFormat.JPEG,
"PNG": ImageFormat.PNG,
"WEBP": ImageFormat.WEBP,
"GIF": ImageFormat.GIF,
"BMP": ImageFormat.BMP,
"TIFF": ImageFormat.TIFF,
"SVG": ImageFormat.SVG,
}
img_format = format_map.get(img.format, ImageFormat.JPEG)
aspect_ratio = img.width / img.height if img.height > 0 else 0.0
metadata = {}
if hasattr(img, "info"):
metadata = dict(img.info)
return ImageInfo(
width=img.width,
height=img.height,
format=img_format,
mode=img.mode,
size_bytes=size_bytes,
aspect_ratio=aspect_ratio,
metadata=metadata,
)
[docs]
def image_to_base64(file_path: str, include_prefix: bool = True) -> str:
"""Convert image to base64 string.
Args:
file_path: Path to the image file
include_prefix: Whether to include data URI prefix
Returns:
str: Base64-encoded image string
Examples:
>>> b64 = image_to_base64("photo.jpg")
>>> b64[:30]
'data:image/jpeg;base64,/9j/4A'
"""
with open(file_path, "rb") as f:
image_data = f.read()
b64_string = base64.b64encode(image_data).decode("utf-8")
if include_prefix:
mime_type = get_mime_type(file_path)
return f"data:{mime_type};base64,{b64_string}"
return b64_string
[docs]
def base64_to_image(b64_string: str, output_path: str) -> str:
"""Convert base64 string to image file.
Args:
b64_string: Base64-encoded image (with or without prefix)
output_path: Path to save the image
Returns:
str: Path to the saved image
Examples:
>>> base64_to_image(b64_data, "output.jpg")
'output.jpg'
"""
# Remove data URI prefix if present
if "," in b64_string and b64_string.startswith("data:"):
b64_string = b64_string.split(",", 1)[1]
image_data = base64.b64decode(b64_string)
with open(output_path, "wb") as f:
f.write(image_data)
return output_path
[docs]
def calculate_image_hash(file_path: str, hash_size: int = 8) -> str:
"""Calculate perceptual hash of an image for similarity comparison.
Args:
file_path: Path to the image file
hash_size: Size of hash (default 8 gives 64-bit hash)
Returns:
str: Hexadecimal hash string
Examples:
>>> hash1 = calculate_image_hash("photo1.jpg")
>>> hash2 = calculate_image_hash("photo2.jpg")
>>> hash1 == hash2 # Similar images have same hash
True
"""
img = load_image(file_path)
# Convert to grayscale and resize
img = img.convert("L")
img = img.resize((hash_size + 1, hash_size), resample=1)
# Calculate difference hash (dHash)
pixels = list(img.getdata())
difference = []
for row in range(hash_size):
for col in range(hash_size):
pixel_left = pixels[row * (hash_size + 1) + col]
pixel_right = pixels[row * (hash_size + 1) + col + 1]
difference.append(pixel_left > pixel_right)
# Convert to hex
hex_string = ""
for i in range(0, len(difference), 4):
chunk = difference[i : i + 4]
hex_value = sum([2**j for j, b in enumerate(chunk) if b])
hex_string += format(hex_value, "x")
return hex_string
# ============================================================================
# Vision Model Integration
# ============================================================================
[docs]
def analyze_image_with_vision_model(
image_path: str,
prompt: str,
model: Union[str, VisionModel] = VisionModel.GPT4O,
api_key: Optional[str] = None,
max_tokens: int = 300,
) -> VisionAnalysis:
"""Analyze an image using a vision model.
Args:
image_path: Path to the image file
prompt: Text prompt/question about the image
model: Vision model to use
api_key: API key for the model provider
max_tokens: Maximum tokens in response
Returns:
VisionAnalysis: Analysis result with description and metadata
Examples:
>>> analysis = analyze_image_with_vision_model(
... "photo.jpg",
... "What objects are in this image?"
... )
>>> print(analysis.description)
'The image contains a cat, a book, and a coffee mug on a table.'
"""
model_str = model.value if isinstance(model, VisionModel) else model
# Determine provider
if model_str.startswith("gpt-4"):
return _analyze_openai_vision(
image_path, prompt, model_str, api_key, max_tokens
)
elif model_str.startswith("claude-3"):
return _analyze_anthropic_vision(
image_path, prompt, model_str, api_key, max_tokens
)
elif model_str.startswith("gemini"):
return _analyze_google_vision(
image_path, prompt, model_str, api_key, max_tokens
)
else:
raise ValueError(f"Unsupported vision model: {model_str}")
def _analyze_openai_vision(
image_path: str, prompt: str, model: str, api_key: Optional[str], max_tokens: int
) -> VisionAnalysis:
"""Analyze image using OpenAI vision model."""
try:
import openai
except ImportError:
raise ImportError("openai is required. Install with: pip install openai")
api_key = api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OpenAI API key required")
client = openai.OpenAI(api_key=api_key)
# Convert image to base64
image_b64 = image_to_base64(image_path, include_prefix=True)
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_b64}},
],
}
],
max_tokens=max_tokens,
)
description = response.choices[0].message.content
return VisionAnalysis(
description=description,
metadata={
"model": model,
"usage": response.usage.model_dump() if response.usage else {},
},
)
def _analyze_anthropic_vision(
image_path: str, prompt: str, model: str, api_key: Optional[str], max_tokens: int
) -> VisionAnalysis:
"""Analyze image using Anthropic Claude vision model."""
try:
import anthropic
except ImportError:
raise ImportError("anthropic is required. Install with: pip install anthropic")
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("Anthropic API key required")
client = anthropic.Anthropic(api_key=api_key)
# Read image and encode
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
mime_type = get_mime_type(image_path)
response = client.messages.create(
model=model,
max_tokens=max_tokens,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_data,
},
},
{"type": "text", "text": prompt},
],
}
],
)
description = response.content[0].text
return VisionAnalysis(
description=description,
metadata={
"model": model,
"usage": {
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
},
},
)
def _analyze_google_vision(
image_path: str, prompt: str, model: str, api_key: Optional[str], max_tokens: int
) -> VisionAnalysis:
"""Analyze image using Google Gemini vision model."""
try:
import google.generativeai as genai
except ImportError:
raise ImportError(
"google-generativeai is required. Install with: pip install google-generativeai"
)
api_key = api_key or os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("Google API key required")
genai.configure(api_key=api_key)
model_obj = genai.GenerativeModel(model)
# Load image
from PIL import Image
img = Image.open(image_path)
response = model_obj.generate_content([prompt, img])
return VisionAnalysis(
description=response.text,
metadata={
"model": model,
"finish_reason": (
response.candidates[0].finish_reason if response.candidates else None
),
},
)
# ============================================================================
# Multi-Modal Embeddings
# ============================================================================
[docs]
def embed_multimodal(
content: Union[str, bytes],
content_type: str,
model: Union[
str, EmbeddingModelMultimodal
] = EmbeddingModelMultimodal.CLIP_VIT_B_32,
device: Union["Device", str] = "cpu",
) -> List[float]:
"""Generate multi-modal embeddings for images, audio, or text.
Args:
content: Content to embed (file path for images/audio, text string for text)
content_type: Type of content ("image", "audio", "text")
model: Embedding model to use
device: Device to run model on (Device enum or string: "cpu", "cuda", "cuda:0", "cuda:1", "mps")
Returns:
List of embedding values
Examples:
>>> from kerb.core.enums import Device
>>> embedding = embed_multimodal("photo.jpg", "image", device=Device.CUDA)
>>> len(embedding)
512
"""
from kerb.core.enums import Device, validate_enum_or_string
model_str = model.value if isinstance(model, EmbeddingModelMultimodal) else model
# Validate and normalize device
device_val = validate_enum_or_string(device, Device, "device")
if isinstance(device_val, Device):
device_str = device_val.value
else:
device_str = device_val
if model_str.startswith("clip"):
return _embed_clip(content, content_type, model_str, device_str)
elif model_str == "imagebind":
return _embed_imagebind(content, content_type, device_str)
else:
raise ValueError(f"Unsupported embedding model: {model_str}")
def _embed_clip(
content: Union[str, bytes], content_type: str, model: str, device: str
) -> List[float]:
"""Generate CLIP embeddings."""
try:
import torch
from transformers import CLIPModel, CLIPProcessor
except ImportError:
raise ImportError(
"transformers and torch required. Install with: pip install transformers torch"
)
# Load model
model_obj = CLIPModel.from_pretrained(model).to(device)
processor = CLIPProcessor.from_pretrained(model)
if content_type == "image":
from PIL import Image
image = Image.open(content)
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model_obj.get_image_features(**inputs)
embedding = image_features[0].cpu().numpy().tolist()
elif content_type == "text":
inputs = processor(text=content, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_features = model_obj.get_text_features(**inputs)
embedding = text_features[0].cpu().numpy().tolist()
else:
raise ValueError(f"Content type {content_type} not supported for CLIP")
return embedding
def _embed_imagebind(
content: Union[str, bytes], content_type: str, device: str
) -> List[float]:
"""Generate ImageBind embeddings (supports image, audio, text)."""
try:
import torch
# ImageBind would be imported here
# from imagebind import data
# from imagebind.models import imagebind_model
except ImportError:
raise ImportError("imagebind required for ImageBind embeddings")
# Placeholder - actual implementation would use ImageBind
raise NotImplementedError("ImageBind implementation requires the imagebind package")
[docs]
def compute_multimodal_similarity(
embedding1: List[float], embedding2: List[float]
) -> float:
"""Compute cosine similarity between two multi-modal embeddings.
Args:
embedding1: First embedding vector
embedding2: Second embedding vector
Returns:
float: Cosine similarity score (-1 to 1)
Examples:
>>> emb1 = embed_multimodal("photo1.jpg", "image")
>>> emb2 = embed_multimodal("photo2.jpg", "image")
>>> similarity = compute_multimodal_similarity(emb1, emb2)
>>> print(f"Similarity: {similarity:.3f}")
Similarity: 0.892
"""
# Compute dot product
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
# Compute magnitudes
mag1 = math.sqrt(sum(a * a for a in embedding1))
mag2 = math.sqrt(sum(b * b for b in embedding2))
# Avoid division by zero
if mag1 == 0 or mag2 == 0:
return 0.0
return dot_product / (mag1 * mag2)