summaryrefslogtreecommitdiffstats
path: root/test_polymatrix/test_expression
diff options
context:
space:
mode:
Diffstat (limited to 'test_polymatrix/test_expression')
-rw-r--r--test_polymatrix/test_expression/__init__.py0
-rw-r--r--test_polymatrix/test_expression/test_addition.py64
-rw-r--r--test_polymatrix/test_expression/test_blockdiag.py58
-rw-r--r--test_polymatrix/test_expression/test_derivative.py52
-rw-r--r--test_polymatrix/test_expression/test_divergence.py39
-rw-r--r--test_polymatrix/test_expression/test_eval.py41
-rw-r--r--test_polymatrix/test_expression/test_linearin.py50
-rw-r--r--test_polymatrix/test_expression/test_matrixmult.py64
-rw-r--r--test_polymatrix/test_expression/test_quadraticin.py62
-rw-r--r--test_polymatrix/test_expression/test_substitude.py43
-rw-r--r--test_polymatrix/test_expression/test_subtractmonomials.py55
-rw-r--r--test_polymatrix/test_expression/test_sum.py35
-rw-r--r--test_polymatrix/test_expression/test_symmetric.py53
-rw-r--r--test_polymatrix/test_expression/test_toconstant.py46
-rw-r--r--test_polymatrix/test_expression/test_truncate.py50
-rw-r--r--test_polymatrix/test_expression/test_vstack.py58
16 files changed, 770 insertions, 0 deletions
diff --git a/test_polymatrix/test_expression/__init__.py b/test_polymatrix/test_expression/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test_polymatrix/test_expression/__init__.py
diff --git a/test_polymatrix/test_expression/test_addition.py b/test_polymatrix/test_expression/test_addition.py
new file mode 100644
index 0000000..623b13f
--- /dev/null
+++ b/test_polymatrix/test_expression/test_addition.py
@@ -0,0 +1,64 @@
+import unittest
+
+from polymatrix.expression.init.initadditionexpr import init_addition_expr
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+
+
+class TestAddition(unittest.TestCase):
+
+ def test_1(self):
+ left_terms = {
+ (0, 0): {
+ tuple(): 1.0,
+ ((0, 1),): 1.0,
+ },
+ (1, 0): {
+ ((0, 2),): 1.0,
+ },
+ }
+
+ right_terms = {
+ (0, 0): {
+ tuple(): 3.0,
+ ((1, 1),): 2.0,
+ },
+ (1, 1): {
+ tuple(): 1.0,
+ },
+ }
+
+ left = init_from_terms_expr(
+ terms=left_terms,
+ shape=(2, 2),
+ )
+
+ right = init_from_terms_expr(
+ terms=right_terms,
+ shape=(2, 2),
+ )
+
+ expr = init_addition_expr(
+ left=left,
+ right=right,
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictContainsSubset({
+ tuple(): 4.0,
+ ((0, 1),): 1.0,
+ ((1, 1),): 2.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictContainsSubset({
+ ((0, 2),): 1.0,
+ }, data)
+
+ data = val.get_poly(1, 1)
+ self.assertDictContainsSubset({
+ tuple(): 1.0,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_blockdiag.py b/test_polymatrix/test_expression/test_blockdiag.py
new file mode 100644
index 0000000..69141a4
--- /dev/null
+++ b/test_polymatrix/test_expression/test_blockdiag.py
@@ -0,0 +1,58 @@
+import unittest
+
+from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+
+
+class TestBlockDiag(unittest.TestCase):
+
+ def test_1(self):
+ terms1 = {
+ (0, 0): {
+ ((1, 1),): 1.0,
+ },
+ (1, 0): {
+ tuple(): 2.0,
+ },
+ }
+
+ terms2 = {
+ (0, 0): {
+ tuple(): 3.0,
+ },
+ (1, 1): {
+ tuple(): 4.0,
+ },
+ }
+
+ expr = init_block_diag_expr(
+ underlying=(
+ init_from_terms_expr(terms=terms1, shape=(2, 2),),
+ init_from_terms_expr(terms=terms2, shape=(2, 2),),
+ ),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ ((1, 1),): 1.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictEqual({
+ tuple(): 2.0,
+ }, data)
+
+ data = val.get_poly(2, 2)
+ self.assertDictEqual({
+ tuple(): 3.0,
+ }, data)
+
+ data = val.get_poly(3, 3)
+ self.assertDictEqual({
+ tuple(): 4.0,
+ }, data)
+
diff --git a/test_polymatrix/test_expression/test_derivative.py b/test_polymatrix/test_expression/test_derivative.py
new file mode 100644
index 0000000..a4fc6f6
--- /dev/null
+++ b/test_polymatrix/test_expression/test_derivative.py
@@ -0,0 +1,52 @@
+import unittest
+from polymatrix.expression.init.initderivativeexpr import init_derivative_expr
+from polymatrix.expression.init.initdivergenceexpr import init_divergence_expr
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr
+
+
+class TestDerivative(unittest.TestCase):
+
+ def test_1(self):
+ underlying_terms = {
+ (0, 0): {
+ ((0, 1),): 2.0,
+ ((1, 2),): 3.0,
+ },
+ (1, 0): {
+ tuple(): 5.0,
+ ((0, 1), (2, 3)): 4.0,
+ },
+ }
+
+ expr = init_derivative_expr(
+ underlying=init_from_terms_expr(terms=underlying_terms, shape=(2, 1)),
+ variables=(0, 1, 2),
+ )
+
+ state = init_expression_state(n_param=3)
+ state, val = expr.apply(state)
+
+ self.assertTupleEqual(val.shape, (2, 3))
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 2.0,
+ }, data)
+
+ data = val.get_poly(0, 1)
+ self.assertDictEqual({
+ ((1, 1),): 6.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictEqual({
+ ((2, 3),): 4.0,
+ }, data)
+
+ data = val.get_poly(1, 2)
+ self.assertDictEqual({
+ ((0, 1), (2, 2)): 12.0,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_divergence.py b/test_polymatrix/test_expression/test_divergence.py
new file mode 100644
index 0000000..412dcca
--- /dev/null
+++ b/test_polymatrix/test_expression/test_divergence.py
@@ -0,0 +1,39 @@
+import unittest
+from polymatrix.expression.init.initdivergenceexpr import init_divergence_expr
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr
+
+
+class TestDivergence(unittest.TestCase):
+
+ def test_1(self):
+ underlying_terms = {
+ (0, 0): {
+ ((0, 1),): 2.0,
+ ((1, 1),): 3.0,
+ },
+ (1, 0): {
+ tuple(): 5.0,
+ ((0, 1),): 3.0,
+ },
+ (2, 0): {
+ ((0, 1),): 2.0,
+ ((1, 1), (2, 3)): 3.0,
+ },
+ }
+
+ expr = init_divergence_expr(
+ underlying=init_from_terms_expr(terms=underlying_terms, shape=(3, 1)),
+ variables=(0, 1, 2),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 2.0,
+ ((1, 1), (2, 2)): 9.0,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_eval.py b/test_polymatrix/test_expression/test_eval.py
new file mode 100644
index 0000000..662322a
--- /dev/null
+++ b/test_polymatrix/test_expression/test_eval.py
@@ -0,0 +1,41 @@
+import unittest
+
+from polymatrix.expression.init.initevalexpr import init_eval_expr
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+
+
+class TestEval(unittest.TestCase):
+
+ def test_1(self):
+ terms = {
+ (0, 0): {
+ ((0, 1), (2, 1)): 2.0,
+ ((0, 1), (1, 1), (3, 1)): 3.0,
+ }, (1, 0):{
+ tuple(): 1.0,
+ ((1, 2),): 1.0,
+ ((2, 1),): 1.0,
+ },
+ }
+
+ expr = init_eval_expr(
+ underlying=init_from_terms_expr(terms=terms, shape=(2, 1)),
+ variables=(0, 1),
+ values=(2.0, 3.0),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ ((2, 1),): 4.0,
+ ((3, 1),): 18.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictEqual({
+ tuple(): 10.0,
+ ((2, 1),): 1,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_linearin.py b/test_polymatrix/test_expression/test_linearin.py
new file mode 100644
index 0000000..67f2c3d
--- /dev/null
+++ b/test_polymatrix/test_expression/test_linearin.py
@@ -0,0 +1,50 @@
+import unittest
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr
+
+
+class TestLinearIn(unittest.TestCase):
+
+ def test_1(self):
+ underlying_terms = {
+ (0, 0): {
+ ((0, 1),): 2.0,
+ ((1, 1),): 3.0,
+ },
+ }
+
+ monomial_terms = {
+ (0, 0): {
+ ((0, 1),): 1.0,
+ },
+ (1, 0): {
+ ((2, 1),): 1.0,
+ },
+ (2, 0): {
+ ((1, 1),): 1.0,
+ },
+ (3, 0): {
+ ((3, 1),): 1.0,
+ },
+ }
+
+ expr = init_linear_in_expr(
+ underlying=init_from_terms_expr(terms=underlying_terms, shape=(2, 1)),
+ monomials=init_from_terms_expr(terms=monomial_terms, shape=(4, 1),),
+ variables=(0, 1),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 2.0,
+ }, data)
+
+ data = val.get_poly(0, 2)
+ self.assertDictEqual({
+ tuple(): 3.0,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_matrixmult.py b/test_polymatrix/test_expression/test_matrixmult.py
new file mode 100644
index 0000000..6facb48
--- /dev/null
+++ b/test_polymatrix/test_expression/test_matrixmult.py
@@ -0,0 +1,64 @@
+import unittest
+from polymatrix.expression.init.initadditionexpr import init_addition_expr
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initmatrixmultexpr import init_matrix_mult_expr
+
+
+class TestMatrixMult(unittest.TestCase):
+
+ def test_1(self):
+ left_terms = {
+ (0, 0): {
+ tuple(): 1.0,
+ ((0, 1),): 1.0,
+ },
+ (0, 1): {
+ ((0, 1),): 1.0,
+ },
+ (1, 1): {
+ ((0, 2),): 1.0,
+ },
+ }
+
+ right_terms = {
+ (0, 0): {
+ tuple(): 3.0,
+ ((1, 1),): 2.0,
+ },
+ (1, 0): {
+ tuple(): 1.0,
+ },
+ }
+
+ left = init_from_terms_expr(
+ terms=left_terms,
+ shape=(2, 2),
+ )
+
+ right = init_from_terms_expr(
+ terms=right_terms,
+ shape=(2, 1),
+ )
+
+ expr = init_matrix_mult_expr(
+ left=left,
+ right=right,
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 3.0,
+ ((0, 1),): 4.0,
+ ((1, 1),): 2.0,
+ ((0, 1), (1, 1),): 2.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictEqual({
+ ((0, 2),): 1.0,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_quadraticin.py b/test_polymatrix/test_expression/test_quadraticin.py
new file mode 100644
index 0000000..c340d7f
--- /dev/null
+++ b/test_polymatrix/test_expression/test_quadraticin.py
@@ -0,0 +1,62 @@
+import unittest
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
+
+
+class TestQuadraticIn(unittest.TestCase):
+
+ def test_1(self):
+ underlying_terms = {
+ (0, 0): {
+ ((0, 1),): 1.0, # x1
+ ((0, 1), (2, 1)): 2.0, # x1
+ ((0, 2), (3, 1)): 3.0, # x1 x1
+ ((0, 2), (1, 2), (4, 1)): 4.0, # x1 x1 x2 x2
+ ((0, 2), (1, 1), (5, 1)): 5.0, # x1 x1 x2
+ }
+ }
+
+ monomial_terms = {
+ (0, 0): {
+ tuple(): 1.0,
+ },
+ (1, 0): {
+ ((0, 1),): 1.0,
+ },
+ (2, 0): {
+ ((0, 1), (1, 1)): 1.0,
+ },
+ }
+
+ expr = init_quadratic_in_expr(
+ underlying=init_from_terms_expr(terms=underlying_terms, shape=(1, 1)),
+ monomials=init_from_terms_expr(terms=monomial_terms, shape=(3, 1)),
+ variables=(0, 1),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 1)
+ self.assertDictContainsSubset({
+ tuple(): 1.0,
+ ((2, 1),): 2.0,
+ }, data)
+
+ data = val.get_poly(1, 1)
+ self.assertDictContainsSubset({
+ ((3, 1),): 3.0,
+ }, data)
+
+ data = val.get_poly(2, 2)
+ self.assertDictContainsSubset({
+ ((4, 1),): 4.0,
+ }, data)
+
+ data = val.get_poly(1, 2)
+ self.assertDictContainsSubset({
+ ((5, 1),): 5.0,
+ }, data)
+ \ No newline at end of file
diff --git a/test_polymatrix/test_expression/test_substitude.py b/test_polymatrix/test_expression/test_substitude.py
new file mode 100644
index 0000000..9a6e875
--- /dev/null
+++ b/test_polymatrix/test_expression/test_substitude.py
@@ -0,0 +1,43 @@
+import unittest
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initsubstituteexpr import init_substitute_expr
+
+
+class TestEval(unittest.TestCase):
+
+ def test_1(self):
+ terms = {
+ (0, 0): {
+ tuple(): 2.0,
+ ((0, 2),): 3.0,
+ ((1, 1),): 1.0,
+ ((2, 2),): 1.0,
+ },
+ }
+
+ substitution = {
+ (0, 0): {
+ ((1, 1),): 1.0,
+ ((2, 1),): 1.0,
+ },
+ }
+
+ expr = init_substitute_expr(
+ underlying=init_from_terms_expr(terms=terms, shape=(1, 1)),
+ variables=(0,),
+ substitutions=(init_from_terms_expr(terms=substitution, shape=(1, 1)),),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 2.0,
+ ((1, 1),): 1.0,
+ ((1, 2),): 3.0,
+ ((1, 1), (2, 1)): 6.0,
+ ((2, 2),): 4.0
+ }, data)
diff --git a/test_polymatrix/test_expression/test_subtractmonomials.py b/test_polymatrix/test_expression/test_subtractmonomials.py
new file mode 100644
index 0000000..f80f76a
--- /dev/null
+++ b/test_polymatrix/test_expression/test_subtractmonomials.py
@@ -0,0 +1,55 @@
+import unittest
+from polymatrix.expression.init.initderivativeexpr import init_derivative_expr
+from polymatrix.expression.init.initdivergenceexpr import init_divergence_expr
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr
+from polymatrix.expression.init.initsubtractmonomialsexpr import init_subtract_monomials_expr
+
+
+class TestDerivative(unittest.TestCase):
+
+ def test_1(self):
+ monomials1 = {
+ (0, 0): {
+ ((0, 1),): 1.0,
+ },
+ (1, 0): {
+ ((0, 1), (1, 2)): 1.0,
+ },
+ }
+
+ monomials2 = {
+ (0, 0): {
+ ((0, 1),): 1.0,
+ },
+ (1, 0): {
+ ((1, 1),): 1.0,
+ },
+ }
+
+ expr = init_subtract_monomials_expr(
+ underlying=init_from_terms_expr(terms=monomials1, shape=(2, 1)),
+ monomials=init_from_terms_expr(terms=monomials2, shape=(2, 1)),
+ )
+
+ state = init_expression_state(n_param=3)
+ state, val = expr.apply(state)
+
+ self.assertTupleEqual(val.shape, (3, 1))
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 1.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictEqual({
+ ((1, 2),): 1.0,
+ }, data)
+
+ data = val.get_poly(2, 0)
+ self.assertDictEqual({
+ ((0, 1), (1, 1)): 1.0,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_sum.py b/test_polymatrix/test_expression/test_sum.py
new file mode 100644
index 0000000..48ae073
--- /dev/null
+++ b/test_polymatrix/test_expression/test_sum.py
@@ -0,0 +1,35 @@
+import unittest
+
+from polymatrix.expression.init.initevalexpr import init_eval_expr
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initsumexpr import init_sum_expr
+
+
+class TestSum(unittest.TestCase):
+
+ def test_1(self):
+ terms = {
+ (0, 0): {
+ tuple(): 2.0,
+ ((0, 1),): 3.0,
+ },
+ (0, 1):{
+ tuple(): 1.0,
+ ((0, 2),): 1.0,
+ },
+ }
+
+ expr = init_sum_expr(
+ underlying=init_from_terms_expr(terms=terms, shape=(1, 2)),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 3.0,
+ ((0, 1),): 3.0,
+ ((0, 2),): 1.0,
+ }, data)
diff --git a/test_polymatrix/test_expression/test_symmetric.py b/test_polymatrix/test_expression/test_symmetric.py
new file mode 100644
index 0000000..ac5eba6
--- /dev/null
+++ b/test_polymatrix/test_expression/test_symmetric.py
@@ -0,0 +1,53 @@
+import unittest
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
+from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr
+
+
+class TestQuadraticIn(unittest.TestCase):
+
+ def test_1(self):
+ terms = {
+ (0, 0): {
+ ((0, 1),): 1.0,
+ },
+ (1, 0): {
+ ((1, 1),): 1.0,
+ },
+ (0, 1): {
+ ((1, 1),): 1.0,
+ ((2, 1),): 1.0,
+ },
+ }
+
+ underlying = init_from_terms_expr(
+ terms=terms,
+ shape=(2, 2),
+ )
+
+ expr = init_symmetric_expr(
+ underlying=underlying,
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictContainsSubset({
+ ((0, 1),): 1.0,
+ }, data)
+
+ data = val.get_poly(0, 1)
+ self.assertDictContainsSubset({
+ ((1, 1),): 1.0,
+ ((2, 1),): 0.5,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictContainsSubset({
+ ((1, 1),): 1.0,
+ ((2, 1),): 0.5,
+ }, data)
+ \ No newline at end of file
diff --git a/test_polymatrix/test_expression/test_toconstant.py b/test_polymatrix/test_expression/test_toconstant.py
new file mode 100644
index 0000000..784aeec
--- /dev/null
+++ b/test_polymatrix/test_expression/test_toconstant.py
@@ -0,0 +1,46 @@
+import unittest
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
+from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr
+from polymatrix.expression.init.inittoconstantexpr import init_to_constant_expr
+from polymatrix.expression.init.inittruncateexpr import init_truncate_expr
+
+
+class TestToConstant(unittest.TestCase):
+
+ def test_1(self):
+ terms = {
+ (0, 0): {
+ tuple(): 2.0,
+ ((0, 1),): 1.0,
+ },
+ (1, 0): {
+ ((0, 2), (1, 1)): 1.0,
+ ((0, 3), (1, 1)): 1.0,
+ },
+ (0, 1): {
+ tuple(): 5.0,
+ ((0, 2), (2, 1),): 1.0,
+ ((3, 1),): 1.0,
+ },
+ }
+
+ expr = init_to_constant_expr(
+ underlying=init_from_terms_expr(terms=terms, shape=(2, 2)),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ tuple(): 2.0,
+ }, data)
+
+ data = val.get_poly(0, 1)
+ self.assertDictEqual({
+ tuple(): 5.0,
+ }, data)
+ \ No newline at end of file
diff --git a/test_polymatrix/test_expression/test_truncate.py b/test_polymatrix/test_expression/test_truncate.py
new file mode 100644
index 0000000..e944229
--- /dev/null
+++ b/test_polymatrix/test_expression/test_truncate.py
@@ -0,0 +1,50 @@
+import unittest
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
+from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr
+from polymatrix.expression.init.inittruncateexpr import init_truncate_expr
+
+
+class TestTruncate(unittest.TestCase):
+
+ def test_1(self):
+ terms = {
+ (0, 0): {
+ ((0, 1),): 1.0, # x1 x1
+ },
+ (1, 0): {
+ ((0, 2), (1, 1)): 1.0, # x1 x1 x2
+ ((0, 3), (1, 1)): 1.0, # x1 x1 x1 x2
+ },
+ (0, 1): {
+ ((0, 2), (2, 1),): 1.0, # x1 x1
+ ((3, 1),): 1.0,
+ },
+ }
+
+ expr = init_truncate_expr(
+ underlying=init_from_terms_expr(terms=terms, shape=(2, 2)),
+ variables=(0, 1),
+ degrees=(1, 2),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ ((0, 1),): 1.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictEqual(
+ {}, data
+ )
+
+ data = val.get_poly(0, 1)
+ self.assertDictEqual({
+ ((0, 2), (2, 1)): 1.0, # x1 x1
+ }, data)
+ \ No newline at end of file
diff --git a/test_polymatrix/test_expression/test_vstack.py b/test_polymatrix/test_expression/test_vstack.py
new file mode 100644
index 0000000..a50267f
--- /dev/null
+++ b/test_polymatrix/test_expression/test_vstack.py
@@ -0,0 +1,58 @@
+import unittest
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initvstackexpr import init_v_stack_expr
+
+
+class TestVStack(unittest.TestCase):
+
+ def test_1(self):
+ terms1 = {
+ (0, 0): {
+ ((1, 1),): 1.0,
+ },
+ (1, 0): {
+ tuple(): 2.0,
+ },
+ }
+
+ terms2 = {
+ (0, 0): {
+ tuple(): 3.0,
+ },
+ (1, 1): {
+ tuple(): 4.0,
+ },
+ }
+
+ expr = init_v_stack_expr(
+ underlying=(
+ init_from_terms_expr(terms=terms1, shape=(2, 2),),
+ init_from_terms_expr(terms=terms2, shape=(2, 2),),
+ ),
+ )
+
+ state = init_expression_state(n_param=2)
+ state, val = expr.apply(state)
+
+ data = val.get_poly(0, 0)
+ self.assertDictEqual({
+ ((1, 1),): 1.0,
+ }, data)
+
+ data = val.get_poly(1, 0)
+ self.assertDictEqual({
+ tuple(): 2.0,
+ }, data)
+
+ data = val.get_poly(2, 0)
+ self.assertDictEqual({
+ tuple(): 3.0,
+ }, data)
+
+ data = val.get_poly(3, 1)
+ self.assertDictEqual({
+ tuple(): 4.0,
+ }, data)
+