Generative Model Construction with Model and Distribution¶
This tutorial walks through usage of the Model and Distribution classes. The Model class wraps the A, B, C, and D arrays into a unified model object, whose structure can be generated automatically from a JSON-like configuration dict. The configuration dict, often labelled model_description, allows you to assign string-valued names to the random variables of the model (observations, hidden states, and control states), and to specify dependencies between these random variables (via the "depends_on" key within each variable-specific sub-dict).
Within the Model object, each component distribution is an instance of the Distribution class. The Distribution class can use string- or integer-valued labels for both the axes and specific indices along those axes. The intention of these classes and their named dimensions/values, is to provide users a more user-friendly, interpretable entrypoint for building discrete generative models in pymdp.
In the summary, the advantages of the Model and Distribution classes are:
- Reduce memory burden on the user (and the indexing errors that often ensue) by working with named axes and elements (e.g., "left", "right" instead of 0, 1)
- Increase interpretability (if applicable) and intuition in the generation and inspection of the POMDP model
- Flexible specification of dependencies between random variables via string labels
Tutorial Structure¶
- Basic Example: A simple grid navigation task built from a structured description using labels, in which the agent has to travel to a goal location.
- A More Advanced Example: A simple foraging task built from a structured description using labels, in which the agent has to search for apples and eat them while they spawn at a set rate.
Example 1: Grid World Navigation¶
Let's start with the simple example: an agent moving horizontally in a 1D grid world. Our agent can be in one of four positions: "left", "center_left", "center_right", or "right", and can take actions "move_right" or "move_left". We give an example of setting up a single agent and then an example of running three agents in parallel by using thebatch_size argument to the Agent constructor.
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
import jax.tree_util as jtu
from jax import numpy as jnp
from jax import random as jr
from pymdp.agent import Agent
from pymdp.distribution import compile_model
from pymdp.envs.env import Env
from pymdp.envs import rollout
positions = ["left", "center_left", "center_right", "right"]
actions = ["move_left", "move_right"]
model_description = {
"observations": {
"position_obs": {
"elements": positions,
"depends_on": ["position"] # we specify that the observation depends on the "position" state factor
},
},
"controls": {
"movement": {"elements": actions} # we specify the available actions
},
"states": {
"position": {
"elements": positions,
"depends_on": ["position"], # our current position depends on previous position...
"controlled_by": ["movement"] # ...and the movement action taken
},
},
}
# compile the model structure from the description
model = compile_model(model_description)
We have built a generative model structure using the model description, however the model's parameters are currently uninitialized arrays of zeros. So now, we can fill in the parameter tensors by using the axis and element labels we defined in the model description dict, to set values in particular indices of these arrays.
# fill in the likelihood (A) tensor
# the observations have an identical mapping to the states (i.e., the agent will perfectly observe its position)
model.A["position_obs"]["left", "left"] = 1.0
model.A["position_obs"]["center_left", "center_left"] = 1.0
model.A["position_obs"]["center_right", "center_right"] = 1.0
model.A["position_obs"]["right", "right"] = 1.0
# model.A["position_obs"].data = jnp.eye(len(positions)) # you could also use the .data attribute to set the identity mapping directly
# fill in the transition model (B) tensor
# note that it's specified as ["to", "from", "action"]
# moving right
model.B["position"]["center_left", "left", "move_right"] = 1.0
model.B["position"]["center_right", "center_left", "move_right"] = 1.0
model.B["position"]["right", "center_right", "move_right"] = 1.0
model.B["position"]["right", "right", "move_right"] = 1.0
# moving left
model.B["position"]["left", "left", "move_left"] = 1.0
model.B["position"]["left", "center_left", "move_left"] = 1.0
model.B["position"]["center_left", "center_right", "move_left"] = 1.0
model.B["position"]["center_right", "right", "move_left"] = 1.0
# set preferences (C) tensor - prefer to be at "center_left"
model.C["position_obs"]["center_left"] = 1.0
Now, let's create the agent (via the Agent object) and have it infer which state it is in via an observation and select an action according to it's goal.
gamma = 10 # deterministic behavior; make gamma smaller for stochastic behavior
# create agent
agent = Agent(**model, gamma=gamma)
# set up initial observation to be "left"
observation = jnp.zeros((agent.batch_size, 1)) # broadcast to agent's batch size (defaults to 1 agent) and add a time dimension
# get the prior
qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too
# print initial beliefs, goal, and action chosen
qs = agent.infer_states([observation], qs_init)
print(f"Current belief about position: {positions[jnp.argmax(qs[0][0])]}")
qs = [jnp.squeeze(q, 1) for q in qs]
print(f"Goal position: {positions[jnp.argmax(agent.C[0])]}")
q_pi, G = agent.infer_policies(qs)
action_idx = agent.sample_action(q_pi)
print(f"Action chosen: {actions[action_idx[0][0]]}")
Current belief about position: left Goal position: center_left Action chosen: move_right
/tmp/ipykernel_1062531/1053406980.py:4: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(**model, gamma=gamma)
We can also run multiple agents or independent trials in parallel, each with a different initial observation (i.e., different initial position) by using the batch_size argument
batch_size = 3 # running 3 trials or agents in parallel
gamma = 10 # deterministic behavior; make gamma smaller for stochastic behavior
# create agent
agent = Agent(**model, batch_size=batch_size, gamma=gamma)
# set up different initial observations for each agent: "left", "center_right", and "right"
observation = [jnp.array([[0], [2], [3]])] # wrap in a list to indicate the single modality; observation[0].shape = (batch_size, 1)
# get the prior
qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too
# print goal and initial beliefs
qs = agent.infer_states(observation, qs_init)
for a in range(batch_size):
print(f"Agent {a}'s current belief about position: {positions[jnp.argmax(qs[0][a])]}")
qs = [jnp.squeeze(q, 1) for q in qs]
print(f"\nGoal position for all agents: {positions[jnp.argmax(agent.C[0])]}\n")
q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
for a in range(batch_size):
print(f"Agent {a}'s action chosen: {actions[action[a][0]]}")
Agent 0's current belief about position: left Agent 1's current belief about position: center_right Agent 2's current belief about position: right Goal position for all agents: center_left Agent 0's action chosen: move_right Agent 1's action chosen: move_left Agent 2's action chosen: move_left
/tmp/ipykernel_1062531/1678643109.py:5: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(**model, batch_size=batch_size, gamma=gamma)
Example 2: Apple Foraging Task¶
Now let's look at a slightly more complex example: an apple foraging task. Here, we have a 1x3 grid, with a "left", "center", and "right" cell. These are orchard cells where an apple can grow at a set rate (1/3). The agent's objective is to find apples and eat them as they get a reward to eat apples. The agent can stay, move_left, move_right, or eat.
num_locations = 3 # these correspond to "left", "center", and "right" and you can just specify ["left", "center", "right"] but we use numbers to show how to use the 'size' key instead of 'elements'
item_list = ["orchard", "apple"]
model_description = {
"observations": {
"location_obs": {"size": num_locations, # if you want to use numbers instead of strings, you can use the 'size' key instead of 'elements'
"depends_on": ["location_state"],
},
"item_obs": {"elements": item_list, # "elements" key for strings
"depends_on": ["location_state", "left_state", "center_state", "right_state"],
},
"reward_obs": {"elements": ["no_reward", "reward"],
"depends_on": ["reward_state"],
},
},
"controls": {
"move": {"elements": ["stay", "move_left", "move_right"],
},
"eat": {"elements": ["noop", "eat"], # noop = no-operation
# note that if you cannot control a state, you still need to add
# an action for it (e.g., with elements: ["null"]) for the model to be initialized
# with the correct dimensions
},
},
"states": {
"location_state": {"size": num_locations,
"depends_on": ["location_state"],
"controlled_by": ["move"],
},
"reward_state": {"elements": ["no_reward", "reward"],
# if you have more than one dependency, the first dependency is its own state factor (at the previous timestep),
# then add the other dependencies in the order they are specified (you can skip over some state factors)
"depends_on": ["reward_state", "location_state",
"left_state", "center_state", "right_state"],
"controlled_by": ["eat"],
},
"left_state": {"elements": item_list,
"depends_on": ["left_state", "location_state"],
"controlled_by": ["eat"],
},
"center_state": {"elements": item_list,
"depends_on": ["center_state", "location_state"],
"controlled_by": ["eat"],
},
"right_state": {"elements": item_list,
"depends_on": ["right_state", "location_state"],
"controlled_by": ["eat"],
},
},
}
model = compile_model(model_description)
We have built a generative model structure using the model description, however the model's parameters are currently uninitialized arrays of zeros. So now, we can fill in the parameter tensors by using the axis and element labels we defined in the model description dict, to set values in particular indices of these arrays.
'''
SPECIFY THE A TENSOR
'''
# identity mapping for the observations regarding location and reward
model.A["location_obs"].data = jnp.eye(len(model.A["location_obs"].data))
model.A["reward_obs"].data = jnp.eye(len(model.A["reward_obs"].data))
# in any of the locations, the agent may observe apple or orchard
model.A["item_obs"]["apple", 0, "apple", :, :] = 1.0
model.A["item_obs"]["apple", 1, :, "apple", :] = 1.0
model.A["item_obs"]["apple", 2, :, :, "apple"] = 1.0
model.A["item_obs"]["orchard", 0, "orchard", :, :] = 1.0
model.A["item_obs"]["orchard", 1, :, "orchard", :] = 1.0
model.A["item_obs"]["orchard", 2, :, :, "orchard"] = 1.0
model.A["item_obs"].data = model.A["item_obs"].data + 1e-3 # add a small amount of noise to the observations
'''
SPECIFY THE B TENSOR
'''
# for moving between locations
# (to, from, action)
valid_transitions = [
# from 0 (left)
(0, 0, "stay"), # from left to left, stay
(1, 0, "move_right"), # from left to center, move right
(2, 0, "move_left"), # from left to right, move left
# from 1 (center)
(0, 1, "move_left"), # from center to left, move left
(1, 1, "stay"), # from center to center, stay
(2, 1, "move_right"), # from center to right, move right
# from 2 (right)
(0, 2, "move_right"), # from right to left, move right
(1, 2, "move_left"), # from right to cecenterntre, move left
(2, 2, "stay"), # from right to right, stay
]
for to_state, from_state, action in valid_transitions:
model.B["location_state"][to_state, from_state, action] = 1.0
# again, remember the reward states will be set as ["to", "from", ...dependencies..., "action"]
# if the agent sees an apple and does not eat the apple (i.e., noop), it does not get a reward
model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "noop"] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "noop"] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "noop"] = 1.0
# if the agent sees an orchard, it does not get a reward regardless of its actions
model.B["reward_state"]["no_reward", "no_reward", 0, "orchard", :, :, :] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 1, :, "orchard", :, :] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "orchard", :] = 1.0
# from a reward state, there will always be no reward in the next timestep regardless of the action
model.B["reward_state"]["no_reward", "reward", 0, :, :, :, :] = 1.0
model.B["reward_state"]["no_reward", "reward", 1, :, :, :, :] = 1.0
model.B["reward_state"]["no_reward", "reward", 2, :, :, :, :] = 1.0
# if the agent sees an orchard and eats, it will not get a reward
model.B["reward_state"]["reward", "no_reward", 0, "orchard", :, :, "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 1, :, "orchard", :, "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 2, :, :, "orchard", "eat"] = 0.0
# if the agent sees an apple and eats the apple, it gets a reward and never not get a reward
model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "eat"] = 0.0
model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "eat"] = 0.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 0, "apple", :, :, "eat"] = 1.0
model.B["reward_state"]["reward", "no_reward", 1, :, "apple", :, "eat"] = 1.0
model.B["reward_state"]["reward", "no_reward", 2, :, :, "apple", "eat"] = 1.0
apple_spawn_locations = ["left_state", "center_state", "right_state"]
apple_spawn_rate = 1/3
for i, state in enumerate(apple_spawn_locations):
model.B[state]["orchard", "orchard", :, :] = 1.0 - apple_spawn_rate # no spawn
model.B[state]["apple", "orchard", :, :] = apple_spawn_rate # spawn
for agent_location in range(num_locations):
if i == agent_location:
# if the agent does not eat the apple (noop), the apple will stay in the cell
model.B[state]["apple", "apple", agent_location, "noop"] = 1.0
# if the agent eats the apple, it will become an orchard cell
model.B[state]["orchard", "apple", agent_location, "eat"] = 1.0
model.B[state].data = model.B[state].data + 1e-3 # add a small amount of noise to the observations
'''
SPECIFY THE C TENSOR.
'''
model.C["reward_obs"]["reward"] = 1.0
'''
NORMALISE THE TENSORS
'''
model.A["location_obs"].normalize()
model.A["item_obs"].normalize()
model.A["reward_obs"].normalize()
model.B["location_state"].normalize()
model.B["reward_state"].normalize()
model.B["left_state"].normalize()
model.B["center_state"].normalize()
model.B["right_state"].normalize()
gamma = 1.0
agent = Agent(**model, learn_A=False, learn_B=False, gamma=gamma, sampling_mode="full")
/tmp/ipykernel_1062531/53738707.py:3: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(**model, learn_A=False, learn_B=False, gamma=gamma, sampling_mode="full")
We can also turn the model description into an environment using PymdpEnv:
from pymdp.envs import PymdpEnv
env = PymdpEnv(**model)
which can then be used to rollout the agent on the env.
from pymdp.envs import rollout
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0))
There's also a make utility method that generates an env and env_params from the model description. You can decide whether to return env_params using the make_env_params argument
from pymdp.envs import make
env, env_params = make(**model, make_env_params=False)
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0))
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0), env_params=env_params)
If you provide make_env_params = True to the make() function, it will return out the environmental parameters as a separate dict of parameters (e.g., A, B, D as keys), with the batch size of these parameters matched to the batch size of the input tensors A, B, and D.
Passing env_params as input arguments to Env functions like reset() and step(), rather than having them be stored internally as class attrbutes, enables the following:
- different environments per agent -- if you want to use a different environmental parameterization per agent (e.g., different transition probabilities per environment), then this can be differentiated in the leading dimensions of any array-valued leaves of the
env_paramspytree. - custom changes to environmental parameters -- if you want to changing environmental paramters in a custom way (e.g.
env_params["B"][factor_idx][batch_indices] = env_params["B"][factor_idx].at[batch_indices][:,state_idx].set(new_transition_parameters)) without having to construct a newEnv. One can simply change the parameters (as long as as shapes align) and then pass them back into the same environment'sstep
# here, just broadcast the compiled model's A, B, and D tensors across the batch dimension of the agent to show the point
A_env_batched = [jnp.broadcast_to(a.data, (agent.batch_size, ) + a.data.shape) for a in model.A]
B_env_batched = [jnp.broadcast_to(b.data, (agent.batch_size, ) + b.data.shape) for b in model.B]
D_env_batched = [jnp.broadcast_to(d.data, (agent.batch_size, ) + d.data.shape) for d in model.D]
env, env_params = make(A_env_batched, B_env_batched, D_env_batched, agent.A_dependencies, agent.B_dependencies, make_env_params=True)
In this case, also pass env_params to rollout for vmapping across both agents and environments in parallel. NOTE: This means that the env_params must have a matched batch size to the Agent class
info, last = rollout(agent, env, num_timesteps=10, rng_key=jr.PRNGKey(0), env_params=env_params)