pymdp.agent

pymdp.agent

Agent API for Active Inference with the modern JAX backend.

Agent

Bases: Module

The Agent class, the highest-level API that wraps together processes for action, perception, and learning under active inference.

Examples:

A single timestep of active inference:

from jax import random as jr

my_agent = Agent(A=A, B=B, C=C, <more_params>)
observation = env.step(initial_action)
qs = my_agent.infer_states(observation, empirical_prior=my_agent.D)
q_pi, neg_efe = my_agent.infer_policies(qs)
keys = jr.split(rng_key, my_agent.batch_size + 1)
next_action = my_agent.sample_action(q_pi, rng_key=keys[1:])
next_observation = env.step(next_action)

This represents one timestep of an active inference process. Wrapping this step in a loop with an Env() class that returns observations and takes actions as inputs, would entail a dynamic agent-environment interaction.

Observation Formats

Observations can be provided in two formats:

  1. Discrete observations (default, categorical_obs=False): Each observations[m] is an integer observation index for modality m. These are converted to one-hot vectors internally.

  2. Categorical observations (categorical_obs=True): Each observations[m] is a probability vector over observations for modality m.

Advanced preprocessing

You can override default preprocessing with preprocess_fn (set on the agent or per infer_states call). If provided, this function should return categorical observations and takes precedence over default discrete/categorical handling.

decode_multi_actions(action: Array) -> Array

Decode flattened multi-actions back to factor-wise actions.

Parameters:

Name Type Description Default
action Array

Flattened multi-action indices.

required

Returns:

Type Description
Array

Array of shape (batch_size, num_controls_multi) containing decoded actions per control factor.

encode_multi_actions(action_multi: Array) -> Array

Encode factor-wise multi-actions into flattened actions.

Parameters:

Name Type Description Default
action_multi Array

Array of actions per control factor.

required

Returns:

Type Description
Array

Flattened action indices with shape (batch_size, num_controls).

get_model_dimensions() -> dict[str, Any]

Collect key model dimensions in a single object.

Returns:

Type Description
dict[str, Any]

Dictionary containing model shape metadata. Includes:

  • num_obs: list[int]
  • num_states: list[int]
  • num_controls: list[int]
  • num_modalities: int
  • num_factors: int
  • num_policies: int
  • policy_len: int
  • inference_horizon: int | None
  • A_dependencies: list[list[int]]
  • B_dependencies: list[list[int]]

infer_parameters(beliefs_A: list[Array], observations: list[Array], actions: Array | None, beliefs_B: list[Array] | None = None, lr_pA: float = 1.0, lr_pB: float = 1.0, **kwargs: Any) -> Agent

Update Dirichlet parameters for A and/or B models from data.

Parameters:

Name Type Description Default
beliefs_A list[Array]

Marginal state beliefs used when updating the observation model parameters.

required
observations list[Array]

Observation histories for each modality.

required
actions Array | None

Action history aligned to time. For multi-action agents this should be shaped (batch, T, num_factors).

required
beliefs_B list[Array] | None

Optional sequence of beliefs used for transition updates. If None, transition updates are skipped.

None
lr_pA float

Learning-rate multiplier for A updates.

1.0
lr_pB float

Learning-rate multiplier for B updates.

1.0
**kwargs Any

Reserved for future arguments.

{}

Returns:

Type Description
Agent

Agent instance with updated pA, A, pB, and B where learning is enabled.

infer_policies(qs: list[Array]) -> tuple[Array, Array]

Perform policy inference by optimizing a posterior (categorical) distribution over policies. This distribution is computed as the softmax of neg_efe * gamma + lnE where neg_efe is the negative expected free energy of policies, gamma is a policy precision and lnE is the (log) prior probability of policies. In SPM-style notation this same quantity is often written as G, with G = neg_efe = -EFE. This function returns the posterior over policies as well as the negative expected free energy of each policy.

Parameters:

Name Type Description Default
qs list[Array]

Posterior beliefs over hidden states (typically output of infer_states), including the most recent timestep.

required

Returns:

Name Type Description
q_pi Array

Posterior beliefs over policies with shape (batch_size, num_policies).

neg_efe Array

Negative expected free energies of policies with shape (batch_size, num_policies).

infer_states(observations: list[Array] | list[int], empirical_prior: list[Array], *, past_actions: Array | None = None, qs_hist: list[Array] | None = None, valid_steps: int | Array | None = None, mask: list[Array] | None = None, preprocess_fn: Callable | None = None, return_info: bool = False) -> list[Array] | tuple[list[Array], dict[str, Any]]

Update approximate posterior over hidden states by solving variational inference problem, given an observation.

Parameters:

Name Type Description Default
observations list[Array] | list[int]

Observation input in one of two formats:

  • Discrete observations (default): each observations[m] is an integer index for modality m.
  • Categorical observations: each observations[m] is a probability vector over observations for modality m.

If preprocess_fn is provided, it should map the raw input to categorical observations and takes precedence over default handling.

required
empirical_prior list[Array]

Empirical prior beliefs over hidden states. Depending on the inference algorithm chosen, the resulting empirical_prior variable may be a matrix (or list[Array]). of additional dimensions to encode extra conditioning variables like timepoint and policy.

required
past_actions Array | None

Action history aligned to time. For single-batch sequence inference this should be shaped (T-1, num_factors). For batched calls it should be shaped (batch, T-1, num_factors).

None
qs_hist list[Array] | None

History of posterior beliefs over hidden states.

None
valid_steps int | Array | None

Number of valid (unpadded) timesteps when using fixed-size sequence windows. If provided, sequence inference methods (mmp, vmp) ignore padded prefix timesteps and transitions.

None
mask list[Array] | None

Mask for observations.

None
preprocess_fn Callable | None

Optional preprocessing function to convert observations into distributional form. If None, defaults to self.process_obs. The callable should accept observations and return distributional observations.

None
return_info bool

If True, also return canonical VFE diagnostics for the inferred posterior (vfe_t, vfe, and component terms). For ovf / exact this remains a forward-filtering diagnostic; to score the full smoothed sequence, call pymdp.inference.smoothing_ovf(...) or pymdp.inference.smoothing_exact(...) explicitly and then pass the resulting pairwise joints into pymdp.maths.calc_vfe(..., joint_qs=...).

False
Notes

categorical_obs is no longer an argument to infer_states. Set it when constructing the agent or supply a preprocess_fn. If you provide a custom preprocessing function, ensure self.categorical_obs matches the output format, since it is still used by learning and planning code paths that consume raw observations.

Returns:

Name Type Description
qs list[Array] or tuple[list[Array], dict[str, Any]]

Posterior beliefs over hidden states. Depending on the inference algorithm chosen, the resulting qs variable will have additional sub-structure to reflect whether beliefs are additionally conditioned on timepoint and policy. For example, in case the self.inference_algo == 'MMP' indexing structure is policy->timepoint->factor, so that qs[p_idx][t_idx][f_idx] refers to beliefs about marginal factor f_idx expected under policy p_idx at timepoint t_idx. If return_info=True, a second return value contains VFE diagnostics.

Examples:

Discrete observations:

>>> obs = [0, 1]  # Modality 0 observed observation 0, modality 1 observed observation 1
>>> qs = agent.infer_states(obs, prior)

Categorical observations:

>>> obs = [
...     jnp.array([0.7, 0.2, 0.1]),  # Peaked belief distribution for observation 0
...     jnp.array([0.5, 0.5])        # Flat belief distribution for observation 1
... ]
>>> agent_cat = Agent(..., categorical_obs=True)
>>> qs = agent_cat.infer_states(obs, prior)

make_categorical(observations: list[Array] | list[int]) -> list[Array]

Convert discrete index observations into one-hot categorical distributions.

Parameters:

Name Type Description Default
observations list[Array] | list[int]

Each entry observations[m] is an integer index for modality m.

required

Returns:

Name Type Description
o_vec list

One-hot categorical distributions for each modality.

multiaction_probabilities(q_pi: Array) -> Array

Compute probabilities of unique multi-actions from the posterior over policies.

Parameters:

Name Type Description Default
q_pi Array

Posterior beliefs over policies for one batch element.

required

Returns:

Type Description
Array

Probability vector over unique multi-actions.

process_obs(observations: list[Array] | list[int]) -> list[Array]

Preprocess observations into the distributional format expected by the inference routines.

Parameters:

Name Type Description Default
observations list[Array] | list[int]

The observation input. Format depends on the default preprocessing:

  • If self.categorical_obs=False (default): Each entry observations[m] is an integer index representing the discrete observation for modality m.

  • If self.categorical_obs=True: Each entry observations[m] is a 1D array representing a probability distribution over observations for modality m.

required

Returns:

Name Type Description
o_vec list[Array]

Observations in distributional form (one-hot vectors or categorical distributions).

Notes

If self.preprocess_fn is set on the agent, it takes precedence over the default categorical/discrete handling and will be used instead of the logic based on self.categorical_obs. This override only affects preprocessing; self.categorical_obs is still used by learning and planning code paths that consume raw observations. Ensure self.categorical_obs matches the output format of your preprocessing (or per-call preprocess_fn) to keep those paths consistent.

sample_action(q_pi: Array, rng_key: Array | None = None) -> Array

Sample or select a discrete action from the posterior over control states.

Parameters:

Name Type Description Default
q_pi Array

Posterior over policies for each batch element (usually from infer_policies).

required
rng_key Array | None

Required for stochastic action selection. For batched agents, pass a key array with one key per batch element.

None

Returns:

Name Type Description
action Array

Action indices per batch element and control factor.

update_empirical_prior(action: Array, qs: list[Array]) -> list[Array]

Compute the empirical prior used for the next state-inference step.

Parameters:

Name Type Description Default
action Array

Action sampled at the current timestep for each control factor.

required
qs list[Array]

Posterior beliefs over hidden states for the current timestep/history.

required

Returns:

Name Type Description
pred list[Array]

Predicted prior over hidden states for the next inference step. For sequence methods (mmp, vmp), this returns self.D to preserve sequence-inference semantics.