Agent class

class pymdp.agent.Agent(A, B, C=None, D=None, E=None, pA=None, pB=None, pD=None, num_controls=None, policy_len=1, inference_horizon=1, control_fac_idx=None, policies=None, gamma=16.0, alpha=16.0, use_utility=True, use_states_info_gain=True, use_param_info_gain=False, action_selection='deterministic', sampling_mode='marginal', inference_algo='VANILLA', inference_params=None, modalities_to_learn='all', lr_pA=1.0, factors_to_learn='all', lr_pB=1.0, lr_pD=1.0, use_BMA=True, policy_sep_prior=False, save_belief_hist=False)

The Agent class, the highest-level API that wraps together processes for action, perception, and learning under active inference.

The basic usage is as follows:

>>> my_agent = Agent(A = A, B = C, <more_params>)
>>> observation = env.step(initial_action)
>>> qs = my_agent.infer_states(observation)
>>> q_pi, G = my_agent.infer_policies()
>>> next_action = my_agent.sample_action()
>>> next_observation = env.step(next_action)

This represents one timestep of an active inference process. Wrapping this step in a loop with an Env() class that returns observations and takes actions as inputs, would entail a dynamic agent-environment interaction.

get_future_qs()

Returns the last self.policy_len timesteps of each policy-conditioned belief over hidden states. This is a step of pre-processing that needs to be done before computing the expected free energy of policies. We do this to avoid computing the expected free energy of policies using beliefs about hidden states in the past (so-called “post-dictive” beliefs).

Returns

future_qs_seq – Posterior beliefs over hidden states under a policy, in the future. This is a nested numpy.ndarray object array, with one sub-array future_qs_seq[p_idx] for each policy. The indexing structure is policy->timepoint–>factor, so that future_qs_seq[p_idx][t_idx][f_idx] refers to beliefs about marginal factor f_idx expected under policy p_idx at future timepoint t_idx, relative to the current timestep.

Return type

numpy.ndarray of dtype object

infer_policies()

Perform policy inference by optimizing a posterior (categorical) distribution over policies. This distribution is computed as the softmax of G * gamma + lnE where G is the negative expected free energy of policies, gamma is a policy precision and lnE is the (log) prior probability of policies. This function returns the posterior over policies as well as the negative expected free energy of each policy.

Returns

  • q_pi (1D numpy.ndarray) – Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy.

  • G (1D numpy.ndarray) – Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy.

infer_states(observation, distr_obs=False)

Update approximate posterior over hidden states by solving variational inference problem, given an observation.

Parameters

observation (list or tuple of ints) – The observation input. Each entry observation[m] stores the index of the discrete observation for modality m.

Returns

qs – Posterior beliefs over hidden states. Depending on the inference algorithm chosen, the resulting qs variable will have additional sub-structure to reflect whether beliefs are additionally conditioned on timepoint and policy. For example, in case the self.inference_algo == 'MMP' `` indexing structure is policy->timepoint-->factor, so that ``qs[p_idx][t_idx][f_idx] refers to beliefs about marginal factor f_idx expected under policy p_idx at timepoint t_idx.

Return type

numpy.ndarray of dtype object

reset(init_qs=None)

Resets the posterior beliefs about hidden states of the agent to a uniform distribution, and resets time to first timestep of the simulation’s temporal horizon. Returns the posterior beliefs about hidden states.

Returns

qs – Initialized posterior over hidden states. Depending on the inference algorithm chosen and other parameters (such as the parameters stored within edge_handling_paramss), the resulting ``qs variable will have additional sub-structure to reflect whether beliefs are additionally conditioned on timepoint and policy.

For example, in case the self.inference_algo == 'MMP' `, the indexing structure of ``qs is policy->timepoint–>factor, so that qs[p_idx][t_idx][f_idx] refers to beliefs about marginal factor f_idx expected under policy p_idx at timepoint t_idx. In this case, the returned qs will only have entries filled out for the first timestep, i.e. for q[p_idx][0], for all policy-indices p_idx. Subsequent entries q[:][1, 2, ...] will be initialized to empty numpy.ndarray objects.

Return type

numpy.ndarray of dtype object

sample_action()

Sample or select a discrete action from the posterior over control states. This function both sets or cachés the action as an internal variable with the agent and returns it. This function also updates time variable (and thus manages consequences of updating the moving reference frame of beliefs) using self.step_time().

Returns

action – Vector containing the indices of the actions for each control factor

Return type

1D numpy.ndarray

set_latest_beliefs(last_belief=None)

Both sets and returns the penultimate belief before the first timestep of the backwards inference horizon. In the case that the inference horizon includes the first timestep of the simulation, then the latest_belief is simply the first belief of the whole simulation, or the prior (self.D). The particular structure of the latest_belief depends on the value of self.edge_handling_params['use_BMA'].

Returns

latest_belief – Penultimate posterior beliefs over hidden states at the timestep just before the first timestep of the inference horizon. Depending on the value of self.edge_handling_params['use_BMA'], the shape of this output array will differ. If self.edge_handling_params['use_BMA'] == True, then latest_belief will be a Bayesian model average of beliefs about hidden states, where the average is taken with respect to posterior beliefs about policies. Otherwise, latest_belief` will be the full, policy-conditioned belief about hidden states, and will have indexing structure policies->factors, such that latest_belief[p_idx][f_idx] refers to the penultimate belief about marginal factor f_idx under policy p_idx.

Return type

numpy.ndarray of dtype object

step_time()

Advances time by one step. This involves updating the self.prev_actions, and in the case of a moving inference horizon, this also shifts the history of post-dictive beliefs forward in time (using self.set_latest_beliefs()), so that the penultimate belief before the beginning of the horizon is correctly indexed.

Returns

curr_timestep – The index in absolute simulation time of the current timestep.

Return type

int

update_A(obs)

Update approximate posterior beliefs about Dirichlet parameters that parameterise the observation likelihood or A array.

Parameters

observation (list or tuple of ints) – The observation input. Each entry observation[m] stores the index of the discrete observation for modality m.

Returns

qA – Posterior Dirichlet parameters over observation model (same shape as A), after having updated it with observations.

Return type

numpy.ndarray of dtype object

update_B(qs_prev)

Update posterior beliefs about Dirichlet parameters that parameterise the transition likelihood

Parameters

qs_prev (1D numpy.ndarray or numpy.ndarray of dtype object) – Marginal posterior beliefs over hidden states at previous timepoint.

Returns

qB – Posterior Dirichlet parameters over transition model (same shape as B), after having updated it with state beliefs and actions.

Return type

numpy.ndarray of dtype object

update_D(qs_t0=None)

Update Dirichlet parameters of the initial hidden state distribution (prior beliefs about hidden states at the beginning of the inference window).

Parameters

qs_t0 (1D numpy.ndarray, numpy.ndarray of dtype object, or None) – Marginal posterior beliefs over hidden states at current timepoint. If None, the value of qs_t0 is set to self.qs_hist[0] (i.e. the initial hidden state beliefs at the first timepoint). If self.inference_algo == "MMP", then qs_t0 is set to be the Bayesian model average of beliefs about hidden states at the first timestep of the backwards inference horizon, where the average is taken with respect to posterior beliefs about policies.

Returns

qD – Posterior Dirichlet parameters over initial hidden state prior (same shape as qs_t0), after having updated it with state beliefs.

Return type

numpy.ndarray of dtype object