Validating Sophisticated Inference (SI) Planning Algorithm using the T-Maze Task¶
Overview¶
This notebook provides a comparative test of two planning algorithms within the active inference framework:
- Vanilla Active Inference Planning: Standard and default planning approach in
pymdp.control.infer_policies()that computes the expected free energy of policies in parallel, and the log probability of each policy is proportional to its expected free energy. The expected free energy of each policy is computed by propagating beliefs about hidden states forward in time under each policy (a typical 'posterior predictive rollout'), without updating those beliefs with counterfactual observations, expected under those future states. This means the belief over hidden states at timesteptin a given rollout, only depends on:
- the previous belief over hidden states at time
tand - the action entailed by the policy in question at time
t.
- Sophisticated Inference (SI) Planning: A more complex but thorough planning approach that evaluates policies using recursive expected free energy calculations, where the agent considers how its beliefs would change under counterfactual observations encountered in the future. This algorithm has higher complexity since it branches on both observations and actions in the future, but it more accurately captures how uncertainty over hidden states changes in the future, as a function of (possibly-encountered) observations.
This notebook uses a simplified version of the T-Maze environment TMazeSimplified as described in the sophisticated inference paper to validate the SI implementation.
The T-Maze Task¶
The T-maze is a classic sequential decision-making problem where the agent (analogized to a rat) faces a fundamental exploration vs. exploitation dilemma. The agent begins at the center of a T-shaped maze, where there is also a left arm, right arm, and cue location (bottom arm). One arm contains reward (cheese), the other punishment (shock), but the agent does not know which arm contains what. A cue at the bottom provides information about reward location with a given accuracy. This version of the T-Maze differs from that described in the T-Maze Demo because there is no "middle" location between the left and right arms, and because every location is reachable from every other location. Additionally, location and cue observations are embedded into a single modality, rather than separated out into two separate modalities.
The agent must choose between immediate but risky exploitation by directly committing to one of the reward arms (50% chance of success) or gather information (explore) first by visiting the cue location at the expense of time, then make an informed choice. When cue validity > 50%, the optimal strategy is to first gather information at the cue location, then select the indicated rewarding arm. This notebook tests whether both planning algorithms can discover this optimal policy and how.
Notebook Structure¶
- Environment Setup: Set up the generative process for the T-Maze environment
- Agent Setup: Set up the generative model for the agent
- Active Inference Rollout: Run active inference rollouts with vanilla and sophisticated inference planning algorithms, with optional visualizations of the agents' behavior
- Results Analysis: Compare actions selected and policy evaluations
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp[nb]" -q
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import mediapy
from pymdp.envs import SimplifiedTMaze, rollout
from pymdp.agent import Agent
from pymdp.planning.si import si_policy_search
Setting up the T-Maze environment (Generative Process)¶
- The
reward_conditionparameter determines the reward location:0for the left arm,1for the right arm, orNonefor random allocation. - The
cue_validityparameter (default 0.95) represents the accuracy of the cues as a probability. - The
reward_probabilityparameter sets the probabilityaof receiving a reward in the correct arm. - With
dependent_outcomes=True, the remaining probability (1-a) becomes punishment in the correct arm (and reward in the incorrect arm). Withdependent_outcomes=False, the remaining probability in the correct arm is no-outcome, and punishment in the incorrect arm is set bypunishment_probability(with no-outcome as the remainder).
Click here to see how the generative process is set up.
States and Observations¶
State Factors:
- Location (4 states):
- 0: center (start location)
- 1: left arm
- 2: right arm
- 3: cue location (bottom arm)
- Reward Location (2 states):
- 0: reward in left arm
- 1: reward in right arm
Control State Factors:
- Location (4 actions):
- 0: move to center (start location)
- 1: move to left arm
- 2: move to right arm
- 3: move to cue location (bottom arm)
Observation Modalities:
- Location (5 observations):
- 0: center (start location)
- 1: left arm
- 2: right arm
- 3: cued left arm (bottom arm)
- 4: cued right arm (bottom arm)
- Outcome (3 observations):
- 0: no outcome
- 1: reward (cheese)
- 2: punishment (shock)
Environment Parameters¶
Observation Likelihood Model (A):
- A[0]: Location observations (5x4x2 tensor)
- Perfect mapping between true and observed locations.
- At the cue location (bottom), the rewarding arm is cued with accuracy set by the
cue_validityparameter.
- A[1]: Outcome observations (3x4x2 tensor)
- In the rewarding arm (set by
reward_condition), reward is presented with a likelihood determined by thereward_probabilityparameter. - Punishment and/or no-outcome are presented with a likelihood determined depending on if
dependent_outcomeis True or False and consequently by thepunishment_probabilityparameter. - No-outcome is observed in the center/start location and cue location.
- In the rewarding arm (set by
Transition Model (B):
- B[0]: Location transitions (4x4x4 tensor)
- Agent can move to any one of the four locations from any location regardless of adjacency.
- B[1]: Reward location (2x2x1 tensor)
- Reward location remains fixed throughout trial.
Initial Conditions (D):
- D[0]: Starting location (4x1 tensor)
- Agent always begins in center location
- D[1]: Reward placement (2x1 tensor)
- Default: Equal chance (50/50) of reward in either arm (
reward_condition=None) - Optional: Can fix reward to specific arm, by setting
reward_conditionto0(for left arm) or1(for right arm)
- Default: Equal chance (50/50) of reward in either arm (
# setting the parameters for the environment
reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
cue_validity = 0.95 # 95% valid cues
reward_probability = 1.0 # 100% chance of reward in the correct arm
dependent_outcomes = True # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm depending on the punishment_probability parameter)
# initializing the environment. see si_tmaze.py in pymdp/envs for the implementation details
env = SimplifiedTMaze(
reward_condition=reward_condition,
cue_validity=cue_validity,
reward_probability=reward_probability,
dependent_outcomes=dependent_outcomes,
)
Setting up the Agents¶
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes for the Agent
C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in env.A]
# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0) # prefer reward
C[1] = C[1].at[2].set(-6.0) # avoid punishment
# slight cost of observing a cue
# C[0] = C[0].at[3].set(-1.0).at[4].set(-1.0)
# flat D tensors [location], [reward] based on B shapes for the agent
D = [jnp.ones(b.shape[0], dtype=jnp.float32) / b.shape[0] for b in env.B]
# note that we initialize agents with different policy lengths for the vanilla vs sophisticated inference planning algorithms
# even though both will eventually end up planning with a horizon of 2. The sophisticated inference planning algorithm requires
# a policy length of 1 in the Agent as we specify horizon length of 2 when initializing the planning algorithm in the `rollout`.
# action_selection="deterministic" means selecting an action from the policy probability distribution (q_pi) by arg-maxxing
# sampling_mode="full" means evaluating the whole action sequence in each policy and executing the first action (as opposed to marginal where the agent evaluates each action type)
policy_len = 2
agent_vanilla = Agent(
env.A, env.B, C, D,
A_dependencies=env.A_dependencies,
B_dependencies=env.B_dependencies,
policy_len=policy_len,
learn_A=False,
learn_B=False,
action_selection="deterministic",
sampling_mode="full",
gamma=3.0
)
agent_si = Agent(
env.A, env.B, C, D,
A_dependencies=env.A_dependencies,
B_dependencies=env.B_dependencies,
policy_len=1,
learn_A=False,
learn_B=False,
action_selection="deterministic",
sampling_mode="full",
gamma=3.0
)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/282069077.py:8: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent_vanilla = Agent( /var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/282069077.py:20: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent_si = Agent(
Running the active inference rollouts¶
key = jr.PRNGKey(0)
T = 3
si_search = si_policy_search(
horizon=policy_len, # plans 2 timesteps ahead
max_nodes=5000, # maximum number of nodes allowed in the tree
max_branching=10, # maximum number of children allowed per node (moderating the branching factor)
policy_prune_threshold=0.0, # no pruning of unlikely policies
observation_prune_threshold=0.0, # no pruning of unlikely observations
entropy_stop_threshold=0.0, # disabling halting of expansion if agent is certain enough
neg_efe_stop_threshold=1e10, # disabling efe value based halting of expansion
kl_threshold=-1, # disabling node reuse if agent is in similar states after an action
prune_penalty=512, # default value for prune penalty
gamma=3, # temperature parameter; lower value (--> 1) prunes policies less aggressively as probabilities are flattened while higher value (--> 16) prunes more aggressively
topk_obsspace=10000, # max number of top observation combinations - this default value just means we want to consider all the observation combinations
)
_, info_vanilla = rollout(agent_vanilla, env, num_timesteps=T, rng_key=key) # default policy search is vanilla
_, info_si = rollout(agent_si, env, num_timesteps=T, rng_key=key, policy_search=si_search)
def make_gif(info):
frames = []
for t in range(info["observation"][0].shape[1]): # iterate over timesteps
# get observations for this timestep
observations_t = [
info["observation"][0][:, t, :],
info["observation"][1][:, t, :],
]
frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
frame = np.asarray(frame, dtype=np.uint8)
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
make_gif(info_vanilla)
make_gif(info_si)
Result analysis¶
# qpi is a posterior over whole policies (action sequences).
# To get the probability of the *current* action, we marginalize over policies
# that share the same first action and sum their qpi values.
# helper functions for:
# - printing out policies and respective probabilities of selecting those policies
# - printing out action and observation info for each timestep
np.set_printoptions(precision=2, suppress=True)
def print_qpi(agent, info, print_t1=False):
qpi_values = info["qpi"]
action_names = {
0: "move to center",
1: "move to left arm",
2: "move to right arm",
3: "move to cue",
}
max_timesteps = 1 if print_t1 else qpi_values.shape[1]
# unique_multiactions returns the unique first-step actions across policies
# for a single control factor, this is just a list of action indices
unique_actions = agent.unique_multiactions[:, 0]
for t in range(max_timesteps):
print(f"Timestep {t}:")
action_probs = agent.multiaction_probabilities(qpi_values[:, t, :])[0]
for action_idx, total_prob in zip(unique_actions.tolist(), action_probs.tolist()):
if action_idx < 0:
continue
action_name = action_names.get(action_idx, f"action_{action_idx}")
print(f" {action_name}: {total_prob:.3f}")
print()
def print_agent_behavior(info):
action_names = {0: "move to center", 1: "move to left arm", 2: "move to right arm", 3: "move to cue"}
location_obs = {0: "center loc", 1: "left arm loc", 2: "right arm loc", 3: "cue-left-arm", 4: "cue-right-arm"}
outcome_obs = {0: "no_outcome", 1: "reward", 2: "punishment"}
actions = info["action"]
observations = info["observation"]
num_timesteps = actions.shape[1]
for t in range(num_timesteps):
action_idx = int(actions[0, t, 0]) # [batch, timestep, action_dim]
action_name = action_names.get(action_idx, f"action_{action_idx}")
location_obs_idx = int(observations[0][0, t, 0]) # [modality][batch, timestep, obs_dim]
outcome_obs_idx = int(observations[1][0, t, 0])
location_name = location_obs.get(location_obs_idx)
outcome_name = outcome_obs.get(outcome_obs_idx)
print(f"t={t}: observed=({location_name}, {outcome_name}) -> action={action_name}")
We can see both agents select the optimal actions to go to the cue first and then go to the left arm to get a reward.
print_agent_behavior(info_vanilla)
t=0: observed=(center loc, no_outcome) -> action=move to cue t=1: observed=(cue-left-arm, no_outcome) -> action=move to left arm t=2: observed=(left arm loc, reward) -> action=move to left arm t=3: observed=(left arm loc, reward) -> action=move to left arm
print_agent_behavior(info_si)
t=0: observed=(center loc, no_outcome) -> action=move to cue t=1: observed=(cue-left-arm, no_outcome) -> action=move to left arm t=2: observed=(left arm loc, reward) -> action=move to left arm t=3: observed=(left arm loc, reward) -> action=move to left arm
Now, to see how the two planning algorithms differ, let's examine how they evaluate policies...
While both agents select the optimal strategy (information gathering first), they differ significantly in their confidence, with the SI agent shows much stronger preference for the information-gathering strategy.
- Vanilla Agent: 81% probability for cue-seeking, 18% for staying at center
- SI Agent: 98% probability for cue-seeking, <1% for staying at center
print_qpi(agent_vanilla, info_vanilla, print_t1=True)
Timestep 0: move to center: 0.183 move to left arm: 0.004 move to right arm: 0.004 move to cue: 0.809
print_qpi(agent_si, info_si, print_t1=True)
Timestep 0: move to center: 0.003 move to left arm: 0.008 move to right arm: 0.008 move to cue: 0.980
Doing the same experiment but with the extendedTMaze environment used in other demos¶
This section mirrors the preceding section, but uses pymdp's default T-Maze environment as used in the T-Maze Demo. It differs from the SimplifiedTMaze environment used before due to the restricted spatial geometry (middle connector that the agent must pass through to reach the two arms) and the fact that the cue and location modalities are split out into separate modalities. We set cue validity (cue probability) to 1.0, use policy_len = 4 (to account for the additional planning depth required for navigating the T-Maze), and observation_prune_threshold = 1e-4 for the SI search.
from pymdp.envs import TMaze
Setting up the T-Maze environment (Generative Process)¶
The rules/hyperparameters of the T-Maze are identical to the TMaze above, except that the generative process parameters (states and observations) are slightly different due to the factorization of the joint cue+location modality (embedded together in SimplifiedTMaze) into two separate modalities.
Click here to see how the generative process for the `TMaze` env (non-simplified) is set up.
States and Observations¶
State Factors:
- Location (5 states):
- 0: center (start location)
- 1: left arm
- 2: right arm
- 3: cue location (bottom arm)
- 4: middle (junction between center and arms)
- Reward Location (2 states):
- 0: reward in left arm
- 1: reward in right arm
Control State Factors:
- Location (5 actions):
- 0: move to center (start location)
- 1: move to left arm
- 2: move to right arm
- 3: move to cue location (bottom arm)
- 4: move to middle (junction)
Observation Modalities:
- Location (5 observations):
- 0: center (start location)
- 1: left arm
- 2: right arm
- 3: cue location (bottom arm)
- 4: middle (junction)
- Outcome (3 observations):
- 0: no outcome
- 1: reward (cheese)
- 2: punishment (shock)
- Cue (3 observations):
- 0: no cue
- 1: cue indicates left arm
- 2: cue indicates right arm
Environment Parameters¶
Observation Likelihood Model (A):
- A[0]: Location observations (5x5 tensor)
- Perfect mapping between true and observed locations.
- A[1]: Outcome observations (3x5x2 tensor)
- In the rewarding arm (set by
reward_condition), reward is presented with a likelihood determined by thereward_probabilityparameter. - Punishment and/or no-outcome are presented with a likelihood determined depending on if
dependent_outcomesis True or False and consequently by thepunishment_probabilityparameter. - No-outcome is observed in the center/start location, cue location, and middle.
- In the rewarding arm (set by
- A[2]: Cue observations (3x5x2 tensor)
- The cue is only observed at the cue location, with accuracy set by
cue_validity. - No cue is observed at all other locations.
- The cue is only observed at the cue location, with accuracy set by
Transition Model (B):
- B[0]: Location transitions (5x5x5 tensor)
- Agent can move between adjacent locations in the T-maze; invalid moves leave the agent in place.
- B[1]: Reward location (2x2x1 tensor)
- Reward location remains fixed throughout trial.
Initial Conditions (D):
- D[0]: Starting location (5x1 tensor)
- Agent always begins in center location
- D[1]: Reward placement (2x1 tensor)
- Default: Equal chance (50/50) of reward in either arm (
reward_condition=None) - Optional: Can fix reward to specific arm, by setting
reward_conditionto0(for left arm) or1(for right arm)
- Default: Equal chance (50/50) of reward in either arm (
# setting the parameters for the environment
reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
cue_validity = 1.0 # 100% valid cues (cue probability)
reward_probability = 1.0 # 100% chance of reward in the correct arm
dependent_outcomes = True # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm depending on the punishment_probability parameter)
punishment_probability = 1.0 # 100% chance of punishment in the other arm
# initializing the environment. see tmaze.py in pymdp/envs for the implementation details
env = TMaze(
reward_condition=reward_condition,
cue_validity=cue_validity,
reward_probability=reward_probability,
punishment_probability=punishment_probability,
dependent_outcomes=dependent_outcomes,
)
Setting up the Agents¶
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes for the Agent
C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in env.A]
# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0) # prefer reward
C[1] = C[1].at[2].set(-6.0) # avoid punishment
# slight cost of observing a cue
C[2] = C[2].at[1].set(-0.5)
C[2] = C[2].at[2].set(-0.5)
# D tensors [location], [reward] based on B shapes for the agent
# - agent starts in the center location
# - reward location prior is uniform
D_loc = jnp.zeros(env.B[0].shape[0], dtype=jnp.float32)
D_loc = D_loc.at[0].set(1.0)
D_reward = jnp.ones(env.B[1].shape[0], dtype=jnp.float32)
D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True)
D = [D_loc, D_reward]
# note that we initialize agents with different policy lengths for the vanilla vs sophisticated inference planning algorithms
# even though both will eventually end up planning with a horizon of 4. The sophisticated inference planning algorithm requires
# a policy length of 1 in the Agent as we specify horizon length of 4 when initializing the planning algorithm in the `rollout`.
# action_selection="deterministic" means selecting an action from the policy probability distribution (q_pi) by arg-maxxing
# sampling_mode="full" means evaluating the whole action sequence in each policy and executing the first action (as opposed to marginal where the agent evaluates each action type)
gamma = 3.0
policy_len = 4
agent_vanilla = Agent(
env.A, env.B, C, D,
A_dependencies=env.A_dependencies,
B_dependencies=env.B_dependencies,
policy_len=policy_len,
learn_A=False,
learn_B=False,
action_selection="deterministic",
sampling_mode="full",
gamma=gamma,
)
agent_si = Agent(
env.A, env.B, C, D,
A_dependencies=env.A_dependencies,
B_dependencies=env.B_dependencies,
policy_len=1,
learn_A=False,
learn_B=False,
action_selection="deterministic",
sampling_mode="full",
gamma=gamma,
)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/2056066623.py:10: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent_vanilla = Agent( /var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61048/2056066623.py:22: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent_si = Agent(
Running the active inference rollouts¶
key = jr.PRNGKey(0)
T = 5
si_search = si_policy_search(
horizon=policy_len, # plans 4 timesteps ahead
max_nodes=5000, # maximum number of nodes allowed in the tree
max_branching=45, # maximum number of children allowed per node (moderating the branching factor)
policy_prune_threshold=0.0, # no pruning of unlikely policies
observation_prune_threshold=1e-4, # no pruning of unlikely observations
entropy_stop_threshold=0.0, # disabling halting of expansion if agent is certain enough
neg_efe_stop_threshold=1e10, # disabling efe value based halting of expansion
kl_threshold=-1, # disabling node reuse if agent is in similar states after an action
prune_penalty=512, # default value for prune penalty
gamma=gamma, # temperature parameter; lower value (---> 1) prunes policies less aggressively as probabilities are flattened while higher value (---> 16) prunes more aggressively
topk_obsspace=10000, # max number of top observation combinations - this default value just means we want to consider all the observation combinations
)
_, info_vanilla = rollout(agent_vanilla, env, num_timesteps=T, rng_key=key) # default policy search is vanilla
_, info_si = rollout(agent_si, env, num_timesteps=T, rng_key=key, policy_search=si_search)
def make_gif(info):
frames = []
for t in range(info["observation"][0].shape[1]): # iterate over timesteps
# get observations for this timestep
observations_t = [
info["observation"][0][:, t, :],
info["observation"][1][:, t, :],
info["observation"][2][:, t, :],
]
frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
frame = np.asarray(frame, dtype=np.uint8)
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
make_gif(info_vanilla)
make_gif(info_si)
Result analysis¶
# qpi is a posterior over whole policies (action sequences).
# To get the probability of the *current* action, we marginalize over policies
# that share the same first action and sum their qpi values.
# helper functions for:
# - printing out policies and respective probabilities of selecting those policies
# - printing out action and observation info for each timestep
np.set_printoptions(precision=2, suppress=True)
def print_qpi(agent, info, print_t1=False):
qpi_values = info["qpi"]
action_names = {
0: "move to center",
1: "move to left arm",
2: "move to right arm",
3: "move to cue",
4: "move to middle",
}
max_timesteps = 1 if print_t1 else qpi_values.shape[1]
# unique_multiactions returns the unique first-step actions across policies
# for a single control factor, this is just a list of action indices
unique_actions = agent.unique_multiactions[:, 0]
for t in range(max_timesteps):
print(f"Timestep {t}:")
action_probs = agent.multiaction_probabilities(qpi_values[:, t, :])[0]
for action_idx, total_prob in zip(unique_actions.tolist(), action_probs.tolist()):
if action_idx < 0:
continue
action_name = action_names.get(action_idx, f"action_{action_idx}")
print(f" {action_name}: {total_prob:.3f}")
print()
def print_agent_behavior(info):
action_names = {
0: "move to center",
1: "move to left arm",
2: "move to right arm",
3: "move to cue",
4: "move to middle",
}
location_obs = {
0: "center loc",
1: "left arm loc",
2: "right arm loc",
3: "cue loc",
4: "middle loc",
}
outcome_obs = {0: "no_outcome", 1: "reward", 2: "punishment"}
cue_obs = {0: "no cue", 1: "cue-left", 2: "cue-right"}
actions = info["action"]
observations = info["observation"]
num_timesteps = actions.shape[1]
for t in range(num_timesteps):
action_idx = int(actions[0, t, 0]) # [batch, timestep, action_dim]
action_name = action_names.get(action_idx, f"action_{action_idx}")
location_obs_idx = int(observations[0][0, t, 0]) # [modality][batch, timestep, obs_dim]
outcome_obs_idx = int(observations[1][0, t, 0])
cue_obs_idx = int(observations[2][0, t, 0])
location_name = location_obs.get(location_obs_idx)
outcome_name = outcome_obs.get(outcome_obs_idx)
cue_name = cue_obs.get(cue_obs_idx)
print(f"t={t}: observed=({location_name}, {outcome_name}, {cue_name}) -> action={action_name}")
With the classic T-maze's adjacency constraints (and the extra middle location), reaching the reward after sampling the cue takes multiple steps. With the current horizon (4) and T=5, you should see cue-seeking and movement toward the chosen arm; increase policy_len and/or T if you want longer sequences.
print_agent_behavior(info_vanilla)
t=0: observed=(center loc, no_outcome, no cue) -> action=move to cue t=1: observed=(cue loc, no_outcome, cue-left) -> action=move to center t=2: observed=(center loc, no_outcome, no cue) -> action=move to middle t=3: observed=(middle loc, no_outcome, no cue) -> action=move to left arm t=4: observed=(left arm loc, reward, no cue) -> action=move to center t=5: observed=(left arm loc, reward, no cue) -> action=move to center
print_agent_behavior(info_si)
t=0: observed=(center loc, no_outcome, no cue) -> action=move to cue t=1: observed=(cue loc, no_outcome, cue-left) -> action=move to center t=2: observed=(center loc, no_outcome, no cue) -> action=move to middle t=3: observed=(middle loc, no_outcome, no cue) -> action=move to left arm t=4: observed=(left arm loc, reward, no cue) -> action=move to center t=5: observed=(left arm loc, reward, no cue) -> action=move to center
Now, to see how the two planning algorithms differ, let's examine how they evaluate policies...
With cue_validity = 1.0, both planners should prefer information gathering, but the SI planner typically yields a sharper (more confident) posterior over policies. Compare how each planner's probabilities shift with the longer horizon.
print_qpi(agent_vanilla, info_vanilla, print_t1=True)
Timestep 0: move to center: 0.142 move to left arm: 0.142 move to right arm: 0.142 move to cue: 0.539 move to middle: 0.036
print_qpi(agent_si, info_si, print_t1=True)
Timestep 0: move to center: 0.001 move to left arm: 0.001 move to right arm: 0.001 move to cue: 0.810 move to middle: 0.186