Skip to content

Training Systems Architecture

This page gives you the mental model to explain a distributed training system before you write a line of PyTorch.

The Simplest Architecture That Still Looks Senior

Section titled “The Simplest Architecture That Still Looks Senior”
flowchart LR
  A[Job Spec] --> B[Orchestrator]
  B --> C[Rendezvous + World Setup]
  B --> D[Trainer Rank 0]
  B --> E[Trainer Rank 1..N]
  F[Dataset / Feature Store] --> G[Shard + Sampler Layer]
  G --> D
  G --> E
  D --> H[Checkpoint Store]
  E --> H
  D --> I[Metrics / Logs / Traces]
  E --> I
Most interview answers improve immediately once this separation is visible.

The interviewer does not need your exact internal platform. They need evidence that you can separate concerns:

  • orchestration decides when work starts and where it runs
  • rendezvous decides which ranks belong to the job
  • trainer runtime performs forward, backward, and optimizer steps
  • data plane ensures each rank gets the correct sample stream
  • artifact plane stores checkpoints and model outputs
  • observability plane tells operators whether the run is healthy

In PyTorch distributed training, these terms should be automatic:

TermMeaningWhy interviewers care
world_sizeTotal process count in the jobDetermines collective behavior and global batch semantics
rankUnique process index across the jobUsed for role assignment and output suppression
local_rankDevice index local to a nodeNeeded for correct device pinning
process groupCommunicator over a set of ranksBecomes central when discussing hybrid parallelism
def train_job(cfg: TrainConfig) -> None:
dist.init_process_group(
backend=cfg.backend,
init_method=cfg.init_method,
world_size=cfg.world_size,
rank=cfg.rank,
)
torch.cuda.set_device(cfg.local_rank)
model = build_model(cfg).to(cfg.device)
optimizer = build_optimizer(model, cfg)
sampler = build_sampler(cfg.dataset, cfg.rank, cfg.world_size, cfg.seed)
loader = build_loader(cfg.dataset, sampler, cfg)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[cfg.local_rank],
output_device=cfg.local_rank,
gradient_as_bucket_view=True,
)
state = maybe_restore_checkpoint(model, optimizer, sampler, cfg)
for epoch in range(state.epoch, cfg.max_epochs):
sampler.set_epoch(epoch)
for batch in loader:
loss = train_step(model, optimizer, batch, cfg)
emit_metrics(loss=loss.item(), rank=cfg.rank)
save_checkpoint(model, optimizer, sampler, epoch, cfg)

This is intentionally boring. Boring is good in interviews if you can explain why:

  • DDP is the most legible baseline
  • sampler state is part of correctness, not just convenience
  • rank-aware metrics prevent log duplication
  • checkpoint save frequency is an RPO tradeoff, not a cosmetic choice

The current official DDP docs make two points worth quoting into your mental model:

  • DDP synchronizes gradients across model replicas, but it does not shard inputs for you. PyTorch explicitly puts input partitioning on the user side, usually through DistributedSampler.
  • gradient_as_bucket_view=True can reduce peak memory and avoid gradient-to-bucket copies, but the gradients become views and some in-place assumptions break.

That matters in interviews because it separates:

  • model replication semantics
  • communication semantics
  • data partitioning semantics

1. Global batch semantics must be explicit

Section titled “1. Global batch semantics must be explicit”

If each rank processes micro_batch_size=8 and world_size=8, your effective global batch is at least 64, and larger if you add gradient accumulation. State it clearly.

2. Storage is part of the training algorithm

Section titled “2. Storage is part of the training algorithm”

Checkpoint throughput, atomicity, and restore latency affect whether your pipeline is operationally viable. Treat storage as a first-class design dependency.

3. Logging is a distributed systems problem

Section titled “3. Logging is a distributed systems problem”

If all ranks write identical logs, you have noise. If only rank 0 writes, you may miss local failures. The practical answer is:

  • structured logs on every rank
  • dense console output only on rank 0
  • counters and histograms aggregated centrally
flowchart TD
  A[Notebook Cell] --> B[Config dataclass]
  A --> C[Dataset stub]
  A --> D[Trainer wrapper]
  D --> E[Mock metrics sink]
  D --> F[Checkpoint adapter]
  D --> G[Launch abstraction]
Even in Colab, shape the code like a system: adapters at the edges, trainer logic in the middle.

A strong notebook does not fully implement Kubernetes launch, object-store auth, or Prometheus exporters. It does preserve the boundaries where those concerns would attach.

DecisionWhy you choose it firstWhat cost it creates later
DDP over FSDPLowest cognitive load, reliable default for live codingLess memory efficiency at larger model sizes
full checkpointsSimplest restore semanticsLarge I/O footprint and slower save times
map-style datasetEasy deterministic shardingHarder to model streaming datasets
rank-0 orchestration decisionsClear control pathCan become a scaling bottleneck for advanced coordination