import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
import jax.numpy as jnp
from jax import tree_util as jtu, vmap, jit
from jax.experimental import sparse
from pymdp.agent import Agent
import matplotlib.pyplot as plt
import seaborn as sns
from pymdp.inference import smoothing_ovf
from pymdp.algos import hmm_smoother_scan_colstoch
from pymdp.maths import log_stable
Set up generative model and a sequence of observations. The A tensors, B tensors and observations are specified in such a way that only later observations ($o_{t > 1}$) help disambiguate hidden states at earlier time points. This will demonstrate the importance of "smoothing" or retrospective inference
num_states = [3, 2]
num_obs = [2]
n_batch = 2
A_1 = jnp.array([[1.0, 1.0, 1.0], [0.0, 0.0, 1.]])
A_2 = jnp.array([[1.0, 1.0], [1., 0.]])
A_tensor = A_1[..., None] * A_2[:, None]
A_tensor /= A_tensor.sum(0)
A = [jnp.broadcast_to(A_tensor, (n_batch, num_obs[0], 3, 2)) ]
# create two transition matrices, one for each state factor
B_1 = jnp.broadcast_to(
jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), (n_batch, 3, 3)
)
B_2 = jnp.broadcast_to(
jnp.array([[0.0, 1.0], [1.0, 0.0]]), (n_batch, 2, 2)
)
B = [B_1[..., None], B_2[..., None]]
# for the single modality, a sequence over time of observations (one hot vectors)
obs = [jnp.broadcast_to(jnp.array([[1., 0.], # observation 0 is ambiguous with respect state factors
[1., 0], # observation 0 is ambiguous with respect state factors
[1., 0], # observation 0 is ambiguous with respect state factors
[0., 1.]])[:, None], (4, n_batch, num_obs[0]) )] # observation 1 provides information about exact state of both factors
C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences
D = [jnp.ones((n_batch, 3)) / 3., jnp.ones((n_batch, 2)) / 2.] # flat prior
E = jnp.ones((n_batch, 1))
Construct the Agent
pA = None
pB = None
agents = Agent(
A=A,
B=B,
C=C,
D=D,
E=E,
pA=pA,
pB=pB,
policy_len=3,
categorical_obs=True,
action_selection="deterministic",
sampling_mode="full",
inference_algo="ovf",
num_iter=16,
batch_size=n_batch
)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61062/464010401.py:4: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agents = Agent(
OVF Smoothing¶
Using obs and policies, pass in the arguments outcomes, past_actions, empirical_prior and qs_hist to agent.infer_states(...)
Run first timestep of inference using obs[0], no past actions, empirical prior set to actual prior, no qs_hist
prior = agents.D
action_hist = []
qs_hist=None
for t in range(len(obs[0])):
first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs)
beliefs = agents.infer_states(first_obs, prior, qs_hist=qs_hist)
actions = jnp.broadcast_to(agents.policies[0, 0], (2, 2))
prior = agents.update_empirical_prior(actions, beliefs)
qs_hist = beliefs
if t < len(obs[0]) - 1:
action_hist.append(actions)
v_jso = jit(vmap(smoothing_ovf), backend='cpu')
actions_seq = jnp.stack(action_hist, 1)
smoothed_beliefs = v_jso(beliefs, agents.B, actions_seq)
%timeit v_jso(beliefs, agents.B, actions_seq)[0][0].block_until_ready()
24.1 μs ± 2.27 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Try the version of smoothing_ovf with sparse tensors
sparse_B = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=1), agents.B)
smoothed_beliefs_sparse = v_jso(beliefs, sparse_B, actions_seq)
%timeit v_jso(beliefs, sparse_B, actions_seq)[0][0].block_until_ready()
28.4 μs ± 361 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Now we can plot that pair of filtering / smoothing distributions for the single batch / single agent, that we ran
# with dense matrices
fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)
sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_beliefs[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_beliefs[0][1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
axes[0, 0].set_title('Filtered beliefs')
axes[0, 1].set_title('Smoothed beliefs')
plt.show()
# with sparse matrices
fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)
sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_beliefs_sparse[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_beliefs_sparse[0][1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
axes[0, 0].set_title('Filtered beliefs')
axes[0, 1].set_title("Smoothed beliefs")
plt.show()
Compare to marginal message passing¶
mmp_agents = agents = Agent(
A=A,
B=B,
C=C,
D=D,
E=E,
pA=pA,
pB=pB,
policy_len=3,
control_fac_idx=None,
categorical_obs=True,
action_selection="deterministic",
sampling_mode="full",
inference_algo="mmp",
num_iter=16,
batch_size=n_batch,
)
mmp_obs = [jnp.moveaxis(obs[0], 0, 1)]
post_marg_beliefs = mmp_agents.infer_states(mmp_obs, mmp_agents.D, past_actions=jnp.stack(action_hist, 1))
#with sparse matrices
fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True)
sns.heatmap(post_marg_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(post_marg_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis')
fig.suptitle('Marginal smoothed beliefs');
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61062/1188368462.py:1: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. mmp_agents = agents = Agent(
Compare to variational message passing¶
vmp_agents = agents = Agent(
A=A,
B=B,
C=C,
D=D,
E=E,
pA=pA,
pB=pB,
policy_len=3,
control_fac_idx=None,
categorical_obs=True,
action_selection="deterministic",
sampling_mode="full",
inference_algo="vmp",
num_iter=16,
batch_size=n_batch,
)
vmp_obs = [jnp.moveaxis(obs[0], 0, 1)]
post_vmp_beliefs = vmp_agents.infer_states(vmp_obs, vmp_agents.D, past_actions=jnp.stack(action_hist, 1))
#with sparse matrices
fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True)
sns.heatmap(post_vmp_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(post_vmp_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis')
fig.suptitle('VMP smoothed beliefs')
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61062/4088240323.py:1: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. vmp_agents = agents = Agent(
Text(0.5, 0.98, 'VMP smoothed beliefs')
Associative-scan HMM (B-native, single-factor reduction)¶
The associative-scan implementation expects a single hidden-state factor. For this multi-factor model, we collapse the two factors into a joint state space via a Cartesian product and run exact filtering/smoothing using the B-native (column-stochastic) formulation, then marginalize back to factor-wise beliefs for visualization.
batch_idx = 0
T = obs[0].shape[0]
s1, s2 = num_states
# single-modality observations for one batch
obs_single = obs[0][:, batch_idx, :] # (T, obs_dim)
# collapse A over factors into a single-state emission matrix
A0 = A[0][batch_idx] # (obs_dim, s1, s2)
A_big = A0.reshape(num_obs[0], s1 * s2)
log_likelihoods = obs_single @ log_stable(A_big)
# build joint transition (column-stochastic) via Kronecker product
B1 = B[0][batch_idx, :, :, 0] # (s1,s1) column-stochastic
B2 = B[1][batch_idx, :, :, 0] # (s2,s2) column-stochastic
B_big = jnp.kron(B1, B2)
# joint prior
D1 = D[0][batch_idx]
D2 = D[1][batch_idx]
prior_big = jnp.kron(D1, D2)
mll, filt_big, pred_big, smooth_big, cond_big = hmm_smoother_scan_colstoch(
prior_big, B_big, log_likelihoods
)
# reshape joint beliefs to factor-wise marginals
filt_joint = filt_big.reshape(T, s1, s2)
smooth_joint = smooth_big.reshape(T, s1, s2)
filt_fac1 = filt_joint.sum(axis=2)
filt_fac2 = filt_joint.sum(axis=1)
smooth_fac1 = smooth_joint.sum(axis=2)
smooth_fac2 = smooth_joint.sum(axis=1)
fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)
sns.heatmap(filt_fac1.mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(filt_fac2.mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smooth_fac1.mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smooth_fac2.mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
axes[0, 0].set_title('Associative scan filtered beliefs (B-native)')
axes[0, 1].set_title('Associative scan smoothed beliefs (B-native)')
plt.show()