MCTS Planning¶
pymdp.planning.mcts
¶
compute_neg_efe(qs: list[jnp.ndarray], action: jnp.ndarray, A: list[jnp.ndarray], A_dependencies: list[list[int]], B: list[jnp.ndarray], B_dependencies: list[list[int]], C: list[jnp.ndarray], use_states_info_gain: bool = True, use_utility: bool = True) -> tuple[jnp.ndarray, list[jnp.ndarray], list[jnp.ndarray]]
¶
Compute one-step negative expected free energy under an action.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs
|
list[ndarray]
|
Current posterior beliefs over hidden-state factors. |
required |
action
|
ndarray
|
Candidate action (multi-action index vector per batch element). |
required |
A
|
list[ndarray]
|
Observation model tensors. |
required |
A_dependencies
|
list[list[int]]
|
Modality-to-factor dependencies for |
required |
B
|
list[ndarray]
|
Transition model tensors. |
required |
B_dependencies
|
list[list[int]]
|
Factor-to-factor transition dependencies for |
required |
C
|
list[ndarray]
|
Preferences over observations. |
required |
use_states_info_gain
|
bool
|
Whether to include state information gain in EFE. |
True
|
use_utility
|
bool
|
Whether to include expected utility in EFE. |
True
|
Returns:
| Type | Description |
|---|---|
tuple[ndarray, list[ndarray], list[ndarray]]
|
Negative EFE, predicted next-state beliefs, and predicted next-observation beliefs. |
get_prob_single_modality(o_m: jnp.ndarray, po_m: jnp.ndarray, distr_obs: bool) -> jnp.ndarray
¶
Compute observation likelihood for a single modality (observation and likelihood)
make_aif_recurrent_fn() -> Callable[[Any, jnp.ndarray, jnp.ndarray, Any], tuple[mctx.RecurrentFnOutput, Any]]
¶
Returns a recurrent_fn for an AIF agent.
mcts_policy_search(search_algo: Callable | None = None, max_depth: int = 6, num_simulations: int = 4096) -> Callable[[Any, list[jnp.ndarray], jnp.ndarray], tuple[jnp.ndarray, Any]]
¶
Build an MCTS-based policy-search callable for Agent planning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
search_algo
|
Callable
|
MCTS search routine from |
None
|
max_depth
|
int
|
Maximum planning depth for the tree search. |
6
|
num_simulations
|
int
|
Number of MCTS simulations per planning call. |
4096
|
Returns:
| Type | Description |
|---|---|
Callable[[Any, list[ndarray], ndarray], tuple[ndarray, Any]]
|
Function with signature |
rollout(policy_search: Callable, agent: Any, env: Any, num_timesteps: int, rng_key: jnp.ndarray) -> tuple[dict[str, Any], dict[str, Any], Any]
¶
Run a policy-search rollout loop for MCTS-based planning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy_search
|
Callable
|
Planning callable that maps |
required |
agent
|
Any
|
Active inference agent. |
required |
env
|
Any
|
Environment exposing |
required |
num_timesteps
|
int
|
Number of rollout steps. |
required |
rng_key
|
ndarray
|
Root JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Any], dict[str, Any], Any]
|
Final carry dictionary, per-step rollout traces, and final environment. |