pymdp.maths¶
pymdp.maths
¶
calc_vfe(qs: list[ArrayLike], prior: list[ArrayLike], *, obs: list[ArrayLike] | None = None, A: list[ArrayLike] | None = None, B: list[ArrayLike] | None = None, past_actions: ArrayLike | None = None, A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, joint_qs: list[ArrayLike] | None = None, qA: list[ArrayLike] | None = None, pA: list[ArrayLike] | None = None, qB: list[ArrayLike] | None = None, pB: list[ArrayLike] | None = None, obs_valid_mask: ArrayLike | None = None, transition_valid_mask: ArrayLike | None = None, distr_obs: bool = True, return_decomposition: bool = False) -> tuple[ArrayLike, ArrayLike] | tuple[ArrayLike, ArrayLike, dict[str, ArrayLike]]
¶
Compute canonical variational free energy from a model/posterior pair.
This function supports both:
- single-step posteriors, where each qs[f] has shape (num_states_f,), and
- sequence posteriors, where each qs[f] has shape (T, num_states_f).
In the sequence case, transition contributions are assigned to the timestep
they terminate at, so vfe_t[t] contains the q(s_t) entropy, the
observation accuracy for o_t, and either the initial-prior term (at
sequence starts) or the transition-model cross-entropy from t-1 -> t.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs
|
list[ArrayLike]
|
Posterior state marginals. |
required |
prior
|
list[ArrayLike]
|
Prior over initial hidden states (or single-step empirical prior). |
required |
obs
|
list[ArrayLike] | None
|
Observations in distributional or discrete-index form. |
None
|
A
|
list[ArrayLike] | None
|
Likelihood tensors. |
None
|
B
|
list[ArrayLike] | None
|
Transition tensors. |
None
|
past_actions
|
ArrayLike | None
|
Action history with shape |
None
|
A_dependencies
|
list[list[int]] | None
|
Sparse modality-to-factor dependency mapping. |
None
|
B_dependencies
|
list[list[int]] | None
|
Sparse transition-factor dependency mapping. |
None
|
joint_qs
|
list[ArrayLike] | None
|
Optional pairwise posterior beliefs for sequence models. Each
|
None
|
qA
|
list[ArrayLike] | None
|
Optional posterior/prior Dirichlet parameter pairs. When provided, the
corresponding KL terms are added to the total |
None
|
pA
|
list[ArrayLike] | None
|
Optional posterior/prior Dirichlet parameter pairs. When provided, the
corresponding KL terms are added to the total |
None
|
qB
|
list[ArrayLike] | None
|
Optional posterior/prior Dirichlet parameter pairs. When provided, the
corresponding KL terms are added to the total |
None
|
pB
|
list[ArrayLike] | None
|
Optional posterior/prior Dirichlet parameter pairs. When provided, the
corresponding KL terms are added to the total |
None
|
obs_valid_mask
|
ArrayLike | None
|
Validity mask for padded observation windows. |
None
|
transition_valid_mask
|
ArrayLike | None
|
Validity mask for transitions in padded sequence windows. |
None
|
distr_obs
|
bool
|
Whether observations are already categorical distributions. |
True
|
return_decomposition
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
tuple
|
|
compute_accuracy(qs: list[ArrayLike], obs: list[ArrayLike], A: list[ArrayLike], A_dependencies: list[list[int]] | None = None, distr_obs: bool = True) -> ArrayLike
¶
Compute the accuracy portion of variational free energy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs
|
list[ArrayLike]
|
Marginal state beliefs. |
required |
obs
|
list[ArrayLike]
|
Observations for each modality. |
required |
A
|
list[ArrayLike]
|
Likelihood tensors for each modality. |
required |
A_dependencies
|
list[list[int]] | None
|
Sparse modality-to-factor dependency mapping. |
None
|
distr_obs
|
bool
|
Whether observations are already categorical distributions. |
True
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Expected log-likelihood term. |
compute_log_likelihood(obs: list[ArrayLike], A: list[ArrayLike], distr_obs: bool = True) -> ArrayLike
¶
Compute likelihood over hidden states across observations from different modalities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
list[ArrayLike]
|
Observations for each modality. |
required |
A
|
list[ArrayLike]
|
Likelihood tensors for each modality. |
required |
distr_obs
|
bool
|
Interpret observations as distributions if |
True
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Combined log-likelihood over hidden states. |
compute_log_likelihood_per_modality(obs: list[ArrayLike], A: list[ArrayLike], distr_obs: bool = True) -> list[ArrayLike]
¶
Compute likelihood over hidden states per modality.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
list[ArrayLike]
|
Observations for each modality. |
required |
A
|
list[ArrayLike]
|
Likelihood tensors for each modality. |
required |
distr_obs
|
bool
|
Interpret observations as distributions if |
True
|
Returns:
| Type | Description |
|---|---|
list[ArrayLike]
|
Per-modality log-likelihood tensors. |
compute_log_likelihood_per_modality_end2end2_padded(obs_padded: ArrayLike, A_padded: ArrayLike, sparsity: str) -> ArrayLike
¶
Compute padded end-to-end per-modality likelihood.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_padded
|
ArrayLike
|
Padded observations. |
required |
A_padded
|
ArrayLike
|
Padded likelihood tensors. |
required |
sparsity
|
str
|
If |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Log-likelihood tensor. |
compute_log_likelihood_single_modality(o_m: ArrayLike, A_m: ArrayLike, distr_obs: bool = True) -> ArrayLike
¶
Compute observation log-likelihood for a single modality.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
o_m
|
ArrayLike
|
Observation for one modality. |
required |
A_m
|
ArrayLike
|
Likelihood tensor for one modality. |
required |
distr_obs
|
bool
|
Interpret |
True
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Log-likelihood for this modality. |
compute_log_likelihoods_block_diag(A_big: ArrayLike, obs_big: ArrayLike, state_shapes: Sequence[tuple[int, ...]], cuts: Sequence[int], use_einsum: bool = False) -> list[ArrayLike]
¶
Compute log-likelihoods using a block-diagonal approach.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A_big
|
ArrayLike
|
Block-diagonal likelihood matrix. |
required |
obs_big
|
ArrayLike
|
Concatenated block observations. |
required |
state_shapes
|
Sequence[tuple[int, ...]]
|
Shape of each modality block. |
required |
cuts
|
Sequence[int]
|
Cumulative cut indices. |
required |
use_einsum
|
bool
|
Use explicit einsum path if |
False
|
Returns:
| Type | Description |
|---|---|
list[ArrayLike]
|
Per-modality log-likelihood tensors. |
compute_log_likelihoods_flat_block_diag(A_big: ArrayLike, obs_big: ArrayLike) -> ArrayLike
¶
Compute flat log-likelihoods using block-diagonal multiplication.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A_big
|
ArrayLike
|
Block-diagonal likelihood matrix. |
required |
obs_big
|
ArrayLike
|
Block-diagonal observations. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Flat log-likelihoods. |
compute_log_likelihoods_flat_block_diag_einsum(A_big: ArrayLike, obs_big: ArrayLike) -> ArrayLike
¶
Compute flat log-likelihoods using block-diagonal einsum.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A_big
|
ArrayLike
|
Block-diagonal likelihood matrix. |
required |
obs_big
|
ArrayLike
|
Block-diagonal observations. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Flat log-likelihoods. |
compute_log_likelihoods_padded(obs_padded: ArrayLike, A_padded: ArrayLike) -> ArrayLike
¶
Compute padded log-likelihoods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_padded
|
ArrayLike
|
Padded observations. |
required |
A_padded
|
ArrayLike
|
Padded likelihood tensor. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Log-stable likelihood over padded input. |
deconstruct_lls(lls_padded: ArrayLike, A_shapes: Sequence[tuple[int, ...]]) -> list[ArrayLike]
¶
Split padded likelihood tensor into modality-specific blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lls_padded
|
ArrayLike
|
Combined padded log-likelihood tensor. |
required |
A_shapes
|
Sequence[tuple[int, ...]]
|
Original unpadded shapes per modality. |
required |
Returns:
| Type | Description |
|---|---|
list[ArrayLike]
|
One tensor per modality. |
deconstruct_log_likelihoods_block_diag(ll_flat: ArrayLike, state_shapes: Sequence[tuple[int, ...]], cuts: Sequence[int]) -> list[ArrayLike]
¶
Split block-diagonal likelihood tensor into per-modality tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ll_flat
|
ArrayLike
|
Flat log-likelihood tensor. |
required |
state_shapes
|
Sequence[tuple[int, ...]]
|
Unwrapped state shapes per modality. |
required |
cuts
|
Sequence[int]
|
Boundary indices for each modality block. |
required |
Returns:
| Type | Description |
|---|---|
list[ArrayLike]
|
Reshaped per-modality likelihood tensors. |
dirichlet_expected_value(dir_arr: ArrayLike, event_dim: int = 0) -> ArrayLike
¶
Returns the expected value of Dirichlet parameters over a set of Categorical distributions, whose event/output dimension is stored in the axis of each array given by event_dim (default is 0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dir_arr
|
ArrayLike
|
Dirichlet parameters. |
required |
event_dim
|
int
|
Event axis to normalize over. |
0
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Expected Categorical probabilities. |
dirichlet_kl_divergence(q_dir: ArrayLike, p_dir: ArrayLike, event_dim: int = 0) -> ArrayLike
¶
Compute KL divergence between two Dirichlet distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_dir
|
ArrayLike
|
Posterior Dirichlet concentration parameters. |
required |
p_dir
|
ArrayLike
|
Prior Dirichlet concentration parameters. |
required |
event_dim
|
int
|
Axis containing the categorical event dimension. |
0
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Scalar KL divergence summed over all conditional contexts. |
factor_dot(M: JAXSparse, xs: list[ArrayLike], keep_dims: Optional[tuple[int]] = None) -> ArrayLike
¶
Dot product of a sparse array with a list of factors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
M
|
JAXSparse
|
Sparse input tensor. |
required |
xs
|
list[ArrayLike]
|
Factors to contract against |
required |
keep_dims
|
Optional[tuple[int]]
|
Axes retained in the output. |
None
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Contracted result. |
factor_dot_flex(M: ArrayLike, xs: list[ArrayLike], dims: list[tuple[int]], keep_dims: Optional[Tuple[int]] = None) -> ArrayLike
¶
Dot product of a multidimensional array with x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
M
|
ArrayLike
|
Tensor to be contracted. |
required |
xs
|
list[ArrayLike]
|
Factors to contract against |
required |
dims
|
list[tuple[int]]
|
Axes in |
required |
keep_dims
|
Optional[Tuple[int]]
|
Axes to retain in the output even if listed in |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Result of the contracted dot product. |
log_stable(x: ArrayLike) -> ArrayLike
¶
Compute stable logarithm with minimum clipping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ArrayLike
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Log-transformed tensor with sparse support handled. |
log_stable_sparse(x: ArrayLike) -> ArrayLike
¶
Compute numerically stable log for sparse or dense input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ArrayLike
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Elementwise log with sparse support preserved. |
multidimensional_outer(arrs: list[ArrayLike]) -> ArrayLike
¶
Compute the outer product of a list of arrays by iterative expansion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arrs
|
list[ArrayLike]
|
List of arrays to combine. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Outer product tensor. |
spm_dot_sparse(X: JAXSparse, x: list[ArrayLike], dims: Optional[list[tuple[int]]], keep_dims: Optional[list[tuple[int]]]) -> ArrayLike
¶
Sparse contraction helper used by :func:factor_dot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
JAXSparse
|
Sparse tensor to contract. |
required |
x
|
list[ArrayLike]
|
Factors to contract against |
required |
dims
|
Optional[list[tuple[int]]]
|
Input axes in |
required |
keep_dims
|
Optional[list[tuple[int]]]
|
Axes preserved in the output. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Contraction result. |
spm_wnorm(A: ArrayLike, exact_param_info_gain: bool = True) -> ArrayLike
¶
Returns the weight matrix used in PyMDP's parameter information-gain term.
Historically this was the heuristic 1/Σα − 1/α. If exact_param_info_gain is set to True we instead return the exact value of
the weight matrix used in the info gain computation defined in _exact_wnorm
while keeping the original function signature so that the rest of the codebase remains unchanged.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
ArrayLike
|
Dirichlet concentration-like array. |
required |
exact_param_info_gain
|
bool
|
Choose exact ( |
True
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Parameter information-gain weight matrix. |
stable_cross_entropy(x: ArrayLike, y: ArrayLike) -> ArrayLike
¶
Compute cross-entropy between two tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ArrayLike
|
Source distribution tensor. |
required |
y
|
ArrayLike
|
Target distribution tensor. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Cross-entropy value. |
stable_entropy(x: ArrayLike) -> ArrayLike
¶
Compute entropy of a probability-like tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ArrayLike
|
Tensor of probabilities. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Entropy value. |
stable_xlogx(x: ArrayLike) -> ArrayLike
¶
Compute x log(x) with non-zero clipping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ArrayLike
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Elementwise |