Source code for kerb.context.window

"""Core context window management functions.

This module provides functions for creating and managing context windows,
including truncation strategies.
"""

from typing import Callable, List, Optional, Union

from kerb.tokenizer import Tokenizer, count_tokens

from .types import ContextItem, ContextWindow, TruncationStrategy


[docs] def create_context_window( items: Union[List[str], List[ContextItem]], max_tokens: Optional[int] = None, strategy: TruncationStrategy = TruncationStrategy.LAST, token_estimator: Optional[Callable[[str], int]] = None, ) -> ContextWindow: """Create a managed context window from items. Args: items: List of strings or ContextItem objects max_tokens: Maximum tokens allowed in window strategy: Truncation strategy if limit exceeded token_estimator: Custom token estimation function (defaults to count_tokens from tokenizer) Returns: ContextWindow: Managed context window Example: >>> window = create_context_window(["Hello", "World"], max_tokens=1000) >>> print(window.current_tokens) """ # Use count_tokens from tokenizer module as default estimator = token_estimator or ( lambda text: count_tokens(text, Tokenizer.CL100K_BASE) ) # Convert strings to ContextItems if needed context_items = [] for item in items: if isinstance(item, str): token_count = estimator(item) context_items.append(ContextItem(content=item, token_count=token_count)) else: if item.token_count is None: item.token_count = estimator(item.content) context_items.append(item) window = ContextWindow( items=context_items, max_tokens=max_tokens, strategy=strategy ) # Calculate total tokens window.current_tokens = sum(item.token_count or 0 for item in context_items) # Apply truncation if needed if max_tokens and window.current_tokens > max_tokens: window = truncate_context_window(window, max_tokens, strategy) return window
[docs] def truncate_context_window( window: ContextWindow, max_tokens: int, strategy: TruncationStrategy = TruncationStrategy.LAST, ) -> ContextWindow: """Truncate context window to fit within token limit. Args: window: Context window to truncate max_tokens: Maximum tokens allowed strategy: Truncation strategy to use Returns: ContextWindow: Truncated context window Example: >>> window = truncate_context_window(window, max_tokens=500) """ if window.current_tokens <= max_tokens: return window if strategy == TruncationStrategy.FIRST: return _truncate_first(window, max_tokens) elif strategy == TruncationStrategy.LAST: return _truncate_last(window, max_tokens) elif strategy == TruncationStrategy.MIDDLE: return _truncate_middle(window, max_tokens) elif strategy == TruncationStrategy.PRIORITY: return _truncate_priority(window, max_tokens) else: return _truncate_last(window, max_tokens)
def _truncate_first(window: ContextWindow, max_tokens: int) -> ContextWindow: """Keep first items up to token limit.""" kept_items = [] current = 0 for item in window.items: if current + (item.token_count or 0) <= max_tokens: kept_items.append(item) current += item.token_count or 0 else: break return ContextWindow( items=kept_items, max_tokens=max_tokens, current_tokens=current, strategy=window.strategy, metadata=window.metadata, ) def _truncate_last(window: ContextWindow, max_tokens: int) -> ContextWindow: """Keep last items up to token limit.""" kept_items = [] current = 0 for item in reversed(window.items): if current + (item.token_count or 0) <= max_tokens: kept_items.insert(0, item) current += item.token_count or 0 else: break return ContextWindow( items=kept_items, max_tokens=max_tokens, current_tokens=current, strategy=window.strategy, metadata=window.metadata, ) def _truncate_middle(window: ContextWindow, max_tokens: int) -> ContextWindow: """Keep start and end items, remove middle.""" if not window.items: return window # Allocate half to start, half to end start_tokens = max_tokens // 2 end_tokens = max_tokens - start_tokens # Get start items start_items = [] current = 0 for item in window.items: if current + (item.token_count or 0) <= start_tokens: start_items.append(item) current += item.token_count or 0 else: break # Get end items end_items = [] current = 0 for item in reversed(window.items): if current + (item.token_count or 0) <= end_tokens: end_items.insert(0, item) current += item.token_count or 0 else: break # Combine kept_items = start_items + end_items total_tokens = sum(item.token_count or 0 for item in kept_items) return ContextWindow( items=kept_items, max_tokens=max_tokens, current_tokens=total_tokens, strategy=window.strategy, metadata=window.metadata, ) def _truncate_priority(window: ContextWindow, max_tokens: int) -> ContextWindow: """Keep highest priority items up to token limit.""" # Sort by priority (descending) sorted_items = sorted(window.items, key=lambda x: x.priority, reverse=True) kept_items = [] current = 0 for item in sorted_items: if current + (item.token_count or 0) <= max_tokens: kept_items.append(item) current += item.token_count or 0 # Restore original order for kept items original_order = {id(item): i for i, item in enumerate(window.items)} kept_items.sort(key=lambda x: original_order.get(id(x), 0)) return ContextWindow( items=kept_items, max_tokens=max_tokens, current_tokens=current, strategy=window.strategy, metadata=window.metadata, )