pymdp.control¶
pymdp.control
¶
Policy construction, expected free energy, and action sampling utilities.
Policies
¶
Bases: Module
A class for storing an array of policies and its properties
construct_policies(num_states: Sequence[int], num_controls: Sequence[int] | None = None, policy_len: int = 1, control_fac_idx: Sequence[int] | None = None) -> Array
¶
Generate an exhaustive policy matrix for the specified planning horizon.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_states
|
Sequence[int]
|
Dimensionalities of each hidden state factor. |
required |
num_controls
|
Sequence[int] | None
|
Dimensionalities of each control state factor. If |
None
|
policy_len
|
int
|
temporal depth ("planning horizon") of policies |
1
|
control_fac_idx
|
Sequence[int] | None
|
Indices of controllable hidden state factors (factors |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
policies |
Array
|
Policy matrix with shape |
sample_action(policies: Array, num_controls: Sequence[int], q_pi: Array, action_selection: str = 'deterministic', alpha: float = 16.0, rng_key: Array | None = None) -> Array
¶
Samples an action from posterior marginals, one action per control factor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_pi
|
Array
|
Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy. |
required |
policies
|
Array
|
Policy matrix with shape |
required |
num_controls
|
Sequence[int]
|
Dimensionalities of each control state factor. |
required |
action_selection
|
str
|
String indicating whether whether the selected action is chosen as the maximum of the posterior over actions, or whether it's sampled from the posterior marginal over actions |
'deterministic'
|
alpha
|
float
|
Action selection precision -- the inverse temperature of the softmax that is used to scale the
action marginals before sampling. This is only used if |
16.0
|
rng_key
|
Array | None
|
PRNG key required when |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
selected_policy |
1D Array
|
Vector containing the indices of the actions for each control factor |
sample_policy(policies: Array, q_pi: Array, action_selection: str = 'deterministic', alpha: float = 16.0, rng_key: Array | None = None) -> Array
¶
Select or sample a policy, then return its first-step multi-action.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policies
|
Array
|
Policy matrix with shape |
required |
q_pi
|
Array
|
Posterior over policies for one batch element. |
required |
action_selection
|
(deterministic, stochastic)
|
Selection mode for choosing a policy. |
"deterministic"
|
alpha
|
float
|
Precision (inverse temperature) used for stochastic sampling. |
16.0
|
rng_key
|
Array | None
|
PRNG key required for |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
First-step action vector for all control factors. |
get_marginals(q_pi: Array, policies: Array, num_controls: Sequence[int]) -> list[Array]
¶
Computes the marginal posterior(s) over actions by integrating their posterior probability under the policies that they appear within.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_pi
|
Array
|
Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy. |
required |
policies
|
Array
|
Policy matrix with shape |
required |
num_controls
|
Sequence[int]
|
Dimensionalities of each control state factor. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
action_marginals |
list[Array]
|
Marginal posterior over actions for each control factor. |
update_posterior_policies(policy_matrix: Array, qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], E: Array, pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], gamma: float = 16.0, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False) -> tuple[Array, Array]
¶
Compute posterior over policies and policy-wise negative expected free energy.
Notes
The returned policy score is neg_efe = -EFE. This same quantity is
often denoted by G.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy_matrix
|
Array
|
Policy tensor with shape |
required |
qs_init
|
list[Array]
|
Current marginal beliefs over hidden states. |
required |
A
|
list[Array]
|
Observation likelihood tensors. |
required |
B
|
list[Array]
|
Transition tensors. |
required |
C
|
list[Array]
|
Prior preferences over observations. |
required |
E
|
Array
|
Prior over policies. |
required |
pA
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
pB
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
A_dependencies
|
list[list[int]]
|
Observation dependencies between modalities and hidden-state factors. |
required |
B_dependencies
|
list[list[int]]
|
Transition dependencies between hidden-state factors. |
required |
gamma
|
float
|
Policy precision parameter. |
16.0
|
use_utility
|
bool
|
Whether to include expected utility in EFE. |
True
|
use_states_info_gain
|
bool
|
Whether to include state epistemic value. |
True
|
use_param_info_gain
|
bool
|
Whether to include parameter epistemic value. |
False
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
|
update_posterior_policies_inductive(policy_matrix: Array, qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], E: Array, pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], I: list[Array], gamma: float = 16.0, inductive_epsilon: float = 0.001, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False, use_inductive: bool = True) -> tuple[Array, Array]
¶
Compute policy posterior and negative expected free energy with optional inductive terms.
Notes
The returned policy score is neg_efe = -EFE. This same quantity is
often denoted by G.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy_matrix
|
Array
|
Policy tensor with shape |
required |
qs_init
|
list[Array]
|
Current marginal state beliefs. |
required |
A
|
list[Array]
|
Observation likelihood models. |
required |
B
|
list[Array]
|
Transition models. |
required |
C
|
list[Array]
|
Prior preference vectors. |
required |
E
|
Array
|
Policy prior over the policy space. |
required |
pA
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
pB
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
A_dependencies
|
list[list[int]]
|
Observation dependencies between modalities and state factors. |
required |
B_dependencies
|
list[list[int]]
|
Transition dependencies between hidden-state factors and control factors. |
required |
I
|
list[Array]
|
Inductive planning matrices. |
required |
gamma
|
float
|
Policy precision for softmax policy posterior. |
16.0
|
inductive_epsilon
|
float
|
Inductive value scale factor. |
0.001
|
use_utility
|
bool
|
Include utility term in expected free energy. |
True
|
use_states_info_gain
|
bool
|
Include epistemic state-information gain term. |
True
|
use_param_info_gain
|
bool
|
Include epistemic parameter-information gain term. |
False
|
use_inductive
|
bool
|
Include inductive value term. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
q_pi |
Array
|
Posterior over policies. |
neg_efe_all_policies |
Array
|
Policy-wise negative expected free energies. |
compute_expected_state(qs_prior: list[Array], B: list[Array], u_t: Array | Sequence[int], B_dependencies: list[list[int]] | None = None) -> list[Array]
¶
Compute posterior over next state, given belief about previous state, transition model and action...
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs_prior
|
list[Array]
|
Marginal beliefs over hidden states at time |
required |
B
|
list[Array]
|
Transition model tensors. |
required |
u_t
|
Array | Sequence[int]
|
Action indices for each control factor. |
required |
B_dependencies
|
list[list[int]] | None
|
Optional dependencies used to marginalize transition tensors. If |
None
|
Returns:
| Type | Description |
|---|---|
list[Array]
|
Marginal beliefs over next-time hidden states. |
compute_expected_state_and_Bs(qs_prior: list[Array], B: list[Array], u_t: Array | Sequence[int]) -> tuple[list[Array], list[Array]]
¶
Compute one-step predictive states and selected transition matrices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs_prior
|
list[Array]
|
Marginal beliefs over hidden states at time |
required |
B
|
list[Array]
|
Transition model tensors for each hidden-state factor. |
required |
u_t
|
Array | Sequence[int]
|
Action indices for each control factor at time |
required |
Returns:
| Type | Description |
|---|---|
tuple[list[Array], list[Array]]
|
|
compute_expected_obs(qs: list[Array], A: list[Array], A_dependencies: list[list[int]]) -> list[Array]
¶
New version of expected observation (computation of Q(o|pi)) that takes into account sparse dependencies between observation modalities and hidden state factors
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs
|
list[Array]
|
Beliefs over hidden states. |
required |
A
|
list[Array]
|
Observation likelihood models. |
required |
A_dependencies
|
list[list[int]]
|
Observation dependencies between modalities and state factors. |
required |
Returns:
| Type | Description |
|---|---|
list[Array]
|
Predictive beliefs over observations for each modality. |
compute_info_gain(qs: list[Array], qo: list[Array], A: list[Array], A_dependencies: list[list[int]]) -> Array
¶
Compute expected state-information gain term of expected free energy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs
|
list[Array]
|
Predicted hidden-state beliefs. |
required |
qo
|
list[Array]
|
Predicted observation beliefs. |
required |
A
|
list[Array]
|
Observation likelihood tensors. |
required |
A_dependencies
|
list[list[int]]
|
Observation dependencies between modalities and hidden-state factors. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar epistemic value from expected information gain. |
compute_expected_utility(qo: list[Array], C: list[Array], t: int = 0) -> Array
¶
Compute expected utility from predictive observations and preferences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qo
|
list[Array]
|
Predicted observations for each modality. |
required |
C
|
list[Array]
|
Prior preferences per modality. Each modality can be static |
required |
t
|
int
|
Planning timestep used when |
0
|
Returns:
| Type | Description |
|---|---|
Array
|
Scalar expected utility contribution. |
calc_negative_pA_info_gain(pA: list[Array], qo: list[Array], qs: list[Array], A_dependencies: list[list[int]]) -> Array
¶
Compute the negative expected Dirichlet information gain about pA.
Notes
This helper returns the negative of the parameter epistemic-value term.
Subtract its return value when adding parameter information gain to
neg_efe.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pA
|
list[Array]
|
Dirichlet parameters over observation model (same shape as |
required |
qo
|
list[Array]
|
Predictive posterior beliefs over observations; stores the beliefs about
observations expected under the policy at some arbitrary time |
required |
qs
|
list[Array]
|
Predictive posterior beliefs over hidden states, stores the beliefs about
hidden states expected under the policy at some arbitrary time |
required |
Returns:
| Name | Type | Description |
|---|---|---|
neg_infogain_pA |
Array
|
Negative expected information gain (scalar JAX array) for the pair of
predictive distributions |
calc_negative_pB_info_gain(pB: list[Array], qs_t: list[Array], qs_t_minus_1: list[Array], B_dependencies: list[list[int]], u_t_minus_1: Array | Sequence[int]) -> Array
¶
Compute the negative expected Dirichlet information gain about pB.
Notes
This helper returns the negative of the parameter epistemic-value term.
Subtract its return value when adding parameter information gain to
neg_efe.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pB
|
list[Array]
|
Dirichlet parameters over transition model (same shape as |
required |
qs_t
|
list[Array]
|
Predictive posterior beliefs over hidden states expected under the
policy at time |
required |
qs_t_minus_1
|
list[Array]
|
Posterior over hidden states at time |
required |
B_dependencies
|
list[list[int]]
|
For each state factor, indices of the state factors that its transition model depends on. |
required |
u_t_minus_1
|
Array | Sequence[int]
|
Actions in time step t-1 for each factor. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
neg_infogain_pB |
Array
|
Negative expected information gain (scalar JAX array) under the policy in question. |
compute_neg_efe_policy(qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], policy_i: Array, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False) -> Array
¶
Compute policy-wise negative expected free energy for one policy.
Notes
This function computes neg_efe = -EFE for a single policy. This policy
score (neg_efe) is commonly denoted G.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs_init
|
list[Array]
|
Initial hidden-state marginals at current timestep. |
required |
A
|
list[Array]
|
Observation likelihood tensors. |
required |
B
|
list[Array]
|
Transition tensors. |
required |
C
|
list[Array]
|
Prior preferences over observations. |
required |
pA
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
pB
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
A_dependencies
|
list[list[int]]
|
Observation dependencies between modalities and hidden-state factors. |
required |
B_dependencies
|
list[list[int]]
|
Transition dependencies between hidden-state factors. |
required |
policy_i
|
Array
|
Single policy trajectory with shape |
required |
use_utility
|
bool
|
Include expected utility term. |
True
|
use_states_info_gain
|
bool
|
Include state-information-gain term. |
True
|
use_param_info_gain
|
bool
|
Include parameter-information-gain term. |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Scalar negative expected free energy for |
compute_neg_efe_policy_inductive(qs_init: list[Array], A: list[Array], B: list[Array], C: list[Array], pA: list[Array] | None, pB: list[Array] | None, A_dependencies: list[list[int]], B_dependencies: list[list[int]], I: list[Array], policy_i: Array, inductive_epsilon: float = 0.001, use_utility: bool = True, use_states_info_gain: bool = True, use_param_info_gain: bool = False, use_inductive: bool = False) -> Array
¶
Compute policy-wise negative expected free energy with inductive planning.
Notes
This function computes neg_efe = -EFE for a single policy with optional
inductive-value terms. This score is commonly denoted G, so here
G = neg_efe = -EFE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs_init
|
list[Array]
|
Initial hidden-state marginals at current timestep. |
required |
A
|
list[Array]
|
Observation likelihood tensors. |
required |
B
|
list[Array]
|
Transition tensors. |
required |
C
|
list[Array]
|
Prior preferences over observations. |
required |
pA
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
pB
|
list[Array] | None
|
Optional posterior Dirichlet parameters for |
required |
A_dependencies
|
list[list[int]]
|
Observation dependencies between modalities and hidden-state factors. |
required |
B_dependencies
|
list[list[int]]
|
Transition dependencies between hidden-state factors. |
required |
I
|
list[Array]
|
Inductive reachability matrices. |
required |
policy_i
|
Array
|
Single policy trajectory with shape |
required |
inductive_epsilon
|
float
|
Scale of the inductive-value contribution. |
0.001
|
use_utility
|
bool
|
Include expected utility term. |
True
|
use_states_info_gain
|
bool
|
Include state-information-gain term. |
True
|
use_param_info_gain
|
bool
|
Include parameter-information-gain term. |
False
|
use_inductive
|
bool
|
Include inductive-value term. |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Scalar negative expected free energy for |
generate_I_matrix(H: list[Array], B: list[Array], threshold: float, depth: int) -> list[Array]
¶
Generate inductive reachability matrices using backward state reachability.
These matrices store whether state j (columns) can still reach the
intended state set after i backward steps (rows).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
H
|
list[Array]
|
Constraints over desired states (1 if you want to reach that state, 0 otherwise) |
required |
B
|
list[Array]
|
Dynamics likelihood mapping or transition model, mapping from hidden states at |
required |
threshold
|
float
|
The threshold for pruning transitions that are below a certain probability |
required |
depth
|
int
|
The temporal depth of the backward induction |
required |
Returns:
| Name | Type | Description |
|---|---|---|
I |
list[Array]
|
For each state factor, contains a 2D indicator array whose element
|
calc_inductive_value_t(qs: list[Array], qs_next: list[Array], I: list[Array], epsilon: float = 0.001) -> Array
¶
Computes the inductive value of a state at a particular time (translation of @tverbele's numpy implementation of inductive planning, formerly
called calc_inductive_cost).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qs
|
list[Array]
|
Marginal posterior beliefs over hidden states at a given timepoint. |
required |
qs_next
|
list[Array]
|
Predictive posterior beliefs over hidden states expected under the policy. |
required |
I
|
list[Array]
|
For each state factor, contains a 2D array whose element i,j yields the probability of reaching the goal state backwards from state j after i steps. |
required |
epsilon
|
float
|
Value that tunes the strength of the inductive value (how much it contributes to the expected free energy of policies) |
0.001
|
Returns:
| Name | Type | Description |
|---|---|---|
inductive_val |
float
|
Value (negative inductive cost) of visiting this state using backwards induction under the policy in question |