pymdp.inference

pymdp.inference

State inference and smoothing utilities for modern JAX-based pymdp agents.

This module provides: - one-step posterior updates (fpi, exact, ovf), - sequence-based inference (mmp, vmp), - backward smoothing utilities for transition/posterior learning.

All public functions operate on JAX arrays and pytrees and are designed to work with batched agent execution (vmap) and fixed-window sequence buffers.

update_posterior_states(A: list[Array], B: list[Array] | None, obs: list[Array], past_actions: Array | None, prior: list[Array] | None = None, qs_hist: list[Array] | None = None, A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, num_iter: int = 16, method: str = 'fpi', distr_obs: bool = True, inference_horizon: int | None = None, valid_steps: int | Array | None = None, return_info: bool = False) -> list[Array] | tuple[list[Array], VFEInfo]

Infer posterior beliefs over hidden states from observations.

Parameters:

Name Type Description Default
A list[Array]

Observation likelihood tensors per modality.

required
B list[Array] | None

Transition model tensors per hidden-state factor. For one-step methods, this can be provided unchanged. For sequence methods and non-None past_actions, transitions are conditioned per timestep.

required
obs pytree

Observation sequence or single-step observation in distributional form (for example, one-hot vectors per modality).

required
past_actions Array | None

Action history with shape (T-1, num_factors) for sequence methods. Can be None when no valid history is available.

required
prior list[Array]

Prior beliefs over hidden states. Required when return_info=True so canonical VFE diagnostics can be computed.

None
qs_hist list[Array]

Existing posterior history buffer. If provided, one-step updates append to this history.

None
A_dependencies list[list[int]]

Sparse modality-to-factor dependency mapping.

None
B_dependencies list[list[int]]

Sparse transition-factor dependency mapping.

None
num_iter int

Number of variational update iterations.

16
method (fpi, ovf, mmp, vmp, exact)

Inference routine to execute.

"fpi"
distr_obs bool

Whether observations are already distributional.

True
inference_horizon int | None

Optional truncation horizon for sequence inference.

None
valid_steps int | Array | None

Number of valid (unpadded) timesteps for fixed-window sequence inputs.

None
return_info bool

If True, also return an info dictionary containing canonical VFE diagnostics (vfe_t, vfe, and component terms). For forward-only methods (fpi, ovf, exact) this reports the current posterior / history returned by the inference call; if you need a full smoothed sequence VFE for ovf or exact, first run smoothing_ovf(...) or smoothing_exact(...) and then call pymdp.maths.calc_vfe(..., joint_qs=...).

False

Returns:

Type Description
list[Array] or tuple[list[Array], VFEInfo]

Posterior state beliefs, with shape semantics depending on method: one-step methods return/append a time axis, sequence methods return full sequence posteriors. If return_info=True, a second return value contains canonical VFE diagnostics.

joint_dist_factor(b: ArrayLike, filtered_qs: list[Array], actions: ArrayLike | None = None) -> tuple[Array, Array]

Compute smoothed marginals and pairwise joints for one factor.

Parameters:

Name Type Description Default
b Array

Either an action-conditioned transition sequence (T-1, K_next, K_curr) or a transition tensor with action axis (..., K_next, K_curr, n_actions).

required
filtered_qs list[Array]

Filtered posterior sequence for this factor with leading time axis.

required
actions Array | None

Optional action sequence to select transitions from b.

None

Returns:

Type Description
tuple[Array, Array]

(smoothed_marginals, pairwise_joints) where: - smoothed marginals have shape (T, K) - joints have shape (T-1, K_next, K_curr)

smoothing_ovf(filtered_post: list[Array], B: list[Array], past_actions: Array | None) -> tuple[list[Array], list[Array]]

Run backward smoothing for factorized online variational filtering history.

Parameters:

Name Type Description Default
filtered_post list[Array]

Filtering posteriors per factor, each with shape (T, K_f) (or equivalent leading-time layout for a batch element).

required
B list[Array]

Transition tensors per factor.

required
past_actions Array

Action history with shape (T-1, num_factors).

required

Returns:

Type Description
tuple[list[Array], list[Array]]

(marginals, joints) per factor.

smoothing_exact(filtered_post: list[Array], B: list[Array], past_actions: Array | None) -> tuple[list[Array], list[Array]]

Exact single-factor HMM backward smoothing from online filtering history.

Parameters:

Name Type Description Default
filtered_post list[Array]

List containing one (T, K) array of filtering marginals.

required
B list[Array]

List containing one transition tensor in pymdp column-stochastic orientation.

required
past_actions Array | None

(T-1, 1) or (T-1,) action history used to select transitions.

required

Returns:

Type Description
(marginals, joints):

marginals: list with one (T, K) smoothed marginal array. joints: list with one (T-1, K_next, K_curr) pairwise posterior array.