Adding a New LLM Provider to Serapeum#
This guide explains how to add a new LLM provider integration to the Serapeum framework.
Overview#
Serapeum uses a provider-based architecture where each LLM provider is implemented as a separate package under libs/providers/. Each provider package is self-contained and includes all features that provider offers (LLM, embeddings, etc.).
Quick Start#
To add a new provider (e.g., OpenAI):
- Create package structure in
libs/providers/openai/ - Implement LLM class inheriting from
ChatToCompletionMixinandFunctionCallingLLM - Add tests
- Configure workspace
- Document
Step-by-Step Guide#
1. Create Package Structure#
Create a new provider package following this structure:
libs/providers/{provider-name}/
├── src/
│ └── serapeum/
│ └── providers/
│ └── {provider_name}/
│ ├── __init__.py
│ ├── llm.py # Chat/completion LLM implementation
│ ├── embeddings.py # Embeddings (if available)
│ └── shared/ # Shared utilities
│ ├── __init__.py
│ ├── client.py # HTTP client, config
│ └── errors.py # Provider-specific errors
├── tests/
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_llm.py
│ ├── test_embeddings.py
│ └── fixtures/
├── pyproject.toml
└── README.md
2. Create pyproject.toml#
File: libs/providers/{provider}/pyproject.toml
[project]
name = "serapeum-{provider}"
version = "0.1.0"
description = "{Provider} integration for Serapeum LLM framework"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"serapeum-core",
"{provider-sdk}>=1.0.0", # The official provider SDK
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=6.0.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/serapeum"]
# Workspace source - points to local serapeum-core during development
[tool.uv.sources]
serapeum-core = { workspace = true }
# Pytest configuration
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
# Test markers
markers = [
"e2e: End-to-end tests requiring external services",
"integration: Integration tests",
"unit: Unit tests",
]
3. Update Root Workspace#
File: pyproject.toml (root)
Add your provider to the workspace members:
[tool.uv.workspace]
members = [
"libs/core",
"libs/providers/*", # This includes your new provider
]
[tool.uv.sources]
serapeum-{provider} = { workspace = true }
4. Implement the LLM Class#
File: libs/providers/{provider}/src/serapeum/providers/{provider}/llm.py
"""LLM implementation for {Provider}."""
from typing import Any, Sequence
from pydantic import Field, PrivateAttr
from serapeum.core.llms import (
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
ChatToCompletionMixin,
FunctionCallingLLM,
Message,
MessageList,
Metadata,
)
from serapeum.core.tools import ToolCallArguments
class {Provider}LLM(ChatToCompletionMixin, FunctionCallingLLM):
"""LLM implementation for {Provider}.
This class provides chat and completion interfaces for {Provider} models.
Completion methods are automatically provided by ChatToCompletionMixin.
Args:
model: Model identifier (e.g., "gpt-4", "claude-3-opus")
api_key: API key for authentication (optional if set via env)
temperature: Sampling temperature (0.0-1.0)
max_tokens: Maximum tokens in response
**kwargs: Additional provider-specific options
Examples:
Basic chat:
```python
from serapeum.providers.{provider} import {Provider}LLM
llm = {Provider}LLM(model="model-name", api_key="your-key")
response = llm.chat([Message(role=MessageRole.USER, chunks=[TextChunk(content="Hello")])])
print(response.message.content)
```
Streaming chat:
```python
for chunk in llm.chat([Message(role=MessageRole.USER, chunks=[TextChunk(content="Hi")])], stream=True):
print(chunk.delta, end="", flush=True)
```
Completion interface (via mixin):
```python
response = llm.complete("What is 2+2?")
print(response.text)
```
"""
model: str = Field(description="Model identifier")
api_key: str | None = Field(default=None, description="API key")
temperature: float = Field(default=0.7, ge=0.0, le=1.0)
max_tokens: int = Field(default=1000, gt=0)
# Private attributes for client instances
_client: Any = PrivateAttr(default=None)
_async_client: Any = PrivateAttr(default=None)
def __init__(
self,
model: str,
api_key: str | None = None,
temperature: float = 0.7,
max_tokens: int = 1000,
**kwargs: Any,
):
"""Initialize the {Provider} LLM."""
super().__init__(
model=model,
api_key=api_key,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
self._client = None
self._async_client = None
@classmethod
def class_name(cls) -> str:
"""Return the class identifier."""
return "{Provider}LLM"
@property
def metadata(self) -> Metadata:
"""Return LLM metadata."""
return Metadata(
context_window=self._get_context_window(),
num_output=self.max_tokens,
model_name=self.model,
is_chat_model=True,
is_function_calling_model=True,
)
@property
def client(self):
"""Lazy-loaded synchronous client."""
if self._client is None:
# Import and initialize the provider's SDK client
from {provider}_sdk import Client
self._client = Client(api_key=self.api_key or self._get_api_key_from_env())
return self._client
@property
def async_client(self):
"""Lazy-loaded async client."""
if self._async_client is None:
from {provider}_sdk import AsyncClient
self._async_client = AsyncClient(api_key=self.api_key or self._get_api_key_from_env())
return self._async_client
def _get_api_key_from_env(self) -> str:
"""Get API key from environment variable."""
import os
api_key = os.getenv("{PROVIDER}_API_KEY")
if not api_key:
raise ValueError(
f"{PROVIDER}_API_KEY environment variable not set. "
"Please set it or pass api_key parameter."
)
return api_key
def _get_context_window(self) -> int:
"""Get context window size for the model."""
# Map model names to context windows
context_windows = {
"model-small": 4096,
"model-large": 128000,
# Add more models
}
return context_windows.get(self.model, 4096)
def _convert_messages(self, messages: MessageList) -> list[dict]:
"""Convert MessageList to provider's format."""
provider_messages = []
for msg in messages:
provider_messages.append({
"role": msg.role.value,
"content": msg.content or "",
})
return provider_messages
# ===== REQUIRED: Implement these 2 chat methods =====
# Each accepts stream: bool and handles both modes internally.
def chat(self, messages: MessageList, stream: bool = False, **kwargs: Any) -> ChatResponse | ChatResponseGen:
"""Send a chat request (or streaming if stream=True).
Args:
messages: List of messages in the conversation
stream: If True, return a generator yielding incremental chunks
**kwargs: Provider-specific options
Returns:
ChatResponse when stream=False, ChatResponseGen when stream=True
"""
if stream:
return self._stream_chat(messages, **kwargs)
provider_messages = self._convert_messages(messages)
response = self.client.chat.completions.create(
model=self.model,
messages=provider_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
**kwargs,
)
return ChatResponse(
message=Message(
role=MessageRole.ASSISTANT,
chunks=[TextChunk(content=response.choices[0].message.content)],
),
raw=response,
)
def _stream_chat(self, messages: MessageList, **kwargs: Any) -> ChatResponseGen:
"""Internal streaming chat helper.
Yields:
ChatResponse chunks with delta content
"""
provider_messages = self._convert_messages(messages)
def gen():
response_stream = self.client.chat.completions.create(
model=self.model,
messages=provider_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=True,
**kwargs,
)
full_content = ""
for chunk in response_stream:
delta = chunk.choices[0].delta.content or ""
full_content += delta
yield ChatResponse(
message=Message(
role=MessageRole.ASSISTANT,
chunks=[TextChunk(content=full_content)],
),
delta=delta,
raw=chunk,
)
return gen()
async def achat(self, messages: MessageList, stream: bool = False, **kwargs: Any) -> ChatResponse | ChatResponseAsyncGen:
"""Async chat request (or streaming if stream=True).
Args:
messages: List of messages in the conversation
stream: If True, return an async generator yielding incremental chunks
**kwargs: Provider-specific options
Returns:
ChatResponse when stream=False, ChatResponseAsyncGen when stream=True
"""
if stream:
return await self._astream_chat(messages, **kwargs)
provider_messages = self._convert_messages(messages)
response = await self.async_client.chat.completions.create(
model=self.model,
messages=provider_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
**kwargs,
)
return ChatResponse(
message=Message(
role=MessageRole.ASSISTANT,
chunks=[TextChunk(content=response.choices[0].message.content)],
),
raw=response,
)
async def _astream_chat(
self, messages: MessageList, **kwargs: Any
) -> ChatResponseAsyncGen:
"""Internal async streaming chat helper.
Returns:
Async generator yielding ChatResponse chunks
"""
provider_messages = self._convert_messages(messages)
async def gen():
response_stream = await self.async_client.chat.completions.create(
model=self.model,
messages=provider_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=True,
**kwargs,
)
full_content = ""
async for chunk in response_stream:
delta = chunk.choices[0].delta.content or ""
full_content += delta
yield ChatResponse(
message=Message(
role=MessageRole.ASSISTANT,
chunks=[TextChunk(content=full_content)],
),
delta=delta,
raw=chunk,
)
return gen()
# ===== COMPLETION METHODS =====
# complete(stream=False), acomplete(stream=False)
# are automatically provided by ChatToCompletionMixin!
# ===== FUNCTION CALLING (Optional) =====
def _prepare_chat_with_tools(
self,
tools: Sequence["BaseTool"],
user_msg: str | Message | None = None,
chat_history: list[Message] | None = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""Prepare chat request with tools."""
tool_specs = [
tool.metadata.to_openai_tool(skip_length_check=True) for tool in tools
]
if isinstance(user_msg, str):
user_msg = Message(role=MessageRole.USER, chunks=[TextChunk(content=user_msg)])
messages = chat_history or []
if user_msg:
messages.append(user_msg)
return {
"messages": messages,
"tools": tool_specs or None,
}
def get_tool_calls_from_response(
self,
response: ChatResponse,
error_on_no_tool_call: bool = True,
) -> list[ToolCallArguments]:
"""Extract tool calls from response."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
if not tool_calls:
if error_on_no_tool_call:
raise ValueError("No tool calls found in response")
return []
# Parse tool calls into ToolCallArguments
selections = []
for tool_call in tool_calls:
selections.append(
ToolCallArguments(
tool_id=tool_call.get("id", tool_call["function"]["name"]),
tool_name=tool_call["function"]["name"],
tool_kwargs=tool_call["function"]["arguments"],
)
)
return selections
5. Implement Package __init__.py#
File: libs/providers/{provider}/src/serapeum/providers/{provider}/__init__.py
"""{Provider} integration for Serapeum LLM framework."""
from serapeum.providers.{provider}.llm import {Provider}LLM
# Add embeddings if available
# from serapeum.providers.{provider}.embeddings import {Provider}Embeddings
__all__ = [
"{Provider}LLM",
# "{Provider}Embeddings",
]
6. Create Tests#
File: libs/providers/{provider}/tests/test_llm.py
"""Tests for {Provider}LLM."""
import pytest
from unittest.mock import Mock, patch
from serapeum.core.base.llms.types import Message, MessageRole, MessageList
from serapeum.providers.{provider} import {Provider}LLM
@pytest.fixture
def mock_client():
"""Create a mock {Provider} client."""
with patch("{provider}_sdk.Client") as mock:
yield mock
@pytest.fixture
def llm(mock_client):
"""Create a {Provider}LLM instance with mocked client."""
return {Provider}LLM(model="test-model", api_key="test-key")
class TestChat:
"""Test chat functionality."""
def test_chat_basic(self, llm, mock_client):
"""Test basic chat."""
# Mock response
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Hello!"
llm.client.chat.completions.create.return_value = mock_response
# Test
messages = MessageList.from_list([
Message(role=MessageRole.USER, chunks=[TextChunk(content="Hi")])
])
response = llm.chat(messages)
assert response.message.content == "Hello!"
assert response.message.role == MessageRole.ASSISTANT
def test_stream_chat(self, llm, mock_client):
"""Test streaming chat via chat(stream=True)."""
# Mock streaming response
mock_chunks = [
Mock(choices=[Mock(delta=Mock(content="Hel"))]),
Mock(choices=[Mock(delta=Mock(content="lo!"))]),
]
llm.client.chat.completions.create.return_value = iter(mock_chunks)
# Test
messages = MessageList.from_list([
Message(role=MessageRole.USER, chunks=[TextChunk(content="Hi")])
])
chunks = list(llm.chat(messages, stream=True))
assert len(chunks) == 2
assert chunks[0].delta == "Hel"
assert chunks[1].delta == "lo!"
@pytest.mark.asyncio
async def test_achat(self, llm, mock_client):
"""Test async chat."""
# Mock async response
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Async hello!"
llm.async_client.chat.completions.create.return_value = mock_response
# Test
messages = MessageList.from_list([
Message(role=MessageRole.USER, chunks=[TextChunk(content="Hi")])
])
response = await llm.achat(messages)
assert response.message.content == "Async hello!"
class TestCompletion:
"""Test completion functionality (from ChatToCompletionMixin)."""
def test_complete(self, llm, mock_client):
"""Test that complete() works via mixin."""
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Answer: 42"
llm.client.chat.completions.create.return_value = mock_response
# Test completion interface
response = llm.complete("What is the answer?")
assert response.text == "Answer: 42"
@pytest.mark.e2e
class TestE2E:
"""End-to-end tests requiring actual API."""
def test_real_chat(self):
"""Test with real API (requires API key)."""
import os
api_key = os.getenv("{PROVIDER}_API_KEY")
if not api_key:
pytest.skip("{PROVIDER}_API_KEY not set")
llm = {Provider}LLM(model="model-name", api_key=api_key)
response = llm.chat(
MessageList.from_list([
Message(role=MessageRole.USER, chunks=[TextChunk(content="Say 'test'")])
])
)
assert response.message.content
assert isinstance(response.message.content, str)
7. Create README#
File: libs/providers/{provider}/README.md
# Serapeum {Provider} Integration
{Provider} integration for the Serapeum LLM framework.
## Installation
```bash
pip install serapeum-{provider}
Usage#
Basic Chat#
from serapeum.providers.{provider} import {Provider}LLM
from serapeum.core.base.llms.types import Message, MessageRole
llm = {Provider}LLM(
model="model-name",
api_key="your-api-key", # or set {PROVIDER}_API_KEY env var
)
response = llm.chat([
Message(role=MessageRole.USER, chunks=[TextChunk(content="Hello!")])
])
print(response.message.content)
Streaming#
for chunk in llm.chat([
Message(role=MessageRole.USER, chunks=[TextChunk(content="Tell me a story")])
], stream=True):
print(chunk.delta, end="", flush=True)
Completion Interface#
Async#
import asyncio
async def main():
response = await llm.achat([
Message(role=MessageRole.USER, chunks=[TextChunk(content="Hi")])
])
print(response.message.content)
asyncio.run(main())
Configuration#
model: Model identifierapi_key: API key (or set{PROVIDER}_API_KEYenvironment variable)temperature: Sampling temperature (0.0-1.0, default: 0.7)max_tokens: Maximum response tokens (default: 1000)
Features#
- ✅ Chat interface
- ✅ Streaming
- ✅ Async support
- ✅ Completion interface (via ChatToCompletionMixin)
- ✅ Function calling (if supported by provider)
- ✅ Embeddings (if available)
Testing#
# Unit tests (mocked)
uv run pytest libs/providers/{provider}/tests/ -m "not e2e"
# E2E tests (requires API key)
export {PROVIDER}_API_KEY="your-key"
uv run pytest libs/providers/{provider}/tests/ -m e2e
License#
Same as Serapeum core.
### 8. Build and Install
```bash
# Sync workspace dependencies
uv sync --dev
# Build the provider package
uv build --wheel -o dist libs/providers/{provider}
# Install locally for testing
uv pip install dist/serapeum_{provider}-0.1.0-py3-none-any.whl
# Or use workspace install
uv sync --package serapeum-{provider}
9. Test the Provider#
# Run unit tests (mocked)
uv run pytest libs/providers/{provider}/tests/ -v -m "not e2e"
# Run E2E tests (requires API key)
export {PROVIDER}_API_KEY="your-api-key"
uv run pytest libs/providers/{provider}/tests/ -v -m e2e
# Run with coverage
uv run pytest libs/providers/{provider}/tests/ --cov=serapeum.providers.{provider}
Key Points#
✅ Do This#
-
Inherit from ChatToCompletionMixin FIRST:
Order matters for Method Resolution Order (MRO)! -
Implement 2 chat methods (each accepts
stream: booland handles both modes): chat(messages, stream=False, **kwargs)-
achat(messages, stream=False, **kwargs) -
Get completion methods for free from the mixin
-
Use lazy-loaded clients for better resource management
-
Add proper error handling for API errors
-
Write comprehensive tests with mocks and E2E tests
❌ Don't Do This#
- Don't implement completion methods manually - use the mixin
- Don't put FunctionCallingLLM before ChatToCompletionMixin in inheritance
- Don't forget to handle async properly (event loops, etc.)
- Don't hardcode API keys - use environment variables
- Don't skip testing - both mocked and E2E
Examples#
Existing Providers#
See these for reference:
- Ollama:
libs/providers/ollama/- Complete implementation with streaming, async, and tool calling - Core abstractions:
libs/core/src/serapeum/core/llms/abstractions/- Base classes and adapters
Minimal Example#
For a minimal provider implementation, you only need ~100 lines of code:
from serapeum.core.llms import ChatToCompletionMixin, FunctionCallingLLM
class MinimalLLM(ChatToCompletionMixin, FunctionCallingLLM):
def chat(self, messages, stream=False, **kwargs):
if stream:
return self._stream_chat(messages, **kwargs)
# Call your provider's API
return ChatResponse(...)
def _stream_chat(self, messages, **kwargs):
# Stream from provider
yield ChatResponse(...)
async def achat(self, messages, stream=False, **kwargs):
if stream:
return await self._astream_chat(messages, **kwargs)
# Async call
return ChatResponse(...)
async def _astream_chat(self, messages, **kwargs):
async def gen():
yield ChatResponse(...)
return gen()
# That's it! Completion methods come from ChatToCompletionMixin
Checklist#
Use this checklist when adding a new provider:
- [ ] Created package structure under
libs/providers/{provider}/ - [ ] Created
pyproject.tomlwith dependencies - [ ] Updated root
pyproject.tomlworkspace members - [ ] Implemented LLM class with ChatToCompletionMixin
- [ ] Implemented 4 chat methods
- [ ] Verified completion methods work (from mixin)
- [ ] Added
_prepare_chat_with_tools()for function calling - [ ] Created comprehensive tests (unit + E2E)
- [ ] Created README.md with examples
- [ ] Added conftest.py with fixtures
- [ ] Tested with
uv sync --package serapeum-{provider} - [ ] Verified MRO is correct:
ChatToCompletionMixinbeforeFunctionCallingLLM - [ ] All tests passing
- [ ] Added to documentation
Next Steps#
After implementing your provider:
- Add to documentation: Update main docs with provider info
- Add examples: Create example notebooks/scripts
- Publish: Build and publish to PyPI (if public)
- Announce: Add to changelog and release notes
Need Help?#
- Example code: See
examples/chat_to_completion_mixin_example.py - Architecture: See
docs/architecture/for detailed docs
Happy coding! 🚀