19 Distributed Training

Chapter 19: Distributed Training🔗

"When your data or model is too big for one machine, distributed training is the answer."


19.1 Why Distributed Training?🔗

SINGLE MACHINE LIMITS:            DISTRIBUTED TRAINING SOLVES:
  GPU memory: 24–80GB               Training 70B parameter LLMs
  Training time: days to weeks      Reducing time from days → hours
  Dataset size: limited by RAM      Training on 100B+ sample datasets

19.2 Distributed Training Strategies🔗

Data Parallelism (Most Common)🔗

One model, split data across GPUs:

  GPU 0: batch 1 ──▶ gradients ──┐
  GPU 1: batch 2 ──▶ gradients ──┼──▶ Average gradients ──▶ Update model
  GPU 2: batch 3 ──▶ gradients ──┤
  GPU 3: batch 4 ──▶ gradients ──┘

  Best for: Most deep learning tasks
  Tools: PyTorch DDP, Horovod

Model Parallelism🔗

One model too large for one GPU — split model across GPUs:

  GPU 0: Layers 1-10 ──▶ GPU 1: Layers 11-20 ──▶ GPU 2: Layers 21-30

  Best for: LLMs (70B+ parameters)
  Tools: DeepSpeed, Megatron-LM

Pipeline Parallelism🔗

Different pipeline stages on different GPUs:

  GPU 0: Layer 1   ──▶ GPU 1: Layer 2 ──▶ GPU 2: Layer 3
  Batch A ──▶                                         Output A
          Batch B ──▶                         Output B
                  Batch C ──▶         Output C
  (Overlapping like an assembly line)

19.3 PyTorch Distributed Data Parallel (DDP)🔗

# train_distributed.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, epochs=10):
    setup(rank, world_size)

    # Model on specific GPU
    device = torch.device(f"cuda:{rank}")
    model = MyModel().to(device)
    model = DDP(model, device_ids=[rank])

    # Sampler ensures each GPU gets different data
    dataset = MyDataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=64)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.BCELoss()

    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        for batch_features, batch_labels in loader:
            batch_features = batch_features.to(device)
            batch_labels = batch_labels.to(device)

            optimizer.zero_grad()
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            loss.backward()           # gradients averaged across GPUs automatically
            optimizer.step()

        if rank == 0:                 # only log from main process
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    # Save model (only from rank 0)
    if rank == 0:
        torch.save(model.module.state_dict(), "models/model.pt")

    cleanup()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()  # number of GPUs
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
# Launch distributed training across multiple nodes
torchrun \
  --nproc_per_node=4 \        # 4 GPUs per machine
  --nnodes=2 \                 # 2 machines
  --node_rank=0 \              # this machine is rank 0
  --master_addr=10.0.0.1 \    # IP of master node
  --master_port=29500 \
  train_distributed.py

19.4 Ray — Distributed Python for ML🔗

Ray is a Python framework for distributed computing — used for HPO, distributed training, and ML serving.

# pip install ray[train]

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

ray.init()

def train_func(config):
    """Training function run on each worker."""
    import torch
    from torch.nn.parallel import DistributedDataParallel

    model = MyModel()
    model = train.torch.prepare_model(model)    # Ray handles DDP setup

    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    dataset = train.get_dataset_shard("train")

    for epoch in range(config["epochs"]):
        for batch in dataset.iter_torch_batches(batch_size=64):
            ...
            loss.backward()
            optimizer.step()

        # Report metrics to Ray Train
        train.report({"loss": loss.item(), "epoch": epoch})

# Launch distributed training across 4 GPUs
trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config={"lr": 0.001, "epochs": 10},
    scaling_config=ScalingConfig(
        num_workers=4,
        use_gpu=True,
        resources_per_worker={"GPU": 1, "CPU": 4},
    ),
    datasets={"train": ray.data.read_csv("data/train.csv")},
)

result = trainer.fit()
print(f"Best checkpoint: {result.best_checkpoints}")

Ray Tune for Distributed HPO🔗

from ray import tune
from ray.tune.schedulers import ASHAScheduler

def train_with_config(config):
    model = GradientBoostingClassifier(
        n_estimators=config["n_estimators"],
        learning_rate=config["lr"],
        max_depth=config["max_depth"],
    )
    model.fit(X_train, y_train)
    acc = accuracy_score(y_val, model.predict(X_val))
    tune.report(accuracy=acc)

tuner = tune.Tuner(
    train_with_config,
    param_space={
        "n_estimators": tune.randint(50, 500),
        "lr": tune.loguniform(0.001, 0.3),
        "max_depth": tune.randint(3, 10),
    },
    tune_config=tune.TuneConfig(
        metric="accuracy",
        mode="max",
        num_samples=50,        # 50 trials in parallel
        scheduler=ASHAScheduler(max_t=10, grace_period=1),
    ),
    run_config=train.RunConfig(
        name="churn-hpo",
        storage_path="gs://my-bucket/ray-results",
    ),
)

results = tuner.fit()
best = results.get_best_result("accuracy", "max")
print(f"Best accuracy: {best.metrics['accuracy']:.4f}")
print(f"Best config: {best.config}")

19.5 Distributed Training on Vertex AI🔗

from google.cloud import aiplatform

aiplatform.init(project="my-project", location="us-central1")

# Multi-GPU distributed training on Vertex AI
job = aiplatform.CustomTrainingJob(
    display_name="distributed-training",
    script_path="train_distributed.py",
    container_uri="us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.2-0:latest",
    requirements=["torchvision"],
)

model = job.run(
    machine_type="a2-highgpu-4g",      # 4x A100 80GB GPUs
    accelerator_type="NVIDIA_TESLA_A100",
    accelerator_count=4,
    replica_count=2,                    # 2 machines × 4 GPUs = 8 GPUs total
    reduction_server_count=0,
)

19.6 Distributed Training Tools Summary🔗

Tool Best For Protocol Scale
PyTorch DDP Standard DL training NCCL Multi-GPU, multi-node
Horovod Framework-agnostic NCCL/Gloo Multi-GPU, multi-node
Ray Train Python ML, HPO Ray Multi-node clusters
DeepSpeed LLMs, ZeRO optimizer NCCL Massive scale (1000+ GPUs)
Megatron-LM Transformer LLMs NCCL Very large scale
Vertex AI Training Managed GCP Auto Managed scale
Spark MLlib Traditional ML at scale Spark Petabyte-scale

Next → Chapter 20: Feature Stores