"""OpenAI provider implementation for LLM generation.
This module provides OpenAI-specific generation functionality.
"""
import os
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional
from kerb.core.types import Message
from ..config import GenerationConfig, GenerationResponse, StreamChunk, Usage
from ..enums import LLMProvider, ModelName
[docs]
class OpenAIGenerator:
"""OpenAI generator with simplified interface.
This is a convenience class for OpenAI-specific generation.
"""
[docs]
def __init__(self, api_key: Optional[str] = None, **kwargs):
"""Initialize OpenAI generator.
Args:
api_key: OpenAI API key (if None, uses OPENAI_API_KEY env var)
**kwargs: Additional configuration
"""
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.config = kwargs
[docs]
def generate(
self, messages: List[Message], model: str = ModelName.GPT_4O_MINI.value, **kwargs
) -> GenerationResponse:
"""Generate using OpenAI API.
Args:
messages: Conversation messages
model: Model name
**kwargs: Additional generation parameters
Returns:
GenerationResponse
"""
config = GenerationConfig(model=model, **kwargs)
return _generate_openai(messages, config, self.api_key)
[docs]
def stream(
self,
messages: List[Message],
model: str = ModelName.GPT_4O_MINI.value,
callback: Optional[Callable[[StreamChunk], None]] = None,
**kwargs,
) -> Iterator[StreamChunk]:
"""Stream from OpenAI API.
Args:
messages: Conversation messages
model: Model name
callback: Optional callback for each chunk
**kwargs: Additional generation parameters
Returns:
Iterator of StreamChunks
"""
config = GenerationConfig(model=model, **kwargs)
return _generate_stream_openai(messages, config, self.api_key, callback)
# ============================================================================
# Internal OpenAI Functions
# ============================================================================
def _generate_openai(
messages: List[Message], config: GenerationConfig, api_key: Optional[str] = None
) -> GenerationResponse:
"""Generate using OpenAI API.
Args:
messages: Conversation messages
config: Generation configuration
api_key: OpenAI API key
Returns:
GenerationResponse
"""
try:
import openai
except ImportError:
raise ImportError(
"OpenAI package not installed. Install with: pip install openai"
)
# Get API key
api_key = api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"OpenAI API key not provided and OPENAI_API_KEY env var not set"
)
client = openai.OpenAI(api_key=api_key)
# Build request
request_params = {
"model": config.model,
"messages": [m.to_dict() for m in messages],
"temperature": config.temperature,
"top_p": config.top_p,
"frequency_penalty": config.frequency_penalty,
"presence_penalty": config.presence_penalty,
"n": config.n,
}
if config.max_tokens:
request_params["max_tokens"] = config.max_tokens
if config.stop_sequences:
request_params["stop"] = config.stop_sequences
if config.logprobs:
request_params["logprobs"] = True
request_params["top_logprobs"] = config.logprobs
if config.seed is not None:
request_params["seed"] = config.seed
if config.response_format:
request_params["response_format"] = config.response_format
if config.tools:
request_params["tools"] = config.tools
if config.tool_choice:
request_params["tool_choice"] = config.tool_choice
# Handle reasoning level
if config.reasoning_level:
# OpenAI only supports reasoning_effort on 'o' models (o1, o3, etc.)
if config.model.startswith(("o1", "o3")):
level = (
config.reasoning_level.value
if hasattr(config.reasoning_level, "value")
else config.reasoning_level
)
request_params["reasoning_effort"] = level
else:
import warnings
warnings.warn(
f"Reasoning level is not supported for model {config.model}. Ignoring.",
UserWarning
)
# Make request
response = client.chat.completions.create(**request_params)
# Parse response
choice = response.choices[0]
content = choice.message.content or ""
usage = Usage(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return GenerationResponse(
content=content,
model=response.model,
provider=LLMProvider.OPENAI,
usage=usage,
finish_reason=choice.finish_reason,
raw_response=response,
)
def _generate_stream_openai(
messages: List[Message],
config: GenerationConfig,
api_key: Optional[str] = None,
callback: Optional[Callable[[StreamChunk], None]] = None,
) -> Iterator[StreamChunk]:
"""Stream from OpenAI API."""
try:
import openai
except ImportError:
raise ImportError(
"OpenAI package not installed. Install with: pip install openai"
)
api_key = api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OpenAI API key not provided")
client = openai.OpenAI(api_key=api_key)
request_params = {
"model": config.model,
"messages": [m.to_dict() for m in messages],
"temperature": config.temperature,
"stream": True,
}
if config.max_tokens:
request_params["max_tokens"] = config.max_tokens
stream = client.chat.completions.create(**request_params)
for chunk_data in stream:
if chunk_data.choices and len(chunk_data.choices) > 0:
choice = chunk_data.choices[0]
content = choice.delta.content or ""
finish_reason = choice.finish_reason
if content or finish_reason:
chunk = StreamChunk(
content=content, finish_reason=finish_reason, model=chunk_data.model
)
if callback:
callback(chunk)
yield chunk