Checkpoint Module
The highentDCA.checkpoint module provides checkpoint strategies for saving model state during training.
Overview
Checkpoints are essential for: - Saving model progress during long training runs - Creating snapshots at specific densities during decimation - Logging training metrics - Enabling training resumption after interruptions - Tracking experiments with Weights & Biases
Classes
Checkpoint (Abstract Base Class)
Base class for all checkpoint strategies. Defines the interface that all checkpoint implementations must follow.
from highentDCA.checkpoint import Checkpoint
Constructor
Checkpoint(
file_paths: dict,
tokens: str,
args: dict,
params: Dict[str, torch.Tensor] | None = None,
chains: torch.Tensor | None = None,
use_wandb: bool = False,
)
Parameters:
- file_paths (
dict): Dictionary with file paths for saving outputs. - Required keys:
"log","params","chains" -
Values:
pathlib.Pathobjects or strings -
tokens (
str): Alphabet for sequence encoding (e.g.,"protein","rna", or custom) -
args (
dict): Training configuration dictionary with keys: "model": Model type ("bmDCA","eaDCA","edDCA")"data": Path to input MSA"alphabet": Sequence alphabet"lr": Learning rate"nsweeps": Number of sweeps"target": Target Pearson correlation-
Additional model-specific args
-
params (
Dict[str, torch.Tensor] | None): Initial model parameters "bias": Fields, shape(L, q)-
"coupling_matrix": Couplings, shape(L, q, L, q) -
chains (
torch.Tensor | None): Initial Markov chains, shape(n_chains, L, q) -
use_wandb (
bool): Enable Weights & Biases logging (default:False)
Methods
header_log()
Write the header row to the log file.
checkpoint.header_log()
Example:
checkpoint = DecCheckpoint(...)
checkpoint.header_log()
# Writes: "Epochs Pearson Entropy Density Time "
log(record: Dict[str, Any])
Log training metrics and append to log file.
checkpoint.log(record: Dict[str, Any])
Parameters:
- record (
Dict[str, Any]): Metrics to log. Valid keys: "Epochs": Epoch number (int)"Pearson": Pearson correlation (float)"Entropy": Model entropy (float)"Density": Coupling density (float)"Time": Elapsed time in seconds (float)
Raises:
ValueError: If record contains unrecognized keys
Example:
checkpoint.log({
"Epochs": 100,
"Pearson": 0.95,
"Entropy": 120.5,
"Density": 0.587,
"Time": 125.3,
})
check(*args, **kwargs) (Abstract)
Check if a checkpoint condition is met. Must be implemented by subclasses.
should_save = checkpoint.check(...)
Returns:
bool:Trueif checkpoint should be saved,Falseotherwise
save(params, mask, chains, log_weights) (Abstract)
Save model state to disk. Must be implemented by subclasses.
checkpoint.save(
params=params,
mask=mask,
chains=chains,
log_weights=log_weights,
)
Parameters:
- params (
Dict[str, torch.Tensor]): Model parameters - mask (
torch.Tensor): Binary coupling mask, shape(L, q, L, q) - chains (
torch.Tensor): Markov chains, shape(n_chains, L, q) - log_weights (
torch.Tensor): AIS log-weights, shape(n_chains,)
DecCheckpoint (Density-Based Checkpoint)
Checkpoint strategy based on coupling matrix density thresholds. Saves model state when density crosses predefined values.
from highentDCA.checkpoint import DecCheckpoint
Constructor
DecCheckpoint(
file_paths: dict,
tokens: str,
args: dict,
params: Dict[str, torch.Tensor] | None = None,
chains: torch.Tensor | None = None,
checkpt_steps: list[float] | None = None,
use_wandb: bool = False,
target_density: float | None = None,
n_steps: int = 10,
**kwargs,
)
Additional Parameters:
- checkpt_steps (
list[float] | None): Custom density thresholds (e.g.,[0.9, 0.5, 0.2, 0.05]) - If
None, generatesn_stepsgeometrically-spaced values from0.99totarget_density -
Values are automatically sorted in descending order
-
target_density (
float | None): Final target density (used to generate checkpt_steps if None provided) -
n_steps (
int): Number of checkpoint steps to generate (default: 10)
Automatic Checkpoint Generation:
If checkpt_steps=None, generates geometric sequence:
where \(\rho_{\text{start}} = 0.99\), \(\rho_{\text{target}}\) = target_density, and \(N\) = n_steps.
Example:
from pathlib import Path
from highentDCA.checkpoint import DecCheckpoint
checkpoint = DecCheckpoint(
file_paths={
"log": Path("output/training.log"),
"params": Path("output/params.dat"),
"chains": Path("output/chains.fasta"),
},
tokens="protein",
args={
"model": "edDCA",
"data": "data/alignment.fasta",
"alphabet": "protein",
"lr": 0.01,
"nsweeps": 10,
"nsweeps_dec": 10,
"target": 0.95,
"density": 0.02,
"drate": 0.01,
"pseudocount": None,
"dtype": "float32",
"label": None,
"nepochs": 50000,
"sampler": "gibbs",
"nchains": 10000,
"seed": 42,
},
target_density=0.02,
n_steps=10,
use_wandb=False,
)
# Automatic checkpoints at:
# [0.99, 0.82, 0.68, 0.56, 0.46, 0.38, 0.31, 0.26, 0.21, 0.02]
Methods
check(density: float)
Check if current density crossed a checkpoint threshold.
should_save = checkpoint.check(density=0.35)
Parameters:
- density (
float): Current coupling matrix density
Returns:
bool:Trueif density crossed a threshold,Falseotherwise
Behavior:
- Automatically removes all passed thresholds
- Allows skipping intermediate checkpoints if density drops rapidly
- Returns
Falsewhen all checkpoints exhausted
Example:
checkpoint = DecCheckpoint(
...,
checkpt_steps=[0.9, 0.5, 0.2, 0.05],
)
checkpoint.check(0.85) # False (0.85 > 0.9)
checkpoint.check(0.60) # True (crossed 0.9)
checkpoint.check(0.45) # True (crossed 0.5)
checkpoint.check(0.18) # True (crossed 0.2, skipped)
checkpoint.check(0.03) # True (crossed 0.05)
checkpoint.check(0.02) # False (all exhausted)
save(params, mask, chains, log_weights, density: float)
Save model state with density-labeled filenames.
checkpoint.save(
params=params,
mask=mask,
chains=chains,
log_weights=log_weights,
density=0.35,
)
Parameters:
- density (
float): Current density (appended to filenames) - Other parameters: Same as base
Checkpoint.save()
Output Files:
Creates two files with density labels:
{stem}_density_{density:.3f}.dat- Parameters file{stem}_density_{density:.3f}.fasta- Chains file
Example:
checkpoint.save(
params=params,
mask=mask,
chains=chains,
log_weights=log_weights,
density=0.587,
)
# Creates:
# - output/params_density_0.587.dat
# - output/chains_density_0.587.fasta
Complete Usage Example
import torch
from pathlib import Path
from adabmDCA.dataset import DatasetDCA
from adabmDCA.utils import init_chains, init_parameters
from adabmDCA.sampling import get_sampler
from highentDCA.models.edDCA import fit
from highentDCA.checkpoint import DecCheckpoint
# Configuration
device = torch.device("cuda")
dtype = torch.float32
# Load dataset
dataset = DatasetDCA(
path_data="data/PF00072.fasta",
alphabet="protein",
device=device,
dtype=dtype,
)
# Initialize model
params = init_parameters(L=dataset.L, q=dataset.q, device=device, dtype=dtype)
chains = init_chains(10000, dataset.L, dataset.q, device=device, dtype=dtype)
log_weights = torch.zeros(10000, device=device, dtype=dtype)
sampler = get_sampler("gibbs")
# Create output directory
output_dir = Path("results/PF00072")
output_dir.mkdir(parents=True, exist_ok=True)
# Configure checkpoint with custom densities
checkpoint = DecCheckpoint(
file_paths={
"log": output_dir / "training.log",
"params": output_dir / "params.dat",
"chains": output_dir / "chains.fasta",
},
tokens=dataset.tokens,
args={
"model": "edDCA",
"data": "data/PF00072.fasta",
"alphabet": "protein",
"lr": 0.01,
"nsweeps": 10,
"nsweeps_dec": 10,
"target": 0.95,
"density": 0.02,
"drate": 0.01,
"pseudocount": None,
"dtype": "float32",
"label": "PF00072_edDCA",
"nepochs": 50000,
"sampler": "gibbs",
"nchains": 10000,
"seed": 42,
},
checkpt_steps=[0.98, 0.80, 0.60, 0.40, 0.20, 0.10, 0.05, 0.02],
target_density=0.02,
use_wandb=False,
)
# Initialize log header
checkpoint.header_log()
# Train model with checkpoint
mask = torch.ones(
dataset.L, dataset.q, dataset.L, dataset.q,
device=device, dtype=torch.bool
)
fit(
sampler=sampler,
chains=chains,
log_weights=log_weights,
fi_target=dataset.fi,
fij_target=dataset.fij,
params=params,
mask=mask,
lr=0.01,
nsweeps=10,
target_pearson=0.95,
target_density=0.02,
drate=0.01,
checkpoint=checkpoint,
args=checkpoint.args,
)
print("Training complete!")
print(f"Results saved in: {output_dir}")
Weights & Biases Integration
Enable experiment tracking with W&B:
checkpoint = DecCheckpoint(
...,
use_wandb=True, # Enable W&B
)
# Metrics automatically logged to W&B:
# - Epochs
# - Pearson correlation
# - Entropy
# - Coupling density
# - Training time
Prerequisites:
pip install wandb
wandb login
See Also
- Training Module - Training functions using checkpoints
- edDCA Model - Model fitting with checkpoints
- Usage Guide - CLI checkpoint options