diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-19 15:59:54 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-19 15:59:54 +0200 |
commit | 52b1615022730b3ceb4b3ced0f6426ab6c764328 (patch) | |
tree | 811ff4f94c179e8f865eb5a8127716018f731d9f | |
parent | add function 'get_key_from_offset' (diff) | |
download | polymatrix-52b1615022730b3ceb4b3ced0f6426ab6c764328.tar.gz polymatrix-52b1615022730b3ceb4b3ced0f6426ab6c764328.zip |
'get_poly' returns None if (row, col) entry is empty, instead of raising a KeyError exception
39 files changed, 195 insertions, 220 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 78f27b8..d0673e4 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -112,9 +112,8 @@ def kkt_equality( for eq_idx, nu_variable in enumerate(nu_variables): - try: - underlying_terms = equality_der.get_poly(eq_idx, row) - except KeyError: + underlying_terms = equality_der.get_poly(eq_idx, row) + if underlying_terms is None: continue for monomial, value in underlying_terms.items(): @@ -122,7 +121,8 @@ def kkt_equality( monomial_terms[new_monomial] += value - terms[row, 0] = dict(monomial_terms) + # terms[row, 0] = dict(monomial_terms) + terms[row, 0] = monomial_terms cost_expr = init_expression(init_from_terms_expr( terms=terms, @@ -182,17 +182,17 @@ def kkt_inequality( for inequality_idx, lambda_variable in enumerate(lambda_variables): - try: - underlying_terms = inequality_der.get_poly(inequality_idx, row) - except KeyError: + polynomial = inequality_der.get_poly(inequality_idx, row) + if polynomial is None: continue - for monomial, value in underlying_terms.items(): + for monomial, value in polynomial.items(): new_monomial = monomial + (lambda_variable,) monomial_terms[new_monomial] += value - terms[row, 0] = dict(monomial_terms) + # terms[row, 0] = dict(monomial_terms) + terms[row, 0] = monomial_terms cost_expr = init_expression(init_from_terms_expr( terms=terms, @@ -210,13 +210,12 @@ def kkt_inequality( r_lambda = lambda_variable + 1 r_inequality = lambda_variable + 2 - try: - underlying_terms = inequality.get_poly(inequality_idx, 0) - except KeyError: + polynomial = inequality.get_poly(inequality_idx, 0) + if polynomial is None: continue # f(x) <= -0.01 - inequality_terms[inequality_idx, 0] = underlying_terms | {(r_inequality, r_inequality): 1} + inequality_terms[inequality_idx, 0] = polynomial | {(r_inequality, r_inequality): 1} # dual feasibility, lambda >= 0 feasibility_terms[inequality_idx, 0] = {(lambda_variable,): 1, (r_lambda, r_lambda): -1} @@ -268,12 +267,11 @@ def rows( terms = {} for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial == None: continue - terms[0, col] = underlying_terms + terms[0, col] = polynomial yield init_expression(underlying=init_from_terms_expr( terms=terms, @@ -481,6 +479,8 @@ def to_matrix_repr( state, ordered_variable_index = get_variable_indices(state, variables) + assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables' + variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)} n_param = len(ordered_variable_index) @@ -497,21 +497,20 @@ def to_matrix_repr( ) for row in range(n_row): - try: - underlying_terms = underlying.get_poly(row, 0) - except KeyError: + underlying_terms = underlying.get_poly(row, 0) + if underlying_terms is None: continue for monomial, value in underlying_terms.items(): def gen_new_monomial(): for var, count in monomial: try: - new_variable = variable_index_map[var] + index = variable_index_map[var] except KeyError: raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}') for _ in range(count): - yield new_variable + yield index new_monomial = tuple(gen_new_monomial()) @@ -580,7 +579,7 @@ def to_constant_repr( A = np.zeros(underlying.shape, dtype=np.double) - for (row, col), polynomial in underlying.get_terms(): + for (row, col), polynomial in underlying.gen_terms(): for monomial, value in polynomial.items(): if len(monomial) == 0: A[row, col] = value @@ -589,7 +588,7 @@ def to_constant_repr( return init_state_monad(func) -def to_sympy_expr( +def to_sympy_repr( expr: Expression, ) -> StateMonadMixin[ExpressionState, sympy.Expr]: @@ -598,7 +597,7 @@ def to_sympy_expr( A = np.zeros(underlying.shape, dtype=np.object) - for (row, col), polynomial in underlying.get_terms(): + for (row, col), polynomial in underlying.gen_terms(): sympy_polynomial = 0 diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py index f338c92..8083eca 100644 --- a/polymatrix/expression/init/initfromtermsexpr.py +++ b/polymatrix/expression/init/initfromtermsexpr.py @@ -9,15 +9,15 @@ def init_from_terms_expr( ): if isinstance(terms, PolyMatrixMixin): shape = terms.shape - terms = terms.get_terms() + gen_terms = terms.gen_terms() else: assert shape is not None + gen_terms = terms - if isinstance(terms, dict): - terms = tuple((key, tuple(value.items())) for key, value in terms.items()) + terms_formatted = tuple((key, tuple(monomials.items())) for key, monomials in gen_terms) return FromTermsExprImpl( - terms=terms, + terms=terms_formatted, shape=shape, ) diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py index 63e1123..db2f2b8 100644 --- a/polymatrix/expression/init/initsubstituteexpr.py +++ b/polymatrix/expression/init/initsubstituteexpr.py @@ -1,5 +1,6 @@ import numpy as np +from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.expression.impl.substituteexprimpl import SubstituteExprImpl @@ -20,8 +21,16 @@ def init_substitute_expr( elif not isinstance(substitutions, tuple): substitutions = (substitutions,) + def gen_substitutions(): + for substitution in substitutions: + match substitution: + case ExpressionBaseMixin(): + yield substitution + case _: + yield init_from_sympy_expr(substitution) + return SubstituteExprImpl( underlying=underlying, variables=variables, - substitutions=substitutions, + substitutions=tuple(gen_substitutions()), ) diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py index 3d2d15b..c4f3113 100644 --- a/polymatrix/expression/mixins/additionexprmixin.py +++ b/polymatrix/expression/mixins/additionexprmixin.py @@ -1,5 +1,6 @@ import abc +import collections import typing import dataclass_abc @@ -29,8 +30,6 @@ class AdditionExprMixin(ExpressionBaseMixin): state, left = self.left.apply(state=state) state, right = self.right.apply(state=state) - terms = {} - if left.shape == (1, 1): left, right = right, left @@ -44,15 +43,12 @@ class AdditionExprMixin(ExpressionBaseMixin): def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: return self.underlying_monomials - try: - underlying_terms = right.get_poly(0, 0) + polynomial = right.get_poly(0, 0) - except KeyError: - pass + if polynomial is not None: - else: broadcasted_right = BroadCastedPolyMatrix( - underlying_monomials=underlying_terms, + underlying_monomials=polynomial, shape=left.shape, ) @@ -63,30 +59,31 @@ class AdditionExprMixin(ExpressionBaseMixin): all_underlying = (left, right) - for underlying in all_underlying: + terms = {} - for row in range(left.shape[0]): - for col in range(left.shape[1]): - - if (row, col) in terms: - terms_row_col = terms[row, col] + for row in range(left.shape[0]): + for col in range(left.shape[1]): - else: - terms_row_col = {} + terms_row_col = {} - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: - continue + for underlying in all_underlying: - for monomial, value in underlying_terms.items(): - - if monomial not in terms_row_col: - terms_row_col[monomial] = 0 + polynomial = underlying.get_poly(row, col) + if polynomial is None: + continue - terms_row_col[monomial] += value + if len(terms_row_col) == 0: + terms_row_col = dict(polynomial) - terms[row, col] = terms_row_col + else: + for monomial, value in polynomial.items(): + + if monomial not in terms_row_col: + terms_row_col[monomial] = value + else: + terms_row_col[monomial] += value + + terms[row, col] = terms_row_col poly_matrix = init_poly_matrix( terms=terms, diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py index 8775fde..eff8de8 100644 --- a/polymatrix/expression/mixins/blockdiagexprmixin.py +++ b/polymatrix/expression/mixins/blockdiagexprmixin.py @@ -45,7 +45,7 @@ class BlockDiagExprMixin(ExpressionBaseMixin): ) else: - raise KeyError() + return None raise Exception(f'row {row} is out of bounds') diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py index ef99644..fe6dbb7 100644 --- a/polymatrix/expression/mixins/cacheexprmixin.py +++ b/polymatrix/expression/mixins/cacheexprmixin.py @@ -1,6 +1,7 @@ import abc import dataclasses +from polymatrix.polymatrix.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -25,7 +26,10 @@ class CacheExprMixin(ExpressionBaseMixin): state, underlying = self.underlying.apply(state) - cached_terms = dict(underlying.get_terms()) + if isinstance(underlying, PolyMatrixAsDictMixin): + cached_terms = underlying.terms + else: + cached_terms = dict(underlying.gen_terms()) poly_matrix = init_poly_matrix( terms=cached_terms, diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py index 87d5f5b..5fce215 100644 --- a/polymatrix/expression/mixins/derivativeexprmixin.py +++ b/polymatrix/expression/mixins/derivativeexprmixin.py @@ -41,9 +41,8 @@ class DerivativeExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): - try: - underlying_terms = underlying.get_poly(row, 0) - except KeyError: + underlying_terms = underlying.get_poly(row, 0) + if underlying_terms is None: continue # derivate each variable and map result to the corresponding column diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py index 391a221..3002109 100644 --- a/polymatrix/expression/mixins/divergenceexprmixin.py +++ b/polymatrix/expression/mixins/divergenceexprmixin.py @@ -38,9 +38,8 @@ class DivergenceExprMixin(ExpressionBaseMixin): for row, variable in enumerate(variables): - try: - underlying_terms = underlying.get_poly(row, 0) - except KeyError: + underlying_terms = underlying.get_poly(row, 0) + if underlying_terms is None: continue state, derivation_terms = get_derivative_monomials( @@ -55,7 +54,8 @@ class DivergenceExprMixin(ExpressionBaseMixin): monomial_terms[monomial] += value poly_matrix = init_poly_matrix( - terms={(0, 0): dict(monomial_terms)}, + # terms={(0, 0): dict(monomial_terms)}, + terms={(0, 0): monomial_terms}, shape=(1, 1), ) diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py index ee9b2ff..d84c09a 100644 --- a/polymatrix/expression/mixins/divisionexprmixin.py +++ b/polymatrix/expression/mixins/divisionexprmixin.py @@ -42,9 +42,8 @@ class DivisionExprMixin(ExpressionBaseMixin): for row in range(left.shape[0]): for col in range(left.shape[1]): - try: - underlying_terms = left.get_poly(row, col) - except KeyError: + underlying_terms = left.get_poly(row, col) + if underlying_terms is None: continue def gen_monomial_terms(): diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py index 24e91fc..408b8c8 100644 --- a/polymatrix/expression/mixins/elemmultexprmixin.py +++ b/polymatrix/expression/mixins/elemmultexprmixin.py @@ -97,14 +97,12 @@ class ElemMultExprMixin(ExpressionBaseMixin): terms_row_col = {} - try: - left_terms = left.get_poly(poly_row, poly_col) - except KeyError: + left_terms = left.get_poly(poly_row, poly_col) + if left_terms is None: continue - try: - right_terms = right.get_poly(poly_row, poly_col) - except KeyError: + right_terms = right.get_poly(poly_row, poly_col) + if right_terms is None: continue for (left_monomial, left_value), (right_monomial, right_value) \ diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py index ca1cf7d..2358566 100644 --- a/polymatrix/expression/mixins/evalexprmixin.py +++ b/polymatrix/expression/mixins/evalexprmixin.py @@ -47,9 +47,8 @@ class EvalExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + underlying_terms = underlying.get_poly(row, col) + if underlying_terms is None: continue terms_row_col = {} diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py index cf4a674..5602294 100644 --- a/polymatrix/expression/mixins/expressionmixin.py +++ b/polymatrix/expression/mixins/expressionmixin.py @@ -27,6 +27,7 @@ from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr from polymatrix.expression.init.initreshapeexpr import init_reshape_expr from polymatrix.expression.init.initsetelementatexpr import init_set_element_at_expr from polymatrix.expression.init.initquadraticmonomialsexpr import init_quadratic_monomials_expr +from polymatrix.expression.init.initsubstituteexpr import init_substitute_expr from polymatrix.expression.init.initsubtractmonomialsexpr import init_subtract_monomials_expr from polymatrix.expression.init.initsumexpr import init_sum_expr from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr @@ -61,7 +62,7 @@ class ExpressionMixin( match other: case ExpressionBaseMixin(): - right = other.underlying + right = other case _: right = init_from_sympy_expr(other) @@ -97,7 +98,7 @@ class ExpressionMixin( def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin': match other: case ExpressionBaseMixin(): - right = other.underlying + right = other case _: right = init_from_sympy_expr(other) @@ -114,7 +115,7 @@ class ExpressionMixin( match other: case ExpressionBaseMixin(): - right = other.underlying + right = other case _: right = init_from_sympy_expr(other) @@ -146,7 +147,7 @@ class ExpressionMixin( def __truediv__(self, other: ExpressionBaseMixin): match other: case ExpressionBaseMixin(): - right = other.underlying + right = other case _: right = init_from_sympy_expr(other) @@ -350,14 +351,6 @@ class ExpressionMixin( ), ) - # def squeeze(self): - # return dataclasses.replace( - # self, - # underlying=init_squeeze_expr( - # underlying=self.underlying, - # ), - # ) - def set_element_at( self, row: int, @@ -397,13 +390,23 @@ class ExpressionMixin( ) -> 'ExpressionMixin': return dataclasses.replace( self, - underlying=init_eval_expr( + underlying=init_substitute_expr( underlying=self.underlying, variables=variable, substitutions=substitutions, ), ) + def subs( + self, + variable: tuple, + substitutions: tuple['ExpressionMixin', ...] = None, + ) -> 'ExpressionMixin': + return self.substitute( + variable=variable, + substitutions=substitutions, + ) + def sum(self): return dataclasses.replace( self, diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py index f904feb..47e83f9 100644 --- a/polymatrix/expression/mixins/eyeexprmixin.py +++ b/polymatrix/expression/mixins/eyeexprmixin.py @@ -33,15 +33,15 @@ class EyeExprMixin(ExpressionBaseMixin): return {tuple(): 1.0} else: - raise KeyError() + return None else: raise Exception(f'{(row, col)=} is out of bounds') - value = variable.shape[0] + n_row = variable.shape[0] polymatrix = EyePolyMatrix( - shape=(value, value), + shape=(n_row, n_row), ) return state, polymatrix diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py index ad13814..ff2d333 100644 --- a/polymatrix/expression/mixins/filterexprmixin.py +++ b/polymatrix/expression/mixins/filterexprmixin.py @@ -44,9 +44,8 @@ class FilterExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): - try: - underlying_terms = underlying.get_poly(row, 0) - except KeyError: + underlying_terms = underlying.get_poly(row, 0) + if underlying_terms is None: continue predicator_value = predicator.get_poly(row, 0)[tuple()] diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py index 2a83742..6b80006 100644 --- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py +++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py @@ -29,7 +29,7 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin): state, variables = self.variables.apply(state=state) def gen_variable_monomials(): - for _, term in variables.get_terms(): + for _, term in variables.gen_terms(): assert len(term) == 1, f'{term} should have only a single monomial' for monomial in term.keys(): @@ -42,9 +42,8 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + underlying_terms = underlying.get_poly(row, col) + if underlying_terms is None: continue monomial_terms = collections.defaultdict(float) @@ -68,7 +67,8 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin): break - terms[row, col] = dict(monomial_terms) + terms[row, col] = monomial_terms + # terms[row, col] = dict(monomial_terms) poly_matrix = init_poly_matrix( terms=terms, diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index f6953e1..8853022 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -28,12 +28,15 @@ class FromSympyExprMixin(ExpressionBaseMixin): try: poly = sympy.poly(poly_data) - except sympy.polys.polyerrors.GeneratorsNeeded: + except sympy.polys.polyerrors.GeneratorsNeeded: if not math.isclose(poly_data, 0): terms[poly_row, poly_col] = {tuple(): poly_data} continue + except ValueError: + raise ValueError(f'{poly_data=}') + for symbol in poly.gens: state = state.register(key=symbol, n_param=1) # print(f'{symbol}: {state.n_param}') diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py index 8fd9be6..9d1c220 100644 --- a/polymatrix/expression/mixins/fromtermsexprmixin.py +++ b/polymatrix/expression/mixins/fromtermsexprmixin.py @@ -12,7 +12,7 @@ from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin class FromTermsExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def terms(self) -> tuple[tuple[tuple[float], float], ...]: + def terms(self) -> tuple[tuple[tuple[tuple[int, int], ...], float], ...]: pass @property diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py index 1300903..cd94761 100644 --- a/polymatrix/expression/mixins/linearinexprmixin.py +++ b/polymatrix/expression/mixins/linearinexprmixin.py @@ -47,12 +47,11 @@ class LinearInExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): - try: - underlying_terms = underlying.get_poly(row, 0) - except KeyError: + polynomial = underlying.get_poly(row, 0) + if polynomial is None: continue - for monomial, value in underlying_terms.items(): + for monomial, value in polynomial.items(): x_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx in variable_indices) p_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx not in variable_indices) diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py index 1358e1e..d013722 100644 --- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py +++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py @@ -36,9 +36,8 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + underlying_terms = underlying.get_poly(row, col) + if underlying_terms is None: continue for monomial, value in underlying_terms.items(): diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py index f53669e..b309ccf 100644 --- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py @@ -34,9 +34,8 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - polynomial = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue for monomial in polynomial.keys(): diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py index d023f1e..a02b4f1 100644 --- a/polymatrix/expression/mixins/matrixmultexprmixin.py +++ b/polymatrix/expression/mixins/matrixmultexprmixin.py @@ -41,33 +41,18 @@ class MatrixMultExprMixin(ExpressionBaseMixin): for index_k in range(left.shape[1]): - try: - left_terms = left.get_poly(poly_row, index_k) - right_terms = right.get_poly(index_k, poly_col) - except KeyError: + left_terms = left.get_poly(poly_row, index_k) + if left_terms is None: continue - # for (left_monomial, left_value), (right_monomial, right_value) \ - # in itertools.product(left_terms.items(), right_terms.items()): - - # value = left_value * right_value - - # if value == 0: - # continue - - # # monomial = tuple(sorted(left_monomial + right_monomial)) - # monomial = merge_monomial_indices((left_monomial, right_monomial)) - - # if monomial not in terms_row_col: - # terms_row_col[monomial] = 0 - - # terms_row_col[monomial] += value + right_terms = right.get_poly(index_k, poly_col) + if right_terms is None: + continue multiply_polynomial(left_terms, right_terms, terms_row_col) if 0 < len(terms_row_col): terms[poly_row, poly_col] = terms_row_col - # terms[poly_row, poly_col] = {key: val for key, val in terms_row_col.items() if not math.isclose(val, 0, abs_tol=1e-12)} poly_matrix = init_poly_matrix( terms=terms, diff --git a/polymatrix/expression/mixins/maxdegreeexprmixin.py b/polymatrix/expression/mixins/maxdegreeexprmixin.py index 7502cbb..b21c157 100644 --- a/polymatrix/expression/mixins/maxdegreeexprmixin.py +++ b/polymatrix/expression/mixins/maxdegreeexprmixin.py @@ -25,9 +25,8 @@ class MaxDegreeExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + underlying_terms = underlying.get_poly(row, col) + if underlying_terms is None: continue def gen_degrees(): diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py index 21fd6da..2ab4cad 100644 --- a/polymatrix/expression/mixins/maxexprmixin.py +++ b/polymatrix/expression/mixins/maxexprmixin.py @@ -28,12 +28,11 @@ class MaxExprMixin(ExpressionBaseMixin): def gen_values(): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue - yield underlying_terms[tuple()] + yield polynomial[tuple()] values = tuple(gen_values()) diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py index ea71185..130ab79 100644 --- a/polymatrix/expression/mixins/quadraticinexprmixin.py +++ b/polymatrix/expression/mixins/quadraticinexprmixin.py @@ -24,7 +24,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod - def variables(self) -> tuple: + def variables(self) -> ExpressionBaseMixin: ... # overwrites abstract method of `ExpressionBaseMixin` @@ -39,11 +39,10 @@ class QuadraticInExprMixin(ExpressionBaseMixin): assert underlying.shape == (1, 1), f'underlying shape is {underlying.shape}' - underlying_terms = underlying.get_poly(0, 0) - terms = collections.defaultdict(dict) + terms = collections.defaultdict(lambda: collections.defaultdict(float)) - for monomial, value in underlying_terms.items(): + for monomial, value in underlying.get_poly(0, 0).items(): x_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx in variable_indices) p_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx not in variable_indices) @@ -60,12 +59,11 @@ class QuadraticInExprMixin(ExpressionBaseMixin): except ValueError: raise ValueError(f'{right=} not in {sos_monomials=}') - monomial_terms = terms[row, col] - - if p_monomial not in monomial_terms: - monomial_terms[p_monomial] = 0 + # monomial_terms = terms[row, col] + # if p_monomial not in monomial_terms: + # monomial_terms[p_monomial] = 0 - monomial_terms[p_monomial] += value + terms[row, col][p_monomial] += value poly_matrix = init_poly_matrix( terms=dict(terms), diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py index c53741b..bddc321 100644 --- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py @@ -34,9 +34,8 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - polynomial = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue for monomial in polynomial.keys(): diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py index 2bb0bfe..a111d20 100644 --- a/polymatrix/expression/mixins/setelementatexprmixin.py +++ b/polymatrix/expression/mixins/setelementatexprmixin.py @@ -39,25 +39,24 @@ class SetElementAtExprMixin(ExpressionBaseMixin): state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) - state, value_expr = self.value.apply(state=state) + state, polynomial_expr = self.value.apply(state=state) - assert value_expr.shape == (1, 1) + assert polynomial_expr.shape == (1, 1) - try: - value = value_expr.get_poly(0, 0) - except KeyError: - value = 0 + polynomial = polynomial_expr.get_poly(0, 0) + if polynomial is None: + polynomial = 0 @dataclass_abc.dataclass_abc(frozen=True) class SetElementAtPolyMatrix(PolyMatrixMixin): underlying: PolyMatrixMixin shape: tuple[int, int] index: tuple[int, int] - value: dict + polynomial: dict def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: if (row, col) == self.index: - return self.value + return self.polynomial else: return self.underlying.get_poly(row, col) @@ -65,6 +64,6 @@ class SetElementAtExprMixin(ExpressionBaseMixin): underlying=underlying, index=self.index, shape=underlying.shape, - value=value, + polynomial=polynomial, )
\ No newline at end of file diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py index 2f46e23..14cf1f3 100644 --- a/polymatrix/expression/mixins/squeezeexprmixin.py +++ b/polymatrix/expression/mixins/squeezeexprmixin.py @@ -31,14 +31,13 @@ class SqueezeExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): - try: - underlying_terms = underlying.get_poly(row, 0) - except KeyError: + polynomial = underlying.get_poly(row, 0) + if polynomial is None: continue terms_row_col = {} - for monomial, value in underlying_terms.items(): + for monomial, value in polynomial.items(): if value != 0.0: terms_row_col[monomial] = value diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py index 1ba8a1a..9741897 100644 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ b/polymatrix/expression/mixins/substituteexprmixin.py @@ -50,6 +50,9 @@ class SubstituteExprMixin(ExpressionBaseMixin): elif isinstance(substitution_expr, int) or isinstance(substitution_expr, float): polynomial = {tuple(): substitution_expr} + else: + raise Exception(f'{substitution_expr=} not recognized') + return state, result + (polynomial,) *_, (state, substitutions) = tuple(itertools.accumulate( @@ -63,14 +66,13 @@ class SubstituteExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue terms_row_col = collections.defaultdict(float) - for monomial, value in underlying_terms.items(): + for monomial, value in polynomial.items(): terms_monomial = {tuple(): value} diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py index 75124d9..9957f10 100644 --- a/polymatrix/expression/mixins/sumexprmixin.py +++ b/polymatrix/expression/mixins/sumexprmixin.py @@ -28,14 +28,13 @@ class SumExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue term_monomials = terms[row, 0] - for monomial, value in underlying_terms.items(): + for monomial, value in polynomial.items(): if monomial in term_monomials: term_monomials[monomial] += value else: diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py index 989cb0e..e41e889 100644 --- a/polymatrix/expression/mixins/symmetricexprmixin.py +++ b/polymatrix/expression/mixins/symmetricexprmixin.py @@ -38,17 +38,15 @@ class SymmetricExprMixin(ExpressionBaseMixin): def gen_symmetric_monomials(): for i_row, i_col in ((row, col), (col, row)): - try: - monomials = self.underlying.get_poly(i_row, i_col) - except: - pass - else: - yield monomials + polynomial = self.underlying.get_poly(i_row, i_col) + + if polynomial is not None: + yield polynomial all_monomials = tuple(gen_symmetric_monomials()) if len(all_monomials) == 0: - raise KeyError() + return None else: terms = collections.defaultdict(float) @@ -58,7 +56,8 @@ class SymmetricExprMixin(ExpressionBaseMixin): for monomial, value in monomials.items(): terms[monomial] += value / 2 - return dict(terms) + # return dict(terms) + return terms polymatrix = SymmetricPolyMatrix( underlying=underlying, diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py index 547b012..ca928c6 100644 --- a/polymatrix/expression/mixins/toconstantexprmixin.py +++ b/polymatrix/expression/mixins/toconstantexprmixin.py @@ -26,13 +26,12 @@ class ToConstantExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue - if tuple() in underlying_terms: - terms[row, col][tuple()] = underlying_terms[tuple()] + if tuple() in polynomial: + terms[row, col][tuple()] = polynomial[tuple()] poly_matrix = init_poly_matrix( terms=dict(terms), diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py index bbb0cbf..2bcaacb 100644 --- a/polymatrix/expression/mixins/toquadraticexprmixin.py +++ b/polymatrix/expression/mixins/toquadraticexprmixin.py @@ -54,18 +54,18 @@ class ToQuadraticExprMixin(ExpressionBaseMixin): else: terms_row_col[monomial] += value - return dict(terms_row_col) + # return dict(terms_row_col) + return terms_row_col for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue terms_row_col = to_quadratic( - monomial_terms=underlying_terms, + monomial_terms=polynomial, ) # terms_row_col = collections.defaultdict(float) diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py index 42f192e..e144ae5 100644 --- a/polymatrix/expression/mixins/truncateexprmixin.py +++ b/polymatrix/expression/mixins/truncateexprmixin.py @@ -42,14 +42,13 @@ class TruncateExprMixin(ExpressionBaseMixin): for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): - try: - underlying_terms = underlying.get_poly(row, col) - except KeyError: + polynomial = underlying.get_poly(row, col) + if polynomial is None: continue terms_row_col = {} - for monomial, value in underlying_terms.items(): + for monomial, value in polynomial.items(): degree = sum((count for var_idx, count in monomial if var_idx in variable_indices)) diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py index cec2d81..4aa7144 100644 --- a/polymatrix/expression/utils/getderivativemonomials.py +++ b/polymatrix/expression/utils/getderivativemonomials.py @@ -72,8 +72,6 @@ def get_derivative_monomials( for monomial, value in monomial_terms.items(): - # # count powers for each variable - # monomial_cnt = dict(collections.Counter(monomial)) monomial_cnt = dict(monomial) def differentiate_monomial(dependent_variable, derivation_variable=None): @@ -117,4 +115,5 @@ def get_derivative_monomials( # ) # derivation_terms[diff_monomial] += value - return state, dict(derivation_terms)
\ No newline at end of file + # return state, dict(derivation_terms) + return state, derivation_terms
\ No newline at end of file diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index 08b44d3..aaf5b0b 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -35,4 +35,6 @@ def get_variable_indices(state, variables): else: yield global_state[0].offset_dict[variable][0] - return global_state[0], tuple(gen_indices()) + indices = tuple(gen_indices()) + + return global_state[0], indices diff --git a/polymatrix/expression/utils/splitmonomialindices.py b/polymatrix/expression/utils/splitmonomialindices.py index 82e5454..4ffa6ff 100644 --- a/polymatrix/expression/utils/splitmonomialindices.py +++ b/polymatrix/expression/utils/splitmonomialindices.py @@ -7,8 +7,7 @@ def split_monomial_indices(monomial): for idx, count in monomial: count_left = count // 2 - is_uneven = count % 2 - if is_uneven: + if count % 2: if is_left: count_left = count_left + 1 @@ -21,7 +20,5 @@ def split_monomial_indices(monomial): if 0 < count_right: right.append((idx, count - count_left)) - - # print((monomial, tuple(left), tuple(right))) return tuple(left), tuple(right)
\ No newline at end of file diff --git a/polymatrix/expression/mixins/polymatrixasdictmixin.py b/polymatrix/polymatrix/mixins/polymatrixasdictmixin.py index 69c59ac..93b6385 100644 --- a/polymatrix/expression/mixins/polymatrixasdictmixin.py +++ b/polymatrix/polymatrix/mixins/polymatrixasdictmixin.py @@ -1,4 +1,5 @@ import abc +import typing from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin @@ -13,8 +14,6 @@ class PolyMatrixAsDictMixin( ... # overwrites abstract method of `PolyMatrixMixin` - def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]: - try: + def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: + if (row, col) in self.terms: return self.terms[row, col] - except KeyError: - raise KeyError(f'{(row, col)} is not a key of {self.terms}') diff --git a/polymatrix/polymatrix/mixins/polymatrixmixin.py b/polymatrix/polymatrix/mixins/polymatrixmixin.py index bea80de..e67a6fa 100644 --- a/polymatrix/polymatrix/mixins/polymatrixmixin.py +++ b/polymatrix/polymatrix/mixins/polymatrixmixin.py @@ -8,18 +8,14 @@ class PolyMatrixMixin(abc.ABC): def shape(self) -> tuple[int, int]: ... - def get_terms(self) -> tuple[tuple[int, int], dict[tuple[int, ...], float]]: - def gen_terms(): - for row in range(self.shape[0]): - for col in range(self.shape[1]): - try: - monomial_terms = self.get_poly(row, col) - except KeyError: - continue - - yield (row, col), monomial_terms - - return tuple(gen_terms()) + def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], dict[tuple[int, ...], float]], None, None]: + for row in range(self.shape[0]): + for col in range(self.shape[1]): + polynomial = self.get_poly(row, col) + if polynomial is None: + continue + + yield (row, col), polynomial @abc.abstractclassmethod def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]: diff --git a/polymatrix/polymatrix/polymatrix.py b/polymatrix/polymatrix/polymatrix.py index c044081..b8523f6 100644 --- a/polymatrix/polymatrix/polymatrix.py +++ b/polymatrix/polymatrix/polymatrix.py @@ -1,4 +1,4 @@ -from polymatrix.expression.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin +from polymatrix.polymatrix.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin class PolyMatrix(PolyMatrixAsDictMixin): pass |