Graph worlds¶
This environment demonstrates agents that can navigate a graph and find an object. Object is only visible when agent is at the same location as the object.
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
# a way to edit and run code and see the effects in the notebook without having to restart the kernel
%load_ext autoreload
%autoreload 2
from jax import numpy as jnp, random as jr, jit
from jax import vmap
import networkx as nx
from pymdp.envs import GraphEnv
from pymdp.envs.graph_worlds import generate_connected_clusters
from pymdp.agent import Agent
from pymdp.envs.rollout import rollout
from pymdp.utils import list_array_uniform, list_array_zeros
import matplotlib.pyplot as plt
key = jr.PRNGKey(0)
Start by generating a graph of locations
graph, _ = generate_connected_clusters(cluster_size=3, connections=2)
nx.draw(graph, with_labels=True, font_weight="bold")
Now we can create a GraphEnv given this graph.
We then specify two object locations and two agent location, which using env.generate_env_params(...) can be used to generate two environments which differ in their initial state prior (their D vectors), which will encode precise priors over a particular initial object location and agent location.
env = GraphEnv(graph)
object_locations=[3, 5]
agent_locations=[0, 1]
env_params = env.generate_env_params(key=key, graph=graph, object_locations=object_locations, agent_locations=agent_locations)
To create an Agent, we reuse the environment's A and B tensors, but give the agent a uniform initial belief about the object location, and a preference to find (see) the object.
A, B = env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies
C = list_array_zeros([a.shape[0] for a in A])
C[1] = C[1].at[1].set(1.0)
D = list_array_uniform([b.shape[0] for b in B])
agent = Agent(A, B, C, D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=2, batch_size=2)
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61026/831457487.py:9: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do. agent = Agent(A, B, C, D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=2, batch_size=2)
Using the rollout function, we can easily simulate two agents in parallel for 10 timesteps...
rollout_jitted = jit(rollout, static_argnums=[1,2])
last, result = rollout_jitted(agent, env, num_timesteps=10, rng_key=key, env_params=env_params)
if you want to continue, call rollout again but with initial_carry=last
last, result2 = rollout_jitted(agent, env, num_timesteps=10, rng_key=key, env_params=env_params, initial_carry=last)
The result dict contains the expected free energy, executed actions, observations, environment state and beliefs over states and policies.
result.keys()
dict_keys(['action', 'empirical_prior', 'env_state', 'neg_efe', 'observation', 'qpi', 'qs'])
The beliefs result is an array for each state factor, and the shape is [batch_size x time x factor_size]
print(len(result["qs"]))
print(result["qs"][0].shape)
2 (2, 11, 7)
We can plot the agent's beliefs over time.
def plot_results(result, agent_idx=0):
"""Plot the results of the agent's beliefs and actions."""
fig, ax = plt.subplots()
ax.title.set_text(f"Agent {agent_idx}")
# Plot the agent location belief as blue dots
T = result["qs"][0].shape[1]
locations = [jnp.argmax(result["qs"][0][agent_idx, t, :]) for t in range(T)]
ax.scatter(jnp.arange(T), locations, c="tab:blue")
# Plot object location beliefs as greyscale intensity
ax.imshow(result["qs"][1][agent_idx, :, :].T, cmap="gray_r", vmin=0.0, vmax=1.0)
plot_results(result, agent_idx=0)
plot_results(result, agent_idx=1)
If you cannot use the rollout method, e.g., you are using a 3rd party environment that is not jax.jit compatible, you can construct your own loop using the infer_and_plan function.
from jax import vmap
import jax.tree_util as jtu
from pymdp.envs.rollout import infer_and_plan
rng_key = jr.PRNGKey(0)
# start with None action, and expand agent.D to add time dimension
action = -jnp.ones((agent.batch_size, len(agent.num_controls)), dtype=jnp.int32)
qs = jtu.tree_map(lambda x: jnp.expand_dims(x, -2), agent.D)
# reset environment and get initial observation
keys = jr.split(rng_key, agent.batch_size + 1)
rng_key = keys[0]
observation, state = vmap(env.reset)(keys[1:], env_params=env_params)
for i in range(10):
# random keys
keys = jr.split(rng_key, agent.batch_size + 2)
rng_key = keys[0]
# infer and plan
agent, action, qs, _ = infer_and_plan(agent, qs, observation, action, keys[1])
# step the environment
observation, state = vmap(env.step)(keys[2:], state, action, env_params=env_params)
print(f"Step {i+1}: Action taken: {action}, Observation: {observation}")
Step 1: Action taken: [[1 0]
[0 0]], Observation: [Array([[1.],
[0.]], dtype=float32), Array([[0.],
[0.]], dtype=float32)]
Step 2: Action taken: [[0 0]
[2 0]], Observation: [Array([[0.],
[2.]], dtype=float32), Array([[0.],
[0.]], dtype=float32)]
Step 3: Action taken: [[2 0]
[0 0]], Observation: [Array([[2.],
[0.]], dtype=float32), Array([[0.],
[0.]], dtype=float32)]
Step 4: Action taken: [[0 0]
[3 0]], Observation: [Array([[0.],
[3.]], dtype=float32), Array([[0.],
[0.]], dtype=float32)]
Step 5: Action taken: [[3 0]
[4 0]], Observation: [Array([[3.],
[4.]], dtype=float32), Array([[1.],
[0.]], dtype=float32)]
Step 6: Action taken: [[1 0]
[3 0]], Observation: [Array([[3.],
[3.]], dtype=float32), Array([[1.],
[0.]], dtype=float32)]
Step 7: Action taken: [[1 0]
[5 0]], Observation: [Array([[3.],
[5.]], dtype=float32), Array([[1.],
[1.]], dtype=float32)]
Step 8: Action taken: [[1 0]
[0 0]], Observation: [Array([[3.],
[5.]], dtype=float32), Array([[1.],
[1.]], dtype=float32)]
Step 9: Action taken: [[1 0]
[0 0]], Observation: [Array([[3.],
[5.]], dtype=float32), Array([[1.],
[1.]], dtype=float32)]
Step 10: Action taken: [[1 0]
[0 0]], Observation: [Array([[3.],
[5.]], dtype=float32), Array([[1.],
[1.]], dtype=float32)]