"""Base provider abstraction and registry for LLM generation.
This module provides the base class for all LLM providers and a registry
system for managing custom providers.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional
from kerb.core.types import Message
from .config import GenerationConfig, GenerationResponse, StreamChunk
# ============================================================================
# Provider Registry
# ============================================================================
_provider_registry: Dict[str, "BaseProvider"] = {}
[docs]
def register_provider(name: str, provider: "BaseProvider") -> None:
"""Register a custom provider.
Args:
name: Provider name (used in model strings like "custom::model-name")
provider: Provider instance
Examples:
>>> from kerb.generation.base import register_provider
>>> provider = MyCustomProvider(api_key="...")
>>> register_provider("mycustom", provider)
>>> # Now can use: generate(messages, model="mycustom::my-model")
"""
_provider_registry[name] = provider
[docs]
def get_provider(name: str) -> Optional["BaseProvider"]:
"""Get a registered provider by name.
Args:
name: Provider name
Returns:
Provider instance or None if not found
"""
return _provider_registry.get(name)
[docs]
def list_providers() -> List[str]:
"""List all registered provider names."""
return list(_provider_registry.keys())
# ============================================================================
# Base Provider Interface
# ============================================================================
[docs]
class BaseProvider(ABC):
"""Base class for LLM providers.
Custom providers should inherit from this class and implement
the required methods.
"""
[docs]
def __init__(self, api_key: Optional[str] = None, **kwargs):
"""Initialize provider.
Args:
api_key: API key (if None, will try to get from environment)
**kwargs: Provider-specific configuration
"""
self.api_key = api_key
self.config = kwargs
[docs]
@abstractmethod
def generate(
self, messages: List[Message], config: GenerationConfig
) -> GenerationResponse:
"""Generate a response.
Args:
messages: List of conversation messages
config: Generation configuration
Returns:
GenerationResponse
"""
pass
[docs]
@abstractmethod
def generate_stream(
self, messages: List[Message], config: GenerationConfig
) -> Iterator[StreamChunk]:
"""Generate a streaming response.
Args:
messages: List of conversation messages
config: Generation configuration
Yields:
StreamChunk
"""
pass
[docs]
@abstractmethod
async def generate_async(
self, messages: List[Message], config: GenerationConfig
) -> GenerationResponse:
"""Generate a response asynchronously.
Args:
messages: List of conversation messages
config: Generation configuration
Returns:
GenerationResponse
"""
pass
[docs]
def validate_config(self, config: GenerationConfig) -> bool:
"""Validate configuration for this provider.
Args:
config: Generation configuration
Returns:
bool: True if valid
Raises:
ValueError: If configuration is invalid
"""
return True