summaryrefslogtreecommitdiffstats
path: root/test_polymatrix
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--test_polymatrix/test_polymatrix.py67
1 files changed, 64 insertions, 3 deletions
diff --git a/test_polymatrix/test_polymatrix.py b/test_polymatrix/test_polymatrix.py
index 1f0bb73..3b239bf 100644
--- a/test_polymatrix/test_polymatrix.py
+++ b/test_polymatrix/test_polymatrix.py
@@ -135,12 +135,12 @@ class TestPolyMatrix(unittest.TestCase):
value = 2,
)
- def test_param_matrix_const_vector_skew_symmetric(self):
+ def test_skew_symmetric_param_matrix_const_vector(self):
"""
param = [a11 a21 a31 a41]
"""
- def skew_symmetric(idx1, idx2):
+ def skew_symmetric(degree, idx1, idx2):
if idx1 == idx2:
return idx1, idx2, 0
elif idx2 < idx1:
@@ -173,4 +173,65 @@ class TestPolyMatrix(unittest.TestCase):
eq_idx = 1,
row_idx = offset_dict[(mat, 0)] + 1,
value = -1,
- ) \ No newline at end of file
+ )
+
+ def test_const_matrix_param_gradient_vector(self):
+ """
+ param = [v11 v12 v21 v22]
+ """
+
+ def gradient(degree, v_row, monom):
+ if degree == 1:
+ factor = sum(v_row==e for e in monom) + 1
+
+ if monom[-1] < v_row:
+ n_v_row = monom[-1]
+ n_monom = sorted(monom + (v_row,), reverse=True)
+
+ if v_row <= monom[-1]:
+ n_v_row = v_row
+ n_monom = monom
+
+ return n_v_row, n_monom, factor
+
+ mat = init_poly_matrix(
+ subs={0: {(0, 0): 1, (1, 0): 1, (0, 1): 1, (1, 1): 1}},
+ )
+ vec = init_poly_vector(
+ degrees=(1,),
+ re_index_func_2=gradient,
+ )
+
+ eq = init_equation(
+ terms = [(mat, vec)],
+ n_var = 2,
+ )
+
+ terms, offset_dict = list(eq.create())
+
+ self.assert_term_in_eq(
+ terms = terms,
+ degree = 1,
+ monoms=(0,),
+ eq_idx = 0,
+ row_idx = offset_dict[(vec, 1)],
+ value = 2,
+ )
+
+ self.assert_term_in_eq(
+ terms = terms,
+ degree = 1,
+ monoms=(1,),
+ eq_idx = 0,
+ row_idx = offset_dict[(vec, 1)]+1,
+ value = 1,
+ )
+
+ self.assert_term_in_eq(
+ terms = terms,
+ degree = 1,
+ monoms=(0,),
+ eq_idx = 1,
+ row_idx = offset_dict[(vec, 1)]+1,
+ value = 1,
+ )