MMP (Marginal Message Passing)

pymdp.algos.mmp.run_mmp(lh_seq, B, policy, prev_actions=None, prior=None, num_iter=10, grad_descent=True, tau=0.25, last_timestep=False)

Marginal message passing scheme for updating marginal posterior beliefs about hidden states over time, conditioned on a particular policy.

  • lh_seq (numpy.ndarray of dtype object) – Log likelihoods of hidden states under a sequence of observations over time. This is assumed to already be log-transformed. Each lh_seq[t] contains the log likelihood of hidden states for a particular observation at time t

  • B (numpy.ndarray of dtype object) – Dynamics likelihood mapping or ‘transition model’, mapping from hidden states at t to hidden states at t+1, given some control state u. Each element B[f] of this object array stores a 3-D tensor for hidden state factor f, whose entries B[f][s, v, u] store the probability of hidden state level s at the current time, given hidden state level v and action u at the previous time.

  • policy (2D numpy.ndarray) – Matrix of shape (policy_len, num_control_factors) that indicates the indices of each action (control state index) upon timestep t and control_factor f` in the element ``policy[t,f] for a given policy.

  • prev_actions (numpy.ndarray, default None) – If provided, should be a matrix of previous actions of shape (infer_len, num_control_factors) that indicates the indices of each action (control state index) taken in the past (up until the current timestep).

  • prior (numpy.ndarray of dtype object, default None) – If provided, the prior beliefs about initial states (at t = 0, relative to infer_len). If None, this defaults to a flat (uninformative) prior over hidden states.

  • numiter (int, default 10) – Number of variational iterations.

  • grad_descent (Bool, default True) – Flag for whether to use gradient descent (free energy gradient updates) instead of fixed point solution to the posterior beliefs

  • tau (float, default 0.25) – Decay constant for use in grad_descent version. Tunes the size of the gradient descent updates to the posterior.

  • last_timestep (Bool, default False) – Flag for whether we are at the last timestep of belief updating


  • qs_seq (numpy.ndarray of dtype object) – Posterior beliefs over hidden states under the policy. Nesting structure is timepoints, factors, where e.g. qs_seq[t][f] stores the marginal belief about factor f at timepoint t under the policy in question.

  • F (float) – Variational free energy of the policy.