Source code for kerb.fine_tuning.quality

"""Data quality analysis functions for fine-tuning datasets."""

import re
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union

from .types import DatasetStats, TrainingDataset

if TYPE_CHECKING:
    from kerb.core.enums import Device


[docs] def analyze_dataset(dataset: TrainingDataset) -> DatasetStats: """Analyze dataset statistics. Args: dataset: Dataset to analyze Returns: DatasetStats with comprehensive statistics """ stats = DatasetStats() stats.total_examples = len(dataset) token_counts = [] prompt_tokens = [] completion_tokens = [] labels = [] for example in dataset.examples: text = example.get_text_content() tokens = len(text.split()) # Rough estimate token_counts.append(tokens) if example.prompt: prompt_tokens.append(len(example.prompt.split())) if example.completion: completion_tokens.append(len(example.completion.split())) if example.label: labels.append(example.label) if token_counts: stats.total_tokens = sum(token_counts) stats.avg_tokens_per_example = stats.total_tokens / len(token_counts) stats.min_tokens = min(token_counts) stats.max_tokens = max(token_counts) if prompt_tokens: stats.avg_prompt_tokens = sum(prompt_tokens) / len(prompt_tokens) if completion_tokens: stats.avg_completion_tokens = sum(completion_tokens) / len(completion_tokens) if labels: stats.label_distribution = dict(Counter(labels)) # Check for duplicates hashes = [ex.compute_hash() for ex in dataset.examples] stats.duplicate_count = len(hashes) - len(set(hashes)) return stats
[docs] def check_data_quality(dataset: TrainingDataset) -> Dict[str, Any]: """Check dataset for quality issues. Args: dataset: Dataset to check Returns: Dictionary with quality metrics and issues """ issues = [] # Check for empty content empty_count = 0 for i, example in enumerate(dataset.examples): text = example.get_text_content().strip() if not text: empty_count += 1 issues.append(f"Example {i}: Empty content") # Check for very short examples short_count = 0 for i, example in enumerate(dataset.examples): text = example.get_text_content() if len(text) < 10: short_count += 1 issues.append(f"Example {i}: Very short content ({len(text)} chars)") # Check for duplicates stats = analyze_dataset(dataset) return { "total_examples": len(dataset), "empty_examples": empty_count, "short_examples": short_count, "duplicate_examples": stats.duplicate_count, "issues": issues[:100], # Limit to first 100 issues "total_issues": len(issues), }
def detect_pii(text: str) -> Dict[str, List[str]]: """Detect personally identifiable information in text. Args: text: Text to analyze Returns: Dictionary with detected PII types and examples """ pii = { "emails": [], "phone_numbers": [], "ssn": [], "credit_cards": [], } # Email pattern email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" pii["emails"] = re.findall(email_pattern, text) # Phone pattern (simple) phone_pattern = r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b" pii["phone_numbers"] = re.findall(phone_pattern, text) # SSN pattern ssn_pattern = r"\b\d{3}-\d{2}-\d{4}\b" pii["ssn"] = re.findall(ssn_pattern, text) # Credit card pattern (simple) cc_pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b" pii["credit_cards"] = re.findall(cc_pattern, text) return {k: v for k, v in pii.items() if v} def compute_perplexity( dataset: TrainingDataset, model_name: str = "gpt2", max_examples: int = None, device: Union["Device", str] = "cpu", ) -> Dict[str, Any]: """Compute perplexity distribution for dataset using a HuggingFace model. Perplexity measures how well the model predicts the text - lower is better. Useful for identifying low-quality or out-of-distribution examples. Args: dataset: Dataset to analyze model_name: HuggingFace model name (e.g., "gpt2", "meta-llama/Llama-2-7b-hf") max_examples: Maximum number of examples to evaluate (None = all) device: Device to run on (Device enum or string: "cpu", "cuda", "cuda:0", "cuda:1", "mps") Returns: Dictionary with perplexity statistics Examples: >>> from kerb.core.enums import Device >>> stats = compute_perplexity(dataset, model_name="gpt2", device=Device.CUDA) >>> print(f"Average perplexity: {stats['mean_perplexity']:.2f}") Note: Requires transformers and torch packages. Install with: pip install transformers torch """ from kerb.core.enums import Device, validate_enum_or_string try: import warnings import torch from transformers import AutoModelForCausalLM, AutoTokenizer warnings.filterwarnings("ignore") except ImportError: return { "error": "Required packages not installed", "message": "Install with: pip install transformers torch", } # Validate and normalize device device_val = validate_enum_or_string(device, Device, "device") if isinstance(device_val, Device): device_str = device_val.value else: device_str = device_val try: # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) model.to(device_str) model.eval() # Set pad token if not set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token perplexities = [] examples_to_process = ( dataset.examples[:max_examples] if max_examples else dataset.examples ) with torch.no_grad(): for example in examples_to_process: text = example.get_text_content() if not text.strip(): continue # Tokenize inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512 ) inputs = {k: v.to(device_str) for k, v in inputs.items()} # Compute loss (negative log-likelihood) outputs = model(**inputs, labels=inputs["input_ids"]) loss = outputs.loss.item() # Perplexity = exp(loss) perplexity = torch.exp(torch.tensor(loss)).item() perplexities.append(perplexity) if not perplexities: return {"message": "No valid examples to compute perplexity"} # Calculate statistics perplexities.sort() n = len(perplexities) return { "model": model_name, "examples_evaluated": n, "mean_perplexity": sum(perplexities) / n, "median_perplexity": perplexities[n // 2], "min_perplexity": min(perplexities), "max_perplexity": max(perplexities), "p25_perplexity": perplexities[n // 4], "p75_perplexity": perplexities[3 * n // 4], "perplexities": perplexities, } except Exception as e: return { "error": str(e), "message": f"Failed to compute perplexity with model {model_name}", } def check_length_distribution(dataset: TrainingDataset) -> Dict[str, Any]: """Analyze token length distribution. Args: dataset: Dataset to analyze Returns: Dictionary with length statistics """ lengths = [] for example in dataset.examples: text = example.get_text_content() lengths.append(len(text.split())) lengths.sort() n = len(lengths) return { "count": n, "min": min(lengths) if lengths else 0, "max": max(lengths) if lengths else 0, "mean": sum(lengths) / n if n > 0 else 0, "median": lengths[n // 2] if n > 0 else 0, "p25": lengths[n // 4] if n > 0 else 0, "p75": lengths[3 * n // 4] if n > 0 else 0, } def detect_duplicates( dataset: TrainingDataset, threshold: float = 0.95 ) -> List[Tuple[int, int]]: """Find duplicate or near-duplicate examples. Args: dataset: Dataset to check threshold: Similarity threshold (1.0 = exact match) Returns: List of (index1, index2) pairs of duplicates """ duplicates = [] hashes = {} for i, example in enumerate(dataset.examples): content_hash = example.compute_hash() if content_hash in hashes: duplicates.append((hashes[content_hash], i)) else: hashes[content_hash] = i return duplicates def check_label_distribution(dataset: TrainingDataset) -> Dict[str, Any]: """Analyze label distribution for classification tasks. Args: dataset: Dataset to analyze Returns: Dictionary with label statistics """ labels = [ex.label for ex in dataset.examples if ex.label is not None] if not labels: return {"message": "No labels found in dataset"} label_counts = Counter(labels) total = len(labels) return { "total_labeled": total, "unique_labels": len(label_counts), "label_counts": dict(label_counts), "label_percentages": { k: round(v / total * 100, 2) for k, v in label_counts.items() }, "most_common": label_counts.most_common(5), "is_balanced": ( max(label_counts.values()) / min(label_counts.values()) < 2 if label_counts else False ), }