MCTS Planning

pymdp.planning.mcts

compute_neg_efe(qs: list[jnp.ndarray], action: jnp.ndarray, A: list[jnp.ndarray], A_dependencies: list[list[int]], B: list[jnp.ndarray], B_dependencies: list[list[int]], C: list[jnp.ndarray], use_states_info_gain: bool = True, use_utility: bool = True) -> tuple[jnp.ndarray, list[jnp.ndarray], list[jnp.ndarray]]

Compute one-step negative expected free energy under an action.

Parameters:

Name Type Description Default
qs list[ndarray]

Current posterior beliefs over hidden-state factors.

required
action ndarray

Candidate action (multi-action index vector per batch element).

required
A list[ndarray]

Observation model tensors.

required
A_dependencies list[list[int]]

Modality-to-factor dependencies for A.

required
B list[ndarray]

Transition model tensors.

required
B_dependencies list[list[int]]

Factor-to-factor transition dependencies for B.

required
C list[ndarray]

Preferences over observations.

required
use_states_info_gain bool

Whether to include state information gain in EFE.

True
use_utility bool

Whether to include expected utility in EFE.

True

Returns:

Type Description
tuple[ndarray, list[ndarray], list[ndarray]]

Negative EFE, predicted next-state beliefs, and predicted next-observation beliefs.

get_prob_single_modality(o_m: jnp.ndarray, po_m: jnp.ndarray, distr_obs: bool) -> jnp.ndarray

Compute observation likelihood for a single modality (observation and likelihood)

make_aif_recurrent_fn() -> Callable[[Any, jnp.ndarray, jnp.ndarray, Any], tuple[mctx.RecurrentFnOutput, Any]]

Returns a recurrent_fn for an AIF agent.

Build an MCTS-based policy-search callable for Agent planning.

Parameters:

Name Type Description Default
search_algo Callable

MCTS search routine from mctx used to evaluate policy actions.

None
max_depth int

Maximum planning depth for the tree search.

6
num_simulations int

Number of MCTS simulations per planning call.

4096

Returns:

Type Description
Callable[[Any, list[ndarray], ndarray], tuple[ndarray, Any]]

Function with signature (agent, beliefs, rng_key) -> (q_pi, info) returning policy weights and raw search output.

rollout(policy_search: Callable, agent: Any, env: Any, num_timesteps: int, rng_key: jnp.ndarray) -> tuple[dict[str, Any], dict[str, Any], Any]

Run a policy-search rollout loop for MCTS-based planning.

Parameters:

Name Type Description Default
policy_search Callable

Planning callable that maps (rng_key, agent, qs) to policy weights.

required
agent Any

Active inference agent.

required
env Any

Environment exposing step(...) for batched transitions.

required
num_timesteps int

Number of rollout steps.

required
rng_key ndarray

Root JAX PRNG key.

required

Returns:

Type Description
tuple[dict[str, Any], dict[str, Any], Any]

Final carry dictionary, per-step rollout traces, and final environment.