pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Imports¶

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [1]:
Copied!
%load_ext autoreload
%autoreload 2

from jax import nn, vmap, lax, jit
from jax import numpy as jnp, random as jr
from jax import tree_util as jtu
from jax.scipy.special import gammaln, digamma

import numpy as np

from pymdp.envs import GridWorld, rollout
from pymdp.agent import Agent
from pymdp.maths import dirichlet_expected_value

import matplotlib.pyplot as plt
import seaborn as sns
%load_ext autoreload %autoreload 2 from jax import nn, vmap, lax, jit from jax import numpy as jnp, random as jr from jax import tree_util as jtu from jax.scipy.special import gammaln, digamma import numpy as np from pymdp.envs import GridWorld, rollout from pymdp.agent import Agent from pymdp.maths import dirichlet_expected_value import matplotlib.pyplot as plt import seaborn as sns
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Grid world generative model¶

Here we will explore learning of the generative model inside a simple grid 7x7 world environment, where at each state agent can move into 4 possible directions. The agent can explore the environment for 100 time steps, after which it is reaturned to the original position. We will start first with an example where likelihood is fixed, and state transitions are unkown. Next we will explore the example where likelihood is unknown but the state transitions are known, and finally we will look at learning under joint uncertainty over likelihood and transitions (we would expect this case not to work in general with flat priors on both components).

In [ ]:
Copied!
# size of the grid world
grid_shape = (7, 7)

# number of agents
batch_size = 20

env = GridWorld(shape=grid_shape, initial_position=(3,3), include_stay=False)
# size of the grid world grid_shape = (7, 7) # number of agents batch_size = 20 env = GridWorld(shape=grid_shape, initial_position=(3,3), include_stay=False)

Define KL divergence between Dirichlet distributions¶

In [ ]:
Copied!
@jit
def kl_div_dirichlet(alpha1, alpha2):
    alpha0 = alpha1.sum(1, keepdims=True)
    kl = gammaln(alpha0.squeeze(1)) - gammaln(alpha2.sum(1))
    kl += jnp.sum(gammaln(alpha2) - gammaln(alpha1) + (alpha1 - alpha2) * (digamma(alpha1) - digamma(alpha0)), 1)

    return kl
@jit def kl_div_dirichlet(alpha1, alpha2): alpha0 = alpha1.sum(1, keepdims=True) kl = gammaln(alpha0.squeeze(1)) - gammaln(alpha2.sum(1)) kl += jnp.sum(gammaln(alpha2) - gammaln(alpha1) + (alpha1 - alpha2) * (digamma(alpha1) - digamma(alpha0)), 1) return kl

Initialize different sets of agents using Agent() class, use a different alpha (action selection temperature) for each set¶

Here, the agents have to learn B matrix, with A matrix being fixed to the true observation mapping of generative process, and flat prior oevr initial hidden states¶

In [ ]:
Copied!
# create agent with A matrix being fixed to the A of the generative process
num_obs = [a.shape[0] for a in env.A]
num_states = [b.shape[0] for b in env.B]

_A = [jnp.array(a) for a in env.A]
C = [jnp.zeros(num_obs[0])]
pB = [jnp.ones_like(env.B[0]) / num_states[0]]
_B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB)
_D = [jnp.ones(num_states[0])/num_states[0]]  # flat prior over initial states

agents = []
for i in range(5):
    agents.append( 
        Agent(
            _A,
            _B,
            C,
            _D,
            E=None,
            pA=None,
            pB=pB,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(batch_size),
            alpha=jnp.ones(batch_size) * i * .2,
            categorical_obs=False,
            action_selection="stochastic",
            inference_algo="exact",
            num_iter=1,
            learn_A=False,
            learn_B=True,
            learn_D=False,
            batch_size=batch_size,
            learning_mode="offline",
        )
    )
# create agent with A matrix being fixed to the A of the generative process num_obs = [a.shape[0] for a in env.A] num_states = [b.shape[0] for b in env.B] _A = [jnp.array(a) for a in env.A] C = [jnp.zeros(num_obs[0])] pB = [jnp.ones_like(env.B[0]) / num_states[0]] _B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB) _D = [jnp.ones(num_states[0])/num_states[0]] # flat prior over initial states agents = [] for i in range(5): agents.append( Agent( _A, _B, C, _D, E=None, pA=None, pB=pB, policy_len=3, use_utility=False, use_states_info_gain=True, use_param_info_gain=True, gamma=jnp.ones(batch_size), alpha=jnp.ones(batch_size) * i * .2, categorical_obs=False, action_selection="stochastic", inference_algo="exact", num_iter=1, learn_A=False, learn_B=True, learn_D=False, batch_size=batch_size, learning_mode="offline", ) )

Run multiple blocks of active inference on grid world for each batch of agents that was initialized different alpha values.¶

Since the agents are using offline learning (learning_mode = "offline"), this means that parameter updates to the B tensors are performed at the end of each block.¶

In [ ]:
Copied!
pB_ground_truth = 1e4 * env.B[0] + 1e-4
num_timesteps = 50
num_blocks = 40
key = jr.PRNGKey(0)
block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1))
divs = {i : [] for i in range(len(agents))}
for block in range(num_blocks):
    block_keys = block_and_batch_keys[block]
    for i, agent in enumerate(agents):
        init_obs, init_states = vmap(env.reset)(block_keys[:-1])
        last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states})
        agents[i] = last['agent']
        divs[i].append(kl_div_dirichlet(last['agent'].pB[0], pB_ground_truth).sum(-1).mean(-1))
pB_ground_truth = 1e4 * env.B[0] + 1e-4 num_timesteps = 50 num_blocks = 40 key = jr.PRNGKey(0) block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1)) divs = {i : [] for i in range(len(agents))} for block in range(num_blocks): block_keys = block_and_batch_keys[block] for i, agent in enumerate(agents): init_obs, init_states = vmap(env.reset)(block_keys[:-1]) last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states}) agents[i] = last['agent'] divs[i].append(kl_div_dirichlet(last['agent'].pB[0], pB_ground_truth).sum(-1).mean(-1))

Plot the KL divergence between the true parameters and believed parameters over time for the different groups of agents (agents with different levels of action stochasticity)¶

In [2]:
Copied!
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True)
for i in range(len(agents)):
    p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2)

axes.legend(title='alpha')
axes.set_ylabel('KL divergence')
axes.set_xlabel('epoch')
fig.tight_layout()
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True) for i in range(len(agents)): p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2) axes.legend(title='alpha') axes.set_ylabel('KL divergence') axes.set_xlabel('epoch') fig.tight_layout()
No description has been provided for this image

Visualize the learned B tensor alongside the true environmental parameters after training¶

In [3]:
Copied!
num_actions = env.B[0].shape[-1]
base_labels = ["Up", "Right", "Down", "Left", "Stay"]
action_labels = base_labels[:num_actions]

row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')


fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True)


for i in range(num_actions):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(action_labels[i])

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center')


fig.tight_layout()
num_actions = env.B[0].shape[-1] base_labels = ["Up", "Right", "Down", "Left", "Stay"] action_labels = base_labels[:num_actions] row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True) for i in range(num_actions): for j, agent in enumerate(agents): sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(action_labels[i]) for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Here, the agents have to learn A matrix, with B matrix being fixed to the true observation mapping of generative process, and precise (and accurate) prior over initial hidden states¶

In [ ]:
Copied!
C = [jnp.zeros((batch_size, num_obs[0]))]
pA = [jnp.ones_like(a) / a.shape[0] for a in env.A]
_A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA)
B = [jnp.array(b) for b in env.B]
D = [jnp.array(d) for d in env.D]  

agents = []
for i in range(5):
    agents.append( 
        Agent(
            _A,
            B,
            C,
            D,
            E=None,
            pA=pA,
            pB=None,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(batch_size),
            alpha=jnp.ones(batch_size) * i * .2,
            categorical_obs=False,
            action_selection="stochastic",
            inference_algo="exact",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=batch_size,
            learning_mode="offline",
        )
    )
C = [jnp.zeros((batch_size, num_obs[0]))] pA = [jnp.ones_like(a) / a.shape[0] for a in env.A] _A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA) B = [jnp.array(b) for b in env.B] D = [jnp.array(d) for d in env.D] agents = [] for i in range(5): agents.append( Agent( _A, B, C, D, E=None, pA=pA, pB=None, policy_len=3, use_utility=False, use_states_info_gain=True, use_param_info_gain=True, gamma=jnp.ones(batch_size), alpha=jnp.ones(batch_size) * i * .2, categorical_obs=False, action_selection="stochastic", inference_algo="exact", num_iter=1, learn_A=True, learn_B=False, learn_D=False, batch_size=batch_size, learning_mode="offline", ) )

Run multiple blocks of active inference on grid world for each batch of agents that was initialized different alpha values.¶

Since the agents are using offline learning (learning_mode = "offline"), this means that parameter updates to the A matrix are performed at the end of each block.¶

In [ ]:
Copied!
pA_ground_truth = 1e4 * env.A[0] + 1e-4
num_timesteps = 50
num_blocks = 20
key = jr.PRNGKey(0)
block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1))

divs = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    block_keys = block_and_batch_keys[block]
    for i, agent in enumerate(agents):
        init_obs, init_states = vmap(env.reset)(block_keys[:-1])
        last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states})
        agents[i] = last['agent']
        divs[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1))
pA_ground_truth = 1e4 * env.A[0] + 1e-4 num_timesteps = 50 num_blocks = 20 key = jr.PRNGKey(0) block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1)) divs = {i: [] for i in range(len(agents))} for block in range(num_blocks): block_keys = block_and_batch_keys[block] for i, agent in enumerate(agents): init_obs, init_states = vmap(env.reset)(block_keys[:-1]) last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states}) agents[i] = last['agent'] divs[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1))

Plot the KL divergence between the true parameters and believed parameters over time for the different groups of agents (agents with different levels of action stochasticity)¶

In [4]:
Copied!
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True)
for i in range(len(agents)):
    p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2)

axes.legend(title='alpha')
axes.set_ylabel('KL divergence')
axes.set_xlabel('epoch')
fig.tight_layout()
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True) for i in range(len(agents)): p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2) axes.legend(title='alpha') axes.set_ylabel('KL divergence') axes.set_xlabel('epoch') fig.tight_layout()
No description has been provided for this image

Visualize the learned A matrices alongside the true environmental parameters after training¶

In [5]:
Copied!
n_batches_to_show = min(5, agents[0].A[0].shape[0])
num_rows = len(agents) + 1
row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')

fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True)

for i in range(n_batches_to_show):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(f'batch={i + 1}')

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center')

fig.tight_layout()
n_batches_to_show = min(5, agents[0].A[0].shape[0]) num_rows = len(agents) + 1 row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True) for i in range(n_batches_to_show): for j, agent in enumerate(agents): sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(f'batch={i + 1}') for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Here, once again the agents have to learn the A matrix, with B matrix being fixed to the true B matrix of generative process. This time, agents have a flat prior over initial hidden states¶

In [ ]:
Copied!
# create agent with B matrix being fixed to the B of the generative process, and precise initial beliefs about hidden states
C = [jnp.zeros(num_obs[0])]
pA = [jnp.ones_like(a) / a.shape[0] for a in env.A]
_A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA)
B = [jnp.array(b) for b in env.B]
_D = [jnp.ones(num_states[0])/num_states[0]]  # flat prior over initial states

agents = []
for i in range(5):
    agents.append( 
        Agent(
            _A,
            B,
            C,
            _D,
            E=None,
            pA=pA,
            pB=None,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(batch_size),
            alpha=jnp.ones(batch_size) * i * .2,
            categorical_obs=False,
            action_selection="stochastic",
            inference_algo="exact",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=batch_size,
            learning_mode="offline",
        )
    )
# create agent with B matrix being fixed to the B of the generative process, and precise initial beliefs about hidden states C = [jnp.zeros(num_obs[0])] pA = [jnp.ones_like(a) / a.shape[0] for a in env.A] _A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA) B = [jnp.array(b) for b in env.B] _D = [jnp.ones(num_states[0])/num_states[0]] # flat prior over initial states agents = [] for i in range(5): agents.append( Agent( _A, B, C, _D, E=None, pA=pA, pB=None, policy_len=3, use_utility=False, use_states_info_gain=True, use_param_info_gain=True, gamma=jnp.ones(batch_size), alpha=jnp.ones(batch_size) * i * .2, categorical_obs=False, action_selection="stochastic", inference_algo="exact", num_iter=1, learn_A=True, learn_B=False, learn_D=False, batch_size=batch_size, learning_mode="offline", ) )

Run multiple blocks of active inference on grid world for each batch of agents that was initialized different alpha values.¶

Since the agents are using offline learning (learning_mode = "offline"), this means that parameter updates to the A matrix are performed at the end of each block.¶

In [ ]:
Copied!
pA_ground_truth = 1e4 * env.A[0] + 1e-4
num_timesteps = 50
num_blocks = 20
key = jr.PRNGKey(0)
block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1))

divs = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    block_keys = block_and_batch_keys[block]
    for i, agent in enumerate(agents):
        init_obs, init_states = vmap(env.reset)(block_keys[:-1])
        last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states})
        agents[i] = last['agent']
        divs[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1))
pA_ground_truth = 1e4 * env.A[0] + 1e-4 num_timesteps = 50 num_blocks = 20 key = jr.PRNGKey(0) block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1)) divs = {i: [] for i in range(len(agents))} for block in range(num_blocks): block_keys = block_and_batch_keys[block] for i, agent in enumerate(agents): init_obs, init_states = vmap(env.reset)(block_keys[:-1]) last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states}) agents[i] = last['agent'] divs[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1))

Plot the KL divergence between the true parameters and believed parameters over time for the different groups of agents (agents with different levels of action stochasticity)¶

In [6]:
Copied!
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True)
for i in range(len(agents)):
    p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2)

axes.legend(title='alpha')
axes.set_ylabel('KL divergence')
axes.set_xlabel('epoch')
fig.tight_layout()
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True) for i in range(len(agents)): p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2) axes.legend(title='alpha') axes.set_ylabel('KL divergence') axes.set_xlabel('epoch') fig.tight_layout()
No description has been provided for this image

Visualize the learned A matrices alongside the true environmental parameters after training¶

In [7]:
Copied!
n_batches_to_show = min(5, agents[0].A[0].shape[0])
num_rows = len(agents) + 1
row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')

fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True)

for i in range(n_batches_to_show):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(f'batch={i + 1}')

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center')

fig.tight_layout()
n_batches_to_show = min(5, agents[0].A[0].shape[0]) num_rows = len(agents) + 1 row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True) for i in range(n_batches_to_show): for j, agent in enumerate(agents): sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(f'batch={i + 1}') for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Here, once again the agents have to learn the A matrix, but now have a flat/uninformative prior over the B matrix¶

We use a precise and accurate prior over hidden states¶

In [ ]:
Copied!
C = [jnp.zeros(num_obs[0])]
pA = [jnp.ones_like(env.A[0]) / num_obs[0]]
_A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA)
pB = [jnp.ones_like(env.B[0]) / num_states[0]]
_B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB)
D = [jnp.array(d) for d in env.D]

agents = []
for i in range(5):
    agents.append( 
        Agent(
            _A,
            _B,
            C,
            D,
            E=None,
            pA=pA,
            pB=pB,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(batch_size),
            alpha=jnp.ones(batch_size) * i * .2,
            categorical_obs=False,
            action_selection="stochastic",
            inference_algo="exact",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=batch_size,
            learning_mode="offline",
        )
    )
C = [jnp.zeros(num_obs[0])] pA = [jnp.ones_like(env.A[0]) / num_obs[0]] _A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA) pB = [jnp.ones_like(env.B[0]) / num_states[0]] _B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB) D = [jnp.array(d) for d in env.D] agents = [] for i in range(5): agents.append( Agent( _A, _B, C, D, E=None, pA=pA, pB=pB, policy_len=3, use_utility=False, use_states_info_gain=True, use_param_info_gain=True, gamma=jnp.ones(batch_size), alpha=jnp.ones(batch_size) * i * .2, categorical_obs=False, action_selection="stochastic", inference_algo="exact", num_iter=1, learn_A=True, learn_B=False, learn_D=False, batch_size=batch_size, learning_mode="offline", ) )

Run multiple blocks of active inference on grid world for each batch of agents that was initialized different alpha values.¶

Since the agents are using offline learning (learning_mode = "offline"), this means that parameter updates to the A matrix are performed at the end of each block.¶

In [ ]:
Copied!
pA_ground_truth = 1e4 * env.A[0] + 1e-4
pB_ground_truth = 1e4 * env.B[0] + 1e-4
num_timesteps = 50
num_blocks = 100
key = jr.PRNGKey(0)
block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1))

divs1 = {i: [] for i in range(len(agents))}
divs2 = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    block_keys = block_and_batch_keys[block]
    for i, agent in enumerate(agents):
        init_obs, init_states = vmap(env.reset)(block_keys[:-1])
        last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states})
        agents[i] = last['agent']
        divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1))
        divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB_ground_truth).sum(-1).mean(-1))
pA_ground_truth = 1e4 * env.A[0] + 1e-4 pB_ground_truth = 1e4 * env.B[0] + 1e-4 num_timesteps = 50 num_blocks = 100 key = jr.PRNGKey(0) block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1)) divs1 = {i: [] for i in range(len(agents))} divs2 = {i: [] for i in range(len(agents))} for block in range(num_blocks): block_keys = block_and_batch_keys[block] for i, agent in enumerate(agents): init_obs, init_states = vmap(env.reset)(block_keys[:-1]) last, info = jit(rollout, static_argnums=[1, 2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states}) agents[i] = last['agent'] divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1)) divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB_ground_truth).sum(-1).mean(-1))

Plot the KL divergence between the true parameters and believed parameters over time for the different groups of agents (agents with different levels of action stochasticity)¶

In [8]:
Copied!
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False)
for i in range(len(agents)):
    p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2)

    p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2)

axes[0].legend(title='alpha')
axes[0].set_ylabel('KL divergence')
axes[0].set_xlabel('epoch')
axes[1].set_xlabel('epoch')
axes[0].set_title('A matrix')
axes[1].set_title('B matrix')
fig.tight_layout()
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False) for i in range(len(agents)): p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2) p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2) axes[0].legend(title='alpha') axes[0].set_ylabel('KL divergence') axes[0].set_xlabel('epoch') axes[1].set_xlabel('epoch') axes[0].set_title('A matrix') axes[1].set_title('B matrix') fig.tight_layout()
No description has been provided for this image

Visualize the learned A matrices alongside the true environmental parameters after training¶

In [9]:
Copied!
n_batches_to_show = min(5, agents[0].A[0].shape[0])
num_rows = len(agents) + 1
row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')

fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True)

for i in range(n_batches_to_show):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(f'batch={i + 1}')

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center')

fig.tight_layout()
n_batches_to_show = min(5, agents[0].A[0].shape[0]) num_rows = len(agents) + 1 row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True) for i in range(n_batches_to_show): for j, agent in enumerate(agents): sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(f'batch={i + 1}') for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Visualize the B tensor (just the expected value of the posterior which in absence of learning is same as the prior) alongside the true environmental parameters after training¶

In [10]:
Copied!
num_actions = env.B[0].shape[-1]
base_labels = ["Up", "Right", "Down", "Left", "Stay"]
action_labels = base_labels[:num_actions]

row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')

fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True)

for i in range(num_actions):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(action_labels[i])

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center')


fig.tight_layout()
num_actions = env.B[0].shape[-1] base_labels = ["Up", "Right", "Down", "Left", "Stay"] action_labels = base_labels[:num_actions] row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True) for i in range(num_actions): for j, agent in enumerate(agents): sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(action_labels[i]) for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Here, once again the agents have to learn the A matrix, but now have an informative prior over the B matrix: we fix it to the expected transition distribution under a flat prior over actions, i.e. the agent knows generally that motion on the grid is restricted to the locality, but have no sense of how particular actions relate to particular transitions.¶

We use a flat prior over initial hidden states¶

In [ ]:
Copied!
C = [jnp.zeros(num_obs[0])]
pA = [jnp.ones_like(env.A[0]) / num_obs[0]]
_A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA)
B_collapsed_actions = jnp.clip(env.B[0].sum(-1), max=1)
pB = [jnp.expand_dims(B_collapsed_actions, -1) + jnp.ones_like(env.B[0]) / num_states[0]]
_B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB)

agents = []
for i in range(5):
    agents.append( 
        Agent(
            _A,
            _B,
            C,
            _D,
            E=None,
            pA=pA,
            pB=pB,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(batch_size),
            alpha=jnp.ones(batch_size) * i * .2,
            categorical_obs=False,
            action_selection="stochastic",
            inference_algo="exact",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=batch_size,
            learning_mode="offline",
        )
    )
C = [jnp.zeros(num_obs[0])] pA = [jnp.ones_like(env.A[0]) / num_obs[0]] _A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA) B_collapsed_actions = jnp.clip(env.B[0].sum(-1), max=1) pB = [jnp.expand_dims(B_collapsed_actions, -1) + jnp.ones_like(env.B[0]) / num_states[0]] _B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB) agents = [] for i in range(5): agents.append( Agent( _A, _B, C, _D, E=None, pA=pA, pB=pB, policy_len=3, use_utility=False, use_states_info_gain=True, use_param_info_gain=True, gamma=jnp.ones(batch_size), alpha=jnp.ones(batch_size) * i * .2, categorical_obs=False, action_selection="stochastic", inference_algo="exact", num_iter=1, learn_A=True, learn_B=False, learn_D=False, batch_size=batch_size, learning_mode="offline", ) )

Run multiple blocks of active inference on grid world for each batch of agents that was initialized different alpha values.¶

Since the agents are using offline learning (learning_mode = "offline"), this means that parameter updates to the A matrix are performed at the end of each block.¶

In [ ]:
Copied!
pA_ground_truth = 1e4 * env.A[0] + 1e-4
pB_ground_truth = 1e4 * env.B[0] + 1e-4
num_timesteps = 50
num_blocks = 100
key = jr.PRNGKey(0)
block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1))

divs1 = {i: [] for i in range(len(agents))}
divs2 = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    block_keys = block_and_batch_keys[block]
    for i, agent in enumerate(agents):
        init_obs, init_states = vmap(env.reset)(block_keys[:-1])
        last, info = jit(rollout, static_argnums=[1,2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states})
        agents[i] = last['agent']
        divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1))
        divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB_ground_truth).sum(-1).mean(-1))
pA_ground_truth = 1e4 * env.A[0] + 1e-4 pB_ground_truth = 1e4 * env.B[0] + 1e-4 num_timesteps = 50 num_blocks = 100 key = jr.PRNGKey(0) block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1)) divs1 = {i: [] for i in range(len(agents))} divs2 = {i: [] for i in range(len(agents))} for block in range(num_blocks): block_keys = block_and_batch_keys[block] for i, agent in enumerate(agents): init_obs, init_states = vmap(env.reset)(block_keys[:-1]) last, info = jit(rollout, static_argnums=[1,2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states}) agents[i] = last['agent'] divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1)) divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB_ground_truth).sum(-1).mean(-1))

Plot the KL divergence between the true parameters and believed parameters over time for the different groups of agents (agents with different levels of action stochasticity)¶

In [11]:
Copied!
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False)
for i in range(len(agents)):
    p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2)

    p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2)

axes[0].legend(title='alpha')
axes[0].set_ylabel('KL divergence')
axes[0].set_xlabel('epoch')
axes[1].set_xlabel('epoch')
axes[0].set_title('A matrix')
axes[1].set_title('B matrix')
fig.tight_layout()
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False) for i in range(len(agents)): p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2) p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2) axes[0].legend(title='alpha') axes[0].set_ylabel('KL divergence') axes[0].set_xlabel('epoch') axes[1].set_xlabel('epoch') axes[0].set_title('A matrix') axes[1].set_title('B matrix') fig.tight_layout()
No description has been provided for this image

Visualize the learned A matrices alongside the true environmental parameters after training¶

In [12]:
Copied!
n_batches_to_show = min(5, agents[0].A[0].shape[0])
num_rows = len(agents) + 1
row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')

fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True)

for i in range(n_batches_to_show):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis')
    sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(f'batch={i + 1}')

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center')

fig.tight_layout()
n_batches_to_show = min(5, agents[0].A[0].shape[0]) num_rows = len(agents) + 1 row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True) for i in range(n_batches_to_show): for j, agent in enumerate(agents): sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis') sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(f'batch={i + 1}') for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Visualize the B tensor (just the expected value of the posterior which in absence of learning is same as the prior) alongside the true environmental parameters after training¶

In [13]:
Copied!
num_actions = env.B[0].shape[-1]
base_labels = ["Up", "Right", "Down", "Left", "Stay"]
action_labels = base_labels[:num_actions]

row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')


fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True)


for i in range(num_actions):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(action_labels[i])

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center')


fig.tight_layout()
num_actions = env.B[0].shape[-1] base_labels = ["Up", "Right", "Down", "Left", "Stay"] action_labels = base_labels[:num_actions] row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True) for i in range(num_actions): for j, agent in enumerate(agents): sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(action_labels[i]) for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Here, once again the agents have to learn the A matrix, but now have an informative prior over the B matrix: we fix it to the expected transition distribution under a flat prior over actions, i.e. the agent knows generally that motion on the grid is restricted to the locality, but have no sense of how particular actions relate to particular transitions.¶

We use a precise and accruate prior over initial hidden states¶

In [ ]:
Copied!
C = [jnp.zeros(num_obs[0])]
pA = [jnp.ones_like(env.A[0]) / num_obs[0]]
_A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA)
B_collapsed_actions = jnp.clip(env.B[0].sum(-1), max=1)
pB = [jnp.expand_dims(B_collapsed_actions, -1) + jnp.ones_like(env.B[0]) / num_states[0]]
_B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB)

agents = []
for i in range(5):
    agents.append( 
        Agent(
            _A,
            _B,
            C,
            D,
            E=None,
            pA=pA,
            pB=pB,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(batch_size),
            alpha=jnp.ones(batch_size) * i * .2,
            categorical_obs=False,
            action_selection="stochastic",
            inference_algo="exact",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=batch_size,
            learning_mode="offline",
        )
    )
C = [jnp.zeros(num_obs[0])] pA = [jnp.ones_like(env.A[0]) / num_obs[0]] _A = jtu.tree_map(lambda a: dirichlet_expected_value(a), pA) B_collapsed_actions = jnp.clip(env.B[0].sum(-1), max=1) pB = [jnp.expand_dims(B_collapsed_actions, -1) + jnp.ones_like(env.B[0]) / num_states[0]] _B = jtu.tree_map(lambda b: dirichlet_expected_value(b), pB) agents = [] for i in range(5): agents.append( Agent( _A, _B, C, D, E=None, pA=pA, pB=pB, policy_len=3, use_utility=False, use_states_info_gain=True, use_param_info_gain=True, gamma=jnp.ones(batch_size), alpha=jnp.ones(batch_size) * i * .2, categorical_obs=False, action_selection="stochastic", inference_algo="exact", num_iter=1, learn_A=True, learn_B=False, learn_D=False, batch_size=batch_size, learning_mode="offline", ) )

Run multiple blocks of active inference on grid world for each batch of agents that was initialized different alpha values.¶

Since the agents are using offline learning (learning_mode = "offline"), this means that parameter updates to the A matrix are performed at the end of each block.¶

In [ ]:
Copied!
pA_ground_truth = 1e4 * env.A[0] + 1e-4
pB_ground_truth = 1e4 * env.B[0] + 1e-4
num_timesteps = 50
num_blocks = 100
key = jr.PRNGKey(0)
block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1))

divs1 = {i: [] for i in range(len(agents))}
divs2 = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    block_keys = block_and_batch_keys[block]
    for i, agent in enumerate(agents):
        init_obs, init_states = vmap(env.reset)(block_keys[:-1])
        last, info = jit(rollout, static_argnums=[1,2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states})
        agents[i] = last['agent']
        divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1))
        divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB_ground_truth).sum(-1).mean(-1))
pA_ground_truth = 1e4 * env.A[0] + 1e-4 pB_ground_truth = 1e4 * env.B[0] + 1e-4 num_timesteps = 50 num_blocks = 100 key = jr.PRNGKey(0) block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1)) divs1 = {i: [] for i in range(len(agents))} divs2 = {i: [] for i in range(len(agents))} for block in range(num_blocks): block_keys = block_and_batch_keys[block] for i, agent in enumerate(agents): init_obs, init_states = vmap(env.reset)(block_keys[:-1]) last, info = jit(rollout, static_argnums=[1,2])(agent, env, num_timesteps, block_keys[-1], initial_carry={"observation": init_obs, "env_state": init_states}) agents[i] = last['agent'] divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA_ground_truth).mean(-1)) divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB_ground_truth).sum(-1).mean(-1))

Plot the KL divergence between the true parameters and believed parameters over time for the different groups of agents (agents with different levels of action stochasticity)¶

In [14]:
Copied!
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False)
for i in range(len(agents)):
    p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2)

    p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2)

axes[0].legend(title='alpha')
axes[0].set_ylabel('KL divergence')
axes[0].set_xlabel('epoch')
axes[1].set_xlabel('epoch')
axes[0].set_title('A matrix')
axes[1].set_title('B matrix')
fig.tight_layout()
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False) for i in range(len(agents)): p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2) p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean()) axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2) axes[0].legend(title='alpha') axes[0].set_ylabel('KL divergence') axes[0].set_xlabel('epoch') axes[1].set_xlabel('epoch') axes[0].set_title('A matrix') axes[1].set_title('B matrix') fig.tight_layout()
No description has been provided for this image

Visualize the learned A matrices alongside the true environmental parameters after training¶

In [15]:
Copied!
n_batches_to_show = min(5, agents[0].A[0].shape[0])
num_rows = len(agents) + 1
row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')

fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True)

for i in range(n_batches_to_show):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(f'batch={i + 1}')

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center')

fig.tight_layout()
n_batches_to_show = min(5, agents[0].A[0].shape[0]) num_rows = len(agents) + 1 row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(num_rows, n_batches_to_show, figsize=(16, 8), sharex=True, sharey=True) for i in range(n_batches_to_show): for j, agent in enumerate(agents): sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.A[0], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(f'batch={i + 1}') for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true A', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Visualize the B tensor (just the expected value of the posterior which in absence of learning is same as the prior) alongside the true environmental parameters after training¶

In [16]:
Copied!
num_actions = env.B[0].shape[-1]
base_labels = ["Up", "Right", "Down", "Left", "Stay"]
action_labels = base_labels[:num_actions]

row_labels = []
for agent in agents:
    alpha_value = float(agent.alpha.mean(0).squeeze())
    row_labels.append(f'alpha={alpha_value:.2f}')


fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True)


for i in range(num_actions):
    for j, agent in enumerate(agents):
        sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(action_labels[i])

for j, label in enumerate(row_labels):
    axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center')
axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center')


fig.tight_layout()
num_actions = env.B[0].shape[-1] base_labels = ["Up", "Right", "Down", "Left", "Stay"] action_labels = base_labels[:num_actions] row_labels = [] for agent in agents: alpha_value = float(agent.alpha.mean(0).squeeze()) row_labels.append(f'alpha={alpha_value:.2f}') fig, axes = plt.subplots(len(agents)+1, num_actions, figsize=(16, 8), sharex=True, sharey=True) for i in range(num_actions): for j, agent in enumerate(agents): sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.) sns.heatmap(env.B[0][..., i], ax=axes[-1, i], cmap='viridis', vmax=1., vmin=0.) axes[0, i].set_title(action_labels[i]) for j, label in enumerate(row_labels): axes[j, 0].set_ylabel(label, rotation=25, labelpad=60, ha='left', va='center') axes[-1, 0].set_ylabel('true B', rotation=0, labelpad=40, ha='left', va='center') fig.tight_layout()
No description has been provided for this image

Made with Dracula Theme for MkDocs