Skip to content

module training


function update_params

update_params(
    fi: Tensor,
    fij: Tensor,
    pi: Tensor,
    pij: Tensor,
    params: Dict[str, Tensor],
    mask: Tensor,
    lr: float
) → Dict[str, Tensor]

Updates the parameters of the model.

Args:

  • fi (torch.Tensor): Single-point frequencies of the data.
  • fij (torch.Tensor): Two-points frequencies of the data.
  • pi (torch.Tensor): Single-point marginals of the model.
  • pij (torch.Tensor): Two-points marginals of the model.
  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • mask (torch.Tensor): Mask of the interaction graph.
  • lr (float): Learning rate.

Returns:

  • Dict[str, torch.Tensor]: Updated parameters.

function train_graph

train_graph(
    sampler: Callable,
    chains: Tensor,
    mask: Tensor,
    fi: Tensor,
    fij: Tensor,
    params: Dict[str, Tensor],
    nsweeps: int,
    lr: float,
    max_epochs: int,
    target_pearson: float,
    fi_test: Tensor | None = None,
    fij_test: Tensor | None = None,
    checkpoint: Checkpoint | None = None,
    check_slope: bool = False,
    log_weights: Tensor | None = None,
    progress_bar: bool = True
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]

Trains the model on a given graph until the target Pearson correlation is reached or the maximum number of epochs is exceeded.

Args:

  • sampler (Callable): Sampling function.
  • chains (torch.Tensor): Markov chains simulated with the model.
  • mask (torch.Tensor): Mask encoding the sparse graph.
  • fi (torch.Tensor): Single-point frequencies of the data.
  • fij (torch.Tensor): Two-point frequencies of the data.
  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • nsweeps (int): Number of Gibbs steps for each gradient estimation.
  • lr (float): Learning rate.
  • max_epochs (int): Maximum number of gradient updates to be done.
  • target_pearson (float): Target Pearson coefficient.
  • fi_test (torch.Tensor | None, optional): Single-point frequencies of the test data. Defaults to None.
  • fij_test (torch.Tensor | None, optional): Two-point frequencies of the test data. Defaults to None.
  • checkpoint (Checkpoint | None, optional): Checkpoint class to be used for saving the model. Defaults to None.
  • check_slope (bool, optional): Whether to take into account the slope for the convergence criterion or not. Defaults to False.
  • log_weights (torch.Tensor, optional): Log-weights used for the online computation of the log-likelihood. Defaults to None.
  • progress_bar (bool, optional): Whether to display a progress bar or not. Defaults to True.

Returns:

  • Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation.

This file was automatically generated via lazydocs.