pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Validating Sophisticated Inference (SI) Planning Algorithm using the T-Maze Task¶

Overview¶

This notebook provides a comparative test of two planning algorithms within the active inference framework:

  1. Vanilla Active Inference Planning: Standard and default planning approach in pymdp.control.infer_policies() that computes the expected free energy of policies in parallel, and the log probability of each policy is proportional to its expected free energy. The expected free energy of each policy is computed by propagating beliefs about hidden states forward in time under each policy (a typical 'posterior predictive rollout'), without updating those beliefs with counterfactual observations, expected under those future states. This means the belief over hidden states at timestep t in a given rollout, only depends on:
  • the previous belief over hidden states at time t and
  • the action entailed by the policy in question at time t.
  1. Sophisticated Inference (SI) Planning: A more complex but thorough planning approach that evaluates policies using recursive expected free energy calculations, where the agent considers how its beliefs would change under counterfactual observations encountered in the future. This algorithm has higher complexity since it branches on both observations and actions in the future, but it more accurately captures how uncertainty over hidden states changes in the future, as a function of (possibly-encountered) observations.

This notebook uses a simplified version of the T-Maze environment TMazeSimplified as described in the sophisticated inference paper to validate the SI implementation.

The T-Maze Task¶

The T-maze is a classic sequential decision-making problem where the agent (analogized to a rat) faces a fundamental exploration vs. exploitation dilemma. The agent begins at the center of a T-shaped maze, where there is also a left arm, right arm, and cue location (bottom arm). One arm contains reward (cheese), the other punishment (shock), but the agent does not know which arm contains what. A cue at the bottom provides information about reward location with a given accuracy. This version of the T-Maze differs from that described in the T-Maze Demo because there is no "middle" location between the left and right arms, and because every location is reachable from every other location. Additionally, location and cue observations are embedded into a single modality, rather than separated out into two separate modalities.

The agent must choose between immediate but risky exploitation by directly committing to one of the reward arms (50% chance of success) or gather information (explore) first by visiting the cue location at the expense of time, then make an informed choice. When cue validity > 50%, the optimal strategy is to first gather information at the cue location, then select the indicated rewarding arm. This notebook tests whether both planning algorithms can discover this optimal policy and how.

Notebook Structure¶

  1. Environment Setup: Set up the generative process for the T-Maze environment
  2. Agent Setup: Set up the generative model for the agent
  3. Active Inference Rollout: Run active inference rollouts with vanilla and sophisticated inference planning algorithms, with optional visualizations of the agents' behavior
  4. Results Analysis: Compare actions selected and policy evaluations
In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp[nb]" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp[nb]" -q
In [ ]:
Copied!
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp
import jax.random as jr

import matplotlib.pyplot as plt
import numpy as np
import mediapy

from pymdp.envs import SimplifiedTMaze, rollout
from pymdp.agent import Agent
from pymdp.planning.si import si_policy_search
%load_ext autoreload %autoreload 2 import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np import mediapy from pymdp.envs import SimplifiedTMaze, rollout from pymdp.agent import Agent from pymdp.planning.si import si_policy_search

Setting up the T-Maze environment (Generative Process)¶

  • The reward_condition parameter determines the reward location: 0 for the left arm, 1 for the right arm, or None for random allocation.
  • The cue_validity parameter (default 0.95) represents the accuracy of the cues as a probability.
  • The reward_probability parameter sets the probability a of receiving a reward in the correct arm.
  • With dependent_outcomes=True, the remaining probability (1-a) becomes punishment in the correct arm (and reward in the incorrect arm). With dependent_outcomes=False, the remaining probability in the correct arm is no-outcome, and punishment in the incorrect arm is set by punishment_probability (with no-outcome as the remainder).
Click here to see how the generative process is set up.

States and Observations¶

State Factors:

  1. Location (4 states):
    • 0: center (start location)
    • 1: left arm
    • 2: right arm
    • 3: cue location (bottom arm)
  2. Reward Location (2 states):
    • 0: reward in left arm
    • 1: reward in right arm

Control State Factors:

  1. Location (4 actions):
    • 0: move to center (start location)
    • 1: move to left arm
    • 2: move to right arm
    • 3: move to cue location (bottom arm)

Observation Modalities:

  1. Location (5 observations):
    • 0: center (start location)
    • 1: left arm
    • 2: right arm
    • 3: cued left arm (bottom arm)
    • 4: cued right arm (bottom arm)
  2. Outcome (3 observations):
    • 0: no outcome
    • 1: reward (cheese)
    • 2: punishment (shock)

Environment Parameters¶

Observation Likelihood Model (A):

  • A[0]: Location observations (5x4x2 tensor)
    • Perfect mapping between true and observed locations.
    • At the cue location (bottom), the rewarding arm is cued with accuracy set by the cue_validity parameter.
  • A[1]: Outcome observations (3x4x2 tensor)
    • In the rewarding arm (set by reward_condition), reward is presented with a likelihood determined by the reward_probability parameter.
    • Punishment and/or no-outcome are presented with a likelihood determined depending on if dependent_outcome is True or False and consequently by the punishment_probability parameter.
    • No-outcome is observed in the center/start location and cue location.

Transition Model (B):

  • B[0]: Location transitions (4x4x4 tensor)
    • Agent can move to any one of the four locations from any location regardless of adjacency.
  • B[1]: Reward location (2x2x1 tensor)
    • Reward location remains fixed throughout trial.

Initial Conditions (D):

  • D[0]: Starting location (4x1 tensor)
    • Agent always begins in center location
  • D[1]: Reward placement (2x1 tensor)
    • Default: Equal chance (50/50) of reward in either arm (reward_condition=None)
    • Optional: Can fix reward to specific arm, by setting reward_condition to 0 (for left arm) or 1 (for right arm)
In [ ]:
Copied!
# setting the parameters for the environment
reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
cue_validity = 0.95 # 95% valid cues

reward_probability = 1.0 # 100% chance of reward in the correct arm
dependent_outcomes = True # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm depending on the punishment_probability parameter)
# setting the parameters for the environment reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation cue_validity = 0.95 # 95% valid cues reward_probability = 1.0 # 100% chance of reward in the correct arm dependent_outcomes = True # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm depending on the punishment_probability parameter)
In [ ]:
Copied!
# initializing the environment. see si_tmaze.py in pymdp/envs for the implementation details
env = SimplifiedTMaze(
    reward_condition=reward_condition,
    cue_validity=cue_validity,  
    reward_probability=reward_probability,     
    dependent_outcomes=dependent_outcomes,
)
# initializing the environment. see si_tmaze.py in pymdp/envs for the implementation details env = SimplifiedTMaze( reward_condition=reward_condition, cue_validity=cue_validity, reward_probability=reward_probability, dependent_outcomes=dependent_outcomes, )

Setting up the Agents¶

In [ ]:
Copied!
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes for the Agent
C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in env.A] 

# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0)    # prefer reward
C[1] = C[1].at[2].set(-6.0)   # avoid punishment

# slight cost of observing a cue
# C[0] = C[0].at[3].set(-1.0).at[4].set(-1.0)
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes for the Agent C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in env.A] # setting preferences for outcomes only C[1] = C[1].at[1].set(2.0) # prefer reward C[1] = C[1].at[2].set(-6.0) # avoid punishment # slight cost of observing a cue # C[0] = C[0].at[3].set(-1.0).at[4].set(-1.0)
In [ ]:
Copied!
# flat D tensors [location], [reward] based on B shapes for the agent
D = [jnp.ones(b.shape[0], dtype=jnp.float32) / b.shape[0] for b in env.B]
# flat D tensors [location], [reward] based on B shapes for the agent D = [jnp.ones(b.shape[0], dtype=jnp.float32) / b.shape[0] for b in env.B]
In [1]:
Copied!
# note that we initialize agents with different policy lengths for the vanilla vs sophisticated inference planning algorithms
# even though both will eventually end up planning with a horizon of 2. The sophisticated inference planning algorithm requires
# a policy length of 1 in the Agent as we specify horizon length of 2 when initializing the planning algorithm in the `rollout`.

# action_selection="deterministic" means selecting an action from the policy probability distribution (q_pi) by arg-maxxing
# sampling_mode="full" means evaluating the whole action sequence in each policy and executing the first action (as opposed to marginal where the agent evaluates each action type)
policy_len = 2
agent_vanilla = Agent(
    env.A, env.B, C, D, 
    A_dependencies=env.A_dependencies, 
    B_dependencies=env.B_dependencies,
    policy_len=policy_len,
    learn_A=False,
    learn_B=False,
    action_selection="deterministic",
    sampling_mode="full",
    gamma=3.0
)

agent_si = Agent(
    env.A, env.B, C, D, 
    A_dependencies=env.A_dependencies, 
    B_dependencies=env.B_dependencies,
    policy_len=1,
    learn_A=False,
    learn_B=False,
    action_selection="deterministic",
    sampling_mode="full",
    gamma=3.0
)
# note that we initialize agents with different policy lengths for the vanilla vs sophisticated inference planning algorithms # even though both will eventually end up planning with a horizon of 2. The sophisticated inference planning algorithm requires # a policy length of 1 in the Agent as we specify horizon length of 2 when initializing the planning algorithm in the `rollout`. # action_selection="deterministic" means selecting an action from the policy probability distribution (q_pi) by arg-maxxing # sampling_mode="full" means evaluating the whole action sequence in each policy and executing the first action (as opposed to marginal where the agent evaluates each action type) policy_len = 2 agent_vanilla = Agent( env.A, env.B, C, D, A_dependencies=env.A_dependencies, B_dependencies=env.B_dependencies, policy_len=policy_len, learn_A=False, learn_B=False, action_selection="deterministic", sampling_mode="full", gamma=3.0 ) agent_si = Agent( env.A, env.B, C, D, A_dependencies=env.A_dependencies, B_dependencies=env.B_dependencies, policy_len=1, learn_A=False, learn_B=False, action_selection="deterministic", sampling_mode="full", gamma=3.0 )
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/282069077.py:8: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent_vanilla = Agent(
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/282069077.py:20: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent_si = Agent(

Running the active inference rollouts¶

In [ ]:
Copied!
key = jr.PRNGKey(0) 
T = 3
key = jr.PRNGKey(0) T = 3
In [ ]:
Copied!
si_search = si_policy_search(
    horizon=policy_len, # plans 2 timesteps ahead
    max_nodes=5000, # maximum number of nodes allowed in the tree
    max_branching=10, # maximum number of children allowed per node (moderating the branching factor)
    policy_prune_threshold=0.0, # no pruning of unlikely policies
    observation_prune_threshold=0.0, # no pruning of unlikely observations
    entropy_stop_threshold=0.0, # disabling halting of expansion if agent is certain enough
    neg_efe_stop_threshold=1e10, # disabling efe value based halting of expansion
    kl_threshold=-1, # disabling node reuse if agent is in similar states after an action
    prune_penalty=512, # default value for prune penalty
    gamma=3, # temperature parameter; lower value (--> 1) prunes policies less aggressively as probabilities are flattened while higher value (--> 16) prunes more aggressively
    topk_obsspace=10000, # max number of top observation combinations - this default value just means we want to consider all the observation combinations
)
si_search = si_policy_search( horizon=policy_len, # plans 2 timesteps ahead max_nodes=5000, # maximum number of nodes allowed in the tree max_branching=10, # maximum number of children allowed per node (moderating the branching factor) policy_prune_threshold=0.0, # no pruning of unlikely policies observation_prune_threshold=0.0, # no pruning of unlikely observations entropy_stop_threshold=0.0, # disabling halting of expansion if agent is certain enough neg_efe_stop_threshold=1e10, # disabling efe value based halting of expansion kl_threshold=-1, # disabling node reuse if agent is in similar states after an action prune_penalty=512, # default value for prune penalty gamma=3, # temperature parameter; lower value (--> 1) prunes policies less aggressively as probabilities are flattened while higher value (--> 16) prunes more aggressively topk_obsspace=10000, # max number of top observation combinations - this default value just means we want to consider all the observation combinations )
In [ ]:
Copied!
_, info_vanilla = rollout(agent_vanilla, env, num_timesteps=T, rng_key=key) # default policy search is vanilla
_, info_si = rollout(agent_si, env, num_timesteps=T, rng_key=key, policy_search=si_search)
_, info_vanilla = rollout(agent_vanilla, env, num_timesteps=T, rng_key=key) # default policy search is vanilla _, info_si = rollout(agent_si, env, num_timesteps=T, rng_key=key, policy_search=si_search)
In [ ]:
Copied!
def make_gif(info):
    frames = []
    for t in range(info["observation"][0].shape[1]):  # iterate over timesteps
        # get observations for this timestep
        observations_t = [
            info["observation"][0][:, t, :],
            info["observation"][1][:, t, :],  
        ]
        
        frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
        frame = np.asarray(frame, dtype=np.uint8)
        plt.close()  # close the figure to prevent memory leak
        frames.append(frame)

    frames = np.array(frames, dtype=np.uint8)
    mediapy.show_video(frames, fps=1)
def make_gif(info): frames = [] for t in range(info["observation"][0].shape[1]): # iterate over timesteps # get observations for this timestep observations_t = [ info["observation"][0][:, t, :], info["observation"][1][:, t, :], ] frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep frame = np.asarray(frame, dtype=np.uint8) plt.close() # close the figure to prevent memory leak frames.append(frame) frames = np.array(frames, dtype=np.uint8) mediapy.show_video(frames, fps=1)
In [2]:
Copied!
make_gif(info_vanilla)
make_gif(info_vanilla)
This browser does not support the video tag.
In [3]:
Copied!
make_gif(info_si)
make_gif(info_si)
This browser does not support the video tag.

Result analysis¶

In [ ]:
Copied!
# qpi is a posterior over whole policies (action sequences).
# To get the probability of the *current* action, we marginalize over policies
# that share the same first action and sum their qpi values.
# helper functions for:
# - printing out policies and respective probabilities of selecting those policies
# - printing out action and observation info for each timestep

np.set_printoptions(precision=2, suppress=True)

def print_qpi(agent, info, print_t1=False):
    qpi_values = info["qpi"]

    action_names = {
        0: "move to center",
        1: "move to left arm",
        2: "move to right arm",
        3: "move to cue",
    }
    max_timesteps = 1 if print_t1 else qpi_values.shape[1]

    # unique_multiactions returns the unique first-step actions across policies
    # for a single control factor, this is just a list of action indices
    unique_actions = agent.unique_multiactions[:, 0]

    for t in range(max_timesteps):
        print(f"Timestep {t}:")
        action_probs = agent.multiaction_probabilities(qpi_values[:, t, :])[0]

        for action_idx, total_prob in zip(unique_actions.tolist(), action_probs.tolist()):
            if action_idx < 0:
                continue
            action_name = action_names.get(action_idx, f"action_{action_idx}")
            print(f"  {action_name}: {total_prob:.3f}")
        print()

def print_agent_behavior(info):

    action_names = {0: "move to center", 1: "move to left arm", 2: "move to right arm", 3: "move to cue"}

    location_obs = {0: "center loc", 1: "left arm loc", 2: "right arm loc", 3: "cue-left-arm", 4: "cue-right-arm"}
    outcome_obs = {0: "no_outcome", 1: "reward", 2: "punishment"}

    actions = info["action"]
    observations = info["observation"]

    num_timesteps = actions.shape[1]

    for t in range(num_timesteps):
        action_idx = int(actions[0, t, 0])  # [batch, timestep, action_dim]
        action_name = action_names.get(action_idx, f"action_{action_idx}")

        location_obs_idx = int(observations[0][0, t, 0])  # [modality][batch, timestep, obs_dim]
        outcome_obs_idx = int(observations[1][0, t, 0])

        location_name = location_obs.get(location_obs_idx)
        outcome_name = outcome_obs.get(outcome_obs_idx)

        print(f"t={t}: observed=({location_name}, {outcome_name}) -> action={action_name}")
# qpi is a posterior over whole policies (action sequences). # To get the probability of the *current* action, we marginalize over policies # that share the same first action and sum their qpi values. # helper functions for: # - printing out policies and respective probabilities of selecting those policies # - printing out action and observation info for each timestep np.set_printoptions(precision=2, suppress=True) def print_qpi(agent, info, print_t1=False): qpi_values = info["qpi"] action_names = { 0: "move to center", 1: "move to left arm", 2: "move to right arm", 3: "move to cue", } max_timesteps = 1 if print_t1 else qpi_values.shape[1] # unique_multiactions returns the unique first-step actions across policies # for a single control factor, this is just a list of action indices unique_actions = agent.unique_multiactions[:, 0] for t in range(max_timesteps): print(f"Timestep {t}:") action_probs = agent.multiaction_probabilities(qpi_values[:, t, :])[0] for action_idx, total_prob in zip(unique_actions.tolist(), action_probs.tolist()): if action_idx < 0: continue action_name = action_names.get(action_idx, f"action_{action_idx}") print(f" {action_name}: {total_prob:.3f}") print() def print_agent_behavior(info): action_names = {0: "move to center", 1: "move to left arm", 2: "move to right arm", 3: "move to cue"} location_obs = {0: "center loc", 1: "left arm loc", 2: "right arm loc", 3: "cue-left-arm", 4: "cue-right-arm"} outcome_obs = {0: "no_outcome", 1: "reward", 2: "punishment"} actions = info["action"] observations = info["observation"] num_timesteps = actions.shape[1] for t in range(num_timesteps): action_idx = int(actions[0, t, 0]) # [batch, timestep, action_dim] action_name = action_names.get(action_idx, f"action_{action_idx}") location_obs_idx = int(observations[0][0, t, 0]) # [modality][batch, timestep, obs_dim] outcome_obs_idx = int(observations[1][0, t, 0]) location_name = location_obs.get(location_obs_idx) outcome_name = outcome_obs.get(outcome_obs_idx) print(f"t={t}: observed=({location_name}, {outcome_name}) -> action={action_name}")

We can see both agents select the optimal actions to go to the cue first and then go to the left arm to get a reward.

In [4]:
Copied!
print_agent_behavior(info_vanilla)
print_agent_behavior(info_vanilla)
t=0: observed=(center loc, no_outcome) -> action=move to cue
t=1: observed=(cue-left-arm, no_outcome) -> action=move to left arm
t=2: observed=(left arm loc, reward) -> action=move to left arm
t=3: observed=(left arm loc, reward) -> action=move to left arm
In [5]:
Copied!
print_agent_behavior(info_si)
print_agent_behavior(info_si)
t=0: observed=(center loc, no_outcome) -> action=move to cue
t=1: observed=(cue-left-arm, no_outcome) -> action=move to left arm
t=2: observed=(left arm loc, reward) -> action=move to left arm
t=3: observed=(left arm loc, reward) -> action=move to left arm

Now, to see how the two planning algorithms differ, let's examine how they evaluate policies...

While both agents select the optimal strategy (information gathering first), they differ significantly in their confidence, with the SI agent shows much stronger preference for the information-gathering strategy.

  • Vanilla Agent: 81% probability for cue-seeking, 18% for staying at center
  • SI Agent: 98% probability for cue-seeking, <1% for staying at center
In [6]:
Copied!
print_qpi(agent_vanilla, info_vanilla, print_t1=True)
print_qpi(agent_vanilla, info_vanilla, print_t1=True)
Timestep 0:
  move to center: 0.183
  move to left arm: 0.004
  move to right arm: 0.004
  move to cue: 0.809

In [7]:
Copied!
print_qpi(agent_si, info_si, print_t1=True)
print_qpi(agent_si, info_si, print_t1=True)
Timestep 0:
  move to center: 0.003
  move to left arm: 0.008
  move to right arm: 0.008
  move to cue: 0.980

Doing the same experiment but with the extendedTMaze environment used in other demos¶

This section mirrors the preceding section, but uses pymdp's default T-Maze environment as used in the T-Maze Demo. It differs from the SimplifiedTMaze environment used before due to the restricted spatial geometry (middle connector that the agent must pass through to reach the two arms) and the fact that the cue and location modalities are split out into separate modalities. We set cue validity (cue probability) to 1.0, use policy_len = 4 (to account for the additional planning depth required for navigating the T-Maze), and observation_prune_threshold = 1e-4 for the SI search.

In [ ]:
Copied!
from pymdp.envs import TMaze
from pymdp.envs import TMaze

Setting up the T-Maze environment (Generative Process)¶

The rules/hyperparameters of the T-Maze are identical to the TMaze above, except that the generative process parameters (states and observations) are slightly different due to the factorization of the joint cue+location modality (embedded together in SimplifiedTMaze) into two separate modalities.

Click here to see how the generative process for the `TMaze` env (non-simplified) is set up.

States and Observations¶

State Factors:

  1. Location (5 states):
    • 0: center (start location)
    • 1: left arm
    • 2: right arm
    • 3: cue location (bottom arm)
    • 4: middle (junction between center and arms)
  2. Reward Location (2 states):
    • 0: reward in left arm
    • 1: reward in right arm

Control State Factors:

  1. Location (5 actions):
    • 0: move to center (start location)
    • 1: move to left arm
    • 2: move to right arm
    • 3: move to cue location (bottom arm)
    • 4: move to middle (junction)

Observation Modalities:

  1. Location (5 observations):
    • 0: center (start location)
    • 1: left arm
    • 2: right arm
    • 3: cue location (bottom arm)
    • 4: middle (junction)
  2. Outcome (3 observations):
    • 0: no outcome
    • 1: reward (cheese)
    • 2: punishment (shock)
  3. Cue (3 observations):
    • 0: no cue
    • 1: cue indicates left arm
    • 2: cue indicates right arm

Environment Parameters¶

Observation Likelihood Model (A):

  • A[0]: Location observations (5x5 tensor)
    • Perfect mapping between true and observed locations.
  • A[1]: Outcome observations (3x5x2 tensor)
    • In the rewarding arm (set by reward_condition), reward is presented with a likelihood determined by the reward_probability parameter.
    • Punishment and/or no-outcome are presented with a likelihood determined depending on if dependent_outcomes is True or False and consequently by the punishment_probability parameter.
    • No-outcome is observed in the center/start location, cue location, and middle.
  • A[2]: Cue observations (3x5x2 tensor)
    • The cue is only observed at the cue location, with accuracy set by cue_validity.
    • No cue is observed at all other locations.

Transition Model (B):

  • B[0]: Location transitions (5x5x5 tensor)
    • Agent can move between adjacent locations in the T-maze; invalid moves leave the agent in place.
  • B[1]: Reward location (2x2x1 tensor)
    • Reward location remains fixed throughout trial.

Initial Conditions (D):

  • D[0]: Starting location (5x1 tensor)
    • Agent always begins in center location
  • D[1]: Reward placement (2x1 tensor)
    • Default: Equal chance (50/50) of reward in either arm (reward_condition=None)
    • Optional: Can fix reward to specific arm, by setting reward_condition to 0 (for left arm) or 1 (for right arm)
In [ ]:
Copied!
# setting the parameters for the environment
reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
cue_validity = 1.0 # 100% valid cues (cue probability)

reward_probability = 1.0 # 100% chance of reward in the correct arm
dependent_outcomes = True # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm depending on the punishment_probability parameter)
punishment_probability = 1.0 # 100% chance of punishment in the other arm
# setting the parameters for the environment reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation cue_validity = 1.0 # 100% valid cues (cue probability) reward_probability = 1.0 # 100% chance of reward in the correct arm dependent_outcomes = True # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm depending on the punishment_probability parameter) punishment_probability = 1.0 # 100% chance of punishment in the other arm
In [ ]:
Copied!
# initializing the environment. see tmaze.py in pymdp/envs for the implementation details
env = TMaze(
    reward_condition=reward_condition,
    cue_validity=cue_validity,  
    reward_probability=reward_probability,
    punishment_probability=punishment_probability,     
    dependent_outcomes=dependent_outcomes,
)
# initializing the environment. see tmaze.py in pymdp/envs for the implementation details env = TMaze( reward_condition=reward_condition, cue_validity=cue_validity, reward_probability=reward_probability, punishment_probability=punishment_probability, dependent_outcomes=dependent_outcomes, )

Setting up the Agents¶

In [ ]:
Copied!
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes for the Agent
C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in env.A] 

# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0)    # prefer reward
C[1] = C[1].at[2].set(-6.0)   # avoid punishment

# slight cost of observing a cue
C[2] = C[2].at[1].set(-0.5) 
C[2] = C[2].at[2].set(-0.5)
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes for the Agent C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in env.A] # setting preferences for outcomes only C[1] = C[1].at[1].set(2.0) # prefer reward C[1] = C[1].at[2].set(-6.0) # avoid punishment # slight cost of observing a cue C[2] = C[2].at[1].set(-0.5) C[2] = C[2].at[2].set(-0.5)
In [ ]:
Copied!
# D tensors [location], [reward] based on B shapes for the agent
# - agent starts in the center location
# - reward location prior is uniform
D_loc = jnp.zeros(env.B[0].shape[0], dtype=jnp.float32)
D_loc = D_loc.at[0].set(1.0)

D_reward = jnp.ones(env.B[1].shape[0], dtype=jnp.float32)
D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True)

D = [D_loc, D_reward]
# D tensors [location], [reward] based on B shapes for the agent # - agent starts in the center location # - reward location prior is uniform D_loc = jnp.zeros(env.B[0].shape[0], dtype=jnp.float32) D_loc = D_loc.at[0].set(1.0) D_reward = jnp.ones(env.B[1].shape[0], dtype=jnp.float32) D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True) D = [D_loc, D_reward]
In [8]:
Copied!
# note that we initialize agents with different policy lengths for the vanilla vs sophisticated inference planning algorithms
# even though both will eventually end up planning with a horizon of 4. The sophisticated inference planning algorithm requires
# a policy length of 1 in the Agent as we specify horizon length of 4 when initializing the planning algorithm in the `rollout`.

# action_selection="deterministic" means selecting an action from the policy probability distribution (q_pi) by arg-maxxing
# sampling_mode="full" means evaluating the whole action sequence in each policy and executing the first action (as opposed to marginal where the agent evaluates each action type)

gamma = 3.0
policy_len = 4 
agent_vanilla = Agent(
    env.A, env.B, C, D, 
    A_dependencies=env.A_dependencies, 
    B_dependencies=env.B_dependencies,
    policy_len=policy_len,
    learn_A=False,
    learn_B=False,
    action_selection="deterministic",
    sampling_mode="full",
    gamma=gamma,
)

agent_si = Agent(
    env.A, env.B, C, D, 
    A_dependencies=env.A_dependencies, 
    B_dependencies=env.B_dependencies,
    policy_len=1,
    learn_A=False,
    learn_B=False,
    action_selection="deterministic",
    sampling_mode="full",
    gamma=gamma,
)
# note that we initialize agents with different policy lengths for the vanilla vs sophisticated inference planning algorithms # even though both will eventually end up planning with a horizon of 4. The sophisticated inference planning algorithm requires # a policy length of 1 in the Agent as we specify horizon length of 4 when initializing the planning algorithm in the `rollout`. # action_selection="deterministic" means selecting an action from the policy probability distribution (q_pi) by arg-maxxing # sampling_mode="full" means evaluating the whole action sequence in each policy and executing the first action (as opposed to marginal where the agent evaluates each action type) gamma = 3.0 policy_len = 4 agent_vanilla = Agent( env.A, env.B, C, D, A_dependencies=env.A_dependencies, B_dependencies=env.B_dependencies, policy_len=policy_len, learn_A=False, learn_B=False, action_selection="deterministic", sampling_mode="full", gamma=gamma, ) agent_si = Agent( env.A, env.B, C, D, A_dependencies=env.A_dependencies, B_dependencies=env.B_dependencies, policy_len=1, learn_A=False, learn_B=False, action_selection="deterministic", sampling_mode="full", gamma=gamma, )
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/2056066623.py:10: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent_vanilla = Agent(
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/2056066623.py:22: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent_si = Agent(

Running the active inference rollouts¶

In [ ]:
Copied!
key = jr.PRNGKey(0) 
T = 5
key = jr.PRNGKey(0) T = 5
In [ ]:
Copied!
si_search = si_policy_search(
    horizon=policy_len, # plans 4 timesteps ahead
    max_nodes=5000, # maximum number of nodes allowed in the tree
    max_branching=45, # maximum number of children allowed per node (moderating the branching factor)
    policy_prune_threshold=0.0, # no pruning of unlikely policies
    observation_prune_threshold=1e-4, # no pruning of unlikely observations
    entropy_stop_threshold=0.0, # disabling halting of expansion if agent is certain enough
    neg_efe_stop_threshold=1e10, # disabling efe value based halting of expansion
    kl_threshold=-1, # disabling node reuse if agent is in similar states after an action
    prune_penalty=512, # default value for prune penalty
    gamma=gamma, # temperature parameter; lower value (---> 1) prunes policies less aggressively as probabilities are flattened while higher value (---> 16) prunes more aggressively
    topk_obsspace=10000, # max number of top observation combinations - this default value just means we want to consider all the observation combinations
)
si_search = si_policy_search( horizon=policy_len, # plans 4 timesteps ahead max_nodes=5000, # maximum number of nodes allowed in the tree max_branching=45, # maximum number of children allowed per node (moderating the branching factor) policy_prune_threshold=0.0, # no pruning of unlikely policies observation_prune_threshold=1e-4, # no pruning of unlikely observations entropy_stop_threshold=0.0, # disabling halting of expansion if agent is certain enough neg_efe_stop_threshold=1e10, # disabling efe value based halting of expansion kl_threshold=-1, # disabling node reuse if agent is in similar states after an action prune_penalty=512, # default value for prune penalty gamma=gamma, # temperature parameter; lower value (---> 1) prunes policies less aggressively as probabilities are flattened while higher value (---> 16) prunes more aggressively topk_obsspace=10000, # max number of top observation combinations - this default value just means we want to consider all the observation combinations )
In [ ]:
Copied!
_, info_vanilla = rollout(agent_vanilla, env, num_timesteps=T, rng_key=key) # default policy search is vanilla
_, info_si = rollout(agent_si, env, num_timesteps=T, rng_key=key, policy_search=si_search)
_, info_vanilla = rollout(agent_vanilla, env, num_timesteps=T, rng_key=key) # default policy search is vanilla _, info_si = rollout(agent_si, env, num_timesteps=T, rng_key=key, policy_search=si_search)
In [ ]:
Copied!
def make_gif(info):
    frames = []
    for t in range(info["observation"][0].shape[1]):  # iterate over timesteps
        # get observations for this timestep
        observations_t = [
            info["observation"][0][:, t, :],
            info["observation"][1][:, t, :],  
            info["observation"][2][:, t, :],
        ]
        
        frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
        frame = np.asarray(frame, dtype=np.uint8)
        plt.close()  # close the figure to prevent memory leak
        frames.append(frame)

    frames = np.array(frames, dtype=np.uint8)
    mediapy.show_video(frames, fps=1)
def make_gif(info): frames = [] for t in range(info["observation"][0].shape[1]): # iterate over timesteps # get observations for this timestep observations_t = [ info["observation"][0][:, t, :], info["observation"][1][:, t, :], info["observation"][2][:, t, :], ] frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep frame = np.asarray(frame, dtype=np.uint8) plt.close() # close the figure to prevent memory leak frames.append(frame) frames = np.array(frames, dtype=np.uint8) mediapy.show_video(frames, fps=1)
In [9]:
Copied!
make_gif(info_vanilla)
make_gif(info_vanilla)
This browser does not support the video tag.
In [10]:
Copied!
make_gif(info_si)
make_gif(info_si)
This browser does not support the video tag.

Result analysis¶

In [ ]:
Copied!
# qpi is a posterior over whole policies (action sequences).
# To get the probability of the *current* action, we marginalize over policies
# that share the same first action and sum their qpi values.
# helper functions for:
# - printing out policies and respective probabilities of selecting those policies
# - printing out action and observation info for each timestep

np.set_printoptions(precision=2, suppress=True)

def print_qpi(agent, info, print_t1=False):
    qpi_values = info["qpi"]

    action_names = {
        0: "move to center",
        1: "move to left arm",
        2: "move to right arm",
        3: "move to cue",
        4: "move to middle",
    }
    max_timesteps = 1 if print_t1 else qpi_values.shape[1]

    # unique_multiactions returns the unique first-step actions across policies
    # for a single control factor, this is just a list of action indices
    unique_actions = agent.unique_multiactions[:, 0]

    for t in range(max_timesteps):
        print(f"Timestep {t}:")
        action_probs = agent.multiaction_probabilities(qpi_values[:, t, :])[0]

        for action_idx, total_prob in zip(unique_actions.tolist(), action_probs.tolist()):
            if action_idx < 0:
                continue
            action_name = action_names.get(action_idx, f"action_{action_idx}")
            print(f"  {action_name}: {total_prob:.3f}")
        print()

def print_agent_behavior(info):

    action_names = {
        0: "move to center",
        1: "move to left arm",
        2: "move to right arm",
        3: "move to cue",
        4: "move to middle",
    }
    
    location_obs = {
        0: "center loc",
        1: "left arm loc",
        2: "right arm loc",
        3: "cue loc",
        4: "middle loc",
    }
    outcome_obs = {0: "no_outcome", 1: "reward", 2: "punishment"}
    cue_obs = {0: "no cue", 1: "cue-left", 2: "cue-right"}
    
    actions = info["action"]
    observations = info["observation"]
    
    num_timesteps = actions.shape[1]
    
    for t in range(num_timesteps):
        action_idx = int(actions[0, t, 0])  # [batch, timestep, action_dim]
        action_name = action_names.get(action_idx, f"action_{action_idx}")
        
        location_obs_idx = int(observations[0][0, t, 0])  # [modality][batch, timestep, obs_dim]
        outcome_obs_idx = int(observations[1][0, t, 0])
        cue_obs_idx = int(observations[2][0, t, 0])
        
        location_name = location_obs.get(location_obs_idx)
        outcome_name = outcome_obs.get(outcome_obs_idx)
        cue_name = cue_obs.get(cue_obs_idx)
        
        print(f"t={t}: observed=({location_name}, {outcome_name}, {cue_name}) -> action={action_name}")
# qpi is a posterior over whole policies (action sequences). # To get the probability of the *current* action, we marginalize over policies # that share the same first action and sum their qpi values. # helper functions for: # - printing out policies and respective probabilities of selecting those policies # - printing out action and observation info for each timestep np.set_printoptions(precision=2, suppress=True) def print_qpi(agent, info, print_t1=False): qpi_values = info["qpi"] action_names = { 0: "move to center", 1: "move to left arm", 2: "move to right arm", 3: "move to cue", 4: "move to middle", } max_timesteps = 1 if print_t1 else qpi_values.shape[1] # unique_multiactions returns the unique first-step actions across policies # for a single control factor, this is just a list of action indices unique_actions = agent.unique_multiactions[:, 0] for t in range(max_timesteps): print(f"Timestep {t}:") action_probs = agent.multiaction_probabilities(qpi_values[:, t, :])[0] for action_idx, total_prob in zip(unique_actions.tolist(), action_probs.tolist()): if action_idx < 0: continue action_name = action_names.get(action_idx, f"action_{action_idx}") print(f" {action_name}: {total_prob:.3f}") print() def print_agent_behavior(info): action_names = { 0: "move to center", 1: "move to left arm", 2: "move to right arm", 3: "move to cue", 4: "move to middle", } location_obs = { 0: "center loc", 1: "left arm loc", 2: "right arm loc", 3: "cue loc", 4: "middle loc", } outcome_obs = {0: "no_outcome", 1: "reward", 2: "punishment"} cue_obs = {0: "no cue", 1: "cue-left", 2: "cue-right"} actions = info["action"] observations = info["observation"] num_timesteps = actions.shape[1] for t in range(num_timesteps): action_idx = int(actions[0, t, 0]) # [batch, timestep, action_dim] action_name = action_names.get(action_idx, f"action_{action_idx}") location_obs_idx = int(observations[0][0, t, 0]) # [modality][batch, timestep, obs_dim] outcome_obs_idx = int(observations[1][0, t, 0]) cue_obs_idx = int(observations[2][0, t, 0]) location_name = location_obs.get(location_obs_idx) outcome_name = outcome_obs.get(outcome_obs_idx) cue_name = cue_obs.get(cue_obs_idx) print(f"t={t}: observed=({location_name}, {outcome_name}, {cue_name}) -> action={action_name}")

With the classic T-maze's adjacency constraints (and the extra middle location), reaching the reward after sampling the cue takes multiple steps. With the current horizon (4) and T=5, you should see cue-seeking and movement toward the chosen arm; increase policy_len and/or T if you want longer sequences.

In [11]:
Copied!
print_agent_behavior(info_vanilla)
print_agent_behavior(info_vanilla)
t=0: observed=(center loc, no_outcome, no cue) -> action=move to cue
t=1: observed=(cue loc, no_outcome, cue-left) -> action=move to center
t=2: observed=(center loc, no_outcome, no cue) -> action=move to middle
t=3: observed=(middle loc, no_outcome, no cue) -> action=move to left arm
t=4: observed=(left arm loc, reward, no cue) -> action=move to center
t=5: observed=(left arm loc, reward, no cue) -> action=move to center
In [12]:
Copied!
print_agent_behavior(info_si)
print_agent_behavior(info_si)
t=0: observed=(center loc, no_outcome, no cue) -> action=move to cue
t=1: observed=(cue loc, no_outcome, cue-left) -> action=move to center
t=2: observed=(center loc, no_outcome, no cue) -> action=move to middle
t=3: observed=(middle loc, no_outcome, no cue) -> action=move to left arm
t=4: observed=(left arm loc, reward, no cue) -> action=move to center
t=5: observed=(left arm loc, reward, no cue) -> action=move to center

Now, to see how the two planning algorithms differ, let's examine how they evaluate policies...

With cue_validity = 1.0, both planners should prefer information gathering, but the SI planner typically yields a sharper (more confident) posterior over policies. Compare how each planner's probabilities shift with the longer horizon.

In [13]:
Copied!
print_qpi(agent_vanilla, info_vanilla, print_t1=True)
print_qpi(agent_vanilla, info_vanilla, print_t1=True)
Timestep 0:
  move to center: 0.142
  move to left arm: 0.142
  move to right arm: 0.142
  move to cue: 0.539
  move to middle: 0.036

In [14]:
Copied!
print_qpi(agent_si, info_si, print_t1=True)
print_qpi(agent_si, info_si, print_t1=True)
Timestep 0:
  move to center: 0.001
  move to left arm: 0.001
  move to right arm: 0.001
  move to cue: 0.810
  move to middle: 0.186

Made with Dracula Theme for MkDocs