25 Model Serving

Chapter 25: Model Serving Strategies🔗

"How you serve a model is just as important as how well you trained it."


25.1 Serving Patterns Overview🔗

┌──────────────────────────────────────────────────────────────────────┐
│                    MODEL SERVING PATTERNS                            │
│                                                                      │
│  ONLINE (Real-Time)     BATCH               STREAMING                │
│  ┌─────────────────┐   ┌──────────────┐    ┌──────────────────┐     │
│  │  REST API        │   │  Score file  │    │  Kafka consumer  │     │
│  │  gRPC            │   │  in bulk     │    │  Flink job       │     │
│  │  WebSocket       │   │  (overnight) │    │  Pub/Sub         │     │
│  │                  │   │              │    │                  │     │
│  │  Latency: <100ms │   │  Latency:    │    │  Latency: ~1s    │     │
│  │  Throughput: med │   │  Hours       │    │  Throughput: high│     │
│  │  Use: fraud det. │   │  Use: recs   │    │  Use: IoT, feeds │     │
│  └─────────────────┘   └──────────────┘    └──────────────────┘     │
└──────────────────────────────────────────────────────────────────────┘

25.2 Online Serving (REST API)🔗

The most common serving pattern — expose the model as an HTTP endpoint.

# src/serve.py — FastAPI model server
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, validator
import pickle
import numpy as np
import time
import logging
from prometheus_client import Counter, Histogram, generate_latest
from starlette.responses import Response

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Churn Prediction API", version="2.0.0")

# ── Load Model ───────────────────────────────────────────────────
with open("models/model.pkl", "rb") as f:
    model = pickle.load(f)

MODEL_VERSION = "v2.1.0"

# ── Prometheus Metrics ───────────────────────────────────────────
REQUEST_COUNT = Counter("predictions_total", "Total predictions", ["version", "result"])
LATENCY = Histogram("prediction_latency_seconds", "Latency histogram")
CONFIDENCE = Histogram("prediction_confidence", "Confidence scores")

# ── Request/Response Models ─────────────────────────────────────
class PredictionRequest(BaseModel):
    customer_id: str
    age: int
    income: float
    tenure_months: int
    monthly_charges: float
    plan: str

    @validator("age")
    def valid_age(cls, v):
        if not 18 <= v <= 120:
            raise ValueError("Age must be between 18 and 120")
        return v

class PredictionResponse(BaseModel):
    customer_id: str
    prediction: int           # 0 = no churn, 1 = churn
    will_churn: bool
    confidence: float
    model_version: str
    latency_ms: float

# ── Endpoints ─────────────────────────────────────────────────
@app.get("/health")
def health():
    return {"status": "healthy", "model_version": MODEL_VERSION}

@app.get("/metrics")
def metrics():
    return Response(generate_latest(), media_type="text/plain")

@app.post("/predict", response_model=PredictionResponse)
def predict(req: PredictionRequest):
    start = time.time()

    plan_map = {"basic": 0, "standard": 1, "premium": 2}
    if req.plan not in plan_map:
        raise HTTPException(422, f"Unknown plan: {req.plan}")

    features = np.array([[
        req.age, req.income, req.tenure_months,
        req.monthly_charges, plan_map[req.plan]
    ]])

    prediction = int(model.predict(features)[0])
    confidence = float(model.predict_proba(features).max())
    latency_ms = (time.time() - start) * 1000

    REQUEST_COUNT.labels(version=MODEL_VERSION, result=str(prediction)).inc()
    LATENCY.observe(latency_ms / 1000)
    CONFIDENCE.observe(confidence)

    return PredictionResponse(
        customer_id=req.customer_id,
        prediction=prediction,
        will_churn=bool(prediction),
        confidence=confidence,
        model_version=MODEL_VERSION,
        latency_ms=round(latency_ms, 2),
    )

@app.post("/predict/batch")
def batch_predict(requests: list[PredictionRequest]):
    return [predict(req) for req in requests]

25.3 Batch Serving🔗

Score large datasets offline — typically overnight or on a schedule.

# batch_predict.py — score a CSV of customers
import pandas as pd
import pickle
import argparse
from google.cloud import storage

def batch_score(input_path: str, output_path: str, model_path: str):
    # Load model
    with open(model_path, "rb") as f:
        model = pickle.load(f)

    # Load data
    df = pd.read_csv(input_path)
    print(f"Scoring {len(df)} customers...")

    # Prepare features
    feature_cols = ["age", "income", "tenure_months", "monthly_charges", "plan_encoded"]
    X = df[feature_cols]

    # Score
    df["churn_prediction"] = model.predict(X)
    df["churn_probability"] = model.predict_proba(X)[:, 1]
    df["scored_at"] = pd.Timestamp.now()

    # Save results
    df.to_csv(output_path, index=False)
    print(f"Results saved to {output_path}")

# Run:
# python batch_predict.py \
#   --input gs://bucket/data/customers_jan.csv \
#   --output gs://bucket/predictions/jan_predictions.csv \
#   --model models/model.pkl
# Airflow DAG task for nightly batch scoring
nightly_score = BashOperator(
    task_id="nightly_batch_score",
    bash_command="""
        python batch_predict.py \
          --input gs://bucket/data/customers_{{ ds }}.csv \
          --output gs://bucket/predictions/{{ ds }}_predictions.csv
    """,
    dag=dag,
)

25.4 Deployment Strategies🔗

Simple Deployment (Replace all at once)🔗

Old pods:  [v1] [v1] [v1]
           STOP all
New pods:  [v2] [v2] [v2]
⚠️ Downtime during switch

Rolling Update (Zero downtime — K8s default)🔗

[v1] [v1] [v1]
[v1] [v1] [v2]   ← add v2, remove v1 gradually
[v1] [v2] [v2]
[v2] [v2] [v2]
✅ No downtime

Blue/Green Deployment (Instant switch, easy rollback)🔗

Load Balancer ────▶ Blue (v1) [ACTIVE]
              ────▶ Green (v2) [STANDBY]

After validation:
Load Balancer ────▶ Blue (v1) [STANDBY — easy rollback]
              ────▶ Green (v2) [ACTIVE]

✅ Instant switch, instant rollback
✅ Full parallel testing
❌ 2x infrastructure cost during transition
# K8s blue-green with two deployments + Service switch
# Deploy green
kubectl apply -f k8s/green-deployment.yaml

# Test green (internal)
kubectl port-forward svc/ml-model-green 8001:80

# Switch traffic
kubectl patch service ml-model-svc -p '{"spec":{"selector":{"version":"green"}}}'

# Rollback instantly
kubectl patch service ml-model-svc -p '{"spec":{"selector":{"version":"blue"}}}'

Canary Deployment (Gradual traffic shift)🔗

100% ─────────────────────────────────▶ v1
                                          ↑ reduce gradually

10% ─────────────────────────────────▶ v2 (canary)
                                          ↑ increase gradually

If no issues after 24h → 100% to v2
# Vertex AI canary deployment
endpoint.update(traffic_split={
    model_v1_deployed_id: 90,   # 90% to stable
    model_v2_deployed_id: 10,   # 10% to canary
})

# Monitor canary for 24h, then promote
# If metrics look good:
endpoint.update(traffic_split={
    model_v1_deployed_id: 0,
    model_v2_deployed_id: 100,
})

Shadow Mode (Test without risk)🔗

User Request ─────────▶ v1 (RESPONDS to user) ──▶ user gets response
              ─────────▶ v2 (runs in SHADOW, response discarded)
                              │
                              ▼ log predictions for analysis

✅ Zero risk — v2 never affects users
✅ Real production traffic
✅ Compare v1 vs v2 predictions offline

25.5 SLA & Latency Targets🔗

Common serving SLA targets by use case:

Use Case                    P99 Latency    Availability
─────────────────────────────────────────────────────────
Fraud detection             < 50ms          99.99%
Real-time recommendations   < 100ms         99.9%
Risk scoring                < 200ms         99.9%
Customer churn (online)     < 500ms         99.5%
Document classification     < 1s            99%
Batch scoring               Hours           99%
Model training              Hours/days      95%

25.6 Model Versioning in Production🔗

Model Registry → Endpoint traffic split → Monitoring → Decision

v1 (stable)    → 80% traffic  → accuracy 0.89 → ✅ keep
v2 (canary)    → 20% traffic  → accuracy 0.91 → ✅ promote to 100%
v3 (shadow)    → 0% traffic   → accuracy 0.85 → ❌ send back to team

Next → Chapter 26: Serving Frameworks