AI Gateway Integration
Learn how to integrate the MLflow AI Gateway with applications, frameworks, and production systems.
Application Integrations
FastAPI Integration
Build REST APIs that proxy requests to the AI Gateway, adding your own business logic, authentication, and data processing:
from fastapi import FastAPI, HTTPException
from mlflow.deployments import get_deploy_client
app = FastAPI()
client = get_deploy_client("http://localhost:5000")
@app.post("/chat")
async def chat_endpoint(message: str):
try:
response = client.predict(
endpoint="chat", inputs={"messages": [{"role": "user", "content": message}]}
)
return {"response": response["choices"][0]["message"]["content"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embed")
async def embed_endpoint(text: str):
try:
response = client.predict(endpoint="embeddings", inputs={"input": text})
return {"embedding": response["data"][0]["embedding"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Flask Integration
Create Flask applications that integrate AI capabilities using familiar request/response patterns:
from flask import Flask, request, jsonify
from mlflow.deployments import get_deploy_client
app = Flask(__name__)
client = get_deploy_client("http://localhost:5000")
@app.route("/chat", methods=["POST"])
def chat():
try:
data = request.get_json()
response = client.predict(
endpoint="chat",
inputs={"messages": [{"role": "user", "content": data["message"]}]},
)
return jsonify({"response": response["choices"][0]["message"]["content"]})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(debug=True)
Async/Await Support
Handle multiple concurrent requests efficiently using asyncio for high-throughput applications:
import asyncio
import aiohttp
import json
async def async_query_gateway(endpoint, data):
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://localhost:5000/gateway/{endpoint}/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(data),
) as response:
return await response.json()
async def main():
# Concurrent requests
tasks = [
async_query_gateway(
"chat", {"messages": [{"role": "user", "content": f"Question {i}"}]}
)
for i in range(5)
]
responses = await asyncio.gather(*tasks)
for i, response in enumerate(responses):
print(f"Response {i}: {response['choices'][0]['message']['content']}")
# Run async example
asyncio.run(main())
LangChain Integration
Setup
LangChain provides pre-built components that work directly with the AI Gateway, enabling easy integration with LangChain's ecosystem of tools and chains:
from langchain_community.llms import MLflowAIGateway
from langchain_community.embeddings import MlflowAIGatewayEmbeddings
from langchain_community.chat_models import ChatMLflowAIGateway
# Configure LangChain to use your gateway
gateway_uri = "http://localhost:5000"
Chat Models
Create LangChain chat models that route through your gateway, allowing you to switch providers without changing your application code:
# Chat model
chat = ChatMLflowAIGateway(
gateway_uri=gateway_uri,
route="chat",
params={
"temperature": 0.7,
"top_p": 0.95,
},
)
# Generate response
from langchain_core.messages import HumanMessage, SystemMessage
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is LangChain?"),
]
response = chat(messages)
print(response.content)
Embeddings
Use gateway-powered embeddings for vector search, semantic similarity, and RAG applications:
# Embeddings
embeddings = MlflowAIGatewayEmbeddings(gateway_uri=gateway_uri, route="embeddings")
# Generate embeddings
text_embeddings = embeddings.embed_documents(
["This is a document", "This is another document"]
)
query_embedding = embeddings.embed_query("This is a query")
Complete RAG Example
Build a complete Retrieval-Augmented Generation (RAG) system using the gateway for both embeddings and chat completion:
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA
# Load documents
loader = TextLoader("path/to/document.txt")
documents = loader.load()
# Split documents
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
# Create vector store
vectorstore = FAISS.from_documents(docs, embeddings)
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=chat, chain_type="stuff", retriever=vectorstore.as_retriever()
)
# Query the system
question = "What is the main topic of the document?"
result = qa_chain.run(question)
print(result)
OpenAI Compatibility
The AI Gateway provides OpenAI-compatible endpoints, allowing you to migrate existing OpenAI applications with minimal code changes:
import openai
# Configure OpenAI client to use the gateway
openai.api_base = "http://localhost:5000/gateway/chat"
openai.api_key = "not-needed" # Gateway handles authentication
# Use standard OpenAI client
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", # Endpoint name in your gateway config
messages=[{"role": "user", "content": "Hello, AI Gateway!"}],
)
print(response.choices[0].message.content)
MLflow Models Integration
Deploy your own custom models alongside external providers for a unified interface to both proprietary and third-party models.
Registering Models
Train and register your models using MLflow's standard workflow, then expose them through the gateway:
import mlflow
import mlflow.pyfunc
# Log and register a model
with mlflow.start_run():
# Your model training code here
mlflow.pyfunc.log_model(
name="my_model",
python_model=MyCustomModel(),
registered_model_name="custom-chat-model",
)
# Deploy the model
# Then configure it in your gateway config.yaml:
endpoints:
- name: custom-model
endpoint_type: llm/v1/chat
model:
provider: mlflow-model-serving
name: custom-chat-model
config:
model_server_url: http://localhost:5001
Production Best Practices
Performance Optimization
- Connection Pooling: Use persistent HTTP connections for high-throughput applications
- Batch Requests: Group multiple requests when possible
- Async Operations: Use async/await for concurrent requests
- Caching: Implement response caching for repeated queries
Error Handling
import time
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException
def robust_query(client, endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
return client.predict(endpoint=endpoint, inputs=inputs)
except MlflowException as e:
if attempt < max_retries - 1:
time.sleep(2**attempt) # Exponential backoff
continue
raise e
# Usage
client = get_deploy_client("http://localhost:5000")
response = robust_query(
client, "chat", {"messages": [{"role": "user", "content": "Hello"}]}
)
Security
- Use HTTPS in production
- Implement authentication at the application level
- Validate inputs before sending to the gateway
- Monitor usage and implement rate limiting
Monitoring and Logging
import logging
from mlflow.deployments import get_deploy_client
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def monitored_query(client, endpoint, inputs):
start_time = time.time()
try:
logger.info(f"Querying endpoint: {endpoint}")
response = client.predict(endpoint=endpoint, inputs=inputs)
duration = time.time() - start_time
logger.info(f"Query completed in {duration:.2f}s")
return response
except Exception as e:
duration = time.time() - start_time
logger.error(f"Query failed after {duration:.2f}s: {e}")
raise
Load Balancing
For high-availability setups, consider running multiple gateway instances:
import random
from mlflow.deployments import get_deploy_client
# Multiple gateway instances
gateway_urls = ["http://gateway1:5000", "http://gateway2:5000", "http://gateway3:5000"]
def get_client():
url = random.choice(gateway_urls)
return get_deploy_client(url)
# Use with automatic failover
def resilient_query(endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
client = get_client()
return client.predict(endpoint=endpoint, inputs=inputs)
except Exception as e:
if attempt < max_retries - 1:
continue
raise e
Health and Monitoring
# Check gateway health via HTTP
import requests
def check_gateway_health(gateway_url):
try:
response = requests.get(f"{gateway_url}/health")
return {
"status": response.status_code,
"healthy": response.status_code == 200,
"response": response.json() if response.status_code == 200 else None,
}
except requests.RequestException as e:
return {"status": "error", "healthy": False, "error": str(e)}
# Example usage
health = check_gateway_health("http://localhost:5000")
print(f"Gateway Health: {health}")