Skip to content

Billion-Scale SMILES Training With DDP

This page is a concrete scale-up story: not “what is DDP?” in the abstract, but how you would actually use PyTorch DistributedDataParallel for a huge molecular training job where the data volume, featurization cost, and resume semantics dominate the design.

The example below is intentionally hypothetical. Think of it as a production-shaped reference scenario you can explain in an interview or use to structure a real design.

Assume a pretraining and weak-supervision corpus with:

  • 4.2 billion unique compounds
  • raw source rows containing compound_id, vendor metadata, smiles, assay identifiers, and sparse labels
  • approximately 30 TB of compressed tabular source data before featurization
  • a transformer-sized molecular encoder that still fits per rank, so DDP is the simplest correct parallelism choice
ColumnTypeWhy it exists
compound_idstringStable join key across ingest, featurization, and training
raw_smilesstringOriginal vendor or registry representation
canonical_smilesstringDeterministic representation used for dedupe and tokenizing
murcko_scaffoldstringSplit and evaluation grouping
assay_idslist[string]Sparse multi-task supervision
labelslist[float]Weak labels or regression targets
token_idslist[uint16]Offline tokenized molecular sequence
token_lengthintBucketing and padding control
quality_flagsbitfieldInvalid chemistry, salt issues, or missing labels
source_partitionstringProvenance and replay/debug support

If the model replica fits comfortably on each GPU, DDP is usually still the right first answer even at billion-sample scale.

  • Model state is replicated, which simplifies recovery and debug.
  • Gradient synchronization happens at well-defined boundaries.
  • Input partitioning stays under your control instead of being hidden behind a more complex sharding wrapper.
  • The dominant engineering risk moves to the data plane, which is exactly where molecular systems usually fail first.

The wrong instinct is to see “billions of molecules” and jump directly to FSDP. Dataset scale and model scale are different problems.

flowchart LR
  A[Raw SMILES tables in object storage] --> B[Offline canonicalization + dedupe]
  B --> C[Offline tokenization + scaffold assignment]
  C --> D[Shard manifest and length buckets]
  D --> E[Rank-aware iterable dataset]
  E --> F[DataLoader workers]
  F --> G[Pinned host buffers]
  G --> H[DDP trainer rank 0..N]
  H --> I[Distributed checkpoint]
  H --> J[Metrics: tokens/sec, padding, invalid rows]
At this scale, chemistry cleanup and shard layout matter as much as the training loop.

Do not put heavy chemistry work in __getitem__.

StepOffline or onlineReason
SMILES validity checkofflineRDKit parse failures in worker hot paths destroy throughput
Canonicalization and dedupeofflineMust be deterministic before split and sharding
Murcko scaffold assignmentofflineEvaluation correctness depends on it
Tokenization to integer idsofflineCPU cost is too high to repeat every epoch
Simple padding and collationonlineDepends on current batch composition
Lightweight maskingonlineCheap, batch-specific, and easy to keep deterministic
Stochastic augmentationsonline, minimalUse sparingly; molecular corruption can change chemistry semantics

The core principle is simple:

expensive chemistry once, cheap tensor shaping every epoch

A molecule-per-file layout will fall apart on metadata and request overhead. A row-wise object-store table is better, but training still improves dramatically if you materialize rank-friendly shards.

  • write immutable Parquet or Arrow shards sized around 128 MB to 512 MB compressed
  • keep token arrays and lengths in the same row group so workers can read sequentially
  • align shard boundaries with manifest units, not arbitrary byte offsets
  • store scaffold metadata beside tokens so evaluation or curriculum jobs do not need a second lookup
  • include a checksum or row-count footer so a rank can prove a shard is complete before advancing resume state
{
"split": "train",
"shard_id": "train-084231",
"uri": "s3://chem-corpus/train/train-084231.parquet",
"num_rows": 131072,
"token_count": 6423119,
"length_bucket": "128-160",
"scaffold_families": 5841,
"checksum": "sha256:9e3d..."
}

At training time, the manifest is the real unit of work. Rows are too fine-grained for efficient remote scheduling; full dataset partitions are too coarse for recovery.

In an interview, config management is one of the easiest ways to signal that you think like an operator instead of a notebook-only user.

The trainer config should encode at least four classes of information:

Config areaExample fieldsWhy it must be explicit
runtime topologyworld_size, global_batch_tokens, grad_accum_stepschanges optimization semantics and DDP behavior
data contractmanifest_uri, split_policy, length_bucket_specties the run to a specific corpus definition
featurizationtokenizer_version, vocab_hash, canonicalization_versionprevents silent feature drift across runs
optimization and evallr, warmup_steps, eval_interval, early_stop_metrickeeps training and decision policy reproducible
from dataclasses import dataclass
@dataclass(frozen=True)
class MoleculeTrainConfig:
run_name: str
seed: int
world_size: int
micro_batch_size: int
grad_accum_steps: int
max_tokens_per_batch: int
manifest_uri: str
manifest_version: str
split_policy: str
tokenizer_version: str
vocab_hash: str
canonicalization_version: str
model_name: str
learning_rate: float
eval_interval_steps: int

The staff-level sentence is:

“I want the checkpoint to point back to a complete data and featurization contract, not just to weights.”

flowchart LR
  A[Config payload] --> B[Derived batch math]
  A --> C[Manifest version]
  A --> D[Tokenizer and canon version]
  B --> E[Run metadata]
  C --> E
  D --> E
  E --> F[Checkpoint lineage]
  F --> G[Resume or audit]
For molecular training, reproducibility starts with linking topology, manifest, and featurization versions into the checkpoint lineage.

For large molecular corpora, choose one of these paths deliberately.

RepresentationBest forHot-path costMain risk
tokenized SMILEStransformer pretraining, contrastive SSLlow after offline tokenizationstring grammar may hide 3D chemistry
offline graph tensorsGNN training on supervised assaysmoderate I/O, low CPUlarge serialized tensors can bloat storage
on-the-fly graph conversionrapid experimentation onlyhigh CPU, poor cache localityworker starvation at scale
hybrid token + descriptormultitask screening and retrievalmoderatefeature drift if offline and online disagree

For the DDP story here, tokenized SMILES is the clean baseline because it isolates DDP behavior from featurization churn.

from rdkit import Chem
def canonicalize_smiles(raw_smiles: str) -> str | None:
mol = Chem.MolFromSmiles(raw_smiles)
if mol is None:
return None
return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
def encode_smiles(smi: str, vocab: dict[str, int], unk_id: int) -> list[int]:
tokens: list[int] = []
i = 0
while i < len(smi):
pair = smi[i : i + 2]
if pair in vocab:
tokens.append(vocab[pair])
i += 2
continue
tokens.append(vocab.get(smi[i], unk_id))
i += 1
return tokens

That code belongs in a batch ETL job, not in your training worker.

Because DDP does not shard inputs for you, the dataset layer needs a deterministic contract for both assignment and replay.

  • partition shards by rank before reading rows
  • reshuffle manifest order deterministically by (seed, epoch)
  • keep length buckets stable enough that compile and kernel selection are not constantly invalidated
  • checkpoint both data position and logical epoch
  • never advance the global cursor until the current shard has been fully consumed and acknowledged
from __future__ import annotations
import random
from dataclasses import dataclass
import pyarrow.parquet as pq
import torch
from torch.utils.data import IterableDataset
@dataclass
class Cursor:
epoch: int = 0
shard_offset: int = 0
row_offset: int = 0
class ShardedMoleculeDataset(IterableDataset):
def __init__(
self,
manifest: list[dict],
rank: int,
world_size: int,
seed: int,
cursor: Cursor | None = None,
) -> None:
self.manifest = manifest
self.rank = rank
self.world_size = world_size
self.seed = seed
self.cursor = cursor or Cursor()
def state_dict(self) -> dict[str, int]:
return {
"epoch": self.cursor.epoch,
"shard_offset": self.cursor.shard_offset,
"row_offset": self.cursor.row_offset,
}
def load_state_dict(self, state: dict[str, int]) -> None:
self.cursor = Cursor(**state)
def set_epoch(self, epoch: int) -> None:
self.cursor.epoch = epoch
self.cursor.shard_offset = 0
self.cursor.row_offset = 0
def _rank_manifest(self) -> list[dict]:
rng = random.Random(self.seed + self.cursor.epoch)
shuffled = list(self.manifest)
rng.shuffle(shuffled)
return shuffled[self.rank :: self.world_size]
def __iter__(self):
shards = self._rank_manifest()
for shard_idx, shard in enumerate(shards[self.cursor.shard_offset :], start=self.cursor.shard_offset):
table = pq.read_table(shard["uri"], columns=["token_ids", "token_length", "labels"])
start_row = self.cursor.row_offset if shard_idx == self.cursor.shard_offset else 0
for row_idx in range(start_row, table.num_rows):
self.cursor.shard_offset = shard_idx
self.cursor.row_offset = row_idx
yield {
"input_ids": torch.tensor(table["token_ids"][row_idx].as_py(), dtype=torch.long),
"labels": torch.tensor(table["labels"][row_idx].as_py(), dtype=torch.float32),
"length": int(table["token_length"][row_idx].as_py()),
}
self.cursor.row_offset = 0

In production you would avoid read_table() on an entire shard if row groups are too large, but the shape of the contract is the important part: the dataset has explicit replay state.

Large molecular corpora usually have a long sequence-length tail. Padding everything to a global maximum is a direct tax on throughput.

  • pre-bucket shards or records by token length
  • form batches from a narrow length range
  • cap total tokens per batch, not just examples per batch
  • record padding ratio as a first-class metric
from torch.nn.utils.rnn import pad_sequence
def collate_molecules(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
batch = sorted(batch, key=lambda item: int(item["length"]), reverse=True)
input_ids = pad_sequence(
[item["input_ids"] for item in batch],
batch_first=True,
padding_value=0,
)
attention_mask = input_ids.ne(0)
labels = torch.stack([item["labels"] for item in batch])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}

This improves three things at once:

  • fewer wasted FLOPs on padding
  • more stable shapes for torch.compile
  • easier reasoning about tokens/sec instead of just examples/sec

If the model fits, the data plane is often the bottleneck. But there are still a few DDP settings worth being explicit about.

SettingWhy use itWhen to avoid it
gradient_as_bucket_view=Truelower peak memory and fewer gradient copiesif code assumes gradients are freely detachable
static_graph=Truecheaper reducer bookkeeping on stable graphsif control flow or parameter usage changes per batch
bucket_cap_mb tuningbetter overlap between backward compute and all-reduceif you have not yet measured comm as a bottleneck
broadcast_buffers=Falseavoid useless buffer sync for models without BN-like semanticsif buffer state must stay synchronized every forward
no_sync()gradient accumulation without all-reducing every micro-stepif accumulation logic is not mathematically intended
from contextlib import nullcontext
import torch
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
def train_epoch(model, loader, optimizer, scaler: GradScaler, cfg):
model.train()
optimizer.zero_grad(set_to_none=True)
for step, batch in enumerate(loader):
sync_ctx = model.no_sync() if (step + 1) % cfg.grad_accum_steps != 0 else nullcontext()
with sync_ctx:
with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp):
outputs = model(
input_ids=batch["input_ids"].to(cfg.device, non_blocking=True),
attention_mask=batch["attention_mask"].to(cfg.device, non_blocking=True),
)
loss = compute_loss(outputs, batch["labels"].to(cfg.device, non_blocking=True))
loss = loss / cfg.grad_accum_steps
scaler.scale(loss).backward()
if (step + 1) % cfg.grad_accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

The talk track for this code is:

  • one process per GPU
  • deterministic data ownership outside DDP
  • accumulation reduces sync frequency
  • AMP improves throughput if the numerics are stable

At this scale, exact replay is not free. The honest answer is to define what level of reproducibility you are promising.

TierWhat you can usually guaranteeWhat it costs
exact single-host replaysame code, same checkpoint, same shard order, same seedseasiest in controlled debugging environments
topology-stable distributed replaysame world size and deterministic shard assignmentpractical for production incident analysis
cross-topology statistical replaysimilar metric curves after resume on different topologymore realistic than bitwise equivalence at cluster scale
  • code revision or image digest
  • config payload and derived batch math
  • manifest version and shard checksums
  • tokenizer and canonicalization versions
  • RNG seeds for Python, NumPy, and PyTorch
  • sampler or dataset cursor state
  • “I do not promise bitwise identity across different cluster topologies unless the platform is designed around that constraint.”
  • “I do promise that a run can be explained from a stable config, data manifest, featurization version, and checkpoint lineage.”
flowchart TD
  A[Checkpoint load request] --> B{Same manifest and featurization versions?}
  B -->|No| C[Reject resume or fork new run]
  B -->|Yes| D{Same topology?}
  D -->|Yes| E[Topology-stable replay]
  D -->|No| F[Statistical replay only]
  E --> G[Continue training]
  F --> G
The practical reproducibility decision is whether a restart is a strict replay or a statistically equivalent continuation.

torch.compile can help, but only after the shapes are reasonably well-behaved.

For molecular training that usually means:

  • bucket lengths aggressively enough that shape churn is bounded
  • keep the model graph static from batch to batch
  • compile after the model is wrapped and stable in its final execution mode
  • measure compile cache churn if you allow many different sequence lengths

If every batch has a radically different padded length, compile may spend more time specializing than accelerating.

For billion-sample jobs, “save model and optimizer” is not enough.

  • model weights
  • optimizer state
  • scaler state
  • logical epoch
  • dataset cursor or sampler state
  • RNG seeds if stochastic masking or sampling affects training semantics
  • manifest version so resume does not silently reload a different corpus layout
  • use distributed checkpointing when save time becomes a bottleneck
  • stage checkpoint writes asynchronously only if you can observe backlog growth
  • checkpoint on step counts, not just epoch boundaries, because an epoch may be many hours long

A billion-scale training run is useless if evaluation is optimistic or too coarse to catch chemistry-specific failure.

  • split by scaffold family for the default offline benchmark
  • add a time-based or prospective split if the interview context includes discovery over time
  • keep dedupe rules shared between train and eval data prep
  • version the evaluation manifest separately from the training manifest
Task shapeBetter metric setWhy accuracy is weak
binary hit classificationAUROC, AUPRC, BEDROC, EF@1%heavy class imbalance and top-of-list economics
multi-task sparse assaysmacro and per-task AUROC with missing-label maskslabel sparsity hides failure if you collapse to one number
regression on potencyRMSE, Spearman, calibration by assay familyranking and calibration often matter more than mean error
retrieval or similarityrecall@K, nearest-neighbor scaffold noveltyaverage loss says little about downstream screening quality
  • run lightweight rank-0 validation every fixed number of steps
  • run heavier scaffold-stratified or assay-family slices less frequently
  • keep a small deterministic smoke-eval set for every checkpointable stage
flowchart LR
  A[Train manifest] --> B[Scaffold split policy]
  B --> C[Versioned train shards]
  B --> D[Versioned eval shards]
  C --> E[DDP training]
  D --> F[Smoke eval]
  D --> G[Scaffold eval]
  D --> H[Assay-family slices]
  E --> I[Checkpoint]
  I --> F
  I --> G
  I --> H
A strong evaluation story keeps scaffold-aware split policy separate from the training loop and tied to checkpointed model states.

When the interviewer asks “how do you know it is working?”, the stronger answer is layered:

  1. training health: tokens/sec, padding ratio, all-reduce fraction, invalid-row rate
  2. optimization health: loss, gradient norm, learning-rate schedule, divergence checks
  3. scientific usefulness: scaffold-split AUROC, BEDROC, EF@1%, assay-family calibration

That framing keeps system health separate from scientific success.

flowchart TD
  A[Run health] --> B[Data-plane metrics]
  A --> C[Optimization metrics]
  A --> D[Scientific metrics]
  B --> E[tokens/sec, padding, invalid rows]
  C --> F[loss, grad norm, lr, divergence]
  D --> G[AUROC, BEDROC, EF@1%, calibration]
Interview answers get stronger when metrics are grouped by operational layer instead of recited as one flat list.
MetricWhy it matters
tokens/sec per rankbest direct throughput measure for variable-length sequence input
padding ratiotells you whether featurization and bucketing are wasting compute
invalid SMILES rejection ratecatches upstream ingest drift
per-rank shard laghighlights storage or worker skew
all-reduce time / step timeseparates comm bottlenecks from input starvation
queue depth before H2Dtells you whether the GPU is waiting on the loader
scaffold coverage by splitproves your evaluation policy is still being enforced
eval lag in stepstells you when validation results are too stale to trust decisions
flowchart TD
  A[Training slows or diverges] --> B{First bad signal}
  B --> C[Padding ratio spikes]
  B --> D[Invalid rows rise]
  B --> E[All-reduce tail grows]
  B --> F[Rank skew appears]
  C --> G[Length buckets too wide or token stats drifted]
  D --> H[Upstream canonicalization regression]
  E --> I[Communication or batch-size issue]
  F --> J[Object-store hotspot or corrupt shard]
Most large molecular jobs fail through data inconsistency or skew before they fail through pure model math.

Move away from plain DDP only when the bottleneck is truly model-state related.

SymptomBetter next step
model barely fitsactivation checkpointing, sequence packing
optimizer state dominates memoryZeroRedundancyOptimizer or FSDP
single-rank memory remains too largeFSDP or tensor-parallel design
communication dominates despite tuningrevisit topology or hybrid parallelism

Until then, keep the system boring and spend your sophistication on data correctness and pipeline throughput.

This is the “read it top to bottom” version of the article: config, distributed init, iterable dataset, collation, model, checkpoint state, training, and evaluation in one script.

from __future__ import annotations
import json
import math
import os
import random
import time
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from pathlib import Path
import numpy as np
import pyarrow.parquet as pq
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterableDataset
@dataclass
class TrainConfig:
backend: str = "nccl"
seed: int = 17
world_size: int = int(os.environ.get("WORLD_SIZE", "1"))
rank: int = int(os.environ.get("RANK", "0"))
local_rank: int = int(os.environ.get("LOCAL_RANK", "0"))
micro_batch_size: int = 32
grad_accum_steps: int = 4
max_epochs: int = 2
max_steps: int = 20_000
eval_interval_steps: int = 1_000
checkpoint_interval_steps: int = 1_000
learning_rate: float = 3e-4
weight_decay: float = 0.01
max_length: int = 256
vocab_size: int = 512
hidden_size: int = 768
num_layers: int = 8
num_heads: int = 12
num_labels: int = 64
use_amp: bool = True
manifest_path: str = "manifests/train_manifest.json"
eval_manifest_path: str = "manifests/eval_manifest.json"
checkpoint_dir: str = "artifacts/checkpoints/smiles"
run_name: str = "billion-smiles-ddp"
tokenizer_version: str = "smiles-bpe-v3"
canonicalization_version: str = "rdkit-2025-iso"
@property
def is_distributed(self) -> bool:
return self.world_size > 1
@property
def is_main_rank(self) -> bool:
return self.rank == 0
@property
def device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{self.local_rank}")
return torch.device("cpu")
@property
def global_batch_size(self) -> int:
return self.micro_batch_size * self.grad_accum_steps * self.world_size
@dataclass
class Cursor:
epoch: int = 0
shard_offset: int = 0
row_offset: int = 0
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_distributed(cfg: TrainConfig) -> None:
if not cfg.is_distributed:
return
if torch.cuda.is_available():
torch.cuda.set_device(cfg.local_rank)
dist.init_process_group(backend=cfg.backend, rank=cfg.rank, world_size=cfg.world_size)
def cleanup_distributed() -> None:
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
def load_manifest(path: str) -> list[dict]:
return json.loads(Path(path).read_text())
class ShardedMoleculeDataset(IterableDataset):
def __init__(
self,
manifest: list[dict],
rank: int,
world_size: int,
seed: int,
cursor: Cursor | None = None,
) -> None:
self.manifest = manifest
self.rank = rank
self.world_size = world_size
self.seed = seed
self.cursor = cursor or Cursor()
def state_dict(self) -> dict[str, int]:
return asdict(self.cursor)
def load_state_dict(self, state: dict[str, int]) -> None:
self.cursor = Cursor(**state)
def set_epoch(self, epoch: int) -> None:
self.cursor.epoch = epoch
self.cursor.shard_offset = 0
self.cursor.row_offset = 0
def _rank_shards(self) -> list[dict]:
shards = list(self.manifest)
random.Random(self.seed + self.cursor.epoch).shuffle(shards)
return shards[self.rank :: self.world_size]
def __iter__(self):
shards = self._rank_shards()
for shard_idx, shard in enumerate(
shards[self.cursor.shard_offset :], start=self.cursor.shard_offset
):
table = pq.read_table(
shard["uri"],
columns=["token_ids", "token_length", "labels", "compound_id"],
)
start_row = self.cursor.row_offset if shard_idx == self.cursor.shard_offset else 0
for row_idx in range(start_row, table.num_rows):
self.cursor.shard_offset = shard_idx
self.cursor.row_offset = row_idx
yield {
"compound_id": table["compound_id"][row_idx].as_py(),
"input_ids": torch.tensor(table["token_ids"][row_idx].as_py(), dtype=torch.long),
"length": int(table["token_length"][row_idx].as_py()),
"labels": torch.tensor(table["labels"][row_idx].as_py(), dtype=torch.float32),
}
self.cursor.row_offset = 0
def collate_molecules(batch: list[dict]) -> dict[str, torch.Tensor]:
batch = sorted(batch, key=lambda item: item["length"], reverse=True)
input_ids = pad_sequence(
[item["input_ids"] for item in batch],
batch_first=True,
padding_value=0,
)
attention_mask = input_ids.ne(0)
labels = torch.stack([item["labels"] for item in batch])
lengths = torch.tensor([item["length"] for item in batch], dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"lengths": lengths,
}
class MoleculeEncoder(nn.Module):
def __init__(self, vocab_size: int, hidden_size: int, num_layers: int, num_heads: int, num_labels: int):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.pos_emb = nn.Embedding(512, hidden_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=hidden_size * 4,
batch_first=True,
norm_first=True,
activation="gelu",
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.head = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, num_labels),
)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.encoder(x, src_key_padding_mask=~attention_mask)
pooled = (x * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp_min(1)
return self.head(pooled)
def make_loader(dataset: IterableDataset, cfg: TrainConfig, training: bool) -> DataLoader:
return DataLoader(
dataset,
batch_size=cfg.micro_batch_size,
num_workers=4,
pin_memory=torch.cuda.is_available(),
persistent_workers=True,
collate_fn=collate_molecules,
drop_last=training,
)
def unwrap(model: nn.Module) -> nn.Module:
return model.module if hasattr(model, "module") else model
def save_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: GradScaler,
dataset: ShardedMoleculeDataset,
cfg: TrainConfig,
epoch: int,
global_step: int,
) -> None:
if not cfg.is_main_rank:
return
checkpoint_dir = Path(cfg.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
payload = {
"model": unwrap(model).state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
"dataset_state": dataset.state_dict(),
"epoch": epoch,
"global_step": global_step,
"config": asdict(cfg),
}
path = checkpoint_dir / f"step-{global_step:08d}.pt"
torch.save(payload, path)
def restore_if_available(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: GradScaler,
dataset: ShardedMoleculeDataset,
cfg: TrainConfig,
) -> tuple[int, int]:
checkpoint_dir = Path(cfg.checkpoint_dir)
if not checkpoint_dir.exists():
return 0, 0
checkpoints = sorted(checkpoint_dir.glob("step-*.pt"))
if not checkpoints:
return 0, 0
payload = torch.load(checkpoints[-1], map_location="cpu")
unwrap(model).load_state_dict(payload["model"])
optimizer.load_state_dict(payload["optimizer"])
scaler.load_state_dict(payload["scaler"])
dataset.load_state_dict(payload["dataset_state"])
return int(payload["epoch"]), int(payload["global_step"])
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, cfg: TrainConfig) -> dict[str, float]:
model.eval()
loss_sum = torch.zeros(1, device=cfg.device)
batch_count = torch.zeros(1, device=cfg.device)
for batch in loader:
logits = model(
input_ids=batch["input_ids"].to(cfg.device, non_blocking=True),
attention_mask=batch["attention_mask"].to(cfg.device, non_blocking=True),
)
labels = batch["labels"].to(cfg.device, non_blocking=True)
loss = F.binary_cross_entropy_with_logits(logits, labels)
loss_sum += loss.detach()
batch_count += 1
if cfg.is_distributed:
dist.all_reduce(loss_sum)
dist.all_reduce(batch_count)
return {"eval_loss": (loss_sum / batch_count.clamp_min(1)).item()}
def train(cfg: TrainConfig) -> None:
seed_everything(cfg.seed)
setup_distributed(cfg)
train_dataset = ShardedMoleculeDataset(
manifest=load_manifest(cfg.manifest_path),
rank=cfg.rank,
world_size=cfg.world_size,
seed=cfg.seed,
)
eval_dataset = ShardedMoleculeDataset(
manifest=load_manifest(cfg.eval_manifest_path),
rank=cfg.rank,
world_size=cfg.world_size,
seed=cfg.seed + 1000,
)
train_loader = make_loader(train_dataset, cfg, training=True)
eval_loader = make_loader(eval_dataset, cfg, training=False)
model = MoleculeEncoder(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size,
num_layers=cfg.num_layers,
num_heads=cfg.num_heads,
num_labels=cfg.num_labels,
).to(cfg.device)
model = torch.compile(model) if hasattr(torch, "compile") else model
if cfg.is_distributed:
model = DDP(
model,
device_ids=[cfg.local_rank] if torch.cuda.is_available() else None,
output_device=cfg.local_rank if torch.cuda.is_available() else None,
gradient_as_bucket_view=True,
static_graph=True,
broadcast_buffers=False,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scaler = GradScaler("cuda", enabled=cfg.use_amp and torch.cuda.is_available())
start_epoch, global_step = restore_if_available(model, optimizer, scaler, train_dataset, cfg)
for epoch in range(start_epoch, cfg.max_epochs):
train_dataset.set_epoch(epoch)
model.train()
optimizer.zero_grad(set_to_none=True)
for batch_idx, batch in enumerate(train_loader):
step_start = time.perf_counter()
sync_ctx = (
model.no_sync()
if cfg.is_distributed and (batch_idx + 1) % cfg.grad_accum_steps != 0
else nullcontext()
)
with sync_ctx:
with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp and torch.cuda.is_available()):
logits = model(
input_ids=batch["input_ids"].to(cfg.device, non_blocking=True),
attention_mask=batch["attention_mask"].to(cfg.device, non_blocking=True),
)
labels = batch["labels"].to(cfg.device, non_blocking=True)
loss = F.binary_cross_entropy_with_logits(logits, labels)
loss = loss / cfg.grad_accum_steps
scaler.scale(loss).backward()
if (batch_idx + 1) % cfg.grad_accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if cfg.is_main_rank and global_step % 50 == 0:
tokens = int(batch["attention_mask"].sum().item()) * cfg.world_size
step_time = time.perf_counter() - step_start
print(
json.dumps(
{
"step": global_step,
"loss": round(loss.item() * cfg.grad_accum_steps, 5),
"tokens_per_sec": round(tokens / max(step_time, 1e-6), 2),
"padding_ratio": round(
1 - batch["attention_mask"].float().mean().item(),
4,
),
}
)
)
if global_step % cfg.eval_interval_steps == 0:
metrics = evaluate(model, eval_loader, cfg)
if cfg.is_main_rank:
print(json.dumps({"step": global_step, **metrics}))
model.train()
if global_step % cfg.checkpoint_interval_steps == 0:
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step)
if global_step >= cfg.max_steps:
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step)
cleanup_distributed()
return
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, cfg.max_epochs - 1, global_step)
cleanup_distributed()
if __name__ == "__main__":
train(TrainConfig())

If you have thirty seconds to summarize the design:

  1. Canonicalize, dedupe, scaffold-tag, and tokenize molecules offline.
  2. Train from immutable length-bucketed shards with deterministic rank assignment.
  3. Use DDP because the model fits; optimize the data plane before escalating parallelism complexity.
  4. Track tokens/sec, padding, shard lag, and invalid-row rate.
  5. Checkpoint dataset cursor state, not just weights, or resume will be wrong.