summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/init.py2
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py84
-rw-r--r--polymatrix/expression/mixins/cacheexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/combinationsexprmixin.py41
-rw-r--r--polymatrix/expression/mixins/degreeexprmixin.py14
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py38
-rw-r--r--polymatrix/expression/mixins/divergenceexprmixin.py3
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py8
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py92
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py26
-rw-r--r--polymatrix/expression/mixins/filterexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/filterlinearpartexprmixin.py19
-rw-r--r--polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/fromtermsexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/fromtupleexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/legendreseriesmixin.py14
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/linearmatrixinexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/linearmonomialsexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py20
-rw-r--r--polymatrix/expression/mixins/maxexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/parametrizeexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/parametrizematrixexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/productexprmixin.py62
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/quadraticmonomialsexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/squeezeexprmixin.py12
-rw-r--r--polymatrix/expression/mixins/substituteexprmixin.py26
-rw-r--r--polymatrix/expression/mixins/subtractmonomialsexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/sumexprmixin.py12
-rw-r--r--polymatrix/expression/mixins/toconstantexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/toquadraticexprmixin.py46
-rw-r--r--polymatrix/expression/mixins/tosortedvariablesmixin.py2
-rw-r--r--polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py8
-rw-r--r--polymatrix/expression/mixins/truncateexprmixin.py10
-rw-r--r--polymatrix/expression/to.py4
-rw-r--r--polymatrix/expression/utils/broadcastpolymatrix.py34
-rw-r--r--polymatrix/polymatrix/impl.py4
-rw-r--r--polymatrix/polymatrix/init.py17
-rw-r--r--polymatrix/polymatrix/mixins.py12
42 files changed, 255 insertions, 433 deletions
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 06fe23c..f8092d2 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -226,7 +226,7 @@ def init_from_terms_expr(
if isinstance(terms, PolyMatrixMixin):
shape = terms.shape
- gen_terms = terms.gen_terms()
+ gen_terms = terms.gen_data()
else:
assert shape is not None
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 7dbf1d2..e3580f0 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,9 +1,8 @@
import abc
import math
-from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl
+from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix
from polymatrix.utils.getstacklines import FrameSummary
-from polymatrix.utils.tooperatorexception import to_operator_exception
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
@@ -26,31 +25,6 @@ class AdditionExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- @staticmethod
- def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]):
- # broadcast left
- if left.shape == (1, 1) and right.shape != (1, 1):
- left = BroadcastPolyMatrixImpl(
- polynomial=left.get_poly(0, 0),
- shape=right.shape,
- )
-
- # broadcast right
- elif left.shape != (1, 1) and right.shape == (1, 1):
- right = BroadcastPolyMatrixImpl(
- polynomial=right.get_poly(0, 0),
- shape=left.shape,
- )
-
- else:
- if not (left.shape == right.shape):
- raise AssertionError(to_operator_exception(
- message=f'{left.shape} != {right.shape}',
- stack=stack,
- ))
-
- return left, right
-
# overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
@@ -59,44 +33,14 @@ class AdditionExprMixin(ExpressionBaseMixin):
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- # if left.shape == (1, 1):
- # left, right = right, left
-
- # if left.shape != (1, 1) and right.shape == (1, 1):
-
- # # @dataclassabc.dataclassabc(frozen=True)
- # # class BroadCastedPolyMatrix(PolyMatrixMixin):
- # # underlying_monomials: tuple[tuple[int], float]
- # # shape: tuple[int, int]
-
- # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
- # # return self.underlying_monomials
-
-
- # right = BroadcastPolyMatrixImpl(
- # polynomial=right.get_poly(0, 0),
- # shape=left.shape,
- # )
-
- # # all_underlying = (left, broadcasted_right)
-
- # else:
- # if not (left.shape == right.shape):
- # raise AssertionError(to_operator_exception(
- # message=f'{left.shape} != {right.shape}',
- # stack=self.stack,
- # ))
-
- # # all_underlying = (left, right)
-
- left, right = self.broadcast(left, right, self.stack)
+ left, right = broadcast_poly_matrix(left, right, self.stack)
- terms = {}
+ poly_matrix_data = {}
for row in range(left.shape[0]):
for col in range(left.shape[1]):
- terms_row_col = {}
+ poly_data = {}
for underlying in (left, right):
@@ -104,26 +48,26 @@ class AdditionExprMixin(ExpressionBaseMixin):
if polynomial is None:
continue
- if len(terms_row_col) == 0:
- terms_row_col = dict(polynomial)
+ if len(poly_data) == 0:
+ poly_data = dict(polynomial)
else:
for monomial, value in polynomial.items():
- if monomial not in terms_row_col:
- terms_row_col[monomial] = value
+ if monomial not in poly_data:
+ poly_data[monomial] = value
else:
- terms_row_col[monomial] += value
+ poly_data[monomial] += value
- if math.isclose(terms_row_col[monomial], 0):
- del terms_row_col[monomial]
+ if math.isclose(poly_data[monomial], 0):
+ del poly_data[monomial]
- if 0 < len(terms_row_col):
- terms[row, col] = terms_row_col
+ if 0 < len(poly_data):
+ poly_matrix_data[row, col] = poly_data
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=left.shape,
)
diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py
index e5aef11..22e5a6a 100644
--- a/polymatrix/expression/mixins/cacheexprmixin.py
+++ b/polymatrix/expression/mixins/cacheexprmixin.py
@@ -29,12 +29,12 @@ class CacheExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state)
if isinstance(underlying, PolyMatrixAsDictMixin):
- cached_terms = underlying.terms
+ cached_data = underlying.data
else:
- cached_terms = dict(underlying.gen_terms())
+ cached_data = dict(underlying.gen_data())
poly_matrix = init_poly_matrix(
- terms=cached_terms,
+ data=cached_data,
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py
index 9da69af..632efd5 100644
--- a/polymatrix/expression/mixins/combinationsexprmixin.py
+++ b/polymatrix/expression/mixins/combinationsexprmixin.py
@@ -31,19 +31,6 @@ class CombinationsExprMixin(ExpressionBaseMixin):
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
- # if self.degree == 0:
- # terms = {(0, 0): {tuple(): 1.0}}
-
- # poly_matrix = init_poly_matrix(
- # terms=terms,
- # shape=(1, 1),
- # )
-
- # elif self.degree == 1:
- # state, monomials = self.expression.apply(state=state)
- # poly_matrix = monomials
-
- # else:
state, poly_matrix = self.expression.apply(state=state)
@@ -55,13 +42,13 @@ class CombinationsExprMixin(ExpressionBaseMixin):
indices = tuple(gen_indices())
- terms = {}
+ poly_matrix_data = {}
for row, indexing in enumerate(indices):
# x.combinations((0, 1, 2)) produces [1, x, x**2]
if len(indexing) == 0:
- terms[row, 0] = {tuple(): 1.0}
+ poly_matrix_data[row, 0] = {tuple(): 1.0}
continue
def acc_product(left, row):
@@ -80,29 +67,11 @@ class CombinationsExprMixin(ExpressionBaseMixin):
initial={},
)
- terms[row, 0] = polynomial
+ poly_matrix_data[row, 0] = polynomial
poly_matrix = init_poly_matrix(
- terms=terms,
- shape=(len(terms), 1),
+ data=poly_matrix_data,
+ shape=(len(poly_matrix_data), 1),
)
- # indices = filter(lambda v: sum(v) <= self.degree, itertools.product(*(range(self.degree) for _ in range(dim))))
-
- # state, monomials = get_monomial_indices(state, self.monomials)
-
- # combinations = tuple(itertools.combinations_with_replacement(monomials, self.number))
-
- # terms = {}
-
- # for row, combination in enumerate(combinations):
- # combination_monomial = merge_monomial_indices(combination)
-
- # terms[row, 0] = {combination_monomial: 1.0}
-
- # poly_matrix = init_poly_matrix(
- # terms=terms,
- # shape=(math.comb(len(monomials) + self.number - 1, self.number), 1),
- # )
-
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/degreeexprmixin.py b/polymatrix/expression/mixins/degreeexprmixin.py
index 273add2..71b936a 100644
--- a/polymatrix/expression/mixins/degreeexprmixin.py
+++ b/polymatrix/expression/mixins/degreeexprmixin.py
@@ -26,26 +26,24 @@ class DegreeExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- underlying_terms = underlying.get_poly(row, col)
+ polynomial = underlying.get_poly(row, col)
- if underlying_terms is None or len(underlying_terms) == 0:
+ if polynomial is None or len(polynomial) == 0:
continue
def gen_degrees():
- for monomial, _ in underlying_terms.items():
+ for monomial, _ in polynomial.items():
yield sum(count for _, count in monomial)
- # degrees = tuple(gen_degrees())
-
- terms[row, col] = {tuple(): max(gen_degrees())}
+ poly_matrix_data[row, col] = {tuple(): max(gen_degrees())}
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 38ea9cc..0b6ec6e 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -56,7 +56,7 @@ class DerivativeExprMixin(ExpressionBaseMixin):
stack=self.stack,
))
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
@@ -76,10 +76,10 @@ class DerivativeExprMixin(ExpressionBaseMixin):
)
if 0 < len(derivation):
- terms[row, col] = derivation
+ poly_matrix_data[row, col] = derivation
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(underlying.shape[0], len(diff_wrt_variables)),
)
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
index 7614669..d848bd2 100644
--- a/polymatrix/expression/mixins/determinantexprmixin.py
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -26,7 +26,7 @@ class DeterminantExprMixin(ExpressionBaseMixin):
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
- raise Exception('not implemented')
+ # raise Exception('not implemented')
if self in state.cache:
return state, state.cache[self]
@@ -35,7 +35,7 @@ class DeterminantExprMixin(ExpressionBaseMixin):
assert underlying.shape[0] == underlying.shape[1]
- inequality_terms = {}
+ inequality_data = {}
auxillary_equations = {}
index_start = state.n_param
@@ -43,69 +43,69 @@ class DeterminantExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
- current_inequality_terms = collections.defaultdict(float)
+ polynomial = collections.defaultdict(float)
# f in f-v^T@x-r^2
# terms = underlying.get_poly(row, row)
try:
- underlying_terms = underlying.get_poly(row, row)
+ underlying_poly = underlying.get_poly(row, row)
except KeyError:
pass
else:
- for monomial, value in underlying_terms.items():
- current_inequality_terms[monomial] += value
+ for monomial, value in underlying_poly.items():
+ polynomial[monomial] += value
for inner_row in range(row):
# -v^T@x in f-v^T@x-r^2
# terms = underlying.get_poly(row, inner_row)
try:
- underlying_terms = underlying.get_poly(row, inner_row)
+ underlying_poly = underlying.get_poly(row, inner_row)
except KeyError:
pass
else:
- for monomial, value in underlying_terms.items():
+ for monomial, value in underlying_poly.items():
new_monomial = monomial + (index_start + rel_index + inner_row,)
- current_inequality_terms[new_monomial] -= value
+ polynomial[new_monomial] -= value
# auxillary terms
# ---------------
- auxillary_term = collections.defaultdict(float)
+ auxillary_polynomial = collections.defaultdict(float)
for inner_col in range(row):
# P@x in P@x-v
key = tuple(reversed(sorted((inner_row, inner_col))))
try:
- underlying_terms = underlying.get_poly(*key)
+ underlying_poly = underlying.get_poly(*key)
except KeyError:
pass
else:
- for monomial, value in underlying_terms.items():
+ for monomial, value in underlying_poly.items():
new_monomial = monomial + (index_start + rel_index + inner_col,)
- auxillary_term[new_monomial] += value
+ auxillary_polynomial[new_monomial] += value
# -v in P@x-v
try:
- underlying_terms = underlying.get_poly(row, inner_row)
+ underlying_poly = underlying.get_poly(row, inner_row)
except KeyError:
pass
else:
- for monomial, value in underlying_terms.items():
- auxillary_term[monomial] -= value
+ for monomial, value in underlying_poly.items():
+ auxillary_polynomial[monomial] -= value
x_variable = index_start + rel_index + inner_row
assert x_variable not in state.auxillary_equations
- auxillary_equations[x_variable] = dict(auxillary_term)
+ auxillary_equations[x_variable] = dict(auxillary_polynomial)
rel_index += row
- inequality_terms[row, 0] = dict(current_inequality_terms)
+ inequality_data[row, 0] = dict(polynomial)
state = state.register(rel_index)
poly_matrix = init_poly_matrix(
- terms=inequality_terms,
+ data=inequality_data,
shape=(underlying.shape[0], 1),
)
diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py
index 10db5b4..fc12abd 100644
--- a/polymatrix/expression/mixins/divergenceexprmixin.py
+++ b/polymatrix/expression/mixins/divergenceexprmixin.py
@@ -55,8 +55,7 @@ class DivergenceExprMixin(ExpressionBaseMixin):
monomial_terms[monomial] += value
poly_matrix = init_poly_matrix(
- # terms={(0, 0): dict(monomial_terms)},
- terms={(0, 0): monomial_terms},
+ data={(0, 0): monomial_terms},
shape=(1, 1),
)
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index 0f2fada..65e6bc3 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -55,7 +55,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
state=state,
left=left,
right=init_poly_matrix(
- terms=right_inv,
+ data=right_inv,
shape=(1, 1),
)
)
@@ -64,7 +64,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
if self in state.cache:
return state, state.cache[self]
- terms = {}
+ poly_matrix_data = {}
division_variable = state.n_param
state = state.register(n_param=1)
@@ -80,7 +80,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
for monomial, value in underlying_terms.items():
yield monomial + ((division_variable, 1),), value
- terms[row, col] = dict(gen_monomial_terms())
+ poly_matrix_data[row, col] = dict(gen_monomial_terms())
def gen_auxillary_terms():
for monomial, value in right_poly.items():
@@ -94,7 +94,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
auxillary_terms[tuple()] -= 1
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=left.shape,
)
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index 17be163..bfe6397 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -23,11 +23,6 @@ class ElemMultExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # # overwrites the abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.left.shape
-
@staticmethod
def elem_mult(
state: ExpressionState,
@@ -53,23 +48,23 @@ class ElemMultExprMixin(ExpressionBaseMixin):
shape=left.shape,
)
- terms = {}
+ poly_matrix_data = {}
for poly_row in range(left.shape[0]):
for poly_col in range(left.shape[1]):
- terms_row_col = {}
+ polynomial = {}
- left_terms = left.get_poly(poly_row, poly_col)
- if left_terms is None:
+ left_polynomial = left.get_poly(poly_row, poly_col)
+ if left_polynomial is None:
continue
- right_terms = right.get_poly(poly_row, poly_col)
- if right_terms is None:
+ right_polynomial = right.get_poly(poly_row, poly_col)
+ if right_polynomial is None:
continue
for (left_monomial, left_value), (right_monomial, right_value) \
- in itertools.product(left_terms.items(), right_terms.items()):
+ in itertools.product(left_polynomial.items(), right_polynomial.items()):
value = left_value * right_value
@@ -80,16 +75,16 @@ class ElemMultExprMixin(ExpressionBaseMixin):
new_monomial = merge_monomial_indices((left_monomial, right_monomial))
- if new_monomial not in terms_row_col:
- terms_row_col[new_monomial] = 0
+ if new_monomial not in polynomial:
+ polynomial[new_monomial] = 0
- terms_row_col[new_monomial] += value
+ polynomial[new_monomial] += value
- if 0 < len(terms_row_col):
- terms[poly_row, poly_col] = terms_row_col
+ if 0 < len(polynomial):
+ poly_matrix_data[poly_row, poly_col] = polynomial
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=left.shape,
)
@@ -105,64 +100,3 @@ class ElemMultExprMixin(ExpressionBaseMixin):
state, right = self.right.apply(state=state)
return self.elem_mult(state, left, right)
-
- # if left.shape != right.shape and left.shape == (1, 1):
- # left, right = right, left
-
- # if right.shape == (1, 1):
- # right_poly = right.get_poly(0, 0)
-
- # @dataclassabc.dataclassabc(frozen=True)
- # class BroadCastedPolyMatrix(PolyMatrixMixin):
- # underlying: tuple[tuple[int], float]
- # shape: tuple[int, int]
-
- # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
- # return self.underlying
-
- # right = BroadCastedPolyMatrix(
- # underlying=right_poly,
- # shape=left.shape,
- # )
-
- # terms = {}
-
- # for poly_row in range(left.shape[0]):
- # for poly_col in range(left.shape[1]):
-
- # terms_row_col = {}
-
- # left_terms = left.get_poly(poly_row, poly_col)
- # if left_terms is None:
- # continue
-
- # right_terms = right.get_poly(poly_row, poly_col)
- # if right_terms is None:
- # continue
-
- # for (left_monomial, left_value), (right_monomial, right_value) \
- # in itertools.product(left_terms.items(), right_terms.items()):
-
- # value = left_value * right_value
-
- # # if value == 0:
- # # continue
-
- # # monomial = tuple(sorted(left_monomial + right_monomial))
-
- # new_monomial = merge_monomial_indices((left_monomial, right_monomial))
-
- # if new_monomial not in terms_row_col:
- # terms_row_col[new_monomial] = 0
-
- # terms_row_col[new_monomial] += value
-
- # if 0 < len(terms_row_col):
- # terms[poly_row, poly_col] = terms_row_col
-
- # poly_matrix = init_poly_matrix(
- # terms=terms,
- # shape=left.shape,
- # )
-
- # return state, poly_matrix
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index 0c23c1e..b86cafa 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -52,18 +52,18 @@ class EvalExprMixin(ExpressionBaseMixin):
initial=(state, tuple(), tuple())
)
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- underlying_terms = underlying.get_poly(row, col)
- if underlying_terms is None:
+ underlying_polynomial = underlying.get_poly(row, col)
+ if underlying_polynomial is None:
continue
- terms_row_col = {}
+ polynomial = {}
- for monomial, value in underlying_terms.items():
+ for monomial, value in underlying_polynomial.items():
def acc_monomial(acc, next):
new_monomial, value = acc
@@ -83,20 +83,20 @@ class EvalExprMixin(ExpressionBaseMixin):
initial=(tuple(), value),
))
- if new_monomial not in terms_row_col:
- terms_row_col[new_monomial] = 0
+ if new_monomial not in polynomial:
+ polynomial[new_monomial] = 0
- terms_row_col[new_monomial] += new_value
+ polynomial[new_monomial] += new_value
# delete zero entries
- if math.isclose(terms_row_col[new_monomial], 0, abs_tol=1e-12):
- del terms_row_col[new_monomial]
+ if math.isclose(polynomial[new_monomial], 0, abs_tol=1e-12):
+ del polynomial[new_monomial]
- if 0 < len(terms_row_col):
- terms[row, col] = terms_row_col
+ if 0 < len(polynomial):
+ poly_matrix_data[row, col] = polynomial
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py
index 86b5c6a..a240102 100644
--- a/polymatrix/expression/mixins/filterexprmixin.py
+++ b/polymatrix/expression/mixins/filterexprmixin.py
@@ -36,14 +36,14 @@ class FilterExprMixin(ExpressionBaseMixin):
assert predicator.shape[1] == 1
assert underlying.shape[0] == predicator.shape[0]
- terms = {}
+ poly_matrix_data = {}
row_index = 0
for row in range(underlying.shape[0]):
- underlying_terms = underlying.get_poly(row, 0)
+ underlying_polynomial = underlying.get_poly(row, 0)
- if underlying_terms is None:
+ if underlying_polynomial is None:
continue
predicator_poly = predicator.get_poly(row, 0)
@@ -64,11 +64,11 @@ class FilterExprMixin(ExpressionBaseMixin):
predicator_value = 0
if (predicator_value != 0) is not self.inverse:
- terms[row_index, 0] = underlying_terms
+ poly_matrix_data[row_index, 0] = underlying_polynomial
row_index += 1
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(row_index, 1),
)
diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
index dfca992..562ac61 100644
--- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py
+++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
@@ -29,7 +29,7 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin):
state, variables = self.variables.apply(state=state)
def gen_variable_monomials():
- for _, term in variables.gen_terms():
+ for _, term in variables.gen_data():
assert len(term) == 1, f'{term} should have only a single monomial'
for monomial in term.keys():
@@ -37,18 +37,18 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin):
variable_monomials = tuple(gen_variable_monomials())
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- underlying_terms = underlying.get_poly(row, col)
- if underlying_terms is None:
+ underlying_polynomial = underlying.get_poly(row, col)
+ if underlying_polynomial is None:
continue
- monomial_terms = collections.defaultdict(float)
+ polynomial = collections.defaultdict(float)
- for monomial, value in underlying_terms.items():
+ for monomial, value in underlying_polynomial.items():
for variable_monomial in variable_monomials:
@@ -63,15 +63,14 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin):
# take the first that matches
if all(variable not in remainder for variable in variable_monomial):
- monomial_terms[remainder] += value
+ polynomial[remainder] += value
break
- terms[row, col] = monomial_terms
- # terms[row, col] = dict(monomial_terms)
+ poly_matrix_data[row, col] = dict(polynomial)
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py
index 7d48c0f..c7d90a1 100644
--- a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py
+++ b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py
@@ -28,18 +28,18 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin):
assert underlying.shape[0] == underlying.shape[1]
- terms = {}
+ poly_matrix_data = {}
var_index = 0
for row in range(underlying.shape[0]):
for col in range(row, underlying.shape[1]):
- terms[var_index, 0] = underlying.get_poly(row, col)
+ poly_matrix_data[var_index, 0] = underlying.get_poly(row, col)
var_index += 1
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(var_index, 1),
)
diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py
index 8aac10b..22512d5 100644
--- a/polymatrix/expression/mixins/fromtermsexprmixin.py
+++ b/polymatrix/expression/mixins/fromtermsexprmixin.py
@@ -26,10 +26,10 @@ class FromTermsExprMixin(ExpressionBaseMixin):
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
- terms = {coord: dict(monomials) for coord, monomials in self.terms}
+ data = {coord: dict(monomials) for coord, monomials in self.terms}
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=data,
shape=self.shape,
)
diff --git a/polymatrix/expression/mixins/fromtupleexprmixin.py b/polymatrix/expression/mixins/fromtupleexprmixin.py
index 19b3ed3..92b965c 100644
--- a/polymatrix/expression/mixins/fromtupleexprmixin.py
+++ b/polymatrix/expression/mixins/fromtupleexprmixin.py
@@ -100,7 +100,7 @@ class FromTupleExprMixin(ExpressionBaseMixin):
polynomials[poly_row, poly_col] = polynomial
poly_matrix = init_poly_matrix(
- terms=polynomials,
+ data=polynomials,
shape=(len(self.data), len(self.data[0])),
)
diff --git a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py
index 21d16fd..445e776 100644
--- a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py
+++ b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py
@@ -106,7 +106,7 @@ class HalfNewtonPolytopeExprMixin(ExpressionBaseMixin):
)
poly_matrix = init_poly_matrix(
- terms={(row, 0): {monom: 1} for row, monom in enumerate(monomials)},
+ data={(row, 0): {monom: 1} for row, monom in enumerate(monomials)},
shape=(len(monomials), 1),
)
diff --git a/polymatrix/expression/mixins/legendreseriesmixin.py b/polymatrix/expression/mixins/legendreseriesmixin.py
index 194e65e..aa3de58 100644
--- a/polymatrix/expression/mixins/legendreseriesmixin.py
+++ b/polymatrix/expression/mixins/legendreseriesmixin.py
@@ -36,27 +36,27 @@ class LegendreSeriesMixin(ExpressionBaseMixin):
else:
degrees = self.degrees
- terms = {}
+ poly_matrix_data = {}
for degree in degrees:
# for degree in self.degree:
poly = underlying.get_poly(degree, 0)
- terms[degree, 0] = dict(poly)
+ poly_matrix_data[degree, 0] = dict(poly)
if 2 <= degree:
poly = underlying.get_poly(degree - 2, 0)
factor = - (degree - 1) / (degree + 1)
for m, v in poly.items():
- if m in terms[degree, 0]:
- terms[degree, 0][m] += v*factor
+ if m in poly_matrix_data[degree, 0]:
+ poly_matrix_data[degree, 0][m] += v*factor
else:
- terms[degree, 0][m] = v*factor
+ poly_matrix_data[degree, 0][m] = v*factor
poly_matrix = init_poly_matrix(
- terms=terms,
- shape=(len(terms), 1),
+ data=poly_matrix_data,
+ shape=(len(poly_matrix_data), 1),
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index ded61ec..f0b57d3 100644
--- a/polymatrix/expression/mixins/linearinexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -87,7 +87,7 @@ class LinearInExprMixin(ExpressionBaseMixin):
terms[row, col][p_monomial] = value
poly_matrix = init_poly_matrix(
- terms=dict(terms),
+ data=dict(terms),
shape=(underlying.shape[0], len(monomials)),
)
diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
index 4d25a1e..3bb8dfe 100644
--- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py
+++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
@@ -53,7 +53,7 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin):
break
poly_matrix = init_poly_matrix(
- terms=dict(terms),
+ data=dict(terms),
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
index ed9e79f..c7de7f1 100644
--- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
@@ -62,14 +62,14 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin):
linear_monomials = sort_monomials(set(gen_linear_monomials()))
- def gen_terms():
+ def gen_data():
for index, monomial in enumerate(linear_monomials):
yield (index, 0), {monomial: 1.0}
- terms = dict(gen_terms())
+ poly_matrix_data = dict(gen_data())
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(len(linear_monomials), 1),
)
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index b7ae5ce..22a6c94 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -42,30 +42,30 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
stack=self.stack,
))
- terms = {}
+ poly_matrix_data = {}
for poly_row in range(left.shape[0]):
for poly_col in range(right.shape[1]):
- terms_row_col = {}
+ polynomial = {}
for index_k in range(left.shape[1]):
- left_terms = left.get_poly(poly_row, index_k)
- if left_terms is None:
+ left_polynomial = left.get_poly(poly_row, index_k)
+ if left_polynomial is None:
continue
- right_terms = right.get_poly(index_k, poly_col)
- if right_terms is None:
+ right_polynomial = right.get_poly(index_k, poly_col)
+ if right_polynomial is None:
continue
- multiply_polynomial(left_terms, right_terms, terms_row_col)
+ multiply_polynomial(left_polynomial, right_polynomial, polynomial)
- if 0 < len(terms_row_col):
- terms[poly_row, poly_col] = terms_row_col
+ if 0 < len(polynomial):
+ poly_matrix_data[poly_row, poly_col] = polynomial
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(left.shape[0], right.shape[1]),
)
diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py
index bfabc4d..fa1155e 100644
--- a/polymatrix/expression/mixins/maxexprmixin.py
+++ b/polymatrix/expression/mixins/maxexprmixin.py
@@ -21,7 +21,7 @@ class MaxExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state)
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
@@ -37,10 +37,10 @@ class MaxExprMixin(ExpressionBaseMixin):
values = tuple(gen_values())
if 0 < len(values):
- terms[row, 0] = {tuple(): max(values)}
+ poly_matrix_data[row, 0] = {tuple(): max(values)}
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(underlying.shape[0], 1),
)
diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py
index 712ff3f..0160b9a 100644
--- a/polymatrix/expression/mixins/parametrizeexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizeexprmixin.py
@@ -50,14 +50,14 @@ class ParametrizeExprMixin(ExpressionBaseMixin):
# assert underlying.shape[1] == 1
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
var_index = start + row
- terms[row, 0] = {((var_index, 1),): 1}
+ poly_matrix_data[row, 0] = {((var_index, 1),): 1}
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py
index 1aa393b..c51bcf0 100644
--- a/polymatrix/expression/mixins/parametrizematrixexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizematrixexprmixin.py
@@ -33,13 +33,13 @@ class ParametrizeMatrixExprMixin(ExpressionBaseMixin):
assert underlying.shape[1] == 1
- terms = {}
+ poly_matrix_data = {}
var_index = 0
for row in range(underlying.shape[0]):
for _ in range(row, underlying.shape[0]):
- terms[var_index, 0] = {((state.n_param + var_index, 1),): 1.0}
+ poly_matrix_data[var_index, 0] = {((state.n_param + var_index, 1),): 1.0}
var_index += 1
@@ -49,7 +49,7 @@ class ParametrizeMatrixExprMixin(ExpressionBaseMixin):
)
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(var_index, 1),
)
diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py
index 0168e27..b41865b 100644
--- a/polymatrix/expression/mixins/productexprmixin.py
+++ b/polymatrix/expression/mixins/productexprmixin.py
@@ -1,6 +1,7 @@
import abc
import itertools
+from polymatrix.polymatrix.typing import PolynomialData
from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
from polymatrix.utils.getstacklines import FrameSummary
@@ -31,23 +32,9 @@ class ProductExprMixin(ExpressionBaseMixin):
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
-
- # if self.number == 0:
- # terms = {(0, 0): {tuple(): 1.0}}
-
- # poly_matrix = init_poly_matrix(
- # terms=terms,
- # shape=(1, 1),
- # )
-
- # elif self.number == 1:
- # state, monomials = self.monomials.apply(state=state)
- # poly_matrix = monomials
-
- # else:
if len(self.underlying) == 0:
- terms = {(0,0): {tuple(): 1}}
+ poly_matrix_data = {(0,0): {tuple(): 1}}
else:
@@ -65,23 +52,6 @@ class ProductExprMixin(ExpressionBaseMixin):
initial=(state, tuple())
)
- # highest_degrees = tuple(e.shape[0] for e in underlying)
-
- # if self.degrees is None:
- # degrees = range(sum(highest_degrees))
-
- # else:
- # degrees = self.degrees
-
- # max_degree = max(degrees)
-
- # for poly_matrix in underlying:
- # if not (max_degree <= poly_matrix.shape[0]):
- # raise AssertionError(to_operator_exception(
- # message=f'{poly_matrix.shape[0]} < {max_degree}',
- # stack=self.stack,
- # ))
-
def gen_indices():
product_indices = itertools.product(*(range(e.shape[0]) for e in underlying))
@@ -92,26 +62,16 @@ class ProductExprMixin(ExpressionBaseMixin):
yield from filter(lambda v: sum(v) in self.degrees, product_indices)
indices = tuple(gen_indices())
- # print(indices)
-
- # indices = filter(lambda v: sum(v) <= self.degree, itertools.product(*(range(self.degree) for _ in range(dim))))
- terms = {}
+ poly_matrix_data = {}
for row, indexing in enumerate(indices):
- # def acc_product(acc, v):
- # left_monomials = acc
- # polymatrix, row = v
-
- # right_monomials = polymatrix.get_poly(row, 0).keys()
-
- # if left_monomials is (None,):
- # return right_monomials
-
- # return tuple(multiply_monomials(left_monomials, right_monomials))
-
- def acc_product(left, v):
+ def acc_product(
+ left: PolynomialData,
+ v: tuple[ExpressionBaseMixin, int],
+ ) -> PolynomialData:
+
poly_matrix, row = v
right = poly_matrix.get_poly(row, 0)
@@ -129,11 +89,11 @@ class ProductExprMixin(ExpressionBaseMixin):
initial={},
)
- terms[row, 0] = polynomial
+ poly_matrix_data[row, 0] = polynomial
poly_matrix = init_poly_matrix(
- terms=terms,
- shape=(len(terms), 1),
+ data=poly_matrix_data,
+ shape=(len(poly_matrix_data), 1),
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index e739364..2cc9e12 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -87,7 +87,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
terms[row, col][p_monomial] += value
poly_matrix = init_poly_matrix(
- terms=dict((k, dict(v)) for k, v in terms.items()),
+ data=dict((k, dict(v)) for k, v in terms.items()),
shape=2*(len(sos_monomials),),
)
diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
index 99befc8..3a7de30 100644
--- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
@@ -65,14 +65,14 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
sos_monomials = tuple(sorted(set(gen_sos_monomials()), key=lambda m: (len(m), m)))
- def gen_terms():
+ def gen_data():
for index, monomial in enumerate(sos_monomials):
yield (index, 0), {monomial: 1.0}
- terms = dict(gen_terms())
+ poly_matrix_data = dict(gen_data())
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(len(sos_monomials), 1),
)
diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py
index eed3d71..888300f 100644
--- a/polymatrix/expression/mixins/squeezeexprmixin.py
+++ b/polymatrix/expression/mixins/squeezeexprmixin.py
@@ -23,7 +23,7 @@ class SqueezeExprMixin(ExpressionBaseMixin):
assert underlying.shape[1] == 1
- terms = {}
+ poly_matrix_data = {}
row_index = 0
for row in range(underlying.shape[0]):
@@ -32,18 +32,18 @@ class SqueezeExprMixin(ExpressionBaseMixin):
if polynomial is None:
continue
- terms_row_col = {}
+ polynomial = {}
for monomial, value in polynomial.items():
if value != 0.0:
- terms_row_col[monomial] = value
+ polynomial[monomial] = value
- if len(terms_row_col):
- terms[row_index, 0] = terms_row_col
+ if len(polynomial):
+ poly_matrix_data[row_index, 0] = polynomial
row_index += 1
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(row_index, 1),
)
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
index b75a2de..2bc36d1 100644
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ b/polymatrix/expression/mixins/substituteexprmixin.py
@@ -67,7 +67,7 @@ class SubstituteExprMixin(ExpressionBaseMixin):
else:
assert len(variable_indices) == len(substitutions), f'{substitutions=}'
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
@@ -76,11 +76,11 @@ class SubstituteExprMixin(ExpressionBaseMixin):
if polynomial is None:
continue
- terms_row_col = collections.defaultdict(float)
+ polynomial = collections.defaultdict(float)
for monomial, value in polynomial.items():
- terms_monomial = {tuple(): value}
+ substituted_monomial = {tuple(): value}
for variable, count in monomial:
if variable in variable_indices:
@@ -90,24 +90,24 @@ class SubstituteExprMixin(ExpressionBaseMixin):
for _ in range(count):
next = {}
- multiply_polynomial(terms_monomial, substitution, next)
- terms_monomial = next
+ multiply_polynomial(substituted_monomial, substitution, next)
+ substituted_monomial = next
else:
next = {}
- multiply_polynomial(terms_monomial, {((variable, count),): 1.0}, next)
- terms_monomial = next
+ multiply_polynomial(substituted_monomial, {((variable, count),): 1.0}, next)
+ substituted_monomial = next
- for monomial, value in terms_monomial.items():
- terms_row_col[monomial] += value
+ for monomial, value in substituted_monomial.items():
+ polynomial[monomial] += value
- terms_row_col = {key: val for key, val in terms_row_col.items() if not math.isclose(val, 0, abs_tol=1e-12)}
+ polynomial = {key: val for key, val in polynomial.items() if not math.isclose(val, 0, abs_tol=1e-12)}
- if 0 < len(terms_row_col):
- terms[row, col] = terms_row_col
+ if 0 < len(polynomial):
+ poly_matrix_data[row, col] = polynomial
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
index 9275332..a01b5e7 100644
--- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
@@ -44,10 +44,10 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin):
remainders = sort_monomials(set(gen_remainders()))
- terms = {(row, 0): {remainder: 1.0} for row, remainder in enumerate(remainders)}
+ data = {(row, 0): {remainder: 1.0} for row, remainder in enumerate(remainders)}
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=data,
shape=(len(remainders), 1),
)
diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py
index 812b260..83f9397 100644
--- a/polymatrix/expression/mixins/sumexprmixin.py
+++ b/polymatrix/expression/mixins/sumexprmixin.py
@@ -29,7 +29,7 @@ class SumExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state)
- terms = collections.defaultdict(dict)
+ poly_matrix_data = collections.defaultdict(dict)
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
@@ -38,16 +38,16 @@ class SumExprMixin(ExpressionBaseMixin):
if polynomial is None:
continue
- term_monomials = terms[row, 0]
+ polynomial = poly_matrix_data[row, 0]
for monomial, value in polynomial.items():
- if monomial in term_monomials:
- term_monomials[monomial] += value
+ if monomial in polynomial:
+ polynomial[monomial] += value
else:
- term_monomials[monomial] = value
+ polynomial[monomial] = value
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(underlying.shape[0], 1),
)
diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py
index b8b3d64..c3dc4e2 100644
--- a/polymatrix/expression/mixins/toconstantexprmixin.py
+++ b/polymatrix/expression/mixins/toconstantexprmixin.py
@@ -39,7 +39,7 @@ class ToConstantExprMixin(ExpressionBaseMixin):
terms[row, col][tuple()] = polynomial[tuple()]
poly_matrix = init_poly_matrix(
- terms=dict(terms),
+ data=dict(terms),
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py
index d10fcd9..c6b692b 100644
--- a/polymatrix/expression/mixins/toquadraticexprmixin.py
+++ b/polymatrix/expression/mixins/toquadraticexprmixin.py
@@ -9,7 +9,6 @@ from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
-# to be deleted?
class ToQuadraticExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
@@ -25,17 +24,17 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
state = [state]
- terms = {}
+ poly_matrix_data = {}
auxillary_equations_from_quadratic = {}
def to_quadratic(monomial_terms):
- terms_row_col = collections.defaultdict(float)
+ polynomial = collections.defaultdict(float)
for monomial, value in monomial_terms.items():
if 2 < len(monomial):
current_aux = state[0].n_param
- terms_row_col[(monomial[0], current_aux)] += value
+ polynomial[(monomial[0], current_aux)] += value
state[0] = state[0].register(n_param=1)
for variable in monomial[1:-2]:
@@ -52,10 +51,10 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
}
else:
- terms_row_col[monomial] += value
+ polynomial[monomial] += value
# return dict(terms_row_col)
- return terms_row_col
+ return polynomial
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
@@ -64,43 +63,18 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
if polynomial is None:
continue
- terms_row_col = to_quadratic(
+ polynomial = to_quadratic(
monomial_terms=polynomial,
)
- # terms_row_col = collections.defaultdict(float)
-
- # for monomial, value in underlying_terms.items():
-
- # if 2 < len(monomial):
- # current_aux = state.n_param
- # terms_row_col[(monomial[0], current_aux)] += value
- # state = state.register(n_param=1)
-
- # for variable in monomial[1:-2]:
- # auxillary_equations[current_aux] = {
- # (variable, current_aux + 1): 1,
- # (current_aux,): -1,
- # }
- # state = state.register(n_param=1)
- # current_aux += 1
-
- # auxillary_equations[current_aux] = {
- # (monomial[-2], monomial[-1]): 1,
- # (current_aux,): -1,
- # }
-
- # else:
- # terms_row_col[monomial] += value
-
- terms[row, col] = terms_row_col
+ poly_matrix_data[row, col] = polynomial
def gen_auxillary_equations():
for key, monomial_terms in state[0].auxillary_equations.items():
- terms_row_col = to_quadratic(
+ polynomial = to_quadratic(
monomial_terms=monomial_terms,
)
- yield key, terms_row_col
+ yield key, polynomial
state = dataclasses.replace(
state[0],
@@ -108,7 +82,7 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
)
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=underlying.shape,
)
diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py
index a61a7e9..32e2162 100644
--- a/polymatrix/expression/mixins/tosortedvariablesmixin.py
+++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py
@@ -31,7 +31,7 @@ class ToSortedVariablesExprMixin(ExpressionBaseMixin):
yield (row, 0), {((index, 1),): 1}
poly_matrix = init_poly_matrix(
- terms=dict(gen_sorted_vector()),
+ data=dict(gen_sorted_vector()),
shape=(len(variable_indices), 1),
)
diff --git a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py
index 8801147..488c991 100644
--- a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py
+++ b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py
@@ -44,20 +44,20 @@ class ToSymmetricMatrixExprMixin(ExpressionBaseMixin):
n_row = invert_binomial_coefficient(underlying.shape[0])
- terms = {}
+ poly_matrix_data = {}
var_index = 0
for row in range(n_row):
for col in range(row, n_row):
- terms[row, col] = underlying.get_poly(var_index, 0)
+ poly_matrix_data[row, col] = underlying.get_poly(var_index, 0)
if row != col:
- terms[col, row] = terms[row, col]
+ poly_matrix_data[col, row] = poly_matrix_data[row, col]
var_index += 1
poly_matrix = init_poly_matrix(
- terms=terms,
+ data=poly_matrix_data,
shape=(n_row, n_row),
)
diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py
index fdf3921..2fcb970 100644
--- a/polymatrix/expression/mixins/truncateexprmixin.py
+++ b/polymatrix/expression/mixins/truncateexprmixin.py
@@ -43,7 +43,7 @@ class TruncateExprMixin(ExpressionBaseMixin):
state, variable_indices = get_variable_indices_from_variable(state, self.variables)
cond = lambda idx: idx in variable_indices
- terms = {}
+ poly_matrix_data = {}
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
@@ -52,19 +52,19 @@ class TruncateExprMixin(ExpressionBaseMixin):
if polynomial is None:
continue
- terms_row_col = {}
+ polynomial = {}
for monomial, value in polynomial.items():
degree = sum((count for var_idx, count in monomial if cond(var_idx)))
if (degree in self.degrees) is not self.inverse:
- terms_row_col[monomial] = value
+ polynomial[monomial] = value
- terms[row, col] = terms_row_col
+ poly_matrix_data[row, col] = polynomial
poly_matrix = init_poly_matrix(
- terms=dict(terms),
+ data=dict(poly_matrix_data),
shape=underlying.shape,
)
diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py
index b02bb11..cebc75a 100644
--- a/polymatrix/expression/to.py
+++ b/polymatrix/expression/to.py
@@ -31,7 +31,7 @@ def to_constant(
A = np.zeros(underlying.shape, dtype=np.double)
- for (row, col), polynomial in underlying.gen_terms():
+ for (row, col), polynomial in underlying.gen_data():
for monomial, value in polynomial.items():
if len(monomial) == 0:
A[row, col] = value
@@ -86,7 +86,7 @@ def to_sympy(
A = np.zeros(underlying.shape, dtype=object)
- for (row, col), polynomial in underlying.gen_terms():
+ for (row, col), polynomial in underlying.gen_data():
sympy_polynomial = 0
diff --git a/polymatrix/expression/utils/broadcastpolymatrix.py b/polymatrix/expression/utils/broadcastpolymatrix.py
new file mode 100644
index 0000000..e8b3ff3
--- /dev/null
+++ b/polymatrix/expression/utils/broadcastpolymatrix.py
@@ -0,0 +1,34 @@
+from polymatrix.polymatrix.init import init_broadcast_poly_matrix, init_poly_matrix
+from polymatrix.utils.getstacklines import FrameSummary
+from polymatrix.utils.tooperatorexception import to_operator_exception
+from polymatrix.polymatrix.abc import PolyMatrix
+
+
+def broadcast_poly_matrix(
+ left: PolyMatrix,
+ right: PolyMatrix,
+ stack: tuple[FrameSummary],
+) -> PolyMatrix:
+
+ # broadcast left
+ if left.shape == (1, 1) and right.shape != (1, 1):
+ left = init_broadcast_poly_matrix(
+ data=left.get_poly(0, 0),
+ shape=right.shape,
+ )
+
+ # broadcast right
+ elif left.shape != (1, 1) and right.shape == (1, 1):
+ right = init_broadcast_poly_matrix(
+ data=right.get_poly(0, 0),
+ shape=left.shape,
+ )
+
+ else:
+ if not (left.shape == right.shape):
+ raise AssertionError(to_operator_exception(
+ message=f'{left.shape} != {right.shape}',
+ stack=stack,
+ ))
+
+ return left, right
diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py
index f957685..04a7872 100644
--- a/polymatrix/polymatrix/impl.py
+++ b/polymatrix/polymatrix/impl.py
@@ -7,11 +7,11 @@ from polymatrix.polymatrix.typing import PolynomialData
@dataclassabc.dataclassabc(frozen=True)
class PolyMatrixImpl(PolyMatrix):
- terms: dict[tuple[int, int], PolynomialData]
+ data: dict[tuple[int, int], PolynomialData]
shape: tuple[int, int]
@dataclassabc.dataclassabc(frozen=True)
class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin):
- polynomial: tuple[tuple[int], float]
+ data: tuple[tuple[int], float]
shape: tuple[int, int]
diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py
index 2dc2a63..fa7bd20 100644
--- a/polymatrix/polymatrix/init.py
+++ b/polymatrix/polymatrix/init.py
@@ -1,13 +1,24 @@
-from polymatrix.polymatrix.impl import PolyMatrixImpl
+from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl, PolyMatrixImpl
from polymatrix.polymatrix.typing import PolynomialData
def init_poly_matrix(
- terms: dict[tuple[int, int], PolynomialData],
+ data: dict[tuple[int, int], PolynomialData],
shape: tuple[int, int],
):
return PolyMatrixImpl(
- terms=terms,
+ data=data,
+ shape=shape,
+)
+
+
+def init_broadcast_poly_matrix(
+ data: PolynomialData,
+ shape: tuple[int, int],
+):
+
+ return BroadcastPolyMatrixImpl(
+ data=data,
shape=shape,
)
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index 5046492..0ce4dad 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -10,7 +10,7 @@ class PolyMatrixMixin(abc.ABC):
def shape(self) -> tuple[int, int]:
...
- def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], PolynomialData], None, None]:
+ def gen_data(self) -> typing.Generator[tuple[tuple[int, int], PolynomialData], None, None]:
for row in range(self.shape[0]):
for col in range(self.shape[1]):
polynomial = self.get_poly(row, col)
@@ -30,13 +30,13 @@ class PolyMatrixAsDictMixin(
):
@property
@abc.abstractmethod
- def terms(self) -> dict[tuple[int, int], PolynomialData]:
+ def data(self) -> dict[tuple[int, int], PolynomialData]:
...
# overwrites the abstract method of `PolyMatrixMixin`
def get_poly(self, row: int, col: int) -> PolynomialData | None:
- if (row, col) in self.terms:
- return self.terms[row, col]
+ if (row, col) in self.data:
+ return self.data[row, col]
class BroadcastPolyMatrixMixin(
@@ -45,9 +45,9 @@ class BroadcastPolyMatrixMixin(
):
@property
@abc.abstractmethod
- def polynomial(self) -> PolynomialData:
+ def data(self) -> PolynomialData:
...
# overwrites the abstract method of `PolyMatrixMixin`
def get_poly(self, col: int, row: int) -> PolynomialData | None:
- return self.polynomial
+ return self.data