pymdp.control

pymdp.control

Policy construction, expected free energy, and action sampling utilities.

Policies

Bases: Module

A class for storing an array of policies and its properties

construct_policies(num_states: Sequence[int], num_controls: Sequence[int] | None = None, policy_len: int = 1, control_fac_idx: Sequence[int] | None = None) -> Array

Generate an exhaustive policy matrix for the specified planning horizon.

Parameters:

Name Type Description Default
num_states Sequence[int]

Dimensionalities of each hidden state factor.

required
num_controls Sequence[int] | None

Dimensionalities of each control state factor. If None, this is computed from controllable state factors.

None
policy_len int

temporal depth ("planning horizon") of policies

1
control_fac_idx Sequence[int] | None

Indices of controllable hidden state factors (factors i where num_controls[i] > 1).

None

Returns:

Name Type Description
policies Array

Policy matrix with shape (num_policies, policy_len, num_factors).

sample_action(policies: Array, num_controls: Sequence[int], q_pi: Array, action_selection: str = 'deterministic', alpha: float = 16.0, rng_key: Array | None = None) -> Array

Samples an action from posterior marginals, one action per control factor.

Parameters:

Name Type Description Default
q_pi Array

Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy.

required
policies Array

Policy matrix with shape (num_policies, policy_len, num_factors).

required
num_controls Sequence[int]

Dimensionalities of each control state factor.

required
action_selection str

String indicating whether whether the selected action is chosen as the maximum of the posterior over actions, or whether it's sampled from the posterior marginal over actions

'deterministic'
alpha float

Action selection precision -- the inverse temperature of the softmax that is used to scale the action marginals before sampling. This is only used if action_selection is "stochastic".

16.0
rng_key Array | None

PRNG key required when action_selection='stochastic'.

None

Returns:

Name Type Description
selected_policy 1D Array

Vector containing the indices of the actions for each control factor

sample_policy(policies: Array, q_pi: Array, action_selection: str = 'deterministic', alpha: float = 16.0, rng_key: Array | None = None) -> Array

Select or sample a policy, then return its first-step multi-action.

Parameters:

Name Type Description Default
policies Array

Policy matrix with shape (num_policies, policy_len, num_factors).

required
q_pi Array

Posterior over policies for one batch element.

required
action_selection (deterministic, stochastic)

Selection mode for choosing a policy.

"deterministic"
alpha float

Precision (inverse temperature) used for stochastic sampling.

16.0
rng_key Array | None

PRNG key required for action_selection='stochastic'.

None

Returns:

Type Description
Array

First-step action vector for all control factors.

get_marginals(q_pi: Array, policies: Array, num_controls: Sequence[int]) -> list[Array]

Computes the marginal posterior(s) over actions by integrating their posterior probability under the policies that they appear within.

Parameters:

Name Type Description Default
q_pi Array

Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy.

required
policies Array

Policy matrix with shape (num_policies, policy_len, num_factors).

required
num_controls Sequence[int]

Dimensionalities of each control state factor.

required

Returns:

Name Type Description
action_marginals list[Array]

Marginal posterior over actions for each control factor.

update_posterior_policies(policy_matrix: Array, qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], E: Array, pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], gamma: float = 16.0, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False) -> tuple[Array, Array]

Compute posterior over policies and policy-wise negative expected free energy.

Notes

The returned policy score is neg_efe = -EFE. This same quantity is often denoted by G.

Parameters:

Name Type Description Default
policy_matrix Array

Policy tensor with shape (num_policies, policy_len, num_factors).

required
qs_init list[Array]

Current marginal beliefs over hidden states.

required
A list[Array]

Observation likelihood tensors.

required
B list[Array]

Transition tensors.

required
C list[Array]

Prior preferences over observations.

required
E Array

Prior over policies.

required
pA list[Array] | None

Optional posterior Dirichlet parameters for A. When use_param_info_gain=True, provide pA, pB, or both.

required
pB list[Array] | None

Optional posterior Dirichlet parameters for B. When use_param_info_gain=True, provide pA, pB, or both.

required
A_dependencies list[list[int]]

Observation dependencies between modalities and hidden-state factors.

required
B_dependencies list[list[int]]

Transition dependencies between hidden-state factors.

required
gamma float

Policy precision parameter.

16.0
use_utility bool

Whether to include expected utility in EFE.

True
use_states_info_gain bool

Whether to include state epistemic value.

True
use_param_info_gain bool

Whether to include parameter epistemic value.

False

Returns:

Type Description
tuple[Array, Array]

(q_pi, neg_efe_all_policies) where q_pi is the posterior over policies.

update_posterior_policies_inductive(policy_matrix: Array, qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], E: Array, pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], I: list[Array], gamma: float = 16.0, inductive_epsilon: float = 0.001, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False, use_inductive: bool = True) -> tuple[Array, Array]

Compute policy posterior and negative expected free energy with optional inductive terms.

Notes

The returned policy score is neg_efe = -EFE. This same quantity is often denoted by G.

Parameters:

Name Type Description Default
policy_matrix Array

Policy tensor with shape (num_policies, policy_len, num_factors).

required
qs_init list[Array]

Current marginal state beliefs.

required
A list[Array]

Observation likelihood models.

required
B list[Array]

Transition models.

required
C list[Array]

Prior preference vectors.

required
E Array

Policy prior over the policy space.

required
pA list[Array] | None

Optional posterior Dirichlet parameters for A. When use_param_info_gain=True, provide pA, pB, or both.

required
pB list[Array] | None

Optional posterior Dirichlet parameters for B. When use_param_info_gain=True, provide pA, pB, or both.

required
A_dependencies list[list[int]]

Observation dependencies between modalities and state factors.

required
B_dependencies list[list[int]]

Transition dependencies between hidden-state factors and control factors.

required
I list[Array]

Inductive planning matrices.

required
gamma float

Policy precision for softmax policy posterior.

16.0
inductive_epsilon float

Inductive value scale factor.

0.001
use_utility bool

Include utility term in expected free energy.

True
use_states_info_gain bool

Include epistemic state-information gain term.

True
use_param_info_gain bool

Include epistemic parameter-information gain term.

False
use_inductive bool

Include inductive value term.

True

Returns:

Name Type Description
q_pi Array

Posterior over policies.

neg_efe_all_policies Array

Policy-wise negative expected free energies.

compute_expected_state(qs_prior: list[Array], B: list[Array], u_t: Array | Sequence[int], B_dependencies: list[list[int]] | None = None) -> list[Array]

Compute posterior over next state, given belief about previous state, transition model and action...

Parameters:

Name Type Description Default
qs_prior list[Array]

Marginal beliefs over hidden states at time t.

required
B list[Array]

Transition model tensors.

required
u_t Array | Sequence[int]

Action indices for each control factor.

required
B_dependencies list[list[int]] | None

Optional dependencies used to marginalize transition tensors. If None, defaults to factor-local transitions.

None

Returns:

Type Description
list[Array]

Marginal beliefs over next-time hidden states.

compute_expected_state_and_Bs(qs_prior: list[Array], B: list[Array], u_t: Array | Sequence[int]) -> tuple[list[Array], list[Array]]

Compute one-step predictive states and selected transition matrices.

Parameters:

Name Type Description Default
qs_prior list[Array]

Marginal beliefs over hidden states at time t.

required
B list[Array]

Transition model tensors for each hidden-state factor.

required
u_t Array | Sequence[int]

Action indices for each control factor at time t.

required

Returns:

Type Description
tuple[list[Array], list[Array]]

(qs_next, Bs) where qs_next are next-state marginals and Bs are the action-conditioned transition slices used for each factor.

compute_expected_obs(qs: list[Array], A: list[Array], A_dependencies: list[list[int]]) -> list[Array]

New version of expected observation (computation of Q(o|pi)) that takes into account sparse dependencies between observation modalities and hidden state factors

Parameters:

Name Type Description Default
qs list[Array]

Beliefs over hidden states.

required
A list[Array]

Observation likelihood models.

required
A_dependencies list[list[int]]

Observation dependencies between modalities and state factors.

required

Returns:

Type Description
list[Array]

Predictive beliefs over observations for each modality.

compute_info_gain(qs: list[Array], qo: list[Array], A: list[Array], A_dependencies: list[list[int]]) -> Array

Compute expected state-information gain term of expected free energy.

Parameters:

Name Type Description Default
qs list[Array]

Predicted hidden-state beliefs.

required
qo list[Array]

Predicted observation beliefs.

required
A list[Array]

Observation likelihood tensors.

required
A_dependencies list[list[int]]

Observation dependencies between modalities and hidden-state factors.

required

Returns:

Type Description
Array

Scalar epistemic value from expected information gain.

compute_expected_utility(qo: list[Array], C: list[Array], t: int = 0) -> Array

Compute expected utility from predictive observations and preferences.

Parameters:

Name Type Description Default
qo list[Array]

Predicted observations for each modality.

required
C list[Array]

Prior preferences per modality. Each modality can be static (num_obs,) or time-indexed (policy_len, num_obs).

required
t int

Planning timestep used when C[m] is time-indexed.

0

Returns:

Type Description
Array

Scalar expected utility contribution.

calc_negative_pA_info_gain(pA: list[Array], qo: list[Array], qs: list[Array], A_dependencies: list[list[int]]) -> Array

Compute the negative expected Dirichlet information gain about pA.

Notes

This helper returns the negative of the parameter epistemic-value term. Subtract its return value when adding parameter information gain to neg_efe.

Parameters:

Name Type Description Default
pA list[Array]

Dirichlet parameters over observation model (same shape as A).

required
qo list[Array]

Predictive posterior beliefs over observations; stores the beliefs about observations expected under the policy at some arbitrary time t.

required
qs list[Array]

Predictive posterior beliefs over hidden states, stores the beliefs about hidden states expected under the policy at some arbitrary time t.

required

Returns:

Name Type Description
neg_infogain_pA Array

Negative expected information gain (scalar JAX array) for the pair of predictive distributions qo and qs.

calc_negative_pB_info_gain(pB: list[Array], qs_t: list[Array], qs_t_minus_1: list[Array], B_dependencies: list[list[int]], u_t_minus_1: Array | Sequence[int]) -> Array

Compute the negative expected Dirichlet information gain about pB.

Notes

This helper returns the negative of the parameter epistemic-value term. Subtract its return value when adding parameter information gain to neg_efe.

Parameters:

Name Type Description Default
pB list[Array]

Dirichlet parameters over transition model (same shape as B).

required
qs_t list[Array]

Predictive posterior beliefs over hidden states expected under the policy at time t.

required
qs_t_minus_1 list[Array]

Posterior over hidden states at time t-1 (before receiving observations).

required
B_dependencies list[list[int]]

For each state factor, indices of the state factors that its transition model depends on.

required
u_t_minus_1 Array | Sequence[int]

Actions in time step t-1 for each factor.

required

Returns:

Name Type Description
neg_infogain_pB Array

Negative expected information gain (scalar JAX array) under the policy in question.

compute_neg_efe_policy(qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], policy_i: Array, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False) -> Array

Compute policy-wise negative expected free energy for one policy.

Notes

This function computes neg_efe = -EFE for a single policy. This policy score (neg_efe) is commonly denoted G.

Parameters:

Name Type Description Default
qs_init list[Array]

Initial hidden-state marginals at current timestep.

required
A list[Array]

Observation likelihood tensors.

required
B list[Array]

Transition tensors.

required
C list[Array]

Prior preferences over observations.

required
pA list[Array] | None

Optional posterior Dirichlet parameters for A. When use_param_info_gain=True, provide pA, pB, or both.

required
pB list[Array] | None

Optional posterior Dirichlet parameters for B. When use_param_info_gain=True, provide pA, pB, or both.

required
A_dependencies list[list[int]]

Observation dependencies between modalities and hidden-state factors.

required
B_dependencies list[list[int]]

Transition dependencies between hidden-state factors.

required
policy_i Array

Single policy trajectory with shape (policy_len, num_factors).

required
use_utility bool

Include expected utility term.

True
use_states_info_gain bool

Include state-information-gain term.

True
use_param_info_gain bool

Include parameter-information-gain term.

False

Returns:

Type Description
Array

Scalar negative expected free energy for policy_i.

compute_neg_efe_policy_inductive(qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], I: list[Array], policy_i: Array, inductive_epsilon: float = 0.001, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False, use_inductive: bool = False) -> Array

Compute policy-wise negative expected free energy with inductive planning.

Notes

This function computes neg_efe = -EFE for a single policy with optional inductive-value terms. This score is commonly denoted G, so here G = neg_efe = -EFE.

Parameters:

Name Type Description Default
qs_init list[Array]

Initial hidden-state marginals at current timestep.

required
A list[Array]

Observation likelihood tensors.

required
B list[Array]

Transition tensors.

required
C list[Array]

Prior preferences over observations.

required
pA list[Array] | None

Optional posterior Dirichlet parameters for A. When use_param_info_gain=True, provide pA, pB, or both.

required
pB list[Array] | None

Optional posterior Dirichlet parameters for B. When use_param_info_gain=True, provide pA, pB, or both.

required
A_dependencies list[list[int]]

Observation dependencies between modalities and hidden-state factors.

required
B_dependencies list[list[int]]

Transition dependencies between hidden-state factors.

required
I list[Array]

Inductive reachability matrices.

required
policy_i Array

Single policy trajectory with shape (policy_len, num_factors).

required
inductive_epsilon float

Scale of the inductive-value contribution.

0.001
use_utility bool

Include expected utility term.

True
use_states_info_gain bool

Include state-information-gain term.

True
use_param_info_gain bool

Include parameter-information-gain term.

False
use_inductive bool

Include inductive-value term.

False

Returns:

Type Description
Array

Scalar negative expected free energy for policy_i.

generate_I_matrix(H: list[Array], B: list[Array], threshold: float, depth: int) -> list[Array]

Generate inductive reachability matrices using backward state reachability.

These matrices store whether state j (columns) can still reach the intended state set after i backward steps (rows).

Parameters:

Name Type Description Default
H list[Array]

Constraints over desired states (1 if you want to reach that state, 0 otherwise)

required
B list[Array]

Dynamics likelihood mapping or transition model, mapping from hidden states at t to hidden states at t+1, given some control state u. Each element B[f] stores a 3-D tensor for hidden state factor f, whose entries B[f][s, v, u] store the probability of hidden state level s at the current time, given hidden state level v and action u at the previous time.

required
threshold float

The threshold for pruning transitions that are below a certain probability

required
depth int

The temporal depth of the backward induction

required

Returns:

Name Type Description
I list[Array]

For each state factor, contains a 2D indicator array whose element i, j is 1 when state j can still reach the intended state set after i backward steps, and 0 otherwise.

calc_inductive_value_t(qs: list[Array], qs_next: list[Array], I: list[Array], epsilon: float = 0.001) -> Array

Computes the inductive value of a state at a particular time (translation of @tverbele's numpy implementation of inductive planning, formerly called calc_inductive_cost).

Parameters:

Name Type Description Default
qs list[Array]

Marginal posterior beliefs over hidden states at a given timepoint.

required
qs_next list[Array]

Predictive posterior beliefs over hidden states expected under the policy.

required
I list[Array]

For each state factor, contains a 2D array whose element i,j yields the probability of reaching the goal state backwards from state j after i steps.

required
epsilon float

Value that tunes the strength of the inductive value (how much it contributes to the expected free energy of policies)

0.001

Returns:

Name Type Description
inductive_val float

Value (negative inductive cost) of visiting this state using backwards induction under the policy in question