Chapter 19: Distributed Training🔗
"When your data or model is too big for one machine, distributed training is the answer."
19.1 Why Distributed Training?🔗
SINGLE MACHINE LIMITS: DISTRIBUTED TRAINING SOLVES:
GPU memory: 24–80GB Training 70B parameter LLMs
Training time: days to weeks Reducing time from days → hours
Dataset size: limited by RAM Training on 100B+ sample datasets
19.2 Distributed Training Strategies🔗
Data Parallelism (Most Common)🔗
One model, split data across GPUs:
GPU 0: batch 1 ──▶ gradients ──┐
GPU 1: batch 2 ──▶ gradients ──┼──▶ Average gradients ──▶ Update model
GPU 2: batch 3 ──▶ gradients ──┤
GPU 3: batch 4 ──▶ gradients ──┘
Best for: Most deep learning tasks
Tools: PyTorch DDP, Horovod
Model Parallelism🔗
One model too large for one GPU — split model across GPUs:
GPU 0: Layers 1-10 ──▶ GPU 1: Layers 11-20 ──▶ GPU 2: Layers 21-30
Best for: LLMs (70B+ parameters)
Tools: DeepSpeed, Megatron-LM
Pipeline Parallelism🔗
Different pipeline stages on different GPUs:
GPU 0: Layer 1 ──▶ GPU 1: Layer 2 ──▶ GPU 2: Layer 3
Batch A ──▶ Output A
Batch B ──▶ Output B
Batch C ──▶ Output C
(Overlapping like an assembly line)
19.3 PyTorch Distributed Data Parallel (DDP)🔗
# train_distributed.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size, epochs=10):
setup(rank, world_size)
# Model on specific GPU
device = torch.device(f"cuda:{rank}")
model = MyModel().to(device)
model = DDP(model, device_ids=[rank])
# Sampler ensures each GPU gets different data
dataset = MyDataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()
for epoch in range(epochs):
sampler.set_epoch(epoch)
for batch_features, batch_labels in loader:
batch_features = batch_features.to(device)
batch_labels = batch_labels.to(device)
optimizer.zero_grad()
outputs = model(batch_features)
loss = criterion(outputs, batch_labels)
loss.backward() # gradients averaged across GPUs automatically
optimizer.step()
if rank == 0: # only log from main process
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# Save model (only from rank 0)
if rank == 0:
torch.save(model.module.state_dict(), "models/model.pt")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count() # number of GPUs
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
# Launch distributed training across multiple nodes
torchrun \
--nproc_per_node=4 \ # 4 GPUs per machine
--nnodes=2 \ # 2 machines
--node_rank=0 \ # this machine is rank 0
--master_addr=10.0.0.1 \ # IP of master node
--master_port=29500 \
train_distributed.py
19.4 Ray — Distributed Python for ML🔗
Ray is a Python framework for distributed computing — used for HPO, distributed training, and ML serving.
# pip install ray[train]
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
ray.init()
def train_func(config):
"""Training function run on each worker."""
import torch
from torch.nn.parallel import DistributedDataParallel
model = MyModel()
model = train.torch.prepare_model(model) # Ray handles DDP setup
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
dataset = train.get_dataset_shard("train")
for epoch in range(config["epochs"]):
for batch in dataset.iter_torch_batches(batch_size=64):
...
loss.backward()
optimizer.step()
# Report metrics to Ray Train
train.report({"loss": loss.item(), "epoch": epoch})
# Launch distributed training across 4 GPUs
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config={"lr": 0.001, "epochs": 10},
scaling_config=ScalingConfig(
num_workers=4,
use_gpu=True,
resources_per_worker={"GPU": 1, "CPU": 4},
),
datasets={"train": ray.data.read_csv("data/train.csv")},
)
result = trainer.fit()
print(f"Best checkpoint: {result.best_checkpoints}")
Ray Tune for Distributed HPO🔗
from ray import tune
from ray.tune.schedulers import ASHAScheduler
def train_with_config(config):
model = GradientBoostingClassifier(
n_estimators=config["n_estimators"],
learning_rate=config["lr"],
max_depth=config["max_depth"],
)
model.fit(X_train, y_train)
acc = accuracy_score(y_val, model.predict(X_val))
tune.report(accuracy=acc)
tuner = tune.Tuner(
train_with_config,
param_space={
"n_estimators": tune.randint(50, 500),
"lr": tune.loguniform(0.001, 0.3),
"max_depth": tune.randint(3, 10),
},
tune_config=tune.TuneConfig(
metric="accuracy",
mode="max",
num_samples=50, # 50 trials in parallel
scheduler=ASHAScheduler(max_t=10, grace_period=1),
),
run_config=train.RunConfig(
name="churn-hpo",
storage_path="gs://my-bucket/ray-results",
),
)
results = tuner.fit()
best = results.get_best_result("accuracy", "max")
print(f"Best accuracy: {best.metrics['accuracy']:.4f}")
print(f"Best config: {best.config}")
19.5 Distributed Training on Vertex AI🔗
from google.cloud import aiplatform
aiplatform.init(project="my-project", location="us-central1")
# Multi-GPU distributed training on Vertex AI
job = aiplatform.CustomTrainingJob(
display_name="distributed-training",
script_path="train_distributed.py",
container_uri="us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.2-0:latest",
requirements=["torchvision"],
)
model = job.run(
machine_type="a2-highgpu-4g", # 4x A100 80GB GPUs
accelerator_type="NVIDIA_TESLA_A100",
accelerator_count=4,
replica_count=2, # 2 machines × 4 GPUs = 8 GPUs total
reduction_server_count=0,
)
19.6 Distributed Training Tools Summary🔗
| Tool | Best For | Protocol | Scale |
|---|---|---|---|
| PyTorch DDP | Standard DL training | NCCL | Multi-GPU, multi-node |
| Horovod | Framework-agnostic | NCCL/Gloo | Multi-GPU, multi-node |
| Ray Train | Python ML, HPO | Ray | Multi-node clusters |
| DeepSpeed | LLMs, ZeRO optimizer | NCCL | Massive scale (1000+ GPUs) |
| Megatron-LM | Transformer LLMs | NCCL | Very large scale |
| Vertex AI Training | Managed GCP | Auto | Managed scale |
| Spark MLlib | Traditional ML at scale | Spark | Petabyte-scale |
Next → Chapter 20: Feature Stores