Inference
The inference.py
module contains the functions for performing inference of discrete hidden states (categorical distributions) in POMDP generative models.
- pymdp.inference.average_states_over_policies(qs_pi, q_pi)
This function computes a expected posterior over hidden states with respect to the posterior over policies, also known as the ‘Bayesian model average of states with respect to policies’.
- Parameters
qs_pi (
numpy.ndarray
of dtype object) – Posterior beliefs over hidden states for each policy. Nesting structure is policies, factors, where e.g.qs_pi[p][f]
stores the marginal belief about factorf
under policyp
.q_pi (
numpy.ndarray
of dtype object) – Posterior beliefs about policies wherelen(q_pi) = num_policies
- Returns
qs_bma – Marginal posterior over hidden states for the current timepoint, averaged across policies according to their posterior probability given by
q_pi
- Return type
numpy.ndarray
of dtype object
- pymdp.inference.update_posterior_states(A, obs, prior=None, **kwargs)
Update marginal posterior over hidden states using mean-field fixed point iteration FPI or Fixed point iteration.
See the following links for details: http://www.cs.cmu.edu/~guestrin/Class/10708/recitations/r9/VI-view.pdf, slides 13- 18, and http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.137.221&rep=rep1&type=pdf, slides 24 - 38.
- Parameters
A (
numpy.ndarray
of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each elementA[m]
of stores annp.ndarray
multidimensional array for observation modalitym
, whose entriesA[m][i, j, k, ...]
store the probability of observation leveli
given hidden state levelsj, k, ...
obs (1D
numpy.ndarray
,numpy.ndarray
of dtype object, int or tuple) – The observation (generated by the environment). If single modality, this can be a 1Dnp.ndarray
(one-hot vector representation) or anint
(observation index) If multi-modality, this can benp.ndarray
of dtype object whose entries are 1D one-hot vectors, or a tuple (ofint
)prior (1D
numpy.ndarray
ornumpy.ndarray
of dtype object, default None) – Prior beliefs about hidden states, to be integrated with the marginal likelihood to obtain a posterior distribution. If not provided, prior is set to be equal to a flat categorical distribution (at the level of the individual inference functions).**kwargs (keyword arguments) – List of keyword/parameter arguments corresponding to parameter values for the fixed-point iteration algorithm
algos.fpi.run_vanilla_fpi.py
- Returns
qs – Marginal posterior beliefs over hidden states at current timepoint
- Return type
1D
numpy.ndarray
ornumpy.ndarray
of dtype object
- pymdp.inference.update_posterior_states_factorized(A, obs, num_obs, num_states, mb_dict, prior=None, **kwargs)
Update marginal posterior over hidden states using mean-field fixed point iteration FPI or Fixed point iteration. This version identifies the Markov blanket of each factor using A_factor_list
See the following links for details: http://www.cs.cmu.edu/~guestrin/Class/10708/recitations/r9/VI-view.pdf, slides 13- 18, and http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.137.221&rep=rep1&type=pdf, slides 24 - 38.
- Parameters
A (
numpy.ndarray
of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each elementA[m]
of stores annp.ndarray
multidimensional array for observation modalitym
, whose entriesA[m][i, j, k, ...]
store the probability of observation leveli
given hidden state levelsj, k, ...
obs (1D
numpy.ndarray
,numpy.ndarray
of dtype object, int or tuple) – The observation (generated by the environment). If single modality, this can be a 1Dnp.ndarray
(one-hot vector representation) or anint
(observation index) If multi-modality, this can benp.ndarray
of dtype object whose entries are 1D one-hot vectors, or a tuple (ofint
)num_obs (
list
ofint
) – List of dimensionalities of each observation modalitynum_states (
list
ofint
) – List of dimensionalities of each hidden state factormb_dict (
Dict
) – Dictionary with two keys (A_factor_list
andA_modality_list
), that stores the factor indices that influence each modality (A_factor_list
) and the modality indices influenced by each factor (A_modality_list
).prior (1D
numpy.ndarray
ornumpy.ndarray
of dtype object, default None) – Prior beliefs about hidden states, to be integrated with the marginal likelihood to obtain a posterior distribution. If not provided, prior is set to be equal to a flat categorical distribution (at the level of the individual inference functions).**kwargs (keyword arguments) – List of keyword/parameter arguments corresponding to parameter values for the fixed-point iteration algorithm
algos.fpi.run_vanilla_fpi.py
- Returns
qs – Marginal posterior beliefs over hidden states at current timepoint
- Return type
1D
numpy.ndarray
ornumpy.ndarray
of dtype object
- pymdp.inference.update_posterior_states_full(A, B, prev_obs, policies, prev_actions=None, prior=None, policy_sep_prior=True, **kwargs)
Update posterior over hidden states using marginal message passing
- Parameters
A (
numpy.ndarray
of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each elementA[m]
of stores annumpy.ndarray
multidimensional array for observation modalitym
, whose entriesA[m][i, j, k, ...]
store the probability of observation leveli
given hidden state levelsj, k, ...
B (
numpy.ndarray
of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states att
to hidden states att+1
, given some control stateu
. Each elementB[f]
of this object array stores a 3-D tensor for hidden state factorf
, whose entriesB[f][s, v, u]
store the probability of hidden state levels
at the current time, given hidden state levelv
and actionu
at the previous time.prev_obs (
list
) – List of observations over time. Each observation in the list can be anint
, alist
of ints, atuple
of ints, a one-hot vector or an object array of one-hot vectors.policies (
list
of 2Dnumpy.ndarray
) – List that stores each policy inpolicies[p_idx]
. Shape ofpolicies[p_idx]
is(num_timesteps, num_factors)
where num_timesteps is the temporal depth of the policy andnum_factors
is the number of control factors.prior (
numpy.ndarray
of dtype object, defaultNone
) – If provided, this anumpy.ndarray
of dtype object, with one sub-array per hidden state factor, that stores the prior beliefs about initial states. IfNone
, this defaults to a flat (uninformative) prior over hidden states.policy_sep_prior (
Bool
, defaultTrue
) – Flag determining whether the prior beliefs from the past are unconditioned on policy, or separated by /conditioned on the policy variable.**kwargs (keyword arguments) – Optional keyword arguments for the function
algos.mmp.run_mmp
- Returns
qs_seq_pi (
numpy.ndarray
of dtype object) – Posterior beliefs over hidden states for each policy. Nesting structure is policies, timepoints, factors, where e.g.qs_seq_pi[p][t][f]
stores the marginal belief about factorf
at timepointt
under policyp
.F (1D
numpy.ndarray
) – Vector of variational free energies for each policy
- pymdp.inference.update_posterior_states_full_factorized(A, mb_dict, B, B_factor_list, prev_obs, policies, prev_actions=None, prior=None, policy_sep_prior=True, **kwargs)
Update posterior over hidden states using marginal message passing
- Parameters
A (
numpy.ndarray
of dtype object) – Sensory likelihood mapping or ‘observation model’, mapping from hidden states to observations. Each elementA[m]
of stores annumpy.ndarray
multidimensional array for observation modalitym
, whose entriesA[m][i, j, k, ...]
store the probability of observation leveli
given hidden state levelsj, k, ...
mb_dict (
Dict
) – Dictionary with two keys (A_factor_list
andA_modality_list
), that stores the factor indices that influence each modality (A_factor_list
) and the modality indices influenced by each factor (A_modality_list
).B (
numpy.ndarray
of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states att
to hidden states att+1
, given some control stateu
. Each elementB[f]
of this object array stores a 3-D tensor for hidden state factorf
, whose entriesB[f][s, v, u]
store the probability of hidden state levels
at the current time, given hidden state levelv
and actionu
at the previous time.B_factor_list (
list
oflist
ofint
) – List of lists of hidden state factors each hidden state factor depends on. Each elementB_factor_list[i]
is a list of the factor indices that factor i’s dynamics depend on.prev_obs (
list
) – List of observations over time. Each observation in the list can be anint
, alist
of ints, atuple
of ints, a one-hot vector or an object array of one-hot vectors.policies (
list
of 2Dnumpy.ndarray
) – List that stores each policy inpolicies[p_idx]
. Shape ofpolicies[p_idx]
is(num_timesteps, num_factors)
where num_timesteps is the temporal depth of the policy andnum_factors
is the number of control factors.prior (
numpy.ndarray
of dtype object, defaultNone
) – If provided, this anumpy.ndarray
of dtype object, with one sub-array per hidden state factor, that stores the prior beliefs about initial states. IfNone
, this defaults to a flat (uninformative) prior over hidden states.policy_sep_prior (
Bool
, defaultTrue
) – Flag determining whether the prior beliefs from the past are unconditioned on policy, or separated by /conditioned on the policy variable.**kwargs (keyword arguments) – Optional keyword arguments for the function
algos.mmp.run_mmp
- Returns
qs_seq_pi (
numpy.ndarray
of dtype object) – Posterior beliefs over hidden states for each policy. Nesting structure is policies, timepoints, factors, where e.g.qs_seq_pi[p][t][f]
stores the marginal belief about factorf
at timepointt
under policyp
.F (1D
numpy.ndarray
) – Vector of variational free energies for each policy