NumPy/legacy to JAX Migration¶
This guide is for users moving from pymdp.legacy (NumPy/object-array style) to the JAX backend of pymdp.
Key concept shifts¶
| Legacy NumPy | JAX |
|---|---|
numpy.ndarray(dtype=object) collections |
pytrees/lists of jax.Array |
np.random global RNG |
explicit jr.PRNGKey threading |
| stateful loops with implicit mutation | functional updates with explicit returns |
mutable attribute assignment (agent.x = ...) |
functional pytree updates (eqx.tree_at(...)) |
pymdp.legacy.* modules |
pymdp.* modern JAX-based modules |
API differences to update¶
- Policy inference now takes current beliefs as inputs, since they are no longer stored internally on the agent.
# legacy NumPy q_pi, G = agent.infer_policies() # JAX q_pi, neg_efe = agent.infer_policies(qs)
Note: in SPM-style notation this same policy score is often called G
(G = neg_efe = -EFE).
-
Stochastic functions require explicit random keys. Common examples in
pymdpinclude action/policy sampling (for exampleagent.sample_action(..., rng_key=...)), random generative model initialization utilities such asutils.random_A_array(...),utils.random_B_array(...), andutils.random_factorized_categorical(...), rollout execution (rollout(..., rng_key=...)), and stochastic environment methods (env.reset(key, ...),env.step(key, ...)).# legacy NumPy action = agent.sample_action() # JAX keys = jr.split(rng_key, agent.batch_size + 1) action = agent.sample_action(q_pi, rng_key=keys[1:]) -
Keep observation preprocessing consistent.
- If
categorical_obs=False, pass discrete indices. - If
categorical_obs=True, pass normalized categorical vectors.
- If
Batching and batch_size¶
Most of pymdp's JAX APIs expect a leading batch_size dimension on most
arrays (for example state/parameter priors and posteriors). This makes it easy
to parallelize with jax.vmap.
Defaults to keep in mind:
Agent(..., batch_size=1)is the default.- By default,
Envobjects are typically unbatched (for exampleenv.A/B/Dhave no leading batch axis), whileAgentleaves are batch-shaped. - In the default single-agent case, many arrays therefore appear as
(1, ...).
When migrating from legacy single-agent code, keep these points in mind:
- Multi-agent simulations (including parallel parameter sweeps) require batched per-agent information (observations, beliefs, parameters, preferences, and some hyperparameters).
- In
Agent.__init__, if you pass unbatchedA/Bwithbatch_size=N, the tensors are broadcast to(N, ...). - If inputs are already batched with leading dimension
1andbatch_size>1, they are also broadcast tobatch_size. - If the leading dimension is batched but neither
1norbatch_size, aValueErroris raised.
Typical multi-agent setup pattern:
batch_size = 3
agent = Agent(A=A, B=B, C=C, D=D, batch_size=batch_size)
# One discrete modality for 3 parallel agents: shape (batch_size, 1)
obs = [jnp.array([[0], [2], [3]])]
qs = agent.infer_states(obs, empirical_prior=agent.D)
For environment-side batching, either:
- Keep a single (unbatched) environment and run batched agents against shared
env.A/B/D. - Pass batched
env_params(for example viaenv.generate_env_params(batch_size=...)) when usingrollout(). Parameters insideenv_paramsshould also carry a leadingbatch_sizedimension aligned to theAgent.
Updating Agent fields in JAX¶
The Agent class is an Equinox
module. In practice, this means you should avoid mutable-style setter updates
like:
agent.A = new_A
agent.B = new_B
Instead, use Equinox-style functional updates with eqx.tree_at(...):
import equinox as eqx
agent = eqx.tree_at(lambda x: (x.A,), agent, (new_A,))
agent = eqx.tree_at(lambda x: (x.B,), agent, (new_B,))
If you change static Agent fields (for example model dimensions, dependency
structure, or static hyperparameters such as number of policies), JAX will
trigger recompiles of agent-specific JIT-compiled methods such as
infer_states, infer_policies, and related methods used during rollouts.
This is the pattern used near the end of
infer_parameters() in pymdp/agent.py, where after a learning update, the new A/B (and, when
enabled, I) arrays are written back into a new Agent instance.
Randomness migration¶
Use explicit key flow everywhere:
rng_key = jr.PRNGKey(0)
rng_key, key_infer, key_action = jr.split(rng_key, 3)
qs = agent.infer_states(obs, empirical_prior=agent.D)
q_pi, _ = agent.infer_policies(qs)
keys = jr.split(key_action, agent.batch_size + 1)
action = agent.sample_action(q_pi, rng_key=keys[1:])
Other frequently used stochastic APIs that also require explicit keys:
- utils.random_A_array(...)
- utils.random_B_array(...)
- utils.random_factorized_categorical(...)
- utils.generate_agent_spec(..., key=...)
- rollout(..., rng_key=...)
- stochastic env.reset(key, ...) / env.step(key, ...)
Avoid np.random when using the new JAX backend (see also
this post, which
encourages caution when working with np.random).
NumPy-to-JAX worked conversion¶
# legacy NumPy
from pymdp.legacy.agent import Agent as LegacyAgent
from pymdp.legacy import utils as legacy_utils
A = legacy_utils.random_A_matrix([3], [2])
B = legacy_utils.random_B_matrix([2], [2])
agent = LegacyAgent(A=A, B=B)
obs = [1]
qs = agent.infer_states(obs)
q_pi, G = agent.infer_policies()
action = agent.sample_action()
# JAX
from jax import random as jr
from pymdp.agent import Agent
from pymdp import utils
key = jr.PRNGKey(0)
keys = jr.split(key, 3)
A = utils.random_A_array(keys[0], [3], [2])
B = utils.random_B_array(keys[1], [2], [2])
agent = Agent(A=A, B=B)
obs = [1]
qs = agent.infer_states(obs, empirical_prior=agent.D)
q_pi, neg_efe = agent.infer_policies(qs)
action_keys = jr.split(keys[2], agent.batch_size + 1)
action = agent.sample_action(q_pi, rng_key=action_keys[1:])
Common migration pitfalls¶
- Missing batch-size dimension for
observations,actions,qs, or parameters (A,B,C,D). - Missing
empirical_priorargument ininfer_states. - Missing
rng_key/keyin stochastic calls (for examplesample_action, random model constructors,rollout, or stochasticenv.reset/env.step). - Mixing
numpy.ndarrayinto JAX-only paths. - Forgetting sequence action-history shape
(T-1, num_factors).
Migration done checklist¶
- No imports from
pymdp.legacyin active scripts/notebooks. - Randomness uses explicit
jr.PRNGKeyflow. infer_policies(qs)andsample_action(q_pi, rng_key=...)usage updated.- Tests pass (
pytest test). - Notebook/docs examples run with API.