Custom Serving Applications
MLflow's custom serving applications allow you to build sophisticated model serving solutions that go beyond simple prediction endpoints. Using the PyFunc framework, you can create custom applications with complex preprocessing, postprocessing, multi-model inference, and business logic integration.
Overviewโ
Custom serving applications in MLflow are built using the mlflow.pyfunc.PythonModel
class, which provides a flexible framework for creating deployable models with custom logic. This approach is ideal when you need to:
- ๐ Implement advanced preprocessing and postprocessing logic
- ๐ง Combine multiple models within a single serving pipeline
- โ Apply business rules and custom validation checks
- ๐ฃ Support diverse input and output data formats
- ๐ Integrate seamlessly with external systems or databases
Custom PyFunc Modelโ
Custom Modelโ
Here's an example of a custom PyFunc model:
import mlflow
import pandas as pd
import json
from typing import Dict, List, Any
import openai # or any other LLM client
class CustomLLMModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""Load LLM configuration and initialize client"""
# Load model configuration from artifacts
config_path = context.artifacts.get("config", "config.json")
with open(config_path, "r") as f:
self.config = json.load(f)
# Initialize LLM client
self.client = openai.OpenAI(api_key=self.config["api_key"])
self.model_name = self.config["model_name"]
self.system_prompt = self.config.get(
"system_prompt", "You are a helpful assistant."
)
def predict(self, context, model_input):
"""Core LLM prediction logic"""
if isinstance(model_input, pd.DataFrame):
# Handle DataFrame input with prompts
responses = []
for _, row in model_input.iterrows():
user_prompt = row.get("prompt", row.get("input", ""))
processed_prompt = self._preprocess_prompt(user_prompt)
response = self._generate_response(processed_prompt)
post_processed = self._postprocess_response(response)
responses.append(post_processed)
return pd.DataFrame({"response": responses})
elif isinstance(model_input, dict):
# Handle single prompt
user_prompt = model_input.get("prompt", model_input.get("input", ""))
processed_prompt = self._preprocess_prompt(user_prompt)
response = self._generate_response(processed_prompt)
return self._postprocess_response(response)
else:
# Handle string input
processed_prompt = self._preprocess_prompt(str(model_input))
response = self._generate_response(processed_prompt)
return self._postprocess_response(response)
def _preprocess_prompt(self, prompt: str) -> str:
"""Custom prompt preprocessing logic"""
# Example: Add context, format prompt, apply templates
template = self.config.get("prompt_template", "{prompt}")
return template.format(prompt=prompt)
def _generate_response(self, prompt: str) -> str:
"""Core LLM inference"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
],
temperature=self.config.get("temperature", 0.7),
max_tokens=self.config.get("max_tokens", 1000),
)
return response.choices[0].message.content
except Exception as e:
return f"Error generating response: {str(e)}"
def _postprocess_response(self, response: str) -> str:
"""Custom response postprocessing logic"""
# Example: format output, apply filters, extract specific parts
if self.config.get("strip_whitespace", True):
response = response.strip()
max_length = self.config.get("max_response_length")
if max_length and len(response) > max_length:
response = response[:max_length] + "..."
return response
# Example configuration
config = {
"api_key": "your-api-key",
"model_name": "gpt-4",
"system_prompt": "You are an expert data analyst. Provide clear, concise answers.",
"temperature": 0.3,
"max_tokens": 500,
"prompt_template": "Context: Data Analysis Task\n\nQuestion: {prompt}\n\nAnswer:",
"strip_whitespace": True,
"max_response_length": 1000,
}
# Save configuration
with open("config.json", "w") as f:
json.dump(config, f)
# Log the model
with mlflow.start_run():
# Log configuration as artifact
mlflow.log_artifact("config.json")
# Create input example
input_example = pd.DataFrame(
{"prompt": ["What is machine learning?", "Explain neural networks"]}
)
model_info = mlflow.pyfunc.log_model(
name="custom_llm_model",
python_model=CustomLLMModel(),
artifacts={"config": "config.json"},
input_example=input_example,
)
Multi-Model Ensembleโ
Create a custom application that combines multiple LLMs with different strengths:
import mlflow
import mlflow.pyfunc
import pandas as pd
import json
import openai
import anthropic
from typing import List, Dict, Any
class MultiLLMEnsemble(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""Load multiple LLM configurations from artifacts"""
# Load ensemble configuration
config_path = context.artifacts["ensemble_config"]
with open(config_path, "r") as f:
self.config = json.load(f)
# Initialize multiple LLM clients
self.llm_clients = {}
# OpenAI client
if "openai" in self.config["models"]:
self.llm_clients["openai"] = openai.OpenAI(
api_key=self.config["models"]["openai"]["api_key"]
)
# Anthropic client
if "anthropic" in self.config["models"]:
self.llm_clients["anthropic"] = anthropic.Anthropic(
api_key=self.config["models"]["anthropic"]["api_key"]
)
# Add other LLM clients as needed
self.voting_strategy = self.config.get("voting_strategy", "weighted_average")
self.model_weights = self.config.get("model_weights", {})
def predict(self, context, model_input):
"""Ensemble prediction with multiple LLMs"""
if isinstance(model_input, pd.DataFrame):
responses = []
for _, row in model_input.iterrows():
prompt = row.get("prompt", row.get("input", ""))
task_type = row.get("task_type", "general")
ensemble_response = self._generate_ensemble_response(prompt, task_type)
responses.append(ensemble_response)
return pd.DataFrame({"response": responses})
else:
prompt = model_input.get("prompt", str(model_input))
task_type = (
model_input.get("task_type", "general")
if isinstance(model_input, dict)
else "general"
)
return self._generate_ensemble_response(prompt, task_type)
def _generate_ensemble_response(
self, prompt: str, task_type: str = "general"
) -> str:
"""Generate responses from multiple LLMs and combine them"""
responses = {}
# Get task-specific model configuration
task_config = self.config.get("task_routing", {}).get(task_type, {})
active_models = task_config.get("models", list(self.llm_clients.keys()))
# Generate responses from each active model
for model_name in active_models:
if model_name in self.llm_clients:
response = self._generate_single_response(model_name, prompt, task_type)
responses[model_name] = response
# Combine responses based on voting strategy
return self._combine_responses(responses, task_type)
def _generate_single_response(
self, model_name: str, prompt: str, task_type: str
) -> str:
"""Generate response from a single LLM"""
model_config = self.config["models"][model_name]
try:
if model_name == "openai":
response = self.llm_clients["openai"].chat.completions.create(
model=model_config["model_name"],
messages=[
{
"role": "system",
"content": model_config.get("system_prompt", ""),
},
{"role": "user", "content": prompt},
],
temperature=model_config.get("temperature", 0.7),
max_tokens=model_config.get("max_tokens", 1000),
)
return response.choices[0].message.content
elif model_name == "anthropic":
response = self.llm_clients["anthropic"].messages.create(
model=model_config["model_name"],
max_tokens=model_config.get("max_tokens", 1000),
temperature=model_config.get("temperature", 0.7),
messages=[{"role": "user", "content": prompt}],
)
return response.content[0].text
# Add other LLM implementations here
except Exception as e:
return f"Error from {model_name}: {str(e)}"
def _combine_responses(self, responses: Dict[str, str], task_type: str) -> str:
"""Combine multiple LLM responses using specified strategy"""
if self.voting_strategy == "best_for_task":
# Route to best model for specific task type
task_config = self.config.get("task_routing", {}).get(task_type, {})
preferred_model = task_config.get("preferred_model")
if preferred_model and preferred_model in responses:
return responses[preferred_model]
elif self.voting_strategy == "consensus":
# Return response if multiple models agree (simplified)
response_list = list(responses.values())
if len(set(response_list)) == 1:
return response_list[0]
else:
# If no consensus, return the longest response
return max(response_list, key=len)
elif self.voting_strategy == "weighted_combination":
# Combine responses with weights (simplified text combination)
combined_response = "Combined insights:\n\n"
for model_name, response in responses.items():
weight = self.model_weights.get(model_name, 1.0)
combined_response += (
f"[{model_name.upper()} - Weight: {weight}]: {response}\n\n"
)
return combined_response
# Default: return first available response
return list(responses.values())[0] if responses else "No response generated"
# Example ensemble configuration
ensemble_config = {
"voting_strategy": "best_for_task",
"models": {
"openai": {
"api_key": "your-openai-key",
"model_name": "gpt-4",
"system_prompt": "You are a helpful assistant specialized in technical analysis.",
"temperature": 0.3,
"max_tokens": 800,
},
"anthropic": {
"api_key": "your-anthropic-key",
"model_name": "claude-3-sonnet-20240229",
"temperature": 0.5,
"max_tokens": 1000,
},
},
"task_routing": {
"code_analysis": {"models": ["openai"], "preferred_model": "openai"},
"creative_writing": {"models": ["anthropic"], "preferred_model": "anthropic"},
"general": {"models": ["openai", "anthropic"], "preferred_model": "openai"},
},
"model_weights": {"openai": 0.6, "anthropic": 0.4},
}
# Save configuration
with open("ensemble_config.json", "w") as f:
json.dump(ensemble_config, f)
# Log the ensemble model
with mlflow.start_run():
# Log configuration as artifact
mlflow.log_artifact("ensemble_config.json")
# Create input example
input_example = pd.DataFrame(
{
"prompt": ["Explain quantum computing", "Write a creative story about AI"],
"task_type": ["general", "creative_writing"],
}
)
mlflow.pyfunc.log_model(
name="multi_llm_ensemble",
python_model=MultiLLMEnsemble(),
artifacts={"ensemble_config": "ensemble_config.json"},
input_example=input_example,
)