Skip to content

Petabyte-Scale Image Training With DDP

This page is the image counterpart to the molecular walkthrough: same DDP baseline, very different data plane. At petabyte scale the model is often not the first problem. File layout, cache strategy, decode placement, and transform economics decide whether the GPUs train or wait.

The scenario below is also hypothetical. Use it as a production-shaped example, not as a claim about one specific public dataset.

Assume a web-scale vision pretraining corpus with:

  • 9.5 billion images
  • 1.4 PB of compressed JPEG and WebP objects
  • captions or class labels stored as sidecar metadata
  • a ViT-sized model that still fits per GPU, so plain DDP remains the simplest correct trainer
FieldTypeWhy it exists
sample_idstringStable dedupe and replay key
image_uristringObject-store path to compressed bytes
captionstringOptional contrastive text supervision
labelint or nullClassification target if present
widthintPrecomputed shape metadata for bucketing and QA
heightintPrecomputed shape metadata for bucketing and QA
mime_typestringJPEG, PNG, or WebP handling
quality_flagsbitfieldCorrupt decode, NSFW filter, duplicate detection, etc.
shard_idstringTraining scheduling and cache accounting

Do not train from billions of individual files if you can avoid it.

The real system wants:

  • immutable medium-sized shards
  • sequential reads
  • predictable cache behavior
  • enough metadata in the manifest to route work without opening every image

Config Management Has To Encode The Data Contract

Section titled “Config Management Has To Encode The Data Contract”

At petabyte scale, config management is not just about learning rate and batch size. It is how you make a run explainable later.

Config areaExample fieldsWhy it matters
topologyworld_size, micro_batch_size, grad_accum_stepschanges optimizer cadence and effective global batch
dataset definitionmanifest_uri, manifest_version, filter_policybinds the run to one corpus view
featurizationdecode_backend, crop_size, augmentation_policycontrols throughput and train/eval comparability
cache and I/Ocache_dir, prefetch_depth, max_open_shardsdetermines whether storage can keep up with the trainer
evaluation policyeval_interval_steps, val_manifest_version, top_k_listties decision-making to a stable validation contract
from dataclasses import dataclass
@dataclass(frozen=True)
class VisionTrainConfig:
run_name: str
seed: int
world_size: int
micro_batch_size: int
grad_accum_steps: int
manifest_uri: str
manifest_version: str
decode_backend: str
crop_size: int
augmentation_policy: str
cache_dir: str
eval_interval_steps: int
val_manifest_version: str

The sentence I would say directly is:

“If I cannot reconstruct the exact corpus view, transform policy, and topology from config plus checkpoint metadata, the run is not really reproducible.”

flowchart LR
  A[Config payload] --> B[Manifest version]
  A --> C[Transform policy]
  A --> D[Topology and batch math]
  B --> E[Run metadata]
  C --> E
  D --> E
  E --> F[Checkpoint lineage]
  F --> G[Resume or postmortem]
For vision training, the run is only auditable if config, corpus view, and transform policy survive into checkpoint metadata.
flowchart LR
  A[Raw image objects] --> B[Offline filtering + dedupe]
  B --> C[Shard packing + manifest build]
  C --> D[Rank-aware shard assignment]
  D --> E[Local NVMe cache]
  E --> F[CPU or GPU decode]
  F --> G[Crop / resize / normalize]
  G --> H[Pinned host memory]
  H --> I[DDP trainer rank 0..N]
  I --> J[Checkpoint + metrics]
The cluster only looks compute-heavy from a distance. In practice, storage and decode dominate unless you engineer them deliberately.
  • shard compressed images into tar, parquet, or another sequential container
  • target roughly 512 MB to 2 GB per shard, depending on cache and network behavior
  • keep sidecar metadata with the image bytes or in a co-located index file
  • avoid millions of tiny remote object requests in steady-state training
  • include row counts and checksums so incomplete shards can be rejected early
Problem with file-per-imageWhat sharding fixes
object-store metadata overheadone request fetches many samples
poor throughput under high fanoutsequential or batched reads improve bandwidth utilization
expensive local cache indexingcache at shard granularity
resume is hard to reason aboutshards provide a natural replay unit

Large image jobs benefit from more offline prep than many teams expect.

StepOffline or onlineWhy
dedupe and content filteringofflinetoo expensive and too important to repeat
corrupt-file detectionofflinedecode failures in hot path cause rank skew
width / height extractionofflineuseful for sampling and crop diagnostics
shard packingofflinecritical for object-store efficiency
heavy resizing to many variantssometimes offlineworth it if the same sizes are reused repeatedly
final stochastic crop / fliponlinecheap, training-specific, and intentionally random
normalizationonlinetrivial tensor work

The clean mental split is:

move expensive global hygiene offline; keep cheap training-time randomness online

For images, “featurization” usually means decode plus transforms.

  • local worker reads compressed bytes from shard cache
  • torchvision.io.decode_image or PIL decodes into tensor form
  • worker applies resize and crop
  • pin_memory=True lets host-to-device copies overlap better
  • keep bytes compressed until late
  • decode or resize with GPU-aware tooling when CPU decode becomes the bottleneck
  • move simple color or normalization work onto device

Use the CPU-first path until you have evidence it is the bottleneck. It is easier to debug and usually good enough for moderate cluster sizes.

At this scale, an IterableDataset is usually a better fit than pretending the corpus is a local random-access array.

from __future__ import annotations
import io
import random
import tarfile
from dataclasses import dataclass
import torch
from PIL import Image
from torch.utils.data import IterableDataset
@dataclass
class StreamState:
epoch: int = 0
shard_offset: int = 0
sample_offset: int = 0
class ImageShardDataset(IterableDataset):
def __init__(self, manifest, rank: int, world_size: int, seed: int, state: StreamState | None = None):
self.manifest = manifest
self.rank = rank
self.world_size = world_size
self.seed = seed
self.state = state or StreamState()
def state_dict(self) -> dict[str, int]:
return {
"epoch": self.state.epoch,
"shard_offset": self.state.shard_offset,
"sample_offset": self.state.sample_offset,
}
def set_epoch(self, epoch: int) -> None:
self.state = StreamState(epoch=epoch)
def _shards_for_rank(self):
rng = random.Random(self.seed + self.state.epoch)
shards = list(self.manifest)
rng.shuffle(shards)
return shards[self.rank :: self.world_size]
def __iter__(self):
shards = self._shards_for_rank()
for shard_idx, shard in enumerate(shards[self.state.shard_offset :], start=self.state.shard_offset):
with open_local_cached_shard(shard["uri"]) as fp:
with tarfile.open(fileobj=fp, mode="r|*") as tf:
sample_idx = 0
for member in tf:
if not member.isfile() or not member.name.endswith(".jpg"):
continue
if shard_idx == self.state.shard_offset and sample_idx < self.state.sample_offset:
sample_idx += 1
continue
image = Image.open(io.BytesIO(tf.extractfile(member).read())).convert("RGB")
self.state.shard_offset = shard_idx
self.state.sample_offset = sample_idx
sample_idx += 1
yield {"image": image, "label": shard["label_lookup"].get(member.name, -1)}
self.state.sample_offset = 0

That code is simplified, but the important bit is the contract:

  • deterministic shard ownership by rank
  • explicit replay state
  • local cache boundary hidden behind open_local_cached_shard()

If every worker reads every shard straight from remote storage, your cluster becomes an object-store benchmark.

  • cache shards on node-local NVMe
  • download ahead in a background thread or helper process
  • keep cache accounting per node, not per worker, so workers share data
  • evict by shard, not by file
  • record cache hit rate as a first-class metric

For a training cluster, local cache policy is part of the data plane, not an implementation detail.

Reproducibility Means Stable Contracts, Not Magic

Section titled “Reproducibility Means Stable Contracts, Not Magic”

Image pipelines include stochastic crops, shuffle order, cache timing, and occasionally non-deterministic decode behavior. So be exact about the contract.

  • code revision or image digest
  • training and validation manifest versions
  • augmentation policy version
  • crop size and normalization constants
  • global seed and per-epoch shuffle seed rule
  • checkpoint lineage and resume step
  • bitwise-identical results across different GPU counts
  • identical sample order if a streaming source is intentionally at-least-once
  • perfectly identical timing once cache warmup or remote storage conditions change

The senior framing is:

“I optimize for explainable and statistically repeatable runs, and I reserve exact replay for controlled debugging paths.”

flowchart TD
  A[Resume request] --> B{Same val and train manifests?}
  B -->|No| C[New run or offline comparison only]
  B -->|Yes| D{Same topology and transform policy?}
  D -->|Yes| E[Closer replay path]
  D -->|No| F[Statistical continuation]
  E --> G[Resume checkpoint]
  F --> G
At petabyte scale, reproducibility is a contract over manifests, transforms, and topology, not just a seed.

Vision workloads often waste performance in one of two ways:

  • tiny batches because decode is slow
  • expensive dynamic shapes because crops or aspect ratios are unmanaged
  • normalize final tensor shapes before the model boundary
  • use drop_last=True so DDP sees even work in steady state
  • if aspect ratio matters, use a bounded set of crop buckets instead of arbitrary image shapes
  • prefetch enough batches that short storage hiccups do not starve the device
import torch
import torchvision.transforms.v2 as T
train_tfms = T.Compose(
[
T.ToImage(),
T.RandomResizedCrop(size=(224, 224), antialias=True),
T.RandomHorizontalFlip(p=0.5),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
def collate_images(batch):
images = torch.stack([train_tfms(sample["image"]) for sample in batch])
labels = torch.tensor([sample["label"] for sample in batch], dtype=torch.long)
return {"images": images, "labels": labels}

If your transforms are much heavier than this, measure whether they belong offline or on a different execution device.

SettingWhy it helps
persistent_workers=Trueavoid repeated worker startup cost
prefetch_factorsmooth short stalls between storage, decode, and train
pin_memory=Truebetter host-to-device transfer overlap
drop_last=Truekeep per-rank step counts aligned
num_workers tuningbalance decode throughput against CPU oversubscription

These do not fix bad storage layout, but they matter once the basics are correct.

from contextlib import nullcontext
import torch
from torch.amp import GradScaler, autocast
def train_epoch(model, loader, optimizer, scaler: GradScaler, cfg):
model.train()
optimizer.zero_grad(set_to_none=True)
for step, batch in enumerate(loader):
images = batch["images"].to(
cfg.device,
non_blocking=True,
memory_format=torch.channels_last,
)
labels = batch["labels"].to(cfg.device, non_blocking=True)
sync_ctx = model.no_sync() if (step + 1) % cfg.grad_accum_steps != 0 else nullcontext()
with sync_ctx:
with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp):
logits = model(images)
loss = torch.nn.functional.cross_entropy(logits, labels)
loss = loss / cfg.grad_accum_steps
scaler.scale(loss).backward()
if (step + 1) % cfg.grad_accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
  • DDP does synchronization; it does not solve image ingest.
  • channels_last often helps convolution-heavy models.
  • AMP is a throughput lever, but only after data starvation is under control.
  • Gradient accumulation changes optimizer cadence and should be explained, not hidden.

At petabyte scale, evaluation needs to be smaller, cleaner, and much more stable than the training corpus.

  • freeze a versioned validation manifest that is never mixed with train shards
  • keep validation transforms deterministic, typically resize plus center crop
  • maintain at least one cheap smoke-eval set and one more representative holdout
  • if the job is multimodal, version the text side of the validation data too
ObjectiveBetter metric setWhy one scalar is not enough
image classificationtop-1, top-5, per-class recalllong-tail classes disappear in a global average
multimodal contrastiveretrieval recall@K, median rankloss alone does not expose retrieval usefulness
ranking or recommendationNDCG, recall@K, slice metrics by traffic segmentdeployment quality depends on who the model fails for
generative or representationdownstream probe metrics and drift slicesproxy loss may improve while representation quality drops
  • run frequent small validation passes to catch regressions early
  • run larger validation passes less often to reduce cluster interruption
  • keep eval throughput separate from train throughput so the cost is visible
flowchart LR
  A[Versioned train shards] --> B[DDP train loop]
  C[Versioned validation shards] --> D[Deterministic val transforms]
  B --> E[Checkpoint]
  E --> F[Smoke eval]
  E --> G[Full holdout eval]
  D --> F
  D --> G
  F --> H[Fast guardrail decision]
  G --> I[Promotion or rollback decision]
The evaluation pipeline should be smaller and cleaner than training, with explicit separation between fast guardrails and decision-grade holdouts.

When asked how you monitor the system, split the answer into layers:

  1. data-plane health: read bandwidth, cache hit rate, decode time, corrupt-sample rate
  2. trainer health: images/sec, batch wait, all-reduce fraction, OOM and restart counts
  3. model quality: top-1 or recall@K, per-slice performance, calibration or drift checks

That is stronger than saying “we watch loss and GPU utilization.”

flowchart TD
  A[Monitoring stack] --> B[Data-plane health]
  A --> C[Trainer health]
  A --> D[Model quality]
  B --> E[read BW, cache hit, decode time]
  C --> F[images/sec, wait, all-reduce, restarts]
  D --> G[top-1, recall@K, slice drift]
A useful interview answer separates ingestion, trainer, and quality signals instead of collapsing them into one dashboard.

Image models are often better compile candidates than molecular sequence models because shapes can be made more regular.

Good conditions for torch.compile:

  • fixed final crop size
  • stable forward graph
  • limited control flow
  • enough steady-state steps to amortize compile overhead

Bad conditions:

  • lots of dynamic aspect-ratio branches in-model
  • constant re-specialization from unconstrained shapes
  • debugging a new training job where compile obscures basic failures

Petabyte-scale datasets make epoch-only checkpoints a weak story because an epoch may be extremely long.

Checkpoint at step intervals and store:

  • model, optimizer, and scaler state
  • logical epoch
  • dataset stream state or sampler state
  • manifest version
  • any cache or prefetch cursor state that affects replay semantics

For the data plane, at-least-once with bounded duplication is often a more honest target than pretending you have perfectly exact sample replay on a giant streaming corpus.

MetricWhy it matters
images/sec per ranktop-line throughput
object-store read bandwidthtells you whether storage is the actual limiter
cache hit rateproves local staging is doing useful work
decode time / batchseparates CPU pressure from model compute
batch wait timedirect signal for GPU starvation
step time skew across ranksexposes corrupt shards, slow nodes, or uneven work
dropped or corrupt sample ratecatches upstream data quality regressions
all-reduce fraction of steptells you when communication tuning is worth your time
validation freshnesstells you whether current quality signals still reflect the run
flowchart TD
  A[GPU utilization falls] --> B{Primary symptom}
  B --> C[High batch wait]
  B --> D[High decode time]
  B --> E[Rank skew]
  B --> F[OOM after increasing batch]
  C --> G[Storage or cache pipeline issue]
  D --> H[Too few workers or expensive transforms]
  E --> I[Corrupt shards or uneven input]
  F --> J[Activation or optimizer-state pressure]
Most petabyte-image incidents are input incidents first and model incidents second.

If the model stops fitting, or optimizer state dominates device memory, then DDP may no longer be sufficient. Until that point, resist the urge to complicate the trainer.

The first escalation is often not FSDP. It is:

  • better shard layout
  • better cache hit rate
  • faster decode path
  • better crop bucketing
  • cleaner observability

Only after the model-state problem is real should you add model-state complexity.

This is the full article in code form: config, shard streaming, local-cache boundary, transforms, DDP setup, evaluation, and checkpointing in one script.

from __future__ import annotations
import io
import json
import os
import random
import tarfile
import time
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tvm
import torchvision.transforms.v2 as T
from PIL import Image
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, IterableDataset
@dataclass
class TrainConfig:
backend: str = "nccl"
seed: int = 17
world_size: int = int(os.environ.get("WORLD_SIZE", "1"))
rank: int = int(os.environ.get("RANK", "0"))
local_rank: int = int(os.environ.get("LOCAL_RANK", "0"))
micro_batch_size: int = 128
grad_accum_steps: int = 2
max_epochs: int = 2
max_steps: int = 10_000
image_size: int = 224
num_classes: int = 1_000
learning_rate: float = 2e-4
weight_decay: float = 0.05
eval_interval_steps: int = 500
checkpoint_interval_steps: int = 500
train_manifest_path: str = "manifests/train_images.json"
val_manifest_path: str = "manifests/val_images.json"
cache_dir: str = "/tmp/tcp-image-cache"
checkpoint_dir: str = "artifacts/checkpoints/images"
use_amp: bool = True
@property
def is_distributed(self) -> bool:
return self.world_size > 1
@property
def is_main_rank(self) -> bool:
return self.rank == 0
@property
def device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{self.local_rank}")
return torch.device("cpu")
@property
def global_batch_size(self) -> int:
return self.micro_batch_size * self.grad_accum_steps * self.world_size
@dataclass
class StreamState:
epoch: int = 0
shard_offset: int = 0
sample_offset: int = 0
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_distributed(cfg: TrainConfig) -> None:
if not cfg.is_distributed:
return
if torch.cuda.is_available():
torch.cuda.set_device(cfg.local_rank)
dist.init_process_group(backend=cfg.backend, rank=cfg.rank, world_size=cfg.world_size)
def cleanup_distributed() -> None:
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
def load_manifest(path: str) -> list[dict]:
return json.loads(Path(path).read_text())
def ensure_local_shard(uri: str, cache_dir: str) -> Path:
cache_root = Path(cache_dir)
cache_root.mkdir(parents=True, exist_ok=True)
destination = cache_root / Path(uri).name
if destination.exists():
return destination
source = Path(uri.replace("file://", ""))
destination.write_bytes(source.read_bytes())
return destination
class ImageShardDataset(IterableDataset):
def __init__(
self,
manifest: list[dict],
rank: int,
world_size: int,
seed: int,
cache_dir: str,
state: StreamState | None = None,
) -> None:
self.manifest = manifest
self.rank = rank
self.world_size = world_size
self.seed = seed
self.cache_dir = cache_dir
self.state = state or StreamState()
def state_dict(self) -> dict[str, int]:
return asdict(self.state)
def load_state_dict(self, state: dict[str, int]) -> None:
self.state = StreamState(**state)
def set_epoch(self, epoch: int) -> None:
self.state = StreamState(epoch=epoch)
def _rank_shards(self) -> list[dict]:
shards = list(self.manifest)
random.Random(self.seed + self.state.epoch).shuffle(shards)
return shards[self.rank :: self.world_size]
def __iter__(self):
shards = self._rank_shards()
for shard_idx, shard in enumerate(
shards[self.state.shard_offset :], start=self.state.shard_offset
):
local_path = ensure_local_shard(shard["uri"], self.cache_dir)
with tarfile.open(local_path, mode="r:*") as archive:
sample_idx = 0
for member in archive:
if not member.isfile() or not member.name.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
continue
if shard_idx == self.state.shard_offset and sample_idx < self.state.sample_offset:
sample_idx += 1
continue
handle = archive.extractfile(member)
if handle is None:
continue
image = Image.open(io.BytesIO(handle.read())).convert("RGB")
self.state.shard_offset = shard_idx
self.state.sample_offset = sample_idx
sample_idx += 1
yield {
"image": image,
"label": int(shard["label_lookup"].get(member.name, -1)),
}
self.state.sample_offset = 0
def build_transforms(image_size: int, training: bool) -> T.Compose:
if training:
return T.Compose(
[
T.ToImage(),
T.RandomResizedCrop(size=(image_size, image_size), antialias=True),
T.RandomHorizontalFlip(p=0.5),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
return T.Compose(
[
T.ToImage(),
T.Resize(size=image_size + 32, antialias=True),
T.CenterCrop(size=(image_size, image_size)),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
def make_collate(transform: T.Compose):
def collate(batch: list[dict]) -> dict[str, torch.Tensor]:
images = torch.stack([transform(sample["image"]) for sample in batch])
labels = torch.tensor([sample["label"] for sample in batch], dtype=torch.long)
return {"images": images, "labels": labels}
return collate
def make_loader(dataset: IterableDataset, cfg: TrainConfig, training: bool) -> DataLoader:
return DataLoader(
dataset,
batch_size=cfg.micro_batch_size,
num_workers=8,
prefetch_factor=4,
persistent_workers=True,
pin_memory=torch.cuda.is_available(),
collate_fn=make_collate(build_transforms(cfg.image_size, training=training)),
drop_last=training,
)
def unwrap(model: nn.Module) -> nn.Module:
return model.module if hasattr(model, "module") else model
def save_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: GradScaler,
dataset: ImageShardDataset,
cfg: TrainConfig,
epoch: int,
global_step: int,
) -> None:
if not cfg.is_main_rank:
return
checkpoint_dir = Path(cfg.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
payload = {
"model": unwrap(model).state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
"dataset_state": dataset.state_dict(),
"epoch": epoch,
"global_step": global_step,
"config": asdict(cfg),
}
torch.save(payload, checkpoint_dir / f"step-{global_step:08d}.pt")
def restore_if_available(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: GradScaler,
dataset: ImageShardDataset,
cfg: TrainConfig,
) -> tuple[int, int]:
checkpoint_dir = Path(cfg.checkpoint_dir)
if not checkpoint_dir.exists():
return 0, 0
checkpoints = sorted(checkpoint_dir.glob("step-*.pt"))
if not checkpoints:
return 0, 0
payload = torch.load(checkpoints[-1], map_location="cpu")
unwrap(model).load_state_dict(payload["model"])
optimizer.load_state_dict(payload["optimizer"])
scaler.load_state_dict(payload["scaler"])
dataset.load_state_dict(payload["dataset_state"])
return int(payload["epoch"]), int(payload["global_step"])
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, cfg: TrainConfig) -> dict[str, float]:
model.eval()
loss_sum = torch.zeros(1, device=cfg.device)
count_sum = torch.zeros(1, device=cfg.device)
correct_sum = torch.zeros(1, device=cfg.device)
for batch in loader:
images = batch["images"].to(cfg.device, non_blocking=True, memory_format=torch.channels_last)
labels = batch["labels"].to(cfg.device, non_blocking=True)
logits = model(images)
loss = F.cross_entropy(logits, labels)
preds = logits.argmax(dim=-1)
loss_sum += loss
count_sum += labels.numel()
correct_sum += (preds == labels).sum()
if cfg.is_distributed:
dist.all_reduce(loss_sum)
dist.all_reduce(count_sum)
dist.all_reduce(correct_sum)
return {
"val_loss": (loss_sum / max(len(loader), 1)).item(),
"top1": (correct_sum / count_sum.clamp_min(1)).item(),
}
def train(cfg: TrainConfig) -> None:
seed_everything(cfg.seed)
setup_distributed(cfg)
train_dataset = ImageShardDataset(
manifest=load_manifest(cfg.train_manifest_path),
rank=cfg.rank,
world_size=cfg.world_size,
seed=cfg.seed,
cache_dir=cfg.cache_dir,
)
val_dataset = ImageShardDataset(
manifest=load_manifest(cfg.val_manifest_path),
rank=cfg.rank,
world_size=cfg.world_size,
seed=cfg.seed + 1000,
cache_dir=cfg.cache_dir,
)
train_loader = make_loader(train_dataset, cfg, training=True)
val_loader = make_loader(val_dataset, cfg, training=False)
model = tvm.resnet50(num_classes=cfg.num_classes).to(
cfg.device, memory_format=torch.channels_last
)
model = torch.compile(model) if hasattr(torch, "compile") else model
if cfg.is_distributed:
model = DDP(
model,
device_ids=[cfg.local_rank] if torch.cuda.is_available() else None,
output_device=cfg.local_rank if torch.cuda.is_available() else None,
gradient_as_bucket_view=True,
static_graph=True,
broadcast_buffers=False,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scaler = GradScaler("cuda", enabled=cfg.use_amp and torch.cuda.is_available())
start_epoch, global_step = restore_if_available(model, optimizer, scaler, train_dataset, cfg)
for epoch in range(start_epoch, cfg.max_epochs):
train_dataset.set_epoch(epoch)
model.train()
optimizer.zero_grad(set_to_none=True)
for batch_idx, batch in enumerate(train_loader):
step_start = time.perf_counter()
images = batch["images"].to(
cfg.device,
non_blocking=True,
memory_format=torch.channels_last,
)
labels = batch["labels"].to(cfg.device, non_blocking=True)
sync_ctx = (
model.no_sync()
if cfg.is_distributed and (batch_idx + 1) % cfg.grad_accum_steps != 0
else nullcontext()
)
with sync_ctx:
with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp and torch.cuda.is_available()):
logits = model(images)
loss = F.cross_entropy(logits, labels) / cfg.grad_accum_steps
scaler.scale(loss).backward()
if (batch_idx + 1) % cfg.grad_accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if cfg.is_main_rank and global_step % 50 == 0:
step_time = time.perf_counter() - step_start
print(
json.dumps(
{
"step": global_step,
"loss": round(loss.item() * cfg.grad_accum_steps, 5),
"images_per_sec": round(
cfg.micro_batch_size * cfg.world_size / max(step_time, 1e-6),
2,
),
"cache_dir": cfg.cache_dir,
}
)
)
if global_step % cfg.eval_interval_steps == 0:
metrics = evaluate(model, val_loader, cfg)
if cfg.is_main_rank:
print(json.dumps({"step": global_step, **metrics}))
model.train()
if global_step % cfg.checkpoint_interval_steps == 0:
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step)
if global_step >= cfg.max_steps:
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step)
cleanup_distributed()
return
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, cfg.max_epochs - 1, global_step)
cleanup_distributed()
if __name__ == "__main__":
train(TrainConfig())
  1. Pack images into immutable sequential shards and train from manifests, not individual files.
  2. Cache shards on local NVMe and treat cache hit rate as a core production metric.
  3. Decode and crop efficiently enough that DDP workers stay fed before worrying about exotic model parallelism.
  4. Use fixed final image shapes, AMP, and channels_last to get more from PyTorch once the input path is healthy.
  5. Save stream position alongside weights or restarts will be operationally expensive and logically ambiguous.