pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

T-Maze Demo¶

In this notebook, we simulate an active inference agent solving the T-Maze task - a type of two-armed contextual bandit. Note that unlike in earlier demos in the legacy version of the pymdp package which has a NumPy backend, this demo relies on pymdp's new jax backend. The JAX backend accelerates pymdp by dispatching the core inference, learning and planning computations to CUDA-compatible GPU devices (if available). This, in combination with JAX's just-in-time (JIT) compilation features, enables pymdp to take advantage of batch processing (i.e., running many AIF processes in parallel) and increased memory-usage / speed, especially recurrent operations (e.g., iterative inference routines or processes that run over time).

The T-Maze Task¶

The T-Maze task implemented in this notebook is adapted from the sophisticated inference paper, and was originally introduced in an active inference context in "Active Inference and Epistemic Value". The T-maze is a two-armed contextual bandit: at any given time, the agent can choose between sampling a cue (context) or choosing between two reward arms (left vs. right).” This task represents a classic problem in sequential decision-making, where an agent (in this case, analogized to a rat) must navigate a T-shaped maze. The agent starts at the centre of the T-maze. Within either the left or right arm, there is either a preferred (i.e., rewarding; cheese) stimulus or an aversive (i.e., punishing; shock) stimulus, with these reward contingencies initially unknown to the agent. In the bottom part of the T-Maze, a cue provides information about the which arm the rewarding stimulus is in.

The agent is faced with a dilemma: commit to one of the potentially rewarding arms or first seek information from the cue to identify the more rewarding option before taking action. We use the term "cue validity" to indicate the probability that the cue correctly indicates the reward's location. If the cue has information about which of the two arms is more rewarded (i.e., the cue validity is greater than 50%), then the optimal behavior entails first visiting the cue arm and then choosing one of the two reward arms.

Outline of this notebook¶

This notebook steps through the following variants of the T-Maze Demo:¶

1. A deterministic generative process (environment), and a single agent solving the task with standard (non-sophisticated) active inference.¶

2. A noisy generative process, and a single agent solving the task with standard active inference.¶

3. A noisy generative model with A and B learning, with correct and incorrect prior structure of those parameters, and a single agent solving the task with standard active inference.¶

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp[nb]" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp[nb]" -q
In [ ]:
Copied!
# a way to edit and run code and see the effects in the notebook without having to restart the kernel
%load_ext autoreload
%autoreload 2

# importing necessary libraries
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import mediapy

from jax import random as jr
from pymdp.envs import TMaze, rollout
from pymdp.agent import Agent
from pymdp.maths import dirichlet_expected_value
# a way to edit and run code and see the effects in the notebook without having to restart the kernel %load_ext autoreload %autoreload 2 # importing necessary libraries import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import mediapy from jax import random as jr from pymdp.envs import TMaze, rollout from pymdp.agent import Agent from pymdp.maths import dirichlet_expected_value

Creating the T-Maze Task (Generative Process)¶

  • The reward_condition parameter determines the reward location: 0 for the left arm, 1 for the right arm, or None for random allocation.
  • The reward_probability parameter sets the chance of receiving a reward in the correct arm. For example, if set at 0.80, there would be an 80% chance of reward and 20% chance of no outcome in the rewarding arm.
  • The punishment_probability parameter specifies the likelihood of punishment in the other arm. For example, if set at 0.80, there would be an 80% chance of punishment and 20% chance of no outcome in the non-rewarding arm.
  • The cue_validity parameter represents the accuracy of the cues as a probability between 0 and 1.
Click here to see how the generative process is set up.

States and Observations¶

State Factors:

  1. Location (5 states):
    • 0: centre (start location)
    • 1: left arm
    • 2: right arm
    • 3: cue location (bottom)
    • 4: middle of arms (between left and right arm)
  2. Reward Location (2 states):
    • 0: reward in left arm
    • 1: reward in right arm

Observation Modalities:

  1. Location (5 observations):
    • Matches the location states exactly
  2. Outcome (3 observations):
    • 0: no outcome
    • 1: reward (cheese)
    • 2: punishment (shock)
  3. Cue (3 observations):
    • 0: no cue
    • 1: left arm cued
    • 2: right arm cued

Environment Parameters¶

Observation Likelihood Model (A):

  • A[0]: Location observations (5x5 tensor)
    • Perfect mapping between true and observed location.
  • A[1]: Outcome observations (3x5x2 tensor)
    • In the more rewarding arm, reward is presented with a likelihood determined by the reward_probability parameter.
    • In the less rewarding arm, punishment is presented with a likelihood determined by the punishment_probability parameter.
    • No outcomes are observed in the centre/start location, cue location, or middle of the two arms.
  • A[2]: Cue observations (3x5x2 tensor)
    • Indicating the reward location, at the cue location (bottom), with accuracy set by the cue_validity parameter.
    • No cues visible elsewhere.

Transition Model (B):

  • B[0]: Location transitions (5x5x5 tensor)
    • Agent can move between adjacent maze cells or stay in the same cell.
  • B[1]: Reward location (2x2x1 tensor)
    • Reward location remains fixed throughout trial.

Initial Conditions (D):

  • D[0]: Starting location (5x1 tensor)
    • Agent always begins in centre location
  • D[1]: Reward placement (2x1 tensor)
    • Default: Equal chance (50/50) of reward in either arm (reward_condition=None)
    • Optional: Can fix reward to specific arm, by setting reward_condition to 0 (for left arm) or 1 (for right arm)
In [ ]:
Copied!
# setting the parameters for the environment
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 1.0 # 100% chance of reward in the correct arm
punishment_probability = 1.0 # 100% chance of punishment in the other arm
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm)

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze( 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition,
    dependent_outcomes=dependent_outcomes
)

# you may print the environment parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code: 

# print([a.shape for a in env.params["A"]]) # shape of all A tensors; the shape should start with the batch_size, then the rows, columns, and additional dimensions for the dependencies
# print(env.params["A"][1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm)
# print(env.params["A"][2][0][:,:,0]) # likelihood of observing no cue, left arm cued, or right arm cued (rows), in each location (columns), when the reward condition is 0 (left arm)

# print([b.shape for b in env.params["B"]]) # shape of all B tensors
# print(env.params["B"][0][0][:,:,4]) # probability of transitioning to each location (rows), from each location (columns), when the agent wants to move to the middle of the arms (location 4)
# setting the parameters for the environment reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation reward_probability = 1.0 # 100% chance of reward in the correct arm punishment_probability = 1.0 # 100% chance of punishment in the other arm cue_validity = 1.0 # 100% valid cues dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm) # initialising the environment. see tmaze.py in pymdp/envs for the implementation details. env = TMaze( reward_probability=reward_probability, punishment_probability=punishment_probability, cue_validity=cue_validity, reward_condition=reward_condition, dependent_outcomes=dependent_outcomes ) # you may print the environment parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code: # print([a.shape for a in env.params["A"]]) # shape of all A tensors; the shape should start with the batch_size, then the rows, columns, and additional dimensions for the dependencies # print(env.params["A"][1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm) # print(env.params["A"][2][0][:,:,0]) # likelihood of observing no cue, left arm cued, or right arm cued (rows), in each location (columns), when the reward condition is 0 (left arm) # print([b.shape for b in env.params["B"]]) # shape of all B tensors # print(env.params["B"][0][0][:,:,4]) # probability of transitioning to each location (rows), from each location (columns), when the agent wants to move to the middle of the arms (location 4)

1. A deterministic generative process (environment), and a single agent solving the task with standard active inference.¶

Creating the Agent (Generative Model)¶

We will create the agent's generative model based off the true observation and transition models (encoded by tensors A and B, respectively) of the environment. In other words, we assume the agent knows how the environment works - i.e., it knows that the likelihood of observing a reward in the left arm is 1.0 (reward_probability=1.0) if the reward is actually in the left arm (reward_condition=0), and it knows the cues are 100% accurate (cue_validity=1.0), and it knows that the reward location will be fixed throughout the trial (non-volatile environment). We can of course change these assumptions to create a more complex agent, and we will do that in the next sections where the environment will be more stochastic and we will also add uncertainty to the agent's generative model so it will have to learn that the environment is deterministic or not.

The preference tensors (C) are set using the known dimensions of the outcome modalities, which we can infer from the A tensor's shape. We fix the agent's preferences a priori to prefer reward and avoid punishment. The agent is set to not have any preference to observe certain locations and cues.

The initial beliefs tensors (D; i.e., state priors) are set using the known dimensions of the hidden state factors, which we can infer from the B tensor's shape. The agent has the prior belief that it starts in the center location and it has no prior about the reward location - i.e., the prior for the reward location is uniformly distributed.

In [1]:
Copied!
# setting A tensors from the environment parameters
A = env.A
A_dependencies = env.A_dependencies # dependencies allow you to specify which state factors each observation modality depends on, so you dont have to store all the conditional dependencies between all state factors and each modality

# setting B tensors from the environment parameters
B = env.B
B_dependencies = env.B_dependencies

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0)    # prefer reward
C[1] = C[1].at[2].set(-3.0)   # avoid punishment

# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros(B[0].shape[0], dtype=jnp.float32) 
D_loc = D_loc.at[0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones(B[1].shape[0], dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    policy_len=2, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    learn_A=False,
    learn_B=False
)

# you may print the agent's generative model parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code: 

print(f'Number of generated agents: {agent.batch_size}\n')
print([a.shape for a in agent.A]) # shape of all A tensors
print(agent.A[1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm)
print(agent.C[1]) # preferences for outcomes
# setting A tensors from the environment parameters A = env.A A_dependencies = env.A_dependencies # dependencies allow you to specify which state factors each observation modality depends on, so you dont have to store all the conditional dependencies between all state factors and each modality # setting B tensors from the environment parameters B = env.B B_dependencies = env.B_dependencies # creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes C = [jnp.zeros(a.shape[0], dtype=jnp.float32) for a in A] # setting preferences for outcomes only C[1] = C[1].at[1].set(2.0) # prefer reward C[1] = C[1].at[2].set(-3.0) # avoid punishment # creating D tensors [location], [reward] based on B shapes D = [] # D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre D_loc = jnp.zeros(B[0].shape[0], dtype=jnp.float32) D_loc = D_loc.at[0].set(1.0) # set centre location to 1.0 D.append(D_loc) # D[1]: reward location - uniform distribution D_reward = jnp.ones(B[1].shape[0], dtype=jnp.float32) D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True) # normalise to get uniform distribution D.append(D_reward) # initialising the agent agent = Agent( A, B, C, D, policy_len=2, # how long the action sequence is that the agent is evaluating A_dependencies=A_dependencies, B_dependencies=B_dependencies, learn_A=False, learn_B=False ) # you may print the agent's generative model parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code: print(f'Number of generated agents: {agent.batch_size}\n') print([a.shape for a in agent.A]) # shape of all A tensors print(agent.A[1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm) print(agent.C[1]) # preferences for outcomes
Number of generated agents: 1

[(1, 5, 5), (1, 3, 5, 2), (1, 3, 5, 2)]
[[1. 0. 0. 1. 1.]
 [0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0.]]
[[ 0.  2. -3.]]
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/343300159.py:29: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(

Running the active inference agent¶

In [ ]:
Copied!
key = jr.PRNGKey(0) # random key for seeding reproducible stochasticity in the AIF loop (e.g., for sampling any random processes in the TMaze, and for sampling actions from the agent's chosen policy).
T = 10 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop

# you may print the info dictionary to see the numerical results of the aif agent completing the T-maze task by editing and uncommenting the following lines and running the code: 

# print(info.keys()) # keys in the info dictionary
# print(info["action"][:,0,:]) # actions taken by the agent (locations throughout the maze) 
# print(info["observation"][0]) # observations of the locations for each batch
# print(info["observation"][2]) # observations of the cues for each batch
# print(jnp.around(info["qs"][1], decimals=2)) # posterior beliefs about the reward location
# print(info["qpi"][0].shape) # shape of the policy tensor
key = jr.PRNGKey(0) # random key for seeding reproducible stochasticity in the AIF loop (e.g., for sampling any random processes in the TMaze, and for sampling actions from the agent's chosen policy). T = 10 # number of timesteps to rollout the aif loop for _, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop # you may print the info dictionary to see the numerical results of the aif agent completing the T-maze task by editing and uncommenting the following lines and running the code: # print(info.keys()) # keys in the info dictionary # print(info["action"][:,0,:]) # actions taken by the agent (locations throughout the maze) # print(info["observation"][0]) # observations of the locations for each batch # print(info["observation"][2]) # observations of the cues for each batch # print(jnp.around(info["qs"][1], decimals=2)) # posterior beliefs about the reward location # print(info["qpi"][0].shape) # shape of the policy tensor

Rendering the Task to Visualise the Agent's Behaviour

In [2]:
Copied!
frames = []
for t in range(info["observation"][0].shape[1]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][:, t, :],
        info["observation"][1][:, t, :],  
        info["observation"][2][:, t, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)

# # uncomment the following lines to save the video as a gif
# os.makedirs("figures", exist_ok=True)
# pil_frames = [Image.fromarray(frame) for frame in frames]
# reward_location = "random" if reward_condition is None else ("left" if reward_condition == 0 else "right")
# filename = os.path.join("figures", f"tmaze_{batch_size}_{reward_location}.gif")
# pil_frames[0].save(
#     filename,
#     save_all=True,
#     append_images=pil_frames[1:],
#     duration=1000,  # 1000ms per frame
#     loop=0
# )
frames = [] for t in range(info["observation"][0].shape[1]): # iterate over timesteps # get observations for this timestep observations_t = [ info["observation"][0][:, t, :], info["observation"][1][:, t, :], info["observation"][2][:, t, :] ] frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep frame = np.asarray(frame, dtype=np.uint8) plt.close() # close the figure to prevent memory leak frames.append(frame) frames = np.array(frames, dtype=np.uint8) mediapy.show_video(frames, fps=1) # # uncomment the following lines to save the video as a gif # os.makedirs("figures", exist_ok=True) # pil_frames = [Image.fromarray(frame) for frame in frames] # reward_location = "random" if reward_condition is None else ("left" if reward_condition == 0 else "right") # filename = os.path.join("figures", f"tmaze_{batch_size}_{reward_location}.gif") # pil_frames[0].save( # filename, # save_all=True, # append_images=pil_frames[1:], # duration=1000, # 1000ms per frame # loop=0 # )
This browser does not support the video tag.

Running multiple agents in parallel using the batch_size parameter¶

The active inference computations taking place within the the Agent methods (inference, planning/action selection, and learning) interally expect most arrays to have an additional leading dimension, with length batch_size. This extra dimension can be used to generate and run multiple agents in parallel, each with their own generative model, observation and action history, and posterior beliefs. For instance, if batch_size = 10, then each A, B, C, ... array will represent the corresponding generative model parameters for 10 different agents, and each agent will receive its own history of observations / emit its own actions. The single agent (default) case is therefore the case of batch_size = 1. If you pass in generative model arrays to the Agent constructor without a (batch_size,)-sized leading dimension (as we did in the previous example), they will be expanded internally to have this dimension, creating batch_size identical copies of the provided model parameters.

One can also run multiple environments in parallel by passing a separate, stateful env_params pytree which stores batch_size different A, B, and D arrays, but for now we are using a single fixed environment. This is the default case where the environmental attributes env.A, env.B, and env.D arrays have no leading (batch_size,) dimension, even while the Agent class does have it.

In [3]:
Copied!
batch_size = 9 # number of agents to run in parallel

# setting A tensors from the environment parameters
A, B = env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies # dependencies allow you to specify the state factors the observation modality depends on so you dont have to compute the full tensor using all state factors

# now expand A and B to have a new batch_size-length leading dimension
A = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in A]
B = [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in B]

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(2.0)    # prefer reward
C[1] = C[1].at[:,2].set(-3.0)   # avoid punishment


# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) 
D_loc = D_loc.at[0,0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    policy_len=2,
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    batch_size=batch_size, # note we have to pass in this batch_size parameter so the class knows to treat the first dimension as the batch size
    learn_A=False,
    learn_B=False
)

_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop

# rendering the task to visualise the agent's behaviour 
frames = []
for t in range(info["observation"][0].shape[1]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][:, t, :],
        info["observation"][1][:, t, :],  
        info["observation"][2][:, t, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
batch_size = 9 # number of agents to run in parallel # setting A tensors from the environment parameters A, B = env.A, env.B A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies # dependencies allow you to specify the state factors the observation modality depends on so you dont have to compute the full tensor using all state factors # now expand A and B to have a new batch_size-length leading dimension A = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in A] B = [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in B] # creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] # setting preferences for outcomes only C[1] = C[1].at[:,1].set(2.0) # prefer reward C[1] = C[1].at[:,2].set(-3.0) # avoid punishment # creating D tensors [location], [reward] based on B shapes D = [] # D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) D_loc = D_loc.at[0,0].set(1.0) # set centre location to 1.0 D.append(D_loc) # D[1]: reward location - uniform distribution D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True) # normalise to get uniform distribution D.append(D_reward) # initialising the agent agent = Agent( A, B, C, D, policy_len=2, A_dependencies=A_dependencies, B_dependencies=B_dependencies, batch_size=batch_size, # note we have to pass in this batch_size parameter so the class knows to treat the first dimension as the batch size learn_A=False, learn_B=False ) _, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop # rendering the task to visualise the agent's behaviour frames = [] for t in range(info["observation"][0].shape[1]): # iterate over timesteps # get observations for this timestep observations_t = [ info["observation"][0][:, t, :], info["observation"][1][:, t, :], info["observation"][2][:, t, :] ] frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep frame = np.asarray(frame, dtype=np.uint8) plt.close() # close the figure to prevent memory leak frames.append(frame) frames = np.array(frames, dtype=np.uint8) mediapy.show_video(frames, fps=1)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/1007283590.py:32: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(
This browser does not support the video tag.

2. A noisy generative process, and a single agent solving the task with standard active inference.¶

Making the world more stochastic¶

In this variant, we will once again give the agent true knowledge of the parameters of the generative process (by fixing the A and B tensors of the agent to those of the true generative process), but now we will add stochasticity to the true generative process parameters, such that the agent is not guaranteed to get reward in the rewarded-arm every time, but only more likely to get reward than punishment there. This stochasticity can be modulated using the reward_probability argument to the TMaze() environment. We can likewise modulate the probability of seeing punishment in the punishment arm, by changing the punishment_probability argument to the TMaze environment. Finally, another source of uncertainty is the cue reliability or validity, which encodes the likelihood that the signal observed in the cue location of the TMaze, accurately identifies which arm is the more-rewarding vs. more-punishing arm. This probability (of the cue signalling the 'correct arm') can be modulated using the cue_validity argument to the TMaze constructor.

In [ ]:
Copied!
# THE GENERATIVE PROCESS (NOISY)
# setting the parameters for the environment
batch_size = 4 # number of environments/agents to run in parallel
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 0.7 # 70% chance of reward in the correct arm
punishment_probability = 0.6 # 60% chance of punishment in the other arm
cue_validity = 0.95 # 90% valid cues
dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm)

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze( 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition,
    dependent_outcomes=dependent_outcomes
)
# THE GENERATIVE PROCESS (NOISY) # setting the parameters for the environment batch_size = 4 # number of environments/agents to run in parallel reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation reward_probability = 0.7 # 70% chance of reward in the correct arm punishment_probability = 0.6 # 60% chance of punishment in the other arm cue_validity = 0.95 # 90% valid cues dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm) # initialising the environment. see tmaze.py in pymdp/envs for the implementation details. env = TMaze( reward_probability=reward_probability, punishment_probability=punishment_probability, cue_validity=cue_validity, reward_condition=reward_condition, dependent_outcomes=dependent_outcomes )
In [4]:
Copied!
# THE GENERATIVE MODEL

# Get A, B, and dependencies from the environment
A, B = env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies # dependencies allow you to specify

# now expand A and B to have a new batch_size-length leading dimension
A = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in A]
B = [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in B]

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(3.0)    # prefer reward
C[1] = C[1].at[:,2].set(-4.0)   # avoid punishment


# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) 
D_loc = D_loc.at[0,0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    policy_len=3, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    batch_size=batch_size,
    learn_A=False,
    learn_B=False
)

key = jr.PRNGKey(0) # random key for the aif loop
T = 20 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
# THE GENERATIVE MODEL # Get A, B, and dependencies from the environment A, B = env.A, env.B A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies # dependencies allow you to specify # now expand A and B to have a new batch_size-length leading dimension A = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in A] B = [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in B] # creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] # setting preferences for outcomes only C[1] = C[1].at[:,1].set(3.0) # prefer reward C[1] = C[1].at[:,2].set(-4.0) # avoid punishment # creating D tensors [location], [reward] based on B shapes D = [] # D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) D_loc = D_loc.at[0,0].set(1.0) # set centre location to 1.0 D.append(D_loc) # D[1]: reward location - uniform distribution D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True) # normalise to get uniform distribution D.append(D_reward) # initialising the agent agent = Agent( A, B, C, D, policy_len=3, # how long the action sequence is that the agent is evaluating A_dependencies=A_dependencies, B_dependencies=B_dependencies, batch_size=batch_size, learn_A=False, learn_B=False ) key = jr.PRNGKey(0) # random key for the aif loop T = 20 # number of timesteps to rollout the aif loop for _, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/34467052.py:32: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(
In [5]:
Copied!
frames = []
for t in range(info["observation"][0].shape[1]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][:, t, :],
        info["observation"][1][:, t, :],  
        info["observation"][2][:, t, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
frames = [] for t in range(info["observation"][0].shape[1]): # iterate over timesteps # get observations for this timestep observations_t = [ info["observation"][0][:, t, :], info["observation"][1][:, t, :], info["observation"][2][:, t, :] ] frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep frame = np.asarray(frame, dtype=np.uint8) plt.close() # close the figure to prevent memory leak frames.append(frame) frames = np.array(frames, dtype=np.uint8) mediapy.show_video(frames, fps=1)
This browser does not support the video tag.

3. Agent is equipped with a partially-uncertain model of the true generative process, and can also perform parameter learning to update its beliefs about outcome and transition contingencies over time.¶

You can tweak the following arguments to the Agent constructor: learn_A and learn_B between True and False -- this selectively turns on/off the ability to update the parameters of the Dirichlet posterior over the A and B tensors, respectively. These functionalities should be used when your agent does not have a correct or confident model of the world, and you want the agent to learn its model over time.

As an example, we will do the following two variants in this section:

1 We initialize the agent to have partially-uncertain beliefs about the A and B parameters, but the mode of their beliefs (i.e. the value with the highest probability) is still based on the true generative process.

2 We will equip the agent with incorrect beliefs about the generative process.

Correct Prior Structure¶

In [ ]:
Copied!
# setting the parameters for the environment
reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 1.0 # 100% chance of reward in the correct arm
punishment_probability = 1.0 # 100% chance of punishment in the other arm
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = False

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze( 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition, 
    dependent_outcomes=dependent_outcomes
)
# setting the parameters for the environment reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation reward_probability = 1.0 # 100% chance of reward in the correct arm punishment_probability = 1.0 # 100% chance of punishment in the other arm cue_validity = 1.0 # 100% valid cues dependent_outcomes = False # initialising the environment. see tmaze.py in pymdp/envs for the implementation details. env = TMaze( reward_probability=reward_probability, punishment_probability=punishment_probability, cue_validity=cue_validity, reward_condition=reward_condition, dependent_outcomes=dependent_outcomes )

We can encode uncertainty in the agent's prior knowledge about the state-outcome contingencies by adding non-zero parameter values to the pA tensor (the prior Dirichlet parameters over the Categorical parameters of the A tensor). In practice, we still put 'true knowledge' into this prior, in the sense that the mode of the likelihood distribution aligned with the true generative process parameters, but we add some uncertainty by adding non-zero values into the parts of the Dirichlet prior tensor where they don't exist in the true generative process parameters; this encodes the agent's uncertainty about whether certain relationships between hidden states (e.g., location states or reward-condition states) and outcomes (cue signals or reward outcomes) exist.¶

In [6]:
Copied!
#  initializing the prior over the A tensors (pA) based on the environment parameters
num_obs = [a.shape[0] for a in env.A]  # number of observations for each modality
num_states = [b.shape[0] for b in env.B]  # number of states for each factor

# initializing pA, pB to be identical to the environmental parameters
pA, pB = [jnp.copy(a) for a in env.A], [jnp.copy(b) for b in env.B] # have to copy them here because we will be modifying pA below

# adding non-zero Dirichlet parameter values to pA flatten the expected value of the A tensor (make more uncertain) 
alpha_scale_pA = 0.3 
for m in [1, 2]:  # only modifying the outcome (m=1) and cue (m=2) observation likelihood mappings
    pA[m] = pA[m] + (alpha_scale_pA * jnp.ones_like(pA[m]))

expected_A = [dirichlet_expected_value(pa_m, event_dim=0) for pa_m in pA]# compute the expected value of the Dirichlet parameters in pA 
expected_B = [dirichlet_expected_value(pb_f, event_dim=0) for pb_f in pB]# compute the expected value of the Dirichlet parameters in pB 

# creating C tensors filled with zeros for [location], [reward], [cue] based on modality shapes
C = [jnp.zeros(no, dtype=jnp.float32) for no in num_obs] 
# setting preferences for outcomes only
C[1] = C[1].at[1].set(2.0)    # prefer reward
C[1] = C[1].at[2].set(-3.0)   # avoid punishment

# creating D tensors [location], [reward] based on B shapes
D = [jnp.zeros(ns, dtype=jnp.float32) for ns in num_states]  # creating D tensors filled with zeros for [location], [reward] based on B shapes
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros(num_states[0], dtype=jnp.float32) 
D_loc = D_loc.at[0].set(1.0)  # set centre location to 1.0
D[0] = D_loc

# D[1]: reward location - uniform distribution
D_reward = jnp.ones(num_states[1], dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True)  # normalise to get uniform distribution
D[1] = D_reward


# initialising the agent
# NOTE: We initialize the first value of A and B based on the expected values of the Dirichlet parameters in pA and pB, respectively.
agent = Agent(
    expected_A, expected_B, C, D, 
    pA=pA,
    pB=pB, 
    policy_len=5, # how long the action sequence is that the agent is evaluating
    A_dependencies=env.A_dependencies, 
    B_dependencies=env.B_dependencies,
    learn_A=True,
    learn_B=False,
    gamma=1.0,
    action_selection="stochastic"
)
# initializing the prior over the A tensors (pA) based on the environment parameters num_obs = [a.shape[0] for a in env.A] # number of observations for each modality num_states = [b.shape[0] for b in env.B] # number of states for each factor # initializing pA, pB to be identical to the environmental parameters pA, pB = [jnp.copy(a) for a in env.A], [jnp.copy(b) for b in env.B] # have to copy them here because we will be modifying pA below # adding non-zero Dirichlet parameter values to pA flatten the expected value of the A tensor (make more uncertain) alpha_scale_pA = 0.3 for m in [1, 2]: # only modifying the outcome (m=1) and cue (m=2) observation likelihood mappings pA[m] = pA[m] + (alpha_scale_pA * jnp.ones_like(pA[m])) expected_A = [dirichlet_expected_value(pa_m, event_dim=0) for pa_m in pA]# compute the expected value of the Dirichlet parameters in pA expected_B = [dirichlet_expected_value(pb_f, event_dim=0) for pb_f in pB]# compute the expected value of the Dirichlet parameters in pB # creating C tensors filled with zeros for [location], [reward], [cue] based on modality shapes C = [jnp.zeros(no, dtype=jnp.float32) for no in num_obs] # setting preferences for outcomes only C[1] = C[1].at[1].set(2.0) # prefer reward C[1] = C[1].at[2].set(-3.0) # avoid punishment # creating D tensors [location], [reward] based on B shapes D = [jnp.zeros(ns, dtype=jnp.float32) for ns in num_states] # creating D tensors filled with zeros for [location], [reward] based on B shapes # D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre D_loc = jnp.zeros(num_states[0], dtype=jnp.float32) D_loc = D_loc.at[0].set(1.0) # set centre location to 1.0 D[0] = D_loc # D[1]: reward location - uniform distribution D_reward = jnp.ones(num_states[1], dtype=jnp.float32) D_reward = D_reward / jnp.sum(D_reward, axis=0, keepdims=True) # normalise to get uniform distribution D[1] = D_reward # initialising the agent # NOTE: We initialize the first value of A and B based on the expected values of the Dirichlet parameters in pA and pB, respectively. agent = Agent( expected_A, expected_B, C, D, pA=pA, pB=pB, policy_len=5, # how long the action sequence is that the agent is evaluating A_dependencies=env.A_dependencies, B_dependencies=env.B_dependencies, learn_A=True, learn_B=False, gamma=1.0, action_selection="stochastic" )
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/1366977802.py:37: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(
In [ ]:
Copied!
key = jr.PRNGKey(0) # random key for the aif loop
T = 10 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
key = jr.PRNGKey(0) # random key for the aif loop T = 10 # number of timesteps to rollout the aif loop for _, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop

We can print out the A tensors over time to see how the agent updates its beliefs about the state-outocome contingencies based on the data its collecting through active inference.

In [7]:
Copied!
print("the environment's A tensor")
print(env.A[1][:,:,1])
print()
print("the agent's A tensor at t=0")
print(agent.A[1][0][:,:,1])
print()
print(f"the agent's A tensor at t={T}")
print(info["A"][1][0,-1,:,:,1])
print("the environment's A tensor") print(env.A[1][:,:,1]) print() print("the agent's A tensor at t=0") print(agent.A[1][0][:,:,1]) print() print(f"the agent's A tensor at t={T}") print(info["A"][1][0,-1,:,:,1])
the environment's A tensor
[[1. 0. 0. 1. 1.]
 [0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0.]]

the agent's A tensor at t=0
[[0.68421054 0.15789475 0.15789476 0.68421054 0.68421054]
 [0.15789476 0.15789475 0.68421054 0.15789476 0.15789476]
 [0.15789476 0.68421054 0.15789476 0.15789476 0.15789476]]

the agent's A tensor at t=10
[[0.86867255 0.15789475 0.15789476 0.72054607 0.6926888 ]
 [0.06566371 0.15789475 0.68421054 0.13972697 0.15365562]
 [0.06566371 0.68421054 0.15789476 0.13972697 0.15365562]]
In [8]:
Copied!
frames = []
for t in range(info["observation"][0].shape[1]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][:, t, :],
        info["observation"][1][:, t, :],  
        info["observation"][2][:, t, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
frames = [] for t in range(info["observation"][0].shape[1]): # iterate over timesteps # get observations for this timestep observations_t = [ info["observation"][0][:, t, :], info["observation"][1][:, t, :], info["observation"][2][:, t, :] ] frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep frame = np.asarray(frame, dtype=np.uint8) plt.close() # close the figure to prevent memory leak frames.append(frame) frames = np.array(frames, dtype=np.uint8) mediapy.show_video(frames, fps=1)
This browser does not support the video tag.
In [9]:
Copied!
# Random initialization test
key = jr.PRNGKey(24)

# setting A tensors from the environment parameters, but randomize the prior over the transitions
pA, pB = [jnp.copy(a) for a in env.A], [jr.uniform(subkey, shape=b.shape) for subkey, b in zip(jr.split(key, num=2), env.B)]

# adding noise to outcome and cue modalities to make priors more uncertain
key1, key2 = jr.split(key, num=2)
for(key_m, modality) in zip([key1, key2], [1, 2]):  # only modifying the outcome (m=1) and cue (m=2) observation likelihood mappings
    pA[modality] = jr.uniform(key_m, shape=pA[modality].shape) # generating random values between 0 and 1

expected_A = [dirichlet_expected_value(pa_m, event_dim=0) for pa_m in pA] # compute the expected value of the Dirichlet parameters in pA
expected_B = [dirichlet_expected_value(pb_f, event_dim=0) for pb_f in pB]# compute the expected value of the Dirichlet parameters in pB

# initialising the agent
agent = Agent(
    expected_A, expected_B, C, D, 
    pA=pA,
    pB=pB, 
    policy_len=5, # how long the action sequence is that the agent is evaluating
    A_dependencies=env.A_dependencies, 
    B_dependencies=env.B_dependencies,
    batch_size=batch_size, 
    learn_A=True,
    learn_B=True,
    gamma=0.1,
    action_selection="stochastic"
)

# running the active inference simulation
key = jr.PRNGKey(0) # random key for the aif loop
T = 50 # number of timesteps to rollout the aif loop for
_, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
# Random initialization test key = jr.PRNGKey(24) # setting A tensors from the environment parameters, but randomize the prior over the transitions pA, pB = [jnp.copy(a) for a in env.A], [jr.uniform(subkey, shape=b.shape) for subkey, b in zip(jr.split(key, num=2), env.B)] # adding noise to outcome and cue modalities to make priors more uncertain key1, key2 = jr.split(key, num=2) for(key_m, modality) in zip([key1, key2], [1, 2]): # only modifying the outcome (m=1) and cue (m=2) observation likelihood mappings pA[modality] = jr.uniform(key_m, shape=pA[modality].shape) # generating random values between 0 and 1 expected_A = [dirichlet_expected_value(pa_m, event_dim=0) for pa_m in pA] # compute the expected value of the Dirichlet parameters in pA expected_B = [dirichlet_expected_value(pb_f, event_dim=0) for pb_f in pB]# compute the expected value of the Dirichlet parameters in pB # initialising the agent agent = Agent( expected_A, expected_B, C, D, pA=pA, pB=pB, policy_len=5, # how long the action sequence is that the agent is evaluating A_dependencies=env.A_dependencies, B_dependencies=env.B_dependencies, batch_size=batch_size, learn_A=True, learn_B=True, gamma=0.1, action_selection="stochastic" ) # running the active inference simulation key = jr.PRNGKey(0) # random key for the aif loop T = 50 # number of timesteps to rollout the aif loop for _, info = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61033/3157206125.py:16: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(
In [10]:
Copied!
print("the environment's A tensor")
print(env.A[1][:,:,1])
print()
print("the agent's A tensor at t=0")
print(agent.A[1][0][:,:,1])
print()
print(f"the agent's A tensor at t={T}")
print(info["A"][1][0,-1,:,:,1])
print("the environment's A tensor") print(env.A[1][:,:,1]) print() print("the agent's A tensor at t=0") print(agent.A[1][0][:,:,1]) print() print(f"the agent's A tensor at t={T}") print(info["A"][1][0,-1,:,:,1])
the environment's A tensor
[[1. 0. 0. 1. 1.]
 [0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0.]]

the agent's A tensor at t=0
[[0.00308274 0.07115752 0.55618274 0.42624453 0.36277273]
 [0.02319651 0.4746476  0.31333998 0.5615474  0.43671665]
 [0.9737207  0.45419487 0.13047728 0.01220808 0.20051058]]

the agent's A tensor at t=50
[[0.01081832 0.06939295 0.55391985 0.4478869  0.3744181 ]
 [0.02301651 0.48767537 0.31206512 0.5403655  0.4287357 ]
 [0.9661652  0.4429317  0.134015   0.01174758 0.19684626]]
In [11]:
Copied!
print("the environment's B tensor")
print(env.B[1][:,:,1])
print()
print("the agent's B tensor at t=0")
print(agent.B[1][0][:,:,1])
print()
print(f"the agent's B tensor at t={T}")
print(info["B"][1][0,-1,:,:,1])
print("the environment's B tensor") print(env.B[1][:,:,1]) print() print("the agent's B tensor at t=0") print(agent.B[1][0][:,:,1]) print() print(f"the agent's B tensor at t={T}") print(info["B"][1][0,-1,:,:,1])
the environment's B tensor
[[1. 0.]
 [0. 1.]]

the agent's B tensor at t=0
[[0.8990159  0.44466075]
 [0.10098408 0.5553392 ]]

the agent's B tensor at t=50
[[0.9954207  0.6131649 ]
 [0.0045793  0.38683507]]
In [12]:
Copied!
frames = []
for t in range(info["observation"][0].shape[1]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][:, t, :],
        info["observation"][1][:, t, :],  
        info["observation"][2][:, t, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)
frames = [] for t in range(info["observation"][0].shape[1]): # iterate over timesteps # get observations for this timestep observations_t = [ info["observation"][0][:, t, :], info["observation"][1][:, t, :], info["observation"][2][:, t, :] ] frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep frame = np.asarray(frame, dtype=np.uint8) plt.close() # close the figure to prevent memory leak frames.append(frame) frames = np.array(frames, dtype=np.uint8) mediapy.show_video(frames, fps=1)
This browser does not support the video tag.

Made with Dracula Theme for MkDocs