pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

In [1]:
Copied!
%load_ext autoreload
%autoreload 2
%load_ext autoreload %autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp[modelfit]" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp[modelfit]" -q
In [ ]:
Copied!
from jax import numpy as jnp, random as jr
from jax import tree_util as jtu
from jax import vmap, nn, lax

import matplotlib.pyplot as plt
import numpy as np

from pymdp.agent import Agent
from pymdp.envs import TMaze

from pybefit.inference import (
    run_nuts,
    run_svi,
    default_dict_nuts,
    default_dict_numpyro_svi,
)

from pybefit.inference import NumpyroModel, NumpyroGuide
from pybefit.inference import Normal, NormalPosterior
from pybefit.inference.numpyro.likelihoods import pymdp_likelihood as likelihood
from numpyro.infer import Predictive
from jax import numpy as jnp, random as jr from jax import tree_util as jtu from jax import vmap, nn, lax import matplotlib.pyplot as plt import numpy as np from pymdp.agent import Agent from pymdp.envs import TMaze from pybefit.inference import ( run_nuts, run_svi, default_dict_nuts, default_dict_numpyro_svi, ) from pybefit.inference import NumpyroModel, NumpyroGuide from pybefit.inference import Normal, NormalPosterior from pybefit.inference.numpyro.likelihoods import pymdp_likelihood as likelihood from numpyro.infer import Predictive

Fitting the parameters of active inference agents performing in the T-Maze environment¶

We set up a TMaze instance with a fully reliable cue-to-reward mapping, so the cue unambiguously reveals the rewarding arm. We also set the rewarded-arm outcome mapping to be probabilistic at 80%. This means that in the rewarded arm, reward is observed with probability 0.8 and punishment with probability 0.2, while these probabilities are inverted in the punished arm (dependent_outcomes=True).

We also parameterize 10 different agent-environment pairs (batch_size=10), 100 experimental blocks (episodes) per agent, and 5 timesteps per block.

Explanatory note on batch_size, subject, block, timestep, and trial terminology

Experimental data collected from individual participants or subjects in psychological and psychiatric studies often follows a multi-level temporal structure. Typically, each subject participates in a task independently from other participants, and each subject performs multiple episodes or blocks sequentially (for example, 20 blocks of 100 trials per block).

Depending on the experimental design, blocks may be independent or may include between-block correlations (for example, regular changes in initial task state across blocks). Within each block, individual "trials" or "timesteps" occur sequentially. There can also be within-block, between-trial correlation structure, and this is an assumption we make in pymdp: a single active inference process unfolds across these timesteps.

Because the word "trial" often refers to events that have their own internal temporal structure, this notebook uses timesteps to reduce ambiguity. This aligns with reinforcement learning and active inference workflows, where these are naturally treated as timesteps and blocks are naturally treated as episodes.

In [ ]:
Copied!
seed_key = jr.PRNGKey(101)

# setting the parameters for the environment
num_blocks = 100  # number of blocks (number of multi-timestep episodes that are run sequentially, per agent)
num_timesteps = 5 # number of timesteps (or "trials") per block

reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 0.8 # 80% chance of reward in the correct arm
punishment_probability = 0.0 # only used when dependent_outcomes=False
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = True # if True, reward and punishment probabilities are coupled across arms via reward_probability. If False, punishment occurs with the fixed punishment_probability.

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
task = TMaze( 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition,
    dependent_outcomes=dependent_outcomes
)

batch_size = 10 # batch_size, which in this case corresponds to the number of agents to fit in parallel
key, _key = jr.split(seed_key)
init_obs, init_states = vmap(task.reset)(jr.split(_key, batch_size))
seed_key = jr.PRNGKey(101) # setting the parameters for the environment num_blocks = 100 # number of blocks (number of multi-timestep episodes that are run sequentially, per agent) num_timesteps = 5 # number of timesteps (or "trials") per block reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation reward_probability = 0.8 # 80% chance of reward in the correct arm punishment_probability = 0.0 # only used when dependent_outcomes=False cue_validity = 1.0 # 100% valid cues dependent_outcomes = True # if True, reward and punishment probabilities are coupled across arms via reward_probability. If False, punishment occurs with the fixed punishment_probability. # initialising the environment. see tmaze.py in pymdp/envs for the implementation details. task = TMaze( reward_probability=reward_probability, punishment_probability=punishment_probability, cue_validity=cue_validity, reward_condition=reward_condition, dependent_outcomes=dependent_outcomes ) batch_size = 10 # batch_size, which in this case corresponds to the number of agents to fit in parallel key, _key = jr.split(seed_key) init_obs, init_states = vmap(task.reset)(jr.split(_key, batch_size))
In [2]:
Copied!
# Constrained recoverable setup: infer only reward-probability beliefs.
# Keep preference strength (lambda) and initial reward-state prior fixed.
num_params = 1
num_agents = batch_size
prior = Normal(num_params, num_agents, backend="numpyro")


def transform(z):
    na, _ = z.shape

    reward_prob = 0.5 + 0.5 * nn.sigmoid(z[..., 0]) # constrain reward probability to be between 0.5 and 1
    lam = jnp.full((na,), 1.2)

    A = lax.stop_gradient(task.A)
    A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (na,) + x.shape), A)
    B = lax.stop_gradient(task.B)
    B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (na,) + x.shape), B)

    one_minus = 1.0 - reward_prob

    reward_left = jnp.stack([reward_prob, one_minus], -1)
    punish_left = jnp.stack([one_minus, reward_prob], -1)
    reward_right = jnp.stack([one_minus, reward_prob], -1)
    punish_right = jnp.stack([reward_prob, one_minus], -1)
    zeros = jnp.zeros_like(reward_left)

    side = jnp.broadcast_to(jnp.array([[1., 1.], [0., 0.], [0., 0.]]), (na, 3, 2))
    left_col = jnp.stack([zeros, reward_left, punish_left], axis=-2)
    right_col = jnp.stack([zeros, reward_right, punish_right], axis=-2)
    A[1] = jnp.stack([side, left_col, right_col, side, side], axis=-2)

    C = [
        jnp.zeros((na, A[0].shape[1])),
        jnp.expand_dims(lam, -1) * jnp.array([0., 1., -1.]),
        jnp.zeros((na, A[2].shape[1])),
    ]
    D = [
        jnp.zeros((na, B[0].shape[1])).at[:, 0].set(1.0),
        jnp.full((na, 2), 0.5),
    ]

    return Agent(
        A,
        B,
        C=C,
        D=D,
        policy_len=2,
        A_dependencies=task.A_dependencies,
        B_dependencies=task.B_dependencies,
        batch_size=na,
        action_selection="stochastic",
    )


key, _key = jr.split(seed_key)
# Use a broad, deterministic latent grid so recoverability is visible across agents.
z = jnp.linspace(-2.5, 2.5, num_agents).reshape(num_agents, num_params)

agents = transform(z)
# Constrained recoverable setup: infer only reward-probability beliefs. # Keep preference strength (lambda) and initial reward-state prior fixed. num_params = 1 num_agents = batch_size prior = Normal(num_params, num_agents, backend="numpyro") def transform(z): na, _ = z.shape reward_prob = 0.5 + 0.5 * nn.sigmoid(z[..., 0]) # constrain reward probability to be between 0.5 and 1 lam = jnp.full((na,), 1.2) A = lax.stop_gradient(task.A) A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (na,) + x.shape), A) B = lax.stop_gradient(task.B) B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (na,) + x.shape), B) one_minus = 1.0 - reward_prob reward_left = jnp.stack([reward_prob, one_minus], -1) punish_left = jnp.stack([one_minus, reward_prob], -1) reward_right = jnp.stack([one_minus, reward_prob], -1) punish_right = jnp.stack([reward_prob, one_minus], -1) zeros = jnp.zeros_like(reward_left) side = jnp.broadcast_to(jnp.array([[1., 1.], [0., 0.], [0., 0.]]), (na, 3, 2)) left_col = jnp.stack([zeros, reward_left, punish_left], axis=-2) right_col = jnp.stack([zeros, reward_right, punish_right], axis=-2) A[1] = jnp.stack([side, left_col, right_col, side, side], axis=-2) C = [ jnp.zeros((na, A[0].shape[1])), jnp.expand_dims(lam, -1) * jnp.array([0., 1., -1.]), jnp.zeros((na, A[2].shape[1])), ] D = [ jnp.zeros((na, B[0].shape[1])).at[:, 0].set(1.0), jnp.full((na, 2), 0.5), ] return Agent( A, B, C=C, D=D, policy_len=2, A_dependencies=task.A_dependencies, B_dependencies=task.B_dependencies, batch_size=na, action_selection="stochastic", ) key, _key = jr.split(seed_key) # Use a broad, deterministic latent grid so recoverability is visible across agents. z = jnp.linspace(-2.5, 2.5, num_agents).reshape(num_agents, num_params) agents = transform(z)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  return Agent(

Sample from the TMaze environment and visualize the results from one block, showing the expected behavior (visit the cue, then choose an arm)¶

In [3]:
Copied!
opts_task = {
    "task": task,
    "num_blocks": num_blocks,
    "num_trials": num_timesteps,
    "num_agents": num_agents,
}
opts_model = {"prior": {}, "transform": {}, "likelihood": opts_task}

model = NumpyroModel(prior, transform, likelihood, opts=opts_model)

pred = Predictive(model, num_samples=1)
key, _key = jr.split(key)
samples = pred(_key)
opts_task = { "task": task, "num_blocks": num_blocks, "num_trials": num_timesteps, "num_agents": num_agents, } opts_model = {"prior": {}, "transform": {}, "likelihood": opts_task} model = NumpyroModel(prior, transform, likelihood, opts=opts_model) pred = Predictive(model, num_samples=1) key, _key = jr.split(key) samples = pred(_key)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  return Agent(
In [4]:
Copied!
frames = []
for t in range(num_timesteps):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        samples["outcomes"][0][0,0,:, t], # subset by predictive sample (first leading dimension) and block (second leading dimension)
        samples["outcomes"][1][0,0,:, t],  
        samples["outcomes"][2][0,0,:, t]   
    ]
    observations_t = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), observations_t) # add lagging dimension as is done before returning in task.step()
       
    frame = task.render(mode="rgb_array", observations=observations_t).astype(jnp.uint8) # render the environment using the observations for this timestep
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = jnp.array(frames, dtype=jnp.uint8)

## Make a panel of subplots showing the frames at different timesteps of the video sequence (don't assume mediapy dependency)
fig, axes = plt.subplots(1, num_timesteps, figsize=(15, 5))
for i in range(num_timesteps):
    axes[i].imshow(frames[i])
    axes[i].axis('off')
    axes[i].set_title(f'Timestep {i+1}')
plt.show()
frames = [] for t in range(num_timesteps): # iterate over timesteps # get observations for this timestep observations_t = [ samples["outcomes"][0][0,0,:, t], # subset by predictive sample (first leading dimension) and block (second leading dimension) samples["outcomes"][1][0,0,:, t], samples["outcomes"][2][0,0,:, t] ] observations_t = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), observations_t) # add lagging dimension as is done before returning in task.step() frame = task.render(mode="rgb_array", observations=observations_t).astype(jnp.uint8) # render the environment using the observations for this timestep plt.close() # close the figure to prevent memory leak frames.append(frame) frames = jnp.array(frames, dtype=jnp.uint8) ## Make a panel of subplots showing the frames at different timesteps of the video sequence (don't assume mediapy dependency) fig, axes = plt.subplots(1, num_timesteps, figsize=(15, 5)) for i in range(num_timesteps): axes[i].imshow(frames[i]) axes[i].axis('off') axes[i].set_title(f'Timestep {i+1}') plt.show()
No description has been provided for this image

Inference Method 1: HMC with NUTS¶

Use the No U-Turn Sampler (NUTS) for Hamiltonian Monte Carlo sampling-based inference to sample from the parameter posterior. pybefit provides useful wrappers for setting up an NUTS-HMC run.

In [5]:
Copied!
# perform inference on parameters using no-u-turn sampler (NUTS)
# opts_sampling dictionary can be used to specify various parameters
# either for the NUTS kernel or MCMC sampler
measurements = {
    "outcomes": [outcomes[0] for outcomes in samples["outcomes"]],
    "multiactions": samples["multiactions"][0],
}

opts_sampling = default_dict_nuts
opts_sampling["num_warmup"] = 400
opts_sampling["num_samples"] = 100
opts_sampling["sampler_kwargs"] = {"kernel": {}, "mcmc": {"progress_bar": True}}
print(opts_sampling)

mcmc_samples, mcmc = run_nuts(model, measurements, opts=opts_sampling)
# perform inference on parameters using no-u-turn sampler (NUTS) # opts_sampling dictionary can be used to specify various parameters # either for the NUTS kernel or MCMC sampler measurements = { "outcomes": [outcomes[0] for outcomes in samples["outcomes"]], "multiactions": samples["multiactions"][0], } opts_sampling = default_dict_nuts opts_sampling["num_warmup"] = 400 opts_sampling["num_samples"] = 100 opts_sampling["sampler_kwargs"] = {"kernel": {}, "mcmc": {"progress_bar": True}} print(opts_sampling) mcmc_samples, mcmc = run_nuts(model, measurements, opts=opts_sampling)
{'seed': 0, 'num_samples': 100, 'num_warmup': 400, 'sampler_kwargs': {'kernel': {}, 'mcmc': {'progress_bar': True}}}
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  return Agent(
sample: 100%|██████████| 500/500 [05:56<00:00,  1.40it/s, 7 steps of size 7.24e-01. acc. prob=0.82] 

Plot each ground truth parameter alongside their posterior means (mean taken over parallel HMC samples/chains)¶

In [6]:
Copied!
plt.figure(figsize=(16, 5))
ground_truth_z = samples['z'][0]
for i in range(num_params):
    plt.scatter(ground_truth_z[:, i], mcmc_samples["z"].mean(0)[:, i], label=i)

plt.plot((ground_truth_z.min(), ground_truth_z.max()), (ground_truth_z.min(), ground_truth_z.max()), "k--")
plt.ylabel("posterior mean (MCMC)")
plt.xlabel("true value")
plt.legend(title="parameter id")
plt.figure(figsize=(16, 5)) ground_truth_z = samples['z'][0] for i in range(num_params): plt.scatter(ground_truth_z[:, i], mcmc_samples["z"].mean(0)[:, i], label=i) plt.plot((ground_truth_z.min(), ground_truth_z.max()), (ground_truth_z.min(), ground_truth_z.max()), "k--") plt.ylabel("posterior mean (MCMC)") plt.xlabel("true value") plt.legend(title="parameter id")
Out[6]:
<matplotlib.legend.Legend at 0x3666c1850>
No description has been provided for this image

Transform the latent parameter corresponding to the reward probability into probability space and investigate overlap between ground-truth and inferred parameter¶

In [7]:
Copied!
inferred_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(mcmc_samples["z"].mean(0)[:, 0])
ground_truth_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0])

plt.figure(figsize=(16, 5))
plt.scatter(ground_truth_reward_probabilities, inferred_reward_probabilities, label="reward probability")

plt.plot((ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), (ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), "k--")
plt.ylabel("posterior mean (NUTS-HMC)")
plt.xlabel("true value")
plt.legend(title="parameter id")

corr = np.corrcoef(
    np.array(ground_truth_reward_probabilities),
    np.array(inferred_reward_probabilities),
)[0, 1]

bimodality_score = np.clip(
    np.mean((np.array(inferred_reward_probabilities) < 0.2) | (np.array(inferred_reward_probabilities) > 0.8))
    - np.mean((np.array(inferred_reward_probabilities) >= 0.35) & (np.array(inferred_reward_probabilities) <= 0.65)),
    0.0,
    1.0,
)

print(f"reward-probability Pearson r: {corr:.3f}")
print(f"reward-probability bimodality score: {bimodality_score:.3f}")
inferred_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(mcmc_samples["z"].mean(0)[:, 0]) ground_truth_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0]) plt.figure(figsize=(16, 5)) plt.scatter(ground_truth_reward_probabilities, inferred_reward_probabilities, label="reward probability") plt.plot((ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), (ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), "k--") plt.ylabel("posterior mean (NUTS-HMC)") plt.xlabel("true value") plt.legend(title="parameter id") corr = np.corrcoef( np.array(ground_truth_reward_probabilities), np.array(inferred_reward_probabilities), )[0, 1] bimodality_score = np.clip( np.mean((np.array(inferred_reward_probabilities) < 0.2) | (np.array(inferred_reward_probabilities) > 0.8)) - np.mean((np.array(inferred_reward_probabilities) >= 0.35) & (np.array(inferred_reward_probabilities) <= 0.65)), 0.0, 1.0, ) print(f"reward-probability Pearson r: {corr:.3f}") print(f"reward-probability bimodality score: {bimodality_score:.3f}")
reward-probability Pearson r: 0.906
reward-probability bimodality score: 0.100
No description has been provided for this image
In [8]:
Copied!
# Inspect full NUTS posterior shape per subject (not just posterior means)
z_samples = np.array(mcmc_samples["z"])

# Handle optional chain dimension: [chains, draws, agents, params] -> [draws_total, agents, params]
if z_samples.ndim == 4:
    z_samples = z_samples.reshape(-1, z_samples.shape[-2], z_samples.shape[-1])
elif z_samples.ndim != 3:
    raise ValueError(f"Unexpected mcmc_samples['z'] shape: {z_samples.shape}")

reward_prob_samples = 0.5 + 0.5 / (1.0 + np.exp(-z_samples[..., 0]))  # [draws, agents]
true_reward_prob = np.array(0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0]))
posterior_mean = reward_prob_samples.mean(axis=0)

subject_ids = np.arange(num_agents)
sort_idx = np.argsort(true_reward_prob)
subject_ids_sorted = subject_ids[sort_idx]
true_reward_prob_sorted = true_reward_prob[sort_idx]
posterior_mean_sorted = posterior_mean[sort_idx]
reward_prob_samples_sorted = reward_prob_samples[:, sort_idx]

plot_positions = np.arange(num_agents)

fig, axes = plt.subplots(2, 1, figsize=(16, 10), constrained_layout=True)

# 1) Violin plot: per-subject posterior distributions + true values (sorted by true reward prob)
parts = axes[0].violinplot(
    [reward_prob_samples_sorted[:, i] for i in range(num_agents)],
    positions=plot_positions,
    showmeans=False,
    showmedians=True,
    showextrema=False,
)
for pc in parts['bodies']:
    pc.set_alpha(0.35)

axes[0].scatter(plot_positions, true_reward_prob_sorted, color='crimson', marker='x', s=80, label='true reward prob')
axes[0].scatter(plot_positions, posterior_mean_sorted, color='black', s=20, label='posterior mean')
axes[0].set_ylabel('reward probability')
axes[0].set_xlabel('subject id (sorted by true reward probability)')
axes[0].set_ylim(0.5, 1.0)
axes[0].set_xticks(plot_positions)
axes[0].set_xticklabels(subject_ids_sorted)
axes[0].set_title('NUTS posterior distribution per subject (violin, sorted)')
axes[0].legend(loc='upper left')

# 2) Density heatmap: posterior mass over probability bins by subject (same sorted order)
bins = np.linspace(0.0, 1.0, 60)
density = np.stack([
    np.histogram(reward_prob_samples_sorted[:, i], bins=bins, density=True)[0]
    for i in range(num_agents)
], axis=1)  # [num_bins-1, num_agents]

im = axes[1].imshow(
    density,
    aspect='auto',
    origin='lower',
    extent=[-0.5, num_agents - 0.5, bins[0], bins[-1]],
    cmap='viridis',
)
axes[1].plot(plot_positions, true_reward_prob_sorted, color='crimson', linestyle='--', marker='x', label='true reward prob')
axes[1].set_ylabel('reward probability')
axes[1].set_xlabel('subject id (sorted by true reward probability)')
axes[1].set_ylim(0.5, 1.0)
axes[1].set_xticks(plot_positions)
axes[1].set_xticklabels(subject_ids_sorted)
axes[1].set_title('NUTS posterior density by subject (sorted)')
axes[1].legend(loc='upper left')
fig.colorbar(im, ax=axes[1], label='density')

plt.show()
# Inspect full NUTS posterior shape per subject (not just posterior means) z_samples = np.array(mcmc_samples["z"]) # Handle optional chain dimension: [chains, draws, agents, params] -> [draws_total, agents, params] if z_samples.ndim == 4: z_samples = z_samples.reshape(-1, z_samples.shape[-2], z_samples.shape[-1]) elif z_samples.ndim != 3: raise ValueError(f"Unexpected mcmc_samples['z'] shape: {z_samples.shape}") reward_prob_samples = 0.5 + 0.5 / (1.0 + np.exp(-z_samples[..., 0])) # [draws, agents] true_reward_prob = np.array(0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0])) posterior_mean = reward_prob_samples.mean(axis=0) subject_ids = np.arange(num_agents) sort_idx = np.argsort(true_reward_prob) subject_ids_sorted = subject_ids[sort_idx] true_reward_prob_sorted = true_reward_prob[sort_idx] posterior_mean_sorted = posterior_mean[sort_idx] reward_prob_samples_sorted = reward_prob_samples[:, sort_idx] plot_positions = np.arange(num_agents) fig, axes = plt.subplots(2, 1, figsize=(16, 10), constrained_layout=True) # 1) Violin plot: per-subject posterior distributions + true values (sorted by true reward prob) parts = axes[0].violinplot( [reward_prob_samples_sorted[:, i] for i in range(num_agents)], positions=plot_positions, showmeans=False, showmedians=True, showextrema=False, ) for pc in parts['bodies']: pc.set_alpha(0.35) axes[0].scatter(plot_positions, true_reward_prob_sorted, color='crimson', marker='x', s=80, label='true reward prob') axes[0].scatter(plot_positions, posterior_mean_sorted, color='black', s=20, label='posterior mean') axes[0].set_ylabel('reward probability') axes[0].set_xlabel('subject id (sorted by true reward probability)') axes[0].set_ylim(0.5, 1.0) axes[0].set_xticks(plot_positions) axes[0].set_xticklabels(subject_ids_sorted) axes[0].set_title('NUTS posterior distribution per subject (violin, sorted)') axes[0].legend(loc='upper left') # 2) Density heatmap: posterior mass over probability bins by subject (same sorted order) bins = np.linspace(0.0, 1.0, 60) density = np.stack([ np.histogram(reward_prob_samples_sorted[:, i], bins=bins, density=True)[0] for i in range(num_agents) ], axis=1) # [num_bins-1, num_agents] im = axes[1].imshow( density, aspect='auto', origin='lower', extent=[-0.5, num_agents - 0.5, bins[0], bins[-1]], cmap='viridis', ) axes[1].plot(plot_positions, true_reward_prob_sorted, color='crimson', linestyle='--', marker='x', label='true reward prob') axes[1].set_ylabel('reward probability') axes[1].set_xlabel('subject id (sorted by true reward probability)') axes[1].set_ylim(0.5, 1.0) axes[1].set_xticks(plot_positions) axes[1].set_xticklabels(subject_ids_sorted) axes[1].set_title('NUTS posterior density by subject (sorted)') axes[1].legend(loc='upper left') fig.colorbar(im, ax=axes[1], label='density') plt.show()
No description has been provided for this image

Inference Method 2: Black-Box Stochastic Variational Inference¶

Use NumPyro's SVI functionality to run variational inference with a MultivariateNormal variational posterior (that is, a Normal guide). SVI runs a black-box variational inference procedure where ELBO gradients are estimated using samples from the guide. This allows less constrained likelihood modeling than traditional mean-field variational Bayesian treatments, where likelihoods are often limited to conjugate-exponential forms.

In [9]:
Copied!
# perform inference on parameters using black-box stochastic variational inference (SVI in numpyro)
# opts_svi dictionary can be used to specify various parameters
# for the SVI optimization algorithm
measurements = {
    "outcomes": [outcomes[0] for outcomes in samples["outcomes"]],
    "multiactions": samples["multiactions"][0],
}

posterior = NumpyroGuide(NormalPosterior(num_params, num_agents, backend="numpyro"))

# perform inference using stochastic variational inference
opts_svi = default_dict_numpyro_svi | {"iter_steps": 1_000}
print(opts_svi)

svi_samples, svi, results = run_svi(model, posterior, measurements, opts=opts_svi)
# perform inference on parameters using black-box stochastic variational inference (SVI in numpyro) # opts_svi dictionary can be used to specify various parameters # for the SVI optimization algorithm measurements = { "outcomes": [outcomes[0] for outcomes in samples["outcomes"]], "multiactions": samples["multiactions"][0], } posterior = NumpyroGuide(NormalPosterior(num_params, num_agents, backend="numpyro")) # perform inference using stochastic variational inference opts_svi = default_dict_numpyro_svi | {"iter_steps": 1_000} print(opts_svi) svi_samples, svi, results = run_svi(model, posterior, measurements, opts=opts_svi)
{'seed': 0, 'enumerate': False, 'iter_steps': 1000, 'optim': None, 'optim_kwargs': {'learning_rate': 0.001}, 'elbo_kwargs': {'num_particles': 10, 'max_plate_nesting': 1}, 'svi_kwargs': {'progress_bar': True, 'stable_update': True}, 'sample_kwargs': {'num_samples': 100}}
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_95807/2592818604.py:42: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  return Agent(
 76%|███████▌  | 757/1000 [09:16<02:59,  1.35it/s, init loss: 7789.5132, avg. loss [701-750]: 7768.9961]

Plot the variational free energy over time (negative ELBO)¶

In [10]:
Copied!
plt.plot(results.losses)
plt.plot(results.losses)
Out[10]:
[<matplotlib.lines.Line2D at 0x3ec6cf790>]
No description has been provided for this image

Plot each ground truth parameter alongside their posterior means (mean taken over posterior samples from the guide)¶

In [11]:
Copied!
plt.figure(figsize=(16, 5))
ground_truth_z = samples['z'][0]
for i in range(num_params):
    plt.scatter(ground_truth_z[:, i], svi_samples["z"].mean(0)[:, i], label=i)

plt.plot((ground_truth_z.min(), ground_truth_z.max()), (ground_truth_z.min(), ground_truth_z.max()), "k--")
plt.ylabel("posterior mean (SVI)")
plt.xlabel("true value")
plt.legend(title="parameter id")
plt.figure(figsize=(16, 5)) ground_truth_z = samples['z'][0] for i in range(num_params): plt.scatter(ground_truth_z[:, i], svi_samples["z"].mean(0)[:, i], label=i) plt.plot((ground_truth_z.min(), ground_truth_z.max()), (ground_truth_z.min(), ground_truth_z.max()), "k--") plt.ylabel("posterior mean (SVI)") plt.xlabel("true value") plt.legend(title="parameter id")
Out[11]:
<matplotlib.legend.Legend at 0x3ecc6cd10>
No description has been provided for this image

Transform the latent parameter corresponding to the reward probability into probability space and investigate overlap between ground-truth and inferred parameter¶

In [12]:
Copied!
inferred_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(svi_samples["z"].mean(0)[:, 0])
ground_truth_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0])

plt.figure(figsize=(16, 5))
plt.scatter(ground_truth_reward_probabilities, inferred_reward_probabilities, label="reward probability")

plt.plot((ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), (ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), "k--")
plt.ylabel("posterior mean (SVI)")
plt.xlabel("true value")
plt.legend(title="parameter id")

corr = np.corrcoef(
    np.array(ground_truth_reward_probabilities),
    np.array(inferred_reward_probabilities),
)[0, 1]

bimodality_score = np.clip(
    np.mean((np.array(inferred_reward_probabilities) < 0.2) | (np.array(inferred_reward_probabilities) > 0.8))
    - np.mean((np.array(inferred_reward_probabilities) >= 0.35) & (np.array(inferred_reward_probabilities) <= 0.65)),
    0.0,
    1.0,
)

print(f"reward-probability Pearson r: {corr:.3f}")
print(f"reward-probability bimodality score: {bimodality_score:.3f}")
inferred_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(svi_samples["z"].mean(0)[:, 0]) ground_truth_reward_probabilities = 0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0]) plt.figure(figsize=(16, 5)) plt.scatter(ground_truth_reward_probabilities, inferred_reward_probabilities, label="reward probability") plt.plot((ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), (ground_truth_reward_probabilities.min(), ground_truth_reward_probabilities.max()), "k--") plt.ylabel("posterior mean (SVI)") plt.xlabel("true value") plt.legend(title="parameter id") corr = np.corrcoef( np.array(ground_truth_reward_probabilities), np.array(inferred_reward_probabilities), )[0, 1] bimodality_score = np.clip( np.mean((np.array(inferred_reward_probabilities) < 0.2) | (np.array(inferred_reward_probabilities) > 0.8)) - np.mean((np.array(inferred_reward_probabilities) >= 0.35) & (np.array(inferred_reward_probabilities) <= 0.65)), 0.0, 1.0, ) print(f"reward-probability Pearson r: {corr:.3f}") print(f"reward-probability bimodality score: {bimodality_score:.3f}")
reward-probability Pearson r: 0.905
reward-probability bimodality score: 0.300
No description has been provided for this image
In [13]:
Copied!
# Inspect full SVI posterior shape per subject (not just posterior means)
z_samples = np.array(svi_samples["z"])

reward_prob_samples = 0.5 + 0.5 / (1.0 + np.exp(-z_samples[..., 0]))  # [draws, agents]
true_reward_prob = np.array(0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0]))
posterior_mean = reward_prob_samples.mean(axis=0)

subject_ids = np.arange(num_agents)
sort_idx = np.argsort(true_reward_prob)
subject_ids_sorted = subject_ids[sort_idx]
true_reward_prob_sorted = true_reward_prob[sort_idx]
posterior_mean_sorted = posterior_mean[sort_idx]
reward_prob_samples_sorted = reward_prob_samples[:, sort_idx]

plot_positions = np.arange(num_agents)

fig, axes = plt.subplots(2, 1, figsize=(16, 10), constrained_layout=True)

# 1) Violin plot: per-subject posterior distributions + true values (sorted by true reward prob)
parts = axes[0].violinplot(
    [reward_prob_samples_sorted[:, i] for i in range(num_agents)],
    positions=plot_positions,
    showmeans=False,
    showmedians=True,
    showextrema=False,
)
for pc in parts['bodies']:
    pc.set_alpha(0.35)

axes[0].scatter(plot_positions, true_reward_prob_sorted, color='crimson', marker='x', s=80, label='true reward prob')
axes[0].scatter(plot_positions, posterior_mean_sorted, color='black', s=20, label='posterior mean')
axes[0].set_ylabel('reward probability')
axes[0].set_xlabel('subject id (sorted by true reward probability)')
axes[0].set_ylim(0.5, 1.0)
axes[0].set_xticks(plot_positions)
axes[0].set_xticklabels(subject_ids_sorted)
axes[0].set_title('SVI posterior distribution per subject (violin, sorted)')
axes[0].legend(loc='upper left')

# 2) Density heatmap: posterior mass over probability bins by subject (same sorted order)
bins = np.linspace(0.0, 1.0, 60)
density = np.stack([
    np.histogram(reward_prob_samples_sorted[:, i], bins=bins, density=True)[0]
    for i in range(num_agents)
], axis=1)  # [num_bins-1, num_agents]

im = axes[1].imshow(
    density,
    aspect='auto',
    origin='lower',
    extent=[-0.5, num_agents - 0.5, bins[0], bins[-1]],
    cmap='viridis',
)
axes[1].plot(plot_positions, true_reward_prob_sorted, color='crimson', linestyle='--', marker='x', label='true reward prob')
axes[1].set_ylabel('reward probability')
axes[1].set_xlabel('subject id (sorted by true reward probability)')
axes[1].set_ylim(0.5, 1.0)
axes[1].set_xticks(plot_positions)
axes[1].set_xticklabels(subject_ids_sorted)
axes[1].set_title('SVI posterior density by subject (sorted)')
axes[1].legend(loc='upper left')
fig.colorbar(im, ax=axes[1], label='density')

plt.show()
# Inspect full SVI posterior shape per subject (not just posterior means) z_samples = np.array(svi_samples["z"]) reward_prob_samples = 0.5 + 0.5 / (1.0 + np.exp(-z_samples[..., 0])) # [draws, agents] true_reward_prob = np.array(0.5 + 0.5 * nn.sigmoid(ground_truth_z[:, 0])) posterior_mean = reward_prob_samples.mean(axis=0) subject_ids = np.arange(num_agents) sort_idx = np.argsort(true_reward_prob) subject_ids_sorted = subject_ids[sort_idx] true_reward_prob_sorted = true_reward_prob[sort_idx] posterior_mean_sorted = posterior_mean[sort_idx] reward_prob_samples_sorted = reward_prob_samples[:, sort_idx] plot_positions = np.arange(num_agents) fig, axes = plt.subplots(2, 1, figsize=(16, 10), constrained_layout=True) # 1) Violin plot: per-subject posterior distributions + true values (sorted by true reward prob) parts = axes[0].violinplot( [reward_prob_samples_sorted[:, i] for i in range(num_agents)], positions=plot_positions, showmeans=False, showmedians=True, showextrema=False, ) for pc in parts['bodies']: pc.set_alpha(0.35) axes[0].scatter(plot_positions, true_reward_prob_sorted, color='crimson', marker='x', s=80, label='true reward prob') axes[0].scatter(plot_positions, posterior_mean_sorted, color='black', s=20, label='posterior mean') axes[0].set_ylabel('reward probability') axes[0].set_xlabel('subject id (sorted by true reward probability)') axes[0].set_ylim(0.5, 1.0) axes[0].set_xticks(plot_positions) axes[0].set_xticklabels(subject_ids_sorted) axes[0].set_title('SVI posterior distribution per subject (violin, sorted)') axes[0].legend(loc='upper left') # 2) Density heatmap: posterior mass over probability bins by subject (same sorted order) bins = np.linspace(0.0, 1.0, 60) density = np.stack([ np.histogram(reward_prob_samples_sorted[:, i], bins=bins, density=True)[0] for i in range(num_agents) ], axis=1) # [num_bins-1, num_agents] im = axes[1].imshow( density, aspect='auto', origin='lower', extent=[-0.5, num_agents - 0.5, bins[0], bins[-1]], cmap='viridis', ) axes[1].plot(plot_positions, true_reward_prob_sorted, color='crimson', linestyle='--', marker='x', label='true reward prob') axes[1].set_ylabel('reward probability') axes[1].set_xlabel('subject id (sorted by true reward probability)') axes[1].set_ylim(0.5, 1.0) axes[1].set_xticks(plot_positions) axes[1].set_xticklabels(subject_ids_sorted) axes[1].set_title('SVI posterior density by subject (sorted)') axes[1].legend(loc='upper left') fig.colorbar(im, ax=axes[1], label='density') plt.show()
No description has been provided for this image

Made with Dracula Theme for MkDocs