Inference

The inference.py module contains the functions for performing inference of discrete hidden states (categorical distributions) in POMDP generative models.

pymdp.inference.average_states_over_policies(qs_pi, q_pi)

This function computes a expected posterior over hidden states with respect to the posterior over policies, also known as the ‘Bayesian model average of states with respect to policies’.

Parameters
  • qs_pi (numpy.ndarray of dtype object) – Posterior beliefs over hidden states for each policy. Nesting structure is policies, factors, where e.g. qs_pi[p][f] stores the marginal belief about factor f under policy p.

  • q_pi (numpy.ndarray of dtype object) – Posterior beliefs about policies where len(q_pi) = num_policies

Returns

qs_bma – Marginal posterior over hidden states for the current timepoint, averaged across policies according to their posterior probability given by q_pi

Return type

numpy.ndarray of dtype object

pymdp.inference.update_posterior_states(A, obs, prior=None, **kwargs)

Update marginal posterior over hidden states using mean-field fixed point iteration FPI or Fixed point iteration.

See the following links for details: http://www.cs.cmu.edu/~guestrin/Class/10708/recitations/r9/VI-view.pdf, slides 13- 18, and http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.137.221&rep=rep1&type=pdf, slides 24 - 38.

Parameters
  • A (numpy.ndarray of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each element A[m] of stores an np.ndarray multidimensional array for observation modality m, whose entries A[m][i, j, k, ...] store the probability of observation level i given hidden state levels j, k, ...

  • obs (1D numpy.ndarray, numpy.ndarray of dtype object, int or tuple) – The observation (generated by the environment). If single modality, this can be a 1D np.ndarray (one-hot vector representation) or an int (observation index) If multi-modality, this can be np.ndarray of dtype object whose entries are 1D one-hot vectors, or a tuple (of int)

  • prior (1D numpy.ndarray or numpy.ndarray of dtype object, default None) – Prior beliefs about hidden states, to be integrated with the marginal likelihood to obtain a posterior distribution. If not provided, prior is set to be equal to a flat categorical distribution (at the level of the individual inference functions).

  • **kwargs (keyword arguments) – List of keyword/parameter arguments corresponding to parameter values for the fixed-point iteration algorithm algos.fpi.run_vanilla_fpi.py

Returns

qs – Marginal posterior beliefs over hidden states at current timepoint

Return type

1D numpy.ndarray or numpy.ndarray of dtype object

pymdp.inference.update_posterior_states_full(A, B, prev_obs, policies, prev_actions=None, prior=None, policy_sep_prior=True, **kwargs)

Update posterior over hidden states using marginal message passing

Parameters
  • A (numpy.ndarray of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each element A[m] of stores an numpy.ndarray multidimensional array for observation modality m, whose entries A[m][i, j, k, ...] store the probability of observation level i given hidden state levels j, k, ...

  • B (numpy.ndarray of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states at t to hidden states at t+1, given some control state u. Each element B[f] of this object array stores a 3-D tensor for hidden state factor f, whose entries B[f][s, v, u] store the probability of hidden state level s at the current time, given hidden state level v and action u at the previous time.

  • prev_obs (list) – List of observations over time. Each observation in the list can be an int, a list of ints, a tuple of ints, a one-hot vector or an object array of one-hot vectors.

  • policies (list of 2D numpy.ndarray) – List that stores each policy in policies[p_idx]. Shape of policies[p_idx] is (num_timesteps, num_factors) where num_timesteps is the temporal depth of the policy and num_factors is the number of control factors.

  • prior (numpy.ndarray of dtype object, default None) – If provided, this a numpy.ndarray of dtype object, with one sub-array per hidden state factor, that stores the prior beliefs about initial states. If None, this defaults to a flat (uninformative) prior over hidden states.

  • policy_sep_prior (Bool, default True) – Flag determining whether the prior beliefs from the past are unconditioned on policy, or separated by /conditioned on the policy variable.

  • **kwargs (keyword arguments) – Optional keyword arguments for the function algos.mmp.run_mmp

Returns

  • qs_seq_pi (numpy.ndarray of dtype object) – Posterior beliefs over hidden states for each policy. Nesting structure is policies, timepoints, factors, where e.g. qs_seq_pi[p][t][f] stores the marginal belief about factor f at timepoint t under policy p.

  • F (1D numpy.ndarray) – Vector of variational free energies for each policy