pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import jax.numpy as jnp
from jax import tree_util as jtu, vmap, jit
from jax.experimental import sparse
from pymdp.agent import Agent
import matplotlib.pyplot as plt
import seaborn as sns

from pymdp.inference import smoothing_ovf
from pymdp.algos import hmm_smoother_scan_colstoch
from pymdp.maths import log_stable
import jax.numpy as jnp from jax import tree_util as jtu, vmap, jit from jax.experimental import sparse from pymdp.agent import Agent import matplotlib.pyplot as plt import seaborn as sns from pymdp.inference import smoothing_ovf from pymdp.algos import hmm_smoother_scan_colstoch from pymdp.maths import log_stable

Set up generative model and a sequence of observations. The A tensors, B tensors and observations are specified in such a way that only later observations ($o_{t > 1}$) help disambiguate hidden states at earlier time points. This will demonstrate the importance of "smoothing" or retrospective inference

In [ ]:
Copied!
num_states = [3, 2]
num_obs = [2]
n_batch = 2

A_1 = jnp.array([[1.0, 1.0, 1.0], [0.0,  0.0,  1.]])
A_2 = jnp.array([[1.0, 1.0], [1., 0.]])

A_tensor = A_1[..., None] * A_2[:, None]

A_tensor /= A_tensor.sum(0)

A = [jnp.broadcast_to(A_tensor, (n_batch, num_obs[0], 3, 2)) ]

# create two transition matrices, one for each state factor
B_1 = jnp.broadcast_to(
    jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), (n_batch, 3, 3)
)

B_2 = jnp.broadcast_to(
        jnp.array([[0.0, 1.0], [1.0, 0.0]]), (n_batch, 2, 2)
    )

B = [B_1[..., None], B_2[..., None]]

# for the single modality, a sequence over time of observations (one hot vectors)
obs = [jnp.broadcast_to(jnp.array([[1., 0.], # observation 0 is ambiguous with respect state factors
                                    [1., 0], # observation 0 is ambiguous with respect state factors
                                    [1., 0], # observation 0 is ambiguous with respect state factors
                                    [0., 1.]])[:, None], (4, n_batch, num_obs[0]) )] # observation 1 provides information about exact state of both factors 
C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences
D = [jnp.ones((n_batch, 3)) / 3., jnp.ones((n_batch, 2)) / 2.] # flat prior
E = jnp.ones((n_batch, 1))
num_states = [3, 2] num_obs = [2] n_batch = 2 A_1 = jnp.array([[1.0, 1.0, 1.0], [0.0, 0.0, 1.]]) A_2 = jnp.array([[1.0, 1.0], [1., 0.]]) A_tensor = A_1[..., None] * A_2[:, None] A_tensor /= A_tensor.sum(0) A = [jnp.broadcast_to(A_tensor, (n_batch, num_obs[0], 3, 2)) ] # create two transition matrices, one for each state factor B_1 = jnp.broadcast_to( jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), (n_batch, 3, 3) ) B_2 = jnp.broadcast_to( jnp.array([[0.0, 1.0], [1.0, 0.0]]), (n_batch, 2, 2) ) B = [B_1[..., None], B_2[..., None]] # for the single modality, a sequence over time of observations (one hot vectors) obs = [jnp.broadcast_to(jnp.array([[1., 0.], # observation 0 is ambiguous with respect state factors [1., 0], # observation 0 is ambiguous with respect state factors [1., 0], # observation 0 is ambiguous with respect state factors [0., 1.]])[:, None], (4, n_batch, num_obs[0]) )] # observation 1 provides information about exact state of both factors C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences D = [jnp.ones((n_batch, 3)) / 3., jnp.ones((n_batch, 2)) / 2.] # flat prior E = jnp.ones((n_batch, 1))

Construct the Agent

In [1]:
Copied!
pA = None
pB = None

agents = Agent(
        A=A,
        B=B,
        C=C,
        D=D,
        E=E,
        pA=pA,
        pB=pB,
        policy_len=3,
        categorical_obs=True,
        action_selection="deterministic",
        sampling_mode="full",
        inference_algo="ovf",
        num_iter=16,
        batch_size=n_batch
)
pA = None pB = None agents = Agent( A=A, B=B, C=C, D=D, E=E, pA=pA, pB=pB, policy_len=3, categorical_obs=True, action_selection="deterministic", sampling_mode="full", inference_algo="ovf", num_iter=16, batch_size=n_batch )
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61062/464010401.py:4: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agents = Agent(

OVF Smoothing¶

Using obs and policies, pass in the arguments outcomes, past_actions, empirical_prior and qs_hist to agent.infer_states(...)

Run first timestep of inference using obs[0], no past actions, empirical prior set to actual prior, no qs_hist

In [ ]:
Copied!
prior = agents.D
action_hist = []
qs_hist=None
for t in range(len(obs[0])):
    first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs)
    beliefs = agents.infer_states(first_obs, prior, qs_hist=qs_hist)
    actions = jnp.broadcast_to(agents.policies[0, 0], (2, 2))
    prior = agents.update_empirical_prior(actions, beliefs)
    qs_hist = beliefs
    if t < len(obs[0]) - 1:
        action_hist.append(actions)

v_jso = jit(vmap(smoothing_ovf), backend='cpu')
actions_seq = jnp.stack(action_hist, 1)
prior = agents.D action_hist = [] qs_hist=None for t in range(len(obs[0])): first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs) beliefs = agents.infer_states(first_obs, prior, qs_hist=qs_hist) actions = jnp.broadcast_to(agents.policies[0, 0], (2, 2)) prior = agents.update_empirical_prior(actions, beliefs) qs_hist = beliefs if t < len(obs[0]) - 1: action_hist.append(actions) v_jso = jit(vmap(smoothing_ovf), backend='cpu') actions_seq = jnp.stack(action_hist, 1)
In [2]:
Copied!
smoothed_beliefs = v_jso(beliefs, agents.B, actions_seq)
%timeit v_jso(beliefs, agents.B, actions_seq)[0][0].block_until_ready()
smoothed_beliefs = v_jso(beliefs, agents.B, actions_seq) %timeit v_jso(beliefs, agents.B, actions_seq)[0][0].block_until_ready()
24.1 μs ± 2.27 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Try the version of smoothing_ovf with sparse tensors

In [3]:
Copied!
sparse_B = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=1), agents.B)

smoothed_beliefs_sparse = v_jso(beliefs, sparse_B, actions_seq)
%timeit v_jso(beliefs, sparse_B, actions_seq)[0][0].block_until_ready()
sparse_B = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=1), agents.B) smoothed_beliefs_sparse = v_jso(beliefs, sparse_B, actions_seq) %timeit v_jso(beliefs, sparse_B, actions_seq)[0][0].block_until_ready()
28.4 μs ± 361 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Now we can plot that pair of filtering / smoothing distributions for the single batch / single agent, that we ran

In [4]:
Copied!
# with dense matrices
fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)

sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')

sns.heatmap(smoothed_beliefs[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_beliefs[0][1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')

axes[0, 0].set_title('Filtered beliefs')
axes[0, 1].set_title('Smoothed beliefs')
plt.show()
# with dense matrices fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True) sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_beliefs[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_beliefs[0][1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') axes[0, 0].set_title('Filtered beliefs') axes[0, 1].set_title('Smoothed beliefs') plt.show()
No description has been provided for this image
In [5]:
Copied!
# with sparse matrices
fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)

sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')

sns.heatmap(smoothed_beliefs_sparse[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_beliefs_sparse[0][1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')

axes[0, 0].set_title('Filtered beliefs')
axes[0, 1].set_title("Smoothed beliefs")
plt.show()
# with sparse matrices fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True) sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_beliefs_sparse[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_beliefs_sparse[0][1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') axes[0, 0].set_title('Filtered beliefs') axes[0, 1].set_title("Smoothed beliefs") plt.show()
No description has been provided for this image

Compare to marginal message passing¶

In [6]:
Copied!
mmp_agents = agents = Agent(
        A=A,
        B=B,
        C=C,
        D=D,
        E=E,
        pA=pA,
        pB=pB,
        policy_len=3,
        control_fac_idx=None,
        categorical_obs=True,
        action_selection="deterministic",
        sampling_mode="full",
        inference_algo="mmp",
        num_iter=16,
        batch_size=n_batch,
)

mmp_obs = [jnp.moveaxis(obs[0], 0, 1)]
post_marg_beliefs = mmp_agents.infer_states(mmp_obs, mmp_agents.D, past_actions=jnp.stack(action_hist, 1))

#with sparse matrices
fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True)

sns.heatmap(post_marg_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(post_marg_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis')

fig.suptitle('Marginal smoothed beliefs');
mmp_agents = agents = Agent( A=A, B=B, C=C, D=D, E=E, pA=pA, pB=pB, policy_len=3, control_fac_idx=None, categorical_obs=True, action_selection="deterministic", sampling_mode="full", inference_algo="mmp", num_iter=16, batch_size=n_batch, ) mmp_obs = [jnp.moveaxis(obs[0], 0, 1)] post_marg_beliefs = mmp_agents.infer_states(mmp_obs, mmp_agents.D, past_actions=jnp.stack(action_hist, 1)) #with sparse matrices fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True) sns.heatmap(post_marg_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(post_marg_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis') fig.suptitle('Marginal smoothed beliefs');
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61062/1188368462.py:1: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  mmp_agents = agents = Agent(
No description has been provided for this image

Compare to variational message passing¶

In [7]:
Copied!
vmp_agents = agents = Agent(
        A=A,
        B=B,
        C=C,
        D=D,
        E=E,
        pA=pA,
        pB=pB,
        policy_len=3,
        control_fac_idx=None,
        categorical_obs=True,
        action_selection="deterministic",
        sampling_mode="full",
        inference_algo="vmp",
        num_iter=16,
        batch_size=n_batch,
)

vmp_obs = [jnp.moveaxis(obs[0], 0, 1)]
post_vmp_beliefs = vmp_agents.infer_states(vmp_obs, vmp_agents.D, past_actions=jnp.stack(action_hist, 1))

#with sparse matrices
fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True)

sns.heatmap(post_vmp_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(post_vmp_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis')

fig.suptitle('VMP smoothed beliefs')
vmp_agents = agents = Agent( A=A, B=B, C=C, D=D, E=E, pA=pA, pB=pB, policy_len=3, control_fac_idx=None, categorical_obs=True, action_selection="deterministic", sampling_mode="full", inference_algo="vmp", num_iter=16, batch_size=n_batch, ) vmp_obs = [jnp.moveaxis(obs[0], 0, 1)] post_vmp_beliefs = vmp_agents.infer_states(vmp_obs, vmp_agents.D, past_actions=jnp.stack(action_hist, 1)) #with sparse matrices fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True) sns.heatmap(post_vmp_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(post_vmp_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis') fig.suptitle('VMP smoothed beliefs')
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61062/4088240323.py:1: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  vmp_agents = agents = Agent(
Out[7]:
Text(0.5, 0.98, 'VMP smoothed beliefs')
No description has been provided for this image

Associative-scan HMM (B-native, single-factor reduction)¶

The associative-scan implementation expects a single hidden-state factor. For this multi-factor model, we collapse the two factors into a joint state space via a Cartesian product and run exact filtering/smoothing using the B-native (column-stochastic) formulation, then marginalize back to factor-wise beliefs for visualization.

In [8]:
Copied!
batch_idx = 0
T = obs[0].shape[0]
s1, s2 = num_states

# single-modality observations for one batch
obs_single = obs[0][:, batch_idx, :]  # (T, obs_dim)

# collapse A over factors into a single-state emission matrix
A0 = A[0][batch_idx]  # (obs_dim, s1, s2)
A_big = A0.reshape(num_obs[0], s1 * s2)
log_likelihoods = obs_single @ log_stable(A_big)

# build joint transition (column-stochastic) via Kronecker product
B1 = B[0][batch_idx, :, :, 0]  # (s1,s1) column-stochastic
B2 = B[1][batch_idx, :, :, 0]  # (s2,s2) column-stochastic
B_big = jnp.kron(B1, B2)

# joint prior
D1 = D[0][batch_idx]
D2 = D[1][batch_idx]
prior_big = jnp.kron(D1, D2)

mll, filt_big, pred_big, smooth_big, cond_big = hmm_smoother_scan_colstoch(
    prior_big, B_big, log_likelihoods
)

# reshape joint beliefs to factor-wise marginals
filt_joint = filt_big.reshape(T, s1, s2)
smooth_joint = smooth_big.reshape(T, s1, s2)

filt_fac1 = filt_joint.sum(axis=2)
filt_fac2 = filt_joint.sum(axis=1)
smooth_fac1 = smooth_joint.sum(axis=2)
smooth_fac2 = smooth_joint.sum(axis=1)

fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)
sns.heatmap(filt_fac1.mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(filt_fac2.mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')

sns.heatmap(smooth_fac1.mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smooth_fac2.mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')

axes[0, 0].set_title('Associative scan filtered beliefs (B-native)')
axes[0, 1].set_title('Associative scan smoothed beliefs (B-native)')
plt.show()
batch_idx = 0 T = obs[0].shape[0] s1, s2 = num_states # single-modality observations for one batch obs_single = obs[0][:, batch_idx, :] # (T, obs_dim) # collapse A over factors into a single-state emission matrix A0 = A[0][batch_idx] # (obs_dim, s1, s2) A_big = A0.reshape(num_obs[0], s1 * s2) log_likelihoods = obs_single @ log_stable(A_big) # build joint transition (column-stochastic) via Kronecker product B1 = B[0][batch_idx, :, :, 0] # (s1,s1) column-stochastic B2 = B[1][batch_idx, :, :, 0] # (s2,s2) column-stochastic B_big = jnp.kron(B1, B2) # joint prior D1 = D[0][batch_idx] D2 = D[1][batch_idx] prior_big = jnp.kron(D1, D2) mll, filt_big, pred_big, smooth_big, cond_big = hmm_smoother_scan_colstoch( prior_big, B_big, log_likelihoods ) # reshape joint beliefs to factor-wise marginals filt_joint = filt_big.reshape(T, s1, s2) smooth_joint = smooth_big.reshape(T, s1, s2) filt_fac1 = filt_joint.sum(axis=2) filt_fac2 = filt_joint.sum(axis=1) smooth_fac1 = smooth_joint.sum(axis=2) smooth_fac2 = smooth_joint.sum(axis=1) fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True) sns.heatmap(filt_fac1.mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(filt_fac2.mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smooth_fac1.mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smooth_fac2.mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') axes[0, 0].set_title('Associative scan filtered beliefs (B-native)') axes[0, 1].set_title('Associative scan smoothed beliefs (B-native)') plt.show()
No description has been provided for this image

Made with Dracula Theme for MkDocs