T-Maze Demo¶
In this notebook, we simulate an active inference agent solving the T-Maze task - a type of two-armed contextual bandit. Note that unlike in earlier demos in the legacy version of the pymdp package which has a NumPy backend, this demo relies on pymdp's new jax backend. The JAX backend accelerates pymdp by dispatching the core inference, learning and planning computations to CUDA-compatible GPU devices (if available). This, in combination with JAX's just-in-time (JIT) compilation features, enables pymdp to take advantage of batch processing (i.e., running many AIF processes in parallel) and increased memory-usage / speed, especially recurrent operations (e.g., iterative inference routines or processes that run over time).
The T-Maze Task¶
The T-Maze task implemented in this notebook is adapted from the sophisticated inference paper, and was originally introduced in an active inference context in "Active Inference and Epistemic Value". The T-maze is a two-armed contextual bandit: at any given time, the agent can choose between sampling a cue (context) or choosing between two reward arms (left vs. right).” This task represents a classic problem in sequential decision-making, where an agent (in this case, analogized to a rat) must navigate a T-shaped maze. The agent starts at the centre of the T-maze. Within either the left or right arm, there is either a preferred (i.e., rewarding; cheese) stimulus or an aversive (i.e., punishing; shock) stimulus, with these reward contingencies initially unknown to the agent. In the bottom part of the T-Maze, a cue provides information about the which arm the rewarding stimulus is in.
The agent is faced with a dilemma: commit to one of the potentially rewarding arms or first seek information from the cue to identify the more rewarding option before taking action. We use the term "cue validity" to indicate the probability that the cue correctly indicates the reward's location. If the cue has information about which of the two arms is more rewarded (i.e., the cue validity is greater than 50%), then the optimal behavior entails first visiting the cue arm and then choosing one of the two reward arms.
Outline of this notebook¶
This notebook steps through the following variants of the T-Maze Demo:¶
1. A deterministic generative process (environment), and a single agent solving the task with standard (non-sophisticated) active inference.¶
2. A noisy generative process, and a single agent solving the task with standard active inference.¶
3. A noisy generative model with A and B learning, with correct and incorrect prior structure of those parameters, and a single agent solving the task with standard active inference.¶
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp[nb]" -q
# a way to edit and run code and see the effects in the notebook without having to restart the kernel
%load_ext autoreload
%autoreload 2
# importing necessary libraries
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import mediapy
from jax import random as jr
from pymdp.envs import TMaze, rollout
from pymdp.agent import Agent
from pymdp.maths import dirichlet_expected_value
Creating the T-Maze Task (Generative Process)¶
- The
reward_conditionparameter determines the reward location:0for the left arm,1for the right arm, orNonefor random allocation. - The
reward_probabilityparameter sets the chance of receiving a reward in the correct arm. For example, if set at 0.80, there would be an 80% chance of reward and 20% chance of no outcome in the rewarding arm. - The
punishment_probabilityparameter specifies the likelihood of punishment in the other arm. For example, if set at 0.80, there would be an 80% chance of punishment and 20% chance of no outcome in the non-rewarding arm. - The
cue_validityparameter represents the accuracy of the cues as a probability between 0 and 1.
Click here to see how the generative process is set up.
States and Observations¶
State Factors:
- Location (5 states):
- 0: centre (start location)
- 1: left arm
- 2: right arm
- 3: cue location (bottom)
- 4: middle of arms (between left and right arm)
- Reward Location (2 states):
- 0: reward in left arm
- 1: reward in right arm
Observation Modalities:
- Location (5 observations):
- Matches the location states exactly
- Outcome (3 observations):
- 0: no outcome
- 1: reward (cheese)
- 2: punishment (shock)
- Cue (3 observations):
- 0: no cue
- 1: left arm cued
- 2: right arm cued
Environment Parameters¶
Observation Likelihood Model (A):
- A[0]: Location observations (5x5 tensor)
- Perfect mapping between true and observed location.
- A[1]: Outcome observations (3x5x2 tensor)
- In the more rewarding arm, reward is presented with a likelihood determined by the
reward_probabilityparameter. - In the less rewarding arm, punishment is presented with a likelihood determined by the
punishment_probabilityparameter. - No outcomes are observed in the centre/start location, cue location, or middle of the two arms.
- In the more rewarding arm, reward is presented with a likelihood determined by the
- A[2]: Cue observations (3x5x2 tensor)
- Indicating the reward location, at the cue location (bottom), with accuracy set by the
cue_validityparameter. - No cues visible elsewhere.
- Indicating the reward location, at the cue location (bottom), with accuracy set by the
Transition Model (B):
- B[0]: Location transitions (5x5x5 tensor)
- Agent can move between adjacent maze cells or stay in the same cell.
- B[1]: Reward location (2x2x1 tensor)
- Reward location remains fixed throughout trial.
Initial Conditions (D):
- D[0]: Starting location (5x1 tensor)
- Agent always begins in centre location
- D[1]: Reward placement (2x1 tensor)
- Default: Equal chance (50/50) of reward in either arm (
reward_condition=None) - Optional: Can fix reward to specific arm, by setting
reward_conditionto0(for left arm) or1(for right arm)
- Default: Equal chance (50/50) of reward in either arm (
# setting the parameters for the environment
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 1.0 # 100% chance of reward in the correct arm
punishment_probability = 1.0 # 100% chance of punishment in the other arm
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm)
# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze(
reward_probability=reward_probability,
punishment_probability=punishment_probability,
cue_validity=cue_validity,
reward_condition=reward_condition,
dependent_outcomes=dependent_outcomes
)
# you may print the environment parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code:
# print([a.shape for a in env.params["A"]]) # shape of all A tensors; the shape should start with the batch_size, then the rows, columns, and additional dimensions for the dependencies
# print(env.params["A"][1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm)
# print(env.params["A"][2][0][:,:,0]) # likelihood of observing no cue, left arm cued, or right arm cued (rows), in each location (columns), when the reward condition is 0 (left arm)
# print([b.shape for b in env.params["B"]]) # shape of all B tensors
# print(env.params["B"][0][0][:,:,4]) # probability of transitioning to each location (rows), from each location (columns), when the agent wants to move to the middle of the arms (location 4)
1. A deterministic generative process (environment), and a single agent solving the task with standard active inference.¶
Creating the Agent (Generative Model)¶
We will create the agent's generative model based off the true observation and transition models (encoded by tensors A and B, respectively) of the environment. In other words, we assume the agent knows how the environment works - i.e., it knows that the likelihood of observing a reward in the left arm is 1.0 (reward_probability=1.0) if the reward is actually in the left arm (reward_condition=0), and it knows the cues are 100% accurate (cue_validity=1.0), and it knows that the reward location will be fixed throughout the trial (non-volatile environment). We can of course change these assumptions to create a more complex agent, and we will do that in the next sections where the environment will be more stochastic and we will also add uncertainty to the agent's generative model so it will have to learn that the environment is deterministic or not.
The preference tensors (C) are set using the known dimensions of the outcome modalities, which we can infer from the A tensor's shape. We fix the agent's preferences a priori to prefer reward and avoid punishment. The agent is set to not have any preference to observe certain locations and cues.
The initial beliefs tensors (D; i.e., state priors) are set using the known dimensions of the hidden state factors, which we can infer from the B tensor's shape. The agent has the prior belief that it starts in the center location and it has no prior about the reward location - i.e., the prior for the reward location is uniformly distributed.
# setting A tensors from the environment parameters
A = env.A
A_dependencies = env.A_dependencies # dependencies allow you to specify which state factors each observation modality depends on, so you dont have to store all the conditional dependencies between all state factors and each modality
# setting B tensors from the environment parameters
B = env.B
B_dependencies = env.B_dependencies
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in A]
# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0) # prefer reward
C[1] = C[1].at[2].set(-3.0) # avoid punishment
# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros(B[0].shape[0], dtype=jnp.float32)
D_loc = D_loc.at[0].set(1.0) # set centre location to 1.0
D.append(D_loc)
# D[1]: reward location - uniform distribution
D_reward = jnp.ones(B[1].shape[0], dtype=jnp.float32)
D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True) # normalise to get uniform distribution
D.append(D_reward)
# initialising the agent
agent = Agent(
A, B, C, D,
policy_len=2, # how long the action sequence is that the agent is evaluating
A_dependencies=A_dependencies,
B_dependencies=B_dependencies,
learn_A=False,
learn_B=False
)
# you may print the agent's generative model parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code:
print(f'Number of generated agents: {agent.batch_size}\n')
print([a.shape for a in agent.A]) # shape of all A tensors
print(agent.A[1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm)
print(agent.C[1]) # preferences for outcomes
Number of generated agents: 1 [(1, 5, 5), (1, 3, 5, 2), (1, 3, 5, 2)] [[1. 0. 0. 1. 1.] [0. 0. 1. 0. 0.] [0. 1. 0. 0. 0.]] [[ 0. 2. -3.]]
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/343300159.py:29: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(
Running the active inference agent¶
key = jr.PRNGKey(0) # random key for seeding reproducible stochasticity in the AIF loop (e.g., for sampling any random processes in the TMaze, and for sampling actions from the agent's chosen policy).
T = 10 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
# you may print the info dictionary to see the numerical results of the aif agent completing the T-maze task by editing and uncommenting the following lines and running the code:
# print(info.keys()) # keys in the info dictionary
# print(info["action"][:,0,:]) # actions taken by the agent (locations throughout the maze)
# print(info["observation"][0]) # observations of the locations for each batch
# print(info["observation"][2]) # observations of the cues for each batch
# print(jnp.around(info["qs"][1], decimals=2)) # posterior beliefs about the reward location
# print(info["qpi"][0].shape) # shape of the policy tensor
Rendering the Task to Visualise the Agent's Behaviour
frames = []
for t in range(info["observation"][0].shape[1]): # iterate over timesteps
# get observations for this timestep
observations_t = [
info["observation"][0][:, t, :],
info["observation"][1][:, t, :],
info["observation"][2][:, t, :]
]
frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
frame = np.asarray(frame, dtype=np.uint8)
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
# # uncomment the following lines to save the video as a gif
# os.makedirs("figures", exist_ok=True)
# pil_frames = [Image.fromarray(frame) for frame in frames]
# reward_location = "random" if reward_condition is None else ("left" if reward_condition == 0 else "right")
# filename = os.path.join("figures", f"tmaze_{batch_size}_{reward_location}.gif")
# pil_frames[0].save(
# filename,
# save_all=True,
# append_images=pil_frames[1:],
# duration=1000, # 1000ms per frame
# loop=0
# )
Running multiple agents in parallel using the batch_size parameter¶
The active inference computations taking place within the the Agent methods (inference, planning/action selection, and learning) interally expect most arrays to have an additional leading dimension, with length batch_size. This extra dimension can be used to generate and run multiple agents in parallel, each with their own generative model, observation and action history, and posterior beliefs. For instance, if batch_size = 10, then each A, B, C, ... array will represent the corresponding generative model parameters for 10 different agents, and each agent will receive its own history of observations / emit its own actions. The single agent (default) case is therefore the case of batch_size = 1. If you pass in generative model arrays to the Agent constructor without a (batch_size,)-sized leading dimension (as we did in the previous example), they will be expanded internally to have this dimension, creating batch_size identical copies of the provided model parameters.
One can also run multiple environments in parallel by passing a separate, stateful env_params pytree which stores batch_size different A, B, and D arrays, but for now we are using a single fixed environment. This is the default case where the environmental attributes env.A, env.B, and env.D arrays have no leading (batch_size,) dimension, even while the Agent class does have it.
batch_size = 9 # number of agents to run in parallel
# setting A tensors from the environment parameters
A, B = env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies # dependencies allow you to specify the state factors the observation modality depends on so you dont have to compute the full tensor using all state factors
# now expand A and B to have a new batch_size-length leading dimension
A = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in A]
B = [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in B]
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A]
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(2.0) # prefer reward
C[1] = C[1].at[:,2].set(-3.0) # avoid punishment
# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32)
D_loc = D_loc.at[0,0].set(1.0) # set centre location to 1.0
D.append(D_loc)
# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32)
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True) # normalise to get uniform distribution
D.append(D_reward)
# initialising the agent
agent = Agent(
A, B, C, D,
policy_len=2,
A_dependencies=A_dependencies,
B_dependencies=B_dependencies,
batch_size=batch_size, # note we have to pass in this batch_size parameter so the class knows to treat the first dimension as the batch size
learn_A=False,
learn_B=False
)
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
# rendering the task to visualise the agent's behaviour
frames = []
for t in range(info["observation"][0].shape[1]): # iterate over timesteps
# get observations for this timestep
observations_t = [
info["observation"][0][:, t, :],
info["observation"][1][:, t, :],
info["observation"][2][:, t, :]
]
frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
frame = np.asarray(frame, dtype=np.uint8)
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/1007283590.py:32: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(
2. A noisy generative process, and a single agent solving the task with standard active inference.¶
Making the world more stochastic¶
In this variant, we will once again give the agent true knowledge of the parameters of the generative process (by fixing the A and B tensors of the agent to those of the true generative process), but now we will add stochasticity to the true generative process parameters, such that the agent is not guaranteed to get reward in the rewarded-arm every time, but only more likely to get reward than punishment there. This stochasticity can be modulated using the reward_probability argument to the TMaze() environment. We can likewise modulate the probability of seeing punishment in the punishment arm, by changing the punishment_probability argument to the TMaze environment. Finally, another source of uncertainty is the cue reliability or validity, which encodes the likelihood that the signal observed in the cue location of the TMaze, accurately identifies which arm is the more-rewarding vs. more-punishing arm. This probability (of the cue signalling the 'correct arm') can be modulated using the cue_validity argument to the TMaze constructor.
# THE GENERATIVE PROCESS (NOISY)
# setting the parameters for the environment
batch_size = 4 # number of environments/agents to run in parallel
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 0.7 # 70% chance of reward in the correct arm
punishment_probability = 0.6 # 60% chance of punishment in the other arm
cue_validity = 0.95 # 90% valid cues
dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm)
# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze(
reward_probability=reward_probability,
punishment_probability=punishment_probability,
cue_validity=cue_validity,
reward_condition=reward_condition,
dependent_outcomes=dependent_outcomes
)
# THE GENERATIVE MODEL
# Get A, B, and dependencies from the environment
A, B = env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies # dependencies allow you to specify
# now expand A and B to have a new batch_size-length leading dimension
A = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in A]
B = [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in B]
# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A]
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(3.0) # prefer reward
C[1] = C[1].at[:,2].set(-4.0) # avoid punishment
# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32)
D_loc = D_loc.at[0,0].set(1.0) # set centre location to 1.0
D.append(D_loc)
# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32)
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True) # normalise to get uniform distribution
D.append(D_reward)
# initialising the agent
agent = Agent(
A, B, C, D,
policy_len=3, # how long the action sequence is that the agent is evaluating
A_dependencies=A_dependencies,
B_dependencies=B_dependencies,
batch_size=batch_size,
learn_A=False,
learn_B=False
)
key = jr.PRNGKey(0) # random key for the aif loop
T = 20 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/34467052.py:32: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(
frames = []
for t in range(info["observation"][0].shape[1]): # iterate over timesteps
# get observations for this timestep
observations_t = [
info["observation"][0][:, t, :],
info["observation"][1][:, t, :],
info["observation"][2][:, t, :]
]
frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
frame = np.asarray(frame, dtype=np.uint8)
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
3. Agent is equipped with a partially-uncertain model of the true generative process, and can also perform parameter learning to update its beliefs about outcome and transition contingencies over time.¶
You can tweak the following arguments to the Agent constructor: learn_A and learn_B between True and False -- this selectively turns on/off the ability to update the parameters of the Dirichlet posterior over the A and B tensors, respectively. These functionalities should be used when your agent does not have a correct or confident model of the world, and you want the agent to learn its model over time.
As an example, we will do the following two variants in this section:
1 We initialize the agent to have partially-uncertain beliefs about the A and B parameters, but the mode of their beliefs (i.e. the value with the highest probability) is still based on the true generative process.
2 We will equip the agent with incorrect beliefs about the generative process.
Correct Prior Structure¶
# setting the parameters for the environment
reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 1.0 # 100% chance of reward in the correct arm
punishment_probability = 1.0 # 100% chance of punishment in the other arm
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = False
# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze(
reward_probability=reward_probability,
punishment_probability=punishment_probability,
cue_validity=cue_validity,
reward_condition=reward_condition,
dependent_outcomes=dependent_outcomes
)
We can encode uncertainty in the agent's prior knowledge about the state-outcome contingencies by adding non-zero parameter values to the pA tensor (the prior Dirichlet parameters over the Categorical parameters of the A tensor). In practice, we still put 'true knowledge' into this prior, in the sense that the mode of the likelihood distribution aligned with the true generative process parameters, but we add some uncertainty by adding non-zero values into the parts of the Dirichlet prior tensor where they don't exist in the true generative process parameters; this encodes the agent's uncertainty about whether certain relationships between hidden states (e.g., location states or reward-condition states) and outcomes (cue signals or reward outcomes) exist.¶
# initializing the prior over the A tensors (pA) based on the environment parameters
num_obs = [a.shape[0] for a in env.A] # number of observations for each modality
num_states = [b.shape[0] for b in env.B] # number of states for each factor
# initializing pA, pB to be identical to the environmental parameters
pA, pB = [jnp.copy(a) for a in env.A], [jnp.copy(b) for b in env.B] # have to copy them here because we will be modifying pA below
# adding non-zero Dirichlet parameter values to pA flatten the expected value of the A tensor (make more uncertain)
alpha_scale_pA = 0.3
for m in [1, 2]: # only modifying the outcome (m=1) and cue (m=2) observation likelihood mappings
pA[m] = pA[m] + (alpha_scale_pA * jnp.ones_like(pA[m]))
expected_A = [dirichlet_expected_value(pa_m, event_dim=0) for pa_m in pA]# compute the expected value of the Dirichlet parameters in pA
expected_B = [dirichlet_expected_value(pb_f, event_dim=0) for pb_f in pB]# compute the expected value of the Dirichlet parameters in pB
# creating C tensors filled with zeros for [location], [reward], [cue] based on modality shapes
C = [jnp.zeros(no, dtype=jnp.float32) for no in num_obs]
# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0) # prefer reward
C[1] = C[1].at[2].set(-3.0) # avoid punishment
# creating D tensors [location], [reward] based on B shapes
D = [jnp.zeros(ns, dtype=jnp.float32) for ns in num_states] # creating D tensors filled with zeros for [location], [reward] based on B shapes
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros(num_states[0], dtype=jnp.float32)
D_loc = D_loc.at[0].set(1.0) # set centre location to 1.0
D[0] = D_loc
# D[1]: reward location - uniform distribution
D_reward = jnp.ones(num_states[1], dtype=jnp.float32)
D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True) # normalise to get uniform distribution
D[1] = D_reward
# initialising the agent
# NOTE: We initialize the first value of A and B based on the expected values of the Dirichlet parameters in pA and pB, respectively.
agent = Agent(
expected_A, expected_B, C, D,
pA=pA,
pB=pB,
policy_len=5, # how long the action sequence is that the agent is evaluating
A_dependencies=env.A_dependencies,
B_dependencies=env.B_dependencies,
learn_A=True,
learn_B=False,
gamma=1.0,
action_selection="stochastic"
)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/1366977802.py:37: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(
key = jr.PRNGKey(0) # random key for the aif loop
T = 10 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
We can print out the A tensors over time to see how the agent updates its beliefs about the state-outocome contingencies based on the data its collecting through active inference.
print("the environment's A tensor")
print(env.A[1][:,:,1])
print()
print("the agent's A tensor at t=0")
print(agent.A[1][0][:,:,1])
print()
print(f"the agent's A tensor at t={T}")
print(info["A"][1][0,-1,:,:,1])
the environment's A tensor [[1. 0. 0. 1. 1.] [0. 0. 1. 0. 0.] [0. 1. 0. 0. 0.]] the agent's A tensor at t=0 [[0.68421054 0.15789475 0.15789476 0.68421054 0.68421054] [0.15789476 0.15789475 0.68421054 0.15789476 0.15789476] [0.15789476 0.68421054 0.15789476 0.15789476 0.15789476]] the agent's A tensor at t=10 [[0.86867255 0.15789475 0.15789476 0.72054607 0.6926888 ] [0.06566371 0.15789475 0.68421054 0.13972697 0.15365562] [0.06566371 0.68421054 0.15789476 0.13972697 0.15365562]]
frames = []
for t in range(info["observation"][0].shape[1]): # iterate over timesteps
# get observations for this timestep
observations_t = [
info["observation"][0][:, t, :],
info["observation"][1][:, t, :],
info["observation"][2][:, t, :]
]
frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
frame = np.asarray(frame, dtype=np.uint8)
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
# Random initialization test
key = jr.PRNGKey(24)
# setting A tensors from the environment parameters, but randomize the prior over the transitions
pA, pB = [jnp.copy(a) for a in env.A], [jr.uniform(subkey, shape=b.shape) for subkey, b in zip(jr.split(key, num=2), env.B)]
# adding noise to outcome and cue modalities to make priors more uncertain
key1, key2 = jr.split(key, num=2)
for(key_m, modality) in zip([key1, key2], [1, 2]): # only modifying the outcome (m=1) and cue (m=2) observation likelihood mappings
pA[modality] = jr.uniform(key_m, shape=pA[modality].shape) # generating random values between 0 and 1
expected_A = [dirichlet_expected_value(pa_m, event_dim=0) for pa_m in pA] # compute the expected value of the Dirichlet parameters in pA
expected_B = [dirichlet_expected_value(pb_f, event_dim=0) for pb_f in pB]# compute the expected value of the Dirichlet parameters in pB
# initialising the agent
agent = Agent(
expected_A, expected_B, C, D,
pA=pA,
pB=pB,
policy_len=5, # how long the action sequence is that the agent is evaluating
A_dependencies=env.A_dependencies,
B_dependencies=env.B_dependencies,
batch_size=batch_size,
learn_A=True,
learn_B=True,
gamma=0.1,
action_selection="stochastic"
)
# running the active inference simulation
key = jr.PRNGKey(0) # random key for the aif loop
T = 50 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/3157206125.py:16: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(
print("the environment's A tensor")
print(env.A[1][:,:,1])
print()
print("the agent's A tensor at t=0")
print(agent.A[1][0][:,:,1])
print()
print(f"the agent's A tensor at t={T}")
print(info["A"][1][0,-1,:,:,1])
the environment's A tensor [[1. 0. 0. 1. 1.] [0. 0. 1. 0. 0.] [0. 1. 0. 0. 0.]] the agent's A tensor at t=0 [[0.00308274 0.07115752 0.55618274 0.42624453 0.36277273] [0.02319651 0.4746476 0.31333998 0.5615474 0.43671665] [0.9737207 0.45419487 0.13047728 0.01220808 0.20051058]] the agent's A tensor at t=50 [[0.01081832 0.06939295 0.55391985 0.4478869 0.3744181 ] [0.02301651 0.48767537 0.31206512 0.5403655 0.4287357 ] [0.9661652 0.4429317 0.134015 0.01174758 0.19684626]]
print("the environment's B tensor")
print(env.B[1][:,:,1])
print()
print("the agent's B tensor at t=0")
print(agent.B[1][0][:,:,1])
print()
print(f"the agent's B tensor at t={T}")
print(info["B"][1][0,-1,:,:,1])
the environment's B tensor [[1. 0.] [0. 1.]] the agent's B tensor at t=0 [[0.8990159 0.44466075] [0.10098408 0.5553392 ]] the agent's B tensor at t=50 [[0.9954207 0.6131649 ] [0.0045793 0.38683507]]
frames = []
for t in range(info["observation"][0].shape[1]): # iterate over timesteps
# get observations for this timestep
observations_t = [
info["observation"][0][:, t, :],
info["observation"][1][:, t, :],
info["observation"][2][:, t, :]
]
frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
frame = np.asarray(frame, dtype=np.uint8)
plt.close() # close the figure to prevent memory leak
frames.append(frame)
frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)