Sophisticated Inference Planning

pymdp.planning.si

Tree

Bases: Module

A tree structure to hold the planning nodes and their data.

This is an equinox module which allows for JAX transformations like jit. Pre-allocates memory for up to max_nodes nodes and max_branching children per node.

__getitem__(index: int) -> dict[str, Any]

Get the node at the given index as a dictionary.

Works on a non-batched Tree!

root() -> dict[str, Any]

Get the root nodes of the tree.

Works on a non-batched Tree!

infer_fn(agent: Any, obs: list[jnp.ndarray], qs: list[jnp.ndarray]) -> list[jnp.ndarray]

Infer posterior states for a candidate observation in SI expansion.

Parameters:

Name Type Description Default
agent Any

Active inference agent.

required
obs list[ndarray]

Candidate observation(s) for each modality.

required
qs list[ndarray]

Prior/predictive beliefs to condition on for state inference.

required

Returns:

Type Description
list[ndarray]

Posterior beliefs over hidden-state factors.

Perform sophisticated-inference tree expansion from the current root.

Expansion continues until one of the stop conditions is met:

  • planning horizon is reached
  • root policy entropy drops below entropy_stop_threshold
  • root negative expected free energy reaches or exceeds neg_efe_stop_threshold

Parameters:

Name Type Description Default
agent Any

Agent instance containing model parameters and policy set.

required
tree Tree

Initial planning tree to expand.

required
horizon int

Maximum planning depth.

required
policy_prune_threshold float

Minimum policy probability required to expand a policy branch.

1/16
observation_prune_threshold float

Minimum observation probability required to expand an observation branch.

1/16
entropy_stop_threshold float

Root-policy entropy threshold used as an early-stop criterion.

0.5
neg_efe_stop_threshold float

Root negative expected free energy threshold used as an early-stop criterion (neg_efe = -EFE). Expansion stops when root neg_efe is greater than or equal to this threshold.

1e10
kl_threshold float

KL threshold for reusing existing observation nodes with similar beliefs. A value of -1 disables node reuse.

-1
prune_penalty float

Penalty assigned to pruned branches.

512
gamma float

Precision (inverse temperature) for policy softmax updates.

1
topk_obsspace int

Maximum number of top observation combinations considered per expansion.

10000
infer_fn Callable

Function used to infer posterior states from candidate observations.

infer_fn
predict_fn Callable

Function used to predict next-state/observation beliefs and policy scores (neg_efe = -EFE).

predict_fn

Returns:

Type Description
Tree

Expanded planning tree.

predict_fn(agent: Any, qs: list[jnp.ndarray]) -> tuple[list[jnp.ndarray], list[jnp.ndarray], jnp.ndarray]

Predict one-step beliefs and policy scores for all policies.

Parameters:

Name Type Description Default
agent Any

Active inference agent containing A, B, C, dependencies, and policy definitions.

required
qs list[ndarray]

Current posterior beliefs over hidden-state factors.

required

Returns:

Type Description
tuple[list[ndarray], list[ndarray], ndarray]

Predicted next-state beliefs, predicted next-observation beliefs, and policy scores (neg_efe = -EFE) for each policy.

root_idx(tree: Tree) -> jnp.ndarray

Return the index of the root observation node in a planning tree.

Parameters:

Name Type Description Default
tree Tree

Planning tree instance.

required

Returns:

Type Description
ndarray

Scalar index of the root node, or -1 if not found.

Create a sophisticated-inference policy-search function.

Parameters:

Name Type Description Default
horizon int

Maximum tree-expansion horizon.

5
max_nodes int

Maximum number of nodes preallocated in the planning tree.

5000
max_branching int

Maximum number of children per node.

10
policy_prune_threshold float

Minimum policy probability required for expansion.

1/16
observation_prune_threshold float

Minimum observation-branch probability required for expansion.

1/16
entropy_stop_threshold float

Stop-expansion threshold on root policy entropy.

0.5
neg_efe_stop_threshold float

Stop-expansion threshold on root negative expected free energy (neg_efe = -EFE). Expansion stops when root neg_efe is greater than or equal to this threshold.

1e10
kl_threshold float

Optional KL threshold for node reuse.

-1
prune_penalty float

Penalty assigned to pruned/dead-end branches.

512
gamma float

Precision/temperature scale used in policy softmax updates.

1
topk_obsspace int

Maximum observation-combination budget per expansion.

10000
infer_fn Callable

State-inference function used during observation expansion.

infer_fn
predict_fn Callable

One-step prediction/valuation function for policy expansion.

predict_fn

Returns:

Type Description
Callable

Search function compatible with pymdp.envs.rollout.rollout policy hooks.