Source code for kerb.generation.generator

"""Main generation functions for LLM interaction.

This module provides the core generation functions that orchestrate calls
to different LLM providers.
"""

import asyncio
import os
import time
from typing import Callable, Dict, Iterator, List, Optional, Union

from kerb.core.types import Message, MessageRole

# Import from our reorganized modules
from .config import GenerationConfig, GenerationResponse, StreamChunk, Usage
from .enums import LLMProvider, ModelName
from .providers.anthropic import (_generate_anthropic,
                                  _generate_stream_anthropic)
from .providers.google import _generate_google, _generate_stream_google
# Import provider-specific functions
from .providers.openai import _generate_openai, _generate_stream_openai
from .utils import (CostTracker, RateLimiter, ResponseCache,
                    _global_cost_tracker, calculate_cost,
                    retry_with_exponential_backoff)


[docs] def generate( messages: Union[List[Message], List[Dict[str, str]], str], model: Optional[Union[str, ModelName]] = None, config: Optional[GenerationConfig] = None, api_key: Optional[str] = None, provider: Optional[LLMProvider] = None, use_cache: bool = True, cost_tracker: Optional[CostTracker] = None, track_cost: bool = False, rate_limiter: Optional[RateLimiter] = None, max_retries: int = 3, **kwargs, ) -> GenerationResponse: """Universal generator function - generate responses from any LLM provider. This is the main generation function that routes to the appropriate provider based on the model and provider parameters. Args: messages: Input messages (can be string, list of dicts, or list of Message objects) model: Model to use (ModelName enum or string for custom models). If not provided, must be specified in config. config: Generation configuration api_key: API key (if not provided, uses environment variable) provider: LLMProvider enum specifying which API to use use_cache: Whether to use response caching cost_tracker: Optional cost tracker instance track_cost: Whether to track costs in global tracker rate_limiter: Optional rate limiter instance max_retries: Maximum retry attempts for failed requests **kwargs: Additional config parameters Returns: GenerationResponse: The generated response Examples: >>> # Using ModelName enum >>> response = generate("Hello", model=ModelName.GPT_4O_MINI, provider=LLMProvider.OPENAI) >>> # Using custom model name >>> response = generate("Hello", model="my-custom-gpt", provider=LLMProvider.OPENAI) >>> # Different providers >>> response = generate("Hello", model=ModelName.CLAUDE_35_HAIKU, provider=LLMProvider.ANTHROPIC) """ if not messages: raise ValueError("Messages cannot be empty") # Determine the model to use if config is not None and model is None: # Use model from config model_str = config.model model_for_detection = config.model elif model is not None: # Convert ModelName enum to string for internal use model_str = model.value if isinstance(model, ModelName) else model model_for_detection = model else: raise ValueError( "Either 'model' parameter or 'config' with a model must be provided" ) # Convert string to messages if isinstance(messages, str): messages = [Message(role=MessageRole.USER, content=messages)] elif isinstance(messages, list) and messages and isinstance(messages[0], dict): messages = [ Message(role=m.get("role", "user"), content=m["content"]) for m in messages ] # Create or update config if config is None: config = GenerationConfig(model=model_str) elif model is not None: # If both config and model are provided, model parameter takes precedence config.model = model_str # Apply kwargs for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) # Validate provider if provider is None: raise ValueError( "Provider must be specified. Pass the provider parameter.\\n" "Example: generate('Hello', model='gpt-4o-mini', provider=LLMProvider.OPENAI)" ) # Validate API key if api_key is None: if provider == LLMProvider.OPENAI and not os.getenv("OPENAI_API_KEY"): raise ValueError("OpenAI API key not found") elif provider == LLMProvider.ANTHROPIC and not os.getenv("ANTHROPIC_API_KEY"): raise ValueError("Anthropic API key not found") elif provider == LLMProvider.GOOGLE and not os.getenv("GOOGLE_API_KEY"): raise ValueError("Google API key not found") # Check cache if use_cache: cache = ResponseCache() cached_response = cache.get(messages, config) if cached_response: return cached_response # Rate limiting if rate_limiter: estimated_tokens = sum(len(m.content.split()) * 1.3 for m in messages) rate_limiter.wait_if_needed(int(estimated_tokens)) # Generate def _generate(): start_time = time.time() if provider == LLMProvider.OPENAI: response = _generate_openai(messages, config, api_key) elif provider == LLMProvider.ANTHROPIC: response = _generate_anthropic(messages, config, api_key) elif provider == LLMProvider.GOOGLE: response = _generate_google(messages, config, api_key) else: response = _generate_mock(messages, config, provider) response.latency = time.time() - start_time return response response = retry_with_exponential_backoff(_generate, max_retries=max_retries) response.cost = calculate_cost(model_for_detection, response.usage) # Track cost if track_cost or cost_tracker: tracker = cost_tracker if cost_tracker else _global_cost_tracker tracker.add_request(model_str, response.usage, response.cost) # Cache if use_cache: cache.set(messages, config, response) return response
[docs] def generate_stream( messages: Union[List[Message], List[Dict[str, str]], str], model: Optional[Union[str, ModelName]] = None, config: Optional[GenerationConfig] = None, api_key: Optional[str] = None, provider: Optional[LLMProvider] = None, callback: Optional[Callable[[StreamChunk], None]] = None, **kwargs, ) -> Iterator[StreamChunk]: """Generate streaming response from any LLM provider. Args: messages: Input messages (can be string, list of dicts, or list of Message objects) model: Model to use (ModelName enum or string for custom models). If not provided, must be specified in config. config: Generation configuration api_key: API key (if not provided, uses environment variable) provider: LLMProvider enum specifying which API to use callback: Optional callback function for each chunk **kwargs: Additional config parameters Yields: StreamChunk: Chunks of the generated response """ # Determine the model to use if config is not None and model is None: # Use model from config model_str = config.model model_for_detection = config.model elif model is not None: # Convert ModelName enum to string for internal use model_str = model.value if isinstance(model, ModelName) else model model_for_detection = model else: raise ValueError( "Either 'model' parameter or 'config' with a model must be provided" ) if isinstance(messages, str): messages = [Message(role=MessageRole.USER, content=messages)] elif isinstance(messages, list) and messages and isinstance(messages[0], dict): messages = [ Message(role=m.get("role", "user"), content=m["content"]) for m in messages ] if config is None: config = GenerationConfig(model=model_str, stream=True) else: if model is not None: config.model = model_str config.stream = True for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) # Validate provider if provider is None: raise ValueError( "Provider must be specified. Pass the provider parameter.\\n" "Example: generate_stream('Hello', model='gpt-4o-mini', provider=LLMProvider.OPENAI)" ) if provider == LLMProvider.OPENAI: yield from _generate_stream_openai(messages, config, api_key, callback) elif provider == LLMProvider.ANTHROPIC: yield from _generate_stream_anthropic(messages, config, api_key, callback) elif provider == LLMProvider.GOOGLE: yield from _generate_stream_google(messages, config, api_key, callback) else: response = generate( messages, model, config, api_key, provider=provider, **kwargs ) chunk = StreamChunk( content=response.content, finish_reason=response.finish_reason, model=model_str, ) if callback: callback(chunk) yield chunk
[docs] def generate_batch( prompts: List[Union[str, List[Message]]], model: Optional[Union[str, ModelName]] = None, config: Optional[GenerationConfig] = None, api_key: Optional[str] = None, provider: Optional[LLMProvider] = None, max_concurrent: int = 5, show_progress: bool = False, **kwargs, ) -> List[GenerationResponse]: """Generate batch responses. Args: prompts: List of prompts to process model: Model to use (ModelName enum or string for custom models). If not provided, must be specified in config. config: Generation configuration api_key: API key (if not provided, uses environment variable) provider: LLMProvider enum specifying which API to use max_concurrent: Maximum concurrent requests show_progress: Whether to show progress **kwargs: Additional config parameters Returns: List[GenerationResponse]: List of generated responses """ if model is None and config is None: raise ValueError( "Either 'model' parameter or 'config' with a model must be provided" ) async def _batch(): sem = asyncio.Semaphore(max_concurrent) async def _one(prompt): async with sem: return await asyncio.to_thread( generate, prompt, model=model, config=config, api_key=api_key, provider=provider, **kwargs, ) tasks = [_one(p) for p in prompts] if show_progress: results = [] for i, task in enumerate(asyncio.as_completed(tasks)): results.append(await task) print(f"Completed {i+1}/{len(prompts)}", end="\r") print() return results return await asyncio.gather(*tasks) return asyncio.run(_batch())
[docs] async def generate_async( messages: Union[List[Message], List[Dict[str, str]], str], model: Optional[Union[str, ModelName]] = None, config: Optional[GenerationConfig] = None, api_key: Optional[str] = None, provider: Optional[LLMProvider] = None, use_cache: bool = True, cost_tracker: Optional[CostTracker] = None, track_cost: bool = False, max_retries: int = 3, **kwargs, ) -> GenerationResponse: """Async generation. Args: messages: Input messages (can be string, list of dicts, or list of Message objects) model: Model to use (ModelName enum or string for custom models). If not provided, must be specified in config. config: Generation configuration api_key: API key (if not provided, uses environment variable) provider: LLMProvider enum specifying which API to use use_cache: Whether to use response caching cost_tracker: Optional cost tracker instance track_cost: Whether to track costs in global tracker max_retries: Maximum retry attempts for failed requests **kwargs: Additional config parameters Returns: GenerationResponse: The generated response """ return await asyncio.to_thread( generate, messages, model=model, config=config, api_key=api_key, provider=provider, use_cache=use_cache, cost_tracker=cost_tracker, track_cost=track_cost, max_retries=max_retries, **kwargs, )
def _generate_mock( messages: List[Message], config: GenerationConfig, provider: LLMProvider ) -> GenerationResponse: """Mock generation. Args: messages: Input messages config: Generation configuration provider: The detected provider to use in the response Returns: GenerationResponse: Mock response with the specified provider """ content = f"Mock response for model {config.model}" prompt_tokens = sum(len(m.content.split()) * 1.3 for m in messages) completion_tokens = len(content.split()) * 1.3 usage = Usage( prompt_tokens=int(prompt_tokens), completion_tokens=int(completion_tokens), total_tokens=int(prompt_tokens + completion_tokens), ) return GenerationResponse( content=content, model=config.model, provider=provider, # Use the provider passed in (already detected) usage=usage, finish_reason="stop", metadata={"mock": True}, )
[docs] class Generator: """Universal LLM generator - easily switch between models and providers. This class provides a convenient stateful interface for LLM generation with support for both enum-based and string-based model specification. It makes it easy to switch between different models and providers without changing your code structure. Examples: >>> # Using ModelName enum >>> gen = Generator(model=ModelName.GPT_4O_MINI, provider=LLMProvider.OPENAI) >>> response = gen.generate("Hello!") >>> # Using custom model name >>> gen = Generator(model="my-custom-model", provider=LLMProvider.OPENAI) >>> response = gen.generate("Hello!") >>> # Easy model switching >>> gen_gpt = Generator(model=ModelName.GPT_4O_MINI, provider=LLMProvider.OPENAI, temperature=0.7) >>> gen_claude = Generator(model=ModelName.CLAUDE_35_HAIKU, provider=LLMProvider.ANTHROPIC, temperature=0.7) """
[docs] def __init__( self, model: Union[str, ModelName], api_key: Optional[str] = None, provider: Optional[LLMProvider] = None, cost_tracker: Optional[CostTracker] = None, **default_config, ): """Initialize the universal Generator. Args: model: Model to use (ModelName enum or string for custom models) api_key: API key (if not provided, uses environment variable) provider: LLMProvider enum specifying which API to use cost_tracker: Optional cost tracker instance **default_config: Default configuration parameters (temperature, max_tokens, etc.) """ self.model = model self.api_key = api_key self.provider = provider self.cost_tracker = cost_tracker self.default_config = default_config
[docs] def generate( self, messages: Union[List[Message], List[Dict[str, str]], str], **kwargs ) -> GenerationResponse: """Generate a response. Args: messages: Input messages **kwargs: Override default config parameters Returns: GenerationResponse: The generated response """ config = {**self.default_config, **kwargs} return generate( messages, model=self.model, api_key=self.api_key, provider=self.provider, cost_tracker=self.cost_tracker, **config, )
[docs] def stream( self, messages: Union[List[Message], List[Dict[str, str]], str], **kwargs ) -> Iterator[StreamChunk]: """Generate a streaming response. Args: messages: Input messages **kwargs: Override default config parameters Yields: StreamChunk: Chunks of the generated response """ config = {**self.default_config, **kwargs} return generate_stream( messages, model=self.model, api_key=self.api_key, provider=self.provider, **config, )
[docs] def batch( self, prompts: List[Union[str, List[Message]]], **kwargs ) -> List[GenerationResponse]: """Generate batch responses. Args: prompts: List of prompts to process **kwargs: Override default config parameters Returns: List[GenerationResponse]: List of generated responses """ config = {**self.default_config, **kwargs} return generate_batch( prompts, model=self.model, api_key=self.api_key, provider=self.provider, **config, )