pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Generative Model Construction with Model and Distribution¶

This tutorial walks through usage of the Model and Distribution classes. The Model class wraps the A, B, C, and D arrays into a unified model object, whose structure can be generated automatically from a JSON-like configuration dict. The configuration dict, often labelled model_description, allows you to assign string-valued names to the random variables of the model (observations, hidden states, and control states), and to specify dependencies between these random variables (via the "depends_on" key within each variable-specific sub-dict).

Within the Model object, each component distribution is an instance of the Distribution class. The Distribution class can use string- or integer-valued labels for both the axes and specific indices along those axes. The intention of these classes and their named dimensions/values, is to provide users a more user-friendly, interpretable entrypoint for building discrete generative models in pymdp.

In the summary, the advantages of the Model and Distribution classes are:

  • Reduce memory burden on the user (and the indexing errors that often ensue) by working with named axes and elements (e.g., "left", "right" instead of 0, 1)
  • Increase interpretability (if applicable) and intuition in the generation and inspection of the POMDP model
  • Flexible specification of dependencies between random variables via string labels

Tutorial Structure¶

  1. Basic Example: A simple grid navigation task built from a structured description using labels, in which the agent has to travel to a goal location.
  2. A More Advanced Example: A simple foraging task built from a structured description using labels, in which the agent has to search for apples and eat them while they spawn at a set rate.

Example 1: Grid World Navigation¶

Let's start with the simple example: an agent moving horizontally in a 1D grid world. Our agent can be in one of four positions: "left", "center_left", "center_right", or "right", and can take actions "move_right" or "move_left". We give an example of setting up a single agent and then an example of running three agents in parallel by using thebatch_size argument to the Agent constructor.

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import jax.tree_util as jtu
from jax import numpy as jnp
from jax import random as jr
from pymdp.agent import Agent
from pymdp.distribution import compile_model
from pymdp.envs.env import Env
from pymdp.envs import rollout
import jax.tree_util as jtu from jax import numpy as jnp from jax import random as jr from pymdp.agent import Agent from pymdp.distribution import compile_model from pymdp.envs.env import Env from pymdp.envs import rollout
In [ ]:
Copied!
positions = ["left", "center_left", "center_right", "right"]
actions = ["move_left", "move_right"]

model_description = {
    "observations": {
        "position_obs": {
            "elements": positions, 
            "depends_on": ["position"] # we specify that the observation depends on the "position" state factor
        },
    },
    "controls": {
        "movement": {"elements": actions} # we specify the available actions
    },
    "states": {
        "position": {
            "elements": positions, 
            "depends_on": ["position"],  # our current position depends on previous position...
            "controlled_by": ["movement"]  # ...and the movement action taken
        },
    },
}

# compile the model structure from the description
model = compile_model(model_description)
positions = ["left", "center_left", "center_right", "right"] actions = ["move_left", "move_right"] model_description = { "observations": { "position_obs": { "elements": positions, "depends_on": ["position"] # we specify that the observation depends on the "position" state factor }, }, "controls": { "movement": {"elements": actions} # we specify the available actions }, "states": { "position": { "elements": positions, "depends_on": ["position"], # our current position depends on previous position... "controlled_by": ["movement"] # ...and the movement action taken }, }, } # compile the model structure from the description model = compile_model(model_description)

We have built a generative model structure using the model description, however the model's parameters are currently uninitialized arrays of zeros. So now, we can fill in the parameter tensors by using the axis and element labels we defined in the model description dict, to set values in particular indices of these arrays.

In [ ]:
Copied!
# fill in the likelihood (A) tensor
# the observations have an identical mapping to the states (i.e., the agent will perfectly observe its position)
model.A["position_obs"]["left", "left"] = 1.0
model.A["position_obs"]["center_left", "center_left"] = 1.0
model.A["position_obs"]["center_right", "center_right"] = 1.0
model.A["position_obs"]["right", "right"] = 1.0
# model.A["position_obs"].data = jnp.eye(len(positions)) # you could also use the .data attribute to set the identity mapping directly

# fill in the transition model (B) tensor
# note that it's specified as ["to", "from", "action"]
# moving right
model.B["position"]["center_left", "left", "move_right"] = 1.0     
model.B["position"]["center_right", "center_left", "move_right"] = 1.0  
model.B["position"]["right", "center_right", "move_right"] = 1.0    
model.B["position"]["right", "right", "move_right"] = 1.0           

# moving left  
model.B["position"]["left", "left", "move_left"] = 1.0              
model.B["position"]["left", "center_left", "move_left"] = 1.0       
model.B["position"]["center_left", "center_right", "move_left"] = 1.0  
model.B["position"]["center_right", "right", "move_left"] = 1.0    

# set preferences (C) tensor - prefer to be at "center_left"
model.C["position_obs"]["center_left"] = 1.0
# fill in the likelihood (A) tensor # the observations have an identical mapping to the states (i.e., the agent will perfectly observe its position) model.A["position_obs"]["left", "left"] = 1.0 model.A["position_obs"]["center_left", "center_left"] = 1.0 model.A["position_obs"]["center_right", "center_right"] = 1.0 model.A["position_obs"]["right", "right"] = 1.0 # model.A["position_obs"].data = jnp.eye(len(positions)) # you could also use the .data attribute to set the identity mapping directly # fill in the transition model (B) tensor # note that it's specified as ["to", "from", "action"] # moving right model.B["position"]["center_left", "left", "move_right"] = 1.0 model.B["position"]["center_right", "center_left", "move_right"] = 1.0 model.B["position"]["right", "center_right", "move_right"] = 1.0 model.B["position"]["right", "right", "move_right"] = 1.0 # moving left model.B["position"]["left", "left", "move_left"] = 1.0 model.B["position"]["left", "center_left", "move_left"] = 1.0 model.B["position"]["center_left", "center_right", "move_left"] = 1.0 model.B["position"]["center_right", "right", "move_left"] = 1.0 # set preferences (C) tensor - prefer to be at "center_left" model.C["position_obs"]["center_left"] = 1.0

Now, let's create the agent (via the Agent object) and have it infer which state it is in via an observation and select an action according to it's goal.

In [1]:
Copied!
gamma = 10 # deterministic behavior; make gamma smaller for stochastic behavior

# create agent
agent = Agent(**model, gamma=gamma)

# set up initial observation to be "left"
observation = jnp.zeros((agent.batch_size, 1)) # broadcast to agent's batch size (defaults to 1 agent) and add a time dimension

# get the prior
qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too

# print initial beliefs, goal, and action chosen
qs = agent.infer_states([observation], qs_init)
print(f"Current belief about position: {positions[jnp.argmax(qs[0][0])]}")
qs = [jnp.squeeze(q, 1) for q in qs]

print(f"Goal position: {positions[jnp.argmax(agent.C[0])]}")

q_pi, G = agent.infer_policies(qs)
action_idx = agent.sample_action(q_pi)
print(f"Action chosen: {actions[action_idx[0][0]]}")
gamma = 10 # deterministic behavior; make gamma smaller for stochastic behavior # create agent agent = Agent(**model, gamma=gamma) # set up initial observation to be "left" observation = jnp.zeros((agent.batch_size, 1)) # broadcast to agent's batch size (defaults to 1 agent) and add a time dimension # get the prior qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too # print initial beliefs, goal, and action chosen qs = agent.infer_states([observation], qs_init) print(f"Current belief about position: {positions[jnp.argmax(qs[0][0])]}") qs = [jnp.squeeze(q, 1) for q in qs] print(f"Goal position: {positions[jnp.argmax(agent.C[0])]}") q_pi, G = agent.infer_policies(qs) action_idx = agent.sample_action(q_pi) print(f"Action chosen: {actions[action_idx[0][0]]}")
Current belief about position: left
Goal position: center_left
Action chosen: move_right
/tmp/ipykernel_1062531/1053406980.py:4: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(**model, gamma=gamma)

We can also run multiple agents or independent trials in parallel, each with a different initial observation (i.e., different initial position) by using the batch_size argument

In [2]:
Copied!
batch_size = 3 # running 3 trials or agents in parallel
gamma = 10     # deterministic behavior; make gamma smaller for stochastic behavior

# create agent
agent = Agent(**model, batch_size=batch_size, gamma=gamma)

# set up different initial observations for each agent: "left", "center_right", and "right"
observation = [jnp.array([[0], [2], [3]])] # wrap in a list to indicate the single modality; observation[0].shape = (batch_size, 1)

# get the prior
qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too

# print goal and initial beliefs
qs = agent.infer_states(observation, qs_init)
for a in range(batch_size): 
    print(f"Agent {a}'s current belief about position: {positions[jnp.argmax(qs[0][a])]}")
qs = [jnp.squeeze(q, 1) for q in qs]

print(f"\nGoal position for all agents: {positions[jnp.argmax(agent.C[0])]}\n")

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
for a in range(batch_size): 
    print(f"Agent {a}'s action chosen: {actions[action[a][0]]}")
batch_size = 3 # running 3 trials or agents in parallel gamma = 10 # deterministic behavior; make gamma smaller for stochastic behavior # create agent agent = Agent(**model, batch_size=batch_size, gamma=gamma) # set up different initial observations for each agent: "left", "center_right", and "right" observation = [jnp.array([[0], [2], [3]])] # wrap in a list to indicate the single modality; observation[0].shape = (batch_size, 1) # get the prior qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too # print goal and initial beliefs qs = agent.infer_states(observation, qs_init) for a in range(batch_size): print(f"Agent {a}'s current belief about position: {positions[jnp.argmax(qs[0][a])]}") qs = [jnp.squeeze(q, 1) for q in qs] print(f"\nGoal position for all agents: {positions[jnp.argmax(agent.C[0])]}\n") q_pi, G = agent.infer_policies(qs) action = agent.sample_action(q_pi) for a in range(batch_size): print(f"Agent {a}'s action chosen: {actions[action[a][0]]}")
Agent 0's current belief about position: left
Agent 1's current belief about position: center_right
Agent 2's current belief about position: right

Goal position for all agents: center_left

Agent 0's action chosen: move_right
Agent 1's action chosen: move_left
Agent 2's action chosen: move_left
/tmp/ipykernel_1062531/1678643109.py:5: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(**model, batch_size=batch_size, gamma=gamma)

Example 2: Apple Foraging Task¶

Now let's look at a slightly more complex example: an apple foraging task. Here, we have a 1x3 grid, with a "left", "center", and "right" cell. These are orchard cells where an apple can grow at a set rate (1/3). The agent's objective is to find apples and eat them as they get a reward to eat apples. The agent can stay, move_left, move_right, or eat.

In [ ]:
Copied!
num_locations = 3 # these correspond to "left", "center", and "right" and you can just specify ["left", "center", "right"] but we use numbers to show how to use the 'size' key instead of 'elements'
item_list = ["orchard", "apple"]

model_description = {
    "observations": {
        "location_obs": {"size": num_locations, # if you want to use numbers instead of strings, you can use the 'size' key instead of 'elements'
                         "depends_on": ["location_state"],
        },
        "item_obs": {"elements": item_list, # "elements" key for strings
                     "depends_on": ["location_state", "left_state", "center_state", "right_state"],
        },
        "reward_obs": {"elements": ["no_reward", "reward"],
                       "depends_on": ["reward_state"],
        },
    },
    "controls": {
        "move": {"elements": ["stay", "move_left", "move_right"],
        },
        "eat": {"elements": ["noop", "eat"], # noop = no-operation
        # note that if you cannot control a state, you still need to add 
        # an action for it (e.g., with elements: ["null"]) for the model to be initialized 
        # with the correct dimensions
        },
    },
    "states": {
        "location_state": {"size": num_locations,
                           "depends_on": ["location_state"],
                           "controlled_by": ["move"],
        },
        "reward_state": {"elements": ["no_reward", "reward"],
                            # if you have more than one dependency, the first dependency is its own state factor (at the previous timestep), 
                            # then add the other dependencies in the order they are specified (you can skip over some state factors)
                            "depends_on": ["reward_state", "location_state", 
                                           "left_state", "center_state", "right_state"],
                            "controlled_by": ["eat"],
        },
        "left_state": {"elements": item_list,
                        "depends_on": ["left_state", "location_state"], 
                        "controlled_by": ["eat"],
        },
        "center_state": {"elements": item_list,
                        "depends_on": ["center_state", "location_state"],
                        "controlled_by": ["eat"],
        },
        "right_state": {"elements": item_list,
                        "depends_on": ["right_state", "location_state"],
                        "controlled_by": ["eat"],
        },
    },
}

model = compile_model(model_description)
num_locations = 3 # these correspond to "left", "center", and "right" and you can just specify ["left", "center", "right"] but we use numbers to show how to use the 'size' key instead of 'elements' item_list = ["orchard", "apple"] model_description = { "observations": { "location_obs": {"size": num_locations, # if you want to use numbers instead of strings, you can use the 'size' key instead of 'elements' "depends_on": ["location_state"], }, "item_obs": {"elements": item_list, # "elements" key for strings "depends_on": ["location_state", "left_state", "center_state", "right_state"], }, "reward_obs": {"elements": ["no_reward", "reward"], "depends_on": ["reward_state"], }, }, "controls": { "move": {"elements": ["stay", "move_left", "move_right"], }, "eat": {"elements": ["noop", "eat"], # noop = no-operation # note that if you cannot control a state, you still need to add # an action for it (e.g., with elements: ["null"]) for the model to be initialized # with the correct dimensions }, }, "states": { "location_state": {"size": num_locations, "depends_on": ["location_state"], "controlled_by": ["move"], }, "reward_state": {"elements": ["no_reward", "reward"], # if you have more than one dependency, the first dependency is its own state factor (at the previous timestep), # then add the other dependencies in the order they are specified (you can skip over some state factors) "depends_on": ["reward_state", "location_state", "left_state", "center_state", "right_state"], "controlled_by": ["eat"], }, "left_state": {"elements": item_list, "depends_on": ["left_state", "location_state"], "controlled_by": ["eat"], }, "center_state": {"elements": item_list, "depends_on": ["center_state", "location_state"], "controlled_by": ["eat"], }, "right_state": {"elements": item_list, "depends_on": ["right_state", "location_state"], "controlled_by": ["eat"], }, }, } model = compile_model(model_description)

We have built a generative model structure using the model description, however the model's parameters are currently uninitialized arrays of zeros. So now, we can fill in the parameter tensors by using the axis and element labels we defined in the model description dict, to set values in particular indices of these arrays.

In [ ]:
Copied!
'''
SPECIFY THE A TENSOR
'''
# identity mapping for the observations regarding location and reward
model.A["location_obs"].data = jnp.eye(len(model.A["location_obs"].data))
model.A["reward_obs"].data = jnp.eye(len(model.A["reward_obs"].data))

# in any of the locations, the agent may observe apple or orchard
model.A["item_obs"]["apple", 0, "apple", :, :] = 1.0
model.A["item_obs"]["apple", 1, :, "apple", :] = 1.0
model.A["item_obs"]["apple", 2, :, :, "apple"] = 1.0
model.A["item_obs"]["orchard", 0, "orchard", :, :] = 1.0
model.A["item_obs"]["orchard", 1, :, "orchard", :] = 1.0
model.A["item_obs"]["orchard", 2, :, :, "orchard"] = 1.0
model.A["item_obs"].data = model.A["item_obs"].data + 1e-3 # add a small amount of noise to the observations

'''
SPECIFY THE B TENSOR
'''

# for moving between locations
# (to, from, action)
valid_transitions = [
    # from 0 (left)
    (0, 0, "stay"), # from left to left, stay
    (1, 0, "move_right"), # from left to center, move right
    (2, 0, "move_left"), # from left to right, move left

    # from 1 (center)
    (0, 1, "move_left"), # from center to left, move left
    (1, 1, "stay"), # from center to center, stay
    (2, 1, "move_right"), # from center to right, move right

    # from 2 (right)
    (0, 2, "move_right"), # from right to left, move right
    (1, 2, "move_left"), # from right to cecenterntre, move left
    (2, 2, "stay"), # from right to right, stay
]

for to_state, from_state, action in valid_transitions:
    model.B["location_state"][to_state, from_state, action] = 1.0

# again, remember the reward states will be set as ["to", "from", ...dependencies..., "action"]
# if the agent sees an apple and does not eat the apple (i.e., noop), it does not get a reward
model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "noop"] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "noop"] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "noop"] = 1.0

# if the agent sees an orchard, it does not get a reward regardless of its actions
model.B["reward_state"]["no_reward", "no_reward", 0, "orchard", :, :, :] = 1.0 
model.B["reward_state"]["no_reward", "no_reward", 1, :, "orchard", :, :] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "orchard", :] = 1.0

# from a reward state, there will always be no reward in the next timestep regardless of the action
model.B["reward_state"]["no_reward", "reward", 0, :, :, :, :] = 1.0
model.B["reward_state"]["no_reward", "reward", 1, :, :, :, :] = 1.0
model.B["reward_state"]["no_reward", "reward", 2, :, :, :, :] = 1.0

# if the agent sees an orchard and eats, it will not get a reward
model.B["reward_state"]["reward", "no_reward", 0, "orchard", :, :, "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 1, :, "orchard", :, "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 2, :, :, "orchard", "eat"] = 0.0

# if the agent sees an apple and eats the apple, it gets a reward and never not get a reward
model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "eat"] = 0.0 
model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "eat"] = 0.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 0, "apple", :, :, "eat"] = 1.0 
model.B["reward_state"]["reward", "no_reward", 1, :, "apple", :, "eat"] = 1.0
model.B["reward_state"]["reward", "no_reward", 2, :, :, "apple", "eat"] = 1.0

apple_spawn_locations = ["left_state", "center_state", "right_state"]
apple_spawn_rate = 1/3
for i, state in enumerate(apple_spawn_locations):
    model.B[state]["orchard", "orchard", :, :] = 1.0 - apple_spawn_rate # no spawn
    model.B[state]["apple", "orchard", :, :] = apple_spawn_rate # spawn
    for agent_location in range(num_locations):
        if i == agent_location:
            # if the agent does not eat the apple (noop), the apple will stay in the cell
            model.B[state]["apple", "apple", agent_location, "noop"] = 1.0
            # if the agent eats the apple, it will become an orchard cell
            model.B[state]["orchard", "apple", agent_location, "eat"] = 1.0
    model.B[state].data = model.B[state].data + 1e-3 # add a small amount of noise to the observations

'''
SPECIFY THE C TENSOR. 
'''
model.C["reward_obs"]["reward"] = 1.0

'''
NORMALISE THE TENSORS
'''

model.A["location_obs"].normalize()
model.A["item_obs"].normalize()
model.A["reward_obs"].normalize()

model.B["location_state"].normalize()
model.B["reward_state"].normalize()
model.B["left_state"].normalize()
model.B["center_state"].normalize()
model.B["right_state"].normalize()
''' SPECIFY THE A TENSOR ''' # identity mapping for the observations regarding location and reward model.A["location_obs"].data = jnp.eye(len(model.A["location_obs"].data)) model.A["reward_obs"].data = jnp.eye(len(model.A["reward_obs"].data)) # in any of the locations, the agent may observe apple or orchard model.A["item_obs"]["apple", 0, "apple", :, :] = 1.0 model.A["item_obs"]["apple", 1, :, "apple", :] = 1.0 model.A["item_obs"]["apple", 2, :, :, "apple"] = 1.0 model.A["item_obs"]["orchard", 0, "orchard", :, :] = 1.0 model.A["item_obs"]["orchard", 1, :, "orchard", :] = 1.0 model.A["item_obs"]["orchard", 2, :, :, "orchard"] = 1.0 model.A["item_obs"].data = model.A["item_obs"].data + 1e-3 # add a small amount of noise to the observations ''' SPECIFY THE B TENSOR ''' # for moving between locations # (to, from, action) valid_transitions = [ # from 0 (left) (0, 0, "stay"), # from left to left, stay (1, 0, "move_right"), # from left to center, move right (2, 0, "move_left"), # from left to right, move left # from 1 (center) (0, 1, "move_left"), # from center to left, move left (1, 1, "stay"), # from center to center, stay (2, 1, "move_right"), # from center to right, move right # from 2 (right) (0, 2, "move_right"), # from right to left, move right (1, 2, "move_left"), # from right to cecenterntre, move left (2, 2, "stay"), # from right to right, stay ] for to_state, from_state, action in valid_transitions: model.B["location_state"][to_state, from_state, action] = 1.0 # again, remember the reward states will be set as ["to", "from", ...dependencies..., "action"] # if the agent sees an apple and does not eat the apple (i.e., noop), it does not get a reward model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "noop"] = 1.0 model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "noop"] = 1.0 model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "noop"] = 1.0 # if the agent sees an orchard, it does not get a reward regardless of its actions model.B["reward_state"]["no_reward", "no_reward", 0, "orchard", :, :, :] = 1.0 model.B["reward_state"]["no_reward", "no_reward", 1, :, "orchard", :, :] = 1.0 model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "orchard", :] = 1.0 # from a reward state, there will always be no reward in the next timestep regardless of the action model.B["reward_state"]["no_reward", "reward", 0, :, :, :, :] = 1.0 model.B["reward_state"]["no_reward", "reward", 1, :, :, :, :] = 1.0 model.B["reward_state"]["no_reward", "reward", 2, :, :, :, :] = 1.0 # if the agent sees an orchard and eats, it will not get a reward model.B["reward_state"]["reward", "no_reward", 0, "orchard", :, :, "eat"] = 0.0 model.B["reward_state"]["reward", "no_reward", 1, :, "orchard", :, "eat"] = 0.0 model.B["reward_state"]["reward", "no_reward", 2, :, :, "orchard", "eat"] = 0.0 # if the agent sees an apple and eats the apple, it gets a reward and never not get a reward model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "eat"] = 0.0 model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "eat"] = 0.0 model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "eat"] = 0.0 model.B["reward_state"]["reward", "no_reward", 0, "apple", :, :, "eat"] = 1.0 model.B["reward_state"]["reward", "no_reward", 1, :, "apple", :, "eat"] = 1.0 model.B["reward_state"]["reward", "no_reward", 2, :, :, "apple", "eat"] = 1.0 apple_spawn_locations = ["left_state", "center_state", "right_state"] apple_spawn_rate = 1/3 for i, state in enumerate(apple_spawn_locations): model.B[state]["orchard", "orchard", :, :] = 1.0 - apple_spawn_rate # no spawn model.B[state]["apple", "orchard", :, :] = apple_spawn_rate # spawn for agent_location in range(num_locations): if i == agent_location: # if the agent does not eat the apple (noop), the apple will stay in the cell model.B[state]["apple", "apple", agent_location, "noop"] = 1.0 # if the agent eats the apple, it will become an orchard cell model.B[state]["orchard", "apple", agent_location, "eat"] = 1.0 model.B[state].data = model.B[state].data + 1e-3 # add a small amount of noise to the observations ''' SPECIFY THE C TENSOR. ''' model.C["reward_obs"]["reward"] = 1.0 ''' NORMALISE THE TENSORS ''' model.A["location_obs"].normalize() model.A["item_obs"].normalize() model.A["reward_obs"].normalize() model.B["location_state"].normalize() model.B["reward_state"].normalize() model.B["left_state"].normalize() model.B["center_state"].normalize() model.B["right_state"].normalize()
In [3]:
Copied!
gamma = 1.0

agent = Agent(**model, learn_A=False, learn_B=False, gamma=gamma, sampling_mode="full")
gamma = 1.0 agent = Agent(**model, learn_A=False, learn_B=False, gamma=gamma, sampling_mode="full")
/tmp/ipykernel_1062531/53738707.py:3: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(**model, learn_A=False, learn_B=False, gamma=gamma, sampling_mode="full")

We can also turn the model description into an environment using PymdpEnv:

In [ ]:
Copied!
from pymdp.envs import PymdpEnv

env = PymdpEnv(**model)
from pymdp.envs import PymdpEnv env = PymdpEnv(**model)

which can then be used to rollout the agent on the env.

In [ ]:
Copied!
from pymdp.envs import rollout

info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0))
from pymdp.envs import rollout info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0))

There's also a make utility method that generates an env and env_params from the model description. You can decide whether to return env_params using the make_env_params argument

In [ ]:
Copied!
from pymdp.envs import make

env, env_params = make(**model, make_env_params=False)
from pymdp.envs import make env, env_params = make(**model, make_env_params=False)
In [ ]:
Copied!
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0))
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0))
In [ ]:
Copied!
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0), env_params=env_params)
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0), env_params=env_params)

If you provide make_env_params = True to the make() function, it will return out the environmental parameters as a separate dict of parameters (e.g., A, B, D as keys), with the batch size of these parameters matched to the batch size of the input tensors A, B, and D.

Passing env_params as input arguments to Env functions like reset() and step(), rather than having them be stored internally as class attrbutes, enables the following:

  • different environments per agent -- if you want to use a different environmental parameterization per agent (e.g., different transition probabilities per environment), then this can be differentiated in the leading dimensions of any array-valued leaves of the env_params pytree.
  • custom changes to environmental parameters -- if you want to changing environmental paramters in a custom way (e.g. env_params["B"][factor_idx][batch_indices] = env_params["B"][factor_idx].at[batch_indices][:,state_idx].set(new_transition_parameters)) without having to construct a new Env. One can simply change the parameters (as long as as shapes align) and then pass them back into the same environment's step
In [ ]:
Copied!
# here, just broadcast the compiled model's A, B, and D tensors across the batch dimension of the agent to show the point
A_env_batched = [jnp.broadcast_to(a.data, (agent.batch_size, ) + a.data.shape) for a in model.A]
B_env_batched = [jnp.broadcast_to(b.data, (agent.batch_size, ) + b.data.shape) for b in model.B]
D_env_batched = [jnp.broadcast_to(d.data, (agent.batch_size, ) + d.data.shape) for d in model.D]

env, env_params = make(A_env_batched, B_env_batched, D_env_batched, agent.A_dependencies, agent.B_dependencies, make_env_params=True)
# here, just broadcast the compiled model's A, B, and D tensors across the batch dimension of the agent to show the point A_env_batched = [jnp.broadcast_to(a.data, (agent.batch_size, ) + a.data.shape) for a in model.A] B_env_batched = [jnp.broadcast_to(b.data, (agent.batch_size, ) + b.data.shape) for b in model.B] D_env_batched = [jnp.broadcast_to(d.data, (agent.batch_size, ) + d.data.shape) for d in model.D] env, env_params = make(A_env_batched, B_env_batched, D_env_batched, agent.A_dependencies, agent.B_dependencies, make_env_params=True)

In this case, also pass env_params to rollout for vmapping across both agents and environments in parallel. NOTE: This means that the env_params must have a matched batch size to the Agent class

In [ ]:
Copied!
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0), env_params=env_params)
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0), env_params=env_params)

Made with Dracula Theme for MkDocs