Source code for kerb.retrieval.query

"""Query processing utilities for retrieval.

This module provides functions for query rewriting, expansion, and decomposition.
"""

import re
from typing import TYPE_CHECKING, List, Optional, Union

if TYPE_CHECKING:
    from kerb.core.enums import ExpansionMethod, QueryStyle


[docs] def rewrite_query( query: str, style: Union["QueryStyle", str] = "clear", max_length: Optional[int] = None, ) -> str: """Rewrite a query for better retrieval. Args: query: The original query text style: Rewriting style (QueryStyle enum or string: "clear", "detailed", "concise", "keyword", "natural") max_length: Maximum length of rewritten query Returns: str: Rewritten query Examples: >>> from kerb.core.enums import QueryStyle >>> rewritten = rewrite_query("python async", style=QueryStyle.DETAILED) """ from kerb.core.enums import QueryStyle, validate_enum_or_string query = query.strip() # Validate and normalize style style_val = validate_enum_or_string(style, QueryStyle, "style") if isinstance(style_val, QueryStyle): style_str = style_val.value else: style_str = style_val if style_str == "clear": # Remove filler words and simplify filler_words = { "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", } words = query.lower().split() words = [w for w in words if w not in filler_words] rewritten = " ".join(words) elif style_str == "detailed": # Add context and specificity if "?" not in query: rewritten = f"Detailed information about {query}" else: rewritten = f"Please provide comprehensive details: {query}" elif style_str == "keyword": # Extract key terms only words = query.lower().split() # Remove common words stop_words = { "how", "what", "when", "where", "why", "who", "which", "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "is", "are", "was", "were", } keywords = [w for w in words if w not in stop_words and len(w) > 2] rewritten = " ".join(keywords) elif style_str == "concise": # Make more concise words = query.split() # Keep only first few important words rewritten = " ".join(words[:5]) elif style_str == "natural": # Convert to natural question format if "?" not in query: if query.lower().startswith(("how", "what", "when", "where", "why", "who")): rewritten = query + "?" else: rewritten = f"What is {query}?" else: rewritten = query else: rewritten = query if max_length and len(rewritten) > max_length: rewritten = rewritten[:max_length].rsplit(" ", 1)[0] return rewritten
[docs] def expand_query( query: str, expansions: Optional[List[str]] = None, method: Union["ExpansionMethod", str] = "synonyms", ) -> List[str]: """Expand a query into multiple variations for broader retrieval. Args: query: The original query text expansions: Custom expansion terms to add method: Expansion method (ExpansionMethod enum or string: "synonyms", "related_terms", "llm", "embeddings") Returns: List[str]: List of query variations Examples: >>> from kerb.core.enums import ExpansionMethod >>> queries = expand_query("machine learning", method=ExpansionMethod.SYNONYMS) """ from kerb.core.enums import ExpansionMethod, validate_enum_or_string variations = [query] if expansions: variations.extend(expansions) # Validate and normalize method method_val = validate_enum_or_string(method, ExpansionMethod, "method") if isinstance(method_val, ExpansionMethod): method_str = method_val.value else: method_str = method_val if method_str == "synonyms": # Simple synonym expansion (can be enhanced with a synonym dictionary) synonym_map = { "ml": ["machine learning", "ML"], "ai": ["artificial intelligence", "AI"], "llm": ["large language model", "LLM"], "api": ["API", "application programming interface"], "db": ["database", "DB"], "async": ["asynchronous", "async"], "auth": ["authentication", "auth"], "config": ["configuration", "config"], } query_lower = query.lower() for term, synonyms in synonym_map.items(): if term in query_lower: for syn in synonyms: expanded = re.sub(term, syn, query_lower, flags=re.IGNORECASE) if expanded not in variations: variations.append(expanded) elif method_str == "related_terms": # Add related terms related_map = { "python": ["python programming", "python code", "python development"], "database": ["database design", "database query", "data storage"], "api": ["REST API", "API endpoint", "API integration"], "error": ["exception", "bug", "issue"], } query_lower = query.lower() for term, related in related_map.items(): if term in query_lower: variations.extend(related) elif method_str in ("llm", "embeddings"): # Placeholder for LLM-based or embedding-based expansion # In production, you'd call an LLM or use embedding similarity # For now, just add the original query pass # Remove duplicates while preserving order seen = set() unique_variations = [] for v in variations: if v.lower() not in seen: seen.add(v.lower()) unique_variations.append(v) return unique_variations
[docs] def generate_sub_queries(query: str, max_queries: int = 3) -> List[str]: """Generate sub-queries from a complex query for step-by-step retrieval. Args: query: The original complex query max_queries: Maximum number of sub-queries to generate Returns: List[str]: List of sub-queries Example: >>> generate_sub_queries("How to implement authentication in a Python FastAPI app?") ["What is authentication?", "How to use FastAPI?", "Python authentication methods"] """ sub_queries = [] # Split on conjunctions if " and " in query.lower(): parts = [p.strip() for p in re.split(r"\band\b", query, flags=re.IGNORECASE)] sub_queries.extend(parts[:max_queries]) # Extract ai question_words = ["how", "what", "when", "where", "why", "who", "which"] words = query.lower().split() # Remove question words to get core concepts concepts = [w for w in words if w not in question_words and len(w) > 3] # Generate sub-queries from concepts if len(concepts) >= 2 and len(sub_queries) < max_queries: for i, concept in enumerate(concepts[: max_queries - len(sub_queries)]): sub_queries.append(f"What is {concept}?") if not sub_queries: sub_queries = [query] return sub_queries[:max_queries]