pymdp.inference¶
pymdp.inference
¶
State inference and smoothing utilities for modern JAX-based pymdp agents.
This module provides:
- one-step posterior updates (fpi, exact, ovf),
- sequence-based inference (mmp, vmp),
- backward smoothing utilities for transition/posterior learning.
All public functions operate on JAX arrays and pytrees and are designed to
work with batched agent execution (vmap) and fixed-window sequence buffers.
update_posterior_states(A: list[Array], B: list[Array] | None, obs: list[Array], past_actions: Array | None, prior: list[Array] | None = None, qs_hist: list[Array] | None = None, A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, num_iter: int = 16, method: str = 'fpi', distr_obs: bool = True, inference_horizon: int | None = None, valid_steps: int | Array | None = None, return_info: bool = False) -> list[Array] | tuple[list[Array], VFEInfo]
¶
Infer posterior beliefs over hidden states from observations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
list[Array]
|
Observation likelihood tensors per modality. |
required |
B
|
list[Array] | None
|
Transition model tensors per hidden-state factor. For one-step methods,
this can be provided unchanged. For sequence methods and non- |
required |
obs
|
pytree
|
Observation sequence or single-step observation in distributional form (for example, one-hot vectors per modality). |
required |
past_actions
|
Array | None
|
Action history with shape |
required |
prior
|
list[Array]
|
Prior beliefs over hidden states. Required when |
None
|
qs_hist
|
list[Array]
|
Existing posterior history buffer. If provided, one-step updates append to this history. |
None
|
A_dependencies
|
list[list[int]]
|
Sparse modality-to-factor dependency mapping. |
None
|
B_dependencies
|
list[list[int]]
|
Sparse transition-factor dependency mapping. |
None
|
num_iter
|
int
|
Number of variational update iterations. |
16
|
method
|
(fpi, ovf, mmp, vmp, exact)
|
Inference routine to execute. |
"fpi"
|
distr_obs
|
bool
|
Whether observations are already distributional. |
True
|
inference_horizon
|
int | None
|
Optional truncation horizon for sequence inference. |
None
|
valid_steps
|
int | Array | None
|
Number of valid (unpadded) timesteps for fixed-window sequence inputs. |
None
|
return_info
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
list[Array] or tuple[list[Array], VFEInfo]
|
Posterior state beliefs, with shape semantics depending on |
joint_dist_factor(b: ArrayLike, filtered_qs: list[Array], actions: ArrayLike | None = None) -> tuple[Array, Array]
¶
Compute smoothed marginals and pairwise joints for one factor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
b
|
Array
|
Either an action-conditioned transition sequence
|
required |
filtered_qs
|
list[Array]
|
Filtered posterior sequence for this factor with leading time axis. |
required |
actions
|
Array | None
|
Optional action sequence to select transitions from |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
|
smoothing_ovf(filtered_post: list[Array], B: list[Array], past_actions: Array | None) -> tuple[list[Array], list[Array]]
¶
Run backward smoothing for factorized online variational filtering history.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filtered_post
|
list[Array]
|
Filtering posteriors per factor, each with shape |
required |
B
|
list[Array]
|
Transition tensors per factor. |
required |
past_actions
|
Array
|
Action history with shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[list[Array], list[Array]]
|
|
smoothing_exact(filtered_post: list[Array], B: list[Array], past_actions: Array | None) -> tuple[list[Array], list[Array]]
¶
Exact single-factor HMM backward smoothing from online filtering history.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filtered_post
|
list[Array]
|
List containing one |
required |
B
|
list[Array]
|
List containing one transition tensor in pymdp column-stochastic orientation. |
required |
past_actions
|
Array | None
|
|
required |
Returns:
| Type | Description |
|---|---|
(marginals, joints):
|
marginals: list with one |