Billion-Scale SMILES Training With DDP
This page is a concrete scale-up story: not “what is DDP?” in the abstract, but how you would actually use PyTorch DistributedDataParallel for a huge molecular training job where the data volume, featurization cost, and resume semantics dominate the design.
The example below is intentionally hypothetical. Think of it as a production-shaped reference scenario you can explain in an interview or use to structure a real design.
Reference Scenario
Section titled “Reference Scenario”Assume a pretraining and weak-supervision corpus with:
- 4.2 billion unique compounds
- raw source rows containing
compound_id, vendor metadata,smiles, assay identifiers, and sparse labels - approximately 30 TB of compressed tabular source data before featurization
- a transformer-sized molecular encoder that still fits per rank, so DDP is the simplest correct parallelism choice
Example dataset schema
Section titled “Example dataset schema”| Column | Type | Why it exists |
|---|---|---|
compound_id | string | Stable join key across ingest, featurization, and training |
raw_smiles | string | Original vendor or registry representation |
canonical_smiles | string | Deterministic representation used for dedupe and tokenizing |
murcko_scaffold | string | Split and evaluation grouping |
assay_ids | list[string] | Sparse multi-task supervision |
labels | list[float] | Weak labels or regression targets |
token_ids | list[uint16] | Offline tokenized molecular sequence |
token_length | int | Bucketing and padding control |
quality_flags | bitfield | Invalid chemistry, salt issues, or missing labels |
source_partition | string | Provenance and replay/debug support |
Why DDP First
Section titled “Why DDP First”If the model replica fits comfortably on each GPU, DDP is usually still the right first answer even at billion-sample scale.
- Model state is replicated, which simplifies recovery and debug.
- Gradient synchronization happens at well-defined boundaries.
- Input partitioning stays under your control instead of being hidden behind a more complex sharding wrapper.
- The dominant engineering risk moves to the data plane, which is exactly where molecular systems usually fail first.
The wrong instinct is to see “billions of molecules” and jump directly to FSDP. Dataset scale and model scale are different problems.
End-To-End System Shape
Section titled “End-To-End System Shape”flowchart LR A[Raw SMILES tables in object storage] --> B[Offline canonicalization + dedupe] B --> C[Offline tokenization + scaffold assignment] C --> D[Shard manifest and length buckets] D --> E[Rank-aware iterable dataset] E --> F[DataLoader workers] F --> G[Pinned host buffers] G --> H[DDP trainer rank 0..N] H --> I[Distributed checkpoint] H --> J[Metrics: tokens/sec, padding, invalid rows]
What Must Happen Offline
Section titled “What Must Happen Offline”Do not put heavy chemistry work in __getitem__.
| Step | Offline or online | Reason |
|---|---|---|
| SMILES validity check | offline | RDKit parse failures in worker hot paths destroy throughput |
| Canonicalization and dedupe | offline | Must be deterministic before split and sharding |
| Murcko scaffold assignment | offline | Evaluation correctness depends on it |
| Tokenization to integer ids | offline | CPU cost is too high to repeat every epoch |
| Simple padding and collation | online | Depends on current batch composition |
| Lightweight masking | online | Cheap, batch-specific, and easy to keep deterministic |
| Stochastic augmentations | online, minimal | Use sparingly; molecular corruption can change chemistry semantics |
The core principle is simple:
expensive chemistry once, cheap tensor shaping every epoch
Storage Layout That Survives Scale
Section titled “Storage Layout That Survives Scale”A molecule-per-file layout will fall apart on metadata and request overhead. A row-wise object-store table is better, but training still improves dramatically if you materialize rank-friendly shards.
Practical shard plan
Section titled “Practical shard plan”- write immutable Parquet or Arrow shards sized around 128 MB to 512 MB compressed
- keep token arrays and lengths in the same row group so workers can read sequentially
- align shard boundaries with manifest units, not arbitrary byte offsets
- store scaffold metadata beside tokens so evaluation or curriculum jobs do not need a second lookup
- include a checksum or row-count footer so a rank can prove a shard is complete before advancing resume state
Example manifest entry
Section titled “Example manifest entry”{ "split": "train", "shard_id": "train-084231", "uri": "s3://chem-corpus/train/train-084231.parquet", "num_rows": 131072, "token_count": 6423119, "length_bucket": "128-160", "scaffold_families": 5841, "checksum": "sha256:9e3d..."}At training time, the manifest is the real unit of work. Rows are too fine-grained for efficient remote scheduling; full dataset partitions are too coarse for recovery.
Config Management Is Part Of The Design
Section titled “Config Management Is Part Of The Design”In an interview, config management is one of the easiest ways to signal that you think like an operator instead of a notebook-only user.
The trainer config should encode at least four classes of information:
| Config area | Example fields | Why it must be explicit |
|---|---|---|
| runtime topology | world_size, global_batch_tokens, grad_accum_steps | changes optimization semantics and DDP behavior |
| data contract | manifest_uri, split_policy, length_bucket_spec | ties the run to a specific corpus definition |
| featurization | tokenizer_version, vocab_hash, canonicalization_version | prevents silent feature drift across runs |
| optimization and eval | lr, warmup_steps, eval_interval, early_stop_metric | keeps training and decision policy reproducible |
Example config shape
Section titled “Example config shape”from dataclasses import dataclass
@dataclass(frozen=True)class MoleculeTrainConfig: run_name: str seed: int world_size: int micro_batch_size: int grad_accum_steps: int max_tokens_per_batch: int manifest_uri: str manifest_version: str split_policy: str tokenizer_version: str vocab_hash: str canonicalization_version: str model_name: str learning_rate: float eval_interval_steps: intThe staff-level sentence is:
“I want the checkpoint to point back to a complete data and featurization contract, not just to weights.”
flowchart LR A[Config payload] --> B[Derived batch math] A --> C[Manifest version] A --> D[Tokenizer and canon version] B --> E[Run metadata] C --> E D --> E E --> F[Checkpoint lineage] F --> G[Resume or audit]
Featurization Strategy
Section titled “Featurization Strategy”For large molecular corpora, choose one of these paths deliberately.
| Representation | Best for | Hot-path cost | Main risk |
|---|---|---|---|
| tokenized SMILES | transformer pretraining, contrastive SSL | low after offline tokenization | string grammar may hide 3D chemistry |
| offline graph tensors | GNN training on supervised assays | moderate I/O, low CPU | large serialized tensors can bloat storage |
| on-the-fly graph conversion | rapid experimentation only | high CPU, poor cache locality | worker starvation at scale |
| hybrid token + descriptor | multitask screening and retrieval | moderate | feature drift if offline and online disagree |
For the DDP story here, tokenized SMILES is the clean baseline because it isolates DDP behavior from featurization churn.
Offline tokenization example
Section titled “Offline tokenization example”from rdkit import Chem
def canonicalize_smiles(raw_smiles: str) -> str | None: mol = Chem.MolFromSmiles(raw_smiles) if mol is None: return None return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
def encode_smiles(smi: str, vocab: dict[str, int], unk_id: int) -> list[int]: tokens: list[int] = [] i = 0 while i < len(smi): pair = smi[i : i + 2] if pair in vocab: tokens.append(vocab[pair]) i += 2 continue tokens.append(vocab.get(smi[i], unk_id)) i += 1 return tokensThat code belongs in a batch ETL job, not in your training worker.
Distributed Dataset Design
Section titled “Distributed Dataset Design”Because DDP does not shard inputs for you, the dataset layer needs a deterministic contract for both assignment and replay.
- partition shards by rank before reading rows
- reshuffle manifest order deterministically by
(seed, epoch) - keep length buckets stable enough that compile and kernel selection are not constantly invalidated
- checkpoint both data position and logical epoch
- never advance the global cursor until the current shard has been fully consumed and acknowledged
Example iterable dataset
Section titled “Example iterable dataset”from __future__ import annotations
import randomfrom dataclasses import dataclass
import pyarrow.parquet as pqimport torchfrom torch.utils.data import IterableDataset
@dataclassclass Cursor: epoch: int = 0 shard_offset: int = 0 row_offset: int = 0
class ShardedMoleculeDataset(IterableDataset): def __init__( self, manifest: list[dict], rank: int, world_size: int, seed: int, cursor: Cursor | None = None, ) -> None: self.manifest = manifest self.rank = rank self.world_size = world_size self.seed = seed self.cursor = cursor or Cursor()
def state_dict(self) -> dict[str, int]: return { "epoch": self.cursor.epoch, "shard_offset": self.cursor.shard_offset, "row_offset": self.cursor.row_offset, }
def load_state_dict(self, state: dict[str, int]) -> None: self.cursor = Cursor(**state)
def set_epoch(self, epoch: int) -> None: self.cursor.epoch = epoch self.cursor.shard_offset = 0 self.cursor.row_offset = 0
def _rank_manifest(self) -> list[dict]: rng = random.Random(self.seed + self.cursor.epoch) shuffled = list(self.manifest) rng.shuffle(shuffled) return shuffled[self.rank :: self.world_size]
def __iter__(self): shards = self._rank_manifest() for shard_idx, shard in enumerate(shards[self.cursor.shard_offset :], start=self.cursor.shard_offset): table = pq.read_table(shard["uri"], columns=["token_ids", "token_length", "labels"]) start_row = self.cursor.row_offset if shard_idx == self.cursor.shard_offset else 0 for row_idx in range(start_row, table.num_rows): self.cursor.shard_offset = shard_idx self.cursor.row_offset = row_idx yield { "input_ids": torch.tensor(table["token_ids"][row_idx].as_py(), dtype=torch.long), "labels": torch.tensor(table["labels"][row_idx].as_py(), dtype=torch.float32), "length": int(table["token_length"][row_idx].as_py()), } self.cursor.row_offset = 0In production you would avoid read_table() on an entire shard if row groups are too large, but the shape of the contract is the important part: the dataset has explicit replay state.
Efficient Batching
Section titled “Efficient Batching”Large molecular corpora usually have a long sequence-length tail. Padding everything to a global maximum is a direct tax on throughput.
Better approach
Section titled “Better approach”- pre-bucket shards or records by token length
- form batches from a narrow length range
- cap total tokens per batch, not just examples per batch
- record padding ratio as a first-class metric
from torch.nn.utils.rnn import pad_sequence
def collate_molecules(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: batch = sorted(batch, key=lambda item: int(item["length"]), reverse=True) input_ids = pad_sequence( [item["input_ids"] for item in batch], batch_first=True, padding_value=0, ) attention_mask = input_ids.ne(0) labels = torch.stack([item["labels"] for item in batch]) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, }This improves three things at once:
- fewer wasted FLOPs on padding
- more stable shapes for
torch.compile - easier reasoning about tokens/sec instead of just examples/sec
DDP Trainer Settings That Usually Matter
Section titled “DDP Trainer Settings That Usually Matter”If the model fits, the data plane is often the bottleneck. But there are still a few DDP settings worth being explicit about.
| Setting | Why use it | When to avoid it |
|---|---|---|
gradient_as_bucket_view=True | lower peak memory and fewer gradient copies | if code assumes gradients are freely detachable |
static_graph=True | cheaper reducer bookkeeping on stable graphs | if control flow or parameter usage changes per batch |
bucket_cap_mb tuning | better overlap between backward compute and all-reduce | if you have not yet measured comm as a bottleneck |
broadcast_buffers=False | avoid useless buffer sync for models without BN-like semantics | if buffer state must stay synchronized every forward |
no_sync() | gradient accumulation without all-reducing every micro-step | if accumulation logic is not mathematically intended |
Training step example
Section titled “Training step example”from contextlib import nullcontext
import torchfrom torch.amp import GradScaler, autocastfrom torch.nn.parallel import DistributedDataParallel as DDP
def train_epoch(model, loader, optimizer, scaler: GradScaler, cfg): model.train() optimizer.zero_grad(set_to_none=True)
for step, batch in enumerate(loader): sync_ctx = model.no_sync() if (step + 1) % cfg.grad_accum_steps != 0 else nullcontext() with sync_ctx: with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp): outputs = model( input_ids=batch["input_ids"].to(cfg.device, non_blocking=True), attention_mask=batch["attention_mask"].to(cfg.device, non_blocking=True), ) loss = compute_loss(outputs, batch["labels"].to(cfg.device, non_blocking=True)) loss = loss / cfg.grad_accum_steps
scaler.scale(loss).backward()
if (step + 1) % cfg.grad_accum_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True)The talk track for this code is:
- one process per GPU
- deterministic data ownership outside DDP
- accumulation reduces sync frequency
- AMP improves throughput if the numerics are stable
Reproducibility: Be Precise, Not Naive
Section titled “Reproducibility: Be Precise, Not Naive”At this scale, exact replay is not free. The honest answer is to define what level of reproducibility you are promising.
Reproducibility tiers
Section titled “Reproducibility tiers”| Tier | What you can usually guarantee | What it costs |
|---|---|---|
| exact single-host replay | same code, same checkpoint, same shard order, same seeds | easiest in controlled debugging environments |
| topology-stable distributed replay | same world size and deterministic shard assignment | practical for production incident analysis |
| cross-topology statistical replay | similar metric curves after resume on different topology | more realistic than bitwise equivalence at cluster scale |
What I would save and pin
Section titled “What I would save and pin”- code revision or image digest
- config payload and derived batch math
- manifest version and shard checksums
- tokenizer and canonicalization versions
- RNG seeds for Python, NumPy, and PyTorch
- sampler or dataset cursor state
What I would say in the room
Section titled “What I would say in the room”- “I do not promise bitwise identity across different cluster topologies unless the platform is designed around that constraint.”
- “I do promise that a run can be explained from a stable config, data manifest, featurization version, and checkpoint lineage.”
flowchart TD
A[Checkpoint load request] --> B{Same manifest and featurization versions?}
B -->|No| C[Reject resume or fork new run]
B -->|Yes| D{Same topology?}
D -->|Yes| E[Topology-stable replay]
D -->|No| F[Statistical replay only]
E --> G[Continue training]
F --> G
torch.compile Without Self-Sabotage
Section titled “torch.compile Without Self-Sabotage”torch.compile can help, but only after the shapes are reasonably well-behaved.
For molecular training that usually means:
- bucket lengths aggressively enough that shape churn is bounded
- keep the model graph static from batch to batch
- compile after the model is wrapped and stable in its final execution mode
- measure compile cache churn if you allow many different sequence lengths
If every batch has a radically different padded length, compile may spend more time specializing than accelerating.
Checkpointing What Actually Matters
Section titled “Checkpointing What Actually Matters”For billion-sample jobs, “save model and optimizer” is not enough.
Minimum restart state
Section titled “Minimum restart state”- model weights
- optimizer state
- scaler state
- logical epoch
- dataset cursor or sampler state
- RNG seeds if stochastic masking or sampling affects training semantics
- manifest version so resume does not silently reload a different corpus layout
Recommended pattern
Section titled “Recommended pattern”- use distributed checkpointing when save time becomes a bottleneck
- stage checkpoint writes asynchronously only if you can observe backlog growth
- checkpoint on step counts, not just epoch boundaries, because an epoch may be many hours long
Evaluation Design For Molecular Training
Section titled “Evaluation Design For Molecular Training”A billion-scale training run is useless if evaluation is optimistic or too coarse to catch chemistry-specific failure.
Evaluation contract
Section titled “Evaluation contract”- split by scaffold family for the default offline benchmark
- add a time-based or prospective split if the interview context includes discovery over time
- keep dedupe rules shared between train and eval data prep
- version the evaluation manifest separately from the training manifest
Metrics by task shape
Section titled “Metrics by task shape”| Task shape | Better metric set | Why accuracy is weak |
|---|---|---|
| binary hit classification | AUROC, AUPRC, BEDROC, EF@1% | heavy class imbalance and top-of-list economics |
| multi-task sparse assays | macro and per-task AUROC with missing-label masks | label sparsity hides failure if you collapse to one number |
| regression on potency | RMSE, Spearman, calibration by assay family | ranking and calibration often matter more than mean error |
| retrieval or similarity | recall@K, nearest-neighbor scaffold novelty | average loss says little about downstream screening quality |
Evaluation cadence
Section titled “Evaluation cadence”- run lightweight rank-0 validation every fixed number of steps
- run heavier scaffold-stratified or assay-family slices less frequently
- keep a small deterministic smoke-eval set for every checkpointable stage
flowchart LR A[Train manifest] --> B[Scaffold split policy] B --> C[Versioned train shards] B --> D[Versioned eval shards] C --> E[DDP training] D --> F[Smoke eval] D --> G[Scaffold eval] D --> H[Assay-family slices] E --> I[Checkpoint] I --> F I --> G I --> H
Interview Language For Metrics And Evals
Section titled “Interview Language For Metrics And Evals”When the interviewer asks “how do you know it is working?”, the stronger answer is layered:
- training health: tokens/sec, padding ratio, all-reduce fraction, invalid-row rate
- optimization health: loss, gradient norm, learning-rate schedule, divergence checks
- scientific usefulness: scaffold-split AUROC, BEDROC, EF@1%, assay-family calibration
That framing keeps system health separate from scientific success.
flowchart TD A[Run health] --> B[Data-plane metrics] A --> C[Optimization metrics] A --> D[Scientific metrics] B --> E[tokens/sec, padding, invalid rows] C --> F[loss, grad norm, lr, divergence] D --> G[AUROC, BEDROC, EF@1%, calibration]
Metrics That Matter More Than Loss
Section titled “Metrics That Matter More Than Loss”| Metric | Why it matters |
|---|---|
| tokens/sec per rank | best direct throughput measure for variable-length sequence input |
| padding ratio | tells you whether featurization and bucketing are wasting compute |
| invalid SMILES rejection rate | catches upstream ingest drift |
| per-rank shard lag | highlights storage or worker skew |
| all-reduce time / step time | separates comm bottlenecks from input starvation |
| queue depth before H2D | tells you whether the GPU is waiting on the loader |
| scaffold coverage by split | proves your evaluation policy is still being enforced |
| eval lag in steps | tells you when validation results are too stale to trust decisions |
Common Failure Modes
Section titled “Common Failure Modes”
flowchart TD
A[Training slows or diverges] --> B{First bad signal}
B --> C[Padding ratio spikes]
B --> D[Invalid rows rise]
B --> E[All-reduce tail grows]
B --> F[Rank skew appears]
C --> G[Length buckets too wide or token stats drifted]
D --> H[Upstream canonicalization regression]
E --> I[Communication or batch-size issue]
F --> J[Object-store hotspot or corrupt shard]
When DDP Stops Being Enough
Section titled “When DDP Stops Being Enough”Move away from plain DDP only when the bottleneck is truly model-state related.
| Symptom | Better next step |
|---|---|
| model barely fits | activation checkpointing, sequence packing |
| optimizer state dominates memory | ZeroRedundancyOptimizer or FSDP |
| single-rank memory remains too large | FSDP or tensor-parallel design |
| communication dominates despite tuning | revisit topology or hybrid parallelism |
Until then, keep the system boring and spend your sophistication on data correctness and pipeline throughput.
Full Reference Implementation
Section titled “Full Reference Implementation”This is the “read it top to bottom” version of the article: config, distributed init, iterable dataset, collation, model, checkpoint state, training, and evaluation in one script.
from __future__ import annotations
import jsonimport mathimport osimport randomimport timefrom contextlib import nullcontextfrom dataclasses import asdict, dataclassfrom pathlib import Path
import numpy as npimport pyarrow.parquet as pqimport torchimport torch.distributed as distimport torch.nn as nnimport torch.nn.functional as Ffrom torch.amp import GradScaler, autocastfrom torch.nn.parallel import DistributedDataParallel as DDPfrom torch.nn.utils.rnn import pad_sequencefrom torch.utils.data import DataLoader, IterableDataset
@dataclassclass TrainConfig: backend: str = "nccl" seed: int = 17 world_size: int = int(os.environ.get("WORLD_SIZE", "1")) rank: int = int(os.environ.get("RANK", "0")) local_rank: int = int(os.environ.get("LOCAL_RANK", "0")) micro_batch_size: int = 32 grad_accum_steps: int = 4 max_epochs: int = 2 max_steps: int = 20_000 eval_interval_steps: int = 1_000 checkpoint_interval_steps: int = 1_000 learning_rate: float = 3e-4 weight_decay: float = 0.01 max_length: int = 256 vocab_size: int = 512 hidden_size: int = 768 num_layers: int = 8 num_heads: int = 12 num_labels: int = 64 use_amp: bool = True manifest_path: str = "manifests/train_manifest.json" eval_manifest_path: str = "manifests/eval_manifest.json" checkpoint_dir: str = "artifacts/checkpoints/smiles" run_name: str = "billion-smiles-ddp" tokenizer_version: str = "smiles-bpe-v3" canonicalization_version: str = "rdkit-2025-iso"
@property def is_distributed(self) -> bool: return self.world_size > 1
@property def is_main_rank(self) -> bool: return self.rank == 0
@property def device(self) -> torch.device: if torch.cuda.is_available(): return torch.device(f"cuda:{self.local_rank}") return torch.device("cpu")
@property def global_batch_size(self) -> int: return self.micro_batch_size * self.grad_accum_steps * self.world_size
@dataclassclass Cursor: epoch: int = 0 shard_offset: int = 0 row_offset: int = 0
def seed_everything(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def setup_distributed(cfg: TrainConfig) -> None: if not cfg.is_distributed: return if torch.cuda.is_available(): torch.cuda.set_device(cfg.local_rank) dist.init_process_group(backend=cfg.backend, rank=cfg.rank, world_size=cfg.world_size)
def cleanup_distributed() -> None: if dist.is_available() and dist.is_initialized(): dist.destroy_process_group()
def load_manifest(path: str) -> list[dict]: return json.loads(Path(path).read_text())
class ShardedMoleculeDataset(IterableDataset): def __init__( self, manifest: list[dict], rank: int, world_size: int, seed: int, cursor: Cursor | None = None, ) -> None: self.manifest = manifest self.rank = rank self.world_size = world_size self.seed = seed self.cursor = cursor or Cursor()
def state_dict(self) -> dict[str, int]: return asdict(self.cursor)
def load_state_dict(self, state: dict[str, int]) -> None: self.cursor = Cursor(**state)
def set_epoch(self, epoch: int) -> None: self.cursor.epoch = epoch self.cursor.shard_offset = 0 self.cursor.row_offset = 0
def _rank_shards(self) -> list[dict]: shards = list(self.manifest) random.Random(self.seed + self.cursor.epoch).shuffle(shards) return shards[self.rank :: self.world_size]
def __iter__(self): shards = self._rank_shards() for shard_idx, shard in enumerate( shards[self.cursor.shard_offset :], start=self.cursor.shard_offset ): table = pq.read_table( shard["uri"], columns=["token_ids", "token_length", "labels", "compound_id"], ) start_row = self.cursor.row_offset if shard_idx == self.cursor.shard_offset else 0 for row_idx in range(start_row, table.num_rows): self.cursor.shard_offset = shard_idx self.cursor.row_offset = row_idx yield { "compound_id": table["compound_id"][row_idx].as_py(), "input_ids": torch.tensor(table["token_ids"][row_idx].as_py(), dtype=torch.long), "length": int(table["token_length"][row_idx].as_py()), "labels": torch.tensor(table["labels"][row_idx].as_py(), dtype=torch.float32), } self.cursor.row_offset = 0
def collate_molecules(batch: list[dict]) -> dict[str, torch.Tensor]: batch = sorted(batch, key=lambda item: item["length"], reverse=True) input_ids = pad_sequence( [item["input_ids"] for item in batch], batch_first=True, padding_value=0, ) attention_mask = input_ids.ne(0) labels = torch.stack([item["labels"] for item in batch]) lengths = torch.tensor([item["length"] for item in batch], dtype=torch.long) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "lengths": lengths, }
class MoleculeEncoder(nn.Module): def __init__(self, vocab_size: int, hidden_size: int, num_layers: int, num_heads: int, num_labels: int): super().__init__() self.token_emb = nn.Embedding(vocab_size, hidden_size, padding_idx=0) self.pos_emb = nn.Embedding(512, hidden_size) encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_size, nhead=num_heads, dim_feedforward=hidden_size * 4, batch_first=True, norm_first=True, activation="gelu", ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.head = nn.Sequential( nn.LayerNorm(hidden_size), nn.Linear(hidden_size, num_labels), )
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0) x = self.token_emb(input_ids) + self.pos_emb(positions) x = self.encoder(x, src_key_padding_mask=~attention_mask) pooled = (x * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp_min(1) return self.head(pooled)
def make_loader(dataset: IterableDataset, cfg: TrainConfig, training: bool) -> DataLoader: return DataLoader( dataset, batch_size=cfg.micro_batch_size, num_workers=4, pin_memory=torch.cuda.is_available(), persistent_workers=True, collate_fn=collate_molecules, drop_last=training, )
def unwrap(model: nn.Module) -> nn.Module: return model.module if hasattr(model, "module") else model
def save_checkpoint( model: nn.Module, optimizer: torch.optim.Optimizer, scaler: GradScaler, dataset: ShardedMoleculeDataset, cfg: TrainConfig, epoch: int, global_step: int,) -> None: if not cfg.is_main_rank: return checkpoint_dir = Path(cfg.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) payload = { "model": unwrap(model).state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "dataset_state": dataset.state_dict(), "epoch": epoch, "global_step": global_step, "config": asdict(cfg), } path = checkpoint_dir / f"step-{global_step:08d}.pt" torch.save(payload, path)
def restore_if_available( model: nn.Module, optimizer: torch.optim.Optimizer, scaler: GradScaler, dataset: ShardedMoleculeDataset, cfg: TrainConfig,) -> tuple[int, int]: checkpoint_dir = Path(cfg.checkpoint_dir) if not checkpoint_dir.exists(): return 0, 0 checkpoints = sorted(checkpoint_dir.glob("step-*.pt")) if not checkpoints: return 0, 0 payload = torch.load(checkpoints[-1], map_location="cpu") unwrap(model).load_state_dict(payload["model"]) optimizer.load_state_dict(payload["optimizer"]) scaler.load_state_dict(payload["scaler"]) dataset.load_state_dict(payload["dataset_state"]) return int(payload["epoch"]), int(payload["global_step"])
@torch.no_grad()def evaluate(model: nn.Module, loader: DataLoader, cfg: TrainConfig) -> dict[str, float]: model.eval() loss_sum = torch.zeros(1, device=cfg.device) batch_count = torch.zeros(1, device=cfg.device) for batch in loader: logits = model( input_ids=batch["input_ids"].to(cfg.device, non_blocking=True), attention_mask=batch["attention_mask"].to(cfg.device, non_blocking=True), ) labels = batch["labels"].to(cfg.device, non_blocking=True) loss = F.binary_cross_entropy_with_logits(logits, labels) loss_sum += loss.detach() batch_count += 1 if cfg.is_distributed: dist.all_reduce(loss_sum) dist.all_reduce(batch_count) return {"eval_loss": (loss_sum / batch_count.clamp_min(1)).item()}
def train(cfg: TrainConfig) -> None: seed_everything(cfg.seed) setup_distributed(cfg)
train_dataset = ShardedMoleculeDataset( manifest=load_manifest(cfg.manifest_path), rank=cfg.rank, world_size=cfg.world_size, seed=cfg.seed, ) eval_dataset = ShardedMoleculeDataset( manifest=load_manifest(cfg.eval_manifest_path), rank=cfg.rank, world_size=cfg.world_size, seed=cfg.seed + 1000, ) train_loader = make_loader(train_dataset, cfg, training=True) eval_loader = make_loader(eval_dataset, cfg, training=False)
model = MoleculeEncoder( vocab_size=cfg.vocab_size, hidden_size=cfg.hidden_size, num_layers=cfg.num_layers, num_heads=cfg.num_heads, num_labels=cfg.num_labels, ).to(cfg.device) model = torch.compile(model) if hasattr(torch, "compile") else model if cfg.is_distributed: model = DDP( model, device_ids=[cfg.local_rank] if torch.cuda.is_available() else None, output_device=cfg.local_rank if torch.cuda.is_available() else None, gradient_as_bucket_view=True, static_graph=True, broadcast_buffers=False, )
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) scaler = GradScaler("cuda", enabled=cfg.use_amp and torch.cuda.is_available())
start_epoch, global_step = restore_if_available(model, optimizer, scaler, train_dataset, cfg)
for epoch in range(start_epoch, cfg.max_epochs): train_dataset.set_epoch(epoch) model.train() optimizer.zero_grad(set_to_none=True) for batch_idx, batch in enumerate(train_loader): step_start = time.perf_counter() sync_ctx = ( model.no_sync() if cfg.is_distributed and (batch_idx + 1) % cfg.grad_accum_steps != 0 else nullcontext() ) with sync_ctx: with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp and torch.cuda.is_available()): logits = model( input_ids=batch["input_ids"].to(cfg.device, non_blocking=True), attention_mask=batch["attention_mask"].to(cfg.device, non_blocking=True), ) labels = batch["labels"].to(cfg.device, non_blocking=True) loss = F.binary_cross_entropy_with_logits(logits, labels) loss = loss / cfg.grad_accum_steps scaler.scale(loss).backward()
if (batch_idx + 1) % cfg.grad_accum_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) global_step += 1
if cfg.is_main_rank and global_step % 50 == 0: tokens = int(batch["attention_mask"].sum().item()) * cfg.world_size step_time = time.perf_counter() - step_start print( json.dumps( { "step": global_step, "loss": round(loss.item() * cfg.grad_accum_steps, 5), "tokens_per_sec": round(tokens / max(step_time, 1e-6), 2), "padding_ratio": round( 1 - batch["attention_mask"].float().mean().item(), 4, ), } ) )
if global_step % cfg.eval_interval_steps == 0: metrics = evaluate(model, eval_loader, cfg) if cfg.is_main_rank: print(json.dumps({"step": global_step, **metrics})) model.train()
if global_step % cfg.checkpoint_interval_steps == 0: save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step)
if global_step >= cfg.max_steps: save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step) cleanup_distributed() return
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, cfg.max_epochs - 1, global_step) cleanup_distributed()
if __name__ == "__main__": train(TrainConfig())Interview-Ready Summary
Section titled “Interview-Ready Summary”If you have thirty seconds to summarize the design:
- Canonicalize, dedupe, scaffold-tag, and tokenize molecules offline.
- Train from immutable length-bucketed shards with deterministic rank assignment.
- Use DDP because the model fits; optimize the data plane before escalating parallelism complexity.
- Track tokens/sec, padding, shard lag, and invalid-row rate.
- Checkpoint dataset cursor state, not just weights, or resume will be wrong.