pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Sophisticated inference on a Generalized T-Maze¶

This notebook demonstrates the effect of sophisticated inference on a generalized T-maze which has multiple cues, out of which only one leads to reward. Sophisticated inference correctly focuses on the cue that yields reward, whereas vanilla AIF will be distracted by the prospect of information gain of the other cues. However, due to the long planning horizon this takes a long time to compute doing the full tree search. See the mcts_generalized_tmaze example to see how to combat this using MCMC sampling.

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
%load_ext autoreload
%autoreload 2

from pymdp.envs.generalized_tmaze import (
    GeneralizedTMazeEnv, get_maze_matrix, parse_maze 
)
from pymdp.agent import Agent
from pymdp.utils import list_array_zeros, list_array_uniform

from jax import random as jr, numpy as jnp
from jax import tree_util as jtu

key = jr.PRNGKey(0)
%load_ext autoreload %autoreload 2 from pymdp.envs.generalized_tmaze import ( GeneralizedTMazeEnv, get_maze_matrix, parse_maze ) from pymdp.agent import Agent from pymdp.utils import list_array_zeros, list_array_uniform from jax import random as jr, numpy as jnp from jax import tree_util as jtu key = jr.PRNGKey(0)
In [1]:
Copied!
key, subkey = jr.split(key)
M = get_maze_matrix(small=True)
env_info = parse_maze(M, subkey)
tmaze_env = GeneralizedTMazeEnv(env_info)

init_obs, init_state = tmaze_env.reset(key)
tmaze_env.render(states=init_state, mode="human")
key, subkey = jr.split(key) M = get_maze_matrix(small=True) env_info = parse_maze(M, subkey) tmaze_env = GeneralizedTMazeEnv(env_info) init_obs, init_state = tmaze_env.reset(key) tmaze_env.render(states=init_state, mode="human")
<Figure size 640x480 with 0 Axes>
No description has been provided for this image
In [ ]:
Copied!
A, B = tmaze_env.A, tmaze_env.B
A_dependencies, B_dependencies = tmaze_env.A_dependencies, tmaze_env.B_dependencies

# [position], [cue], [reward]
C = list_array_zeros([a.shape[0] for a in A])

rewarding_modality = 2 + env_info["num_cues"]

C[rewarding_modality] = C[rewarding_modality].at[1].set(2.0)
C[rewarding_modality] = C[rewarding_modality].at[2].set(-3.0)

D = list_array_uniform([b.shape[0] for b in B])
D[0] = tmaze_env.D[0].copy()

agent = Agent(
    A, B, C, D, 
    None, None, None, 
    policy_len=1,
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies
    )
A, B = tmaze_env.A, tmaze_env.B A_dependencies, B_dependencies = tmaze_env.A_dependencies, tmaze_env.B_dependencies # [position], [cue], [reward] C = list_array_zeros([a.shape[0] for a in A]) rewarding_modality = 2 + env_info["num_cues"] C[rewarding_modality] = C[rewarding_modality].at[1].set(2.0) C[rewarding_modality] = C[rewarding_modality].at[2].set(-3.0) D = list_array_uniform([b.shape[0] for b in B]) D[0] = tmaze_env.D[0].copy() agent = Agent( A, B, C, D, None, None, None, policy_len=1, A_dependencies=A_dependencies, B_dependencies=B_dependencies )
In [ ]:
Copied!
qs = agent.infer_states(
    observations=jtu.tree_map(lambda x: jnp.broadcast_to(x, (agent.batch_size,) + x.shape), init_obs),
    past_actions=None,
    empirical_prior=agent.D,
    qs_hist=None,
)
qs = agent.infer_states( observations=jtu.tree_map(lambda x: jnp.broadcast_to(x, (agent.batch_size,) + x.shape), init_obs), past_actions=None, empirical_prior=agent.D, qs_hist=None, )
In [ ]:
Copied!
from pymdp.planning.si import si_policy_search

key = jr.PRNGKey(0)
policy_search = si_policy_search(horizon=3, entropy_stop_threshold=0.0, gamma=5)
qpi, info = policy_search(agent, qs, key) 
from pymdp.planning.si import si_policy_search key = jr.PRNGKey(0) policy_search = si_policy_search(horizon=3, entropy_stop_threshold=0.0, gamma=5) qpi, info = policy_search(agent, qs, key)

When plotting the plan tree, the preferred plan is indeed to go to the right cue, and then fetch the reward.

In [2]:
Copied!
from pymdp.planning.visualize import visualize_plan_tree


def generalized_tmaze_observation_description(observation, model=None, env_info=env_info):
    obs = observation[0]
    maze_shape = env_info["maze"].shape
    num_cues = env_info["num_cues"]

    pos_idx = int(obs[0])
    row, col = divmod(pos_idx, maze_shape[1])

    cue_vals = [int(v) for v in obs[1 : 1 + num_cues]]
    reward_vals = [int(v) for v in obs[1 + num_cues : 1 + (2 * num_cues)]]

    cue_map = {0: "none", 1: "r1", 2: "r2"}
    reward_map = {0: "none", 1: "reward", 2: "punish"}

    cues_str = ", ".join(
        f"c{i + 1}:{cue_map.get(v, v)}" for i, v in enumerate(cue_vals)
    )
    rewards_str = ", ".join(
        f"r{i + 1}:{reward_map.get(v, v)}" for i, v in enumerate(reward_vals)
    )

    return (
        f"pos: {pos_idx} ({row},{col})\n"
        f"cues: {cues_str}\n"
        f"rewards: {rewards_str}"
    )


def generalized_tmaze_action_description(action, model=None):
    action_idx = int(action[0])
    actions = ["down", "up", "left", "right", "stay"]

    if 0 <= action_idx < len(actions):
        return actions[action_idx]
    return str(action_idx)


visualize_plan_tree(
    info,
    time_idx=None,
    observation_description=generalized_tmaze_observation_description,
    action_description=generalized_tmaze_action_description,
)
from pymdp.planning.visualize import visualize_plan_tree def generalized_tmaze_observation_description(observation, model=None, env_info=env_info): obs = observation[0] maze_shape = env_info["maze"].shape num_cues = env_info["num_cues"] pos_idx = int(obs[0]) row, col = divmod(pos_idx, maze_shape[1]) cue_vals = [int(v) for v in obs[1 : 1 + num_cues]] reward_vals = [int(v) for v in obs[1 + num_cues : 1 + (2 * num_cues)]] cue_map = {0: "none", 1: "r1", 2: "r2"} reward_map = {0: "none", 1: "reward", 2: "punish"} cues_str = ", ".join( f"c{i + 1}:{cue_map.get(v, v)}" for i, v in enumerate(cue_vals) ) rewards_str = ", ".join( f"r{i + 1}:{reward_map.get(v, v)}" for i, v in enumerate(reward_vals) ) return ( f"pos: {pos_idx} ({row},{col})\n" f"cues: {cues_str}\n" f"rewards: {rewards_str}" ) def generalized_tmaze_action_description(action, model=None): action_idx = int(action[0]) actions = ["down", "up", "left", "right", "stay"] if 0 <= action_idx < len(actions): return actions[action_idx] return str(action_idx) visualize_plan_tree( info, time_idx=None, observation_description=generalized_tmaze_observation_description, action_description=generalized_tmaze_action_description, )
No description has been provided for this image

If we limit the planning horizon to a single step, we can see that indeed it renders both cues equally attractive.

In [ ]:
Copied!
policy_search = si_policy_search(horizon=1, entropy_stop_threshold=0.0, gamma=5)
qpi, info = policy_search(agent, qs, key) 
policy_search = si_policy_search(horizon=1, entropy_stop_threshold=0.0, gamma=5) qpi, info = policy_search(agent, qs, key)
In [3]:
Copied!
visualize_plan_tree(
    info,
    time_idx=None,
    observation_description=generalized_tmaze_observation_description,
    action_description=generalized_tmaze_action_description,
)
visualize_plan_tree( info, time_idx=None, observation_description=generalized_tmaze_observation_description, action_description=generalized_tmaze_action_description, )
No description has been provided for this image
In [ ]:
Copied!

Made with Dracula Theme for MkDocs