diff options
Diffstat (limited to 'test_polymatrix/test_expression')
-rw-r--r-- | test_polymatrix/test_expression/__init__.py | 0 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_addition.py | 64 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_blockdiag.py | 58 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_derivative.py | 52 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_divergence.py | 39 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_eval.py | 41 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_linearin.py | 50 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_matrixmult.py | 64 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_quadraticin.py | 62 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_substitude.py | 43 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_subtractmonomials.py | 55 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_sum.py | 35 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_symmetric.py | 53 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_toconstant.py | 46 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_truncate.py | 50 | ||||
-rw-r--r-- | test_polymatrix/test_expression/test_vstack.py | 58 |
16 files changed, 770 insertions, 0 deletions
diff --git a/test_polymatrix/test_expression/__init__.py b/test_polymatrix/test_expression/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test_polymatrix/test_expression/__init__.py diff --git a/test_polymatrix/test_expression/test_addition.py b/test_polymatrix/test_expression/test_addition.py new file mode 100644 index 0000000..623b13f --- /dev/null +++ b/test_polymatrix/test_expression/test_addition.py @@ -0,0 +1,64 @@ +import unittest + +from polymatrix.expression.init.initadditionexpr import init_addition_expr +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr + + +class TestAddition(unittest.TestCase): + + def test_1(self): + left_terms = { + (0, 0): { + tuple(): 1.0, + ((0, 1),): 1.0, + }, + (1, 0): { + ((0, 2),): 1.0, + }, + } + + right_terms = { + (0, 0): { + tuple(): 3.0, + ((1, 1),): 2.0, + }, + (1, 1): { + tuple(): 1.0, + }, + } + + left = init_from_terms_expr( + terms=left_terms, + shape=(2, 2), + ) + + right = init_from_terms_expr( + terms=right_terms, + shape=(2, 2), + ) + + expr = init_addition_expr( + left=left, + right=right, + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictContainsSubset({ + tuple(): 4.0, + ((0, 1),): 1.0, + ((1, 1),): 2.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictContainsSubset({ + ((0, 2),): 1.0, + }, data) + + data = val.get_poly(1, 1) + self.assertDictContainsSubset({ + tuple(): 1.0, + }, data) diff --git a/test_polymatrix/test_expression/test_blockdiag.py b/test_polymatrix/test_expression/test_blockdiag.py new file mode 100644 index 0000000..69141a4 --- /dev/null +++ b/test_polymatrix/test_expression/test_blockdiag.py @@ -0,0 +1,58 @@ +import unittest + +from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr + + +class TestBlockDiag(unittest.TestCase): + + def test_1(self): + terms1 = { + (0, 0): { + ((1, 1),): 1.0, + }, + (1, 0): { + tuple(): 2.0, + }, + } + + terms2 = { + (0, 0): { + tuple(): 3.0, + }, + (1, 1): { + tuple(): 4.0, + }, + } + + expr = init_block_diag_expr( + underlying=( + init_from_terms_expr(terms=terms1, shape=(2, 2),), + init_from_terms_expr(terms=terms2, shape=(2, 2),), + ), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + ((1, 1),): 1.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictEqual({ + tuple(): 2.0, + }, data) + + data = val.get_poly(2, 2) + self.assertDictEqual({ + tuple(): 3.0, + }, data) + + data = val.get_poly(3, 3) + self.assertDictEqual({ + tuple(): 4.0, + }, data) + diff --git a/test_polymatrix/test_expression/test_derivative.py b/test_polymatrix/test_expression/test_derivative.py new file mode 100644 index 0000000..a4fc6f6 --- /dev/null +++ b/test_polymatrix/test_expression/test_derivative.py @@ -0,0 +1,52 @@ +import unittest +from polymatrix.expression.init.initderivativeexpr import init_derivative_expr +from polymatrix.expression.init.initdivergenceexpr import init_divergence_expr + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr + + +class TestDerivative(unittest.TestCase): + + def test_1(self): + underlying_terms = { + (0, 0): { + ((0, 1),): 2.0, + ((1, 2),): 3.0, + }, + (1, 0): { + tuple(): 5.0, + ((0, 1), (2, 3)): 4.0, + }, + } + + expr = init_derivative_expr( + underlying=init_from_terms_expr(terms=underlying_terms, shape=(2, 1)), + variables=(0, 1, 2), + ) + + state = init_expression_state(n_param=3) + state, val = expr.apply(state) + + self.assertTupleEqual(val.shape, (2, 3)) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 2.0, + }, data) + + data = val.get_poly(0, 1) + self.assertDictEqual({ + ((1, 1),): 6.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictEqual({ + ((2, 3),): 4.0, + }, data) + + data = val.get_poly(1, 2) + self.assertDictEqual({ + ((0, 1), (2, 2)): 12.0, + }, data) diff --git a/test_polymatrix/test_expression/test_divergence.py b/test_polymatrix/test_expression/test_divergence.py new file mode 100644 index 0000000..412dcca --- /dev/null +++ b/test_polymatrix/test_expression/test_divergence.py @@ -0,0 +1,39 @@ +import unittest +from polymatrix.expression.init.initdivergenceexpr import init_divergence_expr + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr + + +class TestDivergence(unittest.TestCase): + + def test_1(self): + underlying_terms = { + (0, 0): { + ((0, 1),): 2.0, + ((1, 1),): 3.0, + }, + (1, 0): { + tuple(): 5.0, + ((0, 1),): 3.0, + }, + (2, 0): { + ((0, 1),): 2.0, + ((1, 1), (2, 3)): 3.0, + }, + } + + expr = init_divergence_expr( + underlying=init_from_terms_expr(terms=underlying_terms, shape=(3, 1)), + variables=(0, 1, 2), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 2.0, + ((1, 1), (2, 2)): 9.0, + }, data) diff --git a/test_polymatrix/test_expression/test_eval.py b/test_polymatrix/test_expression/test_eval.py new file mode 100644 index 0000000..662322a --- /dev/null +++ b/test_polymatrix/test_expression/test_eval.py @@ -0,0 +1,41 @@ +import unittest + +from polymatrix.expression.init.initevalexpr import init_eval_expr +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr + + +class TestEval(unittest.TestCase): + + def test_1(self): + terms = { + (0, 0): { + ((0, 1), (2, 1)): 2.0, + ((0, 1), (1, 1), (3, 1)): 3.0, + }, (1, 0):{ + tuple(): 1.0, + ((1, 2),): 1.0, + ((2, 1),): 1.0, + }, + } + + expr = init_eval_expr( + underlying=init_from_terms_expr(terms=terms, shape=(2, 1)), + variables=(0, 1), + values=(2.0, 3.0), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + ((2, 1),): 4.0, + ((3, 1),): 18.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictEqual({ + tuple(): 10.0, + ((2, 1),): 1, + }, data) diff --git a/test_polymatrix/test_expression/test_linearin.py b/test_polymatrix/test_expression/test_linearin.py new file mode 100644 index 0000000..67f2c3d --- /dev/null +++ b/test_polymatrix/test_expression/test_linearin.py @@ -0,0 +1,50 @@ +import unittest + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr + + +class TestLinearIn(unittest.TestCase): + + def test_1(self): + underlying_terms = { + (0, 0): { + ((0, 1),): 2.0, + ((1, 1),): 3.0, + }, + } + + monomial_terms = { + (0, 0): { + ((0, 1),): 1.0, + }, + (1, 0): { + ((2, 1),): 1.0, + }, + (2, 0): { + ((1, 1),): 1.0, + }, + (3, 0): { + ((3, 1),): 1.0, + }, + } + + expr = init_linear_in_expr( + underlying=init_from_terms_expr(terms=underlying_terms, shape=(2, 1)), + monomials=init_from_terms_expr(terms=monomial_terms, shape=(4, 1),), + variables=(0, 1), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 2.0, + }, data) + + data = val.get_poly(0, 2) + self.assertDictEqual({ + tuple(): 3.0, + }, data) diff --git a/test_polymatrix/test_expression/test_matrixmult.py b/test_polymatrix/test_expression/test_matrixmult.py new file mode 100644 index 0000000..6facb48 --- /dev/null +++ b/test_polymatrix/test_expression/test_matrixmult.py @@ -0,0 +1,64 @@ +import unittest +from polymatrix.expression.init.initadditionexpr import init_addition_expr +from polymatrix.expression.init.initexpressionstate import init_expression_state + +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initmatrixmultexpr import init_matrix_mult_expr + + +class TestMatrixMult(unittest.TestCase): + + def test_1(self): + left_terms = { + (0, 0): { + tuple(): 1.0, + ((0, 1),): 1.0, + }, + (0, 1): { + ((0, 1),): 1.0, + }, + (1, 1): { + ((0, 2),): 1.0, + }, + } + + right_terms = { + (0, 0): { + tuple(): 3.0, + ((1, 1),): 2.0, + }, + (1, 0): { + tuple(): 1.0, + }, + } + + left = init_from_terms_expr( + terms=left_terms, + shape=(2, 2), + ) + + right = init_from_terms_expr( + terms=right_terms, + shape=(2, 1), + ) + + expr = init_matrix_mult_expr( + left=left, + right=right, + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 3.0, + ((0, 1),): 4.0, + ((1, 1),): 2.0, + ((0, 1), (1, 1),): 2.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictEqual({ + ((0, 2),): 1.0, + }, data) diff --git a/test_polymatrix/test_expression/test_quadraticin.py b/test_polymatrix/test_expression/test_quadraticin.py new file mode 100644 index 0000000..c340d7f --- /dev/null +++ b/test_polymatrix/test_expression/test_quadraticin.py @@ -0,0 +1,62 @@ +import unittest + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr + + +class TestQuadraticIn(unittest.TestCase): + + def test_1(self): + underlying_terms = { + (0, 0): { + ((0, 1),): 1.0, # x1 + ((0, 1), (2, 1)): 2.0, # x1 + ((0, 2), (3, 1)): 3.0, # x1 x1 + ((0, 2), (1, 2), (4, 1)): 4.0, # x1 x1 x2 x2 + ((0, 2), (1, 1), (5, 1)): 5.0, # x1 x1 x2 + } + } + + monomial_terms = { + (0, 0): { + tuple(): 1.0, + }, + (1, 0): { + ((0, 1),): 1.0, + }, + (2, 0): { + ((0, 1), (1, 1)): 1.0, + }, + } + + expr = init_quadratic_in_expr( + underlying=init_from_terms_expr(terms=underlying_terms, shape=(1, 1)), + monomials=init_from_terms_expr(terms=monomial_terms, shape=(3, 1)), + variables=(0, 1), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 1) + self.assertDictContainsSubset({ + tuple(): 1.0, + ((2, 1),): 2.0, + }, data) + + data = val.get_poly(1, 1) + self.assertDictContainsSubset({ + ((3, 1),): 3.0, + }, data) + + data = val.get_poly(2, 2) + self.assertDictContainsSubset({ + ((4, 1),): 4.0, + }, data) + + data = val.get_poly(1, 2) + self.assertDictContainsSubset({ + ((5, 1),): 5.0, + }, data) +
\ No newline at end of file diff --git a/test_polymatrix/test_expression/test_substitude.py b/test_polymatrix/test_expression/test_substitude.py new file mode 100644 index 0000000..9a6e875 --- /dev/null +++ b/test_polymatrix/test_expression/test_substitude.py @@ -0,0 +1,43 @@ +import unittest + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initsubstituteexpr import init_substitute_expr + + +class TestEval(unittest.TestCase): + + def test_1(self): + terms = { + (0, 0): { + tuple(): 2.0, + ((0, 2),): 3.0, + ((1, 1),): 1.0, + ((2, 2),): 1.0, + }, + } + + substitution = { + (0, 0): { + ((1, 1),): 1.0, + ((2, 1),): 1.0, + }, + } + + expr = init_substitute_expr( + underlying=init_from_terms_expr(terms=terms, shape=(1, 1)), + variables=(0,), + substitutions=(init_from_terms_expr(terms=substitution, shape=(1, 1)),), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 2.0, + ((1, 1),): 1.0, + ((1, 2),): 3.0, + ((1, 1), (2, 1)): 6.0, + ((2, 2),): 4.0 + }, data) diff --git a/test_polymatrix/test_expression/test_subtractmonomials.py b/test_polymatrix/test_expression/test_subtractmonomials.py new file mode 100644 index 0000000..f80f76a --- /dev/null +++ b/test_polymatrix/test_expression/test_subtractmonomials.py @@ -0,0 +1,55 @@ +import unittest +from polymatrix.expression.init.initderivativeexpr import init_derivative_expr +from polymatrix.expression.init.initdivergenceexpr import init_divergence_expr + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr +from polymatrix.expression.init.initsubtractmonomialsexpr import init_subtract_monomials_expr + + +class TestDerivative(unittest.TestCase): + + def test_1(self): + monomials1 = { + (0, 0): { + ((0, 1),): 1.0, + }, + (1, 0): { + ((0, 1), (1, 2)): 1.0, + }, + } + + monomials2 = { + (0, 0): { + ((0, 1),): 1.0, + }, + (1, 0): { + ((1, 1),): 1.0, + }, + } + + expr = init_subtract_monomials_expr( + underlying=init_from_terms_expr(terms=monomials1, shape=(2, 1)), + monomials=init_from_terms_expr(terms=monomials2, shape=(2, 1)), + ) + + state = init_expression_state(n_param=3) + state, val = expr.apply(state) + + self.assertTupleEqual(val.shape, (3, 1)) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 1.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictEqual({ + ((1, 2),): 1.0, + }, data) + + data = val.get_poly(2, 0) + self.assertDictEqual({ + ((0, 1), (1, 1)): 1.0, + }, data) diff --git a/test_polymatrix/test_expression/test_sum.py b/test_polymatrix/test_expression/test_sum.py new file mode 100644 index 0000000..48ae073 --- /dev/null +++ b/test_polymatrix/test_expression/test_sum.py @@ -0,0 +1,35 @@ +import unittest + +from polymatrix.expression.init.initevalexpr import init_eval_expr +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initsumexpr import init_sum_expr + + +class TestSum(unittest.TestCase): + + def test_1(self): + terms = { + (0, 0): { + tuple(): 2.0, + ((0, 1),): 3.0, + }, + (0, 1):{ + tuple(): 1.0, + ((0, 2),): 1.0, + }, + } + + expr = init_sum_expr( + underlying=init_from_terms_expr(terms=terms, shape=(1, 2)), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 3.0, + ((0, 1),): 3.0, + ((0, 2),): 1.0, + }, data) diff --git a/test_polymatrix/test_expression/test_symmetric.py b/test_polymatrix/test_expression/test_symmetric.py new file mode 100644 index 0000000..ac5eba6 --- /dev/null +++ b/test_polymatrix/test_expression/test_symmetric.py @@ -0,0 +1,53 @@ +import unittest + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr +from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr + + +class TestQuadraticIn(unittest.TestCase): + + def test_1(self): + terms = { + (0, 0): { + ((0, 1),): 1.0, + }, + (1, 0): { + ((1, 1),): 1.0, + }, + (0, 1): { + ((1, 1),): 1.0, + ((2, 1),): 1.0, + }, + } + + underlying = init_from_terms_expr( + terms=terms, + shape=(2, 2), + ) + + expr = init_symmetric_expr( + underlying=underlying, + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictContainsSubset({ + ((0, 1),): 1.0, + }, data) + + data = val.get_poly(0, 1) + self.assertDictContainsSubset({ + ((1, 1),): 1.0, + ((2, 1),): 0.5, + }, data) + + data = val.get_poly(1, 0) + self.assertDictContainsSubset({ + ((1, 1),): 1.0, + ((2, 1),): 0.5, + }, data) +
\ No newline at end of file diff --git a/test_polymatrix/test_expression/test_toconstant.py b/test_polymatrix/test_expression/test_toconstant.py new file mode 100644 index 0000000..784aeec --- /dev/null +++ b/test_polymatrix/test_expression/test_toconstant.py @@ -0,0 +1,46 @@ +import unittest + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr +from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr +from polymatrix.expression.init.inittoconstantexpr import init_to_constant_expr +from polymatrix.expression.init.inittruncateexpr import init_truncate_expr + + +class TestToConstant(unittest.TestCase): + + def test_1(self): + terms = { + (0, 0): { + tuple(): 2.0, + ((0, 1),): 1.0, + }, + (1, 0): { + ((0, 2), (1, 1)): 1.0, + ((0, 3), (1, 1)): 1.0, + }, + (0, 1): { + tuple(): 5.0, + ((0, 2), (2, 1),): 1.0, + ((3, 1),): 1.0, + }, + } + + expr = init_to_constant_expr( + underlying=init_from_terms_expr(terms=terms, shape=(2, 2)), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + tuple(): 2.0, + }, data) + + data = val.get_poly(0, 1) + self.assertDictEqual({ + tuple(): 5.0, + }, data) +
\ No newline at end of file diff --git a/test_polymatrix/test_expression/test_truncate.py b/test_polymatrix/test_expression/test_truncate.py new file mode 100644 index 0000000..e944229 --- /dev/null +++ b/test_polymatrix/test_expression/test_truncate.py @@ -0,0 +1,50 @@ +import unittest + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr +from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr +from polymatrix.expression.init.inittruncateexpr import init_truncate_expr + + +class TestTruncate(unittest.TestCase): + + def test_1(self): + terms = { + (0, 0): { + ((0, 1),): 1.0, # x1 x1 + }, + (1, 0): { + ((0, 2), (1, 1)): 1.0, # x1 x1 x2 + ((0, 3), (1, 1)): 1.0, # x1 x1 x1 x2 + }, + (0, 1): { + ((0, 2), (2, 1),): 1.0, # x1 x1 + ((3, 1),): 1.0, + }, + } + + expr = init_truncate_expr( + underlying=init_from_terms_expr(terms=terms, shape=(2, 2)), + variables=(0, 1), + degrees=(1, 2), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + ((0, 1),): 1.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictEqual( + {}, data + ) + + data = val.get_poly(0, 1) + self.assertDictEqual({ + ((0, 2), (2, 1)): 1.0, # x1 x1 + }, data) +
\ No newline at end of file diff --git a/test_polymatrix/test_expression/test_vstack.py b/test_polymatrix/test_expression/test_vstack.py new file mode 100644 index 0000000..a50267f --- /dev/null +++ b/test_polymatrix/test_expression/test_vstack.py @@ -0,0 +1,58 @@ +import unittest + +from polymatrix.expression.init.initexpressionstate import init_expression_state +from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr +from polymatrix.expression.init.initvstackexpr import init_v_stack_expr + + +class TestVStack(unittest.TestCase): + + def test_1(self): + terms1 = { + (0, 0): { + ((1, 1),): 1.0, + }, + (1, 0): { + tuple(): 2.0, + }, + } + + terms2 = { + (0, 0): { + tuple(): 3.0, + }, + (1, 1): { + tuple(): 4.0, + }, + } + + expr = init_v_stack_expr( + underlying=( + init_from_terms_expr(terms=terms1, shape=(2, 2),), + init_from_terms_expr(terms=terms2, shape=(2, 2),), + ), + ) + + state = init_expression_state(n_param=2) + state, val = expr.apply(state) + + data = val.get_poly(0, 0) + self.assertDictEqual({ + ((1, 1),): 1.0, + }, data) + + data = val.get_poly(1, 0) + self.assertDictEqual({ + tuple(): 2.0, + }, data) + + data = val.get_poly(2, 0) + self.assertDictEqual({ + tuple(): 3.0, + }, data) + + data = val.get_poly(3, 1) + self.assertDictEqual({ + tuple(): 4.0, + }, data) + |