pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Graph worlds¶

This environment demonstrates agents that can navigate a graph and find an object. Object is only visible when agent is at the same location as the object.

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
# a way to edit and run code and see the effects in the notebook without having to restart the kernel
%load_ext autoreload
%autoreload 2

from jax import numpy as jnp, random as jr, jit
from jax import vmap
import networkx as nx

from pymdp.envs import GraphEnv
from pymdp.envs.graph_worlds import generate_connected_clusters
from pymdp.agent import Agent
from pymdp.envs.rollout import rollout
from pymdp.utils import list_array_uniform, list_array_zeros

import matplotlib.pyplot as plt

key = jr.PRNGKey(0)
# a way to edit and run code and see the effects in the notebook without having to restart the kernel %load_ext autoreload %autoreload 2 from jax import numpy as jnp, random as jr, jit from jax import vmap import networkx as nx from pymdp.envs import GraphEnv from pymdp.envs.graph_worlds import generate_connected_clusters from pymdp.agent import Agent from pymdp.envs.rollout import rollout from pymdp.utils import list_array_uniform, list_array_zeros import matplotlib.pyplot as plt key = jr.PRNGKey(0)

Start by generating a graph of locations

In [1]:
Copied!
graph, _ = generate_connected_clusters(cluster_size=3, connections=2)
nx.draw(graph, with_labels=True, font_weight="bold")
graph, _ = generate_connected_clusters(cluster_size=3, connections=2) nx.draw(graph, with_labels=True, font_weight="bold")
No description has been provided for this image

Now we can create a GraphEnv given this graph.

We then specify two object locations and two agent location, which using env.generate_env_params(...) can be used to generate two environments which differ in their initial state prior (their D vectors), which will encode precise priors over a particular initial object location and agent location.

In [ ]:
Copied!
env = GraphEnv(graph)

object_locations=[3, 5]
agent_locations=[0, 1]

env_params = env.generate_env_params(key=key, graph=graph, object_locations=object_locations, agent_locations=agent_locations)
env = GraphEnv(graph) object_locations=[3, 5] agent_locations=[0, 1] env_params = env.generate_env_params(key=key, graph=graph, object_locations=object_locations, agent_locations=agent_locations)

To create an Agent, we reuse the environment's A and B tensors, but give the agent a uniform initial belief about the object location, and a preference to find (see) the object.

In [2]:
Copied!
A, B = env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies

C = list_array_zeros([a.shape[0] for a in A])
C[1] = C[1].at[1].set(1.0)

D = list_array_uniform([b.shape[0] for b in B])

agent = Agent(A, B, C, D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=2, batch_size=2)
A, B = env.A, env.B A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies C = list_array_zeros([a.shape[0] for a in A]) C[1] = C[1].at[1].set(1.0) D = list_array_uniform([b.shape[0] for b in B]) agent = Agent(A, B, C, D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=2, batch_size=2)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61026/831457487.py:9: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(A, B, C, D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=2, batch_size=2)

Using the rollout function, we can easily simulate two agents in parallel for 10 timesteps...

In [ ]:
Copied!
rollout_jitted = jit(rollout, static_argnums=[1,2])
last, result = rollout_jitted(agent, env, num_timesteps=10, rng_key=key, env_params=env_params)
rollout_jitted = jit(rollout, static_argnums=[1,2]) last, result = rollout_jitted(agent, env, num_timesteps=10, rng_key=key, env_params=env_params)

if you want to continue, call rollout again but with initial_carry=last

In [ ]:
Copied!
last, result2 = rollout_jitted(agent, env, num_timesteps=10, rng_key=key, env_params=env_params, initial_carry=last)
last, result2 = rollout_jitted(agent, env, num_timesteps=10, rng_key=key, env_params=env_params, initial_carry=last)

The result dict contains the expected free energy, executed actions, observations, environment state and beliefs over states and policies.

In [3]:
Copied!
result.keys()
result.keys()
Out[3]:
dict_keys(['action', 'empirical_prior', 'env_state', 'neg_efe', 'observation', 'qpi', 'qs'])

The beliefs result is an array for each state factor, and the shape is [batch_size x time x factor_size]

In [4]:
Copied!
print(len(result["qs"]))
print(result["qs"][0].shape)
print(len(result["qs"])) print(result["qs"][0].shape)
2
(2, 11, 7)

We can plot the agent's beliefs over time.

In [5]:
Copied!
def plot_results(result, agent_idx=0):
    """Plot the results of the agent's beliefs and actions."""
    fig, ax = plt.subplots()
    ax.title.set_text(f"Agent {agent_idx}")

    # Plot the agent location belief as blue dots
    T = result["qs"][0].shape[1]
    locations = [jnp.argmax(result["qs"][0][agent_idx, t, :]) for t in range(T)]
    ax.scatter(jnp.arange(T), locations, c="tab:blue")

    # Plot object location beliefs as greyscale intensity
    ax.imshow(result["qs"][1][agent_idx, :, :].T, cmap="gray_r", vmin=0.0, vmax=1.0)

plot_results(result, agent_idx=0)
plot_results(result, agent_idx=1)
def plot_results(result, agent_idx=0): """Plot the results of the agent's beliefs and actions.""" fig, ax = plt.subplots() ax.title.set_text(f"Agent {agent_idx}") # Plot the agent location belief as blue dots T = result["qs"][0].shape[1] locations = [jnp.argmax(result["qs"][0][agent_idx, t, :]) for t in range(T)] ax.scatter(jnp.arange(T), locations, c="tab:blue") # Plot object location beliefs as greyscale intensity ax.imshow(result["qs"][1][agent_idx, :, :].T, cmap="gray_r", vmin=0.0, vmax=1.0) plot_results(result, agent_idx=0) plot_results(result, agent_idx=1)
No description has been provided for this image
No description has been provided for this image

If you cannot use the rollout method, e.g., you are using a 3rd party environment that is not jax.jit compatible, you can construct your own loop using the infer_and_plan function.

In [6]:
Copied!
from jax import vmap
import jax.tree_util as jtu
from pymdp.envs.rollout import infer_and_plan

rng_key = jr.PRNGKey(0)

# start with None action, and expand agent.D to add time dimension
action = -jnp.ones((agent.batch_size, len(agent.num_controls)), dtype=jnp.int32)
qs = jtu.tree_map(lambda x: jnp.expand_dims(x, -2), agent.D)

# reset environment and get initial observation
keys = jr.split(rng_key, agent.batch_size + 1)
rng_key = keys[0]
observation, state = vmap(env.reset)(keys[1:], env_params=env_params)

for i in range(10):
    # random keys
    keys = jr.split(rng_key, agent.batch_size + 2)
    rng_key = keys[0]

    # infer and plan
    agent, action, qs, _ = infer_and_plan(agent, qs, observation, action, keys[1])

    # step the environment
    observation, state = vmap(env.step)(keys[2:], state, action, env_params=env_params)

    print(f"Step {i+1}: Action taken: {action}, Observation: {observation}")
from jax import vmap import jax.tree_util as jtu from pymdp.envs.rollout import infer_and_plan rng_key = jr.PRNGKey(0) # start with None action, and expand agent.D to add time dimension action = -jnp.ones((agent.batch_size, len(agent.num_controls)), dtype=jnp.int32) qs = jtu.tree_map(lambda x: jnp.expand_dims(x, -2), agent.D) # reset environment and get initial observation keys = jr.split(rng_key, agent.batch_size + 1) rng_key = keys[0] observation, state = vmap(env.reset)(keys[1:], env_params=env_params) for i in range(10): # random keys keys = jr.split(rng_key, agent.batch_size + 2) rng_key = keys[0] # infer and plan agent, action, qs, _ = infer_and_plan(agent, qs, observation, action, keys[1]) # step the environment observation, state = vmap(env.step)(keys[2:], state, action, env_params=env_params) print(f"Step {i+1}: Action taken: {action}, Observation: {observation}")
Step 1: Action taken: [[1 0]
 [0 0]], Observation: [Array([[1.],
       [0.]], dtype=float32), Array([[0.],
       [0.]], dtype=float32)]
Step 2: Action taken: [[0 0]
 [2 0]], Observation: [Array([[0.],
       [2.]], dtype=float32), Array([[0.],
       [0.]], dtype=float32)]
Step 3: Action taken: [[2 0]
 [0 0]], Observation: [Array([[2.],
       [0.]], dtype=float32), Array([[0.],
       [0.]], dtype=float32)]
Step 4: Action taken: [[0 0]
 [3 0]], Observation: [Array([[0.],
       [3.]], dtype=float32), Array([[0.],
       [0.]], dtype=float32)]
Step 5: Action taken: [[3 0]
 [4 0]], Observation: [Array([[3.],
       [4.]], dtype=float32), Array([[1.],
       [0.]], dtype=float32)]
Step 6: Action taken: [[1 0]
 [3 0]], Observation: [Array([[3.],
       [3.]], dtype=float32), Array([[1.],
       [0.]], dtype=float32)]
Step 7: Action taken: [[1 0]
 [5 0]], Observation: [Array([[3.],
       [5.]], dtype=float32), Array([[1.],
       [1.]], dtype=float32)]
Step 8: Action taken: [[1 0]
 [0 0]], Observation: [Array([[3.],
       [5.]], dtype=float32), Array([[1.],
       [1.]], dtype=float32)]
Step 9: Action taken: [[1 0]
 [0 0]], Observation: [Array([[3.],
       [5.]], dtype=float32), Array([[1.],
       [1.]], dtype=float32)]
Step 10: Action taken: [[1 0]
 [0 0]], Observation: [Array([[3.],
       [5.]], dtype=float32), Array([[1.],
       [1.]], dtype=float32)]

Made with Dracula Theme for MkDocs