Skip to content

All-In-One Handbook

This page compresses the rest of the site into one long-form reference. It is the page to open if you want the entire story in one place: architecture, trainer code, data sharding, checkpointing, observability, performance tuning, and hard interview questions.

Use this when you want:

  • one continuous explanation instead of many short docs
  • one full trainer example you can study line by line
  • one place to rehearse the talk track for a senior/staff interview
  • one page you can skim in the final hour before the interview

Before writing code, define the system in terms of planes:

  • control plane: launch, retry policy, placement, config, run metadata
  • training plane: forward, backward, optimizer step, gradient synchronization
  • data plane: dataset partitioning, loading, preprocessing, host-to-device movement
  • artifact plane: checkpoints, manifests, final model outputs
  • observability plane: metrics, logs, traces, alerts
flowchart LR
  A[Job Spec] --> B[Launcher / Orchestrator]
  B --> C[Rendezvous]
  C --> D[Rank 0]
  C --> E[Rank 1..N]
  F[Dataset or Feature Store] --> G[Distributed Sampler]
  G --> D
  G --> E
  D --> H[Checkpoint Store]
  E --> H
  D --> I[Metrics + Logs]
  E --> I
This is the minimum system shape that still sounds like a production-minded training pipeline.

The First Principles You Need To Say Out Loud

Section titled “The First Principles You Need To Say Out Loud”

In a good interview answer, the first few sentences should establish invariants:

  1. Each rank must know its identity and device ownership.
  2. The input stream must be partitioned deterministically across ranks.
  3. The optimizer state must move forward in lock-step with the model state.
  4. Resume semantics must be explicit instead of hand-waved.
  5. We need enough observability to separate input stalls, compute inefficiency, and communication bottlenecks.

That framing is often more important than the exact code.

The default distributed baseline is:

  • one process per GPU
  • one DDP-wrapped model replica per process
  • one rank-aware sampler per process
  • one shared checkpoint contract across all processes
flowchart TD
  A[Node] --> B[local_rank 0 -> cuda:0]
  A --> C[local_rank 1 -> cuda:1]
  B --> D[rank 0]
  C --> E[rank 1]
  D --> F[full model replica]
  E --> G[full model replica]
  F --> H[gradient all-reduce]
  G --> H
This is the most legible interview baseline because it keeps process ownership and collective semantics clear.

The goal of this code is not to be a turnkey production package. The goal is to show a complete, production-shaped training skeleton that you can explain under pressure.

from __future__ import annotations
import json
import os
import random
import time
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from types import SimpleNamespace
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
@dataclass
class TrainConfig:
backend: str = "nccl"
init_method: str = "env://"
world_size: int = 1
rank: int = 0
local_rank: int = 0
seed: int = 17
dataset_size: int = 20_000
input_width: int = 256
hidden_width: int = 512
num_classes: int = 8
micro_batch_size: int = 16
grad_accum_steps: int = 1
num_workers: int = 2
learning_rate: float = 3e-4
weight_decay: float = 0.01
max_epochs: int = 3
log_every: int = 20
checkpoint_every_steps: int = 200
checkpoint_dir: str = "/tmp/torch-control-plane/checkpoints"
run_id: str = "demo-run"
use_amp: bool = True
@property
def is_distributed(self) -> bool:
return self.world_size > 1
@property
def is_main_rank(self) -> bool:
return self.rank == 0
@property
def device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{self.local_rank}")
return torch.device("cpu")
@property
def global_batch_size(self) -> int:
return self.micro_batch_size * self.grad_accum_steps * self.world_size
def digest(self) -> str:
payload = json.dumps(asdict(self), sort_keys=True)
return str(abs(hash(payload)))
class ToyClassificationDataset(Dataset):
def __init__(self, size: int, width: int, num_classes: int) -> None:
g = torch.Generator().manual_seed(1234)
self.features = torch.randn(size, width, generator=g)
self.labels = torch.randint(0, num_classes, (size,), generator=g)
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
return {
"inputs": self.features[index],
"targets": self.labels[index],
}
class ResumeAwareDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: int,
rank: int,
seed: int = 0,
consumed: int = 0,
**kwargs,
) -> None:
super().__init__(
dataset,
num_replicas=num_replicas,
rank=rank,
seed=seed,
**kwargs,
)
self.consumed = consumed
def state_dict(self) -> dict[str, int]:
return {"epoch": self.epoch, "consumed": self.consumed}
def load_state_dict(self, state: dict[str, int]) -> None:
self.epoch = state["epoch"]
self.consumed = state["consumed"]
def __iter__(self):
indices = list(super().__iter__())
start = min(self.consumed, len(indices))
for index in indices[start:]:
yield index
self.consumed += 1
class TinyNet(nn.Module):
def __init__(self, width: int, hidden: int, num_classes: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(width, hidden),
nn.GELU(),
nn.Linear(hidden, hidden),
nn.GELU(),
nn.Linear(hidden, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def setup_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def maybe_init_dist(cfg: TrainConfig) -> None:
if not cfg.is_distributed:
return
if torch.cuda.is_available():
torch.cuda.set_device(cfg.local_rank)
dist.init_process_group(
backend=cfg.backend,
init_method=cfg.init_method,
world_size=cfg.world_size,
rank=cfg.rank,
)
def cleanup_dist() -> None:
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
def unwrap(model: nn.Module) -> nn.Module:
return model.module if hasattr(model, "module") else model
def latest_checkpoint_path(checkpoint_dir: str) -> Path | None:
base = Path(checkpoint_dir)
if not base.exists():
return None
candidates = sorted(base.glob("step-*.pt"))
return candidates[-1] if candidates else None
def ensure_checkpoint_dir(path: str) -> None:
Path(path).mkdir(parents=True, exist_ok=True)
def log_rank_event(cfg: TrainConfig, event: str, **fields) -> None:
payload = {
"event": event,
"rank": cfg.rank,
"local_rank": cfg.local_rank,
"run_id": cfg.run_id,
**fields,
}
print(json.dumps(payload, sort_keys=True))
@contextmanager
def phase(timer_store: dict[str, list[float]], name: str):
started = time.perf_counter()
try:
yield
finally:
timer_store.setdefault(name, []).append(time.perf_counter() - started)
def move_batch_to_device(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
def build_dataloader(cfg: TrainConfig, dataset: Dataset, sampler: DistributedSampler) -> DataLoader:
return DataLoader(
dataset,
batch_size=cfg.micro_batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=torch.cuda.is_available(),
persistent_workers=cfg.num_workers > 0,
drop_last=True,
)
def build_checkpoint(
step: int,
epoch: int,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: GradScaler | None,
sampler: ResumeAwareDistributedSampler,
cfg: TrainConfig,
) -> dict:
return {
"step": step,
"epoch": epoch,
"model": unwrap(model).state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict() if scaler else None,
"sampler": sampler.state_dict(),
"rng": {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
},
"metadata": {
"run_id": cfg.run_id,
"world_size": cfg.world_size,
"config_digest": cfg.digest(),
"global_batch_size": cfg.global_batch_size,
},
}
def save_checkpoint(
step: int,
epoch: int,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: GradScaler | None,
sampler: ResumeAwareDistributedSampler,
cfg: TrainConfig,
) -> None:
if not cfg.is_main_rank:
return
ensure_checkpoint_dir(cfg.checkpoint_dir)
state = build_checkpoint(step, epoch, model, optimizer, scaler, sampler, cfg)
path = Path(cfg.checkpoint_dir) / f"step-{step:08d}.pt"
torch.save(state, path)
log_rank_event(cfg, "checkpoint_saved", path=str(path), step=step, epoch=epoch)
def restore_if_present(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: GradScaler | None,
sampler: ResumeAwareDistributedSampler,
cfg: TrainConfig,
) -> SimpleNamespace:
path = latest_checkpoint_path(cfg.checkpoint_dir)
if path is None:
return SimpleNamespace(step=0, epoch=0)
state = torch.load(path, map_location="cpu")
unwrap(model).load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
if scaler and state["scaler"] is not None:
scaler.load_state_dict(state["scaler"])
sampler.load_state_dict(state["sampler"])
random.setstate(state["rng"]["python"])
np.random.set_state(state["rng"]["numpy"])
torch.set_rng_state(state["rng"]["torch"])
if torch.cuda.is_available() and state["rng"]["cuda"] is not None:
torch.cuda.set_rng_state_all(state["rng"]["cuda"])
log_rank_event(
cfg,
"checkpoint_restored",
path=str(path),
step=state["step"],
epoch=state["epoch"],
)
return SimpleNamespace(step=state["step"], epoch=state["epoch"])
def train_step(
model: nn.Module,
optimizer: torch.optim.Optimizer,
batch: dict[str, torch.Tensor],
scaler: GradScaler | None,
cfg: TrainConfig,
timer_store: dict[str, list[float]],
) -> float:
with phase(timer_store, "h2d"):
batch = move_batch_to_device(batch, cfg.device)
with phase(timer_store, "forward"):
with autocast(device_type=cfg.device.type, enabled=scaler is not None):
logits = model(batch["inputs"])
loss = nn.functional.cross_entropy(logits, batch["targets"])
with phase(timer_store, "backward"):
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
with phase(timer_store, "optimizer"):
if scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
return float(loss.detach().cpu().item())
def summarize_timers(timer_store: dict[str, list[float]]) -> dict[str, float]:
return {
name: float(sum(values) / max(len(values), 1))
for name, values in timer_store.items()
}
def run(cfg: TrainConfig) -> None:
setup_seed(cfg.seed)
maybe_init_dist(cfg)
dataset = ToyClassificationDataset(
size=cfg.dataset_size,
width=cfg.input_width,
num_classes=cfg.num_classes,
)
sampler = ResumeAwareDistributedSampler(
dataset,
num_replicas=cfg.world_size,
rank=cfg.rank,
seed=cfg.seed,
shuffle=True,
)
loader = build_dataloader(cfg, dataset, sampler)
model = TinyNet(
width=cfg.input_width,
hidden=cfg.hidden_width,
num_classes=cfg.num_classes,
).to(cfg.device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg.learning_rate,
weight_decay=cfg.weight_decay,
)
scaler = GradScaler(device=cfg.device.type, enabled=cfg.use_amp and cfg.device.type == "cuda")
if cfg.is_distributed:
model = DDP(
model,
device_ids=[cfg.local_rank] if cfg.device.type == "cuda" else None,
output_device=cfg.local_rank if cfg.device.type == "cuda" else None,
gradient_as_bucket_view=True,
)
state = restore_if_present(model, optimizer, scaler, sampler, cfg)
step = state.step
try:
for epoch in range(state.epoch, cfg.max_epochs):
sampler.set_epoch(epoch)
timer_store: dict[str, list[float]] = {}
for batch_index, batch in enumerate(loader):
loss = train_step(model, optimizer, batch, scaler, cfg, timer_store)
step += 1
if cfg.is_main_rank and step % cfg.log_every == 0:
metrics = summarize_timers(timer_store)
log_rank_event(
cfg,
"train_progress",
step=step,
epoch=epoch,
batch_index=batch_index,
loss=round(loss, 4),
avg_h2d_ms=round(metrics.get("h2d", 0.0) * 1000, 2),
avg_forward_ms=round(metrics.get("forward", 0.0) * 1000, 2),
avg_backward_ms=round(metrics.get("backward", 0.0) * 1000, 2),
avg_optimizer_ms=round(metrics.get("optimizer", 0.0) * 1000, 2),
)
if step % cfg.checkpoint_every_steps == 0:
save_checkpoint(step, epoch, model, optimizer, scaler, sampler, cfg)
save_checkpoint(step, epoch, model, optimizer, scaler, sampler, cfg)
finally:
cleanup_dist()
if __name__ == "__main__":
cfg = TrainConfig(
world_size=int(os.environ.get("WORLD_SIZE", "1")),
rank=int(os.environ.get("RANK", "0")),
local_rank=int(os.environ.get("LOCAL_RANK", "0")),
)
run(cfg)

The code is easier to defend if you explain it in layers rather than top to bottom.

TrainConfig carries:

  • process identity
  • batch semantics
  • I/O policy
  • optimizer settings
  • checkpoint settings

The strong sentence here is:

“I keep the effective batch semantics on the config because topology changes should not silently change training behavior.”

The dataset itself is boring. That is fine. The interesting part is the sampler:

  • it is rank-aware
  • it supports deterministic reshuffling via set_epoch
  • it stores consumed progress so resume semantics are explicit

This is one of the most important parts of the entire page. Many weak distributed-training answers talk about DDP and never talk about sampler correctness.

The maybe_init_dist() function exists for one reason: separate local single-process behavior from distributed process-group behavior.

That lets you say:

  • the training loop stays mostly the same
  • the launcher changes the process context
  • DDP is a wrapper over a process group, not a magical all-in-one training platform

The DDP wrap is intentionally minimal:

  • wrap after moving to the right device
  • pass one device id per process for GPU training
  • keep the baseline simple and explainable

This is the right place to mention that current PyTorch DDP docs still make clear that DDP handles gradient synchronization, not input sharding.

The code uses autocast() and GradScaler.

Important talking points:

  • autocast scopes the forward pass and loss computation
  • GradScaler protects optimizer steps under mixed precision
  • AMP improves throughput and reduces memory, but it introduces another piece of state that must be checkpointed

The checkpoint contains:

  • model state
  • optimizer state
  • scaler state
  • sampler state
  • RNG state
  • metadata

That is the right modern answer for a strong interview. Saving weights alone is not a recovery plan.

This is the section where you sound senior instead of merely API-literate.

  • each rank sees a deterministic shard
  • reshuffle order changes by epoch when set_epoch() is called
  • the checkpoint captures enough information to restart without optimizer drift
  • global batch math is explicit
flowchart TD
  A[Incorrect training results] --> B[Duplicate samples across ranks]
  A --> C[Skipped samples after resume]
  A --> D[Changed batch semantics after topology shift]
  A --> E[Stale optimizer or scaler state]
The most dangerous failures are often silent: the job runs, loss still moves, and the model quality quietly degrades.

If an interviewer asks “how do you know this distributed trainer is correct?”, answer in this order:

  1. define correctness
  2. name the state that must survive
  3. explain the sample-partitioning invariant
  4. explain the resume validation checks

Do not discuss performance as a single number. Discuss it by phase.

flowchart LR
  A[Data wait] --> B[H2D copy]
  B --> C[Forward]
  C --> D[Backward]
  D --> E[Gradient sync]
  E --> F[Optimizer step]
  F --> G[Checkpoint / side work]
A useful training performance answer decomposes step time instead of saying 'GPU utilization is low.'

The timer store in the sample code supports the most important performance conversation:

  • Are we input-bound?
  • Are we communication-bound?
  • Are we compute-bound?
  • Are we stalling on checkpointing or artifact writes?
SymptomLikely sourceFirst move
GPU idle before forwarddata loader, preprocessing, host-to-device copyinspect loader workers, caching, prefetch
long backward tailcollective communicationinspect DDP sync cost, topology, bucket behavior
periodic step spikescheckpoint or storage jittercorrelate spikes with save cadence
OOM after scalingactivation footprint or optimizer statelower micro-batch, use checkpointing, reconsider parallelism

A mature answer distinguishes:

  • checkpoint frequency
  • checkpoint format
  • checkpoint atomicity
  • checkpoint restore policy

The page’s code assumes a conservative default:

  • full job restart
  • latest good checkpoint
  • same topology preferred

This is a good default for interviews because it is correct and easy to explain.

Bring up torch.distributed.checkpoint when:

  • model state is large enough that rank-parallel checkpointing matters
  • you want to discuss resharding across topology changes
  • the interviewer wants a current PyTorch-native answer beyond plain torch.save

After restore, validate:

  • step monotonicity
  • learning rate continuity
  • optimizer state presence
  • scaler state presence
  • sampler progress continuity
  • loss continuity over the next few steps

You do not need a specific vendor story. You do need a telemetry story.

  • structured logs on all ranks
  • dense human-readable progress on rank 0
  • explicit events for checkpoint save and restore
  • loss
  • learning rate
  • samples/sec
  • step time by phase
  • checkpoint latency
  • restart count
  • all-reduce time if you expose it

Use traces when you want to correlate:

  • launcher events
  • storage stalls
  • long-tail step latency
  • checkpoint publishing
flowchart LR
  A[Trainer ranks] --> B[Structured logs]
  A --> C[Metrics]
  A --> D[Spans]
  B --> E[Log backend]
  C --> F[Metrics backend]
  D --> G[Trace backend]
  E --> H[Operator view]
  F --> H
  G --> H
The observability plane is separate from the training logic, but it must reflect training semantics.

Start with DDP unless memory pressure proves it is insufficient.

  • easiest to explain
  • easiest to debug
  • clearest baseline for shared-screen interviews
  • helps when model state is the memory bottleneck
  • introduces more checkpoint, wrapping, and state-dict complexity
  • is worth mentioning, but not usually worth live-coding first
  • reduces memory pressure
  • increases compute
  • can affect determinism and runtime characteristics

The strong answer is never “use every optimization.” The strong answer is “introduce the next complexity only when the current bottleneck is clear.”

If the exercise is in Colab, do not paste an entire giant module at once. Instead, split the same logic into notebook-friendly cells:

  1. config
  2. dataset and sampler
  3. model and optimizer
  4. distributed init helper
  5. train step
  6. checkpoint helpers
  7. main loop

This gives you narration points between cells and keeps the interviewer inside your reasoning.

Because DDP is the smallest correct distributed baseline. If the model fits, DDP keeps failure semantics and debugging simpler. I would move to FSDP when replicated model state becomes the bottleneck.

”What is the most dangerous silent failure?”

Section titled “”What is the most dangerous silent failure?””

Silent sample duplication or omission across ranks, because the system can look healthy while learning on the wrong data distribution.

Model, optimizer, scaler, sampler, RNG, step metadata, and enough configuration identity to detect incompatible restores.

”How would you debug a slowdown that happens every 20 minutes?”

Section titled “”How would you debug a slowdown that happens every 20 minutes?””

I would correlate the slowdown with checkpoint cadence, storage writes, and rank-level step-time breakdowns before assuming a model-side issue.

”What would you deliberately not build during the interview?”

Section titled “”What would you deliberately not build during the interview?””

Cloud-specific auth, scheduler-specific orchestration, vendor-specific telemetry exporters, and advanced hybrid parallelism. I would preserve interfaces for them, but I would not burn interview time implementing them.

If you need one closing paragraph:

“I built the smallest correct distributed trainer that still preserves production boundaries: rank-aware initialization, deterministic data partitioning, synchronized optimizer progress, resumable state, and observable step behavior. I would start with DDP because it maximizes clarity and correctness in a live exercise, then add FSDP, distributed checkpointing, or more complex topology only when memory, I/O, or communication data proves the simpler baseline is insufficient.”

These are the official docs this page is aligned with: