Source code for kerb.fine_tuning.validation

"""Dataset validation functions for fine-tuning."""

from typing import Any, Dict, List, Optional

from .types import (DatasetFormat, FineTuningProvider, TrainingDataset,
                    ValidationLevel, ValidationResult)


[docs] def validate_dataset( dataset: TrainingDataset, level: ValidationLevel = ValidationLevel.MODERATE, max_tokens: Optional[int] = None, ) -> ValidationResult: """Validate dataset for fine-tuning. Args: dataset: Dataset to validate level: Validation strictness max_tokens: Maximum tokens per example Returns: ValidationResult """ result = ValidationResult(is_valid=True, total_examples=len(dataset)) if len(dataset.examples) == 0: result.add_error("Dataset is empty") return result # Validate each example for i, example in enumerate(dataset.examples): # Check format-specific requirements if dataset.format == DatasetFormat.CHAT: if not example.messages: result.add_error(f"Example {i}: Missing messages for chat format") result.invalid_examples += 1 continue # Validate message structure for j, msg in enumerate(example.messages): if "role" not in msg: result.add_error(f"Example {i}, Message {j}: Missing 'role' field") if "content" not in msg: result.add_error( f"Example {i}, Message {j}: Missing 'content' field" ) elif dataset.format == DatasetFormat.COMPLETION: if not example.prompt: result.add_error(f"Example {i}: Missing prompt for completion format") result.invalid_examples += 1 continue if not example.completion: if level == ValidationLevel.STRICT: result.add_error(f"Example {i}: Missing completion") else: result.add_warning(f"Example {i}: Missing completion") # Check token limits if specified if max_tokens: text = example.get_text_content() estimated_tokens = len(text.split()) # Rough estimate if estimated_tokens > max_tokens: if level == ValidationLevel.STRICT: result.add_error( f"Example {i}: Exceeds token limit ({estimated_tokens} > {max_tokens})" ) else: result.add_warning( f"Example {i}: May exceed token limit ({estimated_tokens} tokens)" ) if result.errors and level == ValidationLevel.STRICT: result.valid_examples = i result.invalid_examples = len(dataset) - i return result result.valid_examples = len(dataset) - result.invalid_examples # Final checks if result.valid_examples < 10: result.add_warning( f"Dataset has only {result.valid_examples} valid examples. Recommended: at least 50-100" ) return result
def validate_format( data: List[Dict[str, Any]], provider: FineTuningProvider ) -> ValidationResult: """Validate format for specific provider. Args: data: Data to validate provider: Target provider Returns: ValidationResult """ result = ValidationResult(is_valid=True, total_examples=len(data)) for i, item in enumerate(data): if provider == FineTuningProvider.OPENAI: if "messages" not in item: result.add_error( f"Example {i}: Missing 'messages' field for OpenAI format" ) else: for j, msg in enumerate(item["messages"]): if "role" not in msg or "content" not in msg: result.add_error( f"Example {i}, Message {j}: Must have 'role' and 'content'" ) elif provider == FineTuningProvider.ANTHROPIC: if "messages" not in item: result.add_error( f"Example {i}: Missing 'messages' field for Anthropic format" ) # Add more provider-specific validation as needed result.valid_examples = len(data) - len(result.errors) result.invalid_examples = len(result.errors) return result def check_token_limits( dataset: TrainingDataset, max_tokens: int = 4096, tokenizer_name: str = "cl100k_base", ) -> Dict[str, Any]: """Check if examples exceed token limits. Args: dataset: Dataset to check max_tokens: Maximum allowed tokens tokenizer_name: Tokenizer to use for counting Returns: Dictionary with statistics about token usage """ try: from ..tokenizer import count_tokens except ImportError: # Fallback to simple word count def count_tokens(text, model): return len(text.split()) exceeding = [] token_counts = [] for i, example in enumerate(dataset.examples): text = example.get_text_content() tokens = count_tokens(text, tokenizer_name) token_counts.append(tokens) if tokens > max_tokens: exceeding.append({"index": i, "tokens": tokens}) return { "total_examples": len(dataset), "exceeding_limit": len(exceeding), "exceeding_examples": exceeding, "avg_tokens": sum(token_counts) / len(token_counts) if token_counts else 0, "max_tokens_found": max(token_counts) if token_counts else 0, "min_tokens_found": min(token_counts) if token_counts else 0, } def validate_messages(messages: List[Dict[str, str]]) -> ValidationResult: """Validate message structure for chat format. Args: messages: List of message dictionaries Returns: ValidationResult """ result = ValidationResult(is_valid=True) valid_roles = {"system", "user", "assistant", "function", "tool"} for i, msg in enumerate(messages): if not isinstance(msg, dict): result.add_error(f"Message {i}: Must be a dictionary") continue if "role" not in msg: result.add_error(f"Message {i}: Missing 'role' field") elif msg["role"] not in valid_roles: result.add_warning( f"Message {i}: Unusual role '{msg['role']}'. Expected: {valid_roles}" ) if "content" not in msg: result.add_error(f"Message {i}: Missing 'content' field") return result def estimate_training_tokens(dataset: TrainingDataset) -> int: """Estimate total training tokens. Args: dataset: Dataset to analyze Returns: Estimated total tokens """ total = 0 for example in dataset.examples: text = example.get_text_content() # Rough estimate: 1 token ≈ 4 characters total += len(text) // 4 return total
[docs] def estimate_cost( dataset: TrainingDataset, model: str = "gpt-4o-mini", n_epochs: int = 3 ) -> Dict[str, float]: """Estimate fine-tuning cost. Args: dataset: Dataset to train on model: Base model name n_epochs: Number of training epochs Returns: Dictionary with cost estimates """ # OpenAI pricing (as of 2024) pricing = { "gpt-3.5-turbo": { "training": 0.008, "input": 0.003, "output": 0.006, }, # per 1K tokens "gpt-4": {"training": 0.03, "input": 0.03, "output": 0.06}, } base_model = ( model.split("-")[0] + "-" + model.split("-")[1] if "-" in model else model ) rates = pricing.get(base_model, pricing["gpt-3.5-turbo"]) total_tokens = estimate_training_tokens(dataset) training_tokens = total_tokens * n_epochs training_cost = (training_tokens / 1000) * rates["training"] return { "total_training_tokens": training_tokens, "estimated_training_cost_usd": round(training_cost, 2), "cost_per_epoch_usd": round(training_cost / n_epochs, 2), "model": model, "n_epochs": n_epochs, }
def validate_completion_format(prompt: str, completion: str) -> ValidationResult: """Validate completion-based format. Args: prompt: Prompt text completion: Completion text Returns: ValidationResult """ result = ValidationResult(is_valid=True) if not prompt or not prompt.strip(): result.add_error("Prompt is empty") if not completion or not completion.strip(): result.add_error("Completion is empty") # Check for common issues if completion.startswith(prompt): result.add_warning("Completion appears to include the prompt") return result def validate_chat_format(messages: List[Dict[str, str]]) -> ValidationResult: """Validate chat-based format. Args: messages: List of message dictionaries Returns: ValidationResult """ return validate_messages(messages)