Skip to main content

AI Gateway Integration

Learn how to integrate the MLflow AI Gateway with applications, frameworks, and production systems.

Application Integrations

FastAPI Integration

Build REST APIs that proxy requests to the AI Gateway, adding your own business logic, authentication, and data processing:

from fastapi import FastAPI, HTTPException
from mlflow.deployments import get_deploy_client

app = FastAPI()
client = get_deploy_client("http://localhost:5000")


@app.post("/chat")
async def chat_endpoint(message: str):
try:
response = client.predict(
endpoint="chat", inputs={"messages": [{"role": "user", "content": message}]}
)
return {"response": response["choices"][0]["message"]["content"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post("/embed")
async def embed_endpoint(text: str):
try:
response = client.predict(endpoint="embeddings", inputs={"input": text})
return {"embedding": response["data"][0]["embedding"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

Flask Integration

Create Flask applications that integrate AI capabilities using familiar request/response patterns:

from flask import Flask, request, jsonify
from mlflow.deployments import get_deploy_client

app = Flask(__name__)
client = get_deploy_client("http://localhost:5000")


@app.route("/chat", methods=["POST"])
def chat():
try:
data = request.get_json()
response = client.predict(
endpoint="chat",
inputs={"messages": [{"role": "user", "content": data["message"]}]},
)
return jsonify({"response": response["choices"][0]["message"]["content"]})
except Exception as e:
return jsonify({"error": str(e)}), 500


if __name__ == "__main__":
app.run(debug=True)

Async/Await Support

Handle multiple concurrent requests efficiently using asyncio for high-throughput applications:

import asyncio
import aiohttp
import json


async def async_query_gateway(endpoint, data):
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://localhost:5000/gateway/{endpoint}/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(data),
) as response:
return await response.json()


async def main():
# Concurrent requests
tasks = [
async_query_gateway(
"chat", {"messages": [{"role": "user", "content": f"Question {i}"}]}
)
for i in range(5)
]

responses = await asyncio.gather(*tasks)
for i, response in enumerate(responses):
print(f"Response {i}: {response['choices'][0]['message']['content']}")


# Run async example
asyncio.run(main())

LangChain Integration

Setup

LangChain provides pre-built components that work directly with the AI Gateway, enabling easy integration with LangChain's ecosystem of tools and chains:

from langchain_community.llms import MLflowAIGateway
from langchain_community.embeddings import MlflowAIGatewayEmbeddings
from langchain_community.chat_models import ChatMLflowAIGateway

# Configure LangChain to use your gateway
gateway_uri = "http://localhost:5000"

Chat Models

Create LangChain chat models that route through your gateway, allowing you to switch providers without changing your application code:

# Chat model
chat = ChatMLflowAIGateway(
gateway_uri=gateway_uri,
route="chat",
params={
"temperature": 0.7,
"top_p": 0.95,
},
)

# Generate response
from langchain_core.messages import HumanMessage, SystemMessage

messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is LangChain?"),
]

response = chat(messages)
print(response.content)

Embeddings

Use gateway-powered embeddings for vector search, semantic similarity, and RAG applications:

# Embeddings
embeddings = MlflowAIGatewayEmbeddings(gateway_uri=gateway_uri, route="embeddings")

# Generate embeddings
text_embeddings = embeddings.embed_documents(
["This is a document", "This is another document"]
)

query_embedding = embeddings.embed_query("This is a query")

Complete RAG Example

Build a complete Retrieval-Augmented Generation (RAG) system using the gateway for both embeddings and chat completion:

from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA

# Load documents
loader = TextLoader("path/to/document.txt")
documents = loader.load()

# Split documents
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# Create vector store
vectorstore = FAISS.from_documents(docs, embeddings)

# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=chat, chain_type="stuff", retriever=vectorstore.as_retriever()
)

# Query the system
question = "What is the main topic of the document?"
result = qa_chain.run(question)
print(result)

OpenAI Compatibility

The AI Gateway provides OpenAI-compatible endpoints, allowing you to migrate existing OpenAI applications with minimal code changes:

import openai

# Configure OpenAI client to use the gateway
openai.api_base = "http://localhost:5000/gateway/chat"
openai.api_key = "not-needed" # Gateway handles authentication

# Use standard OpenAI client
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", # Endpoint name in your gateway config
messages=[{"role": "user", "content": "Hello, AI Gateway!"}],
)

print(response.choices[0].message.content)

MLflow Models Integration

Deploy your own custom models alongside external providers for a unified interface to both proprietary and third-party models.

Registering Models

Train and register your models using MLflow's standard workflow, then expose them through the gateway:

import mlflow
import mlflow.pyfunc

# Log and register a model
with mlflow.start_run():
# Your model training code here
mlflow.pyfunc.log_model(
name="my_model",
python_model=MyCustomModel(),
registered_model_name="custom-chat-model",
)

# Deploy the model
# Then configure it in your gateway config.yaml:
endpoints:
- name: custom-model
endpoint_type: llm/v1/chat
model:
provider: mlflow-model-serving
name: custom-chat-model
config:
model_server_url: http://localhost:5001

Production Best Practices

Performance Optimization

  1. Connection Pooling: Use persistent HTTP connections for high-throughput applications
  2. Batch Requests: Group multiple requests when possible
  3. Async Operations: Use async/await for concurrent requests
  4. Caching: Implement response caching for repeated queries

Error Handling

import time
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException


def robust_query(client, endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
return client.predict(endpoint=endpoint, inputs=inputs)
except MlflowException as e:
if attempt < max_retries - 1:
time.sleep(2**attempt) # Exponential backoff
continue
raise e


# Usage
client = get_deploy_client("http://localhost:5000")
response = robust_query(
client, "chat", {"messages": [{"role": "user", "content": "Hello"}]}
)

Security

  1. Use HTTPS in production
  2. Implement authentication at the application level
  3. Validate inputs before sending to the gateway
  4. Monitor usage and implement rate limiting

Monitoring and Logging

import logging
from mlflow.deployments import get_deploy_client

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def monitored_query(client, endpoint, inputs):
start_time = time.time()
try:
logger.info(f"Querying endpoint: {endpoint}")
response = client.predict(endpoint=endpoint, inputs=inputs)
duration = time.time() - start_time
logger.info(f"Query completed in {duration:.2f}s")
return response
except Exception as e:
duration = time.time() - start_time
logger.error(f"Query failed after {duration:.2f}s: {e}")
raise

Load Balancing

For high-availability setups, consider running multiple gateway instances:

import random
from mlflow.deployments import get_deploy_client

# Multiple gateway instances
gateway_urls = ["http://gateway1:5000", "http://gateway2:5000", "http://gateway3:5000"]


def get_client():
url = random.choice(gateway_urls)
return get_deploy_client(url)


# Use with automatic failover
def resilient_query(endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
client = get_client()
return client.predict(endpoint=endpoint, inputs=inputs)
except Exception as e:
if attempt < max_retries - 1:
continue
raise e

Health and Monitoring

# Check gateway health via HTTP
import requests


def check_gateway_health(gateway_url):
try:
response = requests.get(f"{gateway_url}/health")
return {
"status": response.status_code,
"healthy": response.status_code == 200,
"response": response.json() if response.status_code == 200 else None,
}
except requests.RequestException as e:
return {"status": "error", "healthy": False, "error": str(e)}


# Example usage
health = check_gateway_health("http://localhost:5000")
print(f"Gateway Health: {health}")

Next Steps