Learning

The learning.py module contains the functions for updating parameters of Dirichlet posteriors (that paramaterise categorical priors and likelihoods) in POMDP generative models.

pymdp.learning.update_obs_likelihood_dirichlet(pA, A, obs, qs, lr=1.0, modalities='all')

Update Dirichlet parameters of the observation likelihood distribution.

Parameters
  • pA (numpy.ndarray of dtype object) – Prior Dirichlet parameters over observation model (same shape as A)

  • A (numpy.ndarray of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each element A[m] of stores an numpy.ndarray multidimensional array for observation modality m, whose entries A[m][i, j, k, ...] store the probability of observation level i given hidden state levels j, k, ...

  • obs (1D numpy.ndarray, numpy.ndarray of dtype object, int or tuple) – The observation (generated by the environment). If single modality, this can be a 1D numpy.ndarray (one-hot vector representation) or an int (observation index) If multi-modality, this can be numpy.ndarray of dtype object whose entries are 1D one-hot vectors, or a tuple (of int)

  • qs (1D numpy.ndarray or numpy.ndarray of dtype object, default None) – Marginal posterior beliefs over hidden states at current timepoint.

  • lr (float, default 1.0) – Learning rate, scale of the Dirichlet pseudo-count update.

  • modalities (list, default “all”) – Indices (ranging from 0 to n_modalities - 1) of the observation modalities to include in learning. Defaults to “all”, meaning that modality-specific sub-arrays of pA are all updated using the corresponding observations.

Returns

qA – Posterior Dirichlet parameters over observation model (same shape as A), after having updated it with observations.

Return type

numpy.ndarray of dtype object

pymdp.learning.update_state_likelihood_dirichlet(pB, B, actions, qs, qs_prev, lr=1.0, factors='all')

Update Dirichlet parameters of the transition distribution.

Parameters
  • pB (numpy.ndarray of dtype object) – Prior Dirichlet parameters over transition model (same shape as B)

  • B (numpy.ndarray of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states at t to hidden states at t+1, given some control state u. Each element B[f] of this object array stores a 3-D tensor for hidden state factor f, whose entries B[f][s, v, u] store the probability of hidden state level s at the current time, given hidden state level v and action u at the previous time.

  • actions (1D numpy.ndarray) – A vector with length equal to the number of control factors, where each element contains the index of the action (for that control factor) performed at a given timestep.

  • qs (1D numpy.ndarray or numpy.ndarray of dtype object) – Marginal posterior beliefs over hidden states at current timepoint.

  • qs_prev (1D numpy.ndarray or numpy.ndarray of dtype object) – Marginal posterior beliefs over hidden states at previous timepoint.

  • lr (float, default 1.0) – Learning rate, scale of the Dirichlet pseudo-count update.

  • factors (list, default “all”) – Indices (ranging from 0 to n_factors - 1) of the hidden state factors to include in learning. Defaults to “all”, meaning that factor-specific sub-arrays of pB are all updated using the corresponding hidden state distributions and actions.

Returns

qB – Posterior Dirichlet parameters over transition model (same shape as B), after having updated it with state beliefs and actions.

Return type

numpy.ndarray of dtype object

pymdp.learning.update_state_prior_dirichlet(pD, qs, lr=1.0, factors='all')

Update Dirichlet parameters of the initial hidden state distribution (prior beliefs about hidden states at the beginning of the inference window).

Parameters
  • pD (numpy.ndarray of dtype object) – Prior Dirichlet parameters over initial hidden state prior (same shape as qs)

  • qs (1D numpy.ndarray or numpy.ndarray of dtype object) – Marginal posterior beliefs over hidden states at current timepoint

  • lr (float, default 1.0) – Learning rate, scale of the Dirichlet pseudo-count update.

  • factors (list, default “all”) – Indices (ranging from 0 to n_factors - 1) of the hidden state factors to include in learning. Defaults to “all”, meaning that factor-specific sub-vectors of pD are all updated using the corresponding hidden state distributions.

Returns

qD – Posterior Dirichlet parameters over initial hidden state prior (same shape as qs), after having updated it with state beliefs.

Return type

numpy.ndarray of dtype object