Complex action dependencies¶
In this notebook, we will show some examples of how to specify and run agents with complex action dependencies. Complex action dependencies refer to situations where a state variables depends on multiple actions or no action. These state transitions tensors have shapes of the form: [state_dim, *prev_state_dims, *prev_action_dims].
The general strategy for dealing with this is to flatten the prev_action_dims while initializing the agent so that the new B tensor shapes are [state_dim, *prev_state_dims, math.prod(prev_action_dims)]. If a state has no action dependency, the new B tensor will have shape [state_dim, *prev_state_dims, 1] where 1 stands for a dummy action. All computations will be done in the flattened B tensors and actions will be sampled in the flattened action dimensions. After a flattened action is sampled, one can convert it back to the original action dimensions by calling agent.decode_multi_actions. To flatten multi actions, for example from collected data, one can call agent.encode_multi_actions.
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
from pprint import pprint
import itertools
import numpy as np
from jax import numpy as jnp, random as jr
from jax import tree_util as jtu
from pymdp.agent import Agent
from pymdp import distribution
Multiple action dependencies¶
In this example, some states depend on multiple actions.
model_description = {
"observations": {
"o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
},
"controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}},
"states": {
"s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]},
"s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": ["c1"]},
},
}
B_action_dependencies = [
[list(model_description["controls"].keys()).index(i) for i in s["controlled_by"]]
for s in model_description["states"].values()
]
num_controls = [len(c["elements"]) for c in model_description["controls"].values()]
model = distribution.compile_model(model_description)
# initialize tensor values
model.A["o1"]["A", "A"] = 1.0
model.A["o1"]["B", "B"] = 1.0
model.A["o1"]["C", "C"] = 1.0
model.A["o1"]["D", "D"] = 1.0
for i, state in enumerate(model_description["states"].keys()):
controls = list(itertools.product(*[
model_description["controls"][c]["elements"] for c in model_description["states"][state]["controlled_by"]
]))
for control in controls:
model.B[i][("B", "A", *control)] = 1.0
model.B[i][("C", "B", *control)] = 1.0
model.B[i][("D", "C", *control)] = 1.0
model.B[i][("D", "D", *control)] = 1.0
agent = Agent(
model.A, model.B,
B_action_dependencies=B_action_dependencies,
num_controls=num_controls,
)
# dummy history
# re-write the action sampling to be controlled by jax.random
action_key, obs_key = jr.split(jr.PRNGKey(0))
action = agent.policies[jr.choice(action_key, len(agent.policies))]
observation = [jr.randint(obs_key, (1, 1), 0, d) for d in agent.num_obs]
qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D)
prior = agent.update_empirical_prior(action, qs_hist)
qs = agent.infer_states(observations=observation, empirical_prior=prior)
q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)
action_reconstruct = agent.encode_multi_actions(action_multi)
print("A_dependencies", agent.A_dependencies)
print("B_dependencies", agent.B_dependencies)
print("B_action_dependencies", agent.B_action_dependencies)
print("original control dims", agent.num_controls_multi)
print("flattened control dims", agent.num_controls)
print("original B shapes", [a.data.shape for a in model.B])
print("flattened B shapes", [a.shape for a in agent.B])
print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in model.B])
print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B])
print("\n")
print("prior")
pprint([p.round(2) for p in prior])
print("post")
pprint([p.round(2) for p in qs])
print("action")
pprint(action)
print("action_multi")
pprint(action_multi)
print("action_reconstruct")
pprint(action_reconstruct)
A_dependencies [[0]] B_dependencies [[0], [1]] B_action_dependencies [[0, 1], [0]] original control dims [2, 3] flattened control dims [6, 2] original B shapes [(4, 4, 2, 3), (4, 4, 2)] flattened B shapes [(1, 4, 4, 6), (1, 4, 4, 2)] B normalized [Array(True, dtype=bool), Array(True, dtype=bool)] B flat normalized [Array(True, dtype=bool), Array(True, dtype=bool)] prior [Array([[0. , 0.25, 0.25, 0.5 ]], dtype=float32), Array([[0. , 0.25, 0.25, 0.5 ]], dtype=float32)] post [Array([[[0., 0., 1., 0.]]], dtype=float32), Array([[[0. , 0.25, 0.25, 0.5 ]]], dtype=float32)] action Array([[0, 0]], dtype=int32) action_multi Array([[0, 0]], dtype=int32) action_reconstruct Array([[0, 0]], dtype=int32)
No action dependency¶
In this example, some states do not depend on any action.
model_description = {
"observations": {
"o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
},
"controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}},
"states": {
"s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]},
"s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": []},
},
}
B_action_dependencies = [
[list(model_description["controls"].keys()).index(i) for i in s["controlled_by"]]
for s in model_description["states"].values()
]
num_controls = [len(c["elements"]) for c in model_description["controls"].values()]
model = distribution.compile_model(model_description)
# initialize tensor values
model.A["o1"]["A", "A"] = 1.0
model.A["o1"]["B", "B"] = 1.0
model.A["o1"]["C", "C"] = 1.0
model.A["o1"]["D", "D"] = 1.0
for i, state in enumerate(model_description["states"].keys()):
controls = list(itertools.product(*[
model_description["controls"][c]["elements"] for c in model_description["states"][state]["controlled_by"]
]))
for control in controls:
model.B[i][("B", "A", *control)] = 1.0
model.B[i][("C", "B", *control)] = 1.0
model.B[i][("D", "C", *control)] = 1.0
model.B[i][("D", "D", *control)] = 1.0
agent = Agent(
model.A, model.B,
B_action_dependencies=B_action_dependencies,
num_controls=num_controls,
)
# dummy history
action_key, obs_key = jr.split(jr.PRNGKey(0))
action = agent.policies[jr.choice(action_key, len(agent.policies))]
observation = [jr.randint(obs_key, (1, 1), 0, d) for d in agent.num_obs]
qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D)
prior = agent.update_empirical_prior(action, qs_hist)
qs = agent.infer_states(observations=observation, empirical_prior=prior)
q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)
action_reconstruct = agent.encode_multi_actions(action_multi)
print("A_dependencies", agent.A_dependencies)
print("B_dependencies", agent.B_dependencies)
print("B_action_dependencies", agent.B_action_dependencies)
print("original control dims", agent.num_controls_multi)
print("flattened control dims", agent.num_controls)
print("original B shapes", [a.data.shape for a in model.B])
print("flattened B shapes", [a.shape for a in agent.B])
print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in model.B])
print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B])
print("\n")
print("prior")
pprint([p.round(2) for p in prior])
print("post")
pprint([p.round(2) for p in qs])
print("action")
pprint(action)
print("action_multi")
pprint(action_multi)
print("action_reconstruct")
pprint(action_reconstruct)
A_dependencies [[0]] B_dependencies [[0], [1]] B_action_dependencies [[0, 1], []] original control dims [2, 3] flattened control dims [6, 1] original B shapes [(4, 4, 2, 3), (4, 4)] flattened B shapes [(1, 4, 4, 6), (1, 4, 4, 1)] B normalized [Array(True, dtype=bool), Array(True, dtype=bool)] B flat normalized [Array(True, dtype=bool), Array(True, dtype=bool)] prior [Array([[0. , 0.25, 0.25, 0.5 ]], dtype=float32), Array([[0. , 0.25, 0.25, 0.5 ]], dtype=float32)] post [Array([[[0., 0., 1., 0.]]], dtype=float32), Array([[[0. , 0.25, 0.25, 0.5 ]]], dtype=float32)] action Array([[0, 0]], dtype=int32) action_multi Array([[0, 0]], dtype=int32) action_reconstruct Array([[0, 0]], dtype=int32)