summaryrefslogtreecommitdiffstats
path: root/test_polymatrix/test_tomatrixrepr.py
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-04 13:08:50 +0200
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2022-08-04 13:08:50 +0200
commitfbba836cf1eadaf0f477e04e8e3f2b1cc55eeea5 (patch)
treed712bd413b7a4c545b117180623826c8711bea31 /test_polymatrix/test_tomatrixrepr.py
parentadd polynomial operations for sos optimization (diff)
downloadpolymatrix-fbba836cf1eadaf0f477e04e8e3f2b1cc55eeea5.tar.gz
polymatrix-fbba836cf1eadaf0f477e04e8e3f2b1cc55eeea5.zip
add max_degree, max and filter operator
Diffstat (limited to 'test_polymatrix/test_tomatrixrepr.py')
-rw-r--r--test_polymatrix/test_tomatrixrepr.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/test_polymatrix/test_tomatrixrepr.py b/test_polymatrix/test_tomatrixrepr.py
new file mode 100644
index 0000000..7486c8e
--- /dev/null
+++ b/test_polymatrix/test_tomatrixrepr.py
@@ -0,0 +1,48 @@
+import unittest
+import polymatrix
+
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
+from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr
+
+
+class TestLinearIn(unittest.TestCase):
+
+ def test_1(self):
+ underlying_terms = {
+ (0, 0): {
+ tuple(): 1.0,
+ ((1, 1),): 2.0,
+ },
+ (1, 0): {
+ ((0, 1),): 4.0,
+ ((0, 1), (1, 1)): 3.0,
+ ((1, 2),): 5.0,
+ },
+ (2, 0): {
+ ((0, 1), (1, 2)): 3.0,
+ },
+ }
+
+ expr = init_from_terms_expr(terms=underlying_terms, shape=(3, 1))
+
+ state = init_expression_state(n_param=2)
+ state, result = polymatrix.to_matrix_equations((expr,), (0, 1)).apply(state)
+
+ A0 = result.matrix_equations[0][0]
+ A1 = result.matrix_equations[0][1]
+ A2 = result.matrix_equations[0][2]
+ A3 = result.matrix_equations[0][3]
+
+ self.assertEquals(A0[0, 0], 1.0)
+
+ self.assertEquals(A1[0, 1], 2.0)
+ self.assertEquals(A1[1, 0], 4.0)
+
+ self.assertEquals(A2[1, 1], 1.5)
+ self.assertEquals(A2[1, 2], 1.5)
+ self.assertEquals(A2[1, 3], 5.0)
+
+ self.assertEquals(A3[2, 3], 1.0)
+ self.assertEquals(A3[2, 5], 1.0)
+ self.assertEquals(A3[2, 6], 1.0)