pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Sophisticated inference on Graph Worlds¶

This notebook demonstrates tree searching policies with sophisticated active inference.

We create a GraphEnv, where an agent can move between locations on a graph, and see objects at the location it is at.

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
from jax import numpy as jnp, random as jr
from jax import vmap
import jax.tree_util as jtu
from pymdp.utils import list_array_zeros, list_array_norm_dist
from jax import numpy as jnp, random as jr from jax import vmap import jax.tree_util as jtu from pymdp.utils import list_array_zeros, list_array_norm_dist
In [ ]:
Copied!
import networkx as nx
from pymdp.envs import GraphEnv
from pymdp.envs.graph_worlds import generate_connected_clusters

graph, _ = generate_connected_clusters(cluster_size=3, connections=2)
env = GraphEnv(graph, object_location=4, agent_location=0)
import networkx as nx from pymdp.envs import GraphEnv from pymdp.envs.graph_worlds import generate_connected_clusters graph, _ = generate_connected_clusters(cluster_size=3, connections=2) env = GraphEnv(graph, object_location=4, agent_location=0)
In [1]:
Copied!
nx.draw(graph, with_labels=True, font_weight="bold")
nx.draw(graph, with_labels=True, font_weight="bold")
No description has been provided for this image

Let's create an agent, we give the agent a prior on the object location to showcase the planning depth and pruning. We let the agent start at location 0, and give it a prior about where the object is (location 4).

In [ ]:
Copied!
from pymdp.agent import Agent

A, B = env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies

C = list_array_zeros([a.shape[0] for a in A])
C[1] = C[1].at[1].set(10.0)

D = [jnp.ones(d.shape[0]) for d in env.D]
D[0] = D[0].at[0].set(100.0)
D[1] = D[1].at[4].set(100.0)
D = list_array_norm_dist(D)

agent = Agent(A, B, C, D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=1)
from pymdp.agent import Agent A, B = env.A, env.B A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies C = list_array_zeros([a.shape[0] for a in A]) C[1] = C[1].at[1].set(10.0) D = [jnp.ones(d.shape[0]) for d in env.D] D[0] = D[0].at[0].set(100.0) D[1] = D[1].at[4].set(100.0) D = list_array_norm_dist(D) agent = Agent(A, B, C, D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=1)

Now we can call the si_policy_search method and visualize the planning tree.

In [ ]:
Copied!
from pymdp.planning.si import si_policy_search
from pymdp.envs import rollout

key = jr.PRNGKey(0)
last, info = rollout(agent, env, 5, key, policy_search=si_policy_search(horizon=3))
from pymdp.planning.si import si_policy_search from pymdp.envs import rollout key = jr.PRNGKey(0) last, info = rollout(agent, env, 5, key, policy_search=si_policy_search(horizon=3))

Given the agent prior, it should commit to going to location 4, expecting to see the object there.

In [2]:
Copied!
from pymdp.planning.visualize import visualize_plan_tree

def graph_world_observation_description(observation, model=None, graph=graph):
    obs = observation[0]
    loc_idx = int(obs[0])
    obj_val = int(obs[1])

    obj_str = "object" if obj_val == 1 else "no object"
    return f"loc: {loc_idx}\nobj: {obj_str}"


def graph_world_action_description(action, model=None, graph=graph):
    action_idx = int(action[0])
    if action_idx in graph.nodes:
        return f"go to node {action_idx}"
    return str(action_idx)


visualize_plan_tree(
    info,
    observation_description=graph_world_observation_description,
    action_description=graph_world_action_description,
)
from pymdp.planning.visualize import visualize_plan_tree def graph_world_observation_description(observation, model=None, graph=graph): obs = observation[0] loc_idx = int(obs[0]) obj_val = int(obs[1]) obj_str = "object" if obj_val == 1 else "no object" return f"loc: {loc_idx}\nobj: {obj_str}" def graph_world_action_description(action, model=None, graph=graph): action_idx = int(action[0]) if action_idx in graph.nodes: return f"go to node {action_idx}" return str(action_idx) visualize_plan_tree( info, observation_description=graph_world_observation_description, action_description=graph_world_action_description, )
No description has been provided for this image

Made with Dracula Theme for MkDocs