Skip to content

Sequence and Structure Modeling

Protein language models, equivariant graph networks, and structure prediction models introduce distributed training challenges that do not appear in standard CV or NLP pipelines. Sequence length variation is extreme. Pair representations are quadratic in memory. Geometric equivariance adds architectural constraints that interact badly with standard parallelism strategies.

flowchart LR
  A[Protein sequence] --> B[PLM encoder: ESM-style]
  C[Molecular graph] --> D[GNN encoder: MPNN / SchNet]
  E[3D structure] --> F[Equivariant encoder: SE3 / EGNN]
  B --> G[Task head]
  D --> G
  F --> G
  G --> H[Property or structure output]
Biotech ML models usually combine a domain-specific encoder with a task head. The encoder choice determines parallelism strategy.

ESM-style protein language models are transformer encoders trained on amino acid sequences with masked token prediction. The distributed training challenges differ from standard NLP in two ways:

  1. Sequence length distribution is bimodal and heavy-tailed. UniRef50 spans lengths from 10 to 35,000 residues. A single long protein in a batch can dominate memory and cause a rank-level OOM that does not affect other ranks.
  2. The vocabulary is small (20–33 tokens) but the activation footprint is large. Memory pressure comes from activations and attention matrices, not embedding tables.
from dataclasses import dataclass, field
from training_systems_architecture import TrainConfig
@dataclass
class ProteinLMConfig(TrainConfig):
max_length: int = 1024
vocab_size: int = 33
d_model: int = 1280
num_layers: int = 33
num_heads: int = 20
length_bucket_boundaries: list[int] = field(
default_factory=lambda: [128, 256, 512, 1024]
)
@property
def pair_representation_memory_gb(self) -> float:
return (self.max_length**2 * self.d_model * 4) / (1024**3)

At max_length=1024 and d_model=1280, a pair representation alone consumes 6.7 GB per sequence in FP32. That number drives the decisions about sequence parallelism and activation checkpointing.

Standard data parallelism (DDP) replicates the full model across ranks. For transformers on long sequences, the bottleneck is per-sequence activation footprint, not parameter count.

Sequence parallelism partitions the sequence dimension across ranks so each rank processes a contiguous subsequence. The attention layer requires an all-gather across the sequence axis before computing attention scores:

flowchart TD
  A[Sequence of length L] --> B[Split: L divided by N tokens per rank]
  B --> C[Rank 0: tokens 0 to L/N]
  B --> D[Rank 1: tokens L/N to 2L/N]
  B --> E[Rank N-1: remaining tokens]
  C --> F[Local Q/K/V projection]
  D --> F
  E --> F
  F --> G[All-gather keys and values]
  G --> H[Full attention on each rank]
  H --> I[Scatter output back to sequence shards]
Sequence parallelism trades memory linearity for a communication round-trip per attention layer. Worth it when sequences exceed 2K tokens.
ConditionAction
Sequences fit in device memory at target batch sizeDDP first; simpler failure model
Sequences exceed device memory even at micro_batch=1Sequence parallelism is required
Mixed short and long sequencesDynamic batching; sequence parallelism for long-tail only

AlphaFold-style architectures maintain a pair representation of shape (L, L, d) where L is sequence length and d is the channel dimension. Memory grows quadratically in L:

def pair_memory_bytes(seq_len: int, d_pair: int = 128, dtype_bytes: int = 2) -> int:
return seq_len * seq_len * d_pair * dtype_bytes
for length in [256, 512, 1024, 2048]:
gb = pair_memory_bytes(length) / (1024**3)
print(f"L={length:5d}: {gb:.3f} GB per sequence (bfloat16, d_pair=128)")
Sequence lengthPair memory (bfloat16, d_pair=128)
2560.016 GB
5120.064 GB
10240.25 GB
20481.0 GB

Gradient accumulation does not help pair-representation memory because the tensor lives in the forward graph, not the optimizer state. The two effective mitigations are activation checkpointing over Evoformer-style pair update blocks and a sequence length curriculum that caps length during early training.

Equivariant networks for 3D molecular structures produce outputs that transform predictably under rotation and translation. This constraint affects the distributed training setup in non-obvious ways:

flowchart LR
  A[3D atom coordinates] --> B[Canonicalize coordinate frame]
  B --> C[Equivariant message passing]
  C --> D[Scalar invariant features]
  C --> E[Vector equivariant features]
  D --> F[Task readout]
  E --> F
  B -.->|Must precede sharding| G[Data partition]
Frame canonicalization must precede distributed sharding. Rank-local re-centering of a molecular shard is geometrically meaningless.

Three training constraints that matter operationally:

  1. Frame canonicalization before sharding. Re-centering or aligning a molecular structure must happen in preprocessing, not inside a DataLoader worker. A rank that receives a shard of a protein cannot independently re-center because it lacks the full structural context.
  2. No rotation augmentation. An equivariant model is rotation-invariant by construction. Augmenting with random rotations wastes compute and can degrade training when equivariance is approximate rather than exact.
  3. Layer normalization over batch normalization. Batch normalization aggregates across the batch dimension. For variable-size molecular graphs, this is inconsistent across batches. Layer normalization operates per-node and is safe.

Biotech models are frequently trained simultaneously on dozens to hundreds of biological assays. This creates correctness challenges that resemble class imbalance but are harder to detect because the label matrix is sparse:

import torch
import torch.nn as nn
def multitask_loss(
predictions: torch.Tensor,
targets: torch.Tensor,
mask: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
per_task_loss = nn.functional.binary_cross_entropy_with_logits(
predictions, targets, reduction="none"
)
masked = per_task_loss * mask.float()
weighted = masked * weights.unsqueeze(0)
return weighted.sum() / mask.float().sum().clamp(min=1.0)

The mask is critical. In a multi-task bioassay dataset, most compounds are not tested in most assays. Treating missing labels as negatives is a common silent error that systematically biases every head toward predicting inactivity and inflates apparent performance.

The standard TrainConfig from the baseline trainer needs adjustment for long-sequence biological models:

ParameterGeneric defaultProtein LM adjustmentWhy
micro_batch_size16–641–4Activation footprint per sequence
grad_accum_steps1–48–32Recover effective batch size
use_ampTrue (float16)True (bfloat16 preferred)Larger dynamic range; avoids overflow on long sequences
max_lengthN/A512–2048 with curriculumMemory budget sets the ceiling
num_workers2–41–2Variable-length batching reduces prefetch benefit

Activation Checkpointing for Long Sequences

Section titled “Activation Checkpointing for Long Sequences”

For protein LMs and structure prediction models, activation checkpointing is not optional. It is the primary tool for fitting large sequences in device memory:

import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerBlock(nn.Module):
def __init__(self, block: nn.Module) -> None:
super().__init__()
self.block = block
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
return checkpoint(self.block, x, mask, use_reentrant=False)

use_reentrant=False aligns with the current PyTorch recommendation. The reentrant variant has known issues with some forward graph patterns common in attention implementations and does not support torch.compile.

DecisionWhy you choose itWhat it costs
DDP over FSDP for PLMsSimpler baselineBreaks down once model exceeds per-device memory
Sequence parallelismRequired for L > 2048 on 80 GB devicesCommunication overhead per attention layer
Activation checkpointing on pair blocksKeeps pair-representation memory bounded~33% more compute per forward pass
Sequence length curriculumStable early training; prevents OOM from long-tail examplesComplicates epoch definition and sampler resume
bfloat16 over float16Larger dynamic range; fewer overflow eventsSlightly lower throughput on older hardware
Masking missing assay labelsPrevents silent negative-label contaminationReduces effective batch density per task