pymdp.learning

pymdp.learning

Dirichlet-parameter learning updates for modern JAX pymdp models.

update_obs_likelihood_dirichlet_m(pA_m: Array, obs_m: Array, qs: list[Array], dependencies_m: list[int], lr: float = 1.0) -> tuple[Array, Array]

Update one modality's Dirichlet parameters for the observation model.

Parameters:

Name Type Description Default
pA_m Array

Current Dirichlet concentration parameters for modality m.

required
obs_m Array

Observation sequence for modality m in categorical/distributional form with leading time axis.

required
qs list[Array]

Posterior beliefs over hidden-state factors.

required
dependencies_m list[int]

Hidden-state factors that modality m depends on.

required
lr float

Learning-rate multiplier for concentration updates.

1.0

Returns:

Type Description
tuple[Array, Array]

Updated concentration parameters and expected likelihood tensor for modality m.

update_obs_likelihood_dirichlet(pA: list[Array | None], A: list[Array], obs: list[Array], qs: list[Array], *, A_dependencies: list[list[int]], categorical_obs: bool, num_obs: list[int], lr: float) -> tuple[list[Array | None], list[Array]]

Update Dirichlet parameters of the observation likelihood (A matrix) given observations and beliefs.

JAX version of pymdp.learning.update_obs_likelihood_dirichlet

Parameters:

Name Type Description Default
pA List[Array]

Prior Dirichlet parameters for A matrices

required
A List[Array]

Current A matrices (observation likelihoods)

required
obs List[Array]

Observations (either discrete indices or categorical distributions depending on categorical_obs)

required
qs List[Array]

Posterior beliefs over hidden states

required
A_dependencies List[List[int]]

Dependencies between observation modalities and state factors

required
categorical_obs bool

If True, observations are probability distributions; if False, discrete indices

required
num_obs List[int]

Number of observations for each modality

required
lr float

Learning rate

required

Returns:

Name Type Description
qA List[Array]

Updated Dirichlet parameters

E_qA List[Array]

Expected values (updated A matrices)

update_state_transition_dirichlet_f(pB_f: Array, actions_f: Array, joint_qs_f: Array | list[Array], lr: float = 1.0) -> tuple[Array, Array]

Update one factor's Dirichlet parameters for the transition model.

Parameters:

Name Type Description Default
pB_f Array

Current Dirichlet concentration parameters for hidden-state factor f.

required
actions_f Array

One-hot action history for factor f with leading time axis.

required
joint_qs_f Array | list[Array]

Pairwise state beliefs (for example, q(s_t, s_{t-1})) used for transition learning.

required
lr float

Learning-rate multiplier for concentration updates.

1.0

Returns:

Type Description
tuple[Array, Array]

Updated concentration parameters and expected transition tensor.

update_state_transition_dirichlet(pB: list[Array], B: list[Array], joint_beliefs: list[Array], actions: Array, *, num_controls: list[int], lr: float, factors_to_update: str | list[int] = 'all') -> tuple[list[Array], list[Array]]

Update posterior Diriichlet parameters of the state transition likelihood model (B) given the joint beliefs over hidden states and actions.

Supports selective learning of only particular hidden-state factors via factors_to_update (either "all" or a List[int]).

Parameters:

Name Type Description Default
pB list[Array]

Dirichlet concentration parameters for transition model factors.

required
B list[Array]

Current expected transition tensors.

required
joint_beliefs list[Array]

Time-aligned joint beliefs per factor for transition learning.

required
actions Array

Integer action history with shape (batch, T-1, num_factors).

required
num_controls list[int]

Number of control states per factor.

required
lr float

Learning-rate multiplier for concentration updates.

required
factors_to_update all | List[int]

Which hidden-state factors should be updated.

"all"

Returns:

Type Description
tuple[List[Array], List[Array]]

Updated concentration parameters and expected transition tensors.