pymdp.learning¶
pymdp.learning
¶
Dirichlet-parameter learning updates for modern JAX pymdp models.
update_obs_likelihood_dirichlet_m(pA_m: Array, obs_m: Array, qs: list[Array], dependencies_m: list[int], lr: float = 1.0) -> tuple[Array, Array]
¶
Update one modality's Dirichlet parameters for the observation model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pA_m
|
Array
|
Current Dirichlet concentration parameters for modality |
required |
obs_m
|
Array
|
Observation sequence for modality |
required |
qs
|
list[Array]
|
Posterior beliefs over hidden-state factors. |
required |
dependencies_m
|
list[int]
|
Hidden-state factors that modality |
required |
lr
|
float
|
Learning-rate multiplier for concentration updates. |
1.0
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Updated concentration parameters and expected likelihood tensor for
modality |
update_obs_likelihood_dirichlet(pA: list[Array | None], A: list[Array], obs: list[Array], qs: list[Array], *, A_dependencies: list[list[int]], categorical_obs: bool, num_obs: list[int], lr: float) -> tuple[list[Array | None], list[Array]]
¶
Update Dirichlet parameters of the observation likelihood (A matrix) given observations and beliefs.
JAX version of pymdp.learning.update_obs_likelihood_dirichlet
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pA
|
List[Array]
|
Prior Dirichlet parameters for A matrices |
required |
A
|
List[Array]
|
Current A matrices (observation likelihoods) |
required |
obs
|
List[Array]
|
Observations (either discrete indices or categorical distributions depending on categorical_obs) |
required |
qs
|
List[Array]
|
Posterior beliefs over hidden states |
required |
A_dependencies
|
List[List[int]]
|
Dependencies between observation modalities and state factors |
required |
categorical_obs
|
bool
|
If True, observations are probability distributions; if False, discrete indices |
required |
num_obs
|
List[int]
|
Number of observations for each modality |
required |
lr
|
float
|
Learning rate |
required |
Returns:
| Name | Type | Description |
|---|---|---|
qA |
List[Array]
|
Updated Dirichlet parameters |
E_qA |
List[Array]
|
Expected values (updated A matrices) |
update_state_transition_dirichlet_f(pB_f: Array, actions_f: Array, joint_qs_f: Array | list[Array], lr: float = 1.0) -> tuple[Array, Array]
¶
Update one factor's Dirichlet parameters for the transition model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pB_f
|
Array
|
Current Dirichlet concentration parameters for hidden-state factor |
required |
actions_f
|
Array
|
One-hot action history for factor |
required |
joint_qs_f
|
Array | list[Array]
|
Pairwise state beliefs (for example, |
required |
lr
|
float
|
Learning-rate multiplier for concentration updates. |
1.0
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Updated concentration parameters and expected transition tensor. |
update_state_transition_dirichlet(pB: list[Array], B: list[Array], joint_beliefs: list[Array], actions: Array, *, num_controls: list[int], lr: float, factors_to_update: str | list[int] = 'all') -> tuple[list[Array], list[Array]]
¶
Update posterior Diriichlet parameters of the state transition likelihood model (B) given the joint beliefs over hidden states and actions.
Supports selective learning of only particular hidden-state factors via
factors_to_update (either "all" or a List[int]).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pB
|
list[Array]
|
Dirichlet concentration parameters for transition model factors. |
required |
B
|
list[Array]
|
Current expected transition tensors. |
required |
joint_beliefs
|
list[Array]
|
Time-aligned joint beliefs per factor for transition learning. |
required |
actions
|
Array
|
Integer action history with shape |
required |
num_controls
|
list[int]
|
Number of control states per factor. |
required |
lr
|
float
|
Learning-rate multiplier for concentration updates. |
required |
factors_to_update
|
all | List[int]
|
Which hidden-state factors should be updated. |
"all"
|
Returns:
| Type | Description |
|---|---|
tuple[List[Array], List[Array]]
|
Updated concentration parameters and expected transition tensors. |