Using rollout() for compiled active inference loops¶
rollout() runs repeated inference, planning, action sampling, and environment stepping for a fixed horizon.
Internally, it uses jax.lax.scan, so direct calls are valid and JAX-friendly.
When to use it¶
- The
.step()and.reset()methods of yourEnvcan be JIT-compiled. This includes compatibility withpymdp's nativePymdpEnv, as well as existing JAX RL environment frameworks (see for example gymnax, jumanji, and navix). - You want high-throughput simulations by compiling the full closed-loop interaction once and executing it efficiently across many rollouts.
- You want a single, consistent API for multi-step active inference rollouts with explicit PRNG key threading.
Required inputs¶
agent:pymdp.agent.Agentenv:pymdp.envs.env.Env-compatible objectnum_timesteps: integer horizonrng_key: JAX PRNG key
Optional inputs¶
initial_carry: override initial rollout carrypolicy_search: custom policy search functionenv_params: batched environment parameters
Canonical usage¶
from jax import random as jr
from pymdp.envs.rollout import rollout
rng_key = jr.PRNGKey(0)
last, info = rollout(agent, env, 20, rng_key)
For repeated calls with fixed env and num_timesteps, we recommend wrapping
rollout with jit so XLA can cache the compiled program:
from jax import jit
from jax import random as jr
from pymdp.envs.rollout import rollout
rng_key = jr.PRNGKey(0)
rollout_jit = jit(rollout, static_argnums=[1, 2]) # env and num_timesteps are static
last, info = rollout_jit(
agent,
env,
20,
rng_key,
)
Reproducible key flow¶
rollout() internally splits keys per step and per batch. For deterministic re-runs:
- pass in the same
rng_keyseed, - keep environment params/initial carry identical,
- avoid hidden non-JAX randomness inside your environment's
.step()or.reset()methods.
Batched runs and carry¶
agent.batch_sizecontrols parallel batch dimension.initial_carrycan override auto-initialized state (for warm-starting).
Relationship to manual loops¶
rollout() repeatedly applies the one-step helper infer_and_plan internally
using jax.lax.scan.
Use manual loops when:
- your environment's
.step()and.reset()methods cannot be JITTed. - you need custom per-step side effects that don't respect the active inference logic of
infer_and_plan.
For JAX-based environments, we recommend using rollout(), as it's usually simpler and less error-prone.
Debugging checklist¶
- Shape mismatch: check observation/action histories and factor dimensions.
- Stochastic sampling errors: ensure valid RNG keys are threaded.
- Sequence methods (
mmp,vmp): ensurepast_actionsand valid windows are correct. - Learning updates: verify
learn_A/learn_Bflags and action alignment.