summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py38
-rw-r--r--polymatrix/expression/combinationsexpr.py4
-rw-r--r--polymatrix/expression/eyeexpr.py4
-rw-r--r--polymatrix/expression/impl/combinationsexprimpl.py9
-rw-r--r--polymatrix/expression/impl/eyeexprimpl.py8
-rw-r--r--polymatrix/expression/impl/getitemexprimpl.py2
-rw-r--r--polymatrix/expression/impl/linearin2exprimpl.py9
-rw-r--r--polymatrix/expression/impl/squeezeexprimpl.py8
-rw-r--r--polymatrix/expression/impl/sumexprimpl.py8
-rw-r--r--polymatrix/expression/impl/symmetricexprimpl.py8
-rw-r--r--polymatrix/expression/impl/toconstantexprimpl.py8
-rw-r--r--polymatrix/expression/init/initcombinationsexpr.py12
-rw-r--r--polymatrix/expression/init/initevalexpr.py2
-rw-r--r--polymatrix/expression/init/initeyeexpr.py10
-rw-r--r--polymatrix/expression/init/initgetitemexpr.py14
-rw-r--r--polymatrix/expression/init/initlinearin2expr.py12
-rw-r--r--polymatrix/expression/init/initpolymatrix.py4
-rw-r--r--polymatrix/expression/init/initsqueezeexpr.py10
-rw-r--r--polymatrix/expression/init/initsumexpr.py10
-rw-r--r--polymatrix/expression/init/initsymmetricexpr.py10
-rw-r--r--polymatrix/expression/init/inittoconstantexpr.py10
-rw-r--r--polymatrix/expression/linearin2expr.py4
-rw-r--r--polymatrix/expression/linearinexpr.py4
-rw-r--r--polymatrix/expression/linearmatrixinexpr.py4
-rw-r--r--polymatrix/expression/mixins/accumulateexprmixin.py3
-rw-r--r--polymatrix/expression/mixins/addauxequationsexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py44
-rw-r--r--polymatrix/expression/mixins/blockdiagexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/cacheexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/combinationsexprmixin.py75
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/expressionbasemixin.py8
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py117
-rw-r--r--polymatrix/expression/mixins/eyeexprmixin.py47
-rw-r--r--polymatrix/expression/mixins/filterlinearpartexprmixin.py (renamed from polymatrix/expression/mixins/linearmatrixinexprmixin.py)35
-rw-r--r--polymatrix/expression/mixins/fromarrayexprmixin.py3
-rw-r--r--polymatrix/expression/mixins/fromtermsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py62
-rw-r--r--polymatrix/expression/mixins/kktexprmixin.py3
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py54
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/oldlinearexprmixin.py131
-rw-r--r--polymatrix/expression/mixins/parametrizetermsexprmixin.py16
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py68
-rw-r--r--polymatrix/expression/mixins/quadraticreprexprmixin.py87
-rw-r--r--polymatrix/expression/mixins/repmatexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/reshapeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/squeezeexprmixin.py53
-rw-r--r--polymatrix/expression/mixins/sumexprmixin.py49
-rw-r--r--polymatrix/expression/mixins/symmetricexprmixin.py68
-rw-r--r--polymatrix/expression/mixins/toconstantexprmixin.py44
-rw-r--r--polymatrix/expression/mixins/toquadraticexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/transposeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py5
-rw-r--r--polymatrix/expression/quadraticinexpr.py4
-rw-r--r--polymatrix/expression/squeezeexpr.py4
-rw-r--r--polymatrix/expression/sumexpr.py4
-rw-r--r--polymatrix/expression/symmetricexpr.py4
-rw-r--r--polymatrix/expression/toconstantexpr.py4
-rw-r--r--polymatrix/expression/utils/getvariableindices.py6
-rw-r--r--polymatrix/sympyutils.py2
65 files changed, 1042 insertions, 206 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 20b05a9..85def18 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -11,10 +11,9 @@ from polymatrix.expression.expressionstate import ExpressionState
from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr
from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr
from polymatrix.expression.init.initexpression import init_expression
+from polymatrix.expression.init.initeyeexpr import init_eye_expr
from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
-from polymatrix.expression.init.initkktexpr import init_kkt_expr
-from polymatrix.expression.init.initlinearmatrixinexpr import init_linear_matrix_in_expr
from polymatrix.expression.init.initvstackexpr import init_v_stack_expr
from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.utils.getvariableindices import get_variable_indices
@@ -89,6 +88,14 @@ def block_diag(
)
+def eye(
+ variable: tuple[Expression],
+):
+ return init_expression(
+ init_eye_expr(variable=variable)
+ )
+
+
# def kkt(
# cost: Expression,
# variables: Expression,
@@ -317,6 +324,17 @@ def kkt_inequality(
# return StateMonad.init(func)
+def shape(
+ expr: Expression,
+) -> StateMonadMixin[ExpressionState, tuple[int, ...]]:
+ def func(state: ExpressionState):
+ state, polymatrix = expr.apply(state)
+
+ return state, polymatrix.shape
+
+ return init_state_monad(func)
+
+
@dataclasses.dataclass
class MatrixEquations:
matrix_equations: tuple[tuple[np.ndarray, ...], ...]
@@ -340,12 +358,16 @@ class MatrixEquations:
return (matrix_1, matrix_2, matrix_3)
def get_value(self, variable, value):
- if isinstance(variable, Expression):
- variable = variable.underlying
-
- offset = self.state.offset_dict[variable]
- offset_idx = list(self.variable_index.index(idx) for idx in range(*offset))
- return value[offset_idx]
+ variable_indices = get_variable_indices(self.state, variable)[1]
+ value_index = list(self.variable_index.index(variable_index) for variable_index in variable_indices)
+ return value[value_index]
+
+ def set_value(self, variable, value):
+ variable_indices = get_variable_indices(self.state, variable)[1]
+ value_index = list(self.variable_index.index(variable_index) for variable_index in variable_indices)
+ vec = np.zeros(len(self.variable_index))
+ vec[value_index] = value
+ return vec
def to_matrix_equations(
diff --git a/polymatrix/expression/combinationsexpr.py b/polymatrix/expression/combinationsexpr.py
new file mode 100644
index 0000000..3373365
--- /dev/null
+++ b/polymatrix/expression/combinationsexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin
+
+class CombinationsExpr(CombinationsExprMixin):
+ pass
diff --git a/polymatrix/expression/eyeexpr.py b/polymatrix/expression/eyeexpr.py
new file mode 100644
index 0000000..4e64569
--- /dev/null
+++ b/polymatrix/expression/eyeexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.eyeexprmixin import EyeExprMixin
+
+class EyeExpr(EyeExprMixin):
+ pass
diff --git a/polymatrix/expression/impl/combinationsexprimpl.py b/polymatrix/expression/impl/combinationsexprimpl.py
new file mode 100644
index 0000000..2e809a2
--- /dev/null
+++ b/polymatrix/expression/impl/combinationsexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.combinationsexpr import CombinationsExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class CombinationsExprImpl(CombinationsExpr):
+ underlying: ExpressionBaseMixin
+ number: int
diff --git a/polymatrix/expression/impl/eyeexprimpl.py b/polymatrix/expression/impl/eyeexprimpl.py
new file mode 100644
index 0000000..9c50f8a
--- /dev/null
+++ b/polymatrix/expression/impl/eyeexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.eyeexpr import EyeExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class EyeExprImpl(EyeExpr):
+ variable: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/getitemexprimpl.py b/polymatrix/expression/impl/getitemexprimpl.py
index b8972a9..578d4c7 100644
--- a/polymatrix/expression/impl/getitemexprimpl.py
+++ b/polymatrix/expression/impl/getitemexprimpl.py
@@ -6,4 +6,4 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@dataclass_abc.dataclass_abc(frozen=True)
class GetItemExprImpl(GetItemExpr):
underlying: ExpressionBaseMixin
- index: tuple
+ index: tuple[tuple[int, ...], tuple[int, ...]]
diff --git a/polymatrix/expression/impl/linearin2exprimpl.py b/polymatrix/expression/impl/linearin2exprimpl.py
new file mode 100644
index 0000000..34b42bf
--- /dev/null
+++ b/polymatrix/expression/impl/linearin2exprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.linearin2expr import LinearIn2Expr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class LinearIn2ExprImpl(LinearIn2Expr):
+ underlying: ExpressionBaseMixin
+ expression: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/squeezeexprimpl.py b/polymatrix/expression/impl/squeezeexprimpl.py
new file mode 100644
index 0000000..b60af68
--- /dev/null
+++ b/polymatrix/expression/impl/squeezeexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.squeezeexpr import SqueezeExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class SqueezeExprImpl(SqueezeExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/sumexprimpl.py b/polymatrix/expression/impl/sumexprimpl.py
new file mode 100644
index 0000000..b148b43
--- /dev/null
+++ b/polymatrix/expression/impl/sumexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.sumexpr import SumExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class SumExprImpl(SumExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/symmetricexprimpl.py b/polymatrix/expression/impl/symmetricexprimpl.py
new file mode 100644
index 0000000..b15b8a0
--- /dev/null
+++ b/polymatrix/expression/impl/symmetricexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.symmetricexpr import SymmetricExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class SymmetricExprImpl(SymmetricExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/toconstantexprimpl.py b/polymatrix/expression/impl/toconstantexprimpl.py
new file mode 100644
index 0000000..9e3d13d
--- /dev/null
+++ b/polymatrix/expression/impl/toconstantexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.toconstantexpr import ToConstantExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ToConstantExprImpl(ToConstantExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/init/initcombinationsexpr.py b/polymatrix/expression/init/initcombinationsexpr.py
new file mode 100644
index 0000000..7a1df2f
--- /dev/null
+++ b/polymatrix/expression/init/initcombinationsexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.combinationsexprimpl import CombinationsExprImpl
+
+
+def init_combinations_expr(
+ underlying: ExpressionBaseMixin,
+ number: int,
+):
+ return CombinationsExprImpl(
+ underlying=underlying,
+ number=number,
+)
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
index 525a697..49bb0a3 100644
--- a/polymatrix/expression/init/initevalexpr.py
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -15,7 +15,7 @@ def init_eval_expr(
variables, values = tuple(zip(*variables))
elif isinstance(values, np.ndarray):
- values = tuple(values)
+ values = tuple(values.reshape(-1))
elif not isinstance(values, tuple):
values = (values,)
diff --git a/polymatrix/expression/init/initeyeexpr.py b/polymatrix/expression/init/initeyeexpr.py
new file mode 100644
index 0000000..e691b03
--- /dev/null
+++ b/polymatrix/expression/init/initeyeexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.eyeexprimpl import EyeExprImpl
+
+
+def init_eye_expr(
+ variable: ExpressionBaseMixin,
+):
+ return EyeExprImpl(
+ variable=variable,
+)
diff --git a/polymatrix/expression/init/initgetitemexpr.py b/polymatrix/expression/init/initgetitemexpr.py
index 140fa3a..5fea7a5 100644
--- a/polymatrix/expression/init/initgetitemexpr.py
+++ b/polymatrix/expression/init/initgetitemexpr.py
@@ -1,12 +1,22 @@
+from numpy import isin
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl.getitemexprimpl import GetItemExprImpl
def init_get_item_expr(
underlying: ExpressionBaseMixin,
- index: tuple,
+ index: tuple[tuple[int, ...], tuple[int, ...]],
):
+
+ def get_hashable_slice(index):
+ if isinstance(index, slice):
+ return GetItemExprImpl.Slice(start=index.start, stop=index.stop, step=index.step)
+ else:
+ return index
+
+ proper_index = (get_hashable_slice(index[0]), get_hashable_slice(index[1]))
+
return GetItemExprImpl(
underlying=underlying,
- index=index,
+ index=proper_index,
)
diff --git a/polymatrix/expression/init/initlinearin2expr.py b/polymatrix/expression/init/initlinearin2expr.py
new file mode 100644
index 0000000..7225510
--- /dev/null
+++ b/polymatrix/expression/init/initlinearin2expr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.linearin2exprimpl import LinearIn2ExprImpl
+
+
+def init_linear_in2_expr(
+ underlying: ExpressionBaseMixin,
+ expression: ExpressionBaseMixin,
+):
+ return LinearIn2ExprImpl(
+ underlying=underlying,
+ expression=expression,
+)
diff --git a/polymatrix/expression/init/initpolymatrix.py b/polymatrix/expression/init/initpolymatrix.py
index e6a6cde..d0c9577 100644
--- a/polymatrix/expression/init/initpolymatrix.py
+++ b/polymatrix/expression/init/initpolymatrix.py
@@ -4,13 +4,9 @@ from polymatrix.expression.impl.polymatriximpl import PolyMatrixImpl
def init_poly_matrix(
terms: dict,
shape: tuple,
- aux_terms: tuple[dict[tuple[int, ...], float]] = None,
):
- if aux_terms is None:
- aux_terms = tuple()
return PolyMatrixImpl(
terms=terms,
shape=shape,
- # aux_terms=aux_terms,
)
diff --git a/polymatrix/expression/init/initsqueezeexpr.py b/polymatrix/expression/init/initsqueezeexpr.py
new file mode 100644
index 0000000..d71a9b5
--- /dev/null
+++ b/polymatrix/expression/init/initsqueezeexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.squeezeexprimpl import SqueezeExprImpl
+
+
+def init_squeeze_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return SqueezeExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/initsumexpr.py b/polymatrix/expression/init/initsumexpr.py
new file mode 100644
index 0000000..606d86b
--- /dev/null
+++ b/polymatrix/expression/init/initsumexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.sumexprimpl import SumExprImpl
+
+
+def init_sum_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return SumExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/initsymmetricexpr.py b/polymatrix/expression/init/initsymmetricexpr.py
new file mode 100644
index 0000000..87be3ed
--- /dev/null
+++ b/polymatrix/expression/init/initsymmetricexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.symmetricexprimpl import SymmetricExprImpl
+
+
+def init_symmetric_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return SymmetricExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/inittoconstantexpr.py b/polymatrix/expression/init/inittoconstantexpr.py
new file mode 100644
index 0000000..ead35f8
--- /dev/null
+++ b/polymatrix/expression/init/inittoconstantexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.toconstantexprimpl import ToConstantExprImpl
+
+
+def init_to_constant_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return ToConstantExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/linearin2expr.py b/polymatrix/expression/linearin2expr.py
new file mode 100644
index 0000000..2c40dc8
--- /dev/null
+++ b/polymatrix/expression/linearin2expr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin
+
+class LinearIn2Expr(LinearInExprMixin):
+ pass
diff --git a/polymatrix/expression/linearinexpr.py b/polymatrix/expression/linearinexpr.py
index 4edf8b3..9054937 100644
--- a/polymatrix/expression/linearinexpr.py
+++ b/polymatrix/expression/linearinexpr.py
@@ -1,4 +1,4 @@
-from polymatrix.expression.mixins.linearinexprmixin import LinearInExprMixin
+from polymatrix.expression.mixins.oldlinearexprmixin import OldLinearExprMixin
-class LinearInExpr(LinearInExprMixin):
+class LinearInExpr(OldLinearExprMixin):
pass
diff --git a/polymatrix/expression/linearmatrixinexpr.py b/polymatrix/expression/linearmatrixinexpr.py
index 2bce2e7..296e902 100644
--- a/polymatrix/expression/linearmatrixinexpr.py
+++ b/polymatrix/expression/linearmatrixinexpr.py
@@ -1,4 +1,4 @@
-from polymatrix.expression.mixins.linearmatrixinexprmixin import LinearMatrixInExprMixin
+from polymatrix.expression.mixins.filterlinearpartexprmixin import FilterLinearPartExprMixin
-class LinearMatrixInExpr(LinearMatrixInExprMixin):
+class LinearMatrixInExpr(FilterLinearPartExprMixin):
pass
diff --git a/polymatrix/expression/mixins/accumulateexprmixin.py b/polymatrix/expression/mixins/accumulateexprmixin.py
index 76e1717..47df335 100644
--- a/polymatrix/expression/mixins/accumulateexprmixin.py
+++ b/polymatrix/expression/mixins/accumulateexprmixin.py
@@ -8,6 +8,7 @@ from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
+# todo: is this needed?
class AccumulateExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
@@ -24,7 +25,7 @@ class AccumulateExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/addauxequationsexprmixin.py b/polymatrix/expression/mixins/addauxequationsexprmixin.py
index 1dae4f7..e5353fe 100644
--- a/polymatrix/expression/mixins/addauxequationsexprmixin.py
+++ b/polymatrix/expression/mixins/addauxequationsexprmixin.py
@@ -10,7 +10,7 @@ from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
-# is this really needed?
+# todo: is this really needed?
class AddAuxEquationsExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
@@ -18,7 +18,7 @@ class AddAuxEquationsExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index d274f2a..2c7a652 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -5,6 +5,7 @@ import dataclass_abc
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
@@ -20,24 +21,49 @@ class AdditionExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
- # @property
- # def shape(self) -> tuple[int, int]:
- # return self.left.shape
-
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- assert left.shape == right.shape, f'{left.shape} != {right.shape}'
-
terms = {}
- for underlying in (left, right):
+ if left.shape == (1, 1):
+ left, right = right, left
+
+ if left.shape != (1, 1) and right.shape == (1, 1):
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class BroadCastedPolyMatrix(PolyMatrixMixin):
+ underlying_monomials: tuple[tuple[int], float]
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ return self.underlying_monomials
+
+ try:
+ underlying_terms = right.get_poly(0, 0)
+
+ except KeyError:
+ pass
+
+ else:
+ broadcasted_right = BroadCastedPolyMatrix(
+ underlying_monomials=underlying_terms,
+ shape=left.shape,
+ )
+
+ all_underlying = (left, broadcasted_right)
+
+ else:
+ assert left.shape == right.shape, f'{left.shape} != {right.shape}'
+
+ all_underlying = (left, right)
+
+ for underlying in all_underlying:
for row in range(left.shape[0]):
for col in range(left.shape[1]):
diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py
index 6bdbce1..df4bd80 100644
--- a/polymatrix/expression/mixins/blockdiagexprmixin.py
+++ b/polymatrix/expression/mixins/blockdiagexprmixin.py
@@ -16,7 +16,7 @@ class BlockDiagExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py
index 8779ba1..ba9a612 100644
--- a/polymatrix/expression/mixins/cacheexprmixin.py
+++ b/polymatrix/expression/mixins/cacheexprmixin.py
@@ -15,7 +15,7 @@ class CacheExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py
new file mode 100644
index 0000000..6d0f6eb
--- /dev/null
+++ b/polymatrix/expression/mixins/combinationsexprmixin.py
@@ -0,0 +1,75 @@
+
+import abc
+import itertools
+import math
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+
+
+class CombinationsExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def number(self) -> int:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape[1] == 1
+
+ def gen_monomials():
+ for _, monomial_terms in underlying.get_terms():
+ assert len(monomial_terms) == 1, f'{monomial_terms} has more than 1 element'
+
+ for monomial, values in monomial_terms.items():
+ yield monomial, values
+
+ monomial_terms = tuple(gen_monomials())
+
+ assert underlying.shape[0] == len(monomial_terms)
+
+ combinations = tuple(itertools.combinations_with_replacement(range(underlying.shape[0]), self.number))
+
+ # print(combinations)
+ # print(math.comb(underlying.shape[0]+self.number-1, self.number))
+
+ terms = {}
+
+ for row, combination in enumerate(combinations):
+ def gen_combination_monomials():
+ for index in combination:
+ monomial, _ = monomial_terms[index]
+ yield from monomial
+
+ combination_monomial = tuple(gen_combination_monomials())
+
+ def acc_combination_value(acc, index):
+ _, value = monomial_terms[index]
+ return acc * value
+
+ *_, combination_value = itertools.accumulate(
+ combination,
+ acc_combination_value,
+ initial=1.0,
+ )
+
+ terms[row, 0] = {combination_monomial: combination_value}
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=(math.comb(underlying.shape[0]+self.number-1, self.number), 1),
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index a551c1f..b403873 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -30,7 +30,7 @@ class DerivativeExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
index 5436ed6..f002f0b 100644
--- a/polymatrix/expression/mixins/determinantexprmixin.py
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -22,7 +22,7 @@ class DeterminantExprMixin(ExpressionBaseMixin):
# return self.underlying.shape[0], 1
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index cfb6bea..dc5f616 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -25,7 +25,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
# return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
@@ -76,7 +76,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
state = dataclasses.replace(
state,
auxillary_equations=state.auxillary_equations | {division_variable: auxillary_terms},
- cached_polymatrix=state.cache | {self: poly_matrix},
+ cache=state.cache | {self: poly_matrix},
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index ca5d921..931e022 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -25,7 +25,7 @@ class ElemMultExprMixin(ExpressionBaseMixin):
# return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index d4fdbea..3089bbb 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -26,7 +26,7 @@ class EvalExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
@@ -38,7 +38,7 @@ class EvalExprMixin(ExpressionBaseMixin):
values = tuple(self.values[0] for _ in variable_indices)
else:
- assert len(variable_indices) == len(self.values)
+ assert len(variable_indices) == len(self.values), f'length of {variable_indices} does not match length of {self.values}'
values = self.values
diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py
index 3825ed1..a4d7dc2 100644
--- a/polymatrix/expression/mixins/expressionbasemixin.py
+++ b/polymatrix/expression/mixins/expressionbasemixin.py
@@ -8,10 +8,10 @@ class ExpressionBaseMixin(
):
@abc.abstractmethod
- def _apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
...
- def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
- assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}'
+ # def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ # assert isinstance(state, ExpressionStateMixin), f'{state} is not of type {ExpressionStateMixin.__name__}'
- return self._apply(state)
+ # return self._apply(state)
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index 4cb4a50..f8b5ef7 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -6,11 +6,13 @@ from sympy import re
from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr
from polymatrix.expression.init.initadditionexpr import init_addition_expr
from polymatrix.expression.init.initcacheexpr import init_cache_expr
+from polymatrix.expression.init.initcombinationsexpr import init_combinations_expr
from polymatrix.expression.init.initderivativeexpr import init_derivative_expr
from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr
from polymatrix.expression.init.initdivisionexpr import init_division_expr
from polymatrix.expression.init.initelemmultexpr import init_elem_mult_expr
from polymatrix.expression.init.initevalexpr import init_eval_expr
+from polymatrix.expression.init.initlinearin2expr import init_linear_in2_expr
from polymatrix.expression.init.initlinearinexpr import init_linear_in_expr
from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
from polymatrix.expression.init.initgetitemexpr import init_get_item_expr
@@ -20,6 +22,10 @@ from polymatrix.expression.init.initparametrizetermsexpr import init_parametrize
from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr
from polymatrix.expression.init.initreshapeexpr import init_reshape_expr
+from polymatrix.expression.init.initsqueezeexpr import init_squeeze_expr
+from polymatrix.expression.init.initsumexpr import init_sum_expr
+from polymatrix.expression.init.initsymmetricexpr import init_symmetric_expr
+from polymatrix.expression.init.inittoconstantexpr import init_to_constant_expr
from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr
from polymatrix.expression.init.inittransposeexpr import init_transpose_expr
@@ -39,7 +45,7 @@ class ExpressionMixin(
...
# overwrites abstract method of `PolyMatrixExprBaseMixin`
- def _apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
return self.underlying.apply(state)
# # overwrites abstract method of `PolyMatrixExprBaseMixin`
@@ -131,11 +137,17 @@ class ExpressionMixin(
return other @ self
def __truediv__(self, other: ExpressionBaseMixin):
+ match other:
+ case ExpressionBaseMixin():
+ right = other.underlying
+ case _:
+ right = init_from_array_expr(other)
+
return dataclasses.replace(
self,
underlying=init_division_expr(
left=self.underlying,
- right=other,
+ right=right,
),
)
@@ -165,31 +177,20 @@ class ExpressionMixin(
),
)
- def parametrize(self, name: str, variables: tuple) -> 'ExpressionMixin':
+ def combinations(self, number: int):
return dataclasses.replace(
self,
- underlying=init_parametrize_terms_expr(
- name=name,
+ underlying=init_combinations_expr(
underlying=self.underlying,
- variables=variables,
+ number=number,
),
)
- def reshape(self, n: int, m: int) -> 'ExpressionMixin':
- return dataclasses.replace(
- self,
- underlying=init_reshape_expr(
- underlying=self.underlying,
- new_shape=(n, m),
- ),
- )
-
- def rep_mat(self, n: int, m: int) -> 'ExpressionMixin':
+ def determinant(self) -> 'ExpressionMixin':
return dataclasses.replace(
self,
- underlying=init_rep_mat_expr(
+ underlying=init_determinant_expr(
underlying=self.underlying,
- repetition=(n, m),
),
)
@@ -207,16 +208,21 @@ class ExpressionMixin(
),
)
- def linear_in(self, variables: tuple) -> 'ExpressionMixin':
+ def eval(
+ self,
+ variable: tuple,
+ value: tuple[float, ...] = None,
+ ) -> 'ExpressionMixin':
return dataclasses.replace(
self,
- underlying=init_linear_in_expr(
+ underlying=init_eval_expr(
underlying=self.underlying,
- variables=variables,
+ variables=variable,
+ values=value,
),
)
- def linear_matrix_in(self, variable) -> 'ExpressionMixin':
+ def filter_linear_part(self, variable) -> 'ExpressionMixin':
return dataclasses.replace(
self,
underlying=init_linear_matrix_in_expr(
@@ -225,6 +231,25 @@ class ExpressionMixin(
),
)
+ def linear_in(self, variables: tuple) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_linear_in2_expr(
+ underlying=self.underlying,
+ expression=variables,
+ ),
+ )
+
+ def parametrize(self, name: str, variables: tuple) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_parametrize_terms_expr(
+ name=name,
+ underlying=self.underlying,
+ variables=variables,
+ ),
+ )
+
def quadratic_in(self, variables: tuple) -> 'ExpressionMixin':
return dataclasses.replace(
self,
@@ -234,25 +259,53 @@ class ExpressionMixin(
),
)
- def determinant(self) -> 'ExpressionMixin':
+ def reshape(self, n: int, m: int) -> 'ExpressionMixin':
return dataclasses.replace(
self,
- underlying=init_determinant_expr(
+ underlying=init_reshape_expr(
underlying=self.underlying,
+ new_shape=(n, m),
),
)
- def eval(
- self,
- variable: tuple,
- value: tuple[float, ...] = None,
- ) -> 'ExpressionMixin':
+ def rep_mat(self, n: int, m: int) -> 'ExpressionMixin':
return dataclasses.replace(
self,
- underlying=init_eval_expr(
+ underlying=init_rep_mat_expr(
+ underlying=self.underlying,
+ repetition=(n, m),
+ ),
+ )
+
+ def squeeze(self):
+ return dataclasses.replace(
+ self,
+ underlying=init_squeeze_expr(
+ underlying=self.underlying,
+ ),
+ )
+
+ def sum(self):
+ return dataclasses.replace(
+ self,
+ underlying=init_sum_expr(
+ underlying=self.underlying,
+ ),
+ )
+
+ def symmetric(self):
+ return dataclasses.replace(
+ self,
+ underlying=init_symmetric_expr(
+ underlying=self.underlying,
+ ),
+ )
+
+ def to_constant(self) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_to_constant_expr(
underlying=self.underlying,
- variables=variable,
- values=value,
),
)
diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py
new file mode 100644
index 0000000..fde2017
--- /dev/null
+++ b/polymatrix/expression/mixins/eyeexprmixin.py
@@ -0,0 +1,47 @@
+
+import abc
+import itertools
+import dataclass_abc
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+
+
+class EyeExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def variable(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+
+ state, variable = self.variable.apply(state)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class EyePolyMatrix(PolyMatrixMixin):
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ if max(row, col) <= self.shape[0]:
+ if row == col:
+ return {tuple(): 1.0}
+
+ else:
+ raise KeyError()
+
+ else:
+ raise Exception(f'{(row, col)=} is out of bounds')
+
+ value = variable.shape[0]
+
+ polymatrix = EyePolyMatrix(
+ shape=(value, value),
+ )
+
+ return state, polymatrix
diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
index 321d955..c8e66ef 100644
--- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py
+++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
@@ -10,7 +10,8 @@ from polymatrix.expression.expressionstate import ExpressionState
from polymatrix.expression.utils.getvariableindices import get_variable_indices
-class LinearMatrixInExprMixin(ExpressionBaseMixin):
+# is this class needed?
+class FilterLinearPartExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
@@ -22,15 +23,22 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices(state, self.variable)
+ state, variables = self.variables.apply(state=state)
- assert len(variable_indices) == 1
+ def gen_variable_monomials():
+ for _, term in variables.get_terms():
+ assert len(term) == 1, f'{term} should have only a single monomial'
+
+ for monomial in term.keys():
+ yield set(monomial)
+
+ variable_monomials = tuple(gen_variable_monomials())
terms = {}
@@ -46,13 +54,22 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin):
for monomial, value in underlying_terms.items():
- x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ for variable_monomial in variable_monomials:
+
+ remainder = list(monomial)
+
+ try:
+ for variable in variable_monomial:
+ remainder.remove(variable)
+
+ except ValueError:
+ continue
- # only take linear terms
- if len(x_monomial) == 1:
- p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+ # take the first that matches
+ if all(variable not in remainder for variable in variable_monomial):
+ monomial_terms[remainder] += value
- monomial_terms[p_monomial] += value
+ break
terms[row, col] = dict(monomial_terms)
diff --git a/polymatrix/expression/mixins/fromarrayexprmixin.py b/polymatrix/expression/mixins/fromarrayexprmixin.py
index e79810c..179cec4 100644
--- a/polymatrix/expression/mixins/fromarrayexprmixin.py
+++ b/polymatrix/expression/mixins/fromarrayexprmixin.py
@@ -16,7 +16,7 @@ class FromArrayExprMixin(ExpressionBaseMixin):
pass
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
@@ -34,6 +34,7 @@ class FromArrayExprMixin(ExpressionBaseMixin):
for symbol in poly.gens:
state = state.register(key=symbol, n_param=1)
+ # print(f'{symbol}: {state.n_param}')
terms_row_col = {}
diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py
index 1af3f11..4ada8ec 100644
--- a/polymatrix/expression/mixins/fromtermsexprmixin.py
+++ b/polymatrix/expression/mixins/fromtermsexprmixin.py
@@ -21,7 +21,7 @@ class FromTermsExprMixin(ExpressionBaseMixin):
pass
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
index c8ac02e..fa7edb1 100644
--- a/polymatrix/expression/mixins/getitemexprmixin.py
+++ b/polymatrix/expression/mixins/getitemexprmixin.py
@@ -12,6 +12,12 @@ from polymatrix.expression.expressionstate import ExpressionState
class GetItemExprMixin(ExpressionBaseMixin):
+ @dataclasses.dataclass(frozen=True)
+ class Slice:
+ start: int
+ step: int
+ stop: int
+
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
@@ -19,32 +25,70 @@ class GetItemExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def index(self) -> tuple[int, int]:
+ def index(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
+ def get_proper_index(index, shape):
+ if isinstance(index, tuple):
+ return index
+
+ elif isinstance(index, GetItemExprMixin.Slice):
+ if index.start is None:
+ start = 0
+ else:
+ start = index.start
+
+ if index.stop is None:
+ stop = shape
+ else:
+ stop = index.stop
+
+ if index.step is None:
+ step = 1
+ else:
+ step = index.step
+
+ return tuple(range(start, stop, step))
+
+ else:
+ return (index,)
+
+ proper_index = (
+ get_proper_index(self.index[0], underlying.shape[0]),
+ get_proper_index(self.index[1], underlying.shape[1]),
+ )
+
@dataclass_abc.dataclass_abc(frozen=True)
class GetItemPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
index: tuple[int, int]
- shape: tuple[int, int]
- # aux_terms: tuple
+
+ @property
+ def shape(self) -> tuple[int, int]:
+ return (len(self.index[0]), len(self.index[1]))
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- assert row == 0 and col == 0
+ try:
+ n_row = self.index[0][row]
+ except IndexError:
+ raise IndexError(f'tuple index {row} out of range given {self.index[0]}')
- return self.underlying.get_poly(self.index[0], self.index[1])
+ try:
+ n_col = self.index[1][col]
+ except IndexError:
+ raise IndexError(f'tuple index {col} out of range given {self.index[1]}')
+
+ return self.underlying.get_poly(n_row, n_col)
return state, GetItemPolyMatrix(
underlying=underlying,
- shape=(1, 1),
- index=self.index,
- # aux_terms=underlying.aux_terms,
+ index=proper_index,
)
\ No newline at end of file
diff --git a/polymatrix/expression/mixins/kktexprmixin.py b/polymatrix/expression/mixins/kktexprmixin.py
index 20949a1..d4a3d97 100644
--- a/polymatrix/expression/mixins/kktexprmixin.py
+++ b/polymatrix/expression/mixins/kktexprmixin.py
@@ -12,6 +12,7 @@ from polymatrix.expression.polymatrix import PolyMatrix
from polymatrix.expression.expressionstate import ExpressionState
+# todo: remove this?
class KKTExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
@@ -39,7 +40,7 @@ class KKTExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index 9ea4e92..b690891 100644
--- a/polymatrix/expression/mixins/linearinexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -18,22 +18,32 @@ class LinearInExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> tuple:
+ def expression(self) -> ExpressionBaseMixin:
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- # todo: uncomment this
- state, variable_indices = get_variable_indices(state, self.variables)
- # variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+ state, variables = self.expression.apply(state=state)
- terms = {}
- idx_row = 0
+ def gen_variable_monomials():
+ for _, term in variables.get_terms():
+ assert len(term) == 1, f'{term} should have only a single monomial'
+
+ for monomial in term.keys():
+ yield monomial
+
+ variable_monomials = tuple(gen_variable_monomials())
+ all_variables = set(variable for monomial in variable_monomials for variable in monomial)
+
+ # print(variables)
+ # print(variable_monomials)
+
+ terms = collections.defaultdict(dict)
for row in range(underlying.shape[0]):
for col in range(underlying.shape[1]):
@@ -43,25 +53,33 @@ class LinearInExprMixin(ExpressionBaseMixin):
except KeyError:
continue
- x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float))
-
for monomial, value in underlying_terms.items():
- x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
- p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+ for variable_index, variable_monomial in enumerate(variable_monomials):
+
+ remainder = list(monomial)
+
+ try:
+ for variable in variable_monomial:
+ remainder.remove(variable)
+
+ except ValueError:
+ continue
- assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
+ # print(remainder)
+ # print(variable_monomial)
+ # print(all(variable not in remainder for variable in all_variables))
- x_monomial_terms[x_monomial][p_monomial] += value
+ # take the first that matches
+ if all(variable not in remainder for variable in all_variables):
- for data in x_monomial_terms.values():
- terms[idx_row, 0] = dict(data)
- idx_row += 1
+ terms[row, variable_index][tuple(remainder)] = value
+ break
poly_matrix = init_poly_matrix(
- terms=terms,
- shape=(idx_row, 1),
+ terms=dict(terms),
+ shape=(underlying.shape[0], len(variable_monomials)),
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index f1e1720..1f09fdc 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -20,14 +20,14 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- assert left.shape[1] == right.shape[0]
+ assert left.shape[1] == right.shape[0], f'{left.shape[1]} is not equal to {right.shape[0]}'
terms = {}
diff --git a/polymatrix/expression/mixins/oldlinearexprmixin.py b/polymatrix/expression/mixins/oldlinearexprmixin.py
new file mode 100644
index 0000000..1bdebd9
--- /dev/null
+++ b/polymatrix/expression/mixins/oldlinearexprmixin.py
@@ -0,0 +1,131 @@
+
+import abc
+import collections
+from numpy import var
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
+
+
+class OldLinearExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ state, variable_indices = get_variable_indices(state, self.variables)
+
+ underlying_terms = underlying.get_terms()
+
+ def gen_variable_terms():
+ for _, monomial_term in underlying_terms:
+ for monomial in monomial_term.keys():
+
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ yield x_monomial
+
+ variable_terms = tuple(set(gen_variable_terms()))
+
+ terms = {}
+
+ for (row, _), monomial_term in underlying_terms:
+
+ x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float))
+
+ for monomial, value in monomial_term.items():
+
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+
+ assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
+
+ x_monomial_terms[x_monomial][p_monomial] += value
+
+ for x_monomial, data in x_monomial_terms.items():
+ terms[row, variable_terms.index(x_monomial)] = dict(data)
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=(underlying.shape[0], len(variable_terms)),
+ )
+
+ return state, poly_matrix
+
+
+ # for row in range(underlying.shape[0]):
+ # for col in range(underlying.shape[1]):
+
+ # try:
+ # underlying_terms = underlying.get_poly(row, col)
+ # except KeyError:
+ # continue
+
+ # x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float))
+
+ # for monomial, value in underlying_terms.items():
+
+ # x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ # p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+
+ # assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
+
+ # x_monomial_terms[x_monomial][p_monomial] += value
+
+ # for data in x_monomial_terms.values():
+ # terms[idx_row, 0] = dict(data)
+ # idx_row += 1
+
+ # poly_matrix = init_poly_matrix(
+ # terms=terms,
+ # shape=(idx_row, 1),
+ # )
+
+ # return state, poly_matrix
+
+ # terms = {}
+ # idx_row = 0
+
+ # for row in range(underlying.shape[0]):
+ # for col in range(underlying.shape[1]):
+
+ # try:
+ # underlying_terms = underlying.get_poly(row, col)
+ # except KeyError:
+ # continue
+
+ # x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float))
+
+ # for monomial, value in underlying_terms.items():
+
+ # x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ # p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+
+ # assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
+
+ # x_monomial_terms[x_monomial][p_monomial] += value
+
+ # for data in x_monomial_terms.values():
+ # terms[idx_row, 0] = dict(data)
+ # idx_row += 1
+
+ # poly_matrix = init_poly_matrix(
+ # terms=terms,
+ # shape=(idx_row, 1),
+ # )
+
+ # return state, poly_matrix
diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
index 255a5da..d678db5 100644
--- a/polymatrix/expression/mixins/parametrizetermsexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
@@ -1,15 +1,12 @@
import abc
import dataclasses
-import functools
-import dataclass_abc
-from polymatrix.expression.init.initexpressionstate import init_expression_state
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
-from polymatrix.expression.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin
from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
class ParametrizeTermsExprMixin(ExpressionBaseMixin):
@@ -29,7 +26,7 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
@@ -37,13 +34,14 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
if self in state.cache:
return state, state.cache[self]
- # if not hasattr(self, '_terms'):
state, underlying = self.underlying.apply(state)
- for variable in self.variables:
- state = state.register(key=variable, n_param=1)
+ # for variable in self.variables:
+ # state = state.register(key=variable, n_param=1)
- variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+ # variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+
+ state, variable_indices = get_variable_indices(state, self.variables)
start_index = state.n_param
terms = {}
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
deleted file mode 100644
index 1fff79b..0000000
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ /dev/null
@@ -1,68 +0,0 @@
-
-import abc
-
-from polymatrix.expression.init.initpolymatrix import init_poly_matrix
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expression.polymatrix import PolyMatrix
-from polymatrix.expression.expressionstate import ExpressionState
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
-
-
-class QuadraticInExprMixin(ExpressionBaseMixin):
- @property
- @abc.abstractmethod
- def underlying(self) -> ExpressionBaseMixin:
- ...
-
- @property
- @abc.abstractmethod
- def variables(self) -> tuple:
- ...
-
- # overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
- self,
- state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrix]:
- state, underlying = self.underlying.apply(state=state)
-
- assert underlying.shape == (1, 1)
-
- state, variable_indices = get_variable_indices(state, self.variables)
-
- terms = {}
-
- for row in range(underlying.shape[0]):
- for col in range(underlying.shape[1]):
-
- try:
- underlying_terms = underlying.get_poly(row, col)
- except KeyError:
- continue
-
- for monomial, value in underlying_terms.items():
-
- x_monomial = tuple(variable_indices.index(var_idx) for var_idx in monomial if var_idx in variable_indices)
- p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
-
- assert len(x_monomial) == 2, f'{x_monomial} should be of length 2'
- assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
-
- key = tuple(reversed(x_monomial))
-
- if key not in terms:
- terms[key] = {}
-
- monomial_terms = terms[key]
-
- if p_monomial not in monomial_terms:
- monomial_terms[p_monomial] = 0
-
- monomial_terms[p_monomial] += value
-
- poly_matrix = init_poly_matrix(
- terms=terms,
- shape=2*(len(self.variables),),
- )
-
- return state, poly_matrix
diff --git a/polymatrix/expression/mixins/quadraticreprexprmixin.py b/polymatrix/expression/mixins/quadraticreprexprmixin.py
new file mode 100644
index 0000000..29ef950
--- /dev/null
+++ b/polymatrix/expression/mixins/quadraticreprexprmixin.py
@@ -0,0 +1,87 @@
+
+import abc
+import collections
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
+
+
+class QuadraticReprExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> tuple:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape == (1, 1)
+
+ state, variable_indices = get_variable_indices(state, self.variables)
+
+ # state, variables = self.expression.apply(state=state)
+
+ def gen_variable_monomials():
+ _, monomials = underlying.get_terms()[0]
+
+ for monomial, value in monomials.items():
+ x_monomial = tuple(variable for variable in monomial if variable in variable_indices)
+ p_monomial = tuple(variable for variable in monomial if variable not in variable_indices)
+
+ assert len(x_monomial) % 2 == 0, f'length of {x_monomial} should be a multiple of 2'
+
+ x_monomial_sorted = tuple(sorted(x_monomial))
+
+ idx_half = int(len(x_monomial)/2)
+
+ yield x_monomial_sorted[:idx_half], x_monomial_sorted[idx_half:], p_monomial, value
+
+ underlying_monomials = tuple(gen_variable_monomials())
+
+ def gen_sos_monomials():
+ for left_monomial, right_monomial, _, _ in underlying_monomials:
+ yield left_monomial
+ yield right_monomial
+
+ sos_monomials = tuple(sorted(set(gen_sos_monomials()), key=lambda m: (len(m), m)))
+
+ # print(f'{underlying_monomials=}')
+ # print(f'{sos_monomials=}')
+
+ terms = collections.defaultdict(dict)
+
+ for left_monomial, right_monomial, p_monomial, value in underlying_monomials:
+
+ col = sos_monomials.index(left_monomial)
+ row = sos_monomials.index(right_monomial)
+ key = (row, col)
+
+ # print(key)
+
+ monomial_terms = terms[key]
+
+ if p_monomial not in monomial_terms:
+ monomial_terms[p_monomial] = 0
+
+ monomial_terms[p_monomial] += value
+
+ # print(terms)
+
+ poly_matrix = init_poly_matrix(
+ terms=dict(terms),
+ shape=2*(len(sos_monomials),),
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py
index 5d5c00f..90c6a66 100644
--- a/polymatrix/expression/mixins/repmatexprmixin.py
+++ b/polymatrix/expression/mixins/repmatexprmixin.py
@@ -17,7 +17,7 @@ class RepMatExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py
index 886f074..ad2a5a6 100644
--- a/polymatrix/expression/mixins/reshapeexprmixin.py
+++ b/polymatrix/expression/mixins/reshapeexprmixin.py
@@ -20,7 +20,7 @@ class ReshapeExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionStateMixin,
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py
new file mode 100644
index 0000000..47f2fcf
--- /dev/null
+++ b/polymatrix/expression/mixins/squeezeexprmixin.py
@@ -0,0 +1,53 @@
+
+import abc
+import collections
+import typing
+import dataclass_abc
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+
+
+class SqueezeExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape[1] == 1
+
+ terms = {}
+ row_index = 0
+
+ for row in range(underlying.shape[0]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, 0)
+ except KeyError:
+ continue
+
+ terms_row_col = {}
+
+ for monomial, value in underlying_terms.items():
+ if value != 0.0:
+ terms_row_col[monomial] = value
+
+ if len(terms_row_col):
+ terms[row_index, 0] = terms_row_col
+ row_index += 1
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=(row_index, 1),
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py
new file mode 100644
index 0000000..96bf3da
--- /dev/null
+++ b/polymatrix/expression/mixins/sumexprmixin.py
@@ -0,0 +1,49 @@
+
+import abc
+import collections
+import dataclasses
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+class SumExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, underlying = self.underlying.apply(state)
+
+ terms = collections.defaultdict(dict)
+
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ term_monomials = terms[row, 0]
+
+ for monomial, coeff in underlying_terms.items():
+ if monomial in term_monomials:
+ term_monomials[monomial] += coeff
+ else:
+ term_monomials[monomial] = coeff
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=(underlying.shape[0], 1),
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py
new file mode 100644
index 0000000..e914895
--- /dev/null
+++ b/polymatrix/expression/mixins/symmetricexprmixin.py
@@ -0,0 +1,68 @@
+
+import abc
+import itertools
+import dataclass_abc
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+
+
+class SymmetricExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape[0] == underlying.shape[1]
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class SymmetricPolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+
+ def gen_symmetric_monomials():
+ for i_row, i_col in ((row, col), (col, row)):
+ try:
+ monomials = self.underlying.get_poly(i_row, i_col)
+ except:
+ pass
+ else:
+ yield monomials
+
+ all_monomials = tuple(gen_symmetric_monomials())
+
+ if len(all_monomials) == 0:
+ raise KeyError()
+
+ else:
+ terms = {}
+
+ # merge monomials
+ for monomials in all_monomials:
+ for monomial, value in monomials.items():
+ if monomial in terms:
+ terms[monomial] = (terms[monomial] + value) / 2
+ else:
+ terms[monomial] = value
+ return terms
+
+ polymat = SymmetricPolyMatrix(
+ underlying=underlying,
+ )
+
+ return state, polymat
diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py
new file mode 100644
index 0000000..1276251
--- /dev/null
+++ b/polymatrix/expression/mixins/toconstantexprmixin.py
@@ -0,0 +1,44 @@
+
+import abc
+import collections
+from numpy import var
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.expression.expressionstate import ExpressionState
+from polymatrix.expression.utils.getvariableindices import get_variable_indices
+
+
+class ToConstantExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionState,
+ ) -> tuple[ExpressionState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ terms = collections.defaultdict(dict)
+
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ if tuple() in underlying_terms:
+ terms[row, col][tuple()] = underlying_terms[tuple()]
+
+ poly_matrix = init_poly_matrix(
+ terms=dict(terms),
+ shape=underlying.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py
index 0aac30a..2b62a89 100644
--- a/polymatrix/expression/mixins/toquadraticexprmixin.py
+++ b/polymatrix/expression/mixins/toquadraticexprmixin.py
@@ -16,7 +16,7 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py
index fa074d7..da1c288 100644
--- a/polymatrix/expression/mixins/transposeexprmixin.py
+++ b/polymatrix/expression/mixins/transposeexprmixin.py
@@ -18,7 +18,7 @@ class TransposeExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `PolyMatrixExprBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
index 90cfb20..b27828e 100644
--- a/polymatrix/expression/mixins/vstackexprmixin.py
+++ b/polymatrix/expression/mixins/vstackexprmixin.py
@@ -16,7 +16,7 @@ class VStackExprMixin(ExpressionBaseMixin):
...
# overwrites abstract method of `ExpressionBaseMixin`
- def _apply(
+ def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
@@ -26,7 +26,8 @@ class VStackExprMixin(ExpressionBaseMixin):
state, polymat = expr.apply(state=state)
all_underlying.append(polymat)
- assert all(underlying.shape[1] == all_underlying[0].shape[1] for underlying in all_underlying)
+ for underlying in all_underlying:
+ assert underlying.shape[1] == all_underlying[0].shape[1], f'{underlying.shape[1]} not equal {all_underlying[0].shape[1]}'
@dataclass_abc.dataclass_abc(frozen=True)
class VStackPolyMatrix(PolyMatrixMixin):
diff --git a/polymatrix/expression/quadraticinexpr.py b/polymatrix/expression/quadraticinexpr.py
index bab76f6..59f452b 100644
--- a/polymatrix/expression/quadraticinexpr.py
+++ b/polymatrix/expression/quadraticinexpr.py
@@ -1,4 +1,4 @@
-from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin
+from polymatrix.expression.mixins.quadraticreprexprmixin import QuadraticReprExprMixin
-class QuadraticInExpr(QuadraticInExprMixin):
+class QuadraticInExpr(QuadraticReprExprMixin):
pass
diff --git a/polymatrix/expression/squeezeexpr.py b/polymatrix/expression/squeezeexpr.py
new file mode 100644
index 0000000..5472764
--- /dev/null
+++ b/polymatrix/expression/squeezeexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin
+
+class SqueezeExpr(SqueezeExprMixin):
+ pass
diff --git a/polymatrix/expression/sumexpr.py b/polymatrix/expression/sumexpr.py
new file mode 100644
index 0000000..7b62e59
--- /dev/null
+++ b/polymatrix/expression/sumexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.sumexprmixin import SumExprMixin
+
+class SumExpr(SumExprMixin):
+ pass
diff --git a/polymatrix/expression/symmetricexpr.py b/polymatrix/expression/symmetricexpr.py
new file mode 100644
index 0000000..ffbaa90
--- /dev/null
+++ b/polymatrix/expression/symmetricexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin
+
+class SymmetricExpr(SymmetricExprMixin):
+ pass
diff --git a/polymatrix/expression/toconstantexpr.py b/polymatrix/expression/toconstantexpr.py
new file mode 100644
index 0000000..d1fb2a9
--- /dev/null
+++ b/polymatrix/expression/toconstantexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.toconstantexprmixin import ToConstantExprMixin
+
+class ToConstantExpr(ToConstantExprMixin):
+ pass
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index 46bb51b..71a91f6 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -14,16 +14,18 @@ def get_variable_indices(state, variables):
for row in range(variables.shape[0]):
row_terms = variables.get_poly(row, 0)
- assert len(row_terms) == 1
+ assert len(row_terms) == 1, f'{row_terms} contains more than one term'
for monomial in row_terms.keys():
- assert len(monomial) == 1
+ assert len(monomial) == 1, f'{monomial} contains more than one variable'
yield monomial[0]
return state, tuple(gen_indices())
else:
+ raise Exception('not supported anymore')
+
if not isinstance(variables, tuple):
variables = (variables,)
diff --git a/polymatrix/sympyutils.py b/polymatrix/sympyutils.py
index ac91b8f..3e727ed 100644
--- a/polymatrix/sympyutils.py
+++ b/polymatrix/sympyutils.py
@@ -18,7 +18,7 @@ def poly_to_data_coord(poly_list, x, degree = None):
sympy_poly_list = tuple(tuple(sympy.poly(p, x) for p in inner_poly_list) for inner_poly_list in poly_list)
if degree is None:
- degree = max(degree for inner_poly_list in sympy_poly_list for poly in inner_poly_list for degree in poly.degree_list())
+ degree = max(sum(monom) for inner_poly_list in sympy_poly_list for poly in inner_poly_list for monom in poly.monoms())
n = len(x)