MMP (Marginal Message Passing)
- pymdp.algos.mmp.run_mmp(lh_seq, B, policy, prev_actions=None, prior=None, num_iter=10, grad_descent=True, tau=0.25, last_timestep=False)
Marginal message passing scheme for updating marginal posterior beliefs about hidden states over time, conditioned on a particular policy.
- Parameters
lh_seq (
numpy.ndarray
of dtype object) – Log likelihoods of hidden states under a sequence of observations over time. This is assumed to already be log-transformed. Eachlh_seq[t]
contains the log likelihood of hidden states for a particular observation at timet
B (
numpy.ndarray
of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states att
to hidden states att+1
, given some control stateu
. Each elementB[f]
of this object array stores a 3-D tensor for hidden state factorf
, whose entriesB[f][s, v, u]
store the probability of hidden state levels
at the current time, given hidden state levelv
and actionu
at the previous time.policy (2D
numpy.ndarray
) – Matrix of shape(policy_len, num_control_factors)
that indicates the indices of each action (control state index) upon timestept
and control_factorf` in the element ``policy[t,f]
for a given policy.prev_actions (
numpy.ndarray
, default None) – If provided, should be a matrix of previous actions of shape(infer_len, num_control_factors)
that indicates the indices of each action (control state index) taken in the past (up until the current timestep).prior (
numpy.ndarray
of dtype object, default None) – If provided, the prior beliefs about initial states (at t = 0, relative toinfer_len
). IfNone
, this defaults to a flat (uninformative) prior over hidden states.numiter (int, default 10) – Number of variational iterations.
grad_descent (Bool, default True) – Flag for whether to use gradient descent (free energy gradient updates) instead of fixed point solution to the posterior beliefs
tau (float, default 0.25) – Decay constant for use in
grad_descent
version. Tunes the size of the gradient descent updates to the posterior.last_timestep (Bool, default False) – Flag for whether we are at the last timestep of belief updating
- Returns
qs_seq (
numpy.ndarray
of dtype object) – Posterior beliefs over hidden states under the policy. Nesting structure is timepoints, factors, where e.g.qs_seq[t][f]
stores the marginal belief about factorf
at timepointt
under the policy in question.F (float) – Variational free energy of the policy.
- pymdp.algos.mmp.run_mmp_factorized(lh_seq, mb_dict, B, B_factor_list, policy, prev_actions=None, prior=None, num_iter=10, grad_descent=True, tau=0.25, last_timestep=False)
Marginal message passing scheme for updating marginal posterior beliefs about hidden states over time, conditioned on a particular policy.
- Parameters
lh_seq (
numpy.ndarray
of dtype object) – Log likelihoods of hidden states under a sequence of observations over time. This is assumed to already be log-transformed. Eachlh_seq[t]
contains the log likelihood of hidden states for a particular observation at timet
mb_dict (
Dict
) – Dictionary with two keys (A_factor_list
andA_modality_list
), that stores the factor indices that influence each modality (A_factor_list
) and the modality indices influenced by each factor (A_modality_list
).B (
numpy.ndarray
of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states att
to hidden states att+1
, given some control stateu
. Each elementB[f]
of this object array stores a 3-D tensor for hidden state factorf
, whose entriesB[f][s, v, u]
store the probability of hidden state levels
at the current time, given hidden state levelv
and actionu
at the previous time.B_factor_list (
list
oflist
ofint
) – List of lists of hidden state factors each hidden state factor depends on. Each elementB_factor_list[i]
is a list of the factor indices that factor i’s dynamics depend on.policy (2D
numpy.ndarray
) – Matrix of shape(policy_len, num_control_factors)
that indicates the indices of each action (control state index) upon timestept
and control_factorf` in the element ``policy[t,f]
for a given policy.prev_actions (
numpy.ndarray
, default None) – If provided, should be a matrix of previous actions of shape(infer_len, num_control_factors)
that indicates the indices of each action (control state index) taken in the past (up until the current timestep).prior (
numpy.ndarray
of dtype object, default None) – If provided, the prior beliefs about initial states (at t = 0, relative toinfer_len
). IfNone
, this defaults to a flat (uninformative) prior over hidden states.numiter (int, default 10) – Number of variational iterations.
grad_descent (Bool, default True) – Flag for whether to use gradient descent (free energy gradient updates) instead of fixed point solution to the posterior beliefs
tau (float, default 0.25) – Decay constant for use in
grad_descent
version. Tunes the size of the gradient descent updates to the posterior.last_timestep (Bool, default False) – Flag for whether we are at the last timestep of belief updating
- Returns
qs_seq (
numpy.ndarray
of dtype object) – Posterior beliefs over hidden states under the policy. Nesting structure is timepoints, factors, where e.g.qs_seq[t][f]
stores the marginal belief about factorf
at timepointt
under the policy in question.F (float) – Variational free energy of the policy.