Sophisticated Inference Planning¶
pymdp.planning.si
¶
Tree
¶
Bases: Module
A tree structure to hold the planning nodes and their data.
This is an equinox module which allows for JAX transformations like jit.
Pre-allocates memory for up to max_nodes nodes and max_branching children per node.
infer_fn(agent: Any, obs: list[jnp.ndarray], qs: list[jnp.ndarray]) -> list[jnp.ndarray]
¶
Infer posterior states for a candidate observation in SI expansion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
agent
|
Any
|
Active inference agent. |
required |
obs
|
list[ndarray]
|
Candidate observation(s) for each modality. |
required |
qs
|
list[ndarray]
|
Prior/predictive beliefs to condition on for state inference. |
required |
Returns:
| Type | Description |
|---|---|
list[ndarray]
|
Posterior beliefs over hidden-state factors. |
optimized_tree_search(agent: Any, tree: Tree, horizon: int, policy_prune_threshold: float = 1 / 16, observation_prune_threshold: float = 1 / 16, entropy_stop_threshold: float = 0.5, neg_efe_stop_threshold: float = 10000000000.0, kl_threshold: float = -1, prune_penalty: float = 512, gamma: float = 1, topk_obsspace: int = 10000, infer_fn: Callable = infer_fn, predict_fn: Callable = predict_fn) -> Tree
¶
Perform sophisticated-inference tree expansion from the current root.
Expansion continues until one of the stop conditions is met:
- planning horizon is reached
- root policy entropy drops below
entropy_stop_threshold - root negative expected free energy reaches or exceeds
neg_efe_stop_threshold
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
agent
|
Any
|
Agent instance containing model parameters and policy set. |
required |
tree
|
Tree
|
Initial planning tree to expand. |
required |
horizon
|
int
|
Maximum planning depth. |
required |
policy_prune_threshold
|
float
|
Minimum policy probability required to expand a policy branch. |
1/16
|
observation_prune_threshold
|
float
|
Minimum observation probability required to expand an observation branch. |
1/16
|
entropy_stop_threshold
|
float
|
Root-policy entropy threshold used as an early-stop criterion. |
0.5
|
neg_efe_stop_threshold
|
float
|
Root negative expected free energy threshold used as an early-stop
criterion ( |
1e10
|
kl_threshold
|
float
|
KL threshold for reusing existing observation nodes with similar beliefs.
A value of |
-1
|
prune_penalty
|
float
|
Penalty assigned to pruned branches. |
512
|
gamma
|
float
|
Precision (inverse temperature) for policy softmax updates. |
1
|
topk_obsspace
|
int
|
Maximum number of top observation combinations considered per expansion. |
10000
|
infer_fn
|
Callable
|
Function used to infer posterior states from candidate observations. |
infer_fn
|
predict_fn
|
Callable
|
Function used to predict next-state/observation beliefs and policy
scores ( |
predict_fn
|
Returns:
| Type | Description |
|---|---|
Tree
|
Expanded planning tree. |
predict_fn(agent: Any, qs: list[jnp.ndarray]) -> tuple[list[jnp.ndarray], list[jnp.ndarray], jnp.ndarray]
¶
Predict one-step beliefs and policy scores for all policies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
agent
|
Any
|
Active inference agent containing |
required |
qs
|
list[ndarray]
|
Current posterior beliefs over hidden-state factors. |
required |
Returns:
| Type | Description |
|---|---|
tuple[list[ndarray], list[ndarray], ndarray]
|
Predicted next-state beliefs, predicted next-observation beliefs, and
policy scores ( |
root_idx(tree: Tree) -> jnp.ndarray
¶
Return the index of the root observation node in a planning tree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Tree
|
Planning tree instance. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Scalar index of the root node, or |
si_policy_search(horizon: int = 5, max_nodes: int = 5000, max_branching: int = 10, policy_prune_threshold: float = 1 / 16, observation_prune_threshold: float = 1 / 16, entropy_stop_threshold: float = 0.5, neg_efe_stop_threshold: float = 10000000000.0, kl_threshold: float = -1, prune_penalty: float = 512, gamma: float = 1, topk_obsspace: int = 10000, infer_fn: Callable = infer_fn, predict_fn: Callable = predict_fn) -> Callable
¶
Create a sophisticated-inference policy-search function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
horizon
|
int
|
Maximum tree-expansion horizon. |
5
|
max_nodes
|
int
|
Maximum number of nodes preallocated in the planning tree. |
5000
|
max_branching
|
int
|
Maximum number of children per node. |
10
|
policy_prune_threshold
|
float
|
Minimum policy probability required for expansion. |
1/16
|
observation_prune_threshold
|
float
|
Minimum observation-branch probability required for expansion. |
1/16
|
entropy_stop_threshold
|
float
|
Stop-expansion threshold on root policy entropy. |
0.5
|
neg_efe_stop_threshold
|
float
|
Stop-expansion threshold on root negative expected free energy
( |
1e10
|
kl_threshold
|
float
|
Optional KL threshold for node reuse. |
-1
|
prune_penalty
|
float
|
Penalty assigned to pruned/dead-end branches. |
512
|
gamma
|
float
|
Precision/temperature scale used in policy softmax updates. |
1
|
topk_obsspace
|
int
|
Maximum observation-combination budget per expansion. |
10000
|
infer_fn
|
Callable
|
State-inference function used during observation expansion. |
infer_fn
|
predict_fn
|
Callable
|
One-step prediction/valuation function for policy expansion. |
predict_fn
|
Returns:
| Type | Description |
|---|---|
Callable
|
Search function compatible with |