13 Airflow

Chapter 13: Apache Airflow for MLOps🔗

"Airflow is the de-facto standard for orchestrating ML pipelines — scheduling, dependencies, retries, and monitoring all in one."


13.1 What is Apache Airflow?🔗

Apache Airflow is an open-source workflow orchestration platform. You define workflows as DAGs (Directed Acyclic Graphs) in Python code, and Airflow schedules, executes, and monitors them.

Why Airflow for MLOps?🔗

Without Airflow:                         With Airflow:
  cron job 1: preprocess.sh               DAG: ml_pipeline
  cron job 2: train.sh (runs at 3am)        │
  cron job 3: evaluate.sh                   ├── preprocess (task)
  Cron jobs don't know about each           │       │ depends on
  other. If preprocess fails, train         ├── train (task)
  still runs on stale data.                 │       │ depends on
                                            └── evaluate (task)
  No visibility into failures.
                                          Built-in retry, alerting,
                                          dependencies, UI dashboard.

13.2 Core Concepts🔗

Term Definition
DAG Directed Acyclic Graph — a workflow definition with tasks and dependencies
Task A single unit of work in a DAG (run a script, call an API, move data)
Operator Template for a task type (PythonOperator, BashOperator, GCSOperator...)
Sensor Special operator that waits for a condition (file arrives, API responds)
Scheduler Process that reads DAG files and schedules task execution
Executor Runs the tasks (LocalExecutor, CeleryExecutor, KubernetesExecutor)
XCom Cross-communication — pass data between tasks

13.3 DAG Anatomy🔗

ml_pipeline DAG:

  start ──▶ ingest_data ──▶ validate_data ──▶ preprocess ──▶ train ──▶ evaluate ──▶ deploy
                                │                                           │
                           FAIL: alert                                 accuracy < 0.85?
                                                                            │
                                                                      FAIL: stop,
                                                                            alert

13.4 Complete ML Pipeline DAG🔗

# dags/ml_training_pipeline.py
from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.gcs import GCSFileTransformOperator
from airflow.providers.google.cloud.operators.vertex_ai.custom_job import (
    CreateCustomTrainingJobOperator
)
from airflow.utils.trigger_rule import TriggerRule
import json

# ── Default Arguments ─────────────────────────────────────────────────
default_args = {
    "owner": "mlops-team",
    "retries": 2,
    "retry_delay": timedelta(minutes=5),
    "email": ["mlops@company.com"],
    "email_on_failure": True,
    "email_on_retry": False,
}

# ── DAG Definition ────────────────────────────────────────────────────
with DAG(
    dag_id="churn_model_training_pipeline",
    default_args=default_args,
    description="Weekly churn model retraining pipeline",
    schedule_interval="0 2 * * 1",  # Every Monday at 2am
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=["mlops", "churn", "training"],
) as dag:

    # ── Task 1: Ingest data from BigQuery ─────────────────────────────
    def ingest_data():
        from google.cloud import bigquery
        import pandas as pd

        client = bigquery.Client()
        query = """
            SELECT * FROM `my-project.ml_data.customer_features`
            WHERE date >= DATE_SUB(CURRENT_DATE(), INTERVAL 90 DAY)
        """
        df = client.query(query).to_dataframe()
        df.to_csv("/tmp/raw_data.csv", index=False)
        print(f"Ingested {len(df)} rows")

    ingest = PythonOperator(
        task_id="ingest_data",
        python_callable=ingest_data,
    )

    # ── Task 2: Validate data with Great Expectations ─────────────────
    def validate_data():
        import great_expectations as gx
        context = gx.get_context()
        result = context.run_checkpoint("churn_checkpoint")
        if not result.success:
            raise ValueError("Data validation FAILED — check GE data docs")
        print("✅ Data validation passed")

    validate = PythonOperator(
        task_id="validate_data",
        python_callable=validate_data,
    )

    # ── Task 3: Preprocess ────────────────────────────────────────────
    preprocess = BashOperator(
        task_id="preprocess_data",
        bash_command="python /opt/airflow/src/preprocess.py "
                     "--input /tmp/raw_data.csv "
                     "--output /tmp/features.csv",
    )

    # ── Task 4: Train model on Vertex AI ─────────────────────────────
    train_on_vertex = CreateCustomTrainingJobOperator(
        task_id="train_model_vertex",
        project_id="my-gcp-project",
        region="us-central1",
        display_name="churn-training-{{ ds }}",   # {{ ds }} = execution date
        worker_pool_specs=[{
            "machine_spec": {"machine_type": "n1-standard-8"},
            "replica_count": 1,
            "python_package_spec": {
                "executor_image_uri": "us-docker.pkg.dev/vertex-ai/training/scikit-learn-cpu.1-0:latest",
                "package_uris": ["gs://my-bucket/trainer-0.1.tar.gz"],
                "python_module": "trainer.task",
                "args": ["--data-path", "/tmp/features.csv",
                         "--model-dir", "gs://my-bucket/models/"],
            },
        }],
    )

    # ── Task 5: Evaluate and branch ───────────────────────────────────
    def evaluate_model(**context):
        import json
        with open("/tmp/metrics.json") as f:
            metrics = json.load(f)
        accuracy = metrics["accuracy"]
        print(f"Model accuracy: {accuracy}")

        # Push to XCom for downstream tasks
        context["task_instance"].xcom_push(key="accuracy", value=accuracy)

        if accuracy >= 0.85:
            return "deploy_to_staging"
        else:
            return "notify_failure"

    evaluate = BranchPythonOperator(
        task_id="evaluate_model",
        python_callable=evaluate_model,
        provide_context=True,
    )

    # ── Task 6a: Deploy to staging ────────────────────────────────────
    def deploy_staging():
        import subprocess
        subprocess.run([
            "kubectl", "set", "image",
            "deployment/churn-model",
            "churn-model=gcr.io/my-project/churn-model:latest",
            "-n", "staging"
        ], check=True)

    deploy = PythonOperator(
        task_id="deploy_to_staging",
        python_callable=deploy_staging,
    )

    # ── Task 6b: Notify failure ───────────────────────────────────────
    def notify_failure(**context):
        accuracy = context["task_instance"].xcom_pull(
            task_ids="evaluate_model", key="accuracy"
        )
        # Send Slack message
        import requests
        requests.post(
            "https://hooks.slack.com/services/...",
            json={"text": f"⚠️ Model accuracy {accuracy:.3f} below threshold — NOT deployed"}
        )

    notify = PythonOperator(
        task_id="notify_failure",
        python_callable=notify_failure,
        provide_context=True,
    )

    # ── Task 7: Update MLflow model registry ─────────────────────────
    def update_registry(**context):
        import mlflow
        accuracy = context["task_instance"].xcom_pull(
            task_ids="evaluate_model", key="accuracy"
        )
        client = mlflow.tracking.MlflowClient()
        client.transition_model_version_stage(
            name="churn-classifier",
            version=1,
            stage="Staging"
        )

    update_registry_task = PythonOperator(
        task_id="update_model_registry",
        python_callable=update_registry,
        provide_context=True,
        trigger_rule=TriggerRule.ONE_SUCCESS,
    )

    # ── DAG Dependencies ──────────────────────────────────────────────
    ingest >> validate >> preprocess >> train_on_vertex >> evaluate
    evaluate >> [deploy, notify]
    deploy >> update_registry_task

13.5 Airflow Sensors (Wait for Events)🔗

from airflow.providers.google.cloud.sensors.gcs import GCSObjectExistenceSensor
from airflow.sensors.time_sensor import TimeSensor

# Wait for new data file in GCS before starting pipeline
wait_for_data = GCSObjectExistenceSensor(
    task_id="wait_for_data_file",
    bucket="my-ml-bucket",
    object="data/new_batch_{{ ds }}.csv",   # {{ ds }} = today's date
    timeout=60 * 60 * 6,                    # timeout after 6 hours
    poke_interval=60 * 5,                   # check every 5 minutes
    mode="reschedule",                       # don't block a worker slot
)

# Usage in DAG:
wait_for_data >> ingest >> validate >> ...

13.6 GCP-Specific Airflow Operators🔗

from airflow.providers.google.cloud.operators.bigquery import (
    BigQueryInsertJobOperator, BigQueryCheckOperator
)
from airflow.providers.google.cloud.operators.dataflow import DataflowCreateJobOperator
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
    CreateAutoMLTabularTrainingJobOperator
)

# Run BigQuery query
run_query = BigQueryInsertJobOperator(
    task_id="run_feature_query",
    configuration={
        "query": {
            "query": "SELECT * FROM `project.dataset.table`",
            "destinationTable": {
                "projectId": "my-project",
                "datasetId": "ml_data",
                "tableId": "features_{{ ds_nodash }}"
            },
            "writeDisposition": "WRITE_TRUNCATE",
        }
    },
)

# Check data quality in BigQuery
check_data = BigQueryCheckOperator(
    task_id="check_row_count",
    sql="SELECT COUNT(*) > 1000 FROM `project.ml_data.features_{{ ds_nodash }}`",
)

# Train AutoML model
automl_train = CreateAutoMLTabularTrainingJobOperator(
    task_id="automl_train",
    project_id="my-project",
    region="us-central1",
    display_name="churn-automl-{{ ds }}",
    optimization_prediction_type="classification",
    dataset_id="{{ ti.xcom_pull('create_dataset', key='dataset_id') }}",
    target_column="churned",
    training_fraction_split=0.8,
    budget_milli_node_hours=1000,
)

13.7 Cloud Composer (Managed Airflow on GCP)🔗

Cloud Composer is Google's managed Apache Airflow service.

# Create Cloud Composer environment
gcloud composer environments create ml-airflow-env \
  --location=us-central1 \
  --image-version=composer-3-airflow-2.7.3 \
  --environment-size=small

# Upload DAG
gcloud composer environments storage dags import \
  --environment=ml-airflow-env \
  --location=us-central1 \
  --source=dags/ml_training_pipeline.py

# Access Airflow UI
gcloud composer environments describe ml-airflow-env \
  --location=us-central1 \
  --format="value(config.airflowUri)"

13.8 Airflow vs Other Orchestrators🔗

Feature Airflow Kubeflow Pipelines Prefect Vertex AI Pipelines
Paradigm DAG in Python K8s-native pipelines Python-first flows Kubeflow on GCP
UI Good Good Excellent Good
Scheduling ✅ Rich ❌ Limited ✅ Yes ✅ Yes
K8s native Partial ✅ Yes Partial ✅ Yes
ML-specific Partial ✅ Yes ✅ Yes ✅ Yes
GCP managed Cloud Composer Via GKE Cloud Run ✅ Native
Community Very large Growing Growing GCP-specific

Next → Chapter 14: Kubeflow Pipelines