diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2024-02-27 08:16:57 +0100 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2024-02-27 08:16:57 +0100 |
commit | 96b49a4153200f31114d924ee49e4b58a904cc67 (patch) | |
tree | 135a2d85ad005fb19abdda150a2fcc820dce9d20 | |
parent | improve typing with PolynomialData and MonomialData (diff) | |
download | polymatrix-96b49a4153200f31114d924ee49e4b58a904cc67.tar.gz polymatrix-96b49a4153200f31114d924ee49e4b58a904cc67.zip |
use or instead of
Diffstat (limited to '')
42 files changed, 255 insertions, 433 deletions
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 06fe23c..f8092d2 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -226,7 +226,7 @@ def init_from_terms_expr( if isinstance(terms, PolyMatrixMixin): shape = terms.shape - gen_terms = terms.gen_terms() + gen_terms = terms.gen_data() else: assert shape is not None diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 7dbf1d2..e3580f0 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -1,9 +1,8 @@ import abc import math -from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl +from polymatrix.expression.utils.broadcastpolymatrix import broadcast_poly_matrix from polymatrix.utils.getstacklines import FrameSummary -from polymatrix.utils.tooperatorexception import to_operator_exception from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState @@ -26,31 +25,6 @@ class AdditionExprMixin(ExpressionBaseMixin): def stack(self) -> tuple[FrameSummary]: ... - @staticmethod - def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]): - # broadcast left - if left.shape == (1, 1) and right.shape != (1, 1): - left = BroadcastPolyMatrixImpl( - polynomial=left.get_poly(0, 0), - shape=right.shape, - ) - - # broadcast right - elif left.shape != (1, 1) and right.shape == (1, 1): - right = BroadcastPolyMatrixImpl( - polynomial=right.get_poly(0, 0), - shape=left.shape, - ) - - else: - if not (left.shape == right.shape): - raise AssertionError(to_operator_exception( - message=f'{left.shape} != {right.shape}', - stack=stack, - )) - - return left, right - # overwrites the abstract method of `ExpressionBaseMixin` def apply( self, @@ -59,44 +33,14 @@ class AdditionExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - # if left.shape == (1, 1): - # left, right = right, left - - # if left.shape != (1, 1) and right.shape == (1, 1): - - # # @dataclassabc.dataclassabc(frozen=True) - # # class BroadCastedPolyMatrix(PolyMatrixMixin): - # # underlying_monomials: tuple[tuple[int], float] - # # shape: tuple[int, int] - - # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: - # # return self.underlying_monomials - - - # right = BroadcastPolyMatrixImpl( - # polynomial=right.get_poly(0, 0), - # shape=left.shape, - # ) - - # # all_underlying = (left, broadcasted_right) - - # else: - # if not (left.shape == right.shape): - # raise AssertionError(to_operator_exception( - # message=f'{left.shape} != {right.shape}', - # stack=self.stack, - # )) - - # # all_underlying = (left, right) - - left, right = self.broadcast(left, right, self.stack) + left, right = broadcast_poly_matrix(left, right, self.stack) - terms = {} + poly_matrix_data = {} for row in range(left.shape[0]): for col in range(left.shape[1]): - terms_row_col = {} + poly_data = {} for underlying in (left, right): @@ -104,26 +48,26 @@ class AdditionExprMixin(ExpressionBaseMixin): if polynomial is None: continue - if len(terms_row_col) == 0: - terms_row_col = dict(polynomial) + if len(poly_data) == 0: + poly_data = dict(polynomial) else: for monomial, value in polynomial.items(): - if monomial not in terms_row_col: - terms_row_col[monomial] = value + if monomial not in poly_data: + poly_data[monomial] = value else: - terms_row_col[monomial] += value + poly_data[monomial] += value - if math.isclose(terms_row_col[monomial], 0): - del terms_row_col[monomial] + if math.isclose(poly_data[monomial], 0): + del poly_data[monomial] - if 0 < len(terms_row_col): - terms[row, col] = terms_row_col + if 0 < len(poly_data): + poly_matrix_data[row, col] = poly_data poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=left.shape, ) diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py index e5aef11..22e5a6a 100644 --- a/polymatrix/expression/mixins/cacheexprmixin.py +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -29,12 +29,12 @@ class CacheExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state) if isinstance(underlying, PolyMatrixAsDictMixin): - cached_terms = underlying.terms + cached_data = underlying.data else: - cached_terms = dict(underlying.gen_terms()) + cached_data = dict(underlying.gen_data()) poly_matrix = init_poly_matrix( - terms=cached_terms, + data=cached_data, shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py index 9da69af..632efd5 100644 --- a/polymatrix/expression/mixins/combinationsexprmixin.py +++ b/polymatrix/expression/mixins/combinationsexprmixin.py @@ -31,19 +31,6 @@ class CombinationsExprMixin(ExpressionBaseMixin): self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - # if self.degree == 0: - # terms = {(0, 0): {tuple(): 1.0}} - - # poly_matrix = init_poly_matrix( - # terms=terms, - # shape=(1, 1), - # ) - - # elif self.degree == 1: - # state, monomials = self.expression.apply(state=state) - # poly_matrix = monomials - - # else: state, poly_matrix = self.expression.apply(state=state) @@ -55,13 +42,13 @@ class CombinationsExprMixin(ExpressionBaseMixin): indices = tuple(gen_indices()) - terms = {} + poly_matrix_data = {} for row, indexing in enumerate(indices): # x.combinations((0, 1, 2)) produces [1, x, x**2] if len(indexing) == 0: - terms[row, 0] = {tuple(): 1.0} + poly_matrix_data[row, 0] = {tuple(): 1.0} continue def acc_product(left, row): @@ -80,29 +67,11 @@ class CombinationsExprMixin(ExpressionBaseMixin): initial={}, ) - terms[row, 0] = polynomial + poly_matrix_data[row, 0] = polynomial poly_matrix = init_poly_matrix( - terms=terms, - shape=(len(terms), 1), + data=poly_matrix_data, + shape=(len(poly_matrix_data), 1), ) - # indices = filter(lambda v: sum(v) <= self.degree, itertools.product(*(range(self.degree) for _ in range(dim)))) - - # state, monomials = get_monomial_indices(state, self.monomials) - - # combinations = tuple(itertools.combinations_with_replacement(monomials, self.number)) - - # terms = {} - - # for row, combination in enumerate(combinations): - # combination_monomial = merge_monomial_indices(combination) - - # terms[row, 0] = {combination_monomial: 1.0} - - # poly_matrix = init_poly_matrix( - # terms=terms, - # shape=(math.comb(len(monomials) + self.number - 1, self.number), 1), - # ) - return state, poly_matrix diff --git a/polymatrix/expression/mixins/degreeexprmixin.py b/polymatrix/expression/mixins/degreeexprmixin.py index 273add2..71b936a 100644 --- a/polymatrix/expression/mixins/degreeexprmixin.py +++ b/polymatrix/expression/mixins/degreeexprmixin.py @@ -26,26 +26,24 @@ class DegreeExprMixin(ExpressionBaseMixin): ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - underlying_terms = underlying.get_poly(row, col) + polynomial = underlying.get_poly(row, col) - if underlying_terms is None or len(underlying_terms) == 0: + if polynomial is None or len(polynomial) == 0: continue def gen_degrees(): - for monomial, _ in underlying_terms.items(): + for monomial, _ in polynomial.items(): yield sum(count for _, count in monomial) - # degrees = tuple(gen_degrees()) - - terms[row, col] = {tuple(): max(gen_degrees())} + poly_matrix_data[row, col] = {tuple(): max(gen_degrees())} poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 38ea9cc..0b6ec6e 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -56,7 +56,7 @@ class DerivativeExprMixin(ExpressionBaseMixin): stack=self.stack, )) - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): @@ -76,10 +76,10 @@ class DerivativeExprMixin(ExpressionBaseMixin): ) if 0 < len(derivation): - terms[row, col] = derivation + poly_matrix_data[row, col] = derivation poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(underlying.shape[0], len(diff_wrt_variables)), ) diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py index 7614669..d848bd2 100644 --- a/polymatrix/expression/mixins/determinantexprmixin.py +++ b/polymatrix/expression/mixins/determinantexprmixin.py @@ -26,7 +26,7 @@ class DeterminantExprMixin(ExpressionBaseMixin): self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - raise Exception('not implemented') + # raise Exception('not implemented') if self in state.cache: return state, state.cache[self] @@ -35,7 +35,7 @@ class DeterminantExprMixin(ExpressionBaseMixin): assert underlying.shape[0] == underlying.shape[1] - inequality_terms = {} + inequality_data = {} auxillary_equations = {} index_start = state.n_param @@ -43,69 +43,69 @@ class DeterminantExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): - current_inequality_terms = collections.defaultdict(float) + polynomial = collections.defaultdict(float) # f in f-v^T@x-r^2 # terms = underlying.get_poly(row, row) try: - underlying_terms = underlying.get_poly(row, row) + underlying_poly = underlying.get_poly(row, row) except KeyError: pass else: - for monomial, value in underlying_terms.items(): - current_inequality_terms[monomial] += value + for monomial, value in underlying_poly.items(): + polynomial[monomial] += value for inner_row in range(row): # -v^T@x in f-v^T@x-r^2 # terms = underlying.get_poly(row, inner_row) try: - underlying_terms = underlying.get_poly(row, inner_row) + underlying_poly = underlying.get_poly(row, inner_row) except KeyError: pass else: - for monomial, value in underlying_terms.items(): + for monomial, value in underlying_poly.items(): new_monomial = monomial + (index_start + rel_index + inner_row,) - current_inequality_terms[new_monomial] -= value + polynomial[new_monomial] -= value # auxillary terms # --------------- - auxillary_term = collections.defaultdict(float) + auxillary_polynomial = collections.defaultdict(float) for inner_col in range(row): # P@x in P@x-v key = tuple(reversed(sorted((inner_row, inner_col)))) try: - underlying_terms = underlying.get_poly(*key) + underlying_poly = underlying.get_poly(*key) except KeyError: pass else: - for monomial, value in underlying_terms.items(): + for monomial, value in underlying_poly.items(): new_monomial = monomial + (index_start + rel_index + inner_col,) - auxillary_term[new_monomial] += value + auxillary_polynomial[new_monomial] += value # -v in P@x-v try: - underlying_terms = underlying.get_poly(row, inner_row) + underlying_poly = underlying.get_poly(row, inner_row) except KeyError: pass else: - for monomial, value in underlying_terms.items(): - auxillary_term[monomial] -= value + for monomial, value in underlying_poly.items(): + auxillary_polynomial[monomial] -= value x_variable = index_start + rel_index + inner_row assert x_variable not in state.auxillary_equations - auxillary_equations[x_variable] = dict(auxillary_term) + auxillary_equations[x_variable] = dict(auxillary_polynomial) rel_index += row - inequality_terms[row, 0] = dict(current_inequality_terms) + inequality_data[row, 0] = dict(polynomial) state = state.register(rel_index) poly_matrix = init_poly_matrix( - terms=inequality_terms, + data=inequality_data, shape=(underlying.shape[0], 1), ) diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 10db5b4..fc12abd 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -55,8 +55,7 @@ class DivergenceExprMixin(ExpressionBaseMixin): monomial_terms[monomial] += value poly_matrix = init_poly_matrix( - # terms={(0, 0): dict(monomial_terms)}, - terms={(0, 0): monomial_terms}, + data={(0, 0): monomial_terms}, shape=(1, 1), ) diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py index 0f2fada..65e6bc3 100644 --- a/polymatrix/expression/mixins/divisionexprmixin.py +++ b/polymatrix/expression/mixins/divisionexprmixin.py @@ -55,7 +55,7 @@ class DivisionExprMixin(ExpressionBaseMixin): state=state, left=left, right=init_poly_matrix( - terms=right_inv, + data=right_inv, shape=(1, 1), ) ) @@ -64,7 +64,7 @@ class DivisionExprMixin(ExpressionBaseMixin): if self in state.cache: return state, state.cache[self] - terms = {} + poly_matrix_data = {} division_variable = state.n_param state = state.register(n_param=1) @@ -80,7 +80,7 @@ class DivisionExprMixin(ExpressionBaseMixin): for monomial, value in underlying_terms.items(): yield monomial + ((division_variable, 1),), value - terms[row, col] = dict(gen_monomial_terms()) + poly_matrix_data[row, col] = dict(gen_monomial_terms()) def gen_auxillary_terms(): for monomial, value in right_poly.items(): @@ -94,7 +94,7 @@ class DivisionExprMixin(ExpressionBaseMixin): auxillary_terms[tuple()] -= 1 poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=left.shape, ) diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index 17be163..bfe6397 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -23,11 +23,6 @@ class ElemMultExprMixin(ExpressionBaseMixin): def right(self) -> ExpressionBaseMixin: ... - # # overwrites the abstract method of `ExpressionBaseMixin` - # @property - # def shape(self) -> tuple[int, int]: - # return self.left.shape - @staticmethod def elem_mult( state: ExpressionState, @@ -53,23 +48,23 @@ class ElemMultExprMixin(ExpressionBaseMixin): shape=left.shape, ) - terms = {} + poly_matrix_data = {} for poly_row in range(left.shape[0]): for poly_col in range(left.shape[1]): - terms_row_col = {} + polynomial = {} - left_terms = left.get_poly(poly_row, poly_col) - if left_terms is None: + left_polynomial = left.get_poly(poly_row, poly_col) + if left_polynomial is None: continue - right_terms = right.get_poly(poly_row, poly_col) - if right_terms is None: + right_polynomial = right.get_poly(poly_row, poly_col) + if right_polynomial is None: continue for (left_monomial, left_value), (right_monomial, right_value) \ - in itertools.product(left_terms.items(), right_terms.items()): + in itertools.product(left_polynomial.items(), right_polynomial.items()): value = left_value * right_value @@ -80,16 +75,16 @@ class ElemMultExprMixin(ExpressionBaseMixin): new_monomial = merge_monomial_indices((left_monomial, right_monomial)) - if new_monomial not in terms_row_col: - terms_row_col[new_monomial] = 0 + if new_monomial not in polynomial: + polynomial[new_monomial] = 0 - terms_row_col[new_monomial] += value + polynomial[new_monomial] += value - if 0 < len(terms_row_col): - terms[poly_row, poly_col] = terms_row_col + if 0 < len(polynomial): + poly_matrix_data[poly_row, poly_col] = polynomial poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=left.shape, ) @@ -105,64 +100,3 @@ class ElemMultExprMixin(ExpressionBaseMixin): state, right = self.right.apply(state=state) return self.elem_mult(state, left, right) - - # if left.shape != right.shape and left.shape == (1, 1): - # left, right = right, left - - # if right.shape == (1, 1): - # right_poly = right.get_poly(0, 0) - - # @dataclassabc.dataclassabc(frozen=True) - # class BroadCastedPolyMatrix(PolyMatrixMixin): - # underlying: tuple[tuple[int], float] - # shape: tuple[int, int] - - # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: - # return self.underlying - - # right = BroadCastedPolyMatrix( - # underlying=right_poly, - # shape=left.shape, - # ) - - # terms = {} - - # for poly_row in range(left.shape[0]): - # for poly_col in range(left.shape[1]): - - # terms_row_col = {} - - # left_terms = left.get_poly(poly_row, poly_col) - # if left_terms is None: - # continue - - # right_terms = right.get_poly(poly_row, poly_col) - # if right_terms is None: - # continue - - # for (left_monomial, left_value), (right_monomial, right_value) \ - # in itertools.product(left_terms.items(), right_terms.items()): - - # value = left_value * right_value - - # # if value == 0: - # # continue - - # # monomial = tuple(sorted(left_monomial + right_monomial)) - - # new_monomial = merge_monomial_indices((left_monomial, right_monomial)) - - # if new_monomial not in terms_row_col: - # terms_row_col[new_monomial] = 0 - - # terms_row_col[new_monomial] += value - - # if 0 < len(terms_row_col): - # terms[poly_row, poly_col] = terms_row_col - - # poly_matrix = init_poly_matrix( - # terms=terms, - # shape=left.shape, - # ) - - # return state, poly_matrix diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index 0c23c1e..b86cafa 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -52,18 +52,18 @@ class EvalExprMixin(ExpressionBaseMixin): initial=(state, tuple(), tuple()) ) - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - underlying_terms = underlying.get_poly(row, col) - if underlying_terms is None: + underlying_polynomial = underlying.get_poly(row, col) + if underlying_polynomial is None: continue - terms_row_col = {} + polynomial = {} - for monomial, value in underlying_terms.items(): + for monomial, value in underlying_polynomial.items(): def acc_monomial(acc, next): new_monomial, value = acc @@ -83,20 +83,20 @@ class EvalExprMixin(ExpressionBaseMixin): initial=(tuple(), value), )) - if new_monomial not in terms_row_col: - terms_row_col[new_monomial] = 0 + if new_monomial not in polynomial: + polynomial[new_monomial] = 0 - terms_row_col[new_monomial] += new_value + polynomial[new_monomial] += new_value # delete zero entries - if math.isclose(terms_row_col[new_monomial], 0, abs_tol=1e-12): - del terms_row_col[new_monomial] + if math.isclose(polynomial[new_monomial], 0, abs_tol=1e-12): + del polynomial[new_monomial] - if 0 < len(terms_row_col): - terms[row, col] = terms_row_col + if 0 < len(polynomial): + poly_matrix_data[row, col] = polynomial poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py index 86b5c6a..a240102 100644 --- a/polymatrix/expression/mixins/filterexprmixin.py +++ b/polymatrix/expression/mixins/filterexprmixin.py @@ -36,14 +36,14 @@ class FilterExprMixin(ExpressionBaseMixin): assert predicator.shape[1] == 1 assert underlying.shape[0] == predicator.shape[0] - terms = {} + poly_matrix_data = {} row_index = 0 for row in range(underlying.shape[0]): - underlying_terms = underlying.get_poly(row, 0) + underlying_polynomial = underlying.get_poly(row, 0) - if underlying_terms is None: + if underlying_polynomial is None: continue predicator_poly = predicator.get_poly(row, 0) @@ -64,11 +64,11 @@ class FilterExprMixin(ExpressionBaseMixin): predicator_value = 0 if (predicator_value != 0) is not self.inverse: - terms[row_index, 0] = underlying_terms + poly_matrix_data[row_index, 0] = underlying_polynomial row_index += 1 poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(row_index, 1), ) diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py index dfca992..562ac61 100644 --- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py +++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py @@ -29,7 +29,7 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin): state, variables = self.variables.apply(state=state) def gen_variable_monomials(): - for _, term in variables.gen_terms(): + for _, term in variables.gen_data(): assert len(term) == 1, f'{term} should have only a single monomial' for monomial in term.keys(): @@ -37,18 +37,18 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin): variable_monomials = tuple(gen_variable_monomials()) - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - underlying_terms = underlying.get_poly(row, col) - if underlying_terms is None: + underlying_polynomial = underlying.get_poly(row, col) + if underlying_polynomial is None: continue - monomial_terms = collections.defaultdict(float) + polynomial = collections.defaultdict(float) - for monomial, value in underlying_terms.items(): + for monomial, value in underlying_polynomial.items(): for variable_monomial in variable_monomials: @@ -63,15 +63,14 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin): # take the first that matches if all(variable not in remainder for variable in variable_monomial): - monomial_terms[remainder] += value + polynomial[remainder] += value break - terms[row, col] = monomial_terms - # terms[row, col] = dict(monomial_terms) + poly_matrix_data[row, col] = dict(polynomial) poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py index 7d48c0f..c7d90a1 100644 --- a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py @@ -28,18 +28,18 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin): assert underlying.shape[0] == underlying.shape[1] - terms = {} + poly_matrix_data = {} var_index = 0 for row in range(underlying.shape[0]): for col in range(row, underlying.shape[1]): - terms[var_index, 0] = underlying.get_poly(row, col) + poly_matrix_data[var_index, 0] = underlying.get_poly(row, col) var_index += 1 poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(var_index, 1), ) diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 8aac10b..22512d5 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -26,10 +26,10 @@ class FromTermsExprMixin(ExpressionBaseMixin): state: ExpressionStateMixin, ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]: - terms = {coord: dict(monomials) for coord, monomials in self.terms} + data = {coord: dict(monomials) for coord, monomials in self.terms} poly_matrix = init_poly_matrix( - terms=terms, + data=data, shape=self.shape, ) diff --git a/polymatrix/expression/mixins/fromtupleexprmixin.py b/polymatrix/expression/mixins/fromtupleexprmixin.py index 19b3ed3..92b965c 100644 --- a/polymatrix/expression/mixins/fromtupleexprmixin.py +++ b/polymatrix/expression/mixins/fromtupleexprmixin.py @@ -100,7 +100,7 @@ class FromTupleExprMixin(ExpressionBaseMixin): polynomials[poly_row, poly_col] = polynomial poly_matrix = init_poly_matrix( - terms=polynomials, + data=polynomials, shape=(len(self.data), len(self.data[0])), ) diff --git a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py index 21d16fd..445e776 100644 --- a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py +++ b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py @@ -106,7 +106,7 @@ class HalfNewtonPolytopeExprMixin(ExpressionBaseMixin): ) poly_matrix = init_poly_matrix( - terms={(row, 0): {monom: 1} for row, monom in enumerate(monomials)}, + data={(row, 0): {monom: 1} for row, monom in enumerate(monomials)}, shape=(len(monomials), 1), ) diff --git a/polymatrix/expression/mixins/legendreseriesmixin.py b/polymatrix/expression/mixins/legendreseriesmixin.py index 194e65e..aa3de58 100644 --- a/polymatrix/expression/mixins/legendreseriesmixin.py +++ b/polymatrix/expression/mixins/legendreseriesmixin.py @@ -36,27 +36,27 @@ class LegendreSeriesMixin(ExpressionBaseMixin): else: degrees = self.degrees - terms = {} + poly_matrix_data = {} for degree in degrees: # for degree in self.degree: poly = underlying.get_poly(degree, 0) - terms[degree, 0] = dict(poly) + poly_matrix_data[degree, 0] = dict(poly) if 2 <= degree: poly = underlying.get_poly(degree - 2, 0) factor = - (degree - 1) / (degree + 1) for m, v in poly.items(): - if m in terms[degree, 0]: - terms[degree, 0][m] += v*factor + if m in poly_matrix_data[degree, 0]: + poly_matrix_data[degree, 0][m] += v*factor else: - terms[degree, 0][m] = v*factor + poly_matrix_data[degree, 0][m] = v*factor poly_matrix = init_poly_matrix( - terms=terms, - shape=(len(terms), 1), + data=poly_matrix_data, + shape=(len(poly_matrix_data), 1), ) return state, poly_matrix diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index ded61ec..f0b57d3 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -87,7 +87,7 @@ class LinearInExprMixin(ExpressionBaseMixin): terms[row, col][p_monomial] = value poly_matrix = init_poly_matrix( - terms=dict(terms), + data=dict(terms), shape=(underlying.shape[0], len(monomials)), ) diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py index 4d25a1e..3bb8dfe 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -53,7 +53,7 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin): break poly_matrix = init_poly_matrix( - terms=dict(terms), + data=dict(terms), shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index ed9e79f..c7de7f1 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -62,14 +62,14 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): linear_monomials = sort_monomials(set(gen_linear_monomials())) - def gen_terms(): + def gen_data(): for index, monomial in enumerate(linear_monomials): yield (index, 0), {monomial: 1.0} - terms = dict(gen_terms()) + poly_matrix_data = dict(gen_data()) poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(len(linear_monomials), 1), ) diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index b7ae5ce..22a6c94 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -42,30 +42,30 @@ class MatrixMultExprMixin(ExpressionBaseMixin): stack=self.stack, )) - terms = {} + poly_matrix_data = {} for poly_row in range(left.shape[0]): for poly_col in range(right.shape[1]): - terms_row_col = {} + polynomial = {} for index_k in range(left.shape[1]): - left_terms = left.get_poly(poly_row, index_k) - if left_terms is None: + left_polynomial = left.get_poly(poly_row, index_k) + if left_polynomial is None: continue - right_terms = right.get_poly(index_k, poly_col) - if right_terms is None: + right_polynomial = right.get_poly(index_k, poly_col) + if right_polynomial is None: continue - multiply_polynomial(left_terms, right_terms, terms_row_col) + multiply_polynomial(left_polynomial, right_polynomial, polynomial) - if 0 < len(terms_row_col): - terms[poly_row, poly_col] = terms_row_col + if 0 < len(polynomial): + poly_matrix_data[poly_row, poly_col] = polynomial poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(left.shape[0], right.shape[1]), ) diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py index bfabc4d..fa1155e 100644 --- a/polymatrix/expression/mixins/maxexprmixin.py +++ b/polymatrix/expression/mixins/maxexprmixin.py @@ -21,7 +21,7 @@ class MaxExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state) - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): @@ -37,10 +37,10 @@ class MaxExprMixin(ExpressionBaseMixin): values = tuple(gen_values()) if 0 < len(values): - terms[row, 0] = {tuple(): max(values)} + poly_matrix_data[row, 0] = {tuple(): max(values)} poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(underlying.shape[0], 1), ) diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index 712ff3f..0160b9a 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -50,14 +50,14 @@ class ParametrizeExprMixin(ExpressionBaseMixin): # assert underlying.shape[1] == 1 - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): var_index = start + row - terms[row, 0] = {((var_index, 1),): 1} + poly_matrix_data[row, 0] = {((var_index, 1),): 1} poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py index 1aa393b..c51bcf0 100644 --- a/polymatrix/expression/mixins/parametrizematrixexprmixin.py +++ b/polymatrix/expression/mixins/parametrizematrixexprmixin.py @@ -33,13 +33,13 @@ class ParametrizeMatrixExprMixin(ExpressionBaseMixin): assert underlying.shape[1] == 1 - terms = {} + poly_matrix_data = {} var_index = 0 for row in range(underlying.shape[0]): for _ in range(row, underlying.shape[0]): - terms[var_index, 0] = {((state.n_param + var_index, 1),): 1.0} + poly_matrix_data[var_index, 0] = {((state.n_param + var_index, 1),): 1.0} var_index += 1 @@ -49,7 +49,7 @@ class ParametrizeMatrixExprMixin(ExpressionBaseMixin): ) poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(var_index, 1), ) diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py index 0168e27..b41865b 100644 --- a/polymatrix/expression/mixins/productexprmixin.py +++ b/polymatrix/expression/mixins/productexprmixin.py @@ -1,6 +1,7 @@ import abc import itertools +from polymatrix.polymatrix.typing import PolynomialData from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial from polymatrix.utils.getstacklines import FrameSummary @@ -31,23 +32,9 @@ class ProductExprMixin(ExpressionBaseMixin): self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: - - # if self.number == 0: - # terms = {(0, 0): {tuple(): 1.0}} - - # poly_matrix = init_poly_matrix( - # terms=terms, - # shape=(1, 1), - # ) - - # elif self.number == 1: - # state, monomials = self.monomials.apply(state=state) - # poly_matrix = monomials - - # else: if len(self.underlying) == 0: - terms = {(0,0): {tuple(): 1}} + poly_matrix_data = {(0,0): {tuple(): 1}} else: @@ -65,23 +52,6 @@ class ProductExprMixin(ExpressionBaseMixin): initial=(state, tuple()) ) - # highest_degrees = tuple(e.shape[0] for e in underlying) - - # if self.degrees is None: - # degrees = range(sum(highest_degrees)) - - # else: - # degrees = self.degrees - - # max_degree = max(degrees) - - # for poly_matrix in underlying: - # if not (max_degree <= poly_matrix.shape[0]): - # raise AssertionError(to_operator_exception( - # message=f'{poly_matrix.shape[0]} < {max_degree}', - # stack=self.stack, - # )) - def gen_indices(): product_indices = itertools.product(*(range(e.shape[0]) for e in underlying)) @@ -92,26 +62,16 @@ class ProductExprMixin(ExpressionBaseMixin): yield from filter(lambda v: sum(v) in self.degrees, product_indices) indices = tuple(gen_indices()) - # print(indices) - - # indices = filter(lambda v: sum(v) <= self.degree, itertools.product(*(range(self.degree) for _ in range(dim)))) - terms = {} + poly_matrix_data = {} for row, indexing in enumerate(indices): - # def acc_product(acc, v): - # left_monomials = acc - # polymatrix, row = v - - # right_monomials = polymatrix.get_poly(row, 0).keys() - - # if left_monomials is (None,): - # return right_monomials - - # return tuple(multiply_monomials(left_monomials, right_monomials)) - - def acc_product(left, v): + def acc_product( + left: PolynomialData, + v: tuple[ExpressionBaseMixin, int], + ) -> PolynomialData: + poly_matrix, row = v right = poly_matrix.get_poly(row, 0) @@ -129,11 +89,11 @@ class ProductExprMixin(ExpressionBaseMixin): initial={}, ) - terms[row, 0] = polynomial + poly_matrix_data[row, 0] = polynomial poly_matrix = init_poly_matrix( - terms=terms, - shape=(len(terms), 1), + data=poly_matrix_data, + shape=(len(poly_matrix_data), 1), ) return state, poly_matrix diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index e739364..2cc9e12 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -87,7 +87,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): terms[row, col][p_monomial] += value poly_matrix = init_poly_matrix( - terms=dict((k, dict(v)) for k, v in terms.items()), + data=dict((k, dict(v)) for k, v in terms.items()), shape=2*(len(sos_monomials),), ) diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index 99befc8..3a7de30 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -65,14 +65,14 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): sos_monomials = tuple(sorted(set(gen_sos_monomials()), key=lambda m: (len(m), m))) - def gen_terms(): + def gen_data(): for index, monomial in enumerate(sos_monomials): yield (index, 0), {monomial: 1.0} - terms = dict(gen_terms()) + poly_matrix_data = dict(gen_data()) poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(len(sos_monomials), 1), ) diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py index eed3d71..888300f 100644 --- a/polymatrix/expression/mixins/squeezeexprmixin.py +++ b/polymatrix/expression/mixins/squeezeexprmixin.py @@ -23,7 +23,7 @@ class SqueezeExprMixin(ExpressionBaseMixin): assert underlying.shape[1] == 1 - terms = {} + poly_matrix_data = {} row_index = 0 for row in range(underlying.shape[0]): @@ -32,18 +32,18 @@ class SqueezeExprMixin(ExpressionBaseMixin): if polynomial is None: continue - terms_row_col = {} + polynomial = {} for monomial, value in polynomial.items(): if value != 0.0: - terms_row_col[monomial] = value + polynomial[monomial] = value - if len(terms_row_col): - terms[row_index, 0] = terms_row_col + if len(polynomial): + poly_matrix_data[row_index, 0] = polynomial row_index += 1 poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(row_index, 1), ) diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index b75a2de..2bc36d1 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -67,7 +67,7 @@ class SubstituteExprMixin(ExpressionBaseMixin): else: assert len(variable_indices) == len(substitutions), f'{substitutions=}' - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): @@ -76,11 +76,11 @@ class SubstituteExprMixin(ExpressionBaseMixin): if polynomial is None: continue - terms_row_col = collections.defaultdict(float) + polynomial = collections.defaultdict(float) for monomial, value in polynomial.items(): - terms_monomial = {tuple(): value} + substituted_monomial = {tuple(): value} for variable, count in monomial: if variable in variable_indices: @@ -90,24 +90,24 @@ class SubstituteExprMixin(ExpressionBaseMixin): for _ in range(count): next = {} - multiply_polynomial(terms_monomial, substitution, next) - terms_monomial = next + multiply_polynomial(substituted_monomial, substitution, next) + substituted_monomial = next else: next = {} - multiply_polynomial(terms_monomial, {((variable, count),): 1.0}, next) - terms_monomial = next + multiply_polynomial(substituted_monomial, {((variable, count),): 1.0}, next) + substituted_monomial = next - for monomial, value in terms_monomial.items(): - terms_row_col[monomial] += value + for monomial, value in substituted_monomial.items(): + polynomial[monomial] += value - terms_row_col = {key: val for key, val in terms_row_col.items() if not math.isclose(val, 0, abs_tol=1e-12)} + polynomial = {key: val for key, val in polynomial.items() if not math.isclose(val, 0, abs_tol=1e-12)} - if 0 < len(terms_row_col): - terms[row, col] = terms_row_col + if 0 < len(polynomial): + poly_matrix_data[row, col] = polynomial poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py index 9275332..a01b5e7 100644 --- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py @@ -44,10 +44,10 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin): remainders = sort_monomials(set(gen_remainders())) - terms = {(row, 0): {remainder: 1.0} for row, remainder in enumerate(remainders)} + data = {(row, 0): {remainder: 1.0} for row, remainder in enumerate(remainders)} poly_matrix = init_poly_matrix( - terms=terms, + data=data, shape=(len(remainders), 1), ) diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py index 812b260..83f9397 100644 --- a/polymatrix/expression/mixins/sumexprmixin.py +++ b/polymatrix/expression/mixins/sumexprmixin.py @@ -29,7 +29,7 @@ class SumExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state) - terms = collections.defaultdict(dict) + poly_matrix_data = collections.defaultdict(dict) for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): @@ -38,16 +38,16 @@ class SumExprMixin(ExpressionBaseMixin): if polynomial is None: continue - term_monomials = terms[row, 0] + polynomial = poly_matrix_data[row, 0] for monomial, value in polynomial.items(): - if monomial in term_monomials: - term_monomials[monomial] += value + if monomial in polynomial: + polynomial[monomial] += value else: - term_monomials[monomial] = value + polynomial[monomial] = value poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(underlying.shape[0], 1), ) diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py index b8b3d64..c3dc4e2 100644 --- a/polymatrix/expression/mixins/toconstantexprmixin.py +++ b/polymatrix/expression/mixins/toconstantexprmixin.py @@ -39,7 +39,7 @@ class ToConstantExprMixin(ExpressionBaseMixin): terms[row, col][tuple()] = polynomial[tuple()] poly_matrix = init_poly_matrix( - terms=dict(terms), + data=dict(terms), shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py index d10fcd9..c6b692b 100644 --- a/polymatrix/expression/mixins/toquadraticexprmixin.py +++ b/polymatrix/expression/mixins/toquadraticexprmixin.py @@ -9,7 +9,6 @@ from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.expressionstate.abc import ExpressionState -# to be deleted? class ToQuadraticExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod @@ -25,17 +24,17 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): state = [state] - terms = {} + poly_matrix_data = {} auxillary_equations_from_quadratic = {} def to_quadratic(monomial_terms): - terms_row_col = collections.defaultdict(float) + polynomial = collections.defaultdict(float) for monomial, value in monomial_terms.items(): if 2 < len(monomial): current_aux = state[0].n_param - terms_row_col[(monomial[0], current_aux)] += value + polynomial[(monomial[0], current_aux)] += value state[0] = state[0].register(n_param=1) for variable in monomial[1:-2]: @@ -52,10 +51,10 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): } else: - terms_row_col[monomial] += value + polynomial[monomial] += value # return dict(terms_row_col) - return terms_row_col + return polynomial for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): @@ -64,43 +63,18 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): if polynomial is None: continue - terms_row_col = to_quadratic( + polynomial = to_quadratic( monomial_terms=polynomial, ) - # terms_row_col = collections.defaultdict(float) - - # for monomial, value in underlying_terms.items(): - - # if 2 < len(monomial): - # current_aux = state.n_param - # terms_row_col[(monomial[0], current_aux)] += value - # state = state.register(n_param=1) - - # for variable in monomial[1:-2]: - # auxillary_equations[current_aux] = { - # (variable, current_aux + 1): 1, - # (current_aux,): -1, - # } - # state = state.register(n_param=1) - # current_aux += 1 - - # auxillary_equations[current_aux] = { - # (monomial[-2], monomial[-1]): 1, - # (current_aux,): -1, - # } - - # else: - # terms_row_col[monomial] += value - - terms[row, col] = terms_row_col + poly_matrix_data[row, col] = polynomial def gen_auxillary_equations(): for key, monomial_terms in state[0].auxillary_equations.items(): - terms_row_col = to_quadratic( + polynomial = to_quadratic( monomial_terms=monomial_terms, ) - yield key, terms_row_col + yield key, polynomial state = dataclasses.replace( state[0], @@ -108,7 +82,7 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): ) poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=underlying.shape, ) diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py index a61a7e9..32e2162 100644 --- a/polymatrix/expression/mixins/tosortedvariablesmixin.py +++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py @@ -31,7 +31,7 @@ class ToSortedVariablesExprMixin(ExpressionBaseMixin): yield (row, 0), {((index, 1),): 1} poly_matrix = init_poly_matrix( - terms=dict(gen_sorted_vector()), + data=dict(gen_sorted_vector()), shape=(len(variable_indices), 1), ) diff --git a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py index 8801147..488c991 100644 --- a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py +++ b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py @@ -44,20 +44,20 @@ class ToSymmetricMatrixExprMixin(ExpressionBaseMixin): n_row = invert_binomial_coefficient(underlying.shape[0]) - terms = {} + poly_matrix_data = {} var_index = 0 for row in range(n_row): for col in range(row, n_row): - terms[row, col] = underlying.get_poly(var_index, 0) + poly_matrix_data[row, col] = underlying.get_poly(var_index, 0) if row != col: - terms[col, row] = terms[row, col] + poly_matrix_data[col, row] = poly_matrix_data[row, col] var_index += 1 poly_matrix = init_poly_matrix( - terms=terms, + data=poly_matrix_data, shape=(n_row, n_row), ) diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index fdf3921..2fcb970 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -43,7 +43,7 @@ class TruncateExprMixin(ExpressionBaseMixin): state, variable_indices = get_variable_indices_from_variable(state, self.variables) cond = lambda idx: idx in variable_indices - terms = {} + poly_matrix_data = {} for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): @@ -52,19 +52,19 @@ class TruncateExprMixin(ExpressionBaseMixin): if polynomial is None: continue - terms_row_col = {} + polynomial = {} for monomial, value in polynomial.items(): degree = sum((count for var_idx, count in monomial if cond(var_idx))) if (degree in self.degrees) is not self.inverse: - terms_row_col[monomial] = value + polynomial[monomial] = value - terms[row, col] = terms_row_col + poly_matrix_data[row, col] = polynomial poly_matrix = init_poly_matrix( - terms=dict(terms), + data=dict(poly_matrix_data), shape=underlying.shape, ) diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index b02bb11..cebc75a 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -31,7 +31,7 @@ def to_constant( A = np.zeros(underlying.shape, dtype=np.double) - for (row, col), polynomial in underlying.gen_terms(): + for (row, col), polynomial in underlying.gen_data(): for monomial, value in polynomial.items(): if len(monomial) == 0: A[row, col] = value @@ -86,7 +86,7 @@ def to_sympy( A = np.zeros(underlying.shape, dtype=object) - for (row, col), polynomial in underlying.gen_terms(): + for (row, col), polynomial in underlying.gen_data(): sympy_polynomial = 0 diff --git a/polymatrix/expression/utils/broadcastpolymatrix.py b/polymatrix/expression/utils/broadcastpolymatrix.py new file mode 100644 index 0000000..e8b3ff3 --- /dev/null +++ b/polymatrix/expression/utils/broadcastpolymatrix.py @@ -0,0 +1,34 @@ +from polymatrix.polymatrix.init import init_broadcast_poly_matrix, init_poly_matrix +from polymatrix.utils.getstacklines import FrameSummary +from polymatrix.utils.tooperatorexception import to_operator_exception +from polymatrix.polymatrix.abc import PolyMatrix + + +def broadcast_poly_matrix( + left: PolyMatrix, + right: PolyMatrix, + stack: tuple[FrameSummary], +) -> PolyMatrix: + + # broadcast left + if left.shape == (1, 1) and right.shape != (1, 1): + left = init_broadcast_poly_matrix( + data=left.get_poly(0, 0), + shape=right.shape, + ) + + # broadcast right + elif left.shape != (1, 1) and right.shape == (1, 1): + right = init_broadcast_poly_matrix( + data=right.get_poly(0, 0), + shape=left.shape, + ) + + else: + if not (left.shape == right.shape): + raise AssertionError(to_operator_exception( + message=f'{left.shape} != {right.shape}', + stack=stack, + )) + + return left, right diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py index f957685..04a7872 100644 --- a/polymatrix/polymatrix/impl.py +++ b/polymatrix/polymatrix/impl.py @@ -7,11 +7,11 @@ from polymatrix.polymatrix.typing import PolynomialData @dataclassabc.dataclassabc(frozen=True) class PolyMatrixImpl(PolyMatrix): - terms: dict[tuple[int, int], PolynomialData] + data: dict[tuple[int, int], PolynomialData] shape: tuple[int, int] @dataclassabc.dataclassabc(frozen=True) class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin): - polynomial: tuple[tuple[int], float] + data: tuple[tuple[int], float] shape: tuple[int, int] diff --git a/polymatrix/polymatrix/init.py b/polymatrix/polymatrix/init.py index 2dc2a63..fa7bd20 100644 --- a/polymatrix/polymatrix/init.py +++ b/polymatrix/polymatrix/init.py @@ -1,13 +1,24 @@ -from polymatrix.polymatrix.impl import PolyMatrixImpl +from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl, PolyMatrixImpl from polymatrix.polymatrix.typing import PolynomialData def init_poly_matrix( - terms: dict[tuple[int, int], PolynomialData], + data: dict[tuple[int, int], PolynomialData], shape: tuple[int, int], ): return PolyMatrixImpl( - terms=terms, + data=data, + shape=shape, +) + + +def init_broadcast_poly_matrix( + data: PolynomialData, + shape: tuple[int, int], +): + + return BroadcastPolyMatrixImpl( + data=data, shape=shape, ) diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py index 5046492..0ce4dad 100644 --- a/polymatrix/polymatrix/mixins.py +++ b/polymatrix/polymatrix/mixins.py @@ -10,7 +10,7 @@ class PolyMatrixMixin(abc.ABC): def shape(self) -> tuple[int, int]: ... - def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], PolynomialData], None, None]: + def gen_data(self) -> typing.Generator[tuple[tuple[int, int], PolynomialData], None, None]: for row in range(self.shape[0]): for col in range(self.shape[1]): polynomial = self.get_poly(row, col) @@ -30,13 +30,13 @@ class PolyMatrixAsDictMixin( ): @property @abc.abstractmethod - def terms(self) -> dict[tuple[int, int], PolynomialData]: + def data(self) -> dict[tuple[int, int], PolynomialData]: ... # overwrites the abstract method of `PolyMatrixMixin` def get_poly(self, row: int, col: int) -> PolynomialData | None: - if (row, col) in self.terms: - return self.terms[row, col] + if (row, col) in self.data: + return self.data[row, col] class BroadcastPolyMatrixMixin( @@ -45,9 +45,9 @@ class BroadcastPolyMatrixMixin( ): @property @abc.abstractmethod - def polynomial(self) -> PolynomialData: + def data(self) -> PolynomialData: ... # overwrites the abstract method of `PolyMatrixMixin` def get_poly(self, col: int, row: int) -> PolynomialData | None: - return self.polynomial + return self.data |