pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

Complex action dependencies¶

In this notebook, we will show some examples of how to specify and run agents with complex action dependencies. Complex action dependencies refer to situations where a state variables depends on multiple actions or no action. These state transitions tensors have shapes of the form: [state_dim, *prev_state_dims, *prev_action_dims].

The general strategy for dealing with this is to flatten the prev_action_dims while initializing the agent so that the new B tensor shapes are [state_dim, *prev_state_dims, math.prod(prev_action_dims)]. If a state has no action dependency, the new B tensor will have shape [state_dim, *prev_state_dims, 1] where 1 stands for a dummy action. All computations will be done in the flattened B tensors and actions will be sampled in the flattened action dimensions. After a flattened action is sampled, one can convert it back to the original action dimensions by calling agent.decode_multi_actions. To flatten multi actions, for example from collected data, one can call agent.encode_multi_actions.

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
from pprint import pprint
import itertools
import numpy as np
from jax import numpy as jnp, random as jr
from jax import tree_util as jtu

from pymdp.agent import Agent
from pymdp import distribution
from pprint import pprint import itertools import numpy as np from jax import numpy as jnp, random as jr from jax import tree_util as jtu from pymdp.agent import Agent from pymdp import distribution

Multiple action dependencies¶

In this example, some states depend on multiple actions.

In [1]:
Copied!
model_description = {
    "observations": {
        "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
    },
    "controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}},
    "states": {
        "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]},
        "s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": ["c1"]},
    },
}

B_action_dependencies = [
    [list(model_description["controls"].keys()).index(i) for i in s["controlled_by"]] 
    for s in model_description["states"].values()
]
num_controls = [len(c["elements"]) for c in model_description["controls"].values()]

model = distribution.compile_model(model_description)

# initialize tensor values
model.A["o1"]["A", "A"] = 1.0
model.A["o1"]["B", "B"] = 1.0
model.A["o1"]["C", "C"] = 1.0
model.A["o1"]["D", "D"] = 1.0

for i, state in enumerate(model_description["states"].keys()):
    controls = list(itertools.product(*[
        model_description["controls"][c]["elements"] for c in model_description["states"][state]["controlled_by"]
    ]))
    for control in controls:
        model.B[i][("B", "A", *control)] = 1.0
        model.B[i][("C", "B", *control)] = 1.0
        model.B[i][("D", "C", *control)] = 1.0
        model.B[i][("D", "D", *control)] = 1.0

agent = Agent(
    model.A, model.B,
    B_action_dependencies=B_action_dependencies,
    num_controls=num_controls,
)

# dummy history
# re-write the action sampling to be controlled by jax.random
action_key, obs_key = jr.split(jr.PRNGKey(0))
action = agent.policies[jr.choice(action_key, len(agent.policies))]
observation = [jr.randint(obs_key, (1, 1), 0, d) for d in agent.num_obs]
qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D)

prior = agent.update_empirical_prior(action, qs_hist)
qs = agent.infer_states(observations=observation, empirical_prior=prior)

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)
action_reconstruct = agent.encode_multi_actions(action_multi)

print("A_dependencies", agent.A_dependencies)
print("B_dependencies", agent.B_dependencies)
print("B_action_dependencies", agent.B_action_dependencies)
print("original control dims", agent.num_controls_multi)
print("flattened control dims", agent.num_controls)
print("original B shapes", [a.data.shape for a in model.B])
print("flattened B shapes", [a.shape for a in agent.B])
print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in model.B])
print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B])

print("\n")
print("prior")
pprint([p.round(2) for p in prior])
print("post")
pprint([p.round(2) for p in qs])
print("action")
pprint(action)
print("action_multi")
pprint(action_multi)
print("action_reconstruct")
pprint(action_reconstruct)
model_description = { "observations": { "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]}, }, "controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}}, "states": { "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]}, "s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": ["c1"]}, }, } B_action_dependencies = [ [list(model_description["controls"].keys()).index(i) for i in s["controlled_by"]] for s in model_description["states"].values() ] num_controls = [len(c["elements"]) for c in model_description["controls"].values()] model = distribution.compile_model(model_description) # initialize tensor values model.A["o1"]["A", "A"] = 1.0 model.A["o1"]["B", "B"] = 1.0 model.A["o1"]["C", "C"] = 1.0 model.A["o1"]["D", "D"] = 1.0 for i, state in enumerate(model_description["states"].keys()): controls = list(itertools.product(*[ model_description["controls"][c]["elements"] for c in model_description["states"][state]["controlled_by"] ])) for control in controls: model.B[i][("B", "A", *control)] = 1.0 model.B[i][("C", "B", *control)] = 1.0 model.B[i][("D", "C", *control)] = 1.0 model.B[i][("D", "D", *control)] = 1.0 agent = Agent( model.A, model.B, B_action_dependencies=B_action_dependencies, num_controls=num_controls, ) # dummy history # re-write the action sampling to be controlled by jax.random action_key, obs_key = jr.split(jr.PRNGKey(0)) action = agent.policies[jr.choice(action_key, len(agent.policies))] observation = [jr.randint(obs_key, (1, 1), 0, d) for d in agent.num_obs] qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D) prior = agent.update_empirical_prior(action, qs_hist) qs = agent.infer_states(observations=observation, empirical_prior=prior) q_pi, G = agent.infer_policies(qs) action = agent.sample_action(q_pi) action_multi = agent.decode_multi_actions(action) action_reconstruct = agent.encode_multi_actions(action_multi) print("A_dependencies", agent.A_dependencies) print("B_dependencies", agent.B_dependencies) print("B_action_dependencies", agent.B_action_dependencies) print("original control dims", agent.num_controls_multi) print("flattened control dims", agent.num_controls) print("original B shapes", [a.data.shape for a in model.B]) print("flattened B shapes", [a.shape for a in agent.B]) print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in model.B]) print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B]) print("\n") print("prior") pprint([p.round(2) for p in prior]) print("post") pprint([p.round(2) for p in qs]) print("action") pprint(action) print("action_multi") pprint(action_multi) print("action_reconstruct") pprint(action_reconstruct)
A_dependencies [[0]]
B_dependencies [[0], [1]]
B_action_dependencies [[0, 1], [0]]
original control dims [2, 3]
flattened control dims [6, 2]
original B shapes [(4, 4, 2, 3), (4, 4, 2)]
flattened B shapes [(1, 4, 4, 6), (1, 4, 4, 2)]
B normalized [Array(True, dtype=bool), Array(True, dtype=bool)]
B flat normalized [Array(True, dtype=bool), Array(True, dtype=bool)]


prior
[Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32),
 Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32)]
post
[Array([[[0., 0., 1., 0.]]], dtype=float32),
 Array([[[0.  , 0.25, 0.25, 0.5 ]]], dtype=float32)]
action
Array([[0, 0]], dtype=int32)
action_multi
Array([[0, 0]], dtype=int32)
action_reconstruct
Array([[0, 0]], dtype=int32)

No action dependency¶

In this example, some states do not depend on any action.

In [2]:
Copied!
model_description = {
    "observations": {
        "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
    },
    "controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}},
    "states": {
        "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]},
        "s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": []},
    },
}

B_action_dependencies = [   
    [list(model_description["controls"].keys()).index(i) for i in s["controlled_by"]] 
    for s in model_description["states"].values()
]
num_controls = [len(c["elements"]) for c in model_description["controls"].values()]

model = distribution.compile_model(model_description)

# initialize tensor values
model.A["o1"]["A", "A"] = 1.0
model.A["o1"]["B", "B"] = 1.0
model.A["o1"]["C", "C"] = 1.0
model.A["o1"]["D", "D"] = 1.0

for i, state in enumerate(model_description["states"].keys()):
    controls = list(itertools.product(*[
        model_description["controls"][c]["elements"] for c in model_description["states"][state]["controlled_by"]
    ]))
    for control in controls:
        model.B[i][("B", "A", *control)] = 1.0
        model.B[i][("C", "B", *control)] = 1.0
        model.B[i][("D", "C", *control)] = 1.0
        model.B[i][("D", "D", *control)] = 1.0

agent = Agent(
    model.A, model.B,
    B_action_dependencies=B_action_dependencies,
    num_controls=num_controls,
)

# dummy history
action_key, obs_key = jr.split(jr.PRNGKey(0))
action = agent.policies[jr.choice(action_key, len(agent.policies))]
observation = [jr.randint(obs_key, (1, 1), 0, d) for d in agent.num_obs]
qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D)

prior = agent.update_empirical_prior(action, qs_hist)
qs = agent.infer_states(observations=observation, empirical_prior=prior)

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)
action_reconstruct = agent.encode_multi_actions(action_multi)

print("A_dependencies", agent.A_dependencies)
print("B_dependencies", agent.B_dependencies)
print("B_action_dependencies", agent.B_action_dependencies)
print("original control dims", agent.num_controls_multi)
print("flattened control dims", agent.num_controls)
print("original B shapes", [a.data.shape for a in model.B])
print("flattened B shapes", [a.shape for a in agent.B])
print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in model.B])
print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B])

print("\n")
print("prior")
pprint([p.round(2) for p in prior])
print("post")
pprint([p.round(2) for p in qs])
print("action")
pprint(action)
print("action_multi")
pprint(action_multi)
print("action_reconstruct")
pprint(action_reconstruct)
model_description = { "observations": { "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]}, }, "controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}}, "states": { "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]}, "s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": []}, }, } B_action_dependencies = [ [list(model_description["controls"].keys()).index(i) for i in s["controlled_by"]] for s in model_description["states"].values() ] num_controls = [len(c["elements"]) for c in model_description["controls"].values()] model = distribution.compile_model(model_description) # initialize tensor values model.A["o1"]["A", "A"] = 1.0 model.A["o1"]["B", "B"] = 1.0 model.A["o1"]["C", "C"] = 1.0 model.A["o1"]["D", "D"] = 1.0 for i, state in enumerate(model_description["states"].keys()): controls = list(itertools.product(*[ model_description["controls"][c]["elements"] for c in model_description["states"][state]["controlled_by"] ])) for control in controls: model.B[i][("B", "A", *control)] = 1.0 model.B[i][("C", "B", *control)] = 1.0 model.B[i][("D", "C", *control)] = 1.0 model.B[i][("D", "D", *control)] = 1.0 agent = Agent( model.A, model.B, B_action_dependencies=B_action_dependencies, num_controls=num_controls, ) # dummy history action_key, obs_key = jr.split(jr.PRNGKey(0)) action = agent.policies[jr.choice(action_key, len(agent.policies))] observation = [jr.randint(obs_key, (1, 1), 0, d) for d in agent.num_obs] qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D) prior = agent.update_empirical_prior(action, qs_hist) qs = agent.infer_states(observations=observation, empirical_prior=prior) q_pi, G = agent.infer_policies(qs) action = agent.sample_action(q_pi) action_multi = agent.decode_multi_actions(action) action_reconstruct = agent.encode_multi_actions(action_multi) print("A_dependencies", agent.A_dependencies) print("B_dependencies", agent.B_dependencies) print("B_action_dependencies", agent.B_action_dependencies) print("original control dims", agent.num_controls_multi) print("flattened control dims", agent.num_controls) print("original B shapes", [a.data.shape for a in model.B]) print("flattened B shapes", [a.shape for a in agent.B]) print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in model.B]) print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B]) print("\n") print("prior") pprint([p.round(2) for p in prior]) print("post") pprint([p.round(2) for p in qs]) print("action") pprint(action) print("action_multi") pprint(action_multi) print("action_reconstruct") pprint(action_reconstruct)
A_dependencies [[0]]
B_dependencies [[0], [1]]
B_action_dependencies [[0, 1], []]
original control dims [2, 3]
flattened control dims [6, 1]
original B shapes [(4, 4, 2, 3), (4, 4)]
flattened B shapes [(1, 4, 4, 6), (1, 4, 4, 1)]
B normalized [Array(True, dtype=bool), Array(True, dtype=bool)]
B flat normalized [Array(True, dtype=bool), Array(True, dtype=bool)]


prior
[Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32),
 Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32)]
post
[Array([[[0., 0., 1., 0.]]], dtype=float32),
 Array([[[0.  , 0.25, 0.25, 0.5 ]]], dtype=float32)]
action
Array([[0, 0]], dtype=int32)
action_multi
Array([[0, 0]], dtype=int32)
action_reconstruct
Array([[0, 0]], dtype=int32)

Made with Dracula Theme for MkDocs