pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp[nb]" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp[nb]" -q
In [ ]:
Copied!
from pymdp.agent import Agent
from pymdp.utils import list_array_zeros, list_array_norm_dist
from typing import Sequence, Optional

import mctx

import jax.numpy as jnp
import jax.random as jr
import jax.nn as nn

import chex
import pygraphviz

from IPython.display import SVG
from pymdp.agent import Agent from pymdp.utils import list_array_zeros, list_array_norm_dist from typing import Sequence, Optional import mctx import jax.numpy as jnp import jax.random as jr import jax.nn as nn import chex import pygraphviz from IPython.display import SVG

Utility function to convert and display the MCTX output to an SVG visualization of the search tree.

In [ ]:
Copied!
def convert_tree_to_graph(
    tree: mctx.Tree,
    action_labels: Optional[Sequence[str]] = None,
    batch_index: int = 0
) -> pygraphviz.AGraph:
  """Converts a search tree into a Graphviz graph.

  Args:
    tree: A `Tree` containing a batch of search data.
    action_labels: Optional labels for edges, defaults to the action index.
    batch_index: Index of the batch element to plot.

  Returns:
    A Graphviz graph representation of `tree`.
  """
  chex.assert_rank(tree.node_values, 2)
  batch_size = tree.node_values.shape[0]
  if action_labels is None:
    action_labels = range(tree.num_actions)
  elif len(action_labels) != tree.num_actions:
    raise ValueError(
        f"action_labels {action_labels} has the wrong number of actions "
        f"({len(action_labels)}). "
        f"Expecting {tree.num_actions}.")

  def node_to_str(node_i, reward=0, discount=1):
    return (f"{node_i}\n"
            f"Reward: {reward:.2f}\n"
            f"Discount: {discount:.2f}\n"
            f"Value: {tree.node_values[batch_index, node_i]:.2f}\n"
            f"Visits: {tree.node_visits[batch_index, node_i]}\n")

  def edge_to_str(node_i, a_i):
    node_index = jnp.full([batch_size], node_i)
    probs = nn.softmax(tree.children_prior_logits[batch_index, node_i])
    return (f"{action_labels[a_i]}\n"
            f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n"  # pytype: disable=unsupported-operands  # always-use-return-annotations
            f"p: {probs[a_i]:.2f}\n")

  graph = pygraphviz.AGraph(directed=True)

  # Add root
  graph.add_node(0, label=node_to_str(node_i=0), color="green")
  # Add all other nodes and connect them up.
  for node_i in range(tree.num_simulations):
    for a_i in range(tree.num_actions):
      # Index of children, or -1 if not expanded
      children_i = tree.children_index[batch_index, node_i, a_i]
      if children_i >= 0:
        graph.add_node(
            children_i,
            label=node_to_str(
                node_i=children_i,
                reward=tree.children_rewards[batch_index, node_i, a_i],
                discount=tree.children_discounts[batch_index, node_i, a_i]),
            color="red")
        graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i))

  return graph
def convert_tree_to_graph( tree: mctx.Tree, action_labels: Optional[Sequence[str]] = None, batch_index: int = 0 ) -> pygraphviz.AGraph: """Converts a search tree into a Graphviz graph. Args: tree: A `Tree` containing a batch of search data. action_labels: Optional labels for edges, defaults to the action index. batch_index: Index of the batch element to plot. Returns: A Graphviz graph representation of `tree`. """ chex.assert_rank(tree.node_values, 2) batch_size = tree.node_values.shape[0] if action_labels is None: action_labels = range(tree.num_actions) elif len(action_labels) != tree.num_actions: raise ValueError( f"action_labels {action_labels} has the wrong number of actions " f"({len(action_labels)}). " f"Expecting {tree.num_actions}.") def node_to_str(node_i, reward=0, discount=1): return (f"{node_i}\n" f"Reward: {reward:.2f}\n" f"Discount: {discount:.2f}\n" f"Value: {tree.node_values[batch_index, node_i]:.2f}\n" f"Visits: {tree.node_visits[batch_index, node_i]}\n") def edge_to_str(node_i, a_i): node_index = jnp.full([batch_size], node_i) probs = nn.softmax(tree.children_prior_logits[batch_index, node_i]) return (f"{action_labels[a_i]}\n" f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n" # pytype: disable=unsupported-operands # always-use-return-annotations f"p: {probs[a_i]:.2f}\n") graph = pygraphviz.AGraph(directed=True) # Add root graph.add_node(0, label=node_to_str(node_i=0), color="green") # Add all other nodes and connect them up. for node_i in range(tree.num_simulations): for a_i in range(tree.num_actions): # Index of children, or -1 if not expanded children_i = tree.children_index[batch_index, node_i, a_i] if children_i >= 0: graph.add_node( children_i, label=node_to_str( node_i=children_i, reward=tree.children_rewards[batch_index, node_i, a_i], discount=tree.children_discounts[batch_index, node_i, a_i]), color="red") graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i)) return graph

Let's test it on the graph world example as well.

In [1]:
Copied!
import networkx as nx
from pymdp.envs import GraphEnv
from pymdp.envs.graph_worlds import generate_connected_clusters

graph, _ = generate_connected_clusters(cluster_size=3, connections=2)
env = GraphEnv(graph, object_location=4, agent_location=0)

nx.draw(graph, with_labels=True)
import networkx as nx from pymdp.envs import GraphEnv from pymdp.envs.graph_worlds import generate_connected_clusters graph, _ = generate_connected_clusters(cluster_size=3, connections=2) env = GraphEnv(graph, object_location=4, agent_location=0) nx.draw(graph, with_labels=True)
No description has been provided for this image

We set the initial location of the agent at node 1, and prior belief that the object is at node 4. Therefore, the expected rewarding path is 1 -> 0 -> 3 -> 4.

In [2]:
Copied!
A, B= env.A, env.B
A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies

C = list_array_zeros([a.shape[0] for a in A])
C[1] = C[1].at[1].set(1.0)

D = [jnp.ones(d.shape[0]) for d in env.D]
D[0] = D[0].at[0].set(100.0)
D[1] = D[1].at[4].set(10.0)
D = list_array_norm_dist(D)

agent = Agent(
    A=A,
    B=B,
    C=C,
    D=D,
    A_dependencies=A_dependencies,
    B_dependencies=B_dependencies,
    categorical_obs=False,
)
A, B= env.A, env.B A_dependencies, B_dependencies = env.A_dependencies, env.B_dependencies C = list_array_zeros([a.shape[0] for a in A]) C[1] = C[1].at[1].set(1.0) D = [jnp.ones(d.shape[0]) for d in env.D] D[0] = D[0].at[0].set(100.0) D[1] = D[1].at[4].set(10.0) D = list_array_norm_dist(D) agent = Agent( A=A, B=B, C=C, D=D, A_dependencies=A_dependencies, B_dependencies=B_dependencies, categorical_obs=False, )
/var/folders/_f/1qqqnkyd5k5g2b1pgfwzzrqm0000gn/T/ipykernel_61044/329308290.py:12: UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
  agent = Agent(
In [ ]:
Copied!
import mctx
from pymdp.planning.mcts import make_aif_recurrent_fn

recurrent_fn = make_aif_recurrent_fn()
rng_key = jr.PRNGKey(111)
import mctx from pymdp.planning.mcts import make_aif_recurrent_fn recurrent_fn = make_aif_recurrent_fn() rng_key = jr.PRNGKey(111)
In [3]:
Copied!
# %%timeit
root = mctx.RootFnOutput(
    prior_logits=jnp.log(agent.E),
    value=jnp.zeros((agent.batch_size)),
    embedding=agent.D,
)

policy_output = mctx.gumbel_muzero_policy(
    agent,
    rng_key,
    root,
    recurrent_fn,
    num_simulations=1024,
    max_depth=3
)

tree_gumbel = policy_output.search_tree
print(policy_output.action_weights)

graph = convert_tree_to_graph(tree_gumbel)
svg = graph.draw(format='svg', prog='dot').decode(graph.encoding)
SVG(svg)
# %%timeit root = mctx.RootFnOutput( prior_logits=jnp.log(agent.E), value=jnp.zeros((agent.batch_size)), embedding=agent.D, ) policy_output = mctx.gumbel_muzero_policy( agent, rng_key, root, recurrent_fn, num_simulations=1024, max_depth=3 ) tree_gumbel = policy_output.search_tree print(policy_output.action_weights) graph = convert_tree_to_graph(tree_gumbel) svg = graph.draw(format='svg', prog='dot').decode(graph.encoding) SVG(svg)
[[1.5554491e-16 6.9548679e-11 1.7227886e-17 1.3283233e-13 1.9635755e-14
  1.0000000e+00 1.0319131e-14]]
Out[3]:
No description has been provided for this image

Made with Dracula Theme for MkDocs