Source code for kerb.testing.mocking

"""Mock LLM providers for testing."""

import random
import re
import time
from datetime import datetime
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

from .types import MockBehavior, MockResponse


[docs] class MockLLM: """Mock LLM provider with configurable responses. This class provides a drop-in replacement for real LLM providers, useful for testing without making actual API calls. """
[docs] def __init__( self, responses: Optional[Union[str, List[str], Dict[str, str]]] = None, behavior: MockBehavior = MockBehavior.FIXED, default_response: str = "Mock response", latency: float = 0.1, token_calculator: Optional[Callable[[str], int]] = None, ): """Initialize mock LLM. Args: responses: Response(s) to return behavior: Behavior mode for returning responses default_response: Default response when no match found latency: Simulated latency per response token_calculator: Function to calculate token counts """ self.behavior = behavior self.default_response = default_response self.latency = latency self.token_calculator = token_calculator or self._simple_token_count self.call_count = 0 self.call_history: List[Dict[str, Any]] = [] # Configure responses based on behavior if isinstance(responses, str): self.responses = [responses] elif isinstance(responses, list): self.responses = responses elif isinstance(responses, dict): self.pattern_responses = responses self.responses = [] else: self.responses = [default_response] self.current_index = 0
[docs] def generate( self, prompt: Union[str, List[Dict[str, str]]], **kwargs ) -> MockResponse: """Generate a mock response. Args: prompt: Input prompt (string or message list) **kwargs: Additional generation parameters (ignored) Returns: MockResponse object """ self.call_count += 1 # Extract text from prompt if isinstance(prompt, list): prompt_text = " ".join([msg.get("content", "") for msg in prompt]) else: prompt_text = prompt # Record call self.call_history.append( { "prompt": prompt_text, "kwargs": kwargs, "timestamp": datetime.now().isoformat(), } ) # Generate response based on behavior if self.behavior == MockBehavior.FIXED: content = self.responses[0] if self.responses else self.default_response elif self.behavior == MockBehavior.SEQUENTIAL: content = self.responses[self.current_index % len(self.responses)] self.current_index += 1 elif self.behavior == MockBehavior.RANDOM: content = ( random.choice(self.responses) if self.responses else self.default_response ) elif self.behavior == MockBehavior.PATTERN: content = self._match_pattern(prompt_text) else: content = self.default_response # Simulate latency time.sleep(self.latency) return MockResponse( content=content, prompt_tokens=self.token_calculator(prompt_text), completion_tokens=self.token_calculator(content), latency=self.latency, metadata={"call_count": self.call_count}, )
def _match_pattern(self, prompt: str) -> str: """Match prompt against patterns and return response.""" for pattern, response in self.pattern_responses.items(): if re.search(pattern, prompt, re.IGNORECASE): return response return self.default_response def _simple_token_count(self, text: str) -> int: """Simple token count estimation.""" return len(text.split())
[docs] def reset(self) -> None: """Reset call count and history.""" self.call_count = 0 self.call_history = [] self.current_index = 0
[docs] def get_last_call(self) -> Optional[Dict[str, Any]]: """Get the last call made to the mock.""" return self.call_history[-1] if self.call_history else None
[docs] def assert_called(self) -> None: """Assert that the mock was called at least once.""" assert self.call_count > 0, "Mock LLM was not called"
[docs] def assert_called_with(self, prompt_contains: str) -> None: """Assert that the mock was called with a prompt containing text.""" for call in self.call_history: if prompt_contains in call["prompt"]: return raise AssertionError( f"Mock LLM was not called with prompt containing: {prompt_contains}" )
[docs] class MockStreamingLLM: """Mock streaming LLM for testing streaming responses."""
[docs] def __init__( self, response: str, chunk_size: int = 10, delay_per_chunk: float = 0.01 ): """Initialize mock streaming LLM. Args: response: Full response to stream chunk_size: Characters per chunk delay_per_chunk: Delay between chunks in seconds """ self.response = response self.chunk_size = chunk_size self.delay_per_chunk = delay_per_chunk
[docs] def generate_stream( self, prompt: Union[str, List[Dict[str, str]]], **kwargs ) -> Iterator[str]: """Generate streaming mock response. Args: prompt: Input prompt **kwargs: Additional parameters (ignored) Yields: Response chunks """ for i in range(0, len(self.response), self.chunk_size): chunk = self.response[i : i + self.chunk_size] time.sleep(self.delay_per_chunk) yield chunk
[docs] def create_mock_llm( responses: Union[str, List[str], Dict[str, str]], behavior: MockBehavior = MockBehavior.FIXED, **kwargs, ) -> MockLLM: """Helper to create a mock LLM instance. Args: responses: Response(s) to configure behavior: Behavior mode **kwargs: Additional MockLLM parameters Returns: Configured MockLLM instance """ return MockLLM(responses=responses, behavior=behavior, **kwargs)