Imports¶
In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import jax.tree_util as jtu
from jax import numpy as jnp, random as jr
from jax import nn, vmap
from equinox import tree_at
import numpy as np
from pymdp.envs import GridWorld, rollout
from pymdp import control
from pymdp.agent import Agent
import jax.tree_util as jtu
from jax import numpy as jnp, random as jr
from jax import nn, vmap
from equinox import tree_at
import numpy as np
from pymdp.envs import GridWorld, rollout
from pymdp import control
from pymdp.agent import Agent
Grid world generative model¶
In [ ]:
Copied!
# size of the grid world
grid_shape = (7, 7)
# start in the middle of the grid
env = GridWorld(shape=grid_shape, initial_position=(3,3), include_stay=False)
desired_state = (6,6) # bottom right corner
# get linear index of desired state
desired_state_id = env.coords_to_index(shape=grid_shape, coord=desired_state)
# create helpful num_obs and num_states lists (lists of observation dimensions per modality, and state dimensions per factor, respectively)
num_obs = [a.shape[0] for a in env.A]
num_states = [b.shape[0] for b in env.B]
num_controls = [b.shape[-1] for b in env.B]
# size of the grid world
grid_shape = (7, 7)
# start in the middle of the grid
env = GridWorld(shape=grid_shape, initial_position=(3,3), include_stay=False)
desired_state = (6,6) # bottom right corner
# get linear index of desired state
desired_state_id = env.coords_to_index(shape=grid_shape, coord=desired_state)
# create helpful num_obs and num_states lists (lists of observation dimensions per modality, and state dimensions per factor, respectively)
num_obs = [a.shape[0] for a in env.A]
num_states = [b.shape[0] for b in env.B]
num_controls = [b.shape[-1] for b in env.B]
Planning and inductive inference parameters¶
In [ ]:
Copied!
# number of agents
batch_size = 5
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 7
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)
# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))
# number of agents
batch_size = 5
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 7
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)
# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))
Initialize an Agent()¶
In [1]:
Copied!
# create agent, using generative process parameters from the environment to initialize the generative model
A = env.A
B = env.B
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = env.D
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,
inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
# create agent, using generative process parameters from the environment to initialize the generative model
A = env.A
B = env.B
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = env.D
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,
inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61059/4163094735.py:6: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,
Run active inference¶
In [ ]:
Copied!
T = 7
last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0))
T = 7
last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0))
In [2]:
Copied!
info['env_state'][0].shape
info['env_state'][0].shape
Out[2]:
(5, 8)
In [3]:
Copied!
agent_id_to_track = 1
for t in range(T):
state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t])
print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
agent_id_to_track = 1
for t in range(T):
state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t])
print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
Grid position for agent 2 at time 0: (3, 3) Grid position for agent 2 at time 1: (3, 4) Grid position for agent 2 at time 2: (3, 5) Grid position for agent 2 at time 3: (3, 6) Grid position for agent 2 at time 4: (4, 6) Grid position for agent 2 at time 5: (5, 6) Grid position for agent 2 at time 6: (6, 6)
Now the agent starts further from the goal and thus need more timesteps to reach it¶
In [ ]:
Copied!
# size of the grid world
grid_shape = (7, 7)
# number of agents
batch_size = 5
# start in the upper left corner this time
upper_left_initial_states = [jnp.repeat(jnp.array(env.coords_to_index(shape=grid_shape, coord=(0,0))), batch_size, axis=0)]
# size of the grid world
grid_shape = (7, 7)
# number of agents
batch_size = 5
# start in the upper left corner this time
upper_left_initial_states = [jnp.repeat(jnp.array(env.coords_to_index(shape=grid_shape, coord=(0,0))), batch_size, axis=0)]
Increase inductive planning depth in order to compute the needed inductive planning matrix¶
In [ ]:
Copied!
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 14
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)
# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 14
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)
# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))
In [4]:
Copied!
# create agent, using generative process parameters from the environment to initialize the generative model
A, B = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in env.A], [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in env.B]
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = [nn.one_hot(upper_left_initial_states[0], num_states[0])] # need to do this since the D of the environment won't match the env state if you reset to a different state
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,
inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
# create agent, using generative process parameters from the environment to initialize the generative model
A, B = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in env.A], [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in env.B]
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = [nn.one_hot(upper_left_initial_states[0], num_states[0])] # need to do this since the D of the environment won't match the env state if you reset to a different state
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,
inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61059/1146799413.py:5: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon,
Run active inference¶
In [ ]:
Copied!
T = 14
# construct a partial initial_carry that only overwrites env_state/observation
init_obs, init_env_state = vmap(env.reset)(
jr.split(jr.PRNGKey(1), batch_size), state=upper_left_initial_states
)
initial_carry_overwrite = {"observation": init_obs, "env_state": init_env_state}
last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0), initial_carry=initial_carry_overwrite)
T = 14
# construct a partial initial_carry that only overwrites env_state/observation
init_obs, init_env_state = vmap(env.reset)(
jr.split(jr.PRNGKey(1), batch_size), state=upper_left_initial_states
)
initial_carry_overwrite = {"observation": init_obs, "env_state": init_env_state}
last, info = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0), initial_carry=initial_carry_overwrite)
In [5]:
Copied!
agent_id_to_track = 1
for t in range(T):
state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t])
print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
agent_id_to_track = 1
for t in range(T):
state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env_state'][0][agent_id_to_track, t])
print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")
Grid position for agent 2 at time 0: (0, 0) Grid position for agent 2 at time 1: (0, 1) Grid position for agent 2 at time 2: (0, 2) Grid position for agent 2 at time 3: (0, 3) Grid position for agent 2 at time 4: (0, 4) Grid position for agent 2 at time 5: (0, 5) Grid position for agent 2 at time 6: (0, 6) Grid position for agent 2 at time 7: (1, 6) Grid position for agent 2 at time 8: (2, 6) Grid position for agent 2 at time 9: (3, 6) Grid position for agent 2 at time 10: (4, 6) Grid position for agent 2 at time 11: (5, 6) Grid position for agent 2 at time 12: (6, 6) Grid position for agent 2 at time 13: (6, 6)