summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py49
-rw-r--r--polymatrix/expression/init/initfromtermsexpr.py8
-rw-r--r--polymatrix/expression/init/initsubstituteexpr.py11
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py49
-rw-r--r--polymatrix/expression/mixins/blockdiagexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/cacheexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/divergenceexprmixin.py8
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py29
-rw-r--r--polymatrix/expression/mixins/eyeexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/filterexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/filterlinearpartexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/fromsympyexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/fromtermsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/linearmatrixinexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/linearmonomialsexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py25
-rw-r--r--polymatrix/expression/mixins/maxdegreeexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/maxexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py16
-rw-r--r--polymatrix/expression/mixins/quadraticmonomialsexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/setelementatexprmixin.py17
-rw-r--r--polymatrix/expression/mixins/squeezeexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/substituteexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/sumexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/symmetricexprmixin.py15
-rw-r--r--polymatrix/expression/mixins/toconstantexprmixin.py9
-rw-r--r--polymatrix/expression/mixins/toquadraticexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/truncateexprmixin.py7
-rw-r--r--polymatrix/expression/utils/getderivativemonomials.py5
-rw-r--r--polymatrix/expression/utils/getvariableindices.py4
-rw-r--r--polymatrix/expression/utils/splitmonomialindices.py5
-rw-r--r--polymatrix/polymatrix/mixins/polymatrixasdictmixin.py (renamed from polymatrix/expression/mixins/polymatrixasdictmixin.py)7
-rw-r--r--polymatrix/polymatrix/mixins/polymatrixmixin.py20
-rw-r--r--polymatrix/polymatrix/polymatrix.py2
39 files changed, 195 insertions, 220 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 78f27b8..d0673e4 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -112,9 +112,8 @@ def kkt_equality(
for eq_idx, nu_variable in enumerate(nu_variables):
- try:
- underlying_terms = equality_der.get_poly(eq_idx, row)
- except KeyError:
+ underlying_terms = equality_der.get_poly(eq_idx, row)
+ if underlying_terms is None:
continue
for monomial, value in underlying_terms.items():
@@ -122,7 +121,8 @@ def kkt_equality(
monomial_terms[new_monomial] += value
- terms[row, 0] = dict(monomial_terms)
+ # terms[row, 0] = dict(monomial_terms)
+ terms[row, 0] = monomial_terms
cost_expr = init_expression(init_from_terms_expr(
terms=terms,
@@ -182,17 +182,17 @@ def kkt_inequality(
for inequality_idx, lambda_variable in enumerate(lambda_variables):
- try:
- underlying_terms = inequality_der.get_poly(inequality_idx, row)
- except KeyError:
+ polynomial = inequality_der.get_poly(inequality_idx, row)
+ if polynomial is None:
continue
- for monomial, value in underlying_terms.items():
+ for monomial, value in polynomial.items():
new_monomial = monomial + (lambda_variable,)
monomial_terms[new_monomial] += value
- terms[row, 0] = dict(monomial_terms)
+ # terms[row, 0] = dict(monomial_terms)
+ terms[row, 0] = monomial_terms
cost_expr = init_expression(init_from_terms_expr(
terms=terms,
@@ -210,13 +210,12 @@ def kkt_inequality(
r_lambda = lambda_variable + 1
r_inequality = lambda_variable + 2
- try:
- underlying_terms = inequality.get_poly(inequality_idx, 0)
- except KeyError:
+ polynomial = inequality.get_poly(inequality_idx, 0)
+ if polynomial is None:
continue
# f(x) <= -0.01
- inequality_terms[inequality_idx, 0] = underlying_terms | {(r_inequality, r_inequality): 1}
+ inequality_terms[inequality_idx, 0] = polynomial | {(r_inequality, r_inequality): 1}
# dual feasibility, lambda >= 0
feasibility_terms[inequality_idx, 0] = {(lambda_variable,): 1, (r_lambda, r_lambda): -1}
@@ -268,12 +267,11 @@ def rows(
terms = {}
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial == None:
continue
- terms[0, col] = underlying_terms
+ terms[0, col] = polynomial
yield init_expression(underlying=init_from_terms_expr(
terms=terms,
@@ -481,6 +479,8 @@ def to_matrix_repr(
state, ordered_variable_index = get_variable_indices(state, variables)
+ assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
+
variable_index_map = {old: new for new, old in enumerate(ordered_variable_index)}
n_param = len(ordered_variable_index)
@@ -497,21 +497,20 @@ def to_matrix_repr(
)
for row in range(n_row):
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, 0)
+ if underlying_terms is None:
continue
for monomial, value in underlying_terms.items():
def gen_new_monomial():
for var, count in monomial:
try:
- new_variable = variable_index_map[var]
+ index = variable_index_map[var]
except KeyError:
raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}')
for _ in range(count):
- yield new_variable
+ yield index
new_monomial = tuple(gen_new_monomial())
@@ -580,7 +579,7 @@ def to_constant_repr(
A = np.zeros(underlying.shape, dtype=np.double)
- for (row, col), polynomial in underlying.get_terms():
+ for (row, col), polynomial in underlying.gen_terms():
for monomial, value in polynomial.items():
if len(monomial) == 0:
A[row, col] = value
@@ -589,7 +588,7 @@ def to_constant_repr(
return init_state_monad(func)
-def to_sympy_expr(
+def to_sympy_repr(
expr: Expression,
) -> StateMonadMixin[ExpressionState, sympy.Expr]:
@@ -598,7 +597,7 @@ def to_sympy_expr(
A = np.zeros(underlying.shape, dtype=np.object)
- for (row, col), polynomial in underlying.get_terms():
+ for (row, col), polynomial in underlying.gen_terms():
sympy_polynomial = 0
diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py
index f338c92..8083eca 100644
--- a/polymatrix/expression/init/initfromtermsexpr.py
+++ b/polymatrix/expression/init/initfromtermsexpr.py
@@ -9,15 +9,15 @@ def init_from_terms_expr(
):
if isinstance(terms, PolyMatrixMixin):
shape = terms.shape
- terms = terms.get_terms()
+ gen_terms = terms.gen_terms()
else:
assert shape is not None
+ gen_terms = terms
- if isinstance(terms, dict):
- terms = tuple((key, tuple(value.items())) for key, value in terms.items())
+ terms_formatted = tuple((key, tuple(monomials.items())) for key, monomials in gen_terms)
return FromTermsExprImpl(
- terms=terms,
+ terms=terms_formatted,
shape=shape,
)
diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py
index 63e1123..db2f2b8 100644
--- a/polymatrix/expression/init/initsubstituteexpr.py
+++ b/polymatrix/expression/init/initsubstituteexpr.py
@@ -1,5 +1,6 @@
import numpy as np
+from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl.substituteexprimpl import SubstituteExprImpl
@@ -20,8 +21,16 @@ def init_substitute_expr(
elif not isinstance(substitutions, tuple):
substitutions = (substitutions,)
+ def gen_substitutions():
+ for substitution in substitutions:
+ match substitution:
+ case ExpressionBaseMixin():
+ yield substitution
+ case _:
+ yield init_from_sympy_expr(substitution)
+
return SubstituteExprImpl(
underlying=underlying,
variables=variables,
- substitutions=substitutions,
+ substitutions=tuple(gen_substitutions()),
)
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 3d2d15b..c4f3113 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,5 +1,6 @@
import abc
+import collections
import typing
import dataclass_abc
@@ -29,8 +30,6 @@ class AdditionExprMixin(ExpressionBaseMixin):
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- terms = {}
-
if left.shape == (1, 1):
left, right = right, left
@@ -44,15 +43,12 @@ class AdditionExprMixin(ExpressionBaseMixin):
def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
return self.underlying_monomials
- try:
- underlying_terms = right.get_poly(0, 0)
+ polynomial = right.get_poly(0, 0)
- except KeyError:
- pass
+ if polynomial is not None:
- else:
broadcasted_right = BroadCastedPolyMatrix(
- underlying_monomials=underlying_terms,
+ underlying_monomials=polynomial,
shape=left.shape,
)
@@ -63,30 +59,31 @@ class AdditionExprMixin(ExpressionBaseMixin):
all_underlying = (left, right)
- for underlying in all_underlying:
+ terms = {}
- for row in range(left.shape[0]):
- for col in range(left.shape[1]):
-
- if (row, col) in terms:
- terms_row_col = terms[row, col]
+ for row in range(left.shape[0]):
+ for col in range(left.shape[1]):
- else:
- terms_row_col = {}
+ terms_row_col = {}
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
- continue
+ for underlying in all_underlying:
- for monomial, value in underlying_terms.items():
-
- if monomial not in terms_row_col:
- terms_row_col[monomial] = 0
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
+ continue
- terms_row_col[monomial] += value
+ if len(terms_row_col) == 0:
+ terms_row_col = dict(polynomial)
- terms[row, col] = terms_row_col
+ else:
+ for monomial, value in polynomial.items():
+
+ if monomial not in terms_row_col:
+ terms_row_col[monomial] = value
+ else:
+ terms_row_col[monomial] += value
+
+ terms[row, col] = terms_row_col
poly_matrix = init_poly_matrix(
terms=terms,
diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py
index 8775fde..eff8de8 100644
--- a/polymatrix/expression/mixins/blockdiagexprmixin.py
+++ b/polymatrix/expression/mixins/blockdiagexprmixin.py
@@ -45,7 +45,7 @@ class BlockDiagExprMixin(ExpressionBaseMixin):
)
else:
- raise KeyError()
+ return None
raise Exception(f'row {row} is out of bounds')
diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py
index ef99644..fe6dbb7 100644
--- a/polymatrix/expression/mixins/cacheexprmixin.py
+++ b/polymatrix/expression/mixins/cacheexprmixin.py
@@ -1,6 +1,7 @@
import abc
import dataclasses
+from polymatrix.polymatrix.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin
from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -25,7 +26,10 @@ class CacheExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state)
- cached_terms = dict(underlying.get_terms())
+ if isinstance(underlying, PolyMatrixAsDictMixin):
+ cached_terms = underlying.terms
+ else:
+ cached_terms = dict(underlying.gen_terms())
poly_matrix = init_poly_matrix(
terms=cached_terms,
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 87d5f5b..5fce215 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -41,9 +41,8 @@ class DerivativeExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, 0)
+ if underlying_terms is None:
continue
# derivate each variable and map result to the corresponding column
diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py
index 391a221..3002109 100644
--- a/polymatrix/expression/mixins/divergenceexprmixin.py
+++ b/polymatrix/expression/mixins/divergenceexprmixin.py
@@ -38,9 +38,8 @@ class DivergenceExprMixin(ExpressionBaseMixin):
for row, variable in enumerate(variables):
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, 0)
+ if underlying_terms is None:
continue
state, derivation_terms = get_derivative_monomials(
@@ -55,7 +54,8 @@ class DivergenceExprMixin(ExpressionBaseMixin):
monomial_terms[monomial] += value
poly_matrix = init_poly_matrix(
- terms={(0, 0): dict(monomial_terms)},
+ # terms={(0, 0): dict(monomial_terms)},
+ terms={(0, 0): monomial_terms},
shape=(1, 1),
)
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index ee9b2ff..d84c09a 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -42,9 +42,8 @@ class DivisionExprMixin(ExpressionBaseMixin):
for row in range(left.shape[0]):
for col in range(left.shape[1]):
- try:
- underlying_terms = left.get_poly(row, col)
- except KeyError:
+ underlying_terms = left.get_poly(row, col)
+ if underlying_terms is None:
continue
def gen_monomial_terms():
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index 24e91fc..408b8c8 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -97,14 +97,12 @@ class ElemMultExprMixin(ExpressionBaseMixin):
terms_row_col = {}
- try:
- left_terms = left.get_poly(poly_row, poly_col)
- except KeyError:
+ left_terms = left.get_poly(poly_row, poly_col)
+ if left_terms is None:
continue
- try:
- right_terms = right.get_poly(poly_row, poly_col)
- except KeyError:
+ right_terms = right.get_poly(poly_row, poly_col)
+ if right_terms is None:
continue
for (left_monomial, left_value), (right_monomial, right_value) \
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index ca1cf7d..2358566 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -47,9 +47,8 @@ class EvalExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, col)
+ if underlying_terms is None:
continue
terms_row_col = {}
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index cf4a674..5602294 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -27,6 +27,7 @@ from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr
from polymatrix.expression.init.initreshapeexpr import init_reshape_expr
from polymatrix.expression.init.initsetelementatexpr import init_set_element_at_expr
from polymatrix.expression.init.initquadraticmonomialsexpr import init_quadratic_monomials_expr
+from polymatrix.expression.init.initsubstituteexpr import init_substitute_expr
from polymatrix.expression.init.initsubtractmonomialsexpr import init_subtract_monomials_expr
from polymatrix.expression.init.initsumexpr import init_sum_expr
from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr
@@ -61,7 +62,7 @@ class ExpressionMixin(
match other:
case ExpressionBaseMixin():
- right = other.underlying
+ right = other
case _:
right = init_from_sympy_expr(other)
@@ -97,7 +98,7 @@ class ExpressionMixin(
def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin':
match other:
case ExpressionBaseMixin():
- right = other.underlying
+ right = other
case _:
right = init_from_sympy_expr(other)
@@ -114,7 +115,7 @@ class ExpressionMixin(
match other:
case ExpressionBaseMixin():
- right = other.underlying
+ right = other
case _:
right = init_from_sympy_expr(other)
@@ -146,7 +147,7 @@ class ExpressionMixin(
def __truediv__(self, other: ExpressionBaseMixin):
match other:
case ExpressionBaseMixin():
- right = other.underlying
+ right = other
case _:
right = init_from_sympy_expr(other)
@@ -350,14 +351,6 @@ class ExpressionMixin(
),
)
- # def squeeze(self):
- # return dataclasses.replace(
- # self,
- # underlying=init_squeeze_expr(
- # underlying=self.underlying,
- # ),
- # )
-
def set_element_at(
self,
row: int,
@@ -397,13 +390,23 @@ class ExpressionMixin(
) -> 'ExpressionMixin':
return dataclasses.replace(
self,
- underlying=init_eval_expr(
+ underlying=init_substitute_expr(
underlying=self.underlying,
variables=variable,
substitutions=substitutions,
),
)
+ def subs(
+ self,
+ variable: tuple,
+ substitutions: tuple['ExpressionMixin', ...] = None,
+ ) -> 'ExpressionMixin':
+ return self.substitute(
+ variable=variable,
+ substitutions=substitutions,
+ )
+
def sum(self):
return dataclasses.replace(
self,
diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py
index f904feb..47e83f9 100644
--- a/polymatrix/expression/mixins/eyeexprmixin.py
+++ b/polymatrix/expression/mixins/eyeexprmixin.py
@@ -33,15 +33,15 @@ class EyeExprMixin(ExpressionBaseMixin):
return {tuple(): 1.0}
else:
- raise KeyError()
+ return None
else:
raise Exception(f'{(row, col)=} is out of bounds')
- value = variable.shape[0]
+ n_row = variable.shape[0]
polymatrix = EyePolyMatrix(
- shape=(value, value),
+ shape=(n_row, n_row),
)
return state, polymatrix
diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py
index ad13814..ff2d333 100644
--- a/polymatrix/expression/mixins/filterexprmixin.py
+++ b/polymatrix/expression/mixins/filterexprmixin.py
@@ -44,9 +44,8 @@ class FilterExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, 0)
+ if underlying_terms is None:
continue
predicator_value = predicator.get_poly(row, 0)[tuple()]
diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
index 2a83742..6b80006 100644
--- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py
+++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
@@ -29,7 +29,7 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin):
state, variables = self.variables.apply(state=state)
def gen_variable_monomials():
- for _, term in variables.get_terms():
+ for _, term in variables.gen_terms():
assert len(term) == 1, f'{term} should have only a single monomial'
for monomial in term.keys():
@@ -42,9 +42,8 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, col)
+ if underlying_terms is None:
continue
monomial_terms = collections.defaultdict(float)
@@ -68,7 +67,8 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin):
break
- terms[row, col] = dict(monomial_terms)
+ terms[row, col] = monomial_terms
+ # terms[row, col] = dict(monomial_terms)
poly_matrix = init_poly_matrix(
terms=terms,
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py
index f6953e1..8853022 100644
--- a/polymatrix/expression/mixins/fromsympyexprmixin.py
+++ b/polymatrix/expression/mixins/fromsympyexprmixin.py
@@ -28,12 +28,15 @@ class FromSympyExprMixin(ExpressionBaseMixin):
try:
poly = sympy.poly(poly_data)
- except sympy.polys.polyerrors.GeneratorsNeeded:
+ except sympy.polys.polyerrors.GeneratorsNeeded:
if not math.isclose(poly_data, 0):
terms[poly_row, poly_col] = {tuple(): poly_data}
continue
+ except ValueError:
+ raise ValueError(f'{poly_data=}')
+
for symbol in poly.gens:
state = state.register(key=symbol, n_param=1)
# print(f'{symbol}: {state.n_param}')
diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py
index 8fd9be6..9d1c220 100644
--- a/polymatrix/expression/mixins/fromtermsexprmixin.py
+++ b/polymatrix/expression/mixins/fromtermsexprmixin.py
@@ -12,7 +12,7 @@ from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
class FromTermsExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def terms(self) -> tuple[tuple[tuple[float], float], ...]:
+ def terms(self) -> tuple[tuple[tuple[tuple[int, int], ...], float], ...]:
pass
@property
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index 1300903..cd94761 100644
--- a/polymatrix/expression/mixins/linearinexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -47,12 +47,11 @@ class LinearInExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
+ polynomial = underlying.get_poly(row, 0)
+ if polynomial is None:
continue
- for monomial, value in underlying_terms.items():
+ for monomial, value in polynomial.items():
x_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx in variable_indices)
p_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx not in variable_indices)
diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
index 1358e1e..d013722 100644
--- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py
+++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
@@ -36,9 +36,8 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, col)
+ if underlying_terms is None:
continue
for monomial, value in underlying_terms.items():
diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
index f53669e..b309ccf 100644
--- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
@@ -34,9 +34,8 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- polynomial = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
for monomial in polynomial.keys():
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index d023f1e..a02b4f1 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -41,33 +41,18 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
for index_k in range(left.shape[1]):
- try:
- left_terms = left.get_poly(poly_row, index_k)
- right_terms = right.get_poly(index_k, poly_col)
- except KeyError:
+ left_terms = left.get_poly(poly_row, index_k)
+ if left_terms is None:
continue
- # for (left_monomial, left_value), (right_monomial, right_value) \
- # in itertools.product(left_terms.items(), right_terms.items()):
-
- # value = left_value * right_value
-
- # if value == 0:
- # continue
-
- # # monomial = tuple(sorted(left_monomial + right_monomial))
- # monomial = merge_monomial_indices((left_monomial, right_monomial))
-
- # if monomial not in terms_row_col:
- # terms_row_col[monomial] = 0
-
- # terms_row_col[monomial] += value
+ right_terms = right.get_poly(index_k, poly_col)
+ if right_terms is None:
+ continue
multiply_polynomial(left_terms, right_terms, terms_row_col)
if 0 < len(terms_row_col):
terms[poly_row, poly_col] = terms_row_col
- # terms[poly_row, poly_col] = {key: val for key, val in terms_row_col.items() if not math.isclose(val, 0, abs_tol=1e-12)}
poly_matrix = init_poly_matrix(
terms=terms,
diff --git a/polymatrix/expression/mixins/maxdegreeexprmixin.py b/polymatrix/expression/mixins/maxdegreeexprmixin.py
index 7502cbb..b21c157 100644
--- a/polymatrix/expression/mixins/maxdegreeexprmixin.py
+++ b/polymatrix/expression/mixins/maxdegreeexprmixin.py
@@ -25,9 +25,8 @@ class MaxDegreeExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ underlying_terms = underlying.get_poly(row, col)
+ if underlying_terms is None:
continue
def gen_degrees():
diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py
index 21fd6da..2ab4cad 100644
--- a/polymatrix/expression/mixins/maxexprmixin.py
+++ b/polymatrix/expression/mixins/maxexprmixin.py
@@ -28,12 +28,11 @@ class MaxExprMixin(ExpressionBaseMixin):
def gen_values():
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
- yield underlying_terms[tuple()]
+ yield polynomial[tuple()]
values = tuple(gen_values())
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index ea71185..130ab79 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -24,7 +24,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> tuple:
+ def variables(self) -> ExpressionBaseMixin:
...
# overwrites abstract method of `ExpressionBaseMixin`
@@ -39,11 +39,10 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
assert underlying.shape == (1, 1), f'underlying shape is {underlying.shape}'
- underlying_terms = underlying.get_poly(0, 0)
-
terms = collections.defaultdict(dict)
+ terms = collections.defaultdict(lambda: collections.defaultdict(float))
- for monomial, value in underlying_terms.items():
+ for monomial, value in underlying.get_poly(0, 0).items():
x_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx in variable_indices)
p_monomial = tuple((var_idx, count) for var_idx, count in monomial if var_idx not in variable_indices)
@@ -60,12 +59,11 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
except ValueError:
raise ValueError(f'{right=} not in {sos_monomials=}')
- monomial_terms = terms[row, col]
-
- if p_monomial not in monomial_terms:
- monomial_terms[p_monomial] = 0
+ # monomial_terms = terms[row, col]
+ # if p_monomial not in monomial_terms:
+ # monomial_terms[p_monomial] = 0
- monomial_terms[p_monomial] += value
+ terms[row, col][p_monomial] += value
poly_matrix = init_poly_matrix(
terms=dict(terms),
diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
index c53741b..bddc321 100644
--- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
@@ -34,9 +34,8 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- polynomial = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
for monomial in polynomial.keys():
diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py
index 2bb0bfe..a111d20 100644
--- a/polymatrix/expression/mixins/setelementatexprmixin.py
+++ b/polymatrix/expression/mixins/setelementatexprmixin.py
@@ -39,25 +39,24 @@ class SetElementAtExprMixin(ExpressionBaseMixin):
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, value_expr = self.value.apply(state=state)
+ state, polynomial_expr = self.value.apply(state=state)
- assert value_expr.shape == (1, 1)
+ assert polynomial_expr.shape == (1, 1)
- try:
- value = value_expr.get_poly(0, 0)
- except KeyError:
- value = 0
+ polynomial = polynomial_expr.get_poly(0, 0)
+ if polynomial is None:
+ polynomial = 0
@dataclass_abc.dataclass_abc(frozen=True)
class SetElementAtPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
shape: tuple[int, int]
index: tuple[int, int]
- value: dict
+ polynomial: dict
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
if (row, col) == self.index:
- return self.value
+ return self.polynomial
else:
return self.underlying.get_poly(row, col)
@@ -65,6 +64,6 @@ class SetElementAtExprMixin(ExpressionBaseMixin):
underlying=underlying,
index=self.index,
shape=underlying.shape,
- value=value,
+ polynomial=polynomial,
)
\ No newline at end of file
diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py
index 2f46e23..14cf1f3 100644
--- a/polymatrix/expression/mixins/squeezeexprmixin.py
+++ b/polymatrix/expression/mixins/squeezeexprmixin.py
@@ -31,14 +31,13 @@ class SqueezeExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
+ polynomial = underlying.get_poly(row, 0)
+ if polynomial is None:
continue
terms_row_col = {}
- for monomial, value in underlying_terms.items():
+ for monomial, value in polynomial.items():
if value != 0.0:
terms_row_col[monomial] = value
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
index 1ba8a1a..9741897 100644
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ b/polymatrix/expression/mixins/substituteexprmixin.py
@@ -50,6 +50,9 @@ class SubstituteExprMixin(ExpressionBaseMixin):
elif isinstance(substitution_expr, int) or isinstance(substitution_expr, float):
polynomial = {tuple(): substitution_expr}
+ else:
+ raise Exception(f'{substitution_expr=} not recognized')
+
return state, result + (polynomial,)
*_, (state, substitutions) = tuple(itertools.accumulate(
@@ -63,14 +66,13 @@ class SubstituteExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
terms_row_col = collections.defaultdict(float)
- for monomial, value in underlying_terms.items():
+ for monomial, value in polynomial.items():
terms_monomial = {tuple(): value}
diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py
index 75124d9..9957f10 100644
--- a/polymatrix/expression/mixins/sumexprmixin.py
+++ b/polymatrix/expression/mixins/sumexprmixin.py
@@ -28,14 +28,13 @@ class SumExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
term_monomials = terms[row, 0]
- for monomial, value in underlying_terms.items():
+ for monomial, value in polynomial.items():
if monomial in term_monomials:
term_monomials[monomial] += value
else:
diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py
index 989cb0e..e41e889 100644
--- a/polymatrix/expression/mixins/symmetricexprmixin.py
+++ b/polymatrix/expression/mixins/symmetricexprmixin.py
@@ -38,17 +38,15 @@ class SymmetricExprMixin(ExpressionBaseMixin):
def gen_symmetric_monomials():
for i_row, i_col in ((row, col), (col, row)):
- try:
- monomials = self.underlying.get_poly(i_row, i_col)
- except:
- pass
- else:
- yield monomials
+ polynomial = self.underlying.get_poly(i_row, i_col)
+
+ if polynomial is not None:
+ yield polynomial
all_monomials = tuple(gen_symmetric_monomials())
if len(all_monomials) == 0:
- raise KeyError()
+ return None
else:
terms = collections.defaultdict(float)
@@ -58,7 +56,8 @@ class SymmetricExprMixin(ExpressionBaseMixin):
for monomial, value in monomials.items():
terms[monomial] += value / 2
- return dict(terms)
+ # return dict(terms)
+ return terms
polymatrix = SymmetricPolyMatrix(
underlying=underlying,
diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py
index 547b012..ca928c6 100644
--- a/polymatrix/expression/mixins/toconstantexprmixin.py
+++ b/polymatrix/expression/mixins/toconstantexprmixin.py
@@ -26,13 +26,12 @@ class ToConstantExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
- if tuple() in underlying_terms:
- terms[row, col][tuple()] = underlying_terms[tuple()]
+ if tuple() in polynomial:
+ terms[row, col][tuple()] = polynomial[tuple()]
poly_matrix = init_poly_matrix(
terms=dict(terms),
diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py
index bbb0cbf..2bcaacb 100644
--- a/polymatrix/expression/mixins/toquadraticexprmixin.py
+++ b/polymatrix/expression/mixins/toquadraticexprmixin.py
@@ -54,18 +54,18 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
else:
terms_row_col[monomial] += value
- return dict(terms_row_col)
+ # return dict(terms_row_col)
+ return terms_row_col
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
terms_row_col = to_quadratic(
- monomial_terms=underlying_terms,
+ monomial_terms=polynomial,
)
# terms_row_col = collections.defaultdict(float)
diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py
index 42f192e..e144ae5 100644
--- a/polymatrix/expression/mixins/truncateexprmixin.py
+++ b/polymatrix/expression/mixins/truncateexprmixin.py
@@ -42,14 +42,13 @@ class TruncateExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
+ polynomial = underlying.get_poly(row, col)
+ if polynomial is None:
continue
terms_row_col = {}
- for monomial, value in underlying_terms.items():
+ for monomial, value in polynomial.items():
degree = sum((count for var_idx, count in monomial if var_idx in variable_indices))
diff --git a/polymatrix/expression/utils/getderivativemonomials.py b/polymatrix/expression/utils/getderivativemonomials.py
index cec2d81..4aa7144 100644
--- a/polymatrix/expression/utils/getderivativemonomials.py
+++ b/polymatrix/expression/utils/getderivativemonomials.py
@@ -72,8 +72,6 @@ def get_derivative_monomials(
for monomial, value in monomial_terms.items():
- # # count powers for each variable
- # monomial_cnt = dict(collections.Counter(monomial))
monomial_cnt = dict(monomial)
def differentiate_monomial(dependent_variable, derivation_variable=None):
@@ -117,4 +115,5 @@ def get_derivative_monomials(
# )
# derivation_terms[diff_monomial] += value
- return state, dict(derivation_terms) \ No newline at end of file
+ # return state, dict(derivation_terms)
+ return state, derivation_terms \ No newline at end of file
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 08b44d3..aaf5b0b 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -35,4 +35,6 @@ def get_variable_indices(state, variables):
else:
yield global_state[0].offset_dict[variable][0]
- return global_state[0], tuple(gen_indices())
+ indices = tuple(gen_indices())
+
+ return global_state[0], indices
diff --git a/polymatrix/expression/utils/splitmonomialindices.py b/polymatrix/expression/utils/splitmonomialindices.py
index 82e5454..4ffa6ff 100644
--- a/polymatrix/expression/utils/splitmonomialindices.py
+++ b/polymatrix/expression/utils/splitmonomialindices.py
@@ -7,8 +7,7 @@ def split_monomial_indices(monomial):
for idx, count in monomial:
count_left = count // 2
- is_uneven = count % 2
- if is_uneven:
+ if count % 2:
if is_left:
count_left = count_left + 1
@@ -21,7 +20,5 @@ def split_monomial_indices(monomial):
if 0 < count_right:
right.append((idx, count - count_left))
-
- # print((monomial, tuple(left), tuple(right)))
return tuple(left), tuple(right) \ No newline at end of file
diff --git a/polymatrix/expression/mixins/polymatrixasdictmixin.py b/polymatrix/polymatrix/mixins/polymatrixasdictmixin.py
index 69c59ac..93b6385 100644
--- a/polymatrix/expression/mixins/polymatrixasdictmixin.py
+++ b/polymatrix/polymatrix/mixins/polymatrixasdictmixin.py
@@ -1,4 +1,5 @@
import abc
+import typing
from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
@@ -13,8 +14,6 @@ class PolyMatrixAsDictMixin(
...
# overwrites abstract method of `PolyMatrixMixin`
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- try:
+ def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ if (row, col) in self.terms:
return self.terms[row, col]
- except KeyError:
- raise KeyError(f'{(row, col)} is not a key of {self.terms}')
diff --git a/polymatrix/polymatrix/mixins/polymatrixmixin.py b/polymatrix/polymatrix/mixins/polymatrixmixin.py
index bea80de..e67a6fa 100644
--- a/polymatrix/polymatrix/mixins/polymatrixmixin.py
+++ b/polymatrix/polymatrix/mixins/polymatrixmixin.py
@@ -8,18 +8,14 @@ class PolyMatrixMixin(abc.ABC):
def shape(self) -> tuple[int, int]:
...
- def get_terms(self) -> tuple[tuple[int, int], dict[tuple[int, ...], float]]:
- def gen_terms():
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
- try:
- monomial_terms = self.get_poly(row, col)
- except KeyError:
- continue
-
- yield (row, col), monomial_terms
-
- return tuple(gen_terms())
+ def gen_terms(self) -> typing.Generator[tuple[tuple[int, int], dict[tuple[int, ...], float]], None, None]:
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+ polynomial = self.get_poly(row, col)
+ if polynomial is None:
+ continue
+
+ yield (row, col), polynomial
@abc.abstractclassmethod
def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
diff --git a/polymatrix/polymatrix/polymatrix.py b/polymatrix/polymatrix/polymatrix.py
index c044081..b8523f6 100644
--- a/polymatrix/polymatrix/polymatrix.py
+++ b/polymatrix/polymatrix/polymatrix.py
@@ -1,4 +1,4 @@
-from polymatrix.expression.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin
+from polymatrix.polymatrix.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin
class PolyMatrix(PolyMatrixAsDictMixin):
pass