pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Imports¶

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import jax.tree_util as jtu
from jax import numpy as jnp, random as jr
from jax import nn, vmap
from equinox import tree_at

import numpy as np

from pymdp.envs import GridWorld, rollout
from pymdp import control
from pymdp.agent import Agent
import jax.tree_util as jtu from jax import numpy as jnp, random as jr from jax import nn, vmap from equinox import tree_at import numpy as np from pymdp.envs import GridWorld, rollout from pymdp import control from pymdp.agent import Agent

Grid world generative model¶

In [ ]:
Copied!
# size of the grid world
grid_shape = (7, 7)

# start in the middle of the grid
env = GridWorld(shape=grid_shape, initial_position=(3,3), include_stay=False)

desired_state = (6,6)  # bottom right corner
# get linear index of desired state
desired_state_id = env.coords_to_index(shape=grid_shape, coord=desired_state)

# create helpful num_obs and num_states lists (lists of observation dimensions per modality, and state dimensions per factor, respectively)
num_obs = [a.shape[0] for a in env.A]
num_states = [b.shape[0] for b in env.B]
num_controls = [b.shape[-1] for b in env.B]
# size of the grid world grid_shape = (7, 7) # start in the middle of the grid env = GridWorld(shape=grid_shape, initial_position=(3,3), include_stay=False) desired_state = (6,6) # bottom right corner # get linear index of desired state desired_state_id = env.coords_to_index(shape=grid_shape, coord=desired_state) # create helpful num_obs and num_states lists (lists of observation dimensions per modality, and state dimensions per factor, respectively) num_obs = [a.shape[0] for a in env.A] num_states = [b.shape[0] for b in env.B] num_controls = [b.shape[-1] for b in env.B]

Planning and inductive inference parameters¶

In [ ]:
Copied!
# number of agents
batch_size = 5

planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 7
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)

# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))
# number of agents batch_size = 5 planning_horizon, inductive_threshold = 1, 0.1 inductive_depth = 7 policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon) # inductive planning goal states H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))

Initialize an Agent()¶

In [1]:
Copied!
# create agent, using generative process parameters from the environment to initialize the generative model
A = env.A
B = env.B
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = env.D
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon, 
            inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
            H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
# create agent, using generative process parameters from the environment to initialize the generative model A = env.A B = env.B C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f])) D = env.D agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon, inductive_depth=inductive_depth, inductive_threshold=inductive_threshold, H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61059/4163094735.py:6: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,

Run active inference¶

In [ ]:
Copied!
T = 7 
last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0))
T = 7 last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0))
In [2]:
Copied!
info['env_state'][0].shape
info['env_state'][0].shape
Out[2]:
(5, 8)
In [3]:
Copied!
agent_id_to_track = 1
for t in range(T):
    state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t])
    print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
agent_id_to_track = 1 for t in range(T): state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t]) print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
Grid position for agent 2 at time 0: (3, 3)
Grid position for agent 2 at time 1: (3, 4)
Grid position for agent 2 at time 2: (3, 5)
Grid position for agent 2 at time 3: (3, 6)
Grid position for agent 2 at time 4: (4, 6)
Grid position for agent 2 at time 5: (5, 6)
Grid position for agent 2 at time 6: (6, 6)

Now the agent starts further from the goal and thus need more timesteps to reach it¶

In [ ]:
Copied!
# size of the grid world
grid_shape = (7, 7)

# number of agents
batch_size = 5

# start in the upper left corner this time
upper_left_initial_states = [jnp.repeat(jnp.array(env.coords_to_index(shape=grid_shape, coord=(0,0))), batch_size, axis=0)]
# size of the grid world grid_shape = (7, 7) # number of agents batch_size = 5 # start in the upper left corner this time upper_left_initial_states = [jnp.repeat(jnp.array(env.coords_to_index(shape=grid_shape, coord=(0,0))), batch_size, axis=0)]

Increase inductive planning depth in order to compute the needed inductive planning matrix¶

In [ ]:
Copied!
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 14
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)

# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))
planning_horizon, inductive_threshold = 1, 0.1 inductive_depth = 14 policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon) # inductive planning goal states H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))
In [4]:
Copied!
# create agent, using generative process parameters from the environment to initialize the generative model
A, B = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in env.A], [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in env.B]
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = [nn.one_hot(upper_left_initial_states[0], num_states[0])] # need to do this since the D of the environment won't match the env state if you reset to a different state
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon, 
            inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
            H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
# create agent, using generative process parameters from the environment to initialize the generative model A, B = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in env.A], [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in env.B] C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f])) D = [nn.one_hot(upper_left_initial_states[0], num_states[0])] # need to do this since the D of the environment won't match the env state if you reset to a different state agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon, inductive_depth=inductive_depth, inductive_threshold=inductive_threshold, H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61059/1146799413.py:5: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,

Run active inference¶

In [ ]:
Copied!
T = 14
# construct a partial initial_carry that only overwrites env_state/observation
init_obs, init_env_state = vmap(env.reset)(
    jr.split(jr.PRNGKey(1), batch_size), state=upper_left_initial_states
)
initial_carry_overwrite = {"observation": init_obs, "env_state": init_env_state}
last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0), initial_carry=initial_carry_overwrite)
T = 14 # construct a partial initial_carry that only overwrites env_state/observation init_obs, init_env_state = vmap(env.reset)( jr.split(jr.PRNGKey(1), batch_size), state=upper_left_initial_states ) initial_carry_overwrite = {"observation": init_obs, "env_state": init_env_state} last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0), initial_carry=initial_carry_overwrite)
In [5]:
Copied!
agent_id_to_track = 1
for t in range(T):
    state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t])
    print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
agent_id_to_track = 1 for t in range(T): state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t]) print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
Grid position for agent 2 at time 0: (0, 0)
Grid position for agent 2 at time 1: (0, 1)
Grid position for agent 2 at time 2: (0, 2)
Grid position for agent 2 at time 3: (0, 3)
Grid position for agent 2 at time 4: (0, 4)
Grid position for agent 2 at time 5: (0, 5)
Grid position for agent 2 at time 6: (0, 6)
Grid position for agent 2 at time 7: (1, 6)
Grid position for agent 2 at time 8: (2, 6)
Grid position for agent 2 at time 9: (3, 6)
Grid position for agent 2 at time 10: (4, 6)
Grid position for agent 2 at time 11: (5, 6)
Grid position for agent 2 at time 12: (6, 6)
Grid position for agent 2 at time 13: (6, 6)

Made with Dracula Theme for MkDocs