Learning
The learning.py
module contains the functions for updating parameters of Dirichlet posteriors (that paramaterise categorical priors and likelihoods) in POMDP generative models.
- pymdp.learning.update_obs_likelihood_dirichlet(pA, A, obs, qs, lr=1.0, modalities='all')
Update Dirichlet parameters of the observation likelihood distribution.
- Parameters
pA (
numpy.ndarray
of dtype object) – Prior Dirichlet parameters over observation model (same shape asA
)A (
numpy.ndarray
of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each elementA[m]
of stores annumpy.ndarray
multidimensional array for observation modalitym
, whose entriesA[m][i, j, k, ...]
store the probability of observation leveli
given hidden state levelsj, k, ...
obs (1D
numpy.ndarray
,numpy.ndarray
of dtype object,int
ortuple
) – The observation (generated by the environment). If single modality, this can be a 1Dnumpy.ndarray
(one-hot vector representation) or anint
(observation index) If multi-modality, this can benumpy.ndarray
of dtype object whose entries are 1D one-hot vectors, or atuple
(ofint
)qs (1D
numpy.ndarray
ornumpy.ndarray
of dtype object, default None) – Marginal posterior beliefs over hidden states at current timepoint.lr (float, default 1.0) – Learning rate, scale of the Dirichlet pseudo-count update.
modalities (
list
, default “all”) – Indices (ranging from 0 ton_modalities - 1
) of the observation modalities to include in learning. Defaults to “all”, meaning that modality-specific sub-arrays ofpA
are all updated using the corresponding observations.
- Returns
qA – Posterior Dirichlet parameters over observation model (same shape as
A
), after having updated it with observations.- Return type
numpy.ndarray
of dtype object
- pymdp.learning.update_obs_likelihood_dirichlet_factorized(pA, A, obs, qs, A_factor_list, lr=1.0, modalities='all')
Update Dirichlet parameters of the observation likelihood distribution, in a case where the observation model is reduced (factorized) and only represents the conditional dependencies between the observation modalities and particular hidden state factors (whose indices are specified in each modality-specific entry of
A_factor_list
)- Parameters
pA (
numpy.ndarray
of dtype object) – Prior Dirichlet parameters over observation model (same shape asA
)A (
numpy.ndarray
of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each elementA[m]
of stores annumpy.ndarray
multidimensional array for observation modalitym
, whose entriesA[m][i, j, k, ...]
store the probability of observation leveli
given hidden state levelsj, k, ...
obs (1D
numpy.ndarray
,numpy.ndarray
of dtype object,int
ortuple
) – The observation (generated by the environment). If single modality, this can be a 1Dnumpy.ndarray
(one-hot vector representation) or anint
(observation index) If multi-modality, this can benumpy.ndarray
of dtype object whose entries are 1D one-hot vectors, or atuple
(ofint
)qs (1D
numpy.ndarray
ornumpy.ndarray
of dtype object, default None) – Marginal posterior beliefs over hidden states at current timepoint.A_factor_list (
list
oflist
ofint
) – List of lists, where each list with index m contains the indices of the hidden states that observation modality m depends on.lr (float, default 1.0) – Learning rate, scale of the Dirichlet pseudo-count update.
modalities (
list
, default “all”) – Indices (ranging from 0 ton_modalities - 1
) of the observation modalities to include in learning. Defaults to “all”, meaning that modality-specific sub-arrays ofpA
are all updated using the corresponding observations.
- Returns
qA – Posterior Dirichlet parameters over observation model (same shape as
A
), after having updated it with observations.- Return type
numpy.ndarray
of dtype object
- pymdp.learning.update_state_likelihood_dirichlet(pB, B, actions, qs, qs_prev, lr=1.0, factors='all')
Update Dirichlet parameters of the transition distribution.
- Parameters
pB (
numpy.ndarray
of dtype object) – Prior Dirichlet parameters over transition model (same shape asB
)B (
numpy.ndarray
of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states att
to hidden states att+1
, given some control stateu
. Each elementB[f]
of this object array stores a 3-D tensor for hidden state factorf
, whose entriesB[f][s, v, u]
store the probability of hidden state levels
at the current time, given hidden state levelv
and actionu
at the previous time.actions (1D
numpy.ndarray
) – A vector with length equal to the number of control factors, where each element contains the index of the action (for that control factor) performed at a given timestep.qs (1D
numpy.ndarray
ornumpy.ndarray
of dtype object) – Marginal posterior beliefs over hidden states at current timepoint.qs_prev (1D
numpy.ndarray
ornumpy.ndarray
of dtype object) – Marginal posterior beliefs over hidden states at previous timepoint.lr (float, default
1.0
) – Learning rate, scale of the Dirichlet pseudo-count update.factors (
list
, default “all”) – Indices (ranging from 0 ton_factors - 1
) of the hidden state factors to include in learning. Defaults to “all”, meaning that factor-specific sub-arrays ofpB
are all updated using the corresponding hidden state distributions and actions.
- Returns
qB – Posterior Dirichlet parameters over transition model (same shape as
B
), after having updated it with state beliefs and actions.- Return type
numpy.ndarray
of dtype object
- pymdp.learning.update_state_likelihood_dirichlet_interactions(pB, B, actions, qs, qs_prev, B_factor_list, lr=1.0, factors='all')
Update Dirichlet parameters of the transition distribution, in the case when ‘interacting’ hidden state factors are present, i.e. the dynamics of a given hidden state factor f are no longer independent of the dynamics of other hidden state factors.
- Parameters
pB (
numpy.ndarray
of dtype object) – Prior Dirichlet parameters over transition model (same shape asB
)B (
numpy.ndarray
of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states att
to hidden states att+1
, given some control stateu
. Each elementB[f]
of this object array stores a 3-D tensor for hidden state factorf
, whose entriesB[f][s, v, u]
store the probability of hidden state levels
at the current time, given hidden state levelv
and actionu
at the previous time.actions (1D
numpy.ndarray
) – A vector with length equal to the number of control factors, where each element contains the index of the action (for that control factor) performed at a given timestep.qs (1D
numpy.ndarray
ornumpy.ndarray
of dtype object) – Marginal posterior beliefs over hidden states at current timepoint.qs_prev (1D
numpy.ndarray
ornumpy.ndarray
of dtype object) – Marginal posterior beliefs over hidden states at previous timepoint.B_factor_list (
list
oflist
ofint
) – A list of lists, where each elementB_factor_list[f]
is a list of indices of hidden state factors that that are needed to predict the dynamics of hidden state factorf
.lr (float, default
1.0
) – Learning rate, scale of the Dirichlet pseudo-count update.factors (
list
, default “all”) – Indices (ranging from 0 ton_factors - 1
) of the hidden state factors to include in learning. Defaults to “all”, meaning that factor-specific sub-arrays ofpB
are all updated using the corresponding hidden state distributions and actions.
- Returns
qB – Posterior Dirichlet parameters over transition model (same shape as
B
), after having updated it with state beliefs and actions.- Return type
numpy.ndarray
of dtype object
- pymdp.learning.update_state_prior_dirichlet(pD, qs, lr=1.0, factors='all')
Update Dirichlet parameters of the initial hidden state distribution (prior beliefs about hidden states at the beginning of the inference window).
- Parameters
pD (
numpy.ndarray
of dtype object) – Prior Dirichlet parameters over initial hidden state prior (same shape asqs
)qs (1D
numpy.ndarray
ornumpy.ndarray
of dtype object) – Marginal posterior beliefs over hidden states at current timepointlr (float, default
1.0
) – Learning rate, scale of the Dirichlet pseudo-count update.factors (
list
, default “all”) – Indices (ranging from 0 ton_factors - 1
) of the hidden state factors to include in learning. Defaults to “all”, meaning that factor-specific sub-vectors ofpD
are all updated using the corresponding hidden state distributions.
- Returns
qD – Posterior Dirichlet parameters over initial hidden state prior (same shape as
qs
), after having updated it with state beliefs.- Return type
numpy.ndarray
of dtype object