%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp[modelfit]" -q
from jax import numpy as jnp, random as jr
from jax import tree_util as jtu
from jax import vmap, nn, lax
import matplotlib.pyplot as plt
import numpy as np
from pymdp.agent import Agent
from pymdp.envs import TMaze
from pybefit.inference import (
run_nuts,
run_svi,
default_dict_nuts,
default_dict_numpyro_svi,
)
from pybefit.inference import NumpyroModel, NumpyroGuide
from pybefit.inference import Normal, NormalPosterior
from pybefit.inference.numpyro.likelihoods import pymdp_likelihood as likelihood
from numpyro.infer import Predictive
Fitting the parameters of active inference agents performing in the T-Maze environment¶
We set up a TMaze instance with a fully reliable cue-to-reward mapping, so the cue unambiguously reveals the rewarding arm. We also set the rewarded-arm outcome mapping to be probabilistic at 80%. This means that in the rewarded arm, reward is observed with probability 0.8 and punishment with probability 0.2, while these probabilities are inverted in the punished arm (dependent_outcomes=True).
We also parameterize 10 different agent-environment pairs (batch_size=10), 100 experimental blocks (episodes) per agent, and 5 timesteps per block.
Explanatory note on batch_size, subject, block, timestep, and trial terminology
Experimental data collected from individual participants or subjects in psychological and psychiatric studies often follows a multi-level temporal structure. Typically, each subject participates in a task independently from other participants, and each subject performs multiple episodes or blocks sequentially (for example, 20 blocks of 100 trials per block).
Depending on the experimental design, blocks may be independent or may include between-block correlations (for example, regular changes in initial task state across blocks). Within each block, individual "trials" or "timesteps" occur sequentially. There can also be within-block, between-trial correlation structure, and this is an assumption we make in pymdp: a single active inference process unfolds across these timesteps.
Because the word "trial" often refers to events that have their own internal temporal structure, this notebook uses timesteps to reduce ambiguity. This aligns with reinforcement learning and active inference workflows, where these are naturally treated as timesteps and blocks are naturally treated as episodes.
seed_key = jr.PRNGKey(101)
# setting the parameters for the environment
num_blocks = 100 # number of blocks (number of multi-timestep episodes that are run sequentially, per agent)
num_timesteps = 5 # number of timesteps (or "trials") per block
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 0.8 # 80% chance of reward in the correct arm
punishment_probability = 0.0 # only used when dependent_outcomes=False
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = True # if True, reward and punishment probabilities are coupled across arms via reward_probability. If False, punishment occurs with the fixed punishment_probability.
# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
task = TMaze(
reward_probability=reward_probability,
punishment_probability=punishment_probability,
cue_validity=cue_validity,
reward_condition=reward_condition,
dependent_outcomes=dependent_outcomes
)
batch_size = 10 # batch_size, which in this case corresponds to the number of agents to fit in parallel
key, _key = jr.split(seed_key)
init_obs, init_states = vmap(task.reset)(jr.split(_key, batch_size))
# Constrained recoverable setup: infer only reward-probability beliefs.
# Keep preference strength (lambda) and initial reward-state prior fixed.
num_params = 1
num_agents = batch_size
prior = Normal(num_params, num_agents, backend="numpyro")
def transform(z):
na, _ = z.shape
reward_prob = 0.5 + 0.5 * nn.sigmoid(z[..., 0]) # constrain reward probability to be between 0.5 and 1
lam = jnp.full((na,), 1.2)
A = lax.stop_gradient(task.A)
A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (na,) + x.shape), A)
B = lax.stop_gradient(task.B)
B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (na,) + x.shape), B)
one_minus = 1.0 - reward_prob
reward_left = jnp.stack([reward_prob, one_minus], -1)
punish_left = jnp.stack([one_minus, reward_prob], -1)
reward_right = jnp.stack([one_minus, reward_prob], -1)
punish_right = jnp.stack([reward_prob, one_minus], -1)
zeros = jnp.zeros_like(reward_left)
side = jnp.broadcast_to(jnp.array([[1., 1.], [0., 0.], [0., 0.]]), (na, 3, 2))
left_col = jnp.stack([zeros, reward_left, punish_left], axis=-2)
right_col = jnp.stack([zeros, reward_right, punish_right], axis=-2)
A[1] = jnp.stack([side, left_col, right_col, side, side], axis=-2)
C = [
jnp.zeros((na, A[0].shape[1])),
jnp.expand_dims(lam, -1) * jnp.array([0., 1., -1.]),
jnp.zeros((na, A[2].shape[1])),
]
D = [
jnp.zeros((na, B[0].shape[1])).at[:, 0].set(1.0),
jnp.full((na, 2), 0.5),
]
return Agent(
A,
B,
C=C,
D=D,
policy_len=2,
A_dependencies=task.A_dependencies,
B_dependencies=task.B_dependencies,
batch_size=na,
action_selection="stochastic",
)
key, _key = jr.split(seed_key)
# Use a broad, deterministic latent grid so recoverability is visible across agents.
z = jnp.linspace(-2.5, 2.5, num_agents).reshape(num_agents, num_params)
agents = transform(z)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. return Agent(
Sample from the TMaze environment and visualize the results from one block, showing the expected behavior (visit the cue, then choose an arm)¶
opts_task = {
"task": task,
"num_blocks": num_blocks,
"num_trials": num_timesteps,
"num_agents": num_agents,
}
opts_model = {"prior": {}, "transform": {}, "likelihood": opts_task}
model = NumpyroModel(prior, transform, likelihood, opts=opts_model)
pred = Predictive(model, num_samples=1)
key, _key = jr.split(key)
samples = pred(_key)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. return Agent(
frames = []
for t in range(num_timesteps): # iterate over timesteps
# get observations for this timestep
observations_t = [
samples["outcomes"][0][0,0,:, t], # subset by predictive sample (first leading dimension) and block (second leading dimension)
samples["outcomes"][1][0,0,:, t],
samples["outcomes"][2][0,0,:, t]
]
observations_t = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), observations_t) # add lagging dimension as is done before returning in task.step()
frame = task.render(mode="rgb_array", observations=observations_t).astype(jnp.uint8) # render the environment using the observations for this timestep
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = jnp.array(frames, dtype=jnp.uint8)
## Make a panel of subplots showing the frames at different timesteps of the video sequence (don't assume mediapy dependency)
fig, axes = plt.subplots(1, num_timesteps, figsize=(15, 5))
for i in range(num_timesteps):
axes[i].imshow(frames[i])
axes[i].axis('off')
axes[i].set_title(f'Timestep {i+1}')
plt.show()
Inference Method 1: HMC with NUTS¶
Use the No U-Turn Sampler (NUTS) for Hamiltonian Monte Carlo sampling-based inference to sample from the parameter posterior. pybefit provides useful wrappers for setting up an NUTS-HMC run.
# perform inference on parameters using no-u-turn sampler (NUTS)
# opts_sampling dictionary can be used to specify various parameters
# either for the NUTS kernel or MCMC sampler
measurements = {
"outcomes": [outcomes[0] for outcomes in samples["outcomes"]],
"multiactions": samples["multiactions"][0],
}
opts_sampling = default_dict_nuts
opts_sampling["num_warmup"] = 400
opts_sampling["num_samples"] = 100
opts_sampling["sampler_kwargs"] = {"kernel": {}, "mcmc": {"progress_bar": True}}
print(opts_sampling)
mcmc_samples, mcmc = run_nuts(model, measurements, opts=opts_sampling)
{'seed': 0, 'num_samples': 100, 'num_warmup': 400, 'sampler_kwargs': {'kernel': {}, 'mcmc': {'progress_bar': True}}}
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. return Agent( sample: 100%|██████████| 500/500 [05:56<00:00, 1.40it/s, 7 steps of size 7.24e-01. acc. prob=0.82]
Plot each ground truth parameter alongside their posterior means (mean taken over parallel HMC samples/chains)¶
plt.figure(figsize=(16, 5))
ground_truth_z = samples['z'][0]
for i in range(num_params):
plt.scatter(ground_truth_z[:, i], mcmc_samples["z"].mean(0)[:, i], label=i)
plt.plot((ground_truth_z.min(), ground_truth_z.max()), (ground_truth_z.min(), ground_truth_z.max()), "k--")
plt.ylabel("posterior mean (MCMC)")
plt.xlabel("true value")
plt.legend(title="parameter id")
<matplotlib.legend.Legend at 0x3666c1850>
Transform the latent parameter corresponding to the reward probability into probability space and investigate overlap between ground-truth and inferred parameter¶
inferred_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(mcmc_samples["z"].mean(0)[:, 0])
ground_truth_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0])
plt.figure(figsize=(16, 5))
plt.scatter(ground_truth_reward_probabilities, inferred_reward_probabilities, label="reward probability")
plt.plot((ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), (ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), "k--")
plt.ylabel("posterior mean (NUTS-HMC)")
plt.xlabel("true value")
plt.legend(title="parameter id")
corr = np.corrcoef(
np.array(ground_truth_reward_probabilities),
np.array(inferred_reward_probabilities),
)[0, 1]
bimodality_score = np.clip(
np.mean((np.array(inferred_reward_probabilities) < 0.2) | (np.array(inferred_reward_probabilities) > 0.8))
- np.mean((np.array(inferred_reward_probabilities) >= 0.35) & (np.array(inferred_reward_probabilities) <= 0.65)),
0.0,
1.0,
)
print(f"reward-probability Pearson r: {corr:.3f}")
print(f"reward-probability bimodality score: {bimodality_score:.3f}")
reward-probability Pearson r: 0.906 reward-probability bimodality score: 0.100
# Inspect full NUTS posterior shape per subject (not just posterior means)
z_samples = np.array(mcmc_samples["z"])
# Handle optional chain dimension: [chains, draws, agents, params] -> [draws_total, agents, params]
if z_samples.ndim == 4:
z_samples = z_samples.reshape(-1, z_samples.shape[-2], z_samples.shape[-1])
elif z_samples.ndim != 3:
raise ValueError(f"Unexpected mcmc_samples['z'] shape: {z_samples.shape}")
reward_prob_samples = 0.5 + 0.5 / (1.0 + np.exp(-z_samples[..., 0])) # [draws, agents]
true_reward_prob = np.array(0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0]))
posterior_mean = reward_prob_samples.mean(axis=0)
subject_ids = np.arange(num_agents)
sort_idx = np.argsort(true_reward_prob)
subject_ids_sorted = subject_ids[sort_idx]
true_reward_prob_sorted = true_reward_prob[sort_idx]
posterior_mean_sorted = posterior_mean[sort_idx]
reward_prob_samples_sorted = reward_prob_samples[:, sort_idx]
plot_positions = np.arange(num_agents)
fig, axes = plt.subplots(2, 1, figsize=(16, 10), constrained_layout=True)
# 1) Violin plot: per-subject posterior distributions + true values (sorted by true reward prob)
parts = axes[0].violinplot(
[reward_prob_samples_sorted[:, i] for i in range(num_agents)],
positions=plot_positions,
showmeans=False,
showmedians=True,
showextrema=False,
)
for pc in parts['bodies']:
pc.set_alpha(0.35)
axes[0].scatter(plot_positions, true_reward_prob_sorted, color='crimson', marker='x', s=80, label='true reward prob')
axes[0].scatter(plot_positions, posterior_mean_sorted, color='black', s=20, label='posterior mean')
axes[0].set_ylabel('reward probability')
axes[0].set_xlabel('subject id (sorted by true reward probability)')
axes[0].set_ylim(0.5, 1.0)
axes[0].set_xticks(plot_positions)
axes[0].set_xticklabels(subject_ids_sorted)
axes[0].set_title('NUTS posterior distribution per subject (violin, sorted)')
axes[0].legend(loc='upper left')
# 2) Density heatmap: posterior mass over probability bins by subject (same sorted order)
bins = np.linspace(0.0, 1.0, 60)
density = np.stack([
np.histogram(reward_prob_samples_sorted[:, i], bins=bins, density=True)[0]
for i in range(num_agents)
], axis=1) # [num_bins-1, num_agents]
im = axes[1].imshow(
density,
aspect='auto',
origin='lower',
extent=[-0.5, num_agents - 0.5, bins[0], bins[-1]],
cmap='viridis',
)
axes[1].plot(plot_positions, true_reward_prob_sorted, color='crimson', linestyle='--', marker='x', label='true reward prob')
axes[1].set_ylabel('reward probability')
axes[1].set_xlabel('subject id (sorted by true reward probability)')
axes[1].set_ylim(0.5, 1.0)
axes[1].set_xticks(plot_positions)
axes[1].set_xticklabels(subject_ids_sorted)
axes[1].set_title('NUTS posterior density by subject (sorted)')
axes[1].legend(loc='upper left')
fig.colorbar(im, ax=axes[1], label='density')
plt.show()
Inference Method 2: Black-Box Stochastic Variational Inference¶
Use NumPyro's SVI functionality to run variational inference with a MultivariateNormal variational posterior (that is, a Normal guide). SVI runs a black-box variational inference procedure where ELBO gradients are estimated using samples from the guide. This allows less constrained likelihood modeling than traditional mean-field variational Bayesian treatments, where likelihoods are often limited to conjugate-exponential forms.
# perform inference on parameters using black-box stochastic variational inference (SVI in numpyro)
# opts_svi dictionary can be used to specify various parameters
# for the SVI optimization algorithm
measurements = {
"outcomes": [outcomes[0] for outcomes in samples["outcomes"]],
"multiactions": samples["multiactions"][0],
}
posterior = NumpyroGuide(NormalPosterior(num_params, num_agents, backend="numpyro"))
# perform inference using stochastic variational inference
opts_svi = default_dict_numpyro_svi | {"iter_steps": 1_000}
print(opts_svi)
svi_samples, svi, results = run_svi(model, posterior, measurements, opts=opts_svi)
{'seed': 0, 'enumerate': False, 'iter_steps': 1000, 'optim': None, 'optim_kwargs': {'learning_rate': 0.001}, 'elbo_kwargs': {'num_particles': 10, 'max_plate_nesting': 1}, 'svi_kwargs': {'progress_bar': True, 'stable_update': True}, 'sample_kwargs': {'num_samples': 100}}
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. return Agent( 76%|███████▌ | 757/1000 [09:16<02:59, 1.35it/s, init loss: 7789.5132, avg. loss [701-750]: 7768.9961]
Plot the variational free energy over time (negative ELBO)¶
plt.plot(results.losses)
[<matplotlib.lines.Line2D at 0x3ec6cf790>]
Plot each ground truth parameter alongside their posterior means (mean taken over posterior samples from the guide)¶
plt.figure(figsize=(16, 5))
ground_truth_z = samples['z'][0]
for i in range(num_params):
plt.scatter(ground_truth_z[:, i], svi_samples["z"].mean(0)[:, i], label=i)
plt.plot((ground_truth_z.min(), ground_truth_z.max()), (ground_truth_z.min(), ground_truth_z.max()), "k--")
plt.ylabel("posterior mean (SVI)")
plt.xlabel("true value")
plt.legend(title="parameter id")
<matplotlib.legend.Legend at 0x3ecc6cd10>
Transform the latent parameter corresponding to the reward probability into probability space and investigate overlap between ground-truth and inferred parameter¶
inferred_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(svi_samples["z"].mean(0)[:, 0])
ground_truth_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0])
plt.figure(figsize=(16, 5))
plt.scatter(ground_truth_reward_probabilities, inferred_reward_probabilities, label="reward probability")
plt.plot((ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), (ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), "k--")
plt.ylabel("posterior mean (SVI)")
plt.xlabel("true value")
plt.legend(title="parameter id")
corr = np.corrcoef(
np.array(ground_truth_reward_probabilities),
np.array(inferred_reward_probabilities),
)[0, 1]
bimodality_score = np.clip(
np.mean((np.array(inferred_reward_probabilities) < 0.2) | (np.array(inferred_reward_probabilities) > 0.8))
- np.mean((np.array(inferred_reward_probabilities) >= 0.35) & (np.array(inferred_reward_probabilities) <= 0.65)),
0.0,
1.0,
)
print(f"reward-probability Pearson r: {corr:.3f}")
print(f"reward-probability bimodality score: {bimodality_score:.3f}")
reward-probability Pearson r: 0.905 reward-probability bimodality score: 0.300
# Inspect full SVI posterior shape per subject (not just posterior means)
z_samples = np.array(svi_samples["z"])
reward_prob_samples = 0.5 + 0.5 / (1.0 + np.exp(-z_samples[..., 0])) # [draws, agents]
true_reward_prob = np.array(0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0]))
posterior_mean = reward_prob_samples.mean(axis=0)
subject_ids = np.arange(num_agents)
sort_idx = np.argsort(true_reward_prob)
subject_ids_sorted = subject_ids[sort_idx]
true_reward_prob_sorted = true_reward_prob[sort_idx]
posterior_mean_sorted = posterior_mean[sort_idx]
reward_prob_samples_sorted = reward_prob_samples[:, sort_idx]
plot_positions = np.arange(num_agents)
fig, axes = plt.subplots(2, 1, figsize=(16, 10), constrained_layout=True)
# 1) Violin plot: per-subject posterior distributions + true values (sorted by true reward prob)
parts = axes[0].violinplot(
[reward_prob_samples_sorted[:, i] for i in range(num_agents)],
positions=plot_positions,
showmeans=False,
showmedians=True,
showextrema=False,
)
for pc in parts['bodies']:
pc.set_alpha(0.35)
axes[0].scatter(plot_positions, true_reward_prob_sorted, color='crimson', marker='x', s=80, label='true reward prob')
axes[0].scatter(plot_positions, posterior_mean_sorted, color='black', s=20, label='posterior mean')
axes[0].set_ylabel('reward probability')
axes[0].set_xlabel('subject id (sorted by true reward probability)')
axes[0].set_ylim(0.5, 1.0)
axes[0].set_xticks(plot_positions)
axes[0].set_xticklabels(subject_ids_sorted)
axes[0].set_title('SVI posterior distribution per subject (violin, sorted)')
axes[0].legend(loc='upper left')
# 2) Density heatmap: posterior mass over probability bins by subject (same sorted order)
bins = np.linspace(0.0, 1.0, 60)
density = np.stack([
np.histogram(reward_prob_samples_sorted[:, i], bins=bins, density=True)[0]
for i in range(num_agents)
], axis=1) # [num_bins-1, num_agents]
im = axes[1].imshow(
density,
aspect='auto',
origin='lower',
extent=[-0.5, num_agents - 0.5, bins[0], bins[-1]],
cmap='viridis',
)
axes[1].plot(plot_positions, true_reward_prob_sorted, color='crimson', linestyle='--', marker='x', label='true reward prob')
axes[1].set_ylabel('reward probability')
axes[1].set_xlabel('subject id (sorted by true reward probability)')
axes[1].set_ylim(0.5, 1.0)
axes[1].set_xticks(plot_positions)
axes[1].set_xticklabels(subject_ids_sorted)
axes[1].set_title('SVI posterior density by subject (sorted)')
axes[1].legend(loc='upper left')
fig.colorbar(im, ax=axes[1], label='density')
plt.show()