Agent class
- class pymdp.agent.Agent(A, B, C=None, D=None, E=None, H=None, pA=None, pB=None, pD=None, num_controls=None, policy_len=1, inference_horizon=1, control_fac_idx=None, policies=None, gamma=16.0, alpha=16.0, use_utility=True, use_states_info_gain=True, use_param_info_gain=False, action_selection='deterministic', sampling_mode='marginal', inference_algo='VANILLA', inference_params=None, modalities_to_learn='all', lr_pA=1.0, factors_to_learn='all', lr_pB=1.0, lr_pD=1.0, use_BMA=True, policy_sep_prior=False, save_belief_hist=False, A_factor_list=None, B_factor_list=None, sophisticated=False, si_horizon=3, si_policy_prune_threshold=0.0625, si_state_prune_threshold=0.0625, si_prune_penalty=512, ii_depth=10, ii_threshold=0.0625)
The Agent class, the highest-level API that wraps together processes for action, perception, and learning under active inference.
The basic usage is as follows:
>>> my_agent = Agent(A = A, B = C, <more_params>) >>> observation = env.step(initial_action) >>> qs = my_agent.infer_states(observation) >>> q_pi, G = my_agent.infer_policies() >>> next_action = my_agent.sample_action() >>> next_observation = env.step(next_action)
This represents one timestep of an active inference process. Wrapping this step in a loop with an
Env()
class that returns observations and takes actions as inputs, would entail a dynamic agent-environment interaction.- get_future_qs()
Returns the last
self.policy_len
timesteps of each policy-conditioned belief over hidden states. This is a step of pre-processing that needs to be done before computing the expected free energy of policies. We do this to avoid computing the expected free energy of policies using beliefs about hidden states in the past (so-called “post-dictive” beliefs).- Returns
future_qs_seq – Posterior beliefs over hidden states under a policy, in the future. This is a nested
numpy.ndarray
object array, with one sub-arrayfuture_qs_seq[p_idx]
for each policy. The indexing structure is policy->timepoint–>factor, so thatfuture_qs_seq[p_idx][t_idx][f_idx]
refers to beliefs about marginal factorf_idx
expected under policyp_idx
at future timepointt_idx
, relative to the current timestep.- Return type
numpy.ndarray
of dtype object
- infer_policies()
Perform policy inference by optimizing a posterior (categorical) distribution over policies. This distribution is computed as the softmax of
G * gamma + lnE
whereG
is the negative expected free energy of policies,gamma
is a policy precision andlnE
is the (log) prior probability of policies. This function returns the posterior over policies as well as the negative expected free energy of each policy. In this version of the function, the expected free energy of policies is computed using known factorized structure in the model, which speeds up computation (particular the state information gain calculations).- Returns
q_pi (1D
numpy.ndarray
) – Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy.G (1D
numpy.ndarray
) – Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy.
- infer_states(observation, distr_obs=False)
Update approximate posterior over hidden states by solving variational inference problem, given an observation.
- Parameters
observation (
list
ortuple
of ints) – The observation input. Each entryobservation[m]
stores the index of the discrete observation for modalitym
.distr_obs (
bool
) – Whether the observation is a distribution over possible observations, rather than a single observation.
- Returns
qs – Posterior beliefs over hidden states. Depending on the inference algorithm chosen, the resulting
qs
variable will have additional sub-structure to reflect whether beliefs are additionally conditioned on timepoint and policy. For example, in case theself.inference_algo == 'MMP' `` indexing structure is policy->timepoint-->factor, so that ``qs[p_idx][t_idx][f_idx]
refers to beliefs about marginal factorf_idx
expected under policyp_idx
at timepointt_idx
.- Return type
numpy.ndarray
of dtype object
- reset(init_qs=None)
Resets the posterior beliefs about hidden states of the agent to a uniform distribution, and resets time to first timestep of the simulation’s temporal horizon. Returns the posterior beliefs about hidden states.
- Returns
qs – Initialized posterior over hidden states. Depending on the inference algorithm chosen and other parameters (such as the parameters stored within
edge_handling_paramss), the resulting ``qs
variable will have additional sub-structure to reflect whether beliefs are additionally conditioned on timepoint and policy.For example, in case the
self.inference_algo == 'MMP' `, the indexing structure of ``qs
is policy->timepoint–>factor, so thatqs[p_idx][t_idx][f_idx]
refers to beliefs about marginal factorf_idx
expected under policyp_idx
at timepointt_idx
. In this case, the returnedqs
will only have entries filled out for the first timestep, i.e. forq[p_idx][0]
, for all policy-indicesp_idx
. Subsequent entriesq[:][1, 2, ...]
will be initialized to emptynumpy.ndarray
objects.- Return type
numpy.ndarray
of dtype object
- sample_action()
Sample or select a discrete action from the posterior over control states. This function both sets or cachés the action as an internal variable with the agent and returns it. This function also updates time variable (and thus manages consequences of updating the moving reference frame of beliefs) using
self.step_time()
.- Returns
action – Vector containing the indices of the actions for each control factor
- Return type
1D
numpy.ndarray
- set_latest_beliefs(last_belief=None)
Both sets and returns the penultimate belief before the first timestep of the backwards inference horizon. In the case that the inference horizon includes the first timestep of the simulation, then the
latest_belief
is simply the first belief of the whole simulation, or the prior (self.D
). The particular structure of thelatest_belief
depends on the value ofself.edge_handling_params['use_BMA']
.- Returns
latest_belief – Penultimate posterior beliefs over hidden states at the timestep just before the first timestep of the inference horizon. Depending on the value of
self.edge_handling_params['use_BMA']
, the shape of this output array will differ. Ifself.edge_handling_params['use_BMA'] == True
, thenlatest_belief
will be a Bayesian model average of beliefs about hidden states, where the average is taken with respect to posterior beliefs about policies. Otherwise, latest_belief` will be the full, policy-conditioned belief about hidden states, and will have indexing structure policies->factors, such thatlatest_belief[p_idx][f_idx]
refers to the penultimate belief about marginal factorf_idx
under policyp_idx
.- Return type
numpy.ndarray
of dtype object
- step_time()
Advances time by one step. This involves updating the
self.prev_actions
, and in the case of a moving inference horizon, this also shifts the history of post-dictive beliefs forward in time (usingself.set_latest_beliefs()
), so that the penultimate belief before the beginning of the horizon is correctly indexed.- Returns
curr_timestep – The index in absolute simulation time of the current timestep.
- Return type
int
- update_A(obs)
Update approximate posterior beliefs about Dirichlet parameters that parameterise the observation likelihood or
A
array.- Parameters
observation (
list
ortuple
of ints) – The observation input. Each entryobservation[m]
stores the index of the discrete observation for modalitym
.- Returns
qA – Posterior Dirichlet parameters over observation model (same shape as
A
), after having updated it with observations.- Return type
numpy.ndarray
of dtype object
- update_B(qs_prev)
Update posterior beliefs about Dirichlet parameters that parameterise the transition likelihood
- Parameters
qs_prev (1D
numpy.ndarray
ornumpy.ndarray
of dtype object) – Marginal posterior beliefs over hidden states at previous timepoint.- Returns
qB – Posterior Dirichlet parameters over transition model (same shape as
B
), after having updated it with state beliefs and actions.- Return type
numpy.ndarray
of dtype object
- update_D(qs_t0=None)
Update Dirichlet parameters of the initial hidden state distribution (prior beliefs about hidden states at the beginning of the inference window).
- Parameters
qs_t0 (1D
numpy.ndarray
,numpy.ndarray
of dtype object, orNone
) – Marginal posterior beliefs over hidden states at current timepoint. IfNone
, the value ofqs_t0
is set toself.qs_hist[0]
(i.e. the initial hidden state beliefs at the first timepoint). Ifself.inference_algo == "MMP"
, thenqs_t0
is set to be the Bayesian model average of beliefs about hidden states at the first timestep of the backwards inference horizon, where the average is taken with respect to posterior beliefs about policies.- Returns
qD – Posterior Dirichlet parameters over initial hidden state prior (same shape as
qs_t0
), after having updated it with state beliefs.- Return type
numpy.ndarray
of dtype object