module training
function update_params
update_params(
fi: Tensor,
fij: Tensor,
pi: Tensor,
pij: Tensor,
params: Dict[str, Tensor],
mask: Tensor,
lr: float
) → Dict[str, Tensor]
Updates the parameters of the model.
Args:
fi(torch.Tensor): Single-point frequencies of the data.fij(torch.Tensor): Two-points frequencies of the data.pi(torch.Tensor): Single-point marginals of the model.pij(torch.Tensor): Two-points marginals of the model.params(Dict[str, torch.Tensor]): Parameters of the model.mask(torch.Tensor): Mask of the interaction graph.lr(float): Learning rate.
Returns:
Dict[str, torch.Tensor]: Updated parameters.
function train_graph
train_graph(
sampler: Callable,
chains: Tensor,
mask: Tensor,
fi: Tensor,
fij: Tensor,
params: Dict[str, Tensor],
nsweeps: int,
lr: float,
max_epochs: int,
target_pearson: float,
fi_test: Tensor | None = None,
fij_test: Tensor | None = None,
checkpoint: Checkpoint | None = None,
check_slope: bool = False,
log_weights: Tensor | None = None,
progress_bar: bool = True
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]
Trains the model on a given graph until the target Pearson correlation is reached or the maximum number of epochs is exceeded.
Args:
sampler(Callable): Sampling function.chains(torch.Tensor): Markov chains simulated with the model.mask(torch.Tensor): Mask encoding the sparse graph.fi(torch.Tensor): Single-point frequencies of the data.fij(torch.Tensor): Two-point frequencies of the data.params(Dict[str, torch.Tensor]): Parameters of the model.nsweeps(int): Number of Gibbs steps for each gradient estimation.lr(float): Learning rate.max_epochs(int): Maximum number of gradient updates to be done.target_pearson(float): Target Pearson coefficient.fi_test(torch.Tensor | None, optional): Single-point frequencies of the test data. Defaults to None.fij_test(torch.Tensor | None, optional): Two-point frequencies of the test data. Defaults to None.checkpoint(Checkpoint | None, optional): Checkpoint class to be used for saving the model. Defaults to None.check_slope(bool, optional): Whether to take into account the slope for the convergence criterion or not. Defaults to False.log_weights(torch.Tensor, optional): Log-weights used for the online computation of the log-likelihood. Defaults to None.progress_bar(bool, optional): Whether to display a progress bar or not. Defaults to True.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation.
This file was automatically generated via lazydocs.