Sparse Array Benchmarking¶
In [ ]:
Copied!
%load_ext autoreload
%autoreload 2
%load_ext autoreload
%autoreload 2
In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import jax.numpy as jnp
from jax import tree_util as jtu
from jax.experimental import sparse
from pymdp.agent import Agent
import matplotlib.pyplot as plt
import seaborn as sns
import time
from pymdp.inference import smoothing_ovf
import numpy as np
import tracemalloc
import jax.numpy as jnp
from jax import tree_util as jtu
from jax.experimental import sparse
from pymdp.agent import Agent
import matplotlib.pyplot as plt
import seaborn as sns
import time
from pymdp.inference import smoothing_ovf
import numpy as np
import tracemalloc
In [ ]:
Copied!
def sizeof(x):
return np.prod(x.shape)
def sizeof_sparse(x):
return np.prod(x.data.shape) + np.prod(x.indices.shape)
def get_matrices(n_batch, num_obs, n_states):
A_1 = jnp.ones((num_obs[0], n_states[0]))
A_1 = A_1.at[-1, :-1].set(0)
A_2 = jnp.ones((num_obs[0], n_states[1]))
A_2 = A_2.at[-1, 1:].set(0)
A_tensor = A_1[..., None] * A_2[:, None]
A_tensor /= A_tensor.sum(0)
A = [jnp.broadcast_to(A_tensor, (n_batch, *A_tensor.shape))]
# create two transition matrices, one for each state factor
B_1 = jnp.eye(n_states[0])
B_1 = B_1.at[:, 1:].set(B_1[:, :-1])
B_1 = B_1.at[:, 0].set(0)
B_1 = B_1.at[-1, 0].set(1)
B_1 = jnp.broadcast_to(B_1, (n_batch, n_states[0], n_states[0]))
B_2 = jnp.eye(n_states[1])
B_2 = B_2.at[:, 1:].set(B_2[:, :-1])
B_2 = B_2.at[:, 0].set(0)
B_2 = B_2.at[-1, 0].set(1)
B_2 = jnp.broadcast_to(B_2, (n_batch, n_states[1], n_states[1]))
B = [B_1[..., None], B_2[..., None]]
C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences
D = [jnp.ones((n_batch, n_states[0])) / n_states[0], jnp.ones((n_batch, n_states[1])) / n_states[1]] # flat prior
E = jnp.ones((n_batch, 1))
return A, B, C, D, E
def sizeof(x):
return np.prod(x.shape)
def sizeof_sparse(x):
return np.prod(x.data.shape) + np.prod(x.indices.shape)
def get_matrices(n_batch, num_obs, n_states):
A_1 = jnp.ones((num_obs[0], n_states[0]))
A_1 = A_1.at[-1, :-1].set(0)
A_2 = jnp.ones((num_obs[0], n_states[1]))
A_2 = A_2.at[-1, 1:].set(0)
A_tensor = A_1[..., None] * A_2[:, None]
A_tensor /= A_tensor.sum(0)
A = [jnp.broadcast_to(A_tensor, (n_batch, *A_tensor.shape))]
# create two transition matrices, one for each state factor
B_1 = jnp.eye(n_states[0])
B_1 = B_1.at[:, 1:].set(B_1[:, :-1])
B_1 = B_1.at[:, 0].set(0)
B_1 = B_1.at[-1, 0].set(1)
B_1 = jnp.broadcast_to(B_1, (n_batch, n_states[0], n_states[0]))
B_2 = jnp.eye(n_states[1])
B_2 = B_2.at[:, 1:].set(B_2[:, :-1])
B_2 = B_2.at[:, 0].set(0)
B_2 = B_2.at[-1, 0].set(1)
B_2 = jnp.broadcast_to(B_2, (n_batch, n_states[1], n_states[1]))
B = [B_1[..., None], B_2[..., None]]
C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences
D = [jnp.ones((n_batch, n_states[0])) / n_states[0], jnp.ones((n_batch, n_states[1])) / n_states[1]] # flat prior
E = jnp.ones((n_batch, 1))
return A, B, C, D, E
In [ ]:
Copied!
def profile(fun, *args):
tracemalloc.start()
tracemalloc.reset_peak()
bt = time.time()
res = fun(*args)
et = time.time()
size, peak = tracemalloc.get_traced_memory()
stats = {'time': et - bt}
return res, stats
def experiment(n_states):
results = {}
n_batch = 1
num_obs = [2]
A, B, C, D, E = get_matrices(n_batch=n_batch, num_obs=num_obs, n_states=n_states)
# for the single modality, a sequence over time of observations (one hot vectors)
obs = [
jnp.broadcast_to(
jnp.array(
[
[1.0, 0.0], # observation 0 is ambiguous with respect state factors
[1.0, 0], # observation 0 is ambiguous with respect state factors
[1.0, 0], # observation 0 is ambiguous with respect state factors
[0.0, 1.0],
]
)[:, None],
(4, n_batch, num_obs[0]),
)
] # observation 1 provides information about exact state of both factors
agents = Agent(
A=A,
B=B,
C=C,
D=D,
E=E,
pA=None,
pB=None,
policy_len=3,
control_fac_idx=None,
policies=None,
gamma=16.0,
alpha=16.0,
use_utility=True,
categorical_obs=True,
action_selection="deterministic",
sampling_mode="full",
inference_algo="ovf",
num_iter=16,
learn_A=False,
learn_B=False,
batch_size=1,
)
jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=n_batch), agents.B)
prior = agents.D
qs_hist = None
action_hist = []
for t in range(len(obs[0])):
first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs)
beliefs = agents.infer_states(first_obs, past_actions=None, empirical_prior=prior, qs_hist=qs_hist)
actions = jnp.broadcast_to(agents.policies[0, 0], (n_batch, 2))
prior = agents.update_empirical_prior(actions, beliefs)
qs_hist = beliefs
action_hist.append(actions)
beliefs = jtu.tree_map(lambda x, y: jnp.concatenate([x[:, None], y], 1), agents.D, beliefs)
take_first = lambda pytree: jtu.tree_map(lambda leaf: leaf[0], pytree)
beliefs_single = take_first(beliefs)
# ======
# Dense implementation
smoothed_beliefs_dense, run_stats = profile(
smoothing_ovf, *(beliefs_single, take_first(agents.B), jnp.stack(action_hist, 1)[0])
)
results.update({k+'_dense': v for k, v in run_stats.items()})
results["size_dense"] = sum([sizeof(sB) for sB in agents.B])
# ======
sparse_B_single = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b[0]), agents.B)
actions_single = jnp.stack(action_hist, 1)[0]
# ======
# Sparse implementation
smoothed_beliefs_sparse, run_stats = profile(
smoothing_ovf, *(beliefs_single, sparse_B_single, actions_single)
)
results.update({k+'_sparse': v for k, v in run_stats.items()})
results["size_sparse"] = sum([sizeof_sparse(sB) for sB in sparse_B_single])
# ======
return results, [beliefs_single, smoothed_beliefs_dense, smoothed_beliefs_sparse]
def profile(fun, *args):
tracemalloc.start()
tracemalloc.reset_peak()
bt = time.time()
res = fun(*args)
et = time.time()
size, peak = tracemalloc.get_traced_memory()
stats = {'time': et - bt}
return res, stats
def experiment(n_states):
results = {}
n_batch = 1
num_obs = [2]
A, B, C, D, E = get_matrices(n_batch=n_batch, num_obs=num_obs, n_states=n_states)
# for the single modality, a sequence over time of observations (one hot vectors)
obs = [
jnp.broadcast_to(
jnp.array(
[
[1.0, 0.0], # observation 0 is ambiguous with respect state factors
[1.0, 0], # observation 0 is ambiguous with respect state factors
[1.0, 0], # observation 0 is ambiguous with respect state factors
[0.0, 1.0],
]
)[:, None],
(4, n_batch, num_obs[0]),
)
] # observation 1 provides information about exact state of both factors
agents = Agent(
A=A,
B=B,
C=C,
D=D,
E=E,
pA=None,
pB=None,
policy_len=3,
control_fac_idx=None,
policies=None,
gamma=16.0,
alpha=16.0,
use_utility=True,
categorical_obs=True,
action_selection="deterministic",
sampling_mode="full",
inference_algo="ovf",
num_iter=16,
learn_A=False,
learn_B=False,
batch_size=1,
)
jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=n_batch), agents.B)
prior = agents.D
qs_hist = None
action_hist = []
for t in range(len(obs[0])):
first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs)
beliefs = agents.infer_states(first_obs, past_actions=None, empirical_prior=prior, qs_hist=qs_hist)
actions = jnp.broadcast_to(agents.policies[0, 0], (n_batch, 2))
prior = agents.update_empirical_prior(actions, beliefs)
qs_hist = beliefs
action_hist.append(actions)
beliefs = jtu.tree_map(lambda x, y: jnp.concatenate([x[:, None], y], 1), agents.D, beliefs)
take_first = lambda pytree: jtu.tree_map(lambda leaf: leaf[0], pytree)
beliefs_single = take_first(beliefs)
# ======
# Dense implementation
smoothed_beliefs_dense, run_stats = profile(
smoothing_ovf, *(beliefs_single, take_first(agents.B), jnp.stack(action_hist, 1)[0])
)
results.update({k+'_dense': v for k, v in run_stats.items()})
results["size_dense"] = sum([sizeof(sB) for sB in agents.B])
# ======
sparse_B_single = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b[0]), agents.B)
actions_single = jnp.stack(action_hist, 1)[0]
# ======
# Sparse implementation
smoothed_beliefs_sparse, run_stats = profile(
smoothing_ovf, *(beliefs_single, sparse_B_single, actions_single)
)
results.update({k+'_sparse': v for k, v in run_stats.items()})
results["size_sparse"] = sum([sizeof_sparse(sB) for sB in sparse_B_single])
# ======
return results, [beliefs_single, smoothed_beliefs_dense, smoothed_beliefs_sparse]
Running the experiment and visualizing the results¶
In [1]:
Copied!
res, (beliefs, smoothed_dense, smoothed_sparse) = experiment([2, 3])
fig, axes = plt.subplots(2, 3, figsize=(8, 4), sharex=True)
sns.heatmap(beliefs[0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(beliefs[1].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_dense[0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_dense[0][1].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_sparse[0][0].mT, ax=axes[0, 2], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_sparse[0][1].mT, ax=axes[1, 2], cbar=False, vmax=1., vmin=0., cmap='viridis')
axes[0, 0].set_title('Filtered beliefs')
axes[0, 1].set_title('smoothed beliefs dense')
axes[0, 2].set_title('smoothed beliefs sparse')
res, (beliefs, smoothed_dense, smoothed_sparse) = experiment([2, 3])
fig, axes = plt.subplots(2, 3, figsize=(8, 4), sharex=True)
sns.heatmap(beliefs[0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(beliefs[1].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_dense[0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_dense[0][1].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_sparse[0][0].mT, ax=axes[0, 2], cbar=False, vmax=1., vmin=0., cmap='viridis')
sns.heatmap(smoothed_sparse[0][1].mT, ax=axes[1, 2], cbar=False, vmax=1., vmin=0., cmap='viridis')
axes[0, 0].set_title('Filtered beliefs')
axes[0, 1].set_title('smoothed beliefs dense')
axes[0, 2].set_title('smoothed beliefs sparse')
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_3558/2542277956.py:35: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agents = Agent(
Out[1]:
Text(0.5, 1.0, 'smoothed beliefs sparse')
Benchmarking runtime and memory performance¶
In [2]:
Copied!
n_steps = 10
res = []
for i in range(1, n_steps):
print(f"Step {i}")
num_states = [100 * i, 300 * i]
print('\t', num_states)
results, bel = experiment(num_states)
res += [results]
print('\t', res[-1])
n_steps = 10
res = []
for i in range(1, n_steps):
print(f"Step {i}")
num_states = [100 * i, 300 * i]
print('\t', num_states)
results, bel = experiment(num_states)
res += [results]
print('\t', res[-1])
Step 1 [100, 300]
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_3558/2542277956.py:35: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agents = Agent(
{'time_dense': 0.4508378505706787, 'size_dense': 100000, 'time_sparse': 0.5985698699951172, 'size_sparse': 1600}
Step 2
[200, 600]
{'time_dense': 0.5454308986663818, 'size_dense': 400000, 'time_sparse': 0.5894217491149902, 'size_sparse': 3200}
Step 3
[300, 900]
{'time_dense': 0.36173391342163086, 'size_dense': 900000, 'time_sparse': 0.5830059051513672, 'size_sparse': 4800}
Step 4
[400, 1200]
{'time_dense': 0.44315290451049805, 'size_dense': 1600000, 'time_sparse': 0.7388842105865479, 'size_sparse': 6400}
Step 5
[500, 1500]
{'time_dense': 0.418179988861084, 'size_dense': 2500000, 'time_sparse': 0.5743489265441895, 'size_sparse': 8000}
Step 6
[600, 1800]
{'time_dense': 0.3202641010284424, 'size_dense': 3600000, 'time_sparse': 0.5391190052032471, 'size_sparse': 9600}
Step 7
[700, 2100]
{'time_dense': 0.44112205505371094, 'size_dense': 4900000, 'time_sparse': 0.6291599273681641, 'size_sparse': 11200}
Step 8
[800, 2400]
{'time_dense': 0.44248127937316895, 'size_dense': 6400000, 'time_sparse': 0.6097638607025146, 'size_sparse': 12800}
Step 9
[900, 2700]
{'time_dense': 0.3552210330963135, 'size_dense': 8100000, 'time_sparse': 0.639538049697876, 'size_sparse': 14400}
In [3]:
Copied!
keys = list(set(r.replace("_dense", "").replace("_sparse", "") for r in res[0].keys()))
n_plots = len(keys)
fig, ax = plt.subplots(n_plots, 1, figsize=(6, 3 * n_plots))
for i, a in enumerate(ax.flatten()):
k = keys[i]
a.plot([r[k + "_dense"] for r in res], label=f"{k.replace('_', ' ').capitalize()} dense")
a.plot([r[k + "_sparse"] for r in res], label=f"{k} sparse")
a.set_xticks(list(range(0, len(res))))
a.set_xticklabels([f"[{1000*i}, {3000*i}]" for i in range(1, n_steps)], rotation=45)
m = max([r[k + "_dense"] for r in res] + [r[k + "_sparse"] for r in res]) * 1.05
a.set_ylim([0, m])
plt.tight_layout()
[a.legend() for a in ax.flatten()]
plt.show()
keys = list(set(r.replace("_dense", "").replace("_sparse", "") for r in res[0].keys()))
n_plots = len(keys)
fig, ax = plt.subplots(n_plots, 1, figsize=(6, 3 * n_plots))
for i, a in enumerate(ax.flatten()):
k = keys[i]
a.plot([r[k + "_dense"] for r in res], label=f"{k.replace('_', ' ').capitalize()} dense")
a.plot([r[k + "_sparse"] for r in res], label=f"{k} sparse")
a.set_xticks(list(range(0, len(res))))
a.set_xticklabels([f"[{1000*i}, {3000*i}]" for i in range(1, n_steps)], rotation=45)
m = max([r[k + "_dense"] for r in res] + [r[k + "_sparse"] for r in res]) * 1.05
a.set_ylim([0, m])
plt.tight_layout()
[a.legend() for a in ax.flatten()]
plt.show()