pymdp logo
  • Home
  • Getting Started
    • Installation
    • Quickstart (JAX)
  • Guides
    • Using rollout() for compiled active inference loops
    • PymdpEnv and Custom Environments
    • Generative Model Structure
  • Migration
    • NumPy/legacy to JAX
  • Tutorials
    • Overview
    • Notebook Gallery
  • API Reference
    • Overview
    • Agent
    • Inference
    • Control
    • Learning
    • Algorithms
    • Utils
    • Maths
    • Environment
    • Env Rollout
    • Sophisticated Inference Planning
    • MCTS Planning
  • Legacy
    • Legacy/NumPy Archive
  • Development
    • Viewing Docs Locally
    • Release Notes
  • infer-actively/pymdp
    • GitHub tag (latest by date)
    • GitHub Repo stars
    • GitHub forks

Open In Colab

In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
    %pip install "inferactively-pymdp" -q
import sys if "google.colab" in sys.modules: %pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import numpy as np
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.experimental.sparse as jsparse
from jax import nn, vmap, jit, block_until_ready
from functools import partial

from pymdp.utils import init_A_and_D_from_spec, get_sample_obs, generate_agent_specs_from_parameter_sets

# Hybrid
from pymdp.utils import apply_padding_batched
from pymdp.maths import compute_log_likelihoods_padded, deconstruct_lls

# Hybrid block
from pymdp.utils import build_block_diag_A, preprocess_A_for_block_diag, prepare_obs_for_block_diag, concatenate_observations_block_diag
from pymdp.maths import compute_log_likelihoods_block_diag, deconstruct_log_likelihoods_block_diag

from pymdp.algos import run_factorized_fpi_hybrid # For hybrid and hybrid block

# End2end padded
from pymdp.utils import apply_A_end2end_padding_batched, apply_obs_end2end_padding_batched
from pymdp.maths import compute_log_likelihood_per_modality_end2end2_padded
from pymdp.algos import run_factorized_fpi_end2end_padded
import numpy as np import jax.numpy as jnp import jax.tree_util as jtu import jax.experimental.sparse as jsparse from jax import nn, vmap, jit, block_until_ready from functools import partial from pymdp.utils import init_A_and_D_from_spec, get_sample_obs, generate_agent_specs_from_parameter_sets # Hybrid from pymdp.utils import apply_padding_batched from pymdp.maths import compute_log_likelihoods_padded, deconstruct_lls # Hybrid block from pymdp.utils import build_block_diag_A, preprocess_A_for_block_diag, prepare_obs_for_block_diag, concatenate_observations_block_diag from pymdp.maths import compute_log_likelihoods_block_diag, deconstruct_log_likelihoods_block_diag from pymdp.algos import run_factorized_fpi_hybrid # For hybrid and hybrid block # End2end padded from pymdp.utils import apply_A_end2end_padding_batched, apply_obs_end2end_padding_batched from pymdp.maths import compute_log_likelihood_per_modality_end2end2_padded from pymdp.algos import run_factorized_fpi_end2end_padded
In [1]:
Copied!
# Define coordinated parameter sets
# (num_factors, num_modalities, state_dim_upper_limit, obs_dim_upper_limit, dim_sampling_type, label)
parameter_sets = [
    (5, 5, 5, 5, 'uniform', 'low'),
    (10, 10, 10, 10, 'uniform', 'medium'),
    (25, 25, 25, 25, 'uniform', 'high'),
    # (125, 125, 125, 125, 'uniform', 'extreme'),  # Uncomment to include extreme cases
]

# Generate agent specs without dumping to file
specs = generate_agent_specs_from_parameter_sets(
    parameter_sets,
    num_agents_per_set=1,
    output_file=None  # Don't save to file
)

spec = specs['arbitrary dependencies'][1]
spec
# Define coordinated parameter sets # (num_factors, num_modalities, state_dim_upper_limit, obs_dim_upper_limit, dim_sampling_type, label) parameter_sets = [ (5, 5, 5, 5, 'uniform', 'low'), (10, 10, 10, 10, 'uniform', 'medium'), (25, 25, 25, 25, 'uniform', 'high'), # (125, 125, 125, 125, 'uniform', 'extreme'), # Uncomment to include extreme cases ] # Generate agent specs without dumping to file specs = generate_agent_specs_from_parameter_sets( parameter_sets, num_agents_per_set=1, output_file=None # Don't save to file ) spec = specs['arbitrary dependencies'][1] spec
Out[1]:
{'num_factors': 10,
 'num_modalities': 10,
 'num_states': [3, 6, 2, 9, 7, 4, 7, 5, 7, 5],
 'num_obs': [4, 2, 4, 9, 7, 8, 4, 5, 7, 4],
 'A_dependencies': [[1],
  [4],
  [5, 6],
  [0, 8],
  [0, 2, 7],
  [3, 5],
  [6],
  [2, 4, 7, 9],
  [0, 2, 7, 8],
  [7]],
 'metadata': {'num_factors': 'medium',
  'num_modalities': 'medium',
  'state_dim_upper_limit': 'medium',
  'obs_dim_upper_limit': 'medium',
  'dim_sampling_type': 'uniform'}}
In [2]:
Copied!
num_iter = 8
batch_size = 4
A_sparsity_level = None # E.g., 0.8 for 80% sparsity

A, D = init_A_and_D_from_spec(
    spec['num_obs'],
    spec['num_states'],
    spec['A_dependencies'],
    A_sparsity_level=A_sparsity_level,
    batch_size=batch_size
)

obs = get_sample_obs(spec['num_obs'], batch_size=batch_size)
o_vec = [nn.one_hot(o, spec['num_obs'][m]) for m, o in enumerate(obs)]

# place where this happens is important!
# o_vec = jtu.tree_map(lambda x: x[-1], o_vec)
num_iter = 8 batch_size = 4 A_sparsity_level = None # E.g., 0.8 for 80% sparsity A, D = init_A_and_D_from_spec( spec['num_obs'], spec['num_states'], spec['A_dependencies'], A_sparsity_level=A_sparsity_level, batch_size=batch_size ) obs = get_sample_obs(spec['num_obs'], batch_size=batch_size) o_vec = [nn.one_hot(o, spec['num_obs'][m]) for m, o in enumerate(obs)] # place where this happens is important! # o_vec = jtu.tree_map(lambda x: x[-1], o_vec)
WARNING:2025-11-20 23:52:19,919:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Original method imported directly from PyMDP¶

In [ ]:
Copied!
from pymdp.inference import update_posterior_states

infer_states_orig_pymdp = vmap(
    partial(
        update_posterior_states,
        A_dependencies=spec['A_dependencies'],
        num_iter=num_iter,
        method='fpi'
    )
)
from pymdp.inference import update_posterior_states infer_states_orig_pymdp = vmap( partial( update_posterior_states, A_dependencies=spec['A_dependencies'], num_iter=num_iter, method='fpi' ) )
In [3]:
Copied!
qs = infer_states_orig_pymdp(A, None, o_vec, None, D)
[q.shape for q in qs], qs
qs = infer_states_orig_pymdp(A, None, o_vec, None, D) [q.shape for q in qs], qs
Out[3]:
([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
           0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
  
         [[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
           0.04707498, 0.12128211, 0.10770807, 0.08129089]],
  
         [[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
           0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
  
         [[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
           0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
  Array([[[0.1689375 , 0.160291  , 0.1127065 , 0.18204276, 0.21946688,
           0.08239742, 0.07415786]],
  
         [[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
           0.08600897, 0.1934056 ]],
  
         [[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
           0.33044106, 0.02134874]],
  
         [[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
           0.07777937, 0.0781489 ]]], dtype=float32),
  Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
  
         [[0.13832287, 0.417142  , 0.29778096, 0.14675426]],
  
         [[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
  
         [[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
  Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
           0.10489379, 0.17765091]],
  
         [[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
           0.2958903 , 0.16056576]],
  
         [[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
           0.18881004, 0.21084957]],
  
         [[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
           0.06183277, 0.10635528]]], dtype=float32),
  Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
  
         [[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
  
         [[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
  
         [[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]],      dtype=float32),
  Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
           0.2847777 , 0.08775268]],
  
         [[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
           0.3041632 , 0.13611032]],
  
         [[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
           0.15282223, 0.11434983]],
  
         [[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
           0.0673811 , 0.16289519]]], dtype=float32),
  Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
  
         [[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
  
         [[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
  
         [[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]],      dtype=float32)])

Hybrid method¶

In [ ]:
Copied!
def infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies, num_iter):
    lls_padded = compute_log_likelihoods_padded(obs_padded, A_padded)
    log_likelihoods = deconstruct_lls(lls_padded, A_shapes)
    return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
def infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies, num_iter): lls_padded = compute_log_likelihoods_padded(obs_padded, A_padded) log_likelihoods = deconstruct_lls(lls_padded, A_shapes) return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
In [ ]:
Copied!
A_padded = apply_padding_batched(A)
A_shapes = [a.shape for a in A]

if A_sparsity_level is not None:
    A_padded = jsparse.BCOO.fromdense(A_padded, n_batch=1)

# obs preprocessing
obs_padded = apply_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
A_padded = apply_padding_batched(A) A_shapes = [a.shape for a in A] if A_sparsity_level is not None: A_padded = jsparse.BCOO.fromdense(A_padded, n_batch=1) # obs preprocessing obs_padded = apply_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
In [4]:
Copied!
qs = infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter)
[q.shape for q in qs], qs
qs = infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter) [q.shape for q in qs], qs
Out[4]:
([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
           0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
  
         [[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
           0.04707498, 0.12128211, 0.10770807, 0.08129089]],
  
         [[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
           0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
  
         [[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
           0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
  Array([[[0.1689375 , 0.160291  , 0.1127065 , 0.18204276, 0.21946688,
           0.08239742, 0.07415786]],
  
         [[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
           0.08600897, 0.1934056 ]],
  
         [[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
           0.33044106, 0.02134874]],
  
         [[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
           0.07777937, 0.0781489 ]]], dtype=float32),
  Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
  
         [[0.13832287, 0.417142  , 0.29778096, 0.14675426]],
  
         [[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
  
         [[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
  Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
           0.10489379, 0.17765091]],
  
         [[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
           0.2958903 , 0.16056576]],
  
         [[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
           0.18881004, 0.21084957]],
  
         [[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
           0.06183277, 0.10635528]]], dtype=float32),
  Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
  
         [[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
  
         [[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
  
         [[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]],      dtype=float32),
  Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
           0.2847777 , 0.08775268]],
  
         [[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
           0.3041632 , 0.13611032]],
  
         [[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
           0.15282223, 0.11434983]],
  
         [[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
           0.0673811 , 0.16289519]]], dtype=float32),
  Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
  
         [[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
  
         [[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
  
         [[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]],      dtype=float32)])
In [5]:
Copied!
# JIT
apply_padding_batched_jit = jit(partial(apply_padding_batched))
infer_states_hybrid_jit = jit(partial(infer_states_hybrid, A_shapes=A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter))
obs_padded = apply_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

qs = infer_states_hybrid_jit(obs_padded, A_padded, D)
[q.shape for q in qs], qs
# JIT apply_padding_batched_jit = jit(partial(apply_padding_batched)) infer_states_hybrid_jit = jit(partial(infer_states_hybrid, A_shapes=A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter)) obs_padded = apply_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec)) qs = infer_states_hybrid_jit(obs_padded, A_padded, D) [q.shape for q in qs], qs
Out[5]:
([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
           0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
  
         [[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
           0.04707498, 0.12128211, 0.10770807, 0.08129089]],
  
         [[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
           0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
  
         [[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
           0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
  Array([[[0.1689375 , 0.160291  , 0.1127065 , 0.18204276, 0.21946688,
           0.08239742, 0.07415786]],
  
         [[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
           0.08600897, 0.1934056 ]],
  
         [[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
           0.33044106, 0.02134874]],
  
         [[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
           0.07777937, 0.0781489 ]]], dtype=float32),
  Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
  
         [[0.13832287, 0.417142  , 0.29778096, 0.14675426]],
  
         [[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
  
         [[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
  Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
           0.10489379, 0.17765091]],
  
         [[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
           0.2958903 , 0.16056576]],
  
         [[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
           0.18881004, 0.21084957]],
  
         [[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
           0.06183277, 0.10635528]]], dtype=float32),
  Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
  
         [[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
  
         [[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
  
         [[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]],      dtype=float32),
  Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
           0.2847777 , 0.08775268]],
  
         [[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
           0.3041632 , 0.13611032]],
  
         [[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
           0.15282223, 0.11434983]],
  
         [[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
           0.0673811 , 0.16289519]]], dtype=float32),
  Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
  
         [[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
  
         [[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
  
         [[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]],      dtype=float32)])

Hybrid Block method¶

In [ ]:
Copied!
# Infer states hybrid block
def infer_states_hybrid_block(obs, A_big, D, state_shapes, cuts, A_dependencies, num_iter, use_einsum=False):
    """Hybrid inference using block diagonal approach for log-likelihood computation."""
    log_likelihoods = compute_log_likelihoods_block_diag(A_big, obs, state_shapes, cuts, use_einsum=use_einsum)
    return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
# Infer states hybrid block def infer_states_hybrid_block(obs, A_big, D, state_shapes, cuts, A_dependencies, num_iter, use_einsum=False): """Hybrid inference using block diagonal approach for log-likelihood computation.""" log_likelihoods = compute_log_likelihoods_block_diag(A_big, obs, state_shapes, cuts, use_einsum=use_einsum) return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
In [ ]:
Copied!
# Create a copy with moved axes for block diagonal method (don't modify original A)
A_moveaxis = [jnp.moveaxis(a, 1, -1) for a in A]
# Preprocess A matrices for block diagonal approach
A_big, state_shapes, cuts = preprocess_A_for_block_diag(A_moveaxis)

if A_sparsity_level is not None:
    A_big = jsparse.BCOO.fromdense(A_big, n_batch=1)

obs_tmp = jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec)
obs_big = concatenate_observations_block_diag(obs_tmp)
# Create a copy with moved axes for block diagonal method (don't modify original A) A_moveaxis = [jnp.moveaxis(a, 1, -1) for a in A] # Preprocess A matrices for block diagonal approach A_big, state_shapes, cuts = preprocess_A_for_block_diag(A_moveaxis) if A_sparsity_level is not None: A_big = jsparse.BCOO.fromdense(A_big, n_batch=1) obs_tmp = jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec) obs_big = concatenate_observations_block_diag(obs_tmp)
In [6]:
Copied!
qs = infer_states_hybrid_block(obs_big, A_big, D, 
    state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], 
    num_iter=num_iter, use_einsum=False
)

[q.shape for q in qs], qs
qs = infer_states_hybrid_block(obs_big, A_big, D, state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], num_iter=num_iter, use_einsum=False ) [q.shape for q in qs], qs
Out[6]:
([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
           0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
  
         [[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
           0.04707498, 0.12128211, 0.10770807, 0.08129089]],
  
         [[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
           0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
  
         [[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
           0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
  Array([[[0.1689375 , 0.160291  , 0.1127065 , 0.18204276, 0.21946688,
           0.08239742, 0.07415786]],
  
         [[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
           0.08600897, 0.1934056 ]],
  
         [[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
           0.33044106, 0.02134874]],
  
         [[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
           0.07777937, 0.0781489 ]]], dtype=float32),
  Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
  
         [[0.13832287, 0.417142  , 0.29778096, 0.14675426]],
  
         [[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
  
         [[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
  Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
           0.10489379, 0.17765091]],
  
         [[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
           0.2958903 , 0.16056576]],
  
         [[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
           0.18881004, 0.21084957]],
  
         [[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
           0.06183277, 0.10635528]]], dtype=float32),
  Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
  
         [[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
  
         [[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
  
         [[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]],      dtype=float32),
  Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
           0.2847777 , 0.08775268]],
  
         [[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
           0.3041632 , 0.13611032]],
  
         [[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
           0.15282223, 0.11434983]],
  
         [[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
           0.0673811 , 0.16289519]]], dtype=float32),
  Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
  
         [[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
  
         [[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
  
         [[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]],      dtype=float32)])
In [7]:
Copied!
# JIT
use_einsum=False
concatenate_observations_block_diag_jit = jit(partial(concatenate_observations_block_diag)) # just for obs
infer_states_hybrid_block_jit = jit(partial(infer_states_hybrid_block, state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], num_iter=num_iter, use_einsum=use_einsum))
obs_big = concatenate_observations_block_diag_jit(obs_tmp) # add padding of obs before running the infer states

qs = infer_states_hybrid_block_jit(obs_big, A_big, D)
[q.shape for q in qs], qs
# JIT use_einsum=False concatenate_observations_block_diag_jit = jit(partial(concatenate_observations_block_diag)) # just for obs infer_states_hybrid_block_jit = jit(partial(infer_states_hybrid_block, state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], num_iter=num_iter, use_einsum=use_einsum)) obs_big = concatenate_observations_block_diag_jit(obs_tmp) # add padding of obs before running the infer states qs = infer_states_hybrid_block_jit(obs_big, A_big, D) [q.shape for q in qs], qs
Out[7]:
([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
           0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
  
         [[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
           0.04707498, 0.12128211, 0.10770807, 0.08129089]],
  
         [[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
           0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
  
         [[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
           0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
  Array([[[0.1689375 , 0.160291  , 0.1127065 , 0.18204276, 0.21946688,
           0.08239742, 0.07415786]],
  
         [[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
           0.08600897, 0.1934056 ]],
  
         [[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
           0.33044106, 0.02134874]],
  
         [[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
           0.07777937, 0.0781489 ]]], dtype=float32),
  Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
  
         [[0.13832287, 0.417142  , 0.29778096, 0.14675426]],
  
         [[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
  
         [[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
  Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
           0.10489379, 0.17765091]],
  
         [[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
           0.2958903 , 0.16056576]],
  
         [[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
           0.18881004, 0.21084957]],
  
         [[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
           0.06183277, 0.10635528]]], dtype=float32),
  Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
  
         [[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
  
         [[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
  
         [[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]],      dtype=float32),
  Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
           0.2847777 , 0.08775268]],
  
         [[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
           0.3041632 , 0.13611032]],
  
         [[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
           0.15282223, 0.11434983]],
  
         [[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
           0.0673811 , 0.16289519]]], dtype=float32),
  Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
  
         [[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
  
         [[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
  
         [[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]],      dtype=float32)])

End2End padded method¶

In [ ]:
Copied!
def infer_states_end2end_padded(A_padded, obs_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter, sparsity=None):
    lls_padded = compute_log_likelihood_per_modality_end2end2_padded(obs_padded, A_padded, sparsity=sparsity)
    return run_factorized_fpi_end2end_padded(lls_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter)
def infer_states_end2end_padded(A_padded, obs_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter, sparsity=None): lls_padded = compute_log_likelihood_per_modality_end2end2_padded(obs_padded, A_padded, sparsity=sparsity) return run_factorized_fpi_end2end_padded(lls_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter)
In [ ]:
Copied!
A_padded = apply_A_end2end_padding_batched(A)

if A_sparsity_level is not None:
    A_padded = jsparse.BCOO.fromdense(A_padded)

max_obs_dim = A_padded.shape[2]
max_state_dim = max(A_padded.shape[3:])

# obs preprocessing
obs_padded = apply_obs_end2end_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec), max_obs_dim)
A_padded = apply_A_end2end_padding_batched(A) if A_sparsity_level is not None: A_padded = jsparse.BCOO.fromdense(A_padded) max_obs_dim = A_padded.shape[2] max_state_dim = max(A_padded.shape[3:]) # obs preprocessing obs_padded = apply_obs_end2end_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec), max_obs_dim)
In [8]:
Copied!
qs = infer_states_end2end_padded(A_padded, obs_padded, D, spec['A_dependencies'], max_obs_dim, max_state_dim, num_iter, sparsity='ll_only')

[q.shape for q in qs], qs
qs = infer_states_end2end_padded(A_padded, obs_padded, D, spec['A_dependencies'], max_obs_dim, max_state_dim, num_iter, sparsity='ll_only') [q.shape for q in qs], qs
Out[8]:
([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.09112392, 0.70235604, 0.20652005]],
  
         [[0.202104  , 0.22807139, 0.56982464]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.37516478, 0.27490932, 0.34992588]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.55582505, 0.44417495]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.3831266 , 0.6168734 ]],
  
         [[0.5388056 , 0.4611944 ]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
           0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
  
         [[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
           0.04707498, 0.12128211, 0.10770807, 0.08129089]],
  
         [[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
           0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
  
         [[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
           0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
  Array([[[0.16893746, 0.16029103, 0.11270652, 0.18204279, 0.21946692,
           0.0823974 , 0.07415791]],
  
         [[0.10790369, 0.2240197 , 0.21666422, 0.06581873, 0.10617904,
           0.08600904, 0.1934056 ]],
  
         [[0.03150751, 0.13815625, 0.12641439, 0.14349648, 0.20863557,
           0.33044103, 0.02134874]],
  
         [[0.09764872, 0.29433855, 0.18542336, 0.14082211, 0.12583902,
           0.07777937, 0.07814886]]], dtype=float32),
  Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
  
         [[0.13832287, 0.417142  , 0.29778096, 0.14675426]],
  
         [[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
  
         [[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
  Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
           0.10489379, 0.17765091]],
  
         [[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
           0.2958903 , 0.16056576]],
  
         [[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
           0.18881004, 0.21084957]],
  
         [[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
           0.06183277, 0.10635528]]], dtype=float32),
  Array([[[0.43077335, 0.01719547, 0.18029204, 0.19461544, 0.17712368]],
  
         [[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322947]],
  
         [[0.34862235, 0.24600424, 0.12305134, 0.11314452, 0.16917753]],
  
         [[0.15003254, 0.08879873, 0.19839816, 0.5253277 , 0.03744284]]],      dtype=float32),
  Array([[[0.11571389, 0.04172199, 0.14292377, 0.21811737, 0.1089925 ,
           0.2847778 , 0.08775266]],
  
         [[0.11719004, 0.01083917, 0.22239207, 0.11398614, 0.09531904,
           0.30416313, 0.13611037]],
  
         [[0.05459085, 0.16350462, 0.13142684, 0.20870967, 0.17459588,
           0.15282223, 0.11434983]],
  
         [[0.1454746 , 0.09393647, 0.19433096, 0.22500303, 0.11097853,
           0.06738115, 0.16289523]]], dtype=float32),
  Array([[[0.23350713, 0.21891789, 0.16771066, 0.17027755, 0.20958683]],
  
         [[0.19165839, 0.2190793 , 0.20649076, 0.19125327, 0.19151816]],
  
         [[0.20866643, 0.18651229, 0.23633833, 0.17361015, 0.19487286]],
  
         [[0.24136762, 0.17386994, 0.22222117, 0.1608011 , 0.20174012]]],      dtype=float32)])
In [9]:
Copied!
# JIT
apply_obs_padding_batched_jit = jit(partial(apply_obs_end2end_padding_batched, max_obs_dim=max_obs_dim))
infer_states_partially_padded_jit = jit(partial(infer_states_end2end_padded, A_dependencies=spec['A_dependencies'], max_obs_dim=max_obs_dim, max_state_dim=max_state_dim, num_iter=num_iter, sparsity='ll_only'))
obs_padded = apply_obs_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

qs = infer_states_partially_padded_jit(A_padded, obs_padded, D)
[q.shape for q in qs], qs
# JIT apply_obs_padding_batched_jit = jit(partial(apply_obs_end2end_padding_batched, max_obs_dim=max_obs_dim)) infer_states_partially_padded_jit = jit(partial(infer_states_end2end_padded, A_dependencies=spec['A_dependencies'], max_obs_dim=max_obs_dim, max_state_dim=max_state_dim, num_iter=num_iter, sparsity='ll_only')) obs_padded = apply_obs_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec)) qs = infer_states_partially_padded_jit(A_padded, obs_padded, D) [q.shape for q in qs], qs
Out[9]:
([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.09112392, 0.70235604, 0.20652005]],
  
         [[0.202104  , 0.22807139, 0.56982464]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.37516478, 0.27490932, 0.34992588]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.55582505, 0.44417495]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.3831266 , 0.6168734 ]],
  
         [[0.5388056 , 0.4611944 ]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
           0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
  
         [[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
           0.04707498, 0.12128211, 0.10770807, 0.08129089]],
  
         [[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
           0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
  
         [[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
           0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
  Array([[[0.16893746, 0.16029103, 0.11270652, 0.18204279, 0.21946692,
           0.0823974 , 0.07415791]],
  
         [[0.10790369, 0.2240197 , 0.21666422, 0.06581873, 0.10617904,
           0.08600904, 0.1934056 ]],
  
         [[0.03150751, 0.13815625, 0.12641439, 0.14349648, 0.20863557,
           0.33044103, 0.02134874]],
  
         [[0.09764872, 0.29433855, 0.18542336, 0.14082211, 0.12583902,
           0.07777937, 0.07814886]]], dtype=float32),
  Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
  
         [[0.13832287, 0.417142  , 0.29778096, 0.14675426]],
  
         [[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
  
         [[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
  Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
           0.10489379, 0.17765091]],
  
         [[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
           0.2958903 , 0.16056576]],
  
         [[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
           0.18881004, 0.21084957]],
  
         [[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
           0.06183277, 0.10635528]]], dtype=float32),
  Array([[[0.43077335, 0.01719547, 0.18029204, 0.19461544, 0.17712368]],
  
         [[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322947]],
  
         [[0.34862235, 0.24600424, 0.12305134, 0.11314452, 0.16917753]],
  
         [[0.15003254, 0.08879873, 0.19839816, 0.5253277 , 0.03744284]]],      dtype=float32),
  Array([[[0.11571389, 0.04172199, 0.14292377, 0.21811737, 0.1089925 ,
           0.2847778 , 0.08775266]],
  
         [[0.11719004, 0.01083917, 0.22239207, 0.11398614, 0.09531904,
           0.30416313, 0.13611037]],
  
         [[0.05459085, 0.16350462, 0.13142684, 0.20870967, 0.17459588,
           0.15282223, 0.11434983]],
  
         [[0.1454746 , 0.09393647, 0.19433096, 0.22500303, 0.11097853,
           0.06738115, 0.16289523]]], dtype=float32),
  Array([[[0.23350713, 0.21891789, 0.16771066, 0.17027755, 0.20958683]],
  
         [[0.19165839, 0.2190793 , 0.20649076, 0.19125327, 0.19151816]],
  
         [[0.20866643, 0.18651229, 0.23633833, 0.17361015, 0.19487286]],
  
         [[0.24136762, 0.17386994, 0.22222117, 0.1608011 , 0.20174012]]],      dtype=float32)])

Made with Dracula Theme for MkDocs