pymdp.algos¶
pymdp.algos
¶
Core variational-inference and exact-HMM algorithm implementations.
This module contains lower-level routines used by high-level inference/control APIs. Public entry points include fixed-point iteration, message-passing over sequences, and exact single-factor scan-based HMM smoothing.
run_factorized_fpi(A: list[Array], obs: list[Array], prior: list[Array], A_dependencies: list[list[int]], num_iter: int = 1, distr_obs: bool = True) -> list[Array]
¶
Run the fixed point iteration algorithm with sparse dependencies between factors and observations (stored in A_dependencies)
run_mmp(A: list[Array], B: list[Array] | None, obs: list[Array], prior: list[Array], A_dependencies: list[list[int]], B_dependencies: list[list[int]], num_iter: int = 1, tau: float = 1.0, distr_obs: bool = True, obs_valid_mask: Array | None = None, transition_valid_mask: Array | None = None) -> list[Array]
¶
Run marginal message passing over a sequence window.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
list[Array]
|
Model likelihood tensors. |
required |
B
|
list[Array] | None
|
Transition tensors (or |
required |
obs
|
list[Array]
|
Observation sequence per modality. |
required |
prior
|
list[Array]
|
Sequence prior over hidden states. |
required |
A_dependencies
|
list[list[int]]
|
Sparse observation dependencies per modality. |
required |
B_dependencies
|
list[list[int]]
|
Sparse transition dependencies per factor. |
required |
num_iter
|
int
|
Number of variational update iterations. |
1
|
tau
|
float
|
Mirror-descent step size. |
1.0
|
distr_obs
|
bool
|
Whether observations are already distributional. |
True
|
obs_valid_mask
|
Array | None
|
Optional validity mask for padded observation windows. |
None
|
transition_valid_mask
|
Array | None
|
Optional validity mask for transitions in padded windows. |
None
|
Returns:
| Type | Description |
|---|---|
list[Array]
|
Sequence posterior beliefs per hidden-state factor. |
run_vmp(A: list[Array], B: list[Array] | None, obs: list[Array], prior: list[Array], A_dependencies: list[list[int]], B_dependencies: list[list[int]], num_iter: int = 1, tau: float = 1.0, distr_obs: bool = True, obs_valid_mask: Array | None = None, transition_valid_mask: Array | None = None) -> list[Array]
¶
Run variational message passing over a sequence window.
Parameters are identical to :func:run_mmp.
Returns:
| Type | Description |
|---|---|
list[Array]
|
Sequence posterior beliefs per hidden-state factor (same structure as
:func: |
run_exact_single_factor_hmm_scan(obs: list[Array], A: list[Array], B: list[Array], prior: list[Array], actions: Array | None = None, distr_obs: bool = True) -> tuple[Array, list[Array], list[Array], list[Array]]
¶
pymdp-style single-factor wrapper around the column-stochastic scan smoother.
Notes
A,B, andpriorare expected as singleton lists (one hidden-state factor).B[0]must be in pymdp-native column-stochastic orientation:(K_next, K_curr[, n_actions])(no transpose required).- Returns
(mll, qs, ps, qss)withqss[0]equal top(z_t | z_{t+1}, x_{1:T})in(T-1, K_next, K_curr)orientation.
hmm_filter_scan_colstoch(initial_probs: Array, B_mats: Array, log_likelihoods: Array) -> tuple[Array, Array, Array]
¶
Exact HMM filtering via lax.associative_scan for column-stochastic transitions.
This is the pymdp-native transition orientation:
B_mats[t, j, i] = p(z_{t+1}=j | z_t=i).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_probs
|
Array
|
Initial distribution |
required |
B_mats
|
Array
|
Column-stochastic transitions, stationary |
required |
log_likelihoods
|
Array
|
|
required |
Returns:
| Type | Description |
|---|---|
(marginal_loglik, filtered_probs, predicted_probs)
|
|
hmm_smoother_scan_colstoch(initial_probs: Array, B_mats: Array, log_likelihoods: Array, return_trans_probs: bool = False) -> tuple[Any, ...]
¶
Exact HMM filtering + smoothing via associative scans for column-stochastic transitions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_probs
|
Array
|
Initial distribution |
required |
B_mats
|
Array
|
Column-stochastic transitions, stationary |
required |
log_likelihoods
|
Array
|
|
required |
return_trans_probs
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
tuple
|
If |