pymdp.utils

pymdp.utils

Utility functions for model construction, data shaping, and sampling.

A_dep_factors_dist(num_states: Sequence[int], A_dep_len: int) -> Array

Probability over hidden-state factors when building A dependencies.

Parameters:

Name Type Description Default
num_states Sequence[int]

Hidden-state dimensionalities.

required
A_dep_len int

Candidate dependency-list length.

required

Returns:

Type Description
Array

Probability vector over factors, same shape as num_states.

A_dep_len_dist(choices: Array, curr_sf_dim: int, max_sf_dim: int) -> Array

Distribution over A-dependency list lengths.

Parameters:

Name Type Description Default
choices Array

Candidate dependency-list lengths.

required
curr_sf_dim int

Dimensionality of a currently assigned hidden state.

required
max_sf_dim int

Maximum hidden-state dimensionality.

required

Returns:

Type Description
Array

Normalized weights over choices.

A_dep_len_dist_unconditional(choices: Array) -> Array

Unconditional prior for A-dependency list lengths.

Parameters:

Name Type Description Default
choices Array

Candidate dependency-list lengths.

required

Returns:

Type Description
Array

Normalized weights over choices.

apply_A_end2end_padding_batched(A: list[Array]) -> Array

Pad A tensors for end-to-end batched processing.

Parameters:

Name Type Description Default
A list[Array]

A tensors per modality.

required

Returns:

Type Description
Array

Batched padded A tensor.

apply_obs_end2end_padding_batched(obs: list[Array], max_obs_dim: int) -> Array

Pad observations for end-to-end batched processing.

Parameters:

Name Type Description Default
obs list[Array]

Observation tensors.

required
max_obs_dim int

Dimensionality to pad observation axis to.

required

Returns:

Type Description
Array

Batched padded observation tensor.

apply_padding_batched(xs: list[Array]) -> Array

Pad and concatenate variable-size arrays along a new batch axis.

Parameters:

Name Type Description Default
xs list[Array]

Arrays to concatenate.

required

Returns:

Type Description
Array

Padded batch tensor.

build_block_diag_A(A_list: list[Array]) -> tuple[Array, tuple[tuple[int, ...], ...], tuple[int, ...]]

Build a block-diagonal representation from modality-wise likelihood tensors.

Parameters:

Name Type Description Default
A_list list[Array]

List of likelihood tensors.

required

Returns:

Type Description
tuple[Array, tuple[tuple[int, ...], ...], tuple[int, ...]]

(A_big, state_shapes, cuts).

concatenate_observations_block_diag(obs_list: list[Array]) -> Array

Concatenate observation vectors for block-diagonal processing.

Parameters:

Name Type Description Default
obs_list list[Array]

One-hot encoded observations per modality.

required

Returns:

Type Description
Array

Concatenated observation tensor.

create_controllable_B(num_states: int | Sequence[int], num_controls: int | Sequence[int]) -> list[Array]

Create deterministic fully-controllable transition matrices.

Parameters:

Name Type Description Default
num_states int | Sequence[int]

Number of hidden states per factor.

required
num_controls int | Sequence[int]

Number of controls per factor.

required

Returns:

Type Description
list[Array]

A list of fully controllable transition tensors.

fig2img(fig: Any) -> np.ndarray

Utility conversion from Matplotlib figure to RGB image array.

Parameters:

Name Type Description Default
fig Any

Matplotlib figure object.

required

Returns:

Type Description
ndarray

RGB image array extracted from the figure.

generate_agent_spec(num_factors: int, num_modalities: int, state_dim_limits: tuple[int, int], obs_dim_limits: tuple[int, int], A_dep_len_limits: tuple[int, int], dim_sampling_type: str, A_dep_len_prior: str = 'uniform', key: Array | None = None) -> tuple[list[int], list[int], list[list[int]]]

Generate a random agent specification from high-level constraints.

Parameters:

Name Type Description Default
num_factors int

Total number of hidden state factors.

required
num_modalities int

Total number of observation modalities.

required
state_dim_limits tuple[int, int]

Inclusive lower/upper bounds for state dimensions.

required
obs_dim_limits tuple[int, int]

Inclusive lower/upper bounds for observation dimensions.

required
A_dep_len_limits tuple[int, int]

Lower/upper bounds for A-dependency list length.

required
dim_sampling_type str

Sampling strategy used for dimensionalities.

required
A_dep_len_prior str

Prior for A-dependency length.

'uniform'
key Array | None

Optional PRNG key.

None

Returns:

Type Description
tuple[list[int], list[int], list[list[int]]]

(num_states, num_obs, A_dependencies).

generate_agent_specs_from_parameter_sets(parameter_sets: Sequence[tuple[int, int, int, int, str, str]], num_agents_per_set: int = 1, max_A_dependency_list_size: int = 10, output_file: str | None = 'agent_specs.json', seed: int | None = None) -> dict[str, list[dict[str, Any]]]

Generate multiple agent specs from parameter grids.

Parameters:

Name Type Description Default
parameter_sets Sequence[tuple[int, int, int, int, str, str]]

Tuples of (num_factors, num_modalities, state_dim_upper_limit, obs_dim_upper_limit, dim_sampling_type, label).

required
num_agents_per_set int

Number of samples to draw for each parameter set.

1
max_A_dependency_list_size int

Maximum allowed A-dependency list size.

10
output_file str | None

Optional path to save generated specs.

'agent_specs.json'
seed int | None

RNG seed.

None

Returns:

Type Description
dict[str, list[dict[str, Any]]]

Mapping to generated specification records.

get_combination_index(x: jax.Array | np.ndarray, dims: Sequence[int]) -> jax.Array | np.ndarray

Find the index of an array of categorical values in an array of categorical dimensions

Parameters:

Name Type Description Default
x Array | ndarray

Categorical values to be converted into combination index.

required
dims Sequence[int]

Categorical dimensions used for conversion.

required

Returns:

Name Type Description
index jax.Array | np.ndarray of shape (batch_size)

Index of the combination.

get_sample_obs(num_obs: Sequence[int], batch_size: int = 1) -> list[Array]

Generate random observations for each modality.

Parameters:

Name Type Description Default
num_obs Sequence[int]

Outcome counts per modality.

required
batch_size int

Number of samples per modality.

1

Returns:

Type Description
list[Array]

Random observations of shape (batch_size, 1) per modality.

index_to_combination(index: jax.Array | np.ndarray, dims: Sequence[int]) -> jax.Array

Convert the combination index according to an array of categorical dimensions back to an array of categorical values

Parameters:

Name Type Description Default
index Array | ndarray

Index of the combination.

required
dims Sequence[int]

Categorical dimensions used for conversion.

required

Returns:

Name Type Description
x jax.Array | np.ndarray of shape (batch_size, act_dims)

Categorical values corresponding to each factor.

init_A_and_D_from_spec(num_obs: Sequence[int], num_states: Sequence[int], A_dependencies: list[list[int]], A_sparsity_level: float | None = None, batch_size: int = 1) -> tuple[list[Array], list[Array]]

Create initial A and D tensors from explicit model metadata.

Parameters:

Name Type Description Default
num_obs Sequence[int]

Observation cardinalities.

required
num_states Sequence[int]

Hidden-state cardinalities.

required
A_dependencies list[list[int]]

Modality-to-state dependencies.

required
A_sparsity_level float | None

Optional sparsity level when constructing A.

None
batch_size int

Number of sampled model instances.

1

Returns:

Type Description
tuple[list[Array], list[Array]]

Initialized A and D arrays.

list_array_norm_dist(dist_list: list[Array]) -> list[Array]

Normalizes a list of Categorical probability distributions.

Parameters:

Name Type Description Default
dist_list list[Array]

List of unnormalized Categorical distributions.

required

Returns:

Type Description
list[Array]

List of normalized distributions.

list_array_scaled(shape_list: Sequence[Sequence[int]], scale: float = 1.0) -> list[Array]

Create arrays filled with a constant scale value.

Parameters:

Name Type Description Default
shape_list Sequence[Sequence[int]]

Target tensor shapes.

required
scale float

Fill value.

1.0

Returns:

Type Description
list[Array]

Arrays filled with scale.

list_array_uniform(shape_list: Sequence[Sequence[int]]) -> list[Array]

Creates uniform Categorical arrays for each requested shape.

Parameters:

Name Type Description Default
shape_list Sequence[Sequence[int]]

Target tensor shapes.

required

Returns:

Type Description
list[Array]

Uniform distributions for each shape.

list_array_zeros(shape_list: Sequence[Sequence[int]]) -> list[Array]

Create zero arrays for each requested shape.

Parameters:

Name Type Description Default
shape_list Sequence[Sequence[int]]

Target tensor shapes.

required

Returns:

Type Description
list[Array]

Zero-filled arrays for each shape.

make_A_full(A_reduced: list[Array], A_dependencies: list[list[int]], num_obs: list[int], num_states: list[int]) -> list[Array]

Lift reduced likelihood tensors into full modality tensors.

Parameters:

Name Type Description Default
A_reduced list[Array]

Reduced likelihood tensors.

required
A_dependencies list[list[int]]

Dependency structure between modalities and state factors.

required
num_obs list[int]

Observation dimensions.

required
num_states list[int]

State dimensions.

required

Returns:

Type Description
list[Array]

Full likelihood tensors with redundant factor dimensions restored.

norm_dist(dist: Array) -> Array

Normalizes a Categorical probability distribution.

Parameters:

Name Type Description Default
dist Array

Unnormalized Categorical distribution.

required

Returns:

Type Description
Array

Normalized distribution.

prepare_obs_for_block_diag(obs: list[Array], num_obs: Sequence[int]) -> list[Array]

Prepare observation vectors for block-diagonal calculations.

Parameters:

Name Type Description Default
obs list[Array]

Raw observation tensors.

required
num_obs Sequence[int]

Observation cardinalities.

required

Returns:

Type Description
list[Array]

One-hot encoded observation arrays prepared for block-diagonal use.

preprocess_A_for_block_diag(A: list[Array]) -> tuple[Array, tuple[tuple[int, ...], ...], tuple[int, ...]]

Preprocess A matrices for block-diagonal likelihood evaluation.

Parameters:

Name Type Description Default
A list[Array]

Likelihood tensors.

required

Returns:

Type Description
tuple[Array, tuple[tuple[int, ...], ...], tuple[int, ...]]

Block-diagonal representation and auxiliary metadata.

random_A_array(key: Array, num_obs: int | Sequence[int], num_states: int | Sequence[int], A_dependencies: list[list[int]] | None = None) -> list[Array]

Create random observation likelihood tensors.

Parameters:

Name Type Description Default
key Array

PRNG key for sampling.

required
num_obs int | Sequence[int]

Number of discrete observations.

required
num_states int | Sequence[int]

Number of hidden states per factor.

required
A_dependencies list[list[int]] | None

Optional dependency structure per modality.

None

Returns:

Type Description
list[Array]

Randomized A tensors.

random_B_array(key: Array, num_states: int | Sequence[int], num_controls: int | Sequence[int], B_dependencies: list[list[int]] | None = None, B_action_dependencies: list[list[int]] | None = None) -> list[Array]

Create random transition tensors.

Parameters:

Name Type Description Default
key Array

PRNG key for sampling.

required
num_states int | Sequence[int]

Number of states per hidden-state factor.

required
num_controls int | Sequence[int]

Number of controls per factor.

required
B_dependencies list[list[int]] | None

Optional state-factor dependency structure per factor.

None
B_action_dependencies list[list[int]] | None

Optional action-factor dependency structure per hidden-state factor.

None

Returns:

Type Description
list[Array]

Randomized B tensors.

random_factorized_categorical(key: Array, dims_per_var: Sequence[int]) -> list[Array]

Create random factorized Categorical distributions.

Parameters:

Name Type Description Default
key Array

PRNG key for sampling.

required
dims_per_var Sequence[int]

Number of levels per variable.

required

Returns:

Type Description
list[Array]

A list of sampled categorical vectors.

resolve_a_dependencies(num_factors: int, num_modalities: int, A_dependencies: list[list[int]] | None = None) -> list[list[int]]

Return modality-to-factor dependencies, filling in the fully connected default.

resolve_b_action_dependencies(num_factors: int, B_action_dependencies: list[list[int]] | None = None) -> list[list[int]]

Return control-factor dependencies, filling in the per-factor default.

resolve_b_dependencies(num_factors: int, B_dependencies: list[list[int]] | None = None) -> list[list[int]]

Return factor-to-factor transition dependencies, filling in local defaults.

validate_normalization(tensor: Array, axis: int = 1, tensor_name: str = 'tensor') -> None

Validate that a probability tensor has normalized distributions along a given axis.

Parameters:

Name Type Description Default
tensor Array

Tensor to validate.

required
axis int

Axis that should contain normalized probability distributions.

1
tensor_name str

Human-readable name used in error messages.

'tensor'

Returns:

Type Description
None

In eager mode, raises ValueError if any distribution along the given axis sums to zero or is not normalized. Under JAX tracing/JIT, invalid distributions are signaled via eqx.error_if.