pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

T-Maze Task with Distractors¶

In this tutorial, we set up a T-maze task where an agent navigates a grid to locate rewards. The agent can move in four directions: up, down, left, or right.

The agent would navigate to the cues to obtain information about two potential reward locations. While these cues guide the agent's search, they do not guarantee the presence of a real reward. The agent must learn that some cues are distractors, so it may follow them initially but should recognize when a location does not contain a real reward and continue exploring.

This is to test sophisticated inference.

Initial Setup¶

To run this notebook, it's recommended to use a virtual environment. You can refer to this guide to learn how to set up a virtual environment.

Then, install the current repo as a package in editable mode by running the following command in your terminal:

pip install -e .

And then, install mediapy by running the following command in your terminal:

pip install mediapy
In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp[nb]" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp[nb]" -q
In [ ]:
Copied!
%load_ext autoreload
%autoreload 2

from jax import numpy as jnp, random as jr

from pymdp.envs.generalized_tmaze import (
    GeneralizedTMazeEnv, parse_maze, get_maze_matrix 
)

from pymdp.envs.rollout import rollout
from pymdp.agent import Agent
from pymdp.utils import list_array_uniform
%load_ext autoreload %autoreload 2 from jax import numpy as jnp, random as jr from pymdp.envs.generalized_tmaze import ( GeneralizedTMazeEnv, parse_maze, get_maze_matrix ) from pymdp.envs.rollout import rollout from pymdp.agent import Agent from pymdp.utils import list_array_uniform

Creating the Environment¶

The environment consists of the following elements, which are represented in the figure:

  • Agent (Green Square): The agent starts at a designated position and explores the grid to locate rewards.
  • Cues and Reward Sets: Each colour represents a set that includes one cue (marked by a black cross) and two potential reward locations associated with that cue.
    • Green Set (True Reward Set): Contains real outcomes. Each reward location has a secondary dot to indicate the type of outcome:
      • Red Dot: Indicates a reward.
      • Blue Dot: Indicates a punishment.
    • Orange and Blue Sets (Distractor Sets): These are distractor sets that do not contain real rewards. The reward locations in these sets lack a secondary dot, indicating a false reward location. Although the cues point to these locations, they do not contain actual rewards.
In [1]:
Copied!
M = get_maze_matrix()
key = jr.PRNGKey(0)

key, subkey = jr.split(key)
env_info_m = parse_maze(M, subkey)

tmaze_env = GeneralizedTMazeEnv(env_info_m)

init_obs, init_state = tmaze_env.reset(key)
tmaze_env.render(states=init_state, mode="human")
M = get_maze_matrix() key = jr.PRNGKey(0) key, subkey = jr.split(key) env_info_m = parse_maze(M, subkey) tmaze_env = GeneralizedTMazeEnv(env_info_m) init_obs, init_state = tmaze_env.reset(key) tmaze_env.render(states=init_state, mode="human")
<Figure size 640x480 with 0 Axes>
No description has been provided for this image

Creating the Agent¶

State Factors:

  • Position: however large the grid is
  • Reward: location of the reward

Observation Modalities:

  • Position: however large the grid is
  • Cued info: null, first location, second location
  • Reward: null, no reward, reward

The PymdpEnv class contains the environmental parameters as POMDP parameters (A, B, and D). We initialize our agent's generative model using the same parameters. This means that the agent has full knowledge about the environment transitions, and likelihoods. We initialize the agent with a flat prior, i.e. it does not know where it, or the reward is. Finally, we set the C vector to have a preference only over the rewarding observation of cue-reward pair 1 (i.e. C[1][1] = 1 and zero for other values).

In [ ]:
Copied!
A, B = tmaze_env.A, tmaze_env.B
A_dependencies, B_dependencies = tmaze_env.A_dependencies, tmaze_env.B_dependencies

# [position], [cue], [reward]
C = [jnp.zeros(a.shape[0]) for a in A]

rewarding_modality = -1

C[rewarding_modality] = C[rewarding_modality].at[1].set(1.0)
C[rewarding_modality] = C[rewarding_modality].at[2].set(-3.0)

D = list_array_uniform([b.shape[0] for b in B])

# make 9 different agents to simulate in parallel
batch_size = 9


agent = Agent(
    A, B, C, D, 
    E=None,
    pA=None,
    pB=None,
    policy_len=5,
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    use_utility=True,
    use_states_info_gain=True,
    sampling_mode='full',
    action_selection='stochastic',
    gamma=4.0,
    batch_size=batch_size,
    learn_A=False,
    learn_B=False
)
A, B = tmaze_env.A, tmaze_env.B A_dependencies, B_dependencies = tmaze_env.A_dependencies, tmaze_env.B_dependencies # [position], [cue], [reward] C = [jnp.zeros(a.shape[0]) for a in A] rewarding_modality = -1 C[rewarding_modality] = C[rewarding_modality].at[1].set(1.0) C[rewarding_modality] = C[rewarding_modality].at[2].set(-3.0) D = list_array_uniform([b.shape[0] for b in B]) # make 9 different agents to simulate in parallel batch_size = 9 agent = Agent( A, B, C, D, E=None, pA=None, pB=None, policy_len=5, A_dependencies=A_dependencies, B_dependencies=B_dependencies, use_utility=True, use_states_info_gain=True, sampling_mode='full', action_selection='stochastic', gamma=4.0, batch_size=batch_size, learn_A=False, learn_B=False )
In [2]:
Copied!
print("A tensors")
print(agent.A[0].shape)
print(agent.A[1].shape)
print(agent.A[2].shape)
print(agent.A[3].shape)
print(agent.A[4].shape)
print(agent.A[5].shape)
print(agent.A[6].shape)
print()
print("B tensors")
print(agent.B[0].shape)
print(agent.B[1].shape)
print(agent.B[2].shape)
print(agent.B[3].shape)
print()
print("C tensors")
print(agent.C[0].shape)
print(agent.C[1].shape)
print(agent.C[2].shape)
print(agent.C[3].shape)
print(agent.C[4].shape)
print(agent.C[5].shape)
print(agent.C[6].shape)
print()
print("D tensors")
print(agent.D[0].shape)
print(agent.D[1].shape)
print(agent.D[2].shape)
print(agent.D[3].shape)
print()
print("A and B dependencies")
print(agent.A_dependencies)
print(agent.B_dependencies)
print("A tensors") print(agent.A[0].shape) print(agent.A[1].shape) print(agent.A[2].shape) print(agent.A[3].shape) print(agent.A[4].shape) print(agent.A[5].shape) print(agent.A[6].shape) print() print("B tensors") print(agent.B[0].shape) print(agent.B[1].shape) print(agent.B[2].shape) print(agent.B[3].shape) print() print("C tensors") print(agent.C[0].shape) print(agent.C[1].shape) print(agent.C[2].shape) print(agent.C[3].shape) print(agent.C[4].shape) print(agent.C[5].shape) print(agent.C[6].shape) print() print("D tensors") print(agent.D[0].shape) print(agent.D[1].shape) print(agent.D[2].shape) print(agent.D[3].shape) print() print("A and B dependencies") print(agent.A_dependencies) print(agent.B_dependencies)
A tensors
(9, 25, 25)
(9, 3, 25, 2)
(9, 3, 25, 2)
(9, 3, 25, 2)
(9, 3, 25, 2)
(9, 3, 25, 2)
(9, 3, 25, 2)

B tensors
(9, 25, 25, 5)
(9, 2, 2, 1)
(9, 2, 2, 1)
(9, 2, 2, 1)

C tensors
(9, 25)
(9, 3)
(9, 3)
(9, 3)
(9, 3)
(9, 3)
(9, 3)

D tensors
(9, 25)
(9, 2)
(9, 2)
(9, 2)

A and B dependencies
[[0], [0, 1], [0, 2], [0, 3], [0, 1], [0, 2], [0, 3]]
[[0], [1], [2], [3]]

Rollout an agent episode¶

Using the rollout function, we can run an active inference agent in this environment over a specified number of discrete timesteps using the parameters previously set.

In [ ]:
Copied!
key = jr.PRNGKey(1)
T = 10
_, info = rollout(agent, tmaze_env, num_timesteps=T, rng_key=key)
key = jr.PRNGKey(1) T = 10 _, info = rollout(agent, tmaze_env, num_timesteps=T, rng_key=key)
In [3]:
Copied!
images = []
for t in range(T):
    env_state_t = [s[:, t] for s in info['env_state']]
    images.append(tmaze_env.render(states=env_state_t, mode="rgb_array"))
images = [] for t in range(T): env_state_t = [s[:, t] for s in info['env_state']] images.append(tmaze_env.render(states=env_state_t, mode="rgb_array"))
<Figure size 640x480 with 0 Axes>
In [4]:
Copied!
# make animation
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML


def animate(images, savefile=None, interval=1000):
    # Make a bigger figure (pick whatever looks good)
    fig = plt.figure(figsize=(6, 6), dpi=150)

    # Axes that fills the entire figure
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_axis_off()

    im = ax.imshow(images[0], animated=True)

    def update(k):
        im.set_data(images[k])
        return (im,)

    ani = animation.FuncAnimation(
        fig, update, frames=len(images), interval=interval, blit=True, repeat_delay=1000
    )

    if savefile is not None:
        ani.save(savefile)

    plt.close(fig)
    return ani

ani = animate(images)
HTML(ani.to_html5_video())
# make animation import matplotlib.pyplot as plt import matplotlib.animation as animation from IPython.display import HTML def animate(images, savefile=None, interval=1000): # Make a bigger figure (pick whatever looks good) fig = plt.figure(figsize=(6, 6), dpi=150) # Axes that fills the entire figure ax = fig.add_axes([0, 0, 1, 1]) ax.set_axis_off() im = ax.imshow(images[0], animated=True) def update(k): im.set_data(images[k]) return (im,) ani = animation.FuncAnimation( fig, update, frames=len(images), interval=interval, blit=True, repeat_delay=1000 ) if savefile is not None: ani.save(savefile) plt.close(fig) return ani ani = animate(images) HTML(ani.to_html5_video())
Out[4]:
Your browser does not support the video tag.

Made with Dracula Theme for MkDocs