pymdp.maths

pymdp.maths

calc_vfe(qs: list[ArrayLike], prior: list[ArrayLike], *, obs: list[ArrayLike] | None = None, A: list[ArrayLike] | None = None, B: list[ArrayLike] | None = None, past_actions: ArrayLike | None = None, A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, joint_qs: list[ArrayLike] | None = None, qA: list[ArrayLike] | None = None, pA: list[ArrayLike] | None = None, qB: list[ArrayLike] | None = None, pB: list[ArrayLike] | None = None, obs_valid_mask: ArrayLike | None = None, transition_valid_mask: ArrayLike | None = None, distr_obs: bool = True, return_decomposition: bool = False) -> tuple[ArrayLike, ArrayLike] | tuple[ArrayLike, ArrayLike, dict[str, ArrayLike]]

Compute canonical variational free energy from a model/posterior pair.

This function supports both: - single-step posteriors, where each qs[f] has shape (num_states_f,), and - sequence posteriors, where each qs[f] has shape (T, num_states_f).

In the sequence case, transition contributions are assigned to the timestep they terminate at, so vfe_t[t] contains the q(s_t) entropy, the observation accuracy for o_t, and either the initial-prior term (at sequence starts) or the transition-model cross-entropy from t-1 -> t.

Parameters:

Name Type Description Default
qs list[ArrayLike]

Posterior state marginals.

required
prior list[ArrayLike]

Prior over initial hidden states (or single-step empirical prior).

required
obs list[ArrayLike] | None

Observations in distributional or discrete-index form.

None
A list[ArrayLike] | None

Likelihood tensors.

None
B list[ArrayLike] | None

Transition tensors.

None
past_actions ArrayLike | None

Action history with shape (T-1, num_factors) for sequence VFE.

None
A_dependencies list[list[int]] | None

Sparse modality-to-factor dependency mapping.

None
B_dependencies list[list[int]] | None

Sparse transition-factor dependency mapping.

None
joint_qs list[ArrayLike] | None

Optional pairwise posterior beliefs for sequence models. Each joint_qs[f] should have shape (T-1, num_states[f], *[num_states[d] for d in B_dependencies[f]]). When provided, the state-dependent terms of sequence VFE are computed from the full smoothed chain posterior rather than the mean-field product of adjacent marginals.

None
qA list[ArrayLike] | None

Optional posterior/prior Dirichlet parameter pairs. When provided, the corresponding KL terms are added to the total vfe.

None
pA list[ArrayLike] | None

Optional posterior/prior Dirichlet parameter pairs. When provided, the corresponding KL terms are added to the total vfe.

None
qB list[ArrayLike] | None

Optional posterior/prior Dirichlet parameter pairs. When provided, the corresponding KL terms are added to the total vfe.

None
pB list[ArrayLike] | None

Optional posterior/prior Dirichlet parameter pairs. When provided, the corresponding KL terms are added to the total vfe.

None
obs_valid_mask ArrayLike | None

Validity mask for padded observation windows.

None
transition_valid_mask ArrayLike | None

Validity mask for transitions in padded sequence windows.

None
distr_obs bool

Whether observations are already categorical distributions.

True
return_decomposition bool

If True, also return a dictionary of component terms.

False

Returns:

Type Description
tuple

(vfe_t, vfe) by default. vfe_t has shape (T,) for sequences or scalar shape () for single-step posteriors. vfe is always scalar. When return_decomposition=True, a third element is returned with component arrays and optional parameter KL terms.

compute_accuracy(qs: list[ArrayLike], obs: list[ArrayLike], A: list[ArrayLike], A_dependencies: list[list[int]] | None = None, distr_obs: bool = True) -> ArrayLike

Compute the accuracy portion of variational free energy.

Parameters:

Name Type Description Default
qs list[ArrayLike]

Marginal state beliefs.

required
obs list[ArrayLike]

Observations for each modality.

required
A list[ArrayLike]

Likelihood tensors for each modality.

required
A_dependencies list[list[int]] | None

Sparse modality-to-factor dependency mapping.

None
distr_obs bool

Whether observations are already categorical distributions.

True

Returns:

Type Description
ArrayLike

Expected log-likelihood term.

compute_log_likelihood(obs: list[ArrayLike], A: list[ArrayLike], distr_obs: bool = True) -> ArrayLike

Compute likelihood over hidden states across observations from different modalities.

Parameters:

Name Type Description Default
obs list[ArrayLike]

Observations for each modality.

required
A list[ArrayLike]

Likelihood tensors for each modality.

required
distr_obs bool

Interpret observations as distributions if True.

True

Returns:

Type Description
ArrayLike

Combined log-likelihood over hidden states.

compute_log_likelihood_per_modality(obs: list[ArrayLike], A: list[ArrayLike], distr_obs: bool = True) -> list[ArrayLike]

Compute likelihood over hidden states per modality.

Parameters:

Name Type Description Default
obs list[ArrayLike]

Observations for each modality.

required
A list[ArrayLike]

Likelihood tensors for each modality.

required
distr_obs bool

Interpret observations as distributions if True.

True

Returns:

Type Description
list[ArrayLike]

Per-modality log-likelihood tensors.

compute_log_likelihood_per_modality_end2end2_padded(obs_padded: ArrayLike, A_padded: ArrayLike, sparsity: str) -> ArrayLike

Compute padded end-to-end per-modality likelihood.

Parameters:

Name Type Description Default
obs_padded ArrayLike

Padded observations.

required
A_padded ArrayLike

Padded likelihood tensors.

required
sparsity str

If "ll_only" return only dense log-likelihoods, else sparse variant.

required

Returns:

Type Description
ArrayLike

Log-likelihood tensor.

compute_log_likelihood_single_modality(o_m: ArrayLike, A_m: ArrayLike, distr_obs: bool = True) -> ArrayLike

Compute observation log-likelihood for a single modality.

Parameters:

Name Type Description Default
o_m ArrayLike

Observation for one modality.

required
A_m ArrayLike

Likelihood tensor for one modality.

required
distr_obs bool

Interpret o_m as distribution if True, otherwise as discrete index.

True

Returns:

Type Description
ArrayLike

Log-likelihood for this modality.

compute_log_likelihoods_block_diag(A_big: ArrayLike, obs_big: ArrayLike, state_shapes: Sequence[tuple[int, ...]], cuts: Sequence[int], use_einsum: bool = False) -> list[ArrayLike]

Compute log-likelihoods using a block-diagonal approach.

Parameters:

Name Type Description Default
A_big ArrayLike

Block-diagonal likelihood matrix.

required
obs_big ArrayLike

Concatenated block observations.

required
state_shapes Sequence[tuple[int, ...]]

Shape of each modality block.

required
cuts Sequence[int]

Cumulative cut indices.

required
use_einsum bool

Use explicit einsum path if True.

False

Returns:

Type Description
list[ArrayLike]

Per-modality log-likelihood tensors.

compute_log_likelihoods_flat_block_diag(A_big: ArrayLike, obs_big: ArrayLike) -> ArrayLike

Compute flat log-likelihoods using block-diagonal multiplication.

Parameters:

Name Type Description Default
A_big ArrayLike

Block-diagonal likelihood matrix.

required
obs_big ArrayLike

Block-diagonal observations.

required

Returns:

Type Description
ArrayLike

Flat log-likelihoods.

compute_log_likelihoods_flat_block_diag_einsum(A_big: ArrayLike, obs_big: ArrayLike) -> ArrayLike

Compute flat log-likelihoods using block-diagonal einsum.

Parameters:

Name Type Description Default
A_big ArrayLike

Block-diagonal likelihood matrix.

required
obs_big ArrayLike

Block-diagonal observations.

required

Returns:

Type Description
ArrayLike

Flat log-likelihoods.

compute_log_likelihoods_padded(obs_padded: ArrayLike, A_padded: ArrayLike) -> ArrayLike

Compute padded log-likelihoods.

Parameters:

Name Type Description Default
obs_padded ArrayLike

Padded observations.

required
A_padded ArrayLike

Padded likelihood tensor.

required

Returns:

Type Description
ArrayLike

Log-stable likelihood over padded input.

deconstruct_lls(lls_padded: ArrayLike, A_shapes: Sequence[tuple[int, ...]]) -> list[ArrayLike]

Split padded likelihood tensor into modality-specific blocks.

Parameters:

Name Type Description Default
lls_padded ArrayLike

Combined padded log-likelihood tensor.

required
A_shapes Sequence[tuple[int, ...]]

Original unpadded shapes per modality.

required

Returns:

Type Description
list[ArrayLike]

One tensor per modality.

deconstruct_log_likelihoods_block_diag(ll_flat: ArrayLike, state_shapes: Sequence[tuple[int, ...]], cuts: Sequence[int]) -> list[ArrayLike]

Split block-diagonal likelihood tensor into per-modality tensors.

Parameters:

Name Type Description Default
ll_flat ArrayLike

Flat log-likelihood tensor.

required
state_shapes Sequence[tuple[int, ...]]

Unwrapped state shapes per modality.

required
cuts Sequence[int]

Boundary indices for each modality block.

required

Returns:

Type Description
list[ArrayLike]

Reshaped per-modality likelihood tensors.

dirichlet_expected_value(dir_arr: ArrayLike, event_dim: int = 0) -> ArrayLike

Returns the expected value of Dirichlet parameters over a set of Categorical distributions, whose event/output dimension is stored in the axis of each array given by event_dim (default is 0).

Parameters:

Name Type Description Default
dir_arr ArrayLike

Dirichlet parameters.

required
event_dim int

Event axis to normalize over.

0

Returns:

Type Description
ArrayLike

Expected Categorical probabilities.

dirichlet_kl_divergence(q_dir: ArrayLike, p_dir: ArrayLike, event_dim: int = 0) -> ArrayLike

Compute KL divergence between two Dirichlet distributions.

Parameters:

Name Type Description Default
q_dir ArrayLike

Posterior Dirichlet concentration parameters.

required
p_dir ArrayLike

Prior Dirichlet concentration parameters.

required
event_dim int

Axis containing the categorical event dimension.

0

Returns:

Type Description
ArrayLike

Scalar KL divergence summed over all conditional contexts.

factor_dot(M: JAXSparse, xs: list[ArrayLike], keep_dims: Optional[tuple[int]] = None) -> ArrayLike

Dot product of a sparse array with a list of factors.

Parameters:

Name Type Description Default
M JAXSparse

Sparse input tensor.

required
xs list[ArrayLike]

Factors to contract against M.

required
keep_dims Optional[tuple[int]]

Axes retained in the output.

None

Returns:

Type Description
ArrayLike

Contracted result.

factor_dot_flex(M: ArrayLike, xs: list[ArrayLike], dims: list[tuple[int]], keep_dims: Optional[Tuple[int]] = None) -> ArrayLike

Dot product of a multidimensional array with x.

Parameters:

Name Type Description Default
M ArrayLike

Tensor to be contracted.

required
xs list[ArrayLike]

Factors to contract against M.

required
dims list[tuple[int]]

Axes in M aligned to each tensor in xs.

required
keep_dims Optional[Tuple[int]]

Axes to retain in the output even if listed in dims.

None

Returns:

Type Description
Array

Result of the contracted dot product.

log_stable(x: ArrayLike) -> ArrayLike

Compute stable logarithm with minimum clipping.

Parameters:

Name Type Description Default
x ArrayLike

Input tensor.

required

Returns:

Type Description
ArrayLike

Log-transformed tensor with sparse support handled.

log_stable_sparse(x: ArrayLike) -> ArrayLike

Compute numerically stable log for sparse or dense input.

Parameters:

Name Type Description Default
x ArrayLike

Input tensor.

required

Returns:

Type Description
ArrayLike

Elementwise log with sparse support preserved.

multidimensional_outer(arrs: list[ArrayLike]) -> ArrayLike

Compute the outer product of a list of arrays by iterative expansion.

Parameters:

Name Type Description Default
arrs list[ArrayLike]

List of arrays to combine.

required

Returns:

Type Description
ArrayLike

Outer product tensor.

spm_dot_sparse(X: JAXSparse, x: list[ArrayLike], dims: Optional[list[tuple[int]]], keep_dims: Optional[list[tuple[int]]]) -> ArrayLike

Sparse contraction helper used by :func:factor_dot.

Parameters:

Name Type Description Default
X JAXSparse

Sparse tensor to contract.

required
x list[ArrayLike]

Factors to contract against X.

required
dims Optional[list[tuple[int]]]

Input axes in X aligned to each entry in x.

required
keep_dims Optional[list[tuple[int]]]

Axes preserved in the output.

required

Returns:

Type Description
ArrayLike

Contraction result.

spm_wnorm(A: ArrayLike, exact_param_info_gain: bool = True) -> ArrayLike

Returns the weight matrix used in PyMDP's parameter information-gain term.

Historically this was the heuristic 1/Σα − 1/α. If exact_param_info_gain is set to True we instead return the exact value of the weight matrix used in the info gain computation defined in _exact_wnorm while keeping the original function signature so that the rest of the codebase remains unchanged.

Parameters:

Name Type Description Default
A ArrayLike

Dirichlet concentration-like array.

required
exact_param_info_gain bool

Choose exact (True) or legacy heuristic (False) form.

True

Returns:

Type Description
ArrayLike

Parameter information-gain weight matrix.

stable_cross_entropy(x: ArrayLike, y: ArrayLike) -> ArrayLike

Compute cross-entropy between two tensors.

Parameters:

Name Type Description Default
x ArrayLike

Source distribution tensor.

required
y ArrayLike

Target distribution tensor.

required

Returns:

Type Description
ArrayLike

Cross-entropy value.

stable_entropy(x: ArrayLike) -> ArrayLike

Compute entropy of a probability-like tensor.

Parameters:

Name Type Description Default
x ArrayLike

Tensor of probabilities.

required

Returns:

Type Description
ArrayLike

Entropy value.

stable_xlogx(x: ArrayLike) -> ArrayLike

Compute x log(x) with non-zero clipping.

Parameters:

Name Type Description Default
x ArrayLike

Input tensor.

required

Returns:

Type Description
ArrayLike

Elementwise x * log(x).