pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Sparse Array Benchmarking¶

In [ ]:
Copied!
%load_ext autoreload
%autoreload 2
%load_ext autoreload %autoreload 2
In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import jax.numpy as jnp
from jax import tree_util as jtu
from jax.experimental import sparse
from pymdp.agent import Agent
import matplotlib.pyplot as plt
import seaborn as sns
import time

from pymdp.inference import smoothing_ovf
import numpy as np

import tracemalloc
import jax.numpy as jnp from jax import tree_util as jtu from jax.experimental import sparse from pymdp.agent import Agent import matplotlib.pyplot as plt import seaborn as sns import time from pymdp.inference import smoothing_ovf import numpy as np import tracemalloc
In [ ]:
Copied!
def sizeof(x):
    return np.prod(x.shape)


def sizeof_sparse(x):
    return np.prod(x.data.shape) + np.prod(x.indices.shape)


def get_matrices(n_batch, num_obs, n_states):

    A_1 = jnp.ones((num_obs[0], n_states[0]))
    A_1 = A_1.at[-1, :-1].set(0)
    A_2 = jnp.ones((num_obs[0], n_states[1]))
    A_2 = A_2.at[-1, 1:].set(0)

    A_tensor = A_1[..., None] * A_2[:, None]
    A_tensor /= A_tensor.sum(0)

    A = [jnp.broadcast_to(A_tensor, (n_batch, *A_tensor.shape))]

    # create two transition matrices, one for each state factor
    B_1 = jnp.eye(n_states[0])
    B_1 = B_1.at[:, 1:].set(B_1[:, :-1])
    B_1 = B_1.at[:, 0].set(0)
    B_1 = B_1.at[-1, 0].set(1)
    B_1 = jnp.broadcast_to(B_1, (n_batch, n_states[0], n_states[0]))

    B_2 = jnp.eye(n_states[1])
    B_2 = B_2.at[:, 1:].set(B_2[:, :-1])
    B_2 = B_2.at[:, 0].set(0)
    B_2 = B_2.at[-1, 0].set(1)
    B_2 = jnp.broadcast_to(B_2, (n_batch, n_states[1], n_states[1]))

    B = [B_1[..., None], B_2[..., None]]
    C = [jnp.zeros((n_batch, num_obs[0]))]  # flat preferences
    D = [jnp.ones((n_batch, n_states[0])) / n_states[0], jnp.ones((n_batch, n_states[1])) / n_states[1]]  # flat prior
    E = jnp.ones((n_batch, 1))

    return A, B, C, D, E
def sizeof(x): return np.prod(x.shape) def sizeof_sparse(x): return np.prod(x.data.shape) + np.prod(x.indices.shape) def get_matrices(n_batch, num_obs, n_states): A_1 = jnp.ones((num_obs[0], n_states[0])) A_1 = A_1.at[-1, :-1].set(0) A_2 = jnp.ones((num_obs[0], n_states[1])) A_2 = A_2.at[-1, 1:].set(0) A_tensor = A_1[..., None] * A_2[:, None] A_tensor /= A_tensor.sum(0) A = [jnp.broadcast_to(A_tensor, (n_batch, *A_tensor.shape))] # create two transition matrices, one for each state factor B_1 = jnp.eye(n_states[0]) B_1 = B_1.at[:, 1:].set(B_1[:, :-1]) B_1 = B_1.at[:, 0].set(0) B_1 = B_1.at[-1, 0].set(1) B_1 = jnp.broadcast_to(B_1, (n_batch, n_states[0], n_states[0])) B_2 = jnp.eye(n_states[1]) B_2 = B_2.at[:, 1:].set(B_2[:, :-1]) B_2 = B_2.at[:, 0].set(0) B_2 = B_2.at[-1, 0].set(1) B_2 = jnp.broadcast_to(B_2, (n_batch, n_states[1], n_states[1])) B = [B_1[..., None], B_2[..., None]] C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences D = [jnp.ones((n_batch, n_states[0])) / n_states[0], jnp.ones((n_batch, n_states[1])) / n_states[1]] # flat prior E = jnp.ones((n_batch, 1)) return A, B, C, D, E
In [ ]:
Copied!
def profile(fun, *args): 
    tracemalloc.start()
    tracemalloc.reset_peak()
    bt = time.time()
    res = fun(*args)
    et = time.time()
    size, peak = tracemalloc.get_traced_memory()

    stats = {'time': et - bt}
    return res, stats

def experiment(n_states):
    results = {}

    n_batch = 1
    num_obs = [2]

    A, B, C, D, E = get_matrices(n_batch=n_batch, num_obs=num_obs, n_states=n_states)

    # for the single modality, a sequence over time of observations (one hot vectors)
    obs = [
        jnp.broadcast_to(
            jnp.array(
                [
                    [1.0, 0.0],  # observation 0 is ambiguous with respect state factors
                    [1.0, 0],  # observation 0 is ambiguous with respect state factors
                    [1.0, 0],  # observation 0 is ambiguous with respect state factors
                    [0.0, 1.0],
                ]
            )[:, None],
            (4, n_batch, num_obs[0]),
        )
    ]  # observation 1 provides information about exact state of both factors

    agents = Agent(
        A=A,
        B=B,
        C=C,
        D=D,
        E=E,
        pA=None,
        pB=None,
        policy_len=3,
        control_fac_idx=None,
        policies=None,
        gamma=16.0,
        alpha=16.0,
        use_utility=True,
        categorical_obs=True,
        action_selection="deterministic",
        sampling_mode="full",
        inference_algo="ovf",
        num_iter=16,
        learn_A=False,
        learn_B=False,
        batch_size=1,
    )

    jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=n_batch), agents.B)


    prior = agents.D
    qs_hist = None
    action_hist = []
    for t in range(len(obs[0])):
        first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs)
        beliefs = agents.infer_states(first_obs, past_actions=None, empirical_prior=prior, qs_hist=qs_hist)
        actions = jnp.broadcast_to(agents.policies[0, 0], (n_batch, 2))
        prior = agents.update_empirical_prior(actions, beliefs)
        qs_hist = beliefs
        action_hist.append(actions)

    beliefs = jtu.tree_map(lambda x, y: jnp.concatenate([x[:, None], y], 1), agents.D, beliefs)

    take_first = lambda pytree: jtu.tree_map(lambda leaf: leaf[0], pytree)
    beliefs_single = take_first(beliefs)

    # ======
    # Dense implementation
    smoothed_beliefs_dense, run_stats = profile(
        smoothing_ovf, *(beliefs_single, take_first(agents.B), jnp.stack(action_hist, 1)[0])
    )
    results.update({k+'_dense': v for k, v in run_stats.items()})
    results["size_dense"] = sum([sizeof(sB) for sB in agents.B])
    # ======

    sparse_B_single = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b[0]), agents.B)
    actions_single = jnp.stack(action_hist, 1)[0]

    # ======
    # Sparse implementation
    smoothed_beliefs_sparse, run_stats = profile(
        smoothing_ovf, *(beliefs_single, sparse_B_single, actions_single)
    )
    results.update({k+'_sparse': v for k, v in run_stats.items()})
    results["size_sparse"] = sum([sizeof_sparse(sB) for sB in sparse_B_single])
    # ======

    return results, [beliefs_single, smoothed_beliefs_dense, smoothed_beliefs_sparse]
def profile(fun, *args): tracemalloc.start() tracemalloc.reset_peak() bt = time.time() res = fun(*args) et = time.time() size, peak = tracemalloc.get_traced_memory() stats = {'time': et - bt} return res, stats def experiment(n_states): results = {} n_batch = 1 num_obs = [2] A, B, C, D, E = get_matrices(n_batch=n_batch, num_obs=num_obs, n_states=n_states) # for the single modality, a sequence over time of observations (one hot vectors) obs = [ jnp.broadcast_to( jnp.array( [ [1.0, 0.0], # observation 0 is ambiguous with respect state factors [1.0, 0], # observation 0 is ambiguous with respect state factors [1.0, 0], # observation 0 is ambiguous with respect state factors [0.0, 1.0], ] )[:, None], (4, n_batch, num_obs[0]), ) ] # observation 1 provides information about exact state of both factors agents = Agent( A=A, B=B, C=C, D=D, E=E, pA=None, pB=None, policy_len=3, control_fac_idx=None, policies=None, gamma=16.0, alpha=16.0, use_utility=True, categorical_obs=True, action_selection="deterministic", sampling_mode="full", inference_algo="ovf", num_iter=16, learn_A=False, learn_B=False, batch_size=1, ) jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=n_batch), agents.B) prior = agents.D qs_hist = None action_hist = [] for t in range(len(obs[0])): first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs) beliefs = agents.infer_states(first_obs, past_actions=None, empirical_prior=prior, qs_hist=qs_hist) actions = jnp.broadcast_to(agents.policies[0, 0], (n_batch, 2)) prior = agents.update_empirical_prior(actions, beliefs) qs_hist = beliefs action_hist.append(actions) beliefs = jtu.tree_map(lambda x, y: jnp.concatenate([x[:, None], y], 1), agents.D, beliefs) take_first = lambda pytree: jtu.tree_map(lambda leaf: leaf[0], pytree) beliefs_single = take_first(beliefs) # ====== # Dense implementation smoothed_beliefs_dense, run_stats = profile( smoothing_ovf, *(beliefs_single, take_first(agents.B), jnp.stack(action_hist, 1)[0]) ) results.update({k+'_dense': v for k, v in run_stats.items()}) results["size_dense"] = sum([sizeof(sB) for sB in agents.B]) # ====== sparse_B_single = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b[0]), agents.B) actions_single = jnp.stack(action_hist, 1)[0] # ====== # Sparse implementation smoothed_beliefs_sparse, run_stats = profile( smoothing_ovf, *(beliefs_single, sparse_B_single, actions_single) ) results.update({k+'_sparse': v for k, v in run_stats.items()}) results["size_sparse"] = sum([sizeof_sparse(sB) for sB in sparse_B_single]) # ====== return results, [beliefs_single, smoothed_beliefs_dense, smoothed_beliefs_sparse]

Running the experiment and visualizing the results¶

In [1]:
Copied!
res, (beliefs, smoothed_dense, smoothed_sparse) = experiment([2, 3])

fig, axes = plt.subplots(2, 3, figsize=(8, 4), sharex=True)

sns.heatmap(beliefs[0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(beliefs[1].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')

sns.heatmap(smoothed_dense[0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_dense[0][1].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')

sns.heatmap(smoothed_sparse[0][0].mT, ax=axes[0, 2], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_sparse[0][1].mT, ax=axes[1, 2], cbar=False, vmax=1., vmin=0., cmap='viridis')

axes[0, 0].set_title('Filtered beliefs')
axes[0, 1].set_title('smoothed beliefs dense')
axes[0, 2].set_title('smoothed beliefs sparse')
res, (beliefs, smoothed_dense, smoothed_sparse) = experiment([2, 3]) fig, axes = plt.subplots(2, 3, figsize=(8, 4), sharex=True) sns.heatmap(beliefs[0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(beliefs[1].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_dense[0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_dense[0][1].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_sparse[0][0].mT, ax=axes[0, 2], cbar=False, vmax=1., vmin=0., cmap='viridis') sns.heatmap(smoothed_sparse[0][1].mT, ax=axes[1, 2], cbar=False, vmax=1., vmin=0., cmap='viridis') axes[0, 0].set_title('Filtered beliefs') axes[0, 1].set_title('smoothed beliefs dense') axes[0, 2].set_title('smoothed beliefs sparse')
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_3558/2542277956.py:35: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agents = Agent(
Out[1]:
Text(0.5, 1.0, 'smoothed beliefs sparse')
No description has been provided for this image

Benchmarking runtime and memory performance¶

In [2]:
Copied!
n_steps = 10

res = []
for i in range(1, n_steps):
    print(f"Step {i}")
    num_states = [100 * i, 300 * i]
    print('\t', num_states)
    results, bel = experiment(num_states)
    res += [results]
    print('\t', res[-1])
n_steps = 10 res = [] for i in range(1, n_steps): print(f"Step {i}") num_states = [100 * i, 300 * i] print('\t', num_states) results, bel = experiment(num_states) res += [results] print('\t', res[-1])
Step 1
	 [100, 300]
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_3558/2542277956.py:35: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agents = Agent(
	 {'time_dense': 0.4508378505706787, 'size_dense': 100000, 'time_sparse': 0.5985698699951172, 'size_sparse': 1600}
Step 2
	 [200, 600]
	 {'time_dense': 0.5454308986663818, 'size_dense': 400000, 'time_sparse': 0.5894217491149902, 'size_sparse': 3200}
Step 3
	 [300, 900]
	 {'time_dense': 0.36173391342163086, 'size_dense': 900000, 'time_sparse': 0.5830059051513672, 'size_sparse': 4800}
Step 4
	 [400, 1200]
	 {'time_dense': 0.44315290451049805, 'size_dense': 1600000, 'time_sparse': 0.7388842105865479, 'size_sparse': 6400}
Step 5
	 [500, 1500]
	 {'time_dense': 0.418179988861084, 'size_dense': 2500000, 'time_sparse': 0.5743489265441895, 'size_sparse': 8000}
Step 6
	 [600, 1800]
	 {'time_dense': 0.3202641010284424, 'size_dense': 3600000, 'time_sparse': 0.5391190052032471, 'size_sparse': 9600}
Step 7
	 [700, 2100]
	 {'time_dense': 0.44112205505371094, 'size_dense': 4900000, 'time_sparse': 0.6291599273681641, 'size_sparse': 11200}
Step 8
	 [800, 2400]
	 {'time_dense': 0.44248127937316895, 'size_dense': 6400000, 'time_sparse': 0.6097638607025146, 'size_sparse': 12800}
Step 9
	 [900, 2700]
	 {'time_dense': 0.3552210330963135, 'size_dense': 8100000, 'time_sparse': 0.639538049697876, 'size_sparse': 14400}
In [3]:
Copied!
keys = list(set(r.replace("_dense", "").replace("_sparse", "") for r in res[0].keys())) 
n_plots = len(keys)

fig, ax = plt.subplots(n_plots, 1, figsize=(6, 3 * n_plots))
for i, a in enumerate(ax.flatten()):
    k = keys[i]
    a.plot([r[k + "_dense"] for r in res], label=f"{k.replace('_', ' ').capitalize()} dense")
    a.plot([r[k + "_sparse"] for r in res], label=f"{k} sparse")
    a.set_xticks(list(range(0, len(res))))
    a.set_xticklabels([f"[{1000*i}, {3000*i}]" for i in range(1, n_steps)], rotation=45)
    m = max([r[k + "_dense"] for r in res] + [r[k + "_sparse"] for r in res]) * 1.05
    a.set_ylim([0, m])

plt.tight_layout()
[a.legend() for a in ax.flatten()]
plt.show()
keys = list(set(r.replace("_dense", "").replace("_sparse", "") for r in res[0].keys())) n_plots = len(keys) fig, ax = plt.subplots(n_plots, 1, figsize=(6, 3 * n_plots)) for i, a in enumerate(ax.flatten()): k = keys[i] a.plot([r[k + "_dense"] for r in res], label=f"{k.replace('_', ' ').capitalize()} dense") a.plot([r[k + "_sparse"] for r in res], label=f"{k} sparse") a.set_xticks(list(range(0, len(res)))) a.set_xticklabels([f"[{1000*i}, {3000*i}]" for i in range(1, n_steps)], rotation=45) m = max([r[k + "_dense"] for r in res] + [r[k + "_sparse"] for r in res]) * 1.05 a.set_ylim([0, m]) plt.tight_layout() [a.legend() for a in ax.flatten()] plt.show()
No description has been provided for this image

Made with Dracula Theme for MkDocs