In [ ]:
Copied!
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
import sys
if "google.colab" in sys.modules:
%pip install "inferactively-pymdp" -q
In [ ]:
Copied!
import numpy as np
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.experimental.sparse as jsparse
from jax import nn, vmap, jit, block_until_ready
from functools import partial
from pymdp.utils import init_A_and_D_from_spec, get_sample_obs, generate_agent_specs_from_parameter_sets
# Hybrid
from pymdp.utils import apply_padding_batched
from pymdp.maths import compute_log_likelihoods_padded, deconstruct_lls
# Hybrid block
from pymdp.utils import build_block_diag_A, preprocess_A_for_block_diag, prepare_obs_for_block_diag, concatenate_observations_block_diag
from pymdp.maths import compute_log_likelihoods_block_diag, deconstruct_log_likelihoods_block_diag
from pymdp.algos import run_factorized_fpi_hybrid # For hybrid and hybrid block
# End2end padded
from pymdp.utils import apply_A_end2end_padding_batched, apply_obs_end2end_padding_batched
from pymdp.maths import compute_log_likelihood_per_modality_end2end2_padded
from pymdp.algos import run_factorized_fpi_end2end_padded
import numpy as np
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.experimental.sparse as jsparse
from jax import nn, vmap, jit, block_until_ready
from functools import partial
from pymdp.utils import init_A_and_D_from_spec, get_sample_obs, generate_agent_specs_from_parameter_sets
# Hybrid
from pymdp.utils import apply_padding_batched
from pymdp.maths import compute_log_likelihoods_padded, deconstruct_lls
# Hybrid block
from pymdp.utils import build_block_diag_A, preprocess_A_for_block_diag, prepare_obs_for_block_diag, concatenate_observations_block_diag
from pymdp.maths import compute_log_likelihoods_block_diag, deconstruct_log_likelihoods_block_diag
from pymdp.algos import run_factorized_fpi_hybrid # For hybrid and hybrid block
# End2end padded
from pymdp.utils import apply_A_end2end_padding_batched, apply_obs_end2end_padding_batched
from pymdp.maths import compute_log_likelihood_per_modality_end2end2_padded
from pymdp.algos import run_factorized_fpi_end2end_padded
In [1]:
Copied!
# Define coordinated parameter sets
# (num_factors, num_modalities, state_dim_upper_limit, obs_dim_upper_limit, dim_sampling_type, label)
parameter_sets = [
(5, 5, 5, 5, 'uniform', 'low'),
(10, 10, 10, 10, 'uniform', 'medium'),
(25, 25, 25, 25, 'uniform', 'high'),
# (125, 125, 125, 125, 'uniform', 'extreme'), # Uncomment to include extreme cases
]
# Generate agent specs without dumping to file
specs = generate_agent_specs_from_parameter_sets(
parameter_sets,
num_agents_per_set=1,
output_file=None # Don't save to file
)
spec = specs['arbitrary dependencies'][1]
spec
# Define coordinated parameter sets
# (num_factors, num_modalities, state_dim_upper_limit, obs_dim_upper_limit, dim_sampling_type, label)
parameter_sets = [
(5, 5, 5, 5, 'uniform', 'low'),
(10, 10, 10, 10, 'uniform', 'medium'),
(25, 25, 25, 25, 'uniform', 'high'),
# (125, 125, 125, 125, 'uniform', 'extreme'), # Uncomment to include extreme cases
]
# Generate agent specs without dumping to file
specs = generate_agent_specs_from_parameter_sets(
parameter_sets,
num_agents_per_set=1,
output_file=None # Don't save to file
)
spec = specs['arbitrary dependencies'][1]
spec
Out[1]:
{'num_factors': 10,
'num_modalities': 10,
'num_states': [3, 6, 2, 9, 7, 4, 7, 5, 7, 5],
'num_obs': [4, 2, 4, 9, 7, 8, 4, 5, 7, 4],
'A_dependencies': [[1],
[4],
[5, 6],
[0, 8],
[0, 2, 7],
[3, 5],
[6],
[2, 4, 7, 9],
[0, 2, 7, 8],
[7]],
'metadata': {'num_factors': 'medium',
'num_modalities': 'medium',
'state_dim_upper_limit': 'medium',
'obs_dim_upper_limit': 'medium',
'dim_sampling_type': 'uniform'}}
In [2]:
Copied!
num_iter = 8
batch_size = 4
A_sparsity_level = None # E.g., 0.8 for 80% sparsity
A, D = init_A_and_D_from_spec(
spec['num_obs'],
spec['num_states'],
spec['A_dependencies'],
A_sparsity_level=A_sparsity_level,
batch_size=batch_size
)
obs = get_sample_obs(spec['num_obs'], batch_size=batch_size)
o_vec = [nn.one_hot(o, spec['num_obs'][m]) for m, o in enumerate(obs)]
# place where this happens is important!
# o_vec = jtu.tree_map(lambda x: x[-1], o_vec)
num_iter = 8
batch_size = 4
A_sparsity_level = None # E.g., 0.8 for 80% sparsity
A, D = init_A_and_D_from_spec(
spec['num_obs'],
spec['num_states'],
spec['A_dependencies'],
A_sparsity_level=A_sparsity_level,
batch_size=batch_size
)
obs = get_sample_obs(spec['num_obs'], batch_size=batch_size)
o_vec = [nn.one_hot(o, spec['num_obs'][m]) for m, o in enumerate(obs)]
# place where this happens is important!
# o_vec = jtu.tree_map(lambda x: x[-1], o_vec)
WARNING:2025-11-20 23:52:19,919:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Original method imported directly from PyMDP¶
In [ ]:
Copied!
from pymdp.inference import update_posterior_states
infer_states_orig_pymdp = vmap(
partial(
update_posterior_states,
A_dependencies=spec['A_dependencies'],
num_iter=num_iter,
method='fpi'
)
)
from pymdp.inference import update_posterior_states
infer_states_orig_pymdp = vmap(
partial(
update_posterior_states,
A_dependencies=spec['A_dependencies'],
num_iter=num_iter,
method='fpi'
)
)
In [3]:
Copied!
qs = infer_states_orig_pymdp(A, None, o_vec, None, D)
[q.shape for q in qs], qs
qs = infer_states_orig_pymdp(A, None, o_vec, None, D)
[q.shape for q in qs], qs
Out[3]:
([(4, 1, 3),
(4, 1, 6),
(4, 1, 2),
(4, 1, 9),
(4, 1, 7),
(4, 1, 4),
(4, 1, 7),
(4, 1, 5),
(4, 1, 7),
(4, 1, 5)],
[Array([[[0.0911239 , 0.70235616, 0.20651992]],
[[0.20210403, 0.22807154, 0.56982446]],
[[0.3876841 , 0.3879837 , 0.2243322 ]],
[[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
Array([[[0.15824087, 0.183674 , 0.26761115, 0.09957846, 0.20257226,
0.08832329]],
[[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
0.31559473]],
[[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
0.22736534]],
[[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
0.12064715]]], dtype=float32),
Array([[[0.5558248 , 0.44417518]],
[[0.549685 , 0.450315 ]],
[[0.38312683, 0.6168732 ]],
[[0.5388057 , 0.46119428]]], dtype=float32),
Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
[[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
0.04707498, 0.12128211, 0.10770807, 0.08129089]],
[[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
[[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
Array([[[0.1689375 , 0.160291 , 0.1127065 , 0.18204276, 0.21946688,
0.08239742, 0.07415786]],
[[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
0.08600897, 0.1934056 ]],
[[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
0.33044106, 0.02134874]],
[[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
0.07777937, 0.0781489 ]]], dtype=float32),
Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
[[0.13832287, 0.417142 , 0.29778096, 0.14675426]],
[[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
[[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
0.10489379, 0.17765091]],
[[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
0.2958903 , 0.16056576]],
[[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
0.18881004, 0.21084957]],
[[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
0.06183277, 0.10635528]]], dtype=float32),
Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
[[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
[[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
[[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]], dtype=float32),
Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
0.2847777 , 0.08775268]],
[[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
0.3041632 , 0.13611032]],
[[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
0.15282223, 0.11434983]],
[[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
0.0673811 , 0.16289519]]], dtype=float32),
Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
[[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
[[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
[[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]], dtype=float32)])
Hybrid method¶
In [ ]:
Copied!
def infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies, num_iter):
lls_padded = compute_log_likelihoods_padded(obs_padded, A_padded)
log_likelihoods = deconstruct_lls(lls_padded, A_shapes)
return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
def infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies, num_iter):
lls_padded = compute_log_likelihoods_padded(obs_padded, A_padded)
log_likelihoods = deconstruct_lls(lls_padded, A_shapes)
return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
In [ ]:
Copied!
A_padded = apply_padding_batched(A)
A_shapes = [a.shape for a in A]
if A_sparsity_level is not None:
A_padded = jsparse.BCOO.fromdense(A_padded, n_batch=1)
# obs preprocessing
obs_padded = apply_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
A_padded = apply_padding_batched(A)
A_shapes = [a.shape for a in A]
if A_sparsity_level is not None:
A_padded = jsparse.BCOO.fromdense(A_padded, n_batch=1)
# obs preprocessing
obs_padded = apply_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
In [4]:
Copied!
qs = infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter)
[q.shape for q in qs], qs
qs = infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter)
[q.shape for q in qs], qs
Out[4]:
([(4, 1, 3),
(4, 1, 6),
(4, 1, 2),
(4, 1, 9),
(4, 1, 7),
(4, 1, 4),
(4, 1, 7),
(4, 1, 5),
(4, 1, 7),
(4, 1, 5)],
[Array([[[0.0911239 , 0.70235616, 0.20651992]],
[[0.20210403, 0.22807154, 0.56982446]],
[[0.3876841 , 0.3879837 , 0.2243322 ]],
[[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
Array([[[0.15824087, 0.183674 , 0.26761115, 0.09957846, 0.20257226,
0.08832329]],
[[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
0.31559473]],
[[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
0.22736534]],
[[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
0.12064715]]], dtype=float32),
Array([[[0.5558248 , 0.44417518]],
[[0.549685 , 0.450315 ]],
[[0.38312683, 0.6168732 ]],
[[0.5388057 , 0.46119428]]], dtype=float32),
Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
[[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
0.04707498, 0.12128211, 0.10770807, 0.08129089]],
[[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
[[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
Array([[[0.1689375 , 0.160291 , 0.1127065 , 0.18204276, 0.21946688,
0.08239742, 0.07415786]],
[[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
0.08600897, 0.1934056 ]],
[[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
0.33044106, 0.02134874]],
[[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
0.07777937, 0.0781489 ]]], dtype=float32),
Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
[[0.13832287, 0.417142 , 0.29778096, 0.14675426]],
[[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
[[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
0.10489379, 0.17765091]],
[[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
0.2958903 , 0.16056576]],
[[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
0.18881004, 0.21084957]],
[[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
0.06183277, 0.10635528]]], dtype=float32),
Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
[[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
[[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
[[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]], dtype=float32),
Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
0.2847777 , 0.08775268]],
[[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
0.3041632 , 0.13611032]],
[[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
0.15282223, 0.11434983]],
[[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
0.0673811 , 0.16289519]]], dtype=float32),
Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
[[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
[[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
[[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]], dtype=float32)])
In [5]:
Copied!
# JIT
apply_padding_batched_jit = jit(partial(apply_padding_batched))
infer_states_hybrid_jit = jit(partial(infer_states_hybrid, A_shapes=A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter))
obs_padded = apply_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
qs = infer_states_hybrid_jit(obs_padded, A_padded, D)
[q.shape for q in qs], qs
# JIT
apply_padding_batched_jit = jit(partial(apply_padding_batched))
infer_states_hybrid_jit = jit(partial(infer_states_hybrid, A_shapes=A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter))
obs_padded = apply_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
qs = infer_states_hybrid_jit(obs_padded, A_padded, D)
[q.shape for q in qs], qs
Out[5]:
([(4, 1, 3),
(4, 1, 6),
(4, 1, 2),
(4, 1, 9),
(4, 1, 7),
(4, 1, 4),
(4, 1, 7),
(4, 1, 5),
(4, 1, 7),
(4, 1, 5)],
[Array([[[0.0911239 , 0.70235616, 0.20651992]],
[[0.20210403, 0.22807154, 0.56982446]],
[[0.3876841 , 0.3879837 , 0.2243322 ]],
[[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
Array([[[0.15824087, 0.183674 , 0.26761115, 0.09957846, 0.20257226,
0.08832329]],
[[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
0.31559473]],
[[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
0.22736534]],
[[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
0.12064715]]], dtype=float32),
Array([[[0.5558248 , 0.44417518]],
[[0.549685 , 0.450315 ]],
[[0.38312683, 0.6168732 ]],
[[0.5388057 , 0.46119428]]], dtype=float32),
Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
[[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
0.04707498, 0.12128211, 0.10770807, 0.08129089]],
[[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
[[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
Array([[[0.1689375 , 0.160291 , 0.1127065 , 0.18204276, 0.21946688,
0.08239742, 0.07415786]],
[[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
0.08600897, 0.1934056 ]],
[[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
0.33044106, 0.02134874]],
[[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
0.07777937, 0.0781489 ]]], dtype=float32),
Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
[[0.13832287, 0.417142 , 0.29778096, 0.14675426]],
[[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
[[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
0.10489379, 0.17765091]],
[[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
0.2958903 , 0.16056576]],
[[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
0.18881004, 0.21084957]],
[[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
0.06183277, 0.10635528]]], dtype=float32),
Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
[[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
[[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
[[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]], dtype=float32),
Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
0.2847777 , 0.08775268]],
[[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
0.3041632 , 0.13611032]],
[[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
0.15282223, 0.11434983]],
[[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
0.0673811 , 0.16289519]]], dtype=float32),
Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
[[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
[[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
[[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]], dtype=float32)])
Hybrid Block method¶
In [ ]:
Copied!
# Infer states hybrid block
def infer_states_hybrid_block(obs, A_big, D, state_shapes, cuts, A_dependencies, num_iter, use_einsum=False):
"""Hybrid inference using block diagonal approach for log-likelihood computation."""
log_likelihoods = compute_log_likelihoods_block_diag(A_big, obs, state_shapes, cuts, use_einsum=use_einsum)
return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
# Infer states hybrid block
def infer_states_hybrid_block(obs, A_big, D, state_shapes, cuts, A_dependencies, num_iter, use_einsum=False):
"""Hybrid inference using block diagonal approach for log-likelihood computation."""
log_likelihoods = compute_log_likelihoods_block_diag(A_big, obs, state_shapes, cuts, use_einsum=use_einsum)
return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)
In [ ]:
Copied!
# Create a copy with moved axes for block diagonal method (don't modify original A)
A_moveaxis = [jnp.moveaxis(a, 1, -1) for a in A]
# Preprocess A matrices for block diagonal approach
A_big, state_shapes, cuts = preprocess_A_for_block_diag(A_moveaxis)
if A_sparsity_level is not None:
A_big = jsparse.BCOO.fromdense(A_big, n_batch=1)
obs_tmp = jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec)
obs_big = concatenate_observations_block_diag(obs_tmp)
# Create a copy with moved axes for block diagonal method (don't modify original A)
A_moveaxis = [jnp.moveaxis(a, 1, -1) for a in A]
# Preprocess A matrices for block diagonal approach
A_big, state_shapes, cuts = preprocess_A_for_block_diag(A_moveaxis)
if A_sparsity_level is not None:
A_big = jsparse.BCOO.fromdense(A_big, n_batch=1)
obs_tmp = jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec)
obs_big = concatenate_observations_block_diag(obs_tmp)
In [6]:
Copied!
qs = infer_states_hybrid_block(obs_big, A_big, D,
state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'],
num_iter=num_iter, use_einsum=False
)
[q.shape for q in qs], qs
qs = infer_states_hybrid_block(obs_big, A_big, D,
state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'],
num_iter=num_iter, use_einsum=False
)
[q.shape for q in qs], qs
Out[6]:
([(4, 1, 3),
(4, 1, 6),
(4, 1, 2),
(4, 1, 9),
(4, 1, 7),
(4, 1, 4),
(4, 1, 7),
(4, 1, 5),
(4, 1, 7),
(4, 1, 5)],
[Array([[[0.0911239 , 0.70235616, 0.20651992]],
[[0.20210403, 0.22807154, 0.56982446]],
[[0.3876841 , 0.3879837 , 0.2243322 ]],
[[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
Array([[[0.15824087, 0.183674 , 0.26761115, 0.09957846, 0.20257226,
0.08832329]],
[[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
0.31559473]],
[[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
0.22736534]],
[[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
0.12064715]]], dtype=float32),
Array([[[0.5558248 , 0.44417518]],
[[0.549685 , 0.450315 ]],
[[0.38312683, 0.6168732 ]],
[[0.5388057 , 0.46119428]]], dtype=float32),
Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
[[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
0.04707498, 0.12128211, 0.10770807, 0.08129089]],
[[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
[[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
Array([[[0.1689375 , 0.160291 , 0.1127065 , 0.18204276, 0.21946688,
0.08239742, 0.07415786]],
[[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
0.08600897, 0.1934056 ]],
[[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
0.33044106, 0.02134874]],
[[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
0.07777937, 0.0781489 ]]], dtype=float32),
Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
[[0.13832287, 0.417142 , 0.29778096, 0.14675426]],
[[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
[[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
0.10489379, 0.17765091]],
[[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
0.2958903 , 0.16056576]],
[[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
0.18881004, 0.21084957]],
[[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
0.06183277, 0.10635528]]], dtype=float32),
Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
[[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
[[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
[[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]], dtype=float32),
Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
0.2847777 , 0.08775268]],
[[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
0.3041632 , 0.13611032]],
[[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
0.15282223, 0.11434983]],
[[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
0.0673811 , 0.16289519]]], dtype=float32),
Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
[[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
[[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
[[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]], dtype=float32)])
In [7]:
Copied!
# JIT
use_einsum=False
concatenate_observations_block_diag_jit = jit(partial(concatenate_observations_block_diag)) # just for obs
infer_states_hybrid_block_jit = jit(partial(infer_states_hybrid_block, state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], num_iter=num_iter, use_einsum=use_einsum))
obs_big = concatenate_observations_block_diag_jit(obs_tmp) # add padding of obs before running the infer states
qs = infer_states_hybrid_block_jit(obs_big, A_big, D)
[q.shape for q in qs], qs
# JIT
use_einsum=False
concatenate_observations_block_diag_jit = jit(partial(concatenate_observations_block_diag)) # just for obs
infer_states_hybrid_block_jit = jit(partial(infer_states_hybrid_block, state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], num_iter=num_iter, use_einsum=use_einsum))
obs_big = concatenate_observations_block_diag_jit(obs_tmp) # add padding of obs before running the infer states
qs = infer_states_hybrid_block_jit(obs_big, A_big, D)
[q.shape for q in qs], qs
Out[7]:
([(4, 1, 3),
(4, 1, 6),
(4, 1, 2),
(4, 1, 9),
(4, 1, 7),
(4, 1, 4),
(4, 1, 7),
(4, 1, 5),
(4, 1, 7),
(4, 1, 5)],
[Array([[[0.0911239 , 0.70235616, 0.20651992]],
[[0.20210403, 0.22807154, 0.56982446]],
[[0.3876841 , 0.3879837 , 0.2243322 ]],
[[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
Array([[[0.15824087, 0.183674 , 0.26761115, 0.09957846, 0.20257226,
0.08832329]],
[[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
0.31559473]],
[[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
0.22736534]],
[[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
0.12064715]]], dtype=float32),
Array([[[0.5558248 , 0.44417518]],
[[0.549685 , 0.450315 ]],
[[0.38312683, 0.6168732 ]],
[[0.5388057 , 0.46119428]]], dtype=float32),
Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
[[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
0.04707498, 0.12128211, 0.10770807, 0.08129089]],
[[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
[[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
Array([[[0.1689375 , 0.160291 , 0.1127065 , 0.18204276, 0.21946688,
0.08239742, 0.07415786]],
[[0.10790369, 0.22401966, 0.21666421, 0.06581873, 0.10617904,
0.08600897, 0.1934056 ]],
[[0.03150751, 0.13815624, 0.12641437, 0.14349641, 0.20863566,
0.33044106, 0.02134874]],
[[0.09764873, 0.2943387 , 0.18542327, 0.14082205, 0.12583902,
0.07777937, 0.0781489 ]]], dtype=float32),
Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
[[0.13832287, 0.417142 , 0.29778096, 0.14675426]],
[[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
[[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
0.10489379, 0.17765091]],
[[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
0.2958903 , 0.16056576]],
[[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
0.18881004, 0.21084957]],
[[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
0.06183277, 0.10635528]]], dtype=float32),
Array([[[0.43077353, 0.01719548, 0.18029211, 0.19461532, 0.17712358]],
[[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322948]],
[[0.34862244, 0.24600431, 0.12305126, 0.11314465, 0.16917741]],
[[0.15003231, 0.08879875, 0.19839822, 0.52532786, 0.03744285]]], dtype=float32),
Array([[[0.11571391, 0.04172199, 0.14292379, 0.21811739, 0.10899252,
0.2847777 , 0.08775268]],
[[0.11718995, 0.01083917, 0.22239213, 0.11398617, 0.09531911,
0.3041632 , 0.13611032]],
[[0.05459082, 0.1635046 , 0.13142689, 0.20870976, 0.17459586,
0.15282223, 0.11434983]],
[[0.1454747 , 0.0939365 , 0.19433093, 0.22500299, 0.11097856,
0.0673811 , 0.16289519]]], dtype=float32),
Array([[[0.2335071 , 0.21891794, 0.16771059, 0.17027749, 0.20958687]],
[[0.19165832, 0.21907933, 0.20649089, 0.19125326, 0.19151819]],
[[0.20866643, 0.18651229, 0.23633833, 0.17361018, 0.19487286]],
[[0.24136765, 0.17386995, 0.2222213 , 0.16080104, 0.2017401 ]]], dtype=float32)])
End2End padded method¶
In [ ]:
Copied!
def infer_states_end2end_padded(A_padded, obs_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter, sparsity=None):
lls_padded = compute_log_likelihood_per_modality_end2end2_padded(obs_padded, A_padded, sparsity=sparsity)
return run_factorized_fpi_end2end_padded(lls_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter)
def infer_states_end2end_padded(A_padded, obs_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter, sparsity=None):
lls_padded = compute_log_likelihood_per_modality_end2end2_padded(obs_padded, A_padded, sparsity=sparsity)
return run_factorized_fpi_end2end_padded(lls_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter)
In [ ]:
Copied!
A_padded = apply_A_end2end_padding_batched(A)
if A_sparsity_level is not None:
A_padded = jsparse.BCOO.fromdense(A_padded)
max_obs_dim = A_padded.shape[2]
max_state_dim = max(A_padded.shape[3:])
# obs preprocessing
obs_padded = apply_obs_end2end_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec), max_obs_dim)
A_padded = apply_A_end2end_padding_batched(A)
if A_sparsity_level is not None:
A_padded = jsparse.BCOO.fromdense(A_padded)
max_obs_dim = A_padded.shape[2]
max_state_dim = max(A_padded.shape[3:])
# obs preprocessing
obs_padded = apply_obs_end2end_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec), max_obs_dim)
In [8]:
Copied!
qs = infer_states_end2end_padded(A_padded, obs_padded, D, spec['A_dependencies'], max_obs_dim, max_state_dim, num_iter, sparsity='ll_only')
[q.shape for q in qs], qs
qs = infer_states_end2end_padded(A_padded, obs_padded, D, spec['A_dependencies'], max_obs_dim, max_state_dim, num_iter, sparsity='ll_only')
[q.shape for q in qs], qs
Out[8]:
([(4, 1, 3),
(4, 1, 6),
(4, 1, 2),
(4, 1, 9),
(4, 1, 7),
(4, 1, 4),
(4, 1, 7),
(4, 1, 5),
(4, 1, 7),
(4, 1, 5)],
[Array([[[0.09112392, 0.70235604, 0.20652005]],
[[0.202104 , 0.22807139, 0.56982464]],
[[0.3876841 , 0.3879837 , 0.2243322 ]],
[[0.37516478, 0.27490932, 0.34992588]]], dtype=float32),
Array([[[0.15824087, 0.183674 , 0.26761115, 0.09957846, 0.20257226,
0.08832329]],
[[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
0.31559473]],
[[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
0.22736534]],
[[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
0.12064715]]], dtype=float32),
Array([[[0.55582505, 0.44417495]],
[[0.549685 , 0.450315 ]],
[[0.3831266 , 0.6168734 ]],
[[0.5388056 , 0.4611944 ]]], dtype=float32),
Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
[[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
0.04707498, 0.12128211, 0.10770807, 0.08129089]],
[[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
[[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
Array([[[0.16893746, 0.16029103, 0.11270652, 0.18204279, 0.21946692,
0.0823974 , 0.07415791]],
[[0.10790369, 0.2240197 , 0.21666422, 0.06581873, 0.10617904,
0.08600904, 0.1934056 ]],
[[0.03150751, 0.13815625, 0.12641439, 0.14349648, 0.20863557,
0.33044103, 0.02134874]],
[[0.09764872, 0.29433855, 0.18542336, 0.14082211, 0.12583902,
0.07777937, 0.07814886]]], dtype=float32),
Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
[[0.13832287, 0.417142 , 0.29778096, 0.14675426]],
[[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
[[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
0.10489379, 0.17765091]],
[[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
0.2958903 , 0.16056576]],
[[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
0.18881004, 0.21084957]],
[[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
0.06183277, 0.10635528]]], dtype=float32),
Array([[[0.43077335, 0.01719547, 0.18029204, 0.19461544, 0.17712368]],
[[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322947]],
[[0.34862235, 0.24600424, 0.12305134, 0.11314452, 0.16917753]],
[[0.15003254, 0.08879873, 0.19839816, 0.5253277 , 0.03744284]]], dtype=float32),
Array([[[0.11571389, 0.04172199, 0.14292377, 0.21811737, 0.1089925 ,
0.2847778 , 0.08775266]],
[[0.11719004, 0.01083917, 0.22239207, 0.11398614, 0.09531904,
0.30416313, 0.13611037]],
[[0.05459085, 0.16350462, 0.13142684, 0.20870967, 0.17459588,
0.15282223, 0.11434983]],
[[0.1454746 , 0.09393647, 0.19433096, 0.22500303, 0.11097853,
0.06738115, 0.16289523]]], dtype=float32),
Array([[[0.23350713, 0.21891789, 0.16771066, 0.17027755, 0.20958683]],
[[0.19165839, 0.2190793 , 0.20649076, 0.19125327, 0.19151816]],
[[0.20866643, 0.18651229, 0.23633833, 0.17361015, 0.19487286]],
[[0.24136762, 0.17386994, 0.22222117, 0.1608011 , 0.20174012]]], dtype=float32)])
In [9]:
Copied!
# JIT
apply_obs_padding_batched_jit = jit(partial(apply_obs_end2end_padding_batched, max_obs_dim=max_obs_dim))
infer_states_partially_padded_jit = jit(partial(infer_states_end2end_padded, A_dependencies=spec['A_dependencies'], max_obs_dim=max_obs_dim, max_state_dim=max_state_dim, num_iter=num_iter, sparsity='ll_only'))
obs_padded = apply_obs_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
qs = infer_states_partially_padded_jit(A_padded, obs_padded, D)
[q.shape for q in qs], qs
# JIT
apply_obs_padding_batched_jit = jit(partial(apply_obs_end2end_padding_batched, max_obs_dim=max_obs_dim))
infer_states_partially_padded_jit = jit(partial(infer_states_end2end_padded, A_dependencies=spec['A_dependencies'], max_obs_dim=max_obs_dim, max_state_dim=max_state_dim, num_iter=num_iter, sparsity='ll_only'))
obs_padded = apply_obs_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))
qs = infer_states_partially_padded_jit(A_padded, obs_padded, D)
[q.shape for q in qs], qs
Out[9]:
([(4, 1, 3),
(4, 1, 6),
(4, 1, 2),
(4, 1, 9),
(4, 1, 7),
(4, 1, 4),
(4, 1, 7),
(4, 1, 5),
(4, 1, 7),
(4, 1, 5)],
[Array([[[0.09112392, 0.70235604, 0.20652005]],
[[0.202104 , 0.22807139, 0.56982464]],
[[0.3876841 , 0.3879837 , 0.2243322 ]],
[[0.37516478, 0.27490932, 0.34992588]]], dtype=float32),
Array([[[0.15824087, 0.183674 , 0.26761115, 0.09957846, 0.20257226,
0.08832329]],
[[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
0.31559473]],
[[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
0.22736534]],
[[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
0.12064715]]], dtype=float32),
Array([[[0.55582505, 0.44417495]],
[[0.549685 , 0.450315 ]],
[[0.3831266 , 0.6168734 ]],
[[0.5388056 , 0.4611944 ]]], dtype=float32),
Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16774674,
0.09816181, 0.09280077, 0.1271883 , 0.1426509 ]],
[[0.12493826, 0.16711089, 0.15916635, 0.13118526, 0.06024319,
0.04707498, 0.12128211, 0.10770807, 0.08129089]],
[[0.1815095 , 0.10422336, 0.06684336, 0.14435901, 0.13004695,
0.0635796 , 0.05770647, 0.09602303, 0.15570885]],
[[0.04864096, 0.08837192, 0.09715893, 0.13307187, 0.1358143 ,
0.18936108, 0.10949196, 0.09659162, 0.10149743]]], dtype=float32),
Array([[[0.16893746, 0.16029103, 0.11270652, 0.18204279, 0.21946692,
0.0823974 , 0.07415791]],
[[0.10790369, 0.2240197 , 0.21666422, 0.06581873, 0.10617904,
0.08600904, 0.1934056 ]],
[[0.03150751, 0.13815625, 0.12641439, 0.14349648, 0.20863557,
0.33044103, 0.02134874]],
[[0.09764872, 0.29433855, 0.18542336, 0.14082211, 0.12583902,
0.07777937, 0.07814886]]], dtype=float32),
Array([[[0.20188354, 0.20390284, 0.35773486, 0.23647875]],
[[0.13832287, 0.417142 , 0.29778096, 0.14675426]],
[[0.14250487, 0.34247938, 0.2867788 , 0.22823687]],
[[0.33311817, 0.08985048, 0.27870888, 0.2983224 ]]], dtype=float32),
Array([[[0.25531998, 0.09805848, 0.00180036, 0.06169656, 0.30057994,
0.10489379, 0.17765091]],
[[0.2913621 , 0.11028862, 0.06915001, 0.06127482, 0.01146838,
0.2958903 , 0.16056576]],
[[0.10321112, 0.10843234, 0.09066889, 0.10762182, 0.19040616,
0.18881004, 0.21084957]],
[[0.0594087 , 0.15811725, 0.18402615, 0.22240478, 0.20785518,
0.06183277, 0.10635528]]], dtype=float32),
Array([[[0.43077335, 0.01719547, 0.18029204, 0.19461544, 0.17712368]],
[[0.20952433, 0.28238228, 0.26831535, 0.22654854, 0.01322947]],
[[0.34862235, 0.24600424, 0.12305134, 0.11314452, 0.16917753]],
[[0.15003254, 0.08879873, 0.19839816, 0.5253277 , 0.03744284]]], dtype=float32),
Array([[[0.11571389, 0.04172199, 0.14292377, 0.21811737, 0.1089925 ,
0.2847778 , 0.08775266]],
[[0.11719004, 0.01083917, 0.22239207, 0.11398614, 0.09531904,
0.30416313, 0.13611037]],
[[0.05459085, 0.16350462, 0.13142684, 0.20870967, 0.17459588,
0.15282223, 0.11434983]],
[[0.1454746 , 0.09393647, 0.19433096, 0.22500303, 0.11097853,
0.06738115, 0.16289523]]], dtype=float32),
Array([[[0.23350713, 0.21891789, 0.16771066, 0.17027755, 0.20958683]],
[[0.19165839, 0.2190793 , 0.20649076, 0.19125327, 0.19151816]],
[[0.20866643, 0.18651229, 0.23633833, 0.17361015, 0.19487286]],
[[0.24136762, 0.17386994, 0.22222117, 0.1608011 , 0.20174012]]], dtype=float32)])