pymdp.envs.env

pymdp.envs.env

Environment interfaces and POMDP-backed environment utilities.

This module provides: - Env: an abstract JAX-compatible environment interface used by rollout(), - PymdpEnv: a concrete environment driven by categorical A, B, and D, - make(...): a convenience constructor for PymdpEnv and optional params.

Env

Bases: ABC

Abstract JAX-compatible environment interface used by rollout().

generate_env_params(key: Array | None = None, batch_size: int | None = None) -> dict[str, list[Array]] | None

Generate optional environment parameter pytrees.

Parameters:

Name Type Description Default
key Array | None

Optional JAX PRNG key (unused by default implementation).

None
batch_size int | None

Optional batch size for parameter generation.

None

Returns:

Type Description
dict[str, list[Array]] | None

Environment parameters or None if not implemented.

reset(key: Array, state: list[Array] | None = None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]] abstractmethod

Reset environment state and return initial observation/state.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
state list[Array] | None

Optional explicit initial hidden state.

None
env_params dict[str, list[Array]] | None

Optional runtime override for environment parameters.

None

Returns:

Type Description
tuple[list[Array], list[Array]]

Initial observations and hidden state.

step(key: Array, state: list[Array], action: Array | None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]] abstractmethod

Advance one environment step and return new observation/state.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
state list[Array]

Current hidden state.

required
action Array | None

Action sampled by the agent. None can be used for no-op updates.

required
env_params dict[str, list[Array]] | None

Optional runtime override for environment parameters.

None

Returns:

Type Description
tuple[list[Array], list[Array]]

Next observations and hidden state.

PymdpEnv

Bases: Env

Environment whose dynamics are defined by categorical A, B, and D.

PymdpEnv is useful when the environment is isomorphic to a discrete POMDP generative process: - A[m]: observation likelihoods per modality, - B[f]: transitions per hidden-state factor, - D[f]: initial-state priors per hidden-state factor.

__init__(A: Sequence[Array] | Sequence[Distribution] | None = None, B: Sequence[Array] | Sequence[Distribution] | None = None, D: Sequence[Array] | Sequence[Distribution] | None = None, A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, categorical_obs: bool = False, **kwargs: Any) -> None

Initialize PymdpEnv.

Parameters:

Name Type Description Default
A sequence[Array] | sequence[Distribution] | None

Observation likelihood tensors.

None
B sequence[Array] | sequence[Distribution] | None

Transition tensors.

None
D sequence[Array] | sequence[Distribution] | None

Initial-state priors.

None
A_dependencies list[list[int]] | None

Modality-to-state dependencies for A.

None
B_dependencies list[list[int]] | None

State-to-state dependencies for B.

None
categorical_obs bool

If True, emit one-hot categorical observation vectors with shape (1, num_obs) per modality. Otherwise emit discrete indices with shape (1,).

False
**kwargs Any

Accepted for forward compatibility.

{}

generate_env_params(key: Array | None = None, batch_size: int | None = None) -> dict[str, list[Array]]

Return default environment params, optionally broadcast to batch.

Parameters:

Name Type Description Default
key Array | None

Optional JAX PRNG key (unused).

None
batch_size int | None

If provided, broadcast each parameter leaf with leading shape (batch_size, ...).

None

Returns:

Type Description
dict[str, list[Array]]

Dictionary with keys "A", "B", and "D".

reset(key: Array, state: list[Array] | None = None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]]

Reset state and emit an initial observation sample.

If state is omitted, states are sampled from D.

step(key: Array, state: list[Array], action: Array | None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]]

Advance the process by one timestep.

If action is provided, next hidden states are sampled from B. Observations are then sampled from A conditioned on the new state.

make(A: Sequence[Array] | Sequence[Distribution], B: Sequence[Array] | Sequence[Distribution], D: Sequence[Array] | Sequence[Distribution], A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, make_env_params: bool = False, **kwargs: Any) -> tuple[PymdpEnv, dict[str, list[Array]] | None]

Construct a PymdpEnv (and optionally environment parameters).

Parameters:

Name Type Description Default
A sequence[Array] | sequence[Distribution]

Observation likelihood tensors, one per observation modality.

required
B sequence[Array] | sequence[Distribution]

Transition tensors, one per hidden-state factor.

required
D sequence[Array] | sequence[Distribution]

Initial-state priors, one per hidden-state factor.

required
A_dependencies list[list[int]] | None

Explicit modality-to-state dependencies. If None, dependencies are inferred when possible.

None
B_dependencies list[list[int]] | None

Explicit state-transition dependencies. If None, dependencies are inferred when possible.

None
make_env_params bool

If True, also return env_params={"A": ..., "B": ..., "D": ...} with Distribution entries converted to dense arrays.

False
**kwargs Any

Additional keyword arguments forwarded to PymdpEnv.

{}

Returns:

Type Description
tuple[PymdpEnv, dict[str, list[Array]] | None]

Constructed environment and optional unbatched environment parameters. To broadcast parameters to a larger batch, call env.generate_env_params(batch_size=...).