Chapter 13: Apache Airflow for MLOps🔗
"Airflow is the de-facto standard for orchestrating ML pipelines — scheduling, dependencies, retries, and monitoring all in one."
13.1 What is Apache Airflow?🔗
Apache Airflow is an open-source workflow orchestration platform. You define workflows as DAGs (Directed Acyclic Graphs) in Python code, and Airflow schedules, executes, and monitors them.
Why Airflow for MLOps?🔗
Without Airflow: With Airflow:
cron job 1: preprocess.sh DAG: ml_pipeline
cron job 2: train.sh (runs at 3am) │
cron job 3: evaluate.sh ├── preprocess (task)
Cron jobs don't know about each │ │ depends on
other. If preprocess fails, train ├── train (task)
still runs on stale data. │ │ depends on
└── evaluate (task)
No visibility into failures.
Built-in retry, alerting,
dependencies, UI dashboard.
13.2 Core Concepts🔗
| Term | Definition |
|---|---|
| DAG | Directed Acyclic Graph — a workflow definition with tasks and dependencies |
| Task | A single unit of work in a DAG (run a script, call an API, move data) |
| Operator | Template for a task type (PythonOperator, BashOperator, GCSOperator...) |
| Sensor | Special operator that waits for a condition (file arrives, API responds) |
| Scheduler | Process that reads DAG files and schedules task execution |
| Executor | Runs the tasks (LocalExecutor, CeleryExecutor, KubernetesExecutor) |
| XCom | Cross-communication — pass data between tasks |
13.3 DAG Anatomy🔗
ml_pipeline DAG:
start ──▶ ingest_data ──▶ validate_data ──▶ preprocess ──▶ train ──▶ evaluate ──▶ deploy
│ │
FAIL: alert accuracy < 0.85?
│
FAIL: stop,
alert
13.4 Complete ML Pipeline DAG🔗
# dags/ml_training_pipeline.py
from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.gcs import GCSFileTransformOperator
from airflow.providers.google.cloud.operators.vertex_ai.custom_job import (
CreateCustomTrainingJobOperator
)
from airflow.utils.trigger_rule import TriggerRule
import json
# ── Default Arguments ─────────────────────────────────────────────────
default_args = {
"owner": "mlops-team",
"retries": 2,
"retry_delay": timedelta(minutes=5),
"email": ["mlops@company.com"],
"email_on_failure": True,
"email_on_retry": False,
}
# ── DAG Definition ────────────────────────────────────────────────────
with DAG(
dag_id="churn_model_training_pipeline",
default_args=default_args,
description="Weekly churn model retraining pipeline",
schedule_interval="0 2 * * 1", # Every Monday at 2am
start_date=datetime(2024, 1, 1),
catchup=False,
tags=["mlops", "churn", "training"],
) as dag:
# ── Task 1: Ingest data from BigQuery ─────────────────────────────
def ingest_data():
from google.cloud import bigquery
import pandas as pd
client = bigquery.Client()
query = """
SELECT * FROM `my-project.ml_data.customer_features`
WHERE date >= DATE_SUB(CURRENT_DATE(), INTERVAL 90 DAY)
"""
df = client.query(query).to_dataframe()
df.to_csv("/tmp/raw_data.csv", index=False)
print(f"Ingested {len(df)} rows")
ingest = PythonOperator(
task_id="ingest_data",
python_callable=ingest_data,
)
# ── Task 2: Validate data with Great Expectations ─────────────────
def validate_data():
import great_expectations as gx
context = gx.get_context()
result = context.run_checkpoint("churn_checkpoint")
if not result.success:
raise ValueError("Data validation FAILED — check GE data docs")
print("✅ Data validation passed")
validate = PythonOperator(
task_id="validate_data",
python_callable=validate_data,
)
# ── Task 3: Preprocess ────────────────────────────────────────────
preprocess = BashOperator(
task_id="preprocess_data",
bash_command="python /opt/airflow/src/preprocess.py "
"--input /tmp/raw_data.csv "
"--output /tmp/features.csv",
)
# ── Task 4: Train model on Vertex AI ─────────────────────────────
train_on_vertex = CreateCustomTrainingJobOperator(
task_id="train_model_vertex",
project_id="my-gcp-project",
region="us-central1",
display_name="churn-training-{{ ds }}", # {{ ds }} = execution date
worker_pool_specs=[{
"machine_spec": {"machine_type": "n1-standard-8"},
"replica_count": 1,
"python_package_spec": {
"executor_image_uri": "us-docker.pkg.dev/vertex-ai/training/scikit-learn-cpu.1-0:latest",
"package_uris": ["gs://my-bucket/trainer-0.1.tar.gz"],
"python_module": "trainer.task",
"args": ["--data-path", "/tmp/features.csv",
"--model-dir", "gs://my-bucket/models/"],
},
}],
)
# ── Task 5: Evaluate and branch ───────────────────────────────────
def evaluate_model(**context):
import json
with open("/tmp/metrics.json") as f:
metrics = json.load(f)
accuracy = metrics["accuracy"]
print(f"Model accuracy: {accuracy}")
# Push to XCom for downstream tasks
context["task_instance"].xcom_push(key="accuracy", value=accuracy)
if accuracy >= 0.85:
return "deploy_to_staging"
else:
return "notify_failure"
evaluate = BranchPythonOperator(
task_id="evaluate_model",
python_callable=evaluate_model,
provide_context=True,
)
# ── Task 6a: Deploy to staging ────────────────────────────────────
def deploy_staging():
import subprocess
subprocess.run([
"kubectl", "set", "image",
"deployment/churn-model",
"churn-model=gcr.io/my-project/churn-model:latest",
"-n", "staging"
], check=True)
deploy = PythonOperator(
task_id="deploy_to_staging",
python_callable=deploy_staging,
)
# ── Task 6b: Notify failure ───────────────────────────────────────
def notify_failure(**context):
accuracy = context["task_instance"].xcom_pull(
task_ids="evaluate_model", key="accuracy"
)
# Send Slack message
import requests
requests.post(
"https://hooks.slack.com/services/...",
json={"text": f"⚠️ Model accuracy {accuracy:.3f} below threshold — NOT deployed"}
)
notify = PythonOperator(
task_id="notify_failure",
python_callable=notify_failure,
provide_context=True,
)
# ── Task 7: Update MLflow model registry ─────────────────────────
def update_registry(**context):
import mlflow
accuracy = context["task_instance"].xcom_pull(
task_ids="evaluate_model", key="accuracy"
)
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
name="churn-classifier",
version=1,
stage="Staging"
)
update_registry_task = PythonOperator(
task_id="update_model_registry",
python_callable=update_registry,
provide_context=True,
trigger_rule=TriggerRule.ONE_SUCCESS,
)
# ── DAG Dependencies ──────────────────────────────────────────────
ingest >> validate >> preprocess >> train_on_vertex >> evaluate
evaluate >> [deploy, notify]
deploy >> update_registry_task
13.5 Airflow Sensors (Wait for Events)🔗
from airflow.providers.google.cloud.sensors.gcs import GCSObjectExistenceSensor
from airflow.sensors.time_sensor import TimeSensor
# Wait for new data file in GCS before starting pipeline
wait_for_data = GCSObjectExistenceSensor(
task_id="wait_for_data_file",
bucket="my-ml-bucket",
object="data/new_batch_{{ ds }}.csv", # {{ ds }} = today's date
timeout=60 * 60 * 6, # timeout after 6 hours
poke_interval=60 * 5, # check every 5 minutes
mode="reschedule", # don't block a worker slot
)
# Usage in DAG:
wait_for_data >> ingest >> validate >> ...
13.6 GCP-Specific Airflow Operators🔗
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryInsertJobOperator, BigQueryCheckOperator
)
from airflow.providers.google.cloud.operators.dataflow import DataflowCreateJobOperator
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
CreateAutoMLTabularTrainingJobOperator
)
# Run BigQuery query
run_query = BigQueryInsertJobOperator(
task_id="run_feature_query",
configuration={
"query": {
"query": "SELECT * FROM `project.dataset.table`",
"destinationTable": {
"projectId": "my-project",
"datasetId": "ml_data",
"tableId": "features_{{ ds_nodash }}"
},
"writeDisposition": "WRITE_TRUNCATE",
}
},
)
# Check data quality in BigQuery
check_data = BigQueryCheckOperator(
task_id="check_row_count",
sql="SELECT COUNT(*) > 1000 FROM `project.ml_data.features_{{ ds_nodash }}`",
)
# Train AutoML model
automl_train = CreateAutoMLTabularTrainingJobOperator(
task_id="automl_train",
project_id="my-project",
region="us-central1",
display_name="churn-automl-{{ ds }}",
optimization_prediction_type="classification",
dataset_id="{{ ti.xcom_pull('create_dataset', key='dataset_id') }}",
target_column="churned",
training_fraction_split=0.8,
budget_milli_node_hours=1000,
)
13.7 Cloud Composer (Managed Airflow on GCP)🔗
Cloud Composer is Google's managed Apache Airflow service.
# Create Cloud Composer environment
gcloud composer environments create ml-airflow-env \
--location=us-central1 \
--image-version=composer-3-airflow-2.7.3 \
--environment-size=small
# Upload DAG
gcloud composer environments storage dags import \
--environment=ml-airflow-env \
--location=us-central1 \
--source=dags/ml_training_pipeline.py
# Access Airflow UI
gcloud composer environments describe ml-airflow-env \
--location=us-central1 \
--format="value(config.airflowUri)"
13.8 Airflow vs Other Orchestrators🔗
| Feature | Airflow | Kubeflow Pipelines | Prefect | Vertex AI Pipelines |
|---|---|---|---|---|
| Paradigm | DAG in Python | K8s-native pipelines | Python-first flows | Kubeflow on GCP |
| UI | Good | Good | Excellent | Good |
| Scheduling | ✅ Rich | ❌ Limited | ✅ Yes | ✅ Yes |
| K8s native | Partial | ✅ Yes | Partial | ✅ Yes |
| ML-specific | Partial | ✅ Yes | ✅ Yes | ✅ Yes |
| GCP managed | Cloud Composer | Via GKE | Cloud Run | ✅ Native |
| Community | Very large | Growing | Growing | GCP-specific |