Source code for kerb.cache.backends

"""Cache backend implementations.

This module provides concrete cache implementations:
- MemoryCache: Fast in-memory LRU cache
- DiskCache: Persistent disk-based cache
- TieredCache: Two-tier memory + disk cache
- LLMCache: High-level wrapper for LLM applications
"""

import hashlib
import json
import os
import pickle
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

from .strategies import generate_embedding_key, generate_prompt_key
from .types import BaseCache, CacheEntry, CacheStats

# ============================================================================
# In-Memory Cache
# ============================================================================


[docs] class MemoryCache(BaseCache): """In-memory cache with LRU eviction and TTL support."""
[docs] def __init__( self, max_size: Optional[int] = 1000, default_ttl: Optional[float] = None ): """Initialize memory cache. Args: max_size: Maximum number of entries default_ttl: Default TTL in seconds Example: >>> cache = MemoryCache(max_size=100, default_ttl=3600) >>> cache.set("key", "value") >>> cache.get("key") 'value' """ super().__init__(max_size, default_ttl) self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
[docs] def get(self, key: str) -> Optional[Any]: """Get value from cache.""" self.stats.total_requests += 1 if key not in self._cache: self.stats.misses += 1 return None entry = self._cache[key] # Check if expired if entry.is_expired(): self.delete(key) self.stats.misses += 1 return None # Update access metadata entry.last_accessed = time.time() entry.access_count += 1 # Move to end (LRU) self._cache.move_to_end(key) self.stats.hits += 1 return entry.value
[docs] def set( self, key: str, value: Any, ttl: Optional[float] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Set value in cache.""" # Use default TTL if not specified if ttl is None: ttl = self.default_ttl # Check if we need to evict if self.max_size is not None and key not in self._cache: if len(self._cache) >= self.max_size: self._evict_oldest() # Create entry now = time.time() entry = CacheEntry( key=key, value=value, created_at=now, last_accessed=now, access_count=0, ttl=ttl, metadata=metadata or {}, ) self._cache[key] = entry self._cache.move_to_end(key) self.stats.size = len(self._cache)
[docs] def delete(self, key: str) -> bool: """Delete key from cache.""" if key in self._cache: del self._cache[key] self.stats.size = len(self._cache) return True return False
[docs] def clear(self) -> None: """Clear all cache entries.""" self._cache.clear() self.stats.size = 0
[docs] def exists(self, key: str) -> bool: """Check if key exists and is not expired.""" if key not in self._cache: return False entry = self._cache[key] if entry.is_expired(): self.delete(key) return False return True
[docs] def size(self) -> int: """Get current cache size.""" return len(self._cache)
[docs] def keys(self) -> List[str]: """Get all cache keys.""" return list(self._cache.keys())
def _evict_oldest(self) -> None: """Evict the oldest (least recently used) entry.""" if self._cache: self._cache.popitem(last=False) self.stats.evictions += 1 self.stats.size = len(self._cache)
[docs] def get_entry(self, key: str) -> Optional[CacheEntry]: """Get full cache entry with metadata.""" if key not in self._cache: return None entry = self._cache[key] if entry.is_expired(): self.delete(key) return None return entry
[docs] def get_stats(self) -> CacheStats: """Get cache statistics.""" return self.stats
[docs] def reset_stats(self) -> None: """Reset cache statistics.""" self.stats = CacheStats(size=len(self._cache))
# ============================================================================ # Disk Cache # ============================================================================
[docs] class DiskCache(BaseCache): """Persistent disk-based cache."""
[docs] def __init__( self, cache_dir: str = ".cache", max_size: Optional[int] = None, default_ttl: Optional[float] = None, serializer: str = "pickle", ): """Initialize disk cache. Args: cache_dir: Directory to store cache files max_size: Maximum number of entries default_ttl: Default TTL in seconds serializer: Serialization format ('pickle' or 'json') Example: >>> cache = DiskCache(cache_dir=".cache/llm") >>> cache.set("key", {"data": "value"}) >>> cache.get("key") {'data': 'value'} """ super().__init__(max_size, default_ttl) self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.serializer = serializer # Metadata file self.metadata_file = self.cache_dir / "_metadata.json" self._metadata: Dict[str, Dict[str, Any]] = self._load_metadata() # Clean expired entries on init self._clean_expired()
def _load_metadata(self) -> Dict[str, Dict[str, Any]]: """Load metadata from disk.""" if self.metadata_file.exists(): with open(self.metadata_file, "r") as f: return json.load(f) return {} def _save_metadata(self) -> None: """Save metadata to disk.""" with open(self.metadata_file, "w") as f: json.dump(self._metadata, f) def _get_cache_path(self, key: str) -> Path: """Get file path for a cache key.""" # Use hash of key as filename to avoid filesystem issues key_hash = hashlib.md5(key.encode()).hexdigest() return self.cache_dir / f"{key_hash}.cache"
[docs] def get(self, key: str) -> Optional[Any]: """Get value from cache.""" self.stats.total_requests += 1 if key not in self._metadata: self.stats.misses += 1 return None entry_meta = self._metadata[key] # Check if expired ttl = entry_meta.get("ttl") if ttl is not None: created_at = entry_meta["created_at"] if time.time() - created_at > ttl: self.delete(key) self.stats.misses += 1 return None # Load from disk cache_path = self._get_cache_path(key) if not cache_path.exists(): # Metadata exists but file doesn't - clean up del self._metadata[key] self._save_metadata() self.stats.misses += 1 return None try: if self.serializer == "pickle": with open(cache_path, "rb") as f: value = pickle.load(f) else: # json with open(cache_path, "r") as f: value = json.load(f) # Update access metadata entry_meta["last_accessed"] = time.time() entry_meta["access_count"] = entry_meta.get("access_count", 0) + 1 self._save_metadata() self.stats.hits += 1 return value except Exception: # Failed to load - clean up self.delete(key) self.stats.misses += 1 return None
[docs] def set( self, key: str, value: Any, ttl: Optional[float] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Set value in cache.""" # Use default TTL if not specified if ttl is None: ttl = self.default_ttl # Check if we need to evict if self.max_size is not None and key not in self._metadata: if len(self._metadata) >= self.max_size: self._evict_oldest() # Save to disk cache_path = self._get_cache_path(key) try: if self.serializer == "pickle": with open(cache_path, "wb") as f: pickle.dump(value, f) else: # json with open(cache_path, "w") as f: json.dump(value, f) # Update metadata now = time.time() self._metadata[key] = { "created_at": now, "last_accessed": now, "access_count": 0, "ttl": ttl, "metadata": metadata or {}, } self._save_metadata() self.stats.size = len(self._metadata) except Exception as e: # Failed to save - clean up if cache_path.exists(): cache_path.unlink() raise e
[docs] def delete(self, key: str) -> bool: """Delete key from cache.""" if key in self._metadata: # Delete file cache_path = self._get_cache_path(key) if cache_path.exists(): cache_path.unlink() # Delete metadata del self._metadata[key] self._save_metadata() self.stats.size = len(self._metadata) return True return False
[docs] def clear(self) -> None: """Clear all cache entries.""" # Delete all cache files for key in list(self._metadata.keys()): cache_path = self._get_cache_path(key) if cache_path.exists(): cache_path.unlink() # Clear metadata self._metadata.clear() self._save_metadata() self.stats.size = 0
[docs] def exists(self, key: str) -> bool: """Check if key exists and is not expired.""" if key not in self._metadata: return False entry_meta = self._metadata[key] ttl = entry_meta.get("ttl") if ttl is not None: created_at = entry_meta["created_at"] if time.time() - created_at > ttl: self.delete(key) return False return True
[docs] def size(self) -> int: """Get current cache size.""" return len(self._metadata)
[docs] def keys(self) -> List[str]: """Get all cache keys.""" return list(self._metadata.keys())
def _evict_oldest(self) -> None: """Evict the oldest entry by last access time.""" if not self._metadata: return oldest_key = min( self._metadata.keys(), key=lambda k: self._metadata[k]["last_accessed"] ) self.delete(oldest_key) self.stats.evictions += 1 def _clean_expired(self) -> None: """Remove all expired entries.""" expired_keys = [] now = time.time() for key, meta in self._metadata.items(): ttl = meta.get("ttl") if ttl is not None: created_at = meta["created_at"] if now - created_at > ttl: expired_keys.append(key) for key in expired_keys: self.delete(key)
[docs] def get_stats(self) -> CacheStats: """Get cache statistics.""" return self.stats
[docs] def reset_stats(self) -> None: """Reset cache statistics.""" self.stats = CacheStats(size=len(self._metadata))
# ============================================================================ # Tiered Cache (Memory + Disk) # ============================================================================
[docs] class TieredCache(BaseCache): """Two-tier cache: fast memory cache backed by persistent disk cache."""
[docs] def __init__( self, memory_max_size: int = 100, disk_cache_dir: str = ".cache", disk_max_size: Optional[int] = None, default_ttl: Optional[float] = None, ): """Initialize tiered cache. Args: memory_max_size: Maximum entries in memory cache disk_cache_dir: Directory for disk cache disk_max_size: Maximum entries in disk cache default_ttl: Default TTL in seconds Example: >>> cache = TieredCache(memory_max_size=50, disk_cache_dir=".cache") >>> cache.set("key", "value") >>> cache.get("key") # Fast memory access 'value' """ super().__init__(max_size=None, default_ttl=default_ttl) self.memory_cache = MemoryCache( max_size=memory_max_size, default_ttl=default_ttl ) self.disk_cache = DiskCache( cache_dir=disk_cache_dir, max_size=disk_max_size, default_ttl=default_ttl )
[docs] def get(self, key: str) -> Optional[Any]: """Get value from cache (memory first, then disk).""" # Try memory first value = self.memory_cache.get(key) if value is not None: return value # Try disk value = self.disk_cache.get(key) if value is not None: # Promote to memory cache self.memory_cache.set(key, value) return value return None
[docs] def set( self, key: str, value: Any, ttl: Optional[float] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Set value in both caches.""" if ttl is None: ttl = self.default_ttl self.memory_cache.set(key, value, ttl, metadata) self.disk_cache.set(key, value, ttl, metadata)
[docs] def delete(self, key: str) -> bool: """Delete key from both caches.""" mem_deleted = self.memory_cache.delete(key) disk_deleted = self.disk_cache.delete(key) return mem_deleted or disk_deleted
[docs] def clear(self) -> None: """Clear both caches.""" self.memory_cache.clear() self.disk_cache.clear()
[docs] def exists(self, key: str) -> bool: """Check if key exists in either cache.""" return self.memory_cache.exists(key) or self.disk_cache.exists(key)
[docs] def size(self) -> int: """Get total unique keys across both caches.""" mem_keys = set(self.memory_cache.keys()) disk_keys = set(self.disk_cache.keys()) return len(mem_keys | disk_keys)
[docs] def keys(self) -> List[str]: """Get all unique cache keys.""" mem_keys = set(self.memory_cache.keys()) disk_keys = set(self.disk_cache.keys()) return list(mem_keys | disk_keys)
[docs] def get_stats(self) -> Dict[str, CacheStats]: """Get statistics for both caches.""" return { "memory": self.memory_cache.get_stats(), "disk": self.disk_cache.get_stats(), }
[docs] def reset_stats(self) -> None: """Reset statistics for both caches.""" self.memory_cache.reset_stats() self.disk_cache.reset_stats()
# ============================================================================ # LLM-Specific Cache Wrapper # ============================================================================
[docs] class LLMCache: """High-level cache wrapper for LLM applications."""
[docs] def __init__( self, backend: Optional[BaseCache] = None, cost_per_token: float = 0.00001, # Default: ~$0.01 per 1K tokens avg_tokens_per_request: int = 1000, avg_response_time: float = 2.0, # seconds ): """Initialize LLM cache. Args: backend: Cache backend to use (defaults to MemoryCache) cost_per_token: Cost per token for cost tracking avg_tokens_per_request: Average tokens per request avg_response_time: Average response time in seconds Example: >>> cache = LLMCache() >>> response = cache.get_or_compute( ... key="prompt:123", ... compute_fn=lambda: call_llm("What is AI?"), ... cost=0.001 ... ) """ self.backend = backend or MemoryCache() self.cost_per_token = cost_per_token self.avg_tokens_per_request = avg_tokens_per_request self.avg_response_time = avg_response_time
[docs] def cache_prompt( self, prompt: str, response: str, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, ttl: Optional[float] = None, cost: Optional[float] = None, **kwargs, ) -> str: """Cache an LLM prompt and response. Args: prompt: The prompt text response: The LLM response model: Model name temperature: Temperature setting max_tokens: Max tokens setting ttl: Time to live in seconds cost: Actual cost of the request **kwargs: Additional parameters Returns: str: The cache key Example: >>> key = cache.cache_prompt( ... prompt="What is AI?", ... response="AI is...", ... model="gpt-4o", ... cost=0.001 ... ) """ key = generate_prompt_key(prompt, model, temperature, max_tokens, **kwargs) metadata = {} if cost is not None: metadata["cost"] = cost self.backend.set(key, response, ttl=ttl, metadata=metadata) return key
[docs] def get_cached_prompt( self, prompt: str, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, **kwargs, ) -> Optional[str]: """Get cached LLM response for a prompt. Args: prompt: The prompt text model: Model name temperature: Temperature setting max_tokens: Max tokens setting **kwargs: Additional parameters Returns: Optional[str]: Cached response or None Example: >>> response = cache.get_cached_prompt( ... prompt="What is AI?", ... model="gpt-4o" ... ) """ key = generate_prompt_key(prompt, model, temperature, max_tokens, **kwargs) cached = self.backend.get(key) # Track cost savings if cached is not None and isinstance(self.backend, MemoryCache): entry = self.backend.get_entry(key) if entry: cost = entry.metadata.get( "cost", self.cost_per_token * self.avg_tokens_per_request ) self.backend.stats.estimated_cost_saved += cost self.backend.stats.estimated_time_saved += self.avg_response_time return cached
[docs] def cache_embedding( self, text: str, embedding: List[float], model: Optional[str] = None, ttl: Optional[float] = None, cost: Optional[float] = None, **kwargs, ) -> str: """Cache an embedding. Args: text: The text that was embedded embedding: The embedding vector model: Model name ttl: Time to live in seconds cost: Actual cost of the request **kwargs: Additional parameters Returns: str: The cache key Example: >>> key = cache.cache_embedding( ... text="Hello world", ... embedding=[0.1, 0.2, ...], ... model="text-embedding-3-small", ... cost=0.00001 ... ) """ key = generate_embedding_key(text, model, **kwargs) metadata = {} if cost is not None: metadata["cost"] = cost self.backend.set(key, embedding, ttl=ttl, metadata=metadata) return key
[docs] def get_cached_embedding( self, text: str, model: Optional[str] = None, **kwargs ) -> Optional[List[float]]: """Get cached embedding for text. Args: text: The text to get embedding for model: Model name **kwargs: Additional parameters Returns: Optional[List[float]]: Cached embedding or None Example: >>> embedding = cache.get_cached_embedding( ... text="Hello world", ... model="text-embedding-3-small" ... ) """ key = generate_embedding_key(text, model, **kwargs) cached = self.backend.get(key) # Track cost savings if cached is not None and isinstance(self.backend, MemoryCache): entry = self.backend.get_entry(key) if entry: cost = entry.metadata.get("cost", 0.0001) # Default embedding cost self.backend.stats.estimated_cost_saved += cost self.backend.stats.estimated_time_saved += 0.5 # Embeddings are faster return cached
[docs] def get_or_compute( self, key: str, compute_fn: Callable[[], Any], ttl: Optional[float] = None, cost: Optional[float] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Any: """Get from cache or compute if not found. Args: key: Cache key compute_fn: Function to compute value if not cached ttl: Time to live in seconds cost: Cost of computing the value metadata: Additional metadata Returns: Any: Cached or computed value Example: >>> result = cache.get_or_compute( ... key="expensive:computation", ... compute_fn=lambda: expensive_api_call(), ... ttl=3600, ... cost=0.01 ... ) """ # Try to get from cache value = self.backend.get(key) if value is not None: # Track savings if cost is not None and isinstance(self.backend, MemoryCache): self.backend.stats.estimated_cost_saved += cost return value # Compute value value = compute_fn() # Store in cache cache_metadata = metadata or {} if cost is not None: cache_metadata["cost"] = cost self.backend.set(key, value, ttl=ttl, metadata=cache_metadata) return value
[docs] def get_stats(self) -> CacheStats: """Get cache statistics.""" if isinstance(self.backend, (MemoryCache, DiskCache)): return self.backend.get_stats() elif isinstance(self.backend, TieredCache): stats_dict = self.backend.get_stats() # Combine stats from both caches combined = CacheStats() combined.hits = stats_dict["memory"].hits + stats_dict["disk"].hits combined.misses = stats_dict["memory"].misses + stats_dict["disk"].misses combined.evictions = ( stats_dict["memory"].evictions + stats_dict["disk"].evictions ) combined.total_requests = stats_dict["memory"].total_requests combined.estimated_cost_saved = ( stats_dict["memory"].estimated_cost_saved + stats_dict["disk"].estimated_cost_saved ) combined.estimated_time_saved = ( stats_dict["memory"].estimated_time_saved + stats_dict["disk"].estimated_time_saved ) combined.size = self.backend.size() return combined return self.backend.stats
[docs] def clear(self) -> None: """Clear all cache entries.""" self.backend.clear()
[docs] def invalidate_by_prefix(self, prefix: str) -> int: """Invalidate all keys with a given prefix. Args: prefix: Key prefix to invalidate Returns: int: Number of keys invalidated Example: >>> cache.invalidate_by_prefix("prompt:") 42 """ keys = self.backend.keys() invalidated = 0 for key in keys: if key.startswith(prefix): if self.backend.delete(key): invalidated += 1 return invalidated
[docs] def invalidate_by_pattern(self, pattern: Callable[[str], bool]) -> int: """Invalidate keys matching a pattern function. Args: pattern: Function that returns True for keys to invalidate Returns: int: Number of keys invalidated Example: >>> cache.invalidate_by_pattern(lambda k: "gpt-4o" in k) 15 """ keys = self.backend.keys() invalidated = 0 for key in keys: if pattern(key): if self.backend.delete(key): invalidated += 1 return invalidated