NumPy/legacy to JAX Migration

This guide is for users moving from pymdp.legacy (NumPy/object-array style) to the JAX backend of pymdp.

Key concept shifts

Legacy NumPy JAX
numpy.ndarray(dtype=object) collections pytrees/lists of jax.Array
np.random global RNG explicit jr.PRNGKey threading
stateful loops with implicit mutation functional updates with explicit returns
mutable attribute assignment (agent.x = ...) functional pytree updates (eqx.tree_at(...))
pymdp.legacy.* modules pymdp.* modern JAX-based modules

API differences to update

  1. Policy inference now takes current beliefs as inputs, since they are no longer stored internally on the agent.
    # legacy NumPy
    q_pi, G = agent.infer_policies()
    
    # JAX
    q_pi, neg_efe = agent.infer_policies(qs)
    

Note: in SPM-style notation this same policy score is often called G (G = neg_efe = -EFE).

  1. Stochastic functions require explicit random keys. Common examples in pymdp include action/policy sampling (for example agent.sample_action(..., rng_key=...)), random generative model initialization utilities such as utils.random_A_array(...), utils.random_B_array(...), and utils.random_factorized_categorical(...), rollout execution (rollout(..., rng_key=...)), and stochastic environment methods (env.reset(key, ...), env.step(key, ...)).

    # legacy NumPy
    action = agent.sample_action()
    
    # JAX
    keys = jr.split(rng_key, agent.batch_size + 1)
    action = agent.sample_action(q_pi, rng_key=keys[1:])
    
  2. Keep observation preprocessing consistent.

    • If categorical_obs=False, pass discrete indices.
    • If categorical_obs=True, pass normalized categorical vectors.

Batching and batch_size

Most of pymdp's JAX APIs expect a leading batch_size dimension on most arrays (for example state/parameter priors and posteriors). This makes it easy to parallelize with jax.vmap.

Defaults to keep in mind:

  • Agent(..., batch_size=1) is the default.
  • By default, Env objects are typically unbatched (for example env.A/B/D have no leading batch axis), while Agent leaves are batch-shaped.
  • In the default single-agent case, many arrays therefore appear as (1, ...).

When migrating from legacy single-agent code, keep these points in mind:

  • Multi-agent simulations (including parallel parameter sweeps) require batched per-agent information (observations, beliefs, parameters, preferences, and some hyperparameters).
  • In Agent.__init__, if you pass unbatched A/B with batch_size=N, the tensors are broadcast to (N, ...).
  • If inputs are already batched with leading dimension 1 and batch_size>1, they are also broadcast to batch_size.
  • If the leading dimension is batched but neither 1 nor batch_size, a ValueError is raised.

Typical multi-agent setup pattern:

batch_size = 3
agent = Agent(A=A, B=B, C=C, D=D, batch_size=batch_size)

# One discrete modality for 3 parallel agents: shape (batch_size, 1)
obs = [jnp.array([[0], [2], [3]])]
qs = agent.infer_states(obs, empirical_prior=agent.D)

For environment-side batching, either:

  • Keep a single (unbatched) environment and run batched agents against shared env.A/B/D.
  • Pass batched env_params (for example via env.generate_env_params(batch_size=...)) when using rollout(). Parameters inside env_params should also carry a leading batch_size dimension aligned to the Agent.

Updating Agent fields in JAX

The Agent class is an Equinox module. In practice, this means you should avoid mutable-style setter updates like:

agent.A = new_A
agent.B = new_B

Instead, use Equinox-style functional updates with eqx.tree_at(...):

import equinox as eqx

agent = eqx.tree_at(lambda x: (x.A,), agent, (new_A,))
agent = eqx.tree_at(lambda x: (x.B,), agent, (new_B,))

If you change static Agent fields (for example model dimensions, dependency structure, or static hyperparameters such as number of policies), JAX will trigger recompiles of agent-specific JIT-compiled methods such as infer_states, infer_policies, and related methods used during rollouts.

This is the pattern used near the end of infer_parameters() in pymdp/agent.py, where after a learning update, the new A/B (and, when enabled, I) arrays are written back into a new Agent instance.

Randomness migration

Use explicit key flow everywhere:

rng_key = jr.PRNGKey(0)
rng_key, key_infer, key_action = jr.split(rng_key, 3)

qs = agent.infer_states(obs, empirical_prior=agent.D)
q_pi, _ = agent.infer_policies(qs)
keys = jr.split(key_action, agent.batch_size + 1)
action = agent.sample_action(q_pi, rng_key=keys[1:])

Other frequently used stochastic APIs that also require explicit keys: - utils.random_A_array(...) - utils.random_B_array(...) - utils.random_factorized_categorical(...) - utils.generate_agent_spec(..., key=...) - rollout(..., rng_key=...) - stochastic env.reset(key, ...) / env.step(key, ...)

Avoid np.random when using the new JAX backend (see also this post, which encourages caution when working with np.random).

NumPy-to-JAX worked conversion

# legacy NumPy
from pymdp.legacy.agent import Agent as LegacyAgent
from pymdp.legacy import utils as legacy_utils
A = legacy_utils.random_A_matrix([3], [2])
B = legacy_utils.random_B_matrix([2], [2])
agent = LegacyAgent(A=A, B=B)
obs = [1]
qs = agent.infer_states(obs)
q_pi, G = agent.infer_policies()
action = agent.sample_action()
# JAX
from jax import random as jr
from pymdp.agent import Agent
from pymdp import utils

key = jr.PRNGKey(0)
keys = jr.split(key, 3)
A = utils.random_A_array(keys[0], [3], [2])
B = utils.random_B_array(keys[1], [2], [2])
agent = Agent(A=A, B=B)
obs = [1]
qs = agent.infer_states(obs, empirical_prior=agent.D)
q_pi, neg_efe = agent.infer_policies(qs)
action_keys = jr.split(keys[2], agent.batch_size + 1)
action = agent.sample_action(q_pi, rng_key=action_keys[1:])

Common migration pitfalls

  1. Missing batch-size dimension for observations, actions, qs, or parameters (A, B, C, D).
  2. Missing empirical_prior argument in infer_states.
  3. Missing rng_key/key in stochastic calls (for example sample_action, random model constructors, rollout, or stochastic env.reset/env.step).
  4. Mixing numpy.ndarray into JAX-only paths.
  5. Forgetting sequence action-history shape (T-1, num_factors).

Migration done checklist

  1. No imports from pymdp.legacy in active scripts/notebooks.
  2. Randomness uses explicit jr.PRNGKey flow.
  3. infer_policies(qs) and sample_action(q_pi, rng_key=...) usage updated.
  4. Tests pass (pytest test).
  5. Notebook/docs examples run with API.