pymdp.agent¶
pymdp.agent
¶
Agent API for Active Inference with the modern JAX backend.
Agent
¶
Bases: Module
The Agent class, the highest-level API that wraps together processes for action, perception, and learning under active inference.
Examples:
A single timestep of active inference:
from jax import random as jr
my_agent = Agent(A=A, B=B, C=C, <more_params>)
observation = env.step(initial_action)
qs = my_agent.infer_states(observation, empirical_prior=my_agent.D)
q_pi, neg_efe = my_agent.infer_policies(qs)
keys = jr.split(rng_key, my_agent.batch_size + 1)
next_action = my_agent.sample_action(q_pi, rng_key=keys[1:])
next_observation = env.step(next_action)
This represents one timestep of an active inference process. Wrapping this step in a loop with an Env() class that returns
observations and takes actions as inputs, would entail a dynamic agent-environment interaction.
Observation Formats
Observations can be provided in two formats:
-
Discrete observations (default, categorical_obs=False): Each
observations[m]is an integer observation index for modalitym. These are converted to one-hot vectors internally. -
Categorical observations (categorical_obs=True): Each
observations[m]is a probability vector over observations for modalitym.
Advanced preprocessing
You can override default preprocessing with preprocess_fn (set on the
agent or per infer_states call). If provided, this function should
return categorical observations and takes precedence over default
discrete/categorical handling.
decode_multi_actions(action: Array) -> Array
¶
Decode flattened multi-actions back to factor-wise actions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action
|
Array
|
Flattened multi-action indices. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Array of shape |
encode_multi_actions(action_multi: Array) -> Array
¶
Encode factor-wise multi-actions into flattened actions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action_multi
|
Array
|
Array of actions per control factor. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Flattened action indices with shape |
get_model_dimensions() -> dict[str, Any]
¶
Collect key model dimensions in a single object.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing model shape metadata. Includes:
|
infer_parameters(beliefs_A: list[Array], observations: list[Array], actions: Array | None, beliefs_B: list[Array] | None = None, lr_pA: float = 1.0, lr_pB: float = 1.0, **kwargs: Any) -> Agent
¶
Update Dirichlet parameters for A and/or B models from data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
beliefs_A
|
list[Array]
|
Marginal state beliefs used when updating the observation model parameters. |
required |
observations
|
list[Array]
|
Observation histories for each modality. |
required |
actions
|
Array | None
|
Action history aligned to time. For multi-action agents this should be
shaped |
required |
beliefs_B
|
list[Array] | None
|
Optional sequence of beliefs used for transition updates. If |
None
|
lr_pA
|
float
|
Learning-rate multiplier for |
1.0
|
lr_pB
|
float
|
Learning-rate multiplier for |
1.0
|
**kwargs
|
Any
|
Reserved for future arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Agent
|
Agent instance with updated |
infer_policies(qs: list[Array]) -> tuple[Array, Array]
¶
Perform policy inference by optimizing a posterior (categorical) distribution over policies.
This distribution is computed as the softmax of neg_efe * gamma + lnE where
neg_efe is the negative expected free energy of policies, gamma is a policy
precision and lnE is the (log) prior probability of policies.
In SPM-style notation this same quantity is often written as G, with
G = neg_efe = -EFE.
This function returns the posterior over policies as well as the negative expected free energy of each policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs
|
list[Array]
|
Posterior beliefs over hidden states (typically output of
|
required |
Returns:
| Name | Type | Description |
|---|---|---|
q_pi |
Array
|
Posterior beliefs over policies with shape
|
neg_efe |
Array
|
Negative expected free energies of policies with shape
|
infer_states(observations: list[Array] | list[int], empirical_prior: list[Array], *, past_actions: Array | None = None, qs_hist: list[Array] | None = None, valid_steps: int | Array | None = None, mask: list[Array] | None = None, preprocess_fn: Callable | None = None, return_info: bool = False) -> list[Array] | tuple[list[Array], dict[str, Any]]
¶
Update approximate posterior over hidden states by solving variational inference problem, given an observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observations
|
list[Array] | list[int]
|
Observation input in one of two formats:
If |
required |
empirical_prior
|
list[Array]
|
Empirical prior beliefs over hidden states. Depending on the inference algorithm chosen,
the resulting |
required |
past_actions
|
Array | None
|
Action history aligned to time. For single-batch sequence inference
this should be shaped |
None
|
qs_hist
|
list[Array] | None
|
History of posterior beliefs over hidden states. |
None
|
valid_steps
|
int | Array | None
|
Number of valid (unpadded) timesteps when using fixed-size sequence windows.
If provided, sequence inference methods ( |
None
|
mask
|
list[Array] | None
|
Mask for observations. |
None
|
preprocess_fn
|
Callable | None
|
Optional preprocessing function to convert observations into distributional form.
If None, defaults to |
None
|
return_info
|
bool
|
If |
False
|
Notes
categorical_obs is no longer an argument to infer_states. Set it when
constructing the agent or supply a preprocess_fn. If you provide a custom
preprocessing function, ensure self.categorical_obs matches the output format,
since it is still used by learning and planning code paths that consume raw observations.
Returns:
| Name | Type | Description |
|---|---|---|
qs |
list[Array] or tuple[list[Array], dict[str, Any]]
|
Posterior beliefs over hidden states. Depending on the inference algorithm chosen,
the resulting |
Examples:
Discrete observations:
>>> obs = [0, 1] # Modality 0 observed observation 0, modality 1 observed observation 1
>>> qs = agent.infer_states(obs, prior)
Categorical observations:
>>> obs = [
... jnp.array([0.7, 0.2, 0.1]), # Peaked belief distribution for observation 0
... jnp.array([0.5, 0.5]) # Flat belief distribution for observation 1
... ]
>>> agent_cat = Agent(..., categorical_obs=True)
>>> qs = agent_cat.infer_states(obs, prior)
make_categorical(observations: list[Array] | list[int]) -> list[Array]
¶
Convert discrete index observations into one-hot categorical distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observations
|
list[Array] | list[int]
|
Each entry |
required |
Returns:
| Name | Type | Description |
|---|---|---|
o_vec |
list
|
One-hot categorical distributions for each modality. |
multiaction_probabilities(q_pi: Array) -> Array
¶
Compute probabilities of unique multi-actions from the posterior over policies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_pi
|
Array
|
Posterior beliefs over policies for one batch element. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Probability vector over unique multi-actions. |
process_obs(observations: list[Array] | list[int]) -> list[Array]
¶
Preprocess observations into the distributional format expected by the inference routines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observations
|
list[Array] | list[int]
|
The observation input. Format depends on the default preprocessing:
|
required |
Returns:
| Name | Type | Description |
|---|---|---|
o_vec |
list[Array]
|
Observations in distributional form (one-hot vectors or categorical distributions). |
Notes
If self.preprocess_fn is set on the agent, it takes precedence over the default
categorical/discrete handling and will be used instead of the logic based on
self.categorical_obs. This override only affects preprocessing; self.categorical_obs
is still used by learning and planning code paths that consume raw observations.
Ensure self.categorical_obs matches the output format of your preprocessing
(or per-call preprocess_fn) to keep those paths consistent.
sample_action(q_pi: Array, rng_key: Array | None = None) -> Array
¶
Sample or select a discrete action from the posterior over control states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_pi
|
Array
|
Posterior over policies for each batch element (usually from
|
required |
rng_key
|
Array | None
|
Required for stochastic action selection. For batched agents, pass a key array with one key per batch element. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
action |
Array
|
Action indices per batch element and control factor. |
update_empirical_prior(action: Array, qs: list[Array]) -> list[Array]
¶
Compute the empirical prior used for the next state-inference step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action
|
Array
|
Action sampled at the current timestep for each control factor. |
required |
qs
|
list[Array]
|
Posterior beliefs over hidden states for the current timestep/history. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
pred |
list[Array]
|
Predicted prior over hidden states for the next inference step.
For sequence methods ( |