pymdp.envs.env¶
pymdp.envs.env
¶
Environment interfaces and POMDP-backed environment utilities.
This module provides:
- Env: an abstract JAX-compatible environment interface used by rollout(),
- PymdpEnv: a concrete environment driven by categorical A, B, and D,
- make(...): a convenience constructor for PymdpEnv and optional params.
Env
¶
Bases: ABC
Abstract JAX-compatible environment interface used by rollout().
generate_env_params(key: Array | None = None, batch_size: int | None = None) -> dict[str, list[Array]] | None
¶
Generate optional environment parameter pytrees.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array | None
|
Optional JAX PRNG key (unused by default implementation). |
None
|
batch_size
|
int | None
|
Optional batch size for parameter generation. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, list[Array]] | None
|
Environment parameters or |
reset(key: Array, state: list[Array] | None = None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]]
abstractmethod
¶
Reset environment state and return initial observation/state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key. |
required |
state
|
list[Array] | None
|
Optional explicit initial hidden state. |
None
|
env_params
|
dict[str, list[Array]] | None
|
Optional runtime override for environment parameters. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[list[Array], list[Array]]
|
Initial observations and hidden state. |
step(key: Array, state: list[Array], action: Array | None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]]
abstractmethod
¶
Advance one environment step and return new observation/state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key. |
required |
state
|
list[Array]
|
Current hidden state. |
required |
action
|
Array | None
|
Action sampled by the agent. |
required |
env_params
|
dict[str, list[Array]] | None
|
Optional runtime override for environment parameters. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[list[Array], list[Array]]
|
Next observations and hidden state. |
PymdpEnv
¶
Bases: Env
Environment whose dynamics are defined by categorical A, B, and D.
PymdpEnv is useful when the environment is isomorphic to a discrete POMDP
generative process:
- A[m]: observation likelihoods per modality,
- B[f]: transitions per hidden-state factor,
- D[f]: initial-state priors per hidden-state factor.
__init__(A: Sequence[Array] | Sequence[Distribution] | None = None, B: Sequence[Array] | Sequence[Distribution] | None = None, D: Sequence[Array] | Sequence[Distribution] | None = None, A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, categorical_obs: bool = False, **kwargs: Any) -> None
¶
Initialize PymdpEnv.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
sequence[Array] | sequence[Distribution] | None
|
Observation likelihood tensors. |
None
|
B
|
sequence[Array] | sequence[Distribution] | None
|
Transition tensors. |
None
|
D
|
sequence[Array] | sequence[Distribution] | None
|
Initial-state priors. |
None
|
A_dependencies
|
list[list[int]] | None
|
Modality-to-state dependencies for |
None
|
B_dependencies
|
list[list[int]] | None
|
State-to-state dependencies for |
None
|
categorical_obs
|
bool
|
If |
False
|
**kwargs
|
Any
|
Accepted for forward compatibility. |
{}
|
generate_env_params(key: Array | None = None, batch_size: int | None = None) -> dict[str, list[Array]]
¶
Return default environment params, optionally broadcast to batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array | None
|
Optional JAX PRNG key (unused). |
None
|
batch_size
|
int | None
|
If provided, broadcast each parameter leaf with leading shape
|
None
|
Returns:
| Type | Description |
|---|---|
dict[str, list[Array]]
|
Dictionary with keys |
reset(key: Array, state: list[Array] | None = None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]]
¶
Reset state and emit an initial observation sample.
If state is omitted, states are sampled from D.
step(key: Array, state: list[Array], action: Array | None, env_params: dict[str, list[Array]] | None = None) -> tuple[list[Array], list[Array]]
¶
Advance the process by one timestep.
If action is provided, next hidden states are sampled from B.
Observations are then sampled from A conditioned on the new state.
make(A: Sequence[Array] | Sequence[Distribution], B: Sequence[Array] | Sequence[Distribution], D: Sequence[Array] | Sequence[Distribution], A_dependencies: list[list[int]] | None = None, B_dependencies: list[list[int]] | None = None, make_env_params: bool = False, **kwargs: Any) -> tuple[PymdpEnv, dict[str, list[Array]] | None]
¶
Construct a PymdpEnv (and optionally environment parameters).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
sequence[Array] | sequence[Distribution]
|
Observation likelihood tensors, one per observation modality. |
required |
B
|
sequence[Array] | sequence[Distribution]
|
Transition tensors, one per hidden-state factor. |
required |
D
|
sequence[Array] | sequence[Distribution]
|
Initial-state priors, one per hidden-state factor. |
required |
A_dependencies
|
list[list[int]] | None
|
Explicit modality-to-state dependencies. If |
None
|
B_dependencies
|
list[list[int]] | None
|
Explicit state-transition dependencies. If |
None
|
make_env_params
|
bool
|
If |
False
|
**kwargs
|
Any
|
Additional keyword arguments forwarded to |
{}
|
Returns:
| Type | Description |
|---|---|
tuple[PymdpEnv, dict[str, list[Array]] | None]
|
Constructed environment and optional unbatched environment parameters.
To broadcast parameters to a larger batch, call
|