import unittest from polymatrix.expressionstate.init.initexpressionstate import init_expression_state import polymatrix.expression.initexpressionbase class TestBlockDiag(unittest.TestCase): def test_1(self): terms1 = { (0, 0): { ((1, 1),): 1.0, }, (1, 0): { tuple(): 2.0, }, } terms2 = { (0, 0): { tuple(): 3.0, }, (1, 1): { tuple(): 4.0, }, } expr = polymatrix.expression.initexpressionbase.init_block_diag_expr( underlying=( polymatrix.expression.initexpressionbase.init_from_terms_expr(terms=terms1, shape=(2, 2),), polymatrix.expression.initexpressionbase.init_from_terms_expr(terms=terms2, shape=(2, 2),), ), ) state = init_expression_state(n_param=2) state, val = expr.apply(state) data = val.get_poly(0, 0) self.assertDictEqual({ ((1, 1),): 1.0, }, data) data = val.get_poly(1, 0) self.assertDictEqual({ tuple(): 2.0, }, data) data = val.get_poly(2, 2) self.assertDictEqual({ tuple(): 3.0, }, data) data = val.get_poly(3, 3) self.assertDictEqual({ tuple(): 4.0, }, data)