summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/__init__.py54
-rw-r--r--polymatrix/addexpr.py4
-rw-r--r--polymatrix/exprcontainer.py10
-rw-r--r--polymatrix/expression/__init__.py0
-rw-r--r--polymatrix/expression/derivativeexpr.py4
-rw-r--r--polymatrix/expression/derivativekey.py4
-rw-r--r--polymatrix/expression/determinantexpr.py4
-rw-r--r--polymatrix/expression/divisionexpr.py4
-rw-r--r--polymatrix/expression/elemmultexpr.py4
-rw-r--r--polymatrix/expression/evalexpr.py4
-rw-r--r--polymatrix/expression/expression.py4
-rw-r--r--polymatrix/expression/expressionstate.py4
-rw-r--r--polymatrix/expression/forallexpr.py4
-rw-r--r--polymatrix/expression/fromarrayexpr.py4
-rw-r--r--polymatrix/expression/getitemexpr.py4
-rw-r--r--polymatrix/expression/impl/derivativeexprimpl.py10
-rw-r--r--polymatrix/expression/impl/derivativekeyimpl.py8
-rw-r--r--polymatrix/expression/impl/determinantexprimpl.py8
-rw-r--r--polymatrix/expression/impl/divisionexprimpl.py9
-rw-r--r--polymatrix/expression/impl/elemmultexprimpl.py9
-rw-r--r--polymatrix/expression/impl/evalexprimpl.py10
-rw-r--r--polymatrix/expression/impl/expressionimpl.py8
-rw-r--r--polymatrix/expression/impl/expressionstateimpl.py11
-rw-r--r--polymatrix/expression/impl/forallexprimpl.py9
-rw-r--r--polymatrix/expression/impl/fromarrayexprimpl.py8
-rw-r--r--polymatrix/expression/impl/getitemexprimpl.py9
-rw-r--r--polymatrix/expression/impl/kktexprimpl.py11
-rw-r--r--polymatrix/expression/impl/parametrizesymbolsexprimpl.py10
-rw-r--r--polymatrix/expression/impl/polymatriximpl.py9
-rw-r--r--polymatrix/expression/impl/quadraticinexprimpl.py9
-rw-r--r--polymatrix/expression/impl/repmatexprimpl.py9
-rw-r--r--polymatrix/expression/impl/transposeexprimpl.py8
-rw-r--r--polymatrix/expression/impl/vstackexprimpl.py7
-rw-r--r--polymatrix/expression/init/initderivativeexpr.py17
-rw-r--r--polymatrix/expression/init/initderivativekey.py11
-rw-r--r--polymatrix/expression/init/initdeterminantexpr.py10
-rw-r--r--polymatrix/expression/init/initdivisionexpr.py12
-rw-r--r--polymatrix/expression/init/initelemmultexpr.py12
-rw-r--r--polymatrix/expression/init/initevalexpr.py21
-rw-r--r--polymatrix/expression/init/initexpression.py10
-rw-r--r--polymatrix/expression/init/initexpressionstate.py19
-rw-r--r--polymatrix/expression/init/initforallexpr.py12
-rw-r--r--polymatrix/expression/init/initfromarrayexpr.py31
-rw-r--r--polymatrix/expression/init/initgetitemexpr.py12
-rw-r--r--polymatrix/expression/init/initkktexpr.py15
-rw-r--r--polymatrix/expression/init/initparametrizesymbolsexpr.py14
-rw-r--r--polymatrix/expression/init/initpolymatrix.py16
-rw-r--r--polymatrix/expression/init/initquadraticinexpr.py12
-rw-r--r--polymatrix/expression/init/initrepmatexpr.py12
-rw-r--r--polymatrix/expression/init/inittransposeexpr.py10
-rw-r--r--polymatrix/expression/init/initvstackexpr.py9
-rw-r--r--polymatrix/expression/kktexpr.py4
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py70
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py155
-rw-r--r--polymatrix/expression/mixins/derivativekeymixin.py12
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py115
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py81
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py74
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py93
-rw-r--r--polymatrix/expression/mixins/expressionbasemixin.py17
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py213
-rw-r--r--polymatrix/expression/mixins/expressionstatemixin.py54
-rw-r--r--polymatrix/expression/mixins/forallexprmixin.py72
-rw-r--r--polymatrix/expression/mixins/fromarrayexprmixin.py76
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py55
-rw-r--r--polymatrix/expression/mixins/kktexprmixin.py126
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py75
-rw-r--r--polymatrix/expression/mixins/parametrizetermsexprmixin.py175
-rw-r--r--polymatrix/expression/mixins/polymatrixasdictmixin.py28
-rw-r--r--polymatrix/expression/mixins/polymatrixmixin.py109
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py72
-rw-r--r--polymatrix/expression/mixins/repmatexprmixin.py49
-rw-r--r--polymatrix/expression/mixins/transposeexprmixin.py43
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py66
-rw-r--r--polymatrix/expression/parametrizesymbolsexpr.py4
-rw-r--r--polymatrix/expression/polymatrix.py4
-rw-r--r--polymatrix/expression/quadraticinexpr.py4
-rw-r--r--polymatrix/expression/repmatexpr.py4
-rw-r--r--polymatrix/expression/transposeexpr.py4
-rw-r--r--polymatrix/expression/vstackexpr.py4
-rw-r--r--polymatrix/impl/addexprimpl.py8
-rw-r--r--polymatrix/impl/exprcontainerimpl.py8
-rw-r--r--polymatrix/impl/multexprimpl.py8
-rw-r--r--polymatrix/impl/oldpolymatriximpl.py (renamed from polymatrix/impl/polymatriximpl.py)4
-rw-r--r--polymatrix/impl/optimizationimpl.py4
-rw-r--r--polymatrix/impl/optimizationstateimpl.py8
-rw-r--r--polymatrix/impl/polymatexprimpl.py7
-rw-r--r--polymatrix/impl/polymatrixaddexprimpl.py9
-rw-r--r--polymatrix/impl/polymatrixarrayexprimpl.py8
-rw-r--r--polymatrix/impl/polymatrixexprimpl.py8
-rw-r--r--polymatrix/impl/polymatrixexprstateimpl.py12
-rw-r--r--polymatrix/impl/polymatrixmultexprimpl.py9
-rw-r--r--polymatrix/impl/polymatrixparamexprimpl.py9
-rw-r--r--polymatrix/impl/polymatrixvalueimpl.py7
-rw-r--r--polymatrix/impl/scalarmultexprimpl.py8
-rw-r--r--polymatrix/init/initaddexpr.py11
-rw-r--r--polymatrix/init/initexprcontainer.py9
-rw-r--r--polymatrix/init/initmultexpr.py11
-rw-r--r--polymatrix/init/initoptimization.py2
-rw-r--r--polymatrix/init/initoptimizationstate.py4
-rw-r--r--polymatrix/init/initpolymatexpr.py10
-rw-r--r--polymatrix/init/initpolymatrixaddexpr.py12
-rw-r--r--polymatrix/init/initpolymatrixarrayexpr.py10
-rw-r--r--polymatrix/init/initpolymatrixexpr.py10
-rw-r--r--polymatrix/init/initpolymatrixexprstate.py26
-rw-r--r--polymatrix/init/initpolymatrixmultexpr.py12
-rw-r--r--polymatrix/init/initpolymatrixparamexpr.py33
-rw-r--r--polymatrix/init/initpolymatrixvalue.py10
-rw-r--r--polymatrix/init/initscalarmultexpr.py11
-rw-r--r--polymatrix/init/oldinitpolymatrix.py (renamed from polymatrix/init/initpolymatrix.py)7
-rw-r--r--polymatrix/mixins/addexprmixin.py16
-rw-r--r--polymatrix/mixins/exprcontainermixin.py82
-rw-r--r--polymatrix/mixins/multexprmixin.py15
-rw-r--r--polymatrix/mixins/oldpolymatrixexprmixin.py2
-rw-r--r--polymatrix/mixins/oldpolymatrixexprstatemixin.py (renamed from polymatrix/mixins/optimizationstatemixin.py)10
-rw-r--r--polymatrix/mixins/oldpolymatrixmixin.py (renamed from polymatrix/mixins/polymatrixmixin.py)2
-rw-r--r--polymatrix/mixins/optimizationmixin.py288
-rw-r--r--polymatrix/mixins/optimizationpipeopmixin.py12
-rw-r--r--polymatrix/mixins/scalarmultexprmixin.py14
-rw-r--r--polymatrix/mixins/statemonadmixin.py18
-rw-r--r--polymatrix/multexpr.py4
-rw-r--r--polymatrix/oldpolymatrix.py6
-rw-r--r--polymatrix/oldpolymatrixexprstate.py5
-rw-r--r--polymatrix/optimizationstate.py5
-rw-r--r--polymatrix/polymatrix.py5
-rw-r--r--polymatrix/polymatrixaddexpr.py4
-rw-r--r--polymatrix/polymatrixarrayexpr.py4
-rw-r--r--polymatrix/polymatrixexpr.py4
-rw-r--r--polymatrix/polymatrixexprstate.py4
-rw-r--r--polymatrix/polymatrixmultexpr.py4
-rw-r--r--polymatrix/polymatrixparamexpr.py4
-rw-r--r--polymatrix/polysolver.py4
-rw-r--r--polymatrix/polystruct.py480
-rw-r--r--polymatrix/scalarmultexpr.py4
-rw-r--r--polymatrix/statemonad.py2
135 files changed, 2970 insertions, 710 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index e69de29..0f9bd81 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -0,0 +1,54 @@
+# import typing
+# from polymatrix.init.initpolymatrixexpr import init_poly_matrix_expr
+# from polymatrix.init.initpolymatrixparamexpr import init_poly_matrix_param_expr
+
+
+from ast import Expression
+from polymatrix.expression.init.initexpression import init_expression
+from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
+from polymatrix.expression.init.initkktexpr import init_kkt_expr
+from polymatrix.expression.init.initvstackexpr import init_v_stack_expr
+
+
+def from_(
+ data: tuple[tuple[float]],
+ # name: str,
+ # shape: tuple,
+ # degrees: tuple,
+ # re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] = None,
+):
+ return init_expression(
+ init_from_array_expr(data)
+ )
+
+ # return init_poly_matrix_expr(
+ # underlying=init_poly_matrix_param_expr(
+ # name=name,
+ # degrees=degrees,
+ # shape=shape,
+ # re_index=re_index,
+ # )
+ # )
+
+def v_stack(
+ expressions: tuple[Expression],
+):
+ return init_expression(
+ init_v_stack_expr(expressions)
+ )
+
+
+def kkt(
+ cost: Expression,
+ equality: Expression,
+ variables: Expression
+):
+ return init_expression(
+ init_kkt_expr(
+ cost=cost,
+ equality=equality,
+ variables=variables,
+ )
+ )
+
+
diff --git a/polymatrix/addexpr.py b/polymatrix/addexpr.py
new file mode 100644
index 0000000..983645e
--- /dev/null
+++ b/polymatrix/addexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.mixins.addexprmixin import AddExprMixin
+
+class AddExpr(AddExprMixin):
+ pass \ No newline at end of file
diff --git a/polymatrix/exprcontainer.py b/polymatrix/exprcontainer.py
new file mode 100644
index 0000000..e5766f1
--- /dev/null
+++ b/polymatrix/exprcontainer.py
@@ -0,0 +1,10 @@
+import abc
+import typing
+from polymatrix.mixins.exprcontainermixin import ExprContainerMixin, ExprType
+
+class ExprContainer(
+ ExprContainerMixin[ExprType],
+ typing.Generic[ExprType],
+ abc.ABC,
+):
+ pass
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/polymatrix/expression/__init__.py
diff --git a/polymatrix/expression/derivativeexpr.py b/polymatrix/expression/derivativeexpr.py
new file mode 100644
index 0000000..6805e4b
--- /dev/null
+++ b/polymatrix/expression/derivativeexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.derivativeexprmixin import DerivativeExprMixin
+
+class DerivativeExpr(DerivativeExprMixin):
+ pass
diff --git a/polymatrix/expression/derivativekey.py b/polymatrix/expression/derivativekey.py
new file mode 100644
index 0000000..6d1fff8
--- /dev/null
+++ b/polymatrix/expression/derivativekey.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.derivativekeymixin import DerivativeKeyMixin
+
+class DerivativeKey(DerivativeKeyMixin):
+ pass
diff --git a/polymatrix/expression/determinantexpr.py b/polymatrix/expression/determinantexpr.py
new file mode 100644
index 0000000..6e584f9
--- /dev/null
+++ b/polymatrix/expression/determinantexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.determinantexprmixin import DeterminantExprMixin
+
+class DeterminantExpr(DeterminantExprMixin):
+ pass
diff --git a/polymatrix/expression/divisionexpr.py b/polymatrix/expression/divisionexpr.py
new file mode 100644
index 0000000..63c95d9
--- /dev/null
+++ b/polymatrix/expression/divisionexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.divisionexprmixin import DivisionExprMixin
+
+class DivisionExpr(DivisionExprMixin):
+ pass
diff --git a/polymatrix/expression/elemmultexpr.py b/polymatrix/expression/elemmultexpr.py
new file mode 100644
index 0000000..1c8ef58
--- /dev/null
+++ b/polymatrix/expression/elemmultexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.elemmultexprmixin import ElemMultExprMixin
+
+class ElemMultExpr(ElemMultExprMixin):
+ pass
diff --git a/polymatrix/expression/evalexpr.py b/polymatrix/expression/evalexpr.py
new file mode 100644
index 0000000..0af72d1
--- /dev/null
+++ b/polymatrix/expression/evalexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.evalexprmixin import EvalExprMixin
+
+class EvalExpr(EvalExprMixin):
+ pass
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
new file mode 100644
index 0000000..7f34637
--- /dev/null
+++ b/polymatrix/expression/expression.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
+
+class Expression(ExpressionMixin):
+ pass
diff --git a/polymatrix/expression/expressionstate.py b/polymatrix/expression/expressionstate.py
new file mode 100644
index 0000000..e4b97aa
--- /dev/null
+++ b/polymatrix/expression/expressionstate.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+
+class ExpressionState(ExpressionStateMixin):
+ pass
diff --git a/polymatrix/expression/forallexpr.py b/polymatrix/expression/forallexpr.py
new file mode 100644
index 0000000..972d553
--- /dev/null
+++ b/polymatrix/expression/forallexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.forallexprmixin import ForAllExprMixin
+
+class ForAllExpr(ForAllExprMixin):
+ pass
diff --git a/polymatrix/expression/fromarrayexpr.py b/polymatrix/expression/fromarrayexpr.py
new file mode 100644
index 0000000..b22792b
--- /dev/null
+++ b/polymatrix/expression/fromarrayexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.fromarrayexprmixin import FromArrayExprMixin
+
+class FromArrayExpr(FromArrayExprMixin):
+ pass
diff --git a/polymatrix/expression/getitemexpr.py b/polymatrix/expression/getitemexpr.py
new file mode 100644
index 0000000..e7d32c2
--- /dev/null
+++ b/polymatrix/expression/getitemexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.getitemexprmixin import GetItemExprMixin
+
+class GetItemExpr(GetItemExprMixin):
+ pass
diff --git a/polymatrix/expression/impl/derivativeexprimpl.py b/polymatrix/expression/impl/derivativeexprimpl.py
new file mode 100644
index 0000000..87884ab
--- /dev/null
+++ b/polymatrix/expression/impl/derivativeexprimpl.py
@@ -0,0 +1,10 @@
+import dataclass_abc
+from polymatrix.expression.derivativeexpr import DerivativeExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class DerivativeExprImpl(DerivativeExpr):
+ underlying: ExpressionBaseMixin
+ variables: tuple
+ introduce_derivatives: bool
diff --git a/polymatrix/expression/impl/derivativekeyimpl.py b/polymatrix/expression/impl/derivativekeyimpl.py
new file mode 100644
index 0000000..675503a
--- /dev/null
+++ b/polymatrix/expression/impl/derivativekeyimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.derivativekey import DerivativeKey
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class DerivativeKeyImpl(DerivativeKey):
+ variable: int
+ with_respect_to: int
diff --git a/polymatrix/expression/impl/determinantexprimpl.py b/polymatrix/expression/impl/determinantexprimpl.py
new file mode 100644
index 0000000..3dc4ace
--- /dev/null
+++ b/polymatrix/expression/impl/determinantexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.determinantexpr import DeterminantExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class DeterminantExprImpl(DeterminantExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/divisionexprimpl.py b/polymatrix/expression/impl/divisionexprimpl.py
new file mode 100644
index 0000000..c2c8966
--- /dev/null
+++ b/polymatrix/expression/impl/divisionexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.divisionexpr import DivisionExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class DivisionExprImpl(DivisionExpr):
+ left: ExpressionBaseMixin
+ right: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/elemmultexprimpl.py b/polymatrix/expression/impl/elemmultexprimpl.py
new file mode 100644
index 0000000..a9990e4
--- /dev/null
+++ b/polymatrix/expression/impl/elemmultexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.elemmultexpr import ElemMultExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ElemMultExprImpl(ElemMultExpr):
+ left: ExpressionBaseMixin
+ right: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/evalexprimpl.py b/polymatrix/expression/impl/evalexprimpl.py
new file mode 100644
index 0000000..1730170
--- /dev/null
+++ b/polymatrix/expression/impl/evalexprimpl.py
@@ -0,0 +1,10 @@
+import dataclass_abc
+from polymatrix.expression.evalexpr import EvalExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class EvalExprImpl(EvalExpr):
+ underlying: ExpressionBaseMixin
+ variables: tuple
+ eval_values: tuple
diff --git a/polymatrix/expression/impl/expressionimpl.py b/polymatrix/expression/impl/expressionimpl.py
new file mode 100644
index 0000000..ab1da5d
--- /dev/null
+++ b/polymatrix/expression/impl/expressionimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.expression import Expression
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ExpressionImpl(Expression):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/expressionstateimpl.py b/polymatrix/expression/impl/expressionstateimpl.py
new file mode 100644
index 0000000..89afb56
--- /dev/null
+++ b/polymatrix/expression/impl/expressionstateimpl.py
@@ -0,0 +1,11 @@
+import dataclass_abc
+from polymatrix.expression.expressionstate import ExpressionState
+
+from typing import Optional
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ExpressionStateImpl(ExpressionState):
+ n_param: int
+ offset_dict: dict
+ auxillary_terms: tuple[dict[tuple[int], float]]
+ cached_polymatrix: dict
diff --git a/polymatrix/expression/impl/forallexprimpl.py b/polymatrix/expression/impl/forallexprimpl.py
new file mode 100644
index 0000000..0960479
--- /dev/null
+++ b/polymatrix/expression/impl/forallexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.forallexpr import ForAllExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ForAllExprImpl(ForAllExpr):
+ underlying: ExpressionBaseMixin
+ variables: tuple
diff --git a/polymatrix/expression/impl/fromarrayexprimpl.py b/polymatrix/expression/impl/fromarrayexprimpl.py
new file mode 100644
index 0000000..7d7c45e
--- /dev/null
+++ b/polymatrix/expression/impl/fromarrayexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.fromarrayexpr import FromArrayExpr
+
+from numpy import array
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class FromArrayExprImpl(FromArrayExpr):
+ data: tuple[tuple[float]]
diff --git a/polymatrix/expression/impl/getitemexprimpl.py b/polymatrix/expression/impl/getitemexprimpl.py
new file mode 100644
index 0000000..b8972a9
--- /dev/null
+++ b/polymatrix/expression/impl/getitemexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.getitemexpr import GetItemExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class GetItemExprImpl(GetItemExpr):
+ underlying: ExpressionBaseMixin
+ index: tuple
diff --git a/polymatrix/expression/impl/kktexprimpl.py b/polymatrix/expression/impl/kktexprimpl.py
new file mode 100644
index 0000000..e7df9bf
--- /dev/null
+++ b/polymatrix/expression/impl/kktexprimpl.py
@@ -0,0 +1,11 @@
+import dataclass_abc
+from polymatrix.expression.kktexpr import KKTExpr
+
+from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class KKTExprImpl(KKTExpr):
+ cost: ExpressionMixin
+ equality: ExpressionMixin
+ variables: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/parametrizesymbolsexprimpl.py b/polymatrix/expression/impl/parametrizesymbolsexprimpl.py
new file mode 100644
index 0000000..af1fe3e
--- /dev/null
+++ b/polymatrix/expression/impl/parametrizesymbolsexprimpl.py
@@ -0,0 +1,10 @@
+import dataclass_abc
+from polymatrix.expression.parametrizesymbolsexpr import ParametrizeSymbolsExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ParametrizeSymbolsExprImpl(ParametrizeSymbolsExpr):
+ name: str
+ underlying: ExpressionBaseMixin
+ variables: tuple
diff --git a/polymatrix/expression/impl/polymatriximpl.py b/polymatrix/expression/impl/polymatriximpl.py
new file mode 100644
index 0000000..8dd6fca
--- /dev/null
+++ b/polymatrix/expression/impl/polymatriximpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.polymatrix import PolyMatrix
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixImpl(PolyMatrix):
+ terms: dict
+ shape: tuple[int, ...]
+ # aux_terms: tuple[dict[tuple[int, ...], float]]
diff --git a/polymatrix/expression/impl/quadraticinexprimpl.py b/polymatrix/expression/impl/quadraticinexprimpl.py
new file mode 100644
index 0000000..8c7ac17
--- /dev/null
+++ b/polymatrix/expression/impl/quadraticinexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.quadraticinexpr import QuadraticInExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class QuadraticInExprImpl(QuadraticInExpr):
+ underlying: ExpressionBaseMixin
+ variables: tuple
diff --git a/polymatrix/expression/impl/repmatexprimpl.py b/polymatrix/expression/impl/repmatexprimpl.py
new file mode 100644
index 0000000..f8cee79
--- /dev/null
+++ b/polymatrix/expression/impl/repmatexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.expression.repmatexpr import RepMatExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class RepMatExprImpl(RepMatExpr):
+ underlying: ExpressionBaseMixin
+ repetition: tuple
diff --git a/polymatrix/expression/impl/transposeexprimpl.py b/polymatrix/expression/impl/transposeexprimpl.py
new file mode 100644
index 0000000..59b1a53
--- /dev/null
+++ b/polymatrix/expression/impl/transposeexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.transposeexpr import TransposeExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class TransposeExprImpl(TransposeExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/impl/vstackexprimpl.py b/polymatrix/expression/impl/vstackexprimpl.py
new file mode 100644
index 0000000..9ee582d
--- /dev/null
+++ b/polymatrix/expression/impl/vstackexprimpl.py
@@ -0,0 +1,7 @@
+import dataclass_abc
+from polymatrix.expression.vstackexpr import VStackExpr
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class VStackExprImpl(VStackExpr):
+ underlying: tuple
diff --git a/polymatrix/expression/init/initderivativeexpr.py b/polymatrix/expression/init/initderivativeexpr.py
new file mode 100644
index 0000000..c640f47
--- /dev/null
+++ b/polymatrix/expression/init/initderivativeexpr.py
@@ -0,0 +1,17 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.derivativeexprimpl import DerivativeExprImpl
+
+
+def init_derivative_expr(
+ underlying: ExpressionBaseMixin,
+ variables: tuple,
+ introduce_derivatives: bool = None,
+):
+ if introduce_derivatives is None:
+ introduce_derivatives = False
+
+ return DerivativeExprImpl(
+ underlying=underlying,
+ variables=variables,
+ introduce_derivatives=introduce_derivatives,
+)
diff --git a/polymatrix/expression/init/initderivativekey.py b/polymatrix/expression/init/initderivativekey.py
new file mode 100644
index 0000000..db249ec
--- /dev/null
+++ b/polymatrix/expression/init/initderivativekey.py
@@ -0,0 +1,11 @@
+from polymatrix.expression.impl.derivativekeyimpl import DerivativeKeyImpl
+
+
+def init_derivative_key(
+ variable: int,
+ with_respect_to: int
+):
+ return DerivativeKeyImpl(
+ variable=variable,
+ with_respect_to=with_respect_to,
+)
diff --git a/polymatrix/expression/init/initdeterminantexpr.py b/polymatrix/expression/init/initdeterminantexpr.py
new file mode 100644
index 0000000..f0d1bb5
--- /dev/null
+++ b/polymatrix/expression/init/initdeterminantexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.determinantexprimpl import DeterminantExprImpl
+
+
+def init_determinant_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return DeterminantExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/initdivisionexpr.py b/polymatrix/expression/init/initdivisionexpr.py
new file mode 100644
index 0000000..2a701a9
--- /dev/null
+++ b/polymatrix/expression/init/initdivisionexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.divisionexprimpl import DivisionExprImpl
+
+
+def init_division_expr(
+ left: ExpressionBaseMixin,
+ right: ExpressionBaseMixin,
+):
+ return DivisionExprImpl(
+ left=left,
+ right=right,
+)
diff --git a/polymatrix/expression/init/initelemmultexpr.py b/polymatrix/expression/init/initelemmultexpr.py
new file mode 100644
index 0000000..ae0e85f
--- /dev/null
+++ b/polymatrix/expression/init/initelemmultexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.elemmultexprimpl import ElemMultExprImpl
+
+
+def init_elem_mult_expr(
+ left: ExpressionBaseMixin,
+ right: ExpressionBaseMixin,
+):
+ return ElemMultExprImpl(
+ left=left,
+ right=right,
+)
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
new file mode 100644
index 0000000..deeb4e9
--- /dev/null
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -0,0 +1,21 @@
+from tkinter import Variable
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.evalexprimpl import EvalExprImpl
+
+
+def init_eval_expr(
+ underlying: ExpressionBaseMixin,
+ variables: tuple,
+ eval_values: tuple,
+):
+ if not isinstance(variables, tuple):
+ variables = (variables,)
+
+ if not isinstance(eval_values, tuple):
+ eval_values = (eval_values,)
+
+ return EvalExprImpl(
+ underlying=underlying,
+ variables=variables,
+ eval_values=eval_values,
+)
diff --git a/polymatrix/expression/init/initexpression.py b/polymatrix/expression/init/initexpression.py
new file mode 100644
index 0000000..7262e47
--- /dev/null
+++ b/polymatrix/expression/init/initexpression.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.expressionimpl import ExpressionImpl
+
+
+def init_expression(
+ underlying: ExpressionBaseMixin,
+):
+ return ExpressionImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/initexpressionstate.py b/polymatrix/expression/init/initexpressionstate.py
new file mode 100644
index 0000000..a7d3aac
--- /dev/null
+++ b/polymatrix/expression/init/initexpressionstate.py
@@ -0,0 +1,19 @@
+from polymatrix.expression.impl.expressionstateimpl import ExpressionStateImpl
+
+
+def init_expression_state(
+ n_param: int = None,
+ offset_dict: dict = None,
+):
+ if n_param is None:
+ n_param = 0
+
+ if offset_dict is None:
+ offset_dict = {}
+
+ return ExpressionStateImpl(
+ n_param=n_param,
+ offset_dict=offset_dict,
+ auxillary_terms=tuple(),
+ cached_polymatrix={},
+)
diff --git a/polymatrix/expression/init/initforallexpr.py b/polymatrix/expression/init/initforallexpr.py
new file mode 100644
index 0000000..84388d2
--- /dev/null
+++ b/polymatrix/expression/init/initforallexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.forallexprimpl import ForAllExprImpl
+
+
+def init_for_all_expr(
+ underlying: ExpressionBaseMixin,
+ variables: tuple,
+):
+ return ForAllExprImpl(
+ underlying=underlying,
+ variables=variables,
+)
diff --git a/polymatrix/expression/init/initfromarrayexpr.py b/polymatrix/expression/init/initfromarrayexpr.py
new file mode 100644
index 0000000..6aab26c
--- /dev/null
+++ b/polymatrix/expression/init/initfromarrayexpr.py
@@ -0,0 +1,31 @@
+import typing
+import numpy as np
+
+from polymatrix.expression.impl.fromarrayexprimpl import FromArrayExprImpl
+
+
+def init_from_array_expr(
+ data: typing.Union[np.ndarray, tuple[tuple[float]]],
+):
+
+ match data:
+ case np.ndarray():
+ data = tuple(tuple(i for i in row) for row in data)
+
+ case tuple():
+
+ match data[0]:
+
+ case tuple():
+ n_col = len(data[0])
+ assert all(len(col) == n_col for col in data)
+
+ case _:
+ data = (data,)
+
+ case _:
+ data = ((data,),)
+
+ return FromArrayExprImpl(
+ data=data,
+ )
diff --git a/polymatrix/expression/init/initgetitemexpr.py b/polymatrix/expression/init/initgetitemexpr.py
new file mode 100644
index 0000000..140fa3a
--- /dev/null
+++ b/polymatrix/expression/init/initgetitemexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.getitemexprimpl import GetItemExprImpl
+
+
+def init_get_item_expr(
+ underlying: ExpressionBaseMixin,
+ index: tuple,
+):
+ return GetItemExprImpl(
+ underlying=underlying,
+ index=index,
+)
diff --git a/polymatrix/expression/init/initkktexpr.py b/polymatrix/expression/init/initkktexpr.py
new file mode 100644
index 0000000..d8bfba8
--- /dev/null
+++ b/polymatrix/expression/init/initkktexpr.py
@@ -0,0 +1,15 @@
+from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.kktexprimpl import KKTExprImpl
+
+
+def init_kkt_expr(
+ cost: ExpressionMixin,
+ equality: ExpressionMixin,
+ variables: ExpressionBaseMixin,
+):
+ return KKTExprImpl(
+ cost=cost,
+ equality=equality,
+ variables=variables,
+)
diff --git a/polymatrix/expression/init/initparametrizesymbolsexpr.py b/polymatrix/expression/init/initparametrizesymbolsexpr.py
new file mode 100644
index 0000000..8d7db55
--- /dev/null
+++ b/polymatrix/expression/init/initparametrizesymbolsexpr.py
@@ -0,0 +1,14 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.parametrizesymbolsexprimpl import ParametrizeSymbolsExprImpl
+
+
+def init_parametrize_symbols_expr(
+ name: str,
+ underlying: ExpressionBaseMixin,
+ variables: tuple,
+):
+ return ParametrizeSymbolsExprImpl(
+ name=name,
+ underlying=underlying,
+ variables=variables,
+)
diff --git a/polymatrix/expression/init/initpolymatrix.py b/polymatrix/expression/init/initpolymatrix.py
new file mode 100644
index 0000000..e6a6cde
--- /dev/null
+++ b/polymatrix/expression/init/initpolymatrix.py
@@ -0,0 +1,16 @@
+from polymatrix.expression.impl.polymatriximpl import PolyMatrixImpl
+
+
+def init_poly_matrix(
+ terms: dict,
+ shape: tuple,
+ aux_terms: tuple[dict[tuple[int, ...], float]] = None,
+):
+ if aux_terms is None:
+ aux_terms = tuple()
+
+ return PolyMatrixImpl(
+ terms=terms,
+ shape=shape,
+ # aux_terms=aux_terms,
+)
diff --git a/polymatrix/expression/init/initquadraticinexpr.py b/polymatrix/expression/init/initquadraticinexpr.py
new file mode 100644
index 0000000..1e73745
--- /dev/null
+++ b/polymatrix/expression/init/initquadraticinexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.quadraticinexprimpl import QuadraticInExprImpl
+
+
+def init_quadratic_in_expr(
+ underlying: ExpressionBaseMixin,
+ variables: tuple,
+):
+ return QuadraticInExprImpl(
+ underlying=underlying,
+ variables=variables,
+)
diff --git a/polymatrix/expression/init/initrepmatexpr.py b/polymatrix/expression/init/initrepmatexpr.py
new file mode 100644
index 0000000..2f7fb7c
--- /dev/null
+++ b/polymatrix/expression/init/initrepmatexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.repmatexprimpl import RepMatExprImpl
+
+
+def init_rep_mat_expr(
+ underlying: ExpressionBaseMixin,
+ repetition: tuple,
+):
+ return RepMatExprImpl(
+ underlying=underlying,
+ repetition=repetition,
+)
diff --git a/polymatrix/expression/init/inittransposeexpr.py b/polymatrix/expression/init/inittransposeexpr.py
new file mode 100644
index 0000000..9ca9d91
--- /dev/null
+++ b/polymatrix/expression/init/inittransposeexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.transposeexprimpl import TransposeExprImpl
+
+
+def init_transpose_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return TransposeExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/init/initvstackexpr.py b/polymatrix/expression/init/initvstackexpr.py
new file mode 100644
index 0000000..1d7834e
--- /dev/null
+++ b/polymatrix/expression/init/initvstackexpr.py
@@ -0,0 +1,9 @@
+from polymatrix.expression.impl.vstackexprimpl import VStackExprImpl
+
+
+def init_v_stack_expr(
+ underlying: tuple,
+):
+ return VStackExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/kktexpr.py b/polymatrix/expression/kktexpr.py
new file mode 100644
index 0000000..aa5be52
--- /dev/null
+++ b/polymatrix/expression/kktexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.kktexprmixin import KKTExprMixin
+
+class KKTExpr(KKTExprMixin):
+ pass
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
new file mode 100644
index 0000000..776e952
--- /dev/null
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -0,0 +1,70 @@
+
+import abc
+import typing
+import dataclass_abc
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class AddExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def left(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.left.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, left = self.left.apply(state=state)
+ state, right = self.right.apply(state=state)
+
+ assert left.shape == right.shape
+
+ terms = {}
+
+ for underlying in (left, right):
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ if (row, col) in terms:
+ terms_row_col = terms[row, col]
+
+ else:
+ terms_row_col = {}
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ for monomial, value in underlying_terms.items():
+
+ if monomial not in terms_row_col:
+ terms_row_col[monomial] = 0
+
+ terms_row_col[monomial] += value
+
+ terms[row, col] = terms_row_col
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
new file mode 100644
index 0000000..516aa38
--- /dev/null
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -0,0 +1,155 @@
+
+import abc
+import collections
+import typing
+import dataclass_abc
+from numpy import var
+from polymatrix.expression.init.initderivativekey import init_derivative_key
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class DerivativeExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def introduce_derivatives(self) -> bool:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ match self.variables:
+ case ExpressionBaseMixin():
+ n_cols = self.variables.shape[0]
+
+ case _:
+ n_cols = len(self.variables)
+
+ return self.underlying.shape[0], n_cols
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state = [state]
+
+ state[0], underlying = self.underlying.apply(state=state[0])
+
+ match self.variables:
+ case ExpressionBaseMixin():
+ assert self.variables.shape[1] == 1
+
+ state[0], variables = self.variables.apply(state[0])
+
+ def gen_indices():
+ for row in range(variables.shape[0]):
+ for monomial in variables.get_poly(row, 0).keys():
+ yield monomial[0]
+
+ variable_indices = tuple(sorted(gen_indices()))
+ # print(f'{variable_indices=}')
+
+ case _:
+ def gen_indices():
+ for variable in self.variable:
+ if variable in state[0].offset_dict:
+ yield state[0].offset_dict[variable][0]
+
+ variable_indices = tuple(sorted(gen_indices()))
+
+ terms = {}
+ # aux_terms = []
+
+ for var_idx, variable in enumerate(variable_indices):
+
+ def get_derivative_terms(monomial_terms):
+
+ terms_row_col = {}
+
+ for monomial, value in monomial_terms.items():
+
+ # count powers for each variable
+ monomial_cnt = dict(collections.Counter(monomial))
+
+ if variable not in monomial_cnt:
+ continue
+
+ if self.introduce_derivatives:
+ variable_candidates = (variable,) + tuple(var for var in monomial_cnt.keys() if var not in variable_indices)
+ else:
+ variable_candidates = (variable,)
+
+ for variable_candidate in variable_candidates:
+
+ def generate_monomial():
+ for current_variable, current_count in monomial_cnt.items():
+
+ if current_variable is variable_candidate:
+ sel_counter = current_count - 1
+
+ else:
+ sel_counter = current_count
+
+ for _ in range(sel_counter):
+ yield current_variable
+
+ if variable_candidate is not variable:
+ key = init_derivative_key(
+ variable=variable_candidate,
+ with_respect_to=var_idx,
+ )
+ state[0] = state[0].register(key=key, n_param=1)
+
+ yield state[0].offset_dict[key][0]
+
+ col_monomial = tuple(generate_monomial())
+
+ if col_monomial not in terms_row_col:
+ terms_row_col[col_monomial] = 0
+
+ terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+
+ return terms_row_col
+
+ for row in range(self.shape[0]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, 0)
+ except KeyError:
+ continue
+
+ derivative_terms = get_derivative_terms(underlying_terms)
+
+ if 0 < len(derivative_terms):
+ terms[row, var_idx] = derivative_terms
+
+ # if self.introduce_derivatives:
+ # for aux_monomial in underlying.aux_terms:
+
+ # derivative_terms = get_derivative_terms(aux_monomial)
+
+ # # todo: is this correct?
+ # if 1 < len(derivative_terms):
+ # aux_terms.append(derivative_terms)
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ # aux_terms=underlying.aux_terms + tuple(aux_terms),
+ )
+
+ return state[0], poly_matrix
diff --git a/polymatrix/expression/mixins/derivativekeymixin.py b/polymatrix/expression/mixins/derivativekeymixin.py
new file mode 100644
index 0000000..e0adddc
--- /dev/null
+++ b/polymatrix/expression/mixins/derivativekeymixin.py
@@ -0,0 +1,12 @@
+import abc
+
+class DerivativeKeyMixin(abc.ABC):
+ @property
+ @abc.abstractmethod
+ def variable(self) -> int:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def with_respect_to(self) -> int:
+ ...
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
new file mode 100644
index 0000000..8fe2536
--- /dev/null
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -0,0 +1,115 @@
+
+import abc
+import collections
+import dataclasses
+from numpy import var
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class DeterminantExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape[0], 1
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ if self in state.cached_polymatrix:
+ return state, state.cached_polymatrix[self]
+
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape[0] == underlying.shape[1]
+
+ inequality_terms = {}
+ auxillary_terms = []
+
+ index_start = state.n_param
+ rel_index = 0
+
+ for row in range(self.shape[0]):
+
+ current_inequality_terms = collections.defaultdict(float)
+
+ # f in f-v^T@x-r^2
+ # terms = underlying.get_poly(row, row)
+ try:
+ underlying_terms = underlying.get_poly(row, row)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ current_inequality_terms[monomial] += value
+
+ for inner_row in range(row):
+
+ # -v^T@x in f-v^T@x-r^2
+ # terms = underlying.get_poly(row, inner_row)
+ try:
+ underlying_terms = underlying.get_poly(row, inner_row)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ new_monomial = monomial + (index_start + rel_index + inner_row,)
+ current_inequality_terms[new_monomial] -= value
+
+ auxillary_term = collections.defaultdict(float)
+
+ for inner_col in range(row):
+
+ # P@x in P@x-v
+ key = tuple(reversed(sorted((inner_row, inner_col))))
+ # terms = underlying.get_poly(*key)
+ try:
+ underlying_terms = underlying.get_poly(*key)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ new_monomial = monomial + (index_start + rel_index + inner_col,)
+ auxillary_term[new_monomial] += value
+
+ # -v in P@x-v
+ # terms = underlying.get_poly(row, inner_row)
+ try:
+ underlying_terms = underlying.get_poly(row, inner_row)
+ except KeyError:
+ pass
+ else:
+ for monomial, value in underlying_terms.items():
+ auxillary_term[monomial] -= value
+
+ auxillary_terms.append(dict(auxillary_term))
+
+ rel_index += row
+ inequality_terms[row, 0] = dict(current_inequality_terms)
+
+ state = state.register(rel_index)
+
+ # print(f'{auxillary_terms=}')
+
+ poly_matrix = init_poly_matrix(
+ terms=inequality_terms,
+ shape=self.shape,
+ )
+
+ state = dataclasses.replace(
+ state,
+ auxillary_terms=state.auxillary_terms + tuple(auxillary_terms),
+ cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
new file mode 100644
index 0000000..01d3505
--- /dev/null
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -0,0 +1,81 @@
+
+import abc
+import dataclasses
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class DivisionExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def left(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.left.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ if self in state.cached_polymatrix:
+ return state, state.cached_polymatrix[self]
+
+ state, left = self.left.apply(state=state)
+ state, right = self.right.apply(state=state)
+
+ assert right.shape == (1, 1)
+
+ terms = {}
+
+ division_variable = state.n_param
+ state = state.register(n_param=1)
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ try:
+ underlying_terms = left.get_poly(row, col)
+ except KeyError:
+ continue
+
+ def gen_monomial_terms():
+ for monomial, value in underlying_terms.items():
+ yield monomial + (division_variable,), value
+
+ terms[row, col] = dict(gen_monomial_terms())
+
+ def gen_auxillary_terms():
+ for monomial, value in right.get_poly(0, 0).items():
+ yield monomial + (division_variable,), value
+
+ auxillary_terms = dict(gen_auxillary_terms())
+
+ if tuple() not in auxillary_terms:
+ auxillary_terms[tuple()] = 0
+
+ auxillary_terms[tuple()] -= 1
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ state = dataclasses.replace(
+ state,
+ auxillary_terms=state.auxillary_terms + (auxillary_terms,),
+ cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
new file mode 100644
index 0000000..9684827
--- /dev/null
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -0,0 +1,74 @@
+
+import abc
+import itertools
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class ElemMultExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def left(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.left.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, left = self.left.apply(state=state)
+ state, right = self.right.apply(state=state)
+
+ assert right.shape == (1, 1)
+
+ right_terms = right.get_poly(0, 0)
+
+ terms = {}
+
+ for poly_row in range(left.shape[0]):
+ for poly_col in range(left.shape[1]):
+
+ terms_row_col = {}
+
+ try:
+ left_terms = left.get_poly(poly_row, poly_col)
+ except KeyError:
+ continue
+
+ for (left_monomial, left_value), (right_monomial, right_value) \
+ in itertools.product(left_terms.items(), right_terms.items()):
+
+ value = left_value * right_value
+
+ if value == 0:
+ continue
+
+ monomial = tuple(sorted(left_monomial + right_monomial))
+
+ if monomial not in terms_row_col:
+ terms_row_col[monomial] = 0
+
+ terms_row_col[monomial] += value
+
+ if 0 < len(terms_row_col):
+ terms[poly_row, poly_col] = terms_row_col
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
new file mode 100644
index 0000000..b945064
--- /dev/null
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -0,0 +1,93 @@
+
+import abc
+import itertools
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class EvalExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> tuple:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def eval_values(self) -> tuple[float, ...]:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ assert len(self.variables) == len(self.eval_values)
+
+ state, underlying = self.underlying.apply(state=state)
+
+ def gen_indices():
+ for variable in self.variables:
+ if variable in state.offset_dict:
+ yield state.offset_dict[variable][0]
+
+ variable_indices = tuple(sorted(gen_indices()))
+
+ terms = {}
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ terms_row_col = {}
+
+ for monomial, value in underlying_terms.items():
+
+ def acc_monomial(acc, variable):
+ monomial, value = acc
+
+ if variable in variable_indices:
+ index = variable_indices.index(variable)
+ new_value = value * self.eval_values[index]
+ return monomial, new_value
+
+ else:
+ return monomial + (variable,), value
+
+ *_, (new_monomial, new_value) = tuple(itertools.accumulate(
+ monomial,
+ acc_monomial,
+ initial=(tuple(), value),
+ ))
+
+ # print(f'{new_monomial=}')
+
+ if new_monomial not in terms_row_col:
+ terms_row_col[new_monomial] = 0
+
+ terms_row_col[new_monomial] += new_value
+
+ terms[row, col] = terms_row_col
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py
new file mode 100644
index 0000000..d5109c0
--- /dev/null
+++ b/polymatrix/expression/mixins/expressionbasemixin.py
@@ -0,0 +1,17 @@
+import abc
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+class ExpressionBaseMixin(
+ abc.ABC,
+):
+
+ @property
+ @abc.abstractclassmethod
+ def shape(self) -> tuple[int, int]:
+ ...
+
+ @abc.abstractmethod
+ def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+ ...
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
new file mode 100644
index 0000000..e50f660
--- /dev/null
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -0,0 +1,213 @@
+import abc
+import dataclasses
+import typing
+import numpy as np
+from sympy import re
+from polymatrix.expression.init.initderivativeexpr import init_derivative_expr
+from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr
+from polymatrix.expression.init.initdivisionexpr import init_division_expr
+from polymatrix.expression.init.initelemmultexpr import init_elem_mult_expr
+from polymatrix.expression.init.initevalexpr import init_eval_expr
+from polymatrix.expression.init.initforallexpr import init_for_all_expr
+from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
+from polymatrix.expression.init.initgetitemexpr import init_get_item_expr
+from polymatrix.expression.init.initparametrizesymbolsexpr import init_parametrize_symbols_expr
+from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
+from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr
+from polymatrix.expression.init.inittransposeexpr import init_transpose_expr
+
+from polymatrix.init.initpolymatrixaddexpr import init_poly_matrix_add_expr
+from polymatrix.init.initpolymatrixmultexpr import init_poly_matrix_mult_expr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class ExpressionMixin(
+ ExpressionBaseMixin,
+ abc.ABC,
+):
+
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ def apply(self, state: PolyMatrixExprState) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ return self.underlying.apply(state)
+
+ # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ @property
+ def degrees(self) -> set[int]:
+ return self.underlying.degrees
+
+ # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape
+
+ def __iter__(self):
+ for row in range(self.shape[0]):
+ yield self[row, 0]
+
+ def __add__(self, other: ExpressionBaseMixin) -> 'ExpressionMixin':
+ # assert self.underlying.shape == other.shape, f'shapes {(self.shape, other.shape)} of polynomial matrix do not match!'
+
+ if other is None:
+ return self
+
+ return dataclasses.replace(
+ self,
+ underlying=init_poly_matrix_add_expr(
+ left=self.underlying,
+ right=other.underlying,
+ ),
+ )
+
+ def __radd__(self, other):
+ return self + other
+
+ def __getattr__(self, name):
+ attr = getattr(self.underlying, name)
+
+ if isinstance(attr, ExpressionBaseMixin):
+ return dataclasses.replace(
+ self,
+ underlying=attr,
+ )
+
+ else:
+ return attr
+
+ def __mul__(self, other) -> 'ExpressionMixin':
+ # assert isinstance(other, float)
+
+ right = init_from_array_expr(other)
+
+ return dataclasses.replace(
+ self,
+ underlying=init_elem_mult_expr(
+ left=self.underlying,
+ right=right,
+ ),
+ )
+
+ def __rmul__(self, other):
+ return self * other
+
+ def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin':
+ match other:
+ case ExpressionBaseMixin():
+ right = other.underlying
+ case _:
+ right = init_from_array_expr(other)
+
+ return dataclasses.replace(
+ self,
+ underlying=init_poly_matrix_mult_expr(
+ left=self.underlying,
+ right=right,
+ ),
+ )
+
+ def __truediv__(self, other: ExpressionBaseMixin):
+ return dataclasses.replace(
+ self,
+ underlying=init_division_expr(
+ left=self.underlying,
+ right=other,
+ ),
+ )
+
+ def __getitem__(self, key: tuple[int, int]):
+ return dataclasses.replace(
+ self,
+ underlying=init_get_item_expr(
+ underlying=self.underlying,
+ index=key,
+ ),
+ )
+
+ @property
+ def T(self) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_transpose_expr(
+ underlying=self.underlying,
+ ),
+ )
+
+ def parametrize(self, name: str, variables: tuple) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_parametrize_symbols_expr(
+ name=name,
+ underlying=self.underlying,
+ variables=variables,
+ ),
+ )
+
+ def rep_mat(self, n: int, m: int) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_rep_mat_expr(
+ underlying=self.underlying,
+ repetition=(n, m),
+ ),
+ )
+
+ def diff(
+ self,
+ variables: tuple,
+ introduce_derivatives: bool = None,
+ ) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_derivative_expr(
+ underlying=self.underlying,
+ variables=variables,
+ introduce_derivatives=introduce_derivatives,
+ ),
+ )
+
+ def for_all(self, variables: tuple) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_for_all_expr(
+ underlying=self.underlying,
+ variables=variables,
+ ),
+ )
+
+ def quadratic_in(self, variables: tuple) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_quadratic_in_expr(
+ underlying=self.underlying,
+ variables=variables,
+ ),
+ )
+
+ def determinant(self) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_determinant_expr(
+ underlying=self.underlying,
+ ),
+ )
+
+ def eval(
+ self,
+ variable: tuple,
+ value: tuple[float, ...],
+ ) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_eval_expr(
+ underlying=self.underlying,
+ variables=variable,
+ eval_values=value,
+ ),
+ )
diff --git a/polymatrix/expression/mixins/expressionstatemixin.py b/polymatrix/expression/mixins/expressionstatemixin.py
new file mode 100644
index 0000000..67085d4
--- /dev/null
+++ b/polymatrix/expression/mixins/expressionstatemixin.py
@@ -0,0 +1,54 @@
+import abc
+import itertools
+import dataclasses
+import typing
+
+import sympy
+
+
+class ExpressionStateMixin(abc.ABC):
+
+ @property
+ @abc.abstractmethod
+ def n_param(self) -> int:
+ """
+ current number of parameters used in polynomial matrix expressions
+ """
+
+ ...
+
+ @property
+ @abc.abstractmethod
+ def offset_dict(self) -> dict[tuple[typing.Any], tuple[int, int]]:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def auxillary_terms(self) -> tuple[dict[tuple[int], float]]:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def cached_polymatrix(self) -> dict:
+ ...
+
+ def register(
+ self,
+ n_param: int,
+ key: typing.Any = None,
+ ) -> 'ExpressionStateMixin':
+
+ if key is None:
+ updated_state = dataclasses.replace(self, n_param=self.n_param + n_param)
+
+ elif key not in self.offset_dict:
+ updated_state = dataclasses.replace(
+ self,
+ offset_dict=self.offset_dict | {key: (self.n_param, self.n_param + n_param)},
+ n_param=self.n_param + n_param,
+ )
+
+ else:
+ updated_state = self
+
+ return updated_state
diff --git a/polymatrix/expression/mixins/forallexprmixin.py b/polymatrix/expression/mixins/forallexprmixin.py
new file mode 100644
index 0000000..4f5577c
--- /dev/null
+++ b/polymatrix/expression/mixins/forallexprmixin.py
@@ -0,0 +1,72 @@
+
+import abc
+from numpy import var
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class ForAllExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> tuple:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+
+ terms = {}
+ idx_row = 0
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ x_monomial_terms = {}
+
+ for monomial, value in underlying_terms.items():
+
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+
+ if x_monomial not in x_monomial_terms:
+ x_monomial_terms[x_monomial] = {}
+
+ if p_monomial not in x_monomial_terms:
+ x_monomial_terms[x_monomial][p_monomial] = 0
+
+ x_monomial_terms[x_monomial][p_monomial] += value
+
+ for data in x_monomial_terms.values():
+ terms[idx_row, 0] = data
+ idx_row += 1
+
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/fromarrayexprmixin.py b/polymatrix/expression/mixins/fromarrayexprmixin.py
new file mode 100644
index 0000000..0ae210d
--- /dev/null
+++ b/polymatrix/expression/mixins/fromarrayexprmixin.py
@@ -0,0 +1,76 @@
+
+import abc
+import collections
+import typing
+import numpy as np
+import dataclass_abc
+from numpy import poly
+import sympy
+import functools
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+class FromArrayExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def data(self) -> tuple[tuple[float]]:
+ pass
+
+ # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return len(self.data), len(self.data[0])
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ terms = {}
+
+ for poly_row, col_data in enumerate(self.data):
+ for poly_col, poly_data in enumerate(col_data):
+
+ try:
+ poly = sympy.poly(poly_data)
+ except sympy.polys.polyerrors.GeneratorsNeeded:
+ terms[poly_row, poly_col] = {tuple(): poly_data}
+ continue
+
+ for symbol in poly.gens:
+ state = state.register(key=symbol, n_param=1)
+
+ terms_row_col = {}
+
+ # a5 x1 x3**2 -> c=a5, m_cnt=(1, 0, 2)
+ for value, monomial_count in zip(poly.coeffs(), poly.monoms()):
+
+ if value == 0.0:
+ continue
+
+ # m_cnt=(1, 0, 2) -> m=(0, 2, 2)
+ def gen_monomial():
+ for rel_idx, p in enumerate(monomial_count):
+
+ idx, _ = state.offset_dict[poly.gens[rel_idx]]
+
+ for _ in range(p):
+ yield idx
+
+ monomial = tuple(gen_monomial())
+
+ terms_row_col[monomial] = value
+
+ terms[poly_row, poly_col] = terms_row_col
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
new file mode 100644
index 0000000..9a521cb
--- /dev/null
+++ b/polymatrix/expression/mixins/getitemexprmixin.py
@@ -0,0 +1,55 @@
+
+import abc
+import dataclasses
+import typing
+import dataclass_abc
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class GetItemExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def index(self) -> tuple[int, int]:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return 1, 1
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class GetItemPolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ index: tuple[int, int]
+ shape: tuple[int, int]
+ # aux_terms: tuple
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ assert row == 0 and col == 0
+
+ return self.underlying.get_poly(self.index[0], self.index[1])
+
+ return state, GetItemPolyMatrix(
+ underlying=underlying,
+ shape=self.shape,
+ index=self.index,
+ # aux_terms=underlying.aux_terms,
+ )
+ \ No newline at end of file
diff --git a/polymatrix/expression/mixins/kktexprmixin.py b/polymatrix/expression/mixins/kktexprmixin.py
new file mode 100644
index 0000000..f6066d8
--- /dev/null
+++ b/polymatrix/expression/mixins/kktexprmixin.py
@@ -0,0 +1,126 @@
+
+import abc
+import itertools
+from this import d
+import typing
+import dataclass_abc
+from polymatrix.expression.derivativekey import DerivativeKey
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class KKTExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def cost(self) -> ExpressionMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def equality(self) -> ExpressionMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def equality(self) -> ExpressionMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.cost.shape[0] + self.equality.shape[0], 1
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+
+ assert self.cost.shape == self.variables.shape
+ assert self.cost.shape[1] == 1
+
+ state, cost = self.cost.apply(state=state)
+ state, equality = self.equality.apply(state=state)
+
+ state, equality_der = self.equality.diff(
+ self.variables,
+ introduce_derivatives=True,
+ ).apply(state)
+
+ def acc_nu_variables(acc, v):
+ state, nu_variables = acc
+
+ nu_variable = state.n_param
+ state = state.register(n_param=1)
+
+ return state, nu_variables + [nu_variable]
+
+ *_, (state, nu_variables) = tuple(itertools.accumulate(
+ range(self.equality.shape[0]),
+ acc_nu_variables,
+ initial=(state, []),
+ ))
+
+ terms = {}
+
+ n_row = cost.shape[0]
+
+ for row in range(n_row):
+ try:
+ monomial_terms = cost.get_poly(row, 0)
+ except KeyError:
+ monomial_terms = {}
+
+ for eq_idx, nu_variable in enumerate(nu_variables):
+
+ try:
+ underlying_terms = equality_der.get_poly(eq_idx, row)
+ except KeyError:
+ continue
+
+ for monomial, value in underlying_terms.items():
+ new_monomial = monomial + (nu_variable,)
+
+ if new_monomial not in monomial_terms:
+ monomial_terms[new_monomial] = 0
+
+ monomial_terms[new_monomial] += value
+
+ terms[row, 0] = monomial_terms
+
+ idx_start = n_row
+
+ for row in range(equality.shape[0]):
+
+ try:
+ underlying_terms = equality.get_poly(row, 0)
+ except KeyError:
+ continue
+
+ terms[idx_start + row, 0] = underlying_terms
+
+ idx_start += equality.shape[0]
+
+ # for row, aux_term in enumerate(state.auxillary_terms):
+ # terms[idx_start + row, 0] = aux_term
+
+ # idx_start += len(state.auxillary_terms)
+
+ # derivatives = tuple(key for key in state.offset_dict.keys() if isinstance(key, DerivativeKey))
+ # print(f'{derivatives=}')
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=(idx_start, 1),
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
new file mode 100644
index 0000000..29eb4ac
--- /dev/null
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -0,0 +1,75 @@
+
+import abc
+import itertools
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class MatrixMultExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def left(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return (self.left.shape[0], self.right.shape[1])
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, left = self.left.apply(state=state)
+ state, right = self.right.apply(state=state)
+
+ assert left.shape[1] == right.shape[0]
+
+ terms = {}
+
+ for poly_row in range(left.shape[0]):
+ for poly_col in range(right.shape[1]):
+
+ terms_row_col = {}
+
+ for index_k in range(left.shape[1]):
+
+ try:
+ left_terms = left.get_poly(poly_row, index_k)
+ right_terms = right.get_poly(index_k, poly_col)
+ except KeyError:
+ continue
+
+ for (left_monomial, left_value), (right_monomial, right_value) \
+ in itertools.product(left_terms.items(), right_terms.items()):
+
+ value = left_value * right_value
+
+ if value == 0:
+ continue
+
+ monomial = tuple(sorted(left_monomial + right_monomial))
+
+ if monomial not in terms_row_col:
+ terms_row_col[monomial] = 0
+
+ terms_row_col[monomial] += value
+
+ if 0 < len(terms_row_col):
+ terms[poly_row, poly_col] = terms_row_col
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
new file mode 100644
index 0000000..fbee57c
--- /dev/null
+++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
@@ -0,0 +1,175 @@
+
+import abc
+import functools
+import dataclass_abc
+from polymatrix.expression.init.initexpressionstate import init_expression_state
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+class ParametrizeTermsExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def name(self) -> str:
+ ...
+
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractclassmethod
+ def variables(self) -> tuple:
+ ...
+
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape
+
+ def _internal_apply(self, state: ExpressionStateMixin):
+ if not hasattr(self, '_terms'):
+ state, underlying = self.underlying.apply(state)
+
+ for variable in self.variables:
+ state = state.register(key=variable, n_param=1)
+
+ variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+
+ idx_start = state.n_param
+ # print(f'{idx_start=}')
+ n_param = 0
+ terms = {}
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ terms_row_col = {}
+ collected_terms = []
+
+ for monomial, value in underlying_terms.items():
+
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+
+ if x_monomial not in collected_terms:
+ collected_terms.append(x_monomial)
+
+ # print(f'{x_monomial=}')
+ # print(f'{collected_terms=}')
+
+ idx = idx_start + n_param + collected_terms.index(x_monomial)
+
+ new_monomial = monomial + (idx,)
+
+ terms_row_col[new_monomial] = value
+
+ n_param += len(collected_terms)
+ terms[row, col] = terms_row_col
+
+ state = state.register(key=self, n_param=n_param)
+
+ self._terms = terms
+ self._start_index = idx_start
+ # self._n_param = n_param
+
+ return state, self._terms
+
+ @property
+ def param(self) -> tuple[int, int]:
+ outer_self = self
+
+ # precalculate number of parameters (used for `shape` attribute)
+ # ---------------------------------
+ # not pretty
+
+ dummy_state = init_expression_state()
+
+ dummy_state, underlying = self.underlying.apply(dummy_state)
+
+ for variable in self.variables:
+ dummy_state = dummy_state.register(key=variable, n_param=1)
+
+ variable_indices = tuple(dummy_state.offset_dict[variable][0] for variable in self.variables if variable in dummy_state.offset_dict)
+
+ n_param = 0
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ underlying_terms = underlying.get_poly(row, col)
+
+ collected_terms = []
+
+ for monomial in underlying_terms.keys():
+
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+
+ if x_monomial not in collected_terms:
+ collected_terms.append(x_monomial)
+
+ n_param += len(collected_terms)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class ParameterExprMixin(ExpressionBaseMixin):
+ n_param: int
+
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.n_param, 1
+
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, _ = outer_self._internal_apply(state)
+
+ def gen_monomials():
+ for rel_index in range(self.n_param):
+ yield {(outer_self._start_index + rel_index,): 1}
+
+ terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())}
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
+
+ return ParameterExprMixin(
+ n_param=n_param,
+ )
+
+ # # overwrite this method to customize indexing
+ # def re_index(
+ # self,
+ # degree: int,
+ # poly_col: int,
+ # poly_row: int,
+ # x_monomial: tuple[int, ...],
+ # ) -> tuple[int, int, tuple[int, ...], float]:
+ # return poly_col, poly_row, x_monomial, 1.0
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, terms = self._internal_apply(state)
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/polymatrixasdictmixin.py b/polymatrix/expression/mixins/polymatrixasdictmixin.py
new file mode 100644
index 0000000..6aea8ed
--- /dev/null
+++ b/polymatrix/expression/mixins/polymatrixasdictmixin.py
@@ -0,0 +1,28 @@
+import abc
+
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+class PolyMatrixAsDictMixin(
+ PolyMatrixMixin,
+ abc.ABC,
+):
+ @property
+ @abc.abstractmethod
+ def terms(self) -> dict[tuple[int, int], dict[tuple[int, ...], float]]:
+ ...
+
+ # overwrites abstract method of `PolyMatrixMixin`
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ # key = (row, col)
+
+ # if key in self.terms:
+ # return self.terms[key]
+
+ # else:
+ # return None
+
+ try:
+ return self.terms[row, col]
+ except KeyError:
+ raise KeyError(f'{(row, col)} is not a key of {self.terms}')
diff --git a/polymatrix/expression/mixins/polymatrixmixin.py b/polymatrix/expression/mixins/polymatrixmixin.py
new file mode 100644
index 0000000..e5b803e
--- /dev/null
+++ b/polymatrix/expression/mixins/polymatrixmixin.py
@@ -0,0 +1,109 @@
+import abc
+import dataclasses
+import itertools
+import collections
+import typing
+
+
+class PolyMatrixMixin(abc.ABC):
+ @property
+ @abc.abstractclassmethod
+ def shape(self) -> tuple[int, int]:
+ ...
+
+ @abc.abstractclassmethod
+ def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ ...
+
+ def reduce(self):
+ def gen_used_variables():
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ try:
+ underlying_terms = self.get_poly(row, col)
+ except KeyError:
+ continue
+
+ for monomial in underlying_terms.keys():
+ yield from monomial
+
+ used_variables = set(gen_used_variables())
+
+ variable_index_reverse_map = tuple(sorted(used_variables))
+ variable_index_map = {old: new for new, old in enumerate(variable_index_reverse_map)}
+
+
+ terms = {}
+
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+
+ try:
+ underlying_terms = self.get_poly(row, col)
+ except KeyError:
+ continue
+
+ def gen_updated_monomials():
+ for monomial, value in underlying_terms.items():
+ new_monomial = tuple(variable_index_map[var] for var in monomial)
+ yield new_monomial, value
+
+ terms[row, col] = dict(gen_updated_monomials())
+
+ poly_matrix = dataclasses.replace(
+ self,
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return poly_matrix, variable_index_reverse_map
+
+
+ # @property
+ # @abc.abstractmethod
+ # def terms(self) -> dict[tuple[int, int], dict[tuple[int, ...], dict[tuple[int, ...], float]]]:
+ # ...
+
+ # @property
+ # @abc.abstractclassmethod
+ # def aux_terms(self) -> tuple[dict[tuple[int, ...], float]]:
+ # ...
+
+ # def get_equality_equations(
+ # self,
+ # offset: int = None
+ # ) -> dict[int, list[tuple[list[int], float]]]:
+ # if offset is None:
+ # offset = 0
+
+ # to_eq_index_dict = {}
+ # equality_constraints = collections.defaultdict(list)
+
+ # # def gen_key_equation_mapping():
+ # for (poly_row, poly_col), terms_x_monomial in self.terms.items():
+ # for x_monomial, terms_p_monomial in terms_x_monomial.items():
+ # key = (poly_row, poly_col, x_monomial)
+
+ # if key not in to_eq_index_dict:
+ # to_eq_index_dict[key] = offset
+ # offset += 1
+
+ # for p_monomial, value in terms_p_monomial.items():
+ # equality_constraints[to_eq_index_dict[key]] += ((p_monomial, value),)
+
+ # return equality_constraints
+
+ # @property
+ # @abc.abstractmethod
+ # def degrees(self) -> set[int]:
+ # ...
+
+ # @abc.abstractmethod
+ # def get_term(
+ # self,
+ # poly_col: int,
+ # poly_row: int,
+ # x_monomial: tuple[int, ...]
+ # ) -> PolyMatrixTerm:
+ # ...
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
new file mode 100644
index 0000000..8f5555e
--- /dev/null
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -0,0 +1,72 @@
+
+import abc
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class QuadraticInExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def variables(self) -> tuple:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return 2*(len(self.variables),)
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ assert underlying.shape == (1, 1)
+
+ variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
+
+ terms = {}
+
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ for monomial, value in underlying_terms.items():
+
+ x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
+ p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
+
+ assert len(x_monomial) == 2
+ assert tuple(sorted(x_monomial)) == x_monomial, f'{x_monomial} is not sorted'
+
+ key = tuple(reversed(x_monomial))
+
+ if key not in terms:
+ terms[key] = {}
+
+ monomial_terms = terms[key]
+
+ if p_monomial not in monomial_terms:
+ monomial_terms[p_monomial] = 0
+
+ monomial_terms[p_monomial] += value
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py
new file mode 100644
index 0000000..aec0d48
--- /dev/null
+++ b/polymatrix/expression/mixins/repmatexprmixin.py
@@ -0,0 +1,49 @@
+import abc
+import dataclass_abc
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+class RepMatExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractclassmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @property
+ @abc.abstractclassmethod
+ def repetition(self) -> tuple[int, int]:
+ ...
+
+ @property
+ def shape(self) -> tuple[int, int]:
+ return tuple(s*r for s, r in zip(self.underlying.shape, self.repetition))
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ state, underlying = self.underlying.apply(state)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class RepMatPolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+ # aux_terms: tuple
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ n_row, n_col = underlying.shape
+
+ rel_row = row % n_row
+ rel_col = col % n_col
+
+ return self.underlying.get_poly(rel_row, rel_col)
+
+ return state, RepMatPolyMatrix(
+ underlying=underlying,
+ shape=self.shape,
+ # aux_terms=underlying.aux_terms,
+ ) \ No newline at end of file
diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py
new file mode 100644
index 0000000..63e6a3f
--- /dev/null
+++ b/polymatrix/expression/mixins/transposeexprmixin.py
@@ -0,0 +1,43 @@
+
+import abc
+import dataclasses
+import typing
+import dataclass_abc
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class TransposeExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ return self.underlying.shape[1], self.underlying.shape[0]
+
+ # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class TransposePolyMatrix(PolyMatrixMixin):
+ underlying: PolyMatrixMixin
+ shape: tuple[int, int]
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ return self.underlying.get_poly(col, row)
+
+ return state, TransposePolyMatrix(
+ underlying=underlying,
+ shape=self.shape,
+ ) \ No newline at end of file
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
new file mode 100644
index 0000000..381659c
--- /dev/null
+++ b/polymatrix/expression/mixins/vstackexprmixin.py
@@ -0,0 +1,66 @@
+
+import abc
+import itertools
+import dataclass_abc
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class VStackExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> tuple[ExpressionBaseMixin, ...]:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ def shape(self) -> tuple[int, int]:
+ n_row = sum(expr.shape[0] for expr in self.underlying)
+ return n_row, self.underlying[0].shape[1]
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ assert all(expr.shape[1] == self.underlying[0].shape[1] for expr in self.underlying)
+
+ # todo: rename
+ underlying = []
+ for expr in self.underlying:
+ state, polymat = expr.apply(state=state)
+ underlying.append(polymat)
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class VStackPolyMatrix(PolyMatrixMixin):
+ underlying: tuple[PolyMatrixMixin]
+ underlying_row_range: tuple[tuple[int, int], ...]
+ shape: tuple[int, int]
+ # aux_terms: tuple
+
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ for polymat, (row_start, row_end) in zip(self.underlying, self.underlying_row_range):
+ if row_start <= row < row_end:
+ return polymat.get_poly(
+ row=row - row_start,
+ col=col,
+ )
+
+ raise Exception(f'row {row} is out of bounds')
+
+ underlying_row_range = tuple(itertools.pairwise(
+ itertools.accumulate(
+ (expr.shape[0] for expr in self.underlying),
+ initial=0)
+ ))
+
+ return state, VStackPolyMatrix(
+ underlying=underlying,
+ shape=self.shape,
+ underlying_row_range=underlying_row_range,
+ # aux_terms=tuple(aux_term for expr in underlying for aux_term in expr.aux_terms)
+ )
+ \ No newline at end of file
diff --git a/polymatrix/expression/parametrizesymbolsexpr.py b/polymatrix/expression/parametrizesymbolsexpr.py
new file mode 100644
index 0000000..a81ec23
--- /dev/null
+++ b/polymatrix/expression/parametrizesymbolsexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.parametrizetermsexprmixin import ParametrizeTermsExprMixin
+
+class ParametrizeSymbolsExpr(ParametrizeTermsExprMixin):
+ pass
diff --git a/polymatrix/expression/polymatrix.py b/polymatrix/expression/polymatrix.py
new file mode 100644
index 0000000..c044081
--- /dev/null
+++ b/polymatrix/expression/polymatrix.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.polymatrixasdictmixin import PolyMatrixAsDictMixin
+
+class PolyMatrix(PolyMatrixAsDictMixin):
+ pass
diff --git a/polymatrix/expression/quadraticinexpr.py b/polymatrix/expression/quadraticinexpr.py
new file mode 100644
index 0000000..bab76f6
--- /dev/null
+++ b/polymatrix/expression/quadraticinexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.quadraticinexprmixin import QuadraticInExprMixin
+
+class QuadraticInExpr(QuadraticInExprMixin):
+ pass
diff --git a/polymatrix/expression/repmatexpr.py b/polymatrix/expression/repmatexpr.py
new file mode 100644
index 0000000..bb983be
--- /dev/null
+++ b/polymatrix/expression/repmatexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin
+
+class RepMatExpr(RepMatExprMixin):
+ pass
diff --git a/polymatrix/expression/transposeexpr.py b/polymatrix/expression/transposeexpr.py
new file mode 100644
index 0000000..908a9e4
--- /dev/null
+++ b/polymatrix/expression/transposeexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin
+
+class TransposeExpr(TransposeExprMixin):
+ pass
diff --git a/polymatrix/expression/vstackexpr.py b/polymatrix/expression/vstackexpr.py
new file mode 100644
index 0000000..7e9a1d7
--- /dev/null
+++ b/polymatrix/expression/vstackexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin
+
+class VStackExpr(VStackExprMixin):
+ pass
diff --git a/polymatrix/impl/addexprimpl.py b/polymatrix/impl/addexprimpl.py
new file mode 100644
index 0000000..a3d9e1f
--- /dev/null
+++ b/polymatrix/impl/addexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.addexpr import AddExpr
+from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class AddExprImpl(AddExpr):
+ left: OldPolyMatrixMixin
+ right: OldPolyMatrixMixin
diff --git a/polymatrix/impl/exprcontainerimpl.py b/polymatrix/impl/exprcontainerimpl.py
new file mode 100644
index 0000000..0706150
--- /dev/null
+++ b/polymatrix/impl/exprcontainerimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.exprcontainer import ExprContainer
+from polymatrix.mixins.exprcontainermixin import ExprType
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ExprContainerImpl(ExprContainer[ExprType]):
+ expr: ExpressionBaseMixin
diff --git a/polymatrix/impl/multexprimpl.py b/polymatrix/impl/multexprimpl.py
new file mode 100644
index 0000000..7401ffa
--- /dev/null
+++ b/polymatrix/impl/multexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
+from polymatrix.multexpr import MultExpr
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class MultExprImpl(MultExpr):
+ left: OldPolyMatrixMixin
+ right: OldPolyMatrixMixin
diff --git a/polymatrix/impl/polymatriximpl.py b/polymatrix/impl/oldpolymatriximpl.py
index cc435e3..29590db 100644
--- a/polymatrix/impl/polymatriximpl.py
+++ b/polymatrix/impl/oldpolymatriximpl.py
@@ -1,11 +1,11 @@
import typing
import dataclass_abc
-from polymatrix.polymatrix import PolyMatrix
+from polymatrix.oldpolymatrix import OldPolyMatrix
@dataclass_abc.dataclass_abc(frozen=True, eq=False)
-class PolyMatrixImpl(PolyMatrix):
+class PolyMatrixImpl(OldPolyMatrix):
name: str
degrees: list[int]
subs: dict[int, dict[tuple[int, int], float]]
diff --git a/polymatrix/impl/optimizationimpl.py b/polymatrix/impl/optimizationimpl.py
index ac78ea0..3243175 100644
--- a/polymatrix/impl/optimizationimpl.py
+++ b/polymatrix/impl/optimizationimpl.py
@@ -1,13 +1,13 @@
import dataclass_abc
from polymatrix.optimization import Optimization
-from polymatrix.optimizationstate import OptimizationState
+from polymatrix.oldpolymatrixexprstate import OldPolyMatrixExprState
@dataclass_abc.dataclass_abc(frozen=True, eq=False)
class OptimizationImpl(Optimization):
# n_var: int
- state: OptimizationState
+ state: OldPolyMatrixExprState
equality_constraints: dict[int, dict[tuple[int, int], float]]
inequality_constraints: dict[int, dict[tuple[int, int], float]]
auxillary_equality: dict[int, dict[tuple[int, int], float]]
diff --git a/polymatrix/impl/optimizationstateimpl.py b/polymatrix/impl/optimizationstateimpl.py
index 5ce8d11..acfaddb 100644
--- a/polymatrix/impl/optimizationstateimpl.py
+++ b/polymatrix/impl/optimizationstateimpl.py
@@ -1,11 +1,11 @@
import dataclass_abc
-from polymatrix.optimizationstate import OptimizationState
-from polymatrix.polymatrix import PolyMatrix
+from polymatrix.oldpolymatrixexprstate import OldPolyMatrixExprStateMixin
+from polymatrix.oldpolymatrix import OldPolyMatrix
@dataclass_abc.dataclass_abc(frozen=True, eq=False)
-class OptimizationStateImpl(OptimizationState):
+class OptimizationStateImpl(OldPolyMatrixExprStateMixin):
n_var: int
n_param: int
- offset_dict: dict[tuple[PolyMatrix, int], int]
+ offset_dict: dict[tuple[OldPolyMatrix, int], int]
diff --git a/polymatrix/impl/polymatexprimpl.py b/polymatrix/impl/polymatexprimpl.py
new file mode 100644
index 0000000..c0d6abb
--- /dev/null
+++ b/polymatrix/impl/polymatexprimpl.py
@@ -0,0 +1,7 @@
+import dataclass_abc
+from polymatrix.expression.polymatrix import PolyMatrix
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatExprImpl(PolyMatrix):
+ terms: dict
+ shape: tuple
diff --git a/polymatrix/impl/polymatrixaddexprimpl.py b/polymatrix/impl/polymatrixaddexprimpl.py
new file mode 100644
index 0000000..f30495b
--- /dev/null
+++ b/polymatrix/impl/polymatrixaddexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.polymatrixaddexpr import PolyMatrixAddExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixAddExprImpl(PolyMatrixAddExpr):
+ left: ExpressionBaseMixin
+ right: ExpressionBaseMixin
diff --git a/polymatrix/impl/polymatrixarrayexprimpl.py b/polymatrix/impl/polymatrixarrayexprimpl.py
new file mode 100644
index 0000000..a005129
--- /dev/null
+++ b/polymatrix/impl/polymatrixarrayexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.polymatrixarrayexpr import PolyMatrixArrayExpr
+
+from numpy import array
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixArrayExprImpl(PolyMatrixArrayExpr):
+ data: array
diff --git a/polymatrix/impl/polymatrixexprimpl.py b/polymatrix/impl/polymatrixexprimpl.py
new file mode 100644
index 0000000..a5c52f0
--- /dev/null
+++ b/polymatrix/impl/polymatrixexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.polymatrixexpr import PolyMatrixExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixExprImpl(PolyMatrixExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/impl/polymatrixexprstateimpl.py b/polymatrix/impl/polymatrixexprstateimpl.py
new file mode 100644
index 0000000..25f5c52
--- /dev/null
+++ b/polymatrix/impl/polymatrixexprstateimpl.py
@@ -0,0 +1,12 @@
+import typing
+import dataclass_abc
+import sympy
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixExprStateImpl(PolyMatrixExprState):
+ n_var: int
+ n_param: int
+ offset_dict: dict
+ symbols: typing.Optional[sympy.Symbol]
diff --git a/polymatrix/impl/polymatrixmultexprimpl.py b/polymatrix/impl/polymatrixmultexprimpl.py
new file mode 100644
index 0000000..7f8375e
--- /dev/null
+++ b/polymatrix/impl/polymatrixmultexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.polymatrixmultexpr import PolyMatrixMultExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixMultExprImpl(PolyMatrixMultExpr):
+ left: ExpressionBaseMixin
+ right: ExpressionBaseMixin
diff --git a/polymatrix/impl/polymatrixparamexprimpl.py b/polymatrix/impl/polymatrixparamexprimpl.py
new file mode 100644
index 0000000..afb5388
--- /dev/null
+++ b/polymatrix/impl/polymatrixparamexprimpl.py
@@ -0,0 +1,9 @@
+import dataclass_abc
+from polymatrix.polymatrixparamexpr import PolyMatrixParamExpr
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixParamExprImpl(PolyMatrixParamExpr):
+ name: str
+ shape: tuple
+ degrees: tuple
diff --git a/polymatrix/impl/polymatrixvalueimpl.py b/polymatrix/impl/polymatrixvalueimpl.py
new file mode 100644
index 0000000..c0c9921
--- /dev/null
+++ b/polymatrix/impl/polymatrixvalueimpl.py
@@ -0,0 +1,7 @@
+import dataclass_abc
+from polymatrix.polymatrixvalue import PolyMatrixTerm
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class PolyMatrixValueImpl(PolyMatrixTerm):
+ p_monomial: tuple
+ value: float
diff --git a/polymatrix/impl/scalarmultexprimpl.py b/polymatrix/impl/scalarmultexprimpl.py
new file mode 100644
index 0000000..42bec55
--- /dev/null
+++ b/polymatrix/impl/scalarmultexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.scalarmultexpr import ScalarMultExpr
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ScalarMultExprImpl(ScalarMultExpr):
+ left: float
+ right: ExpressionBaseMixin
diff --git a/polymatrix/init/initaddexpr.py b/polymatrix/init/initaddexpr.py
new file mode 100644
index 0000000..e660ed1
--- /dev/null
+++ b/polymatrix/init/initaddexpr.py
@@ -0,0 +1,11 @@
+from polymatrix.impl.addexprimpl import AddExprImpl
+from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
+
+def init_add_expr(
+ left: OldPolyMatrixMixin,
+ right: OldPolyMatrixMixin,
+):
+ return AddExprImpl(
+ left=left,
+ right=right,
+)
diff --git a/polymatrix/init/initexprcontainer.py b/polymatrix/init/initexprcontainer.py
new file mode 100644
index 0000000..f790705
--- /dev/null
+++ b/polymatrix/init/initexprcontainer.py
@@ -0,0 +1,9 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.impl.exprcontainerimpl import ExprContainerImpl
+
+def init_expr_container(
+ expr: ExpressionBaseMixin,
+):
+ return ExprContainerImpl(
+ expr=expr,
+)
diff --git a/polymatrix/init/initmultexpr.py b/polymatrix/init/initmultexpr.py
new file mode 100644
index 0000000..dfffd97
--- /dev/null
+++ b/polymatrix/init/initmultexpr.py
@@ -0,0 +1,11 @@
+from polymatrix.impl.multexprimpl import MultExprImpl
+from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
+
+def init_mult_expr(
+ left: OldPolyMatrixMixin,
+ right: OldPolyMatrixMixin,
+):
+ return MultExprImpl(
+ left=left,
+ right=right,
+)
diff --git a/polymatrix/init/initoptimization.py b/polymatrix/init/initoptimization.py
index 653d248..2a340eb 100644
--- a/polymatrix/init/initoptimization.py
+++ b/polymatrix/init/initoptimization.py
@@ -1,6 +1,6 @@
from polymatrix.impl.optimizationimpl import OptimizationImpl
from polymatrix.init.initoptimizationstate import init_optimization_state
-from polymatrix.optimizationstate import OptimizationState
+from polymatrix.oldpolymatrixexprstate import OldPolyMatrixExprState
def init_optimization(
diff --git a/polymatrix/init/initoptimizationstate.py b/polymatrix/init/initoptimizationstate.py
index 372e225..edfa33e 100644
--- a/polymatrix/init/initoptimizationstate.py
+++ b/polymatrix/init/initoptimizationstate.py
@@ -1,11 +1,11 @@
from polymatrix.impl.optimizationstateimpl import OptimizationStateImpl
-from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
+from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
def init_optimization_state(
n_var: int,
n_param: int = None,
- offset_dict: dict[tuple[PolyMatrixMixin, int], int] = None,
+ offset_dict: dict[tuple[OldPolyMatrixMixin, int], int] = None,
):
if n_param is None:
n_param = 0
diff --git a/polymatrix/init/initpolymatexpr.py b/polymatrix/init/initpolymatexpr.py
new file mode 100644
index 0000000..9f6cd20
--- /dev/null
+++ b/polymatrix/init/initpolymatexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.impl.polymatexprimpl import PolyMatExprImpl
+
+def init_poly_mat_expr(
+ terms: dict,
+ shape: tuple,
+):
+ return PolyMatExprImpl(
+ terms=terms,
+ shape=shape,
+)
diff --git a/polymatrix/init/initpolymatrixaddexpr.py b/polymatrix/init/initpolymatrixaddexpr.py
new file mode 100644
index 0000000..43c2d1f
--- /dev/null
+++ b/polymatrix/init/initpolymatrixaddexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.impl.polymatrixaddexprimpl import PolyMatrixAddExprImpl
+
+
+def init_poly_matrix_add_expr(
+ left: ExpressionBaseMixin,
+ right: ExpressionBaseMixin,
+):
+ return PolyMatrixAddExprImpl(
+ left=left,
+ right=right,
+)
diff --git a/polymatrix/init/initpolymatrixarrayexpr.py b/polymatrix/init/initpolymatrixarrayexpr.py
new file mode 100644
index 0000000..8f7dbed
--- /dev/null
+++ b/polymatrix/init/initpolymatrixarrayexpr.py
@@ -0,0 +1,10 @@
+from numpy import array
+from polymatrix.impl.polymatrixarrayexprimpl import PolyMatrixArrayExprImpl
+
+
+def init_poly_matrix_array_expr(
+ data: array,
+):
+ return PolyMatrixArrayExprImpl(
+ data=data,
+)
diff --git a/polymatrix/init/initpolymatrixexpr.py b/polymatrix/init/initpolymatrixexpr.py
new file mode 100644
index 0000000..2749bd3
--- /dev/null
+++ b/polymatrix/init/initpolymatrixexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.impl.polymatrixexprimpl import PolyMatrixExprImpl
+
+
+def init_poly_matrix_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return PolyMatrixExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/init/initpolymatrixexprstate.py b/polymatrix/init/initpolymatrixexprstate.py
new file mode 100644
index 0000000..03649c8
--- /dev/null
+++ b/polymatrix/init/initpolymatrixexprstate.py
@@ -0,0 +1,26 @@
+import sympy
+from polymatrix.impl.polymatrixexprstateimpl import PolyMatrixExprStateImpl
+
+def init_poly_matrix_expr_state(
+ n_var: int = None,
+ n_param: int = None,
+ offset_dict: dict = None,
+ symbols: sympy.Symbol = None,
+):
+ if n_var is None:
+ n_var = len(symbols)
+ elif symbols is not None:
+ assert n_var == len(symbols), f'{n_var} is not equal {len(symbols)}'
+
+ if n_param is None:
+ n_param = 0
+
+ if offset_dict is None:
+ offset_dict = {}
+
+ return PolyMatrixExprStateImpl(
+ n_var=n_var,
+ n_param=n_param,
+ offset_dict=offset_dict,
+ symbols=symbols,
+)
diff --git a/polymatrix/init/initpolymatrixmultexpr.py b/polymatrix/init/initpolymatrixmultexpr.py
new file mode 100644
index 0000000..2c423c9
--- /dev/null
+++ b/polymatrix/init/initpolymatrixmultexpr.py
@@ -0,0 +1,12 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.impl.polymatrixmultexprimpl import PolyMatrixMultExprImpl
+
+
+def init_poly_matrix_mult_expr(
+ left: ExpressionBaseMixin,
+ right: ExpressionBaseMixin,
+):
+ return PolyMatrixMultExprImpl(
+ left=left,
+ right=right,
+)
diff --git a/polymatrix/init/initpolymatrixparamexpr.py b/polymatrix/init/initpolymatrixparamexpr.py
new file mode 100644
index 0000000..7e1b6d7
--- /dev/null
+++ b/polymatrix/init/initpolymatrixparamexpr.py
@@ -0,0 +1,33 @@
+import typing
+from polymatrix.impl.polymatrixparamexprimpl import PolyMatrixParamExprImpl
+
+def init_poly_matrix_param_expr(
+ name: str,
+ shape: tuple,
+ degrees: tuple,
+ re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] = None,
+):
+
+ if re_index is None:
+ return PolyMatrixParamExprImpl(
+ name=name,
+ shape=shape,
+ degrees=degrees,
+ )
+
+ else:
+ class ExtendedPolyMatrixParamExprImpl(PolyMatrixParamExprImpl):
+ def re_index(
+ self,
+ degree: int,
+ poly_col: int,
+ poly_row: int,
+ x_monomial: tuple[int, ...],
+ ) -> tuple[int, int, tuple[int, ...], float]:
+ return re_index(degree, poly_col, poly_row, x_monomial)
+
+ return ExtendedPolyMatrixParamExprImpl(
+ name=name,
+ shape=shape,
+ degrees=degrees,
+ )
diff --git a/polymatrix/init/initpolymatrixvalue.py b/polymatrix/init/initpolymatrixvalue.py
new file mode 100644
index 0000000..570f399
--- /dev/null
+++ b/polymatrix/init/initpolymatrixvalue.py
@@ -0,0 +1,10 @@
+from polymatrix.impl.polymatrixvalueimpl import PolyMatrixValueImpl
+
+def init_poly_matrix_value(
+ p_monomial: tuple,
+ value: float,
+):
+ return PolyMatrixValueImpl(
+ p_monomial=p_monomial,
+ value=value,
+)
diff --git a/polymatrix/init/initscalarmultexpr.py b/polymatrix/init/initscalarmultexpr.py
new file mode 100644
index 0000000..a401758
--- /dev/null
+++ b/polymatrix/init/initscalarmultexpr.py
@@ -0,0 +1,11 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.impl.scalarmultexprimpl import ScalarMultExprImpl
+
+def init_scalar_mult_expr(
+ left: float,
+ right: ExpressionBaseMixin,
+):
+ return ScalarMultExprImpl(
+ left=left,
+ right=right,
+)
diff --git a/polymatrix/init/initpolymatrix.py b/polymatrix/init/oldinitpolymatrix.py
index edfe195..f1b2efd 100644
--- a/polymatrix/init/initpolymatrix.py
+++ b/polymatrix/init/oldinitpolymatrix.py
@@ -1,6 +1,7 @@
import typing
-from polymatrix.impl.polymatriximpl import PolyMatrixImpl
+from polymatrix.impl.oldpolymatriximpl import PolyMatrixImpl
+from polymatrix.init.initexprcontainer import init_expr_container
def init_poly_matrix(
@@ -25,7 +26,7 @@ def init_poly_matrix(
if is_negated is None:
is_negated = False
- return PolyMatrixImpl(
+ expr = PolyMatrixImpl(
name=name,
degrees=degrees,
subs=subs,
@@ -34,3 +35,5 @@ def init_poly_matrix(
shape=shape,
is_negated=is_negated,
)
+
+ return init_expr_container(expr=expr)
diff --git a/polymatrix/mixins/addexprmixin.py b/polymatrix/mixins/addexprmixin.py
new file mode 100644
index 0000000..4156829
--- /dev/null
+++ b/polymatrix/mixins/addexprmixin.py
@@ -0,0 +1,16 @@
+import abc
+
+from polymatrix.mixins.oldpolymatrixexprmixin import OldPolyMatrixExprMixin
+
+
+class AddExprMixin(OldPolyMatrixExprMixin, abc.ABC):
+ @property
+ @abc.abstractmethod
+ def left(self) -> OldPolyMatrixExprMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> OldPolyMatrixExprMixin:
+ ...
+
diff --git a/polymatrix/mixins/exprcontainermixin.py b/polymatrix/mixins/exprcontainermixin.py
new file mode 100644
index 0000000..b5dff91
--- /dev/null
+++ b/polymatrix/mixins/exprcontainermixin.py
@@ -0,0 +1,82 @@
+import abc
+import dataclasses
+import typing
+from polymatrix.addexpr import AddExpr
+from polymatrix.init.initaddexpr import init_add_expr
+from polymatrix.init.initmultexpr import init_mult_expr
+from polymatrix.init.initscalarmultexpr import init_scalar_mult_expr
+from polymatrix.mixins.oldpolymatrixexprmixin import OldPolyMatrixExprMixin
+
+from polymatrix.multexpr import MultExpr
+from polymatrix.scalarmultexpr import ScalarMultExpr
+
+
+ExprType = typing.TypeVar('ExprType', bound=OldPolyMatrixExprMixin)
+
+class ExprContainerMixin(
+ typing.Generic[ExprType],
+ abc.ABC,
+):
+ @property
+ @abc.abstractmethod
+ def expr(self) -> ExprType:
+ ...
+
+ def __add__(self, other: 'ExprContainerMixin[ExprType]'):
+ return dataclasses.replace(
+ self,
+ expr=init_add_expr(
+ left=self.expr,
+ right=other.expr,
+ ),
+ )
+
+ def __sub__(self, other: 'ExprContainerMixin[ExprType]'):
+ return self + (-1.0) * other
+
+ def __mul__(self, other: 'ExprContainerMixin[ExprType]'):
+ return dataclasses.replace(
+ self,
+ expr=init_mult_expr(
+ left=self.expr,
+ right=other.expr,
+ ),
+ )
+
+ def __rmul__(self, val: float):
+ return dataclasses.replace(
+ self,
+ expr=init_scalar_mult_expr(
+ left=val,
+ right=self.expr,
+ ),
+ )
+
+ def to_list(self) -> tuple[tuple[float, tuple[OldPolyMatrixExprMixin, ...]]]:
+ def get_mult_expr_list(expr: MultExpr) -> tuple[OldPolyMatrixExprMixin, ...]:
+ match expr:
+ case MultExpr():
+ left = get_mult_expr_list(expr.left)
+ right = get_mult_expr_list(expr.right)
+ return left + right
+ case _:
+ return (expr,)
+
+ def get_add_expr_list(expr: AddExpr) -> tuple[tuple[float, tuple[OldPolyMatrixExprMixin, ...]]]:
+ match expr:
+ case AddExpr():
+ left = get_add_expr_list(expr.left)
+ right = get_add_expr_list(expr.right)
+ return left + right
+ case ScalarMultExpr():
+ right = get_mult_expr_list(expr.right)
+ return ((expr.left, right),)
+ case MultExpr():
+ left = get_mult_expr_list(expr.left)
+ right = get_mult_expr_list(expr.right)
+ return ((1.0, left + right),)
+ case _:
+ return ((1.0, (expr,)),)
+
+ return get_add_expr_list(self.expr)
+ \ No newline at end of file
diff --git a/polymatrix/mixins/multexprmixin.py b/polymatrix/mixins/multexprmixin.py
new file mode 100644
index 0000000..c9c45c2
--- /dev/null
+++ b/polymatrix/mixins/multexprmixin.py
@@ -0,0 +1,15 @@
+import abc
+
+from polymatrix.mixins.oldpolymatrixexprmixin import OldPolyMatrixExprMixin
+
+
+class MultExprMixin(OldPolyMatrixExprMixin, abc.ABC):
+ @property
+ @abc.abstractmethod
+ def left(self) -> OldPolyMatrixExprMixin:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> OldPolyMatrixExprMixin:
+ ...
diff --git a/polymatrix/mixins/oldpolymatrixexprmixin.py b/polymatrix/mixins/oldpolymatrixexprmixin.py
new file mode 100644
index 0000000..6d5a029
--- /dev/null
+++ b/polymatrix/mixins/oldpolymatrixexprmixin.py
@@ -0,0 +1,2 @@
+class OldPolyMatrixExprMixin:
+ pass
diff --git a/polymatrix/mixins/optimizationstatemixin.py b/polymatrix/mixins/oldpolymatrixexprstatemixin.py
index 731477c..0b95092 100644
--- a/polymatrix/mixins/optimizationstatemixin.py
+++ b/polymatrix/mixins/oldpolymatrixexprstatemixin.py
@@ -2,11 +2,11 @@ import abc
import itertools
import dataclasses
-from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
+from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
from polymatrix.utils import monomial_to_index
-class OptimizationStateMixin(abc.ABC):
+class OldPolyMatrixExprStateMixin(abc.ABC):
@property
@abc.abstractmethod
def n_var(self) -> int:
@@ -27,7 +27,7 @@ class OptimizationStateMixin(abc.ABC):
@property
@abc.abstractmethod
- def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], tuple[int, int]]:
+ def offset_dict(self) -> dict[tuple[OldPolyMatrixMixin, int], tuple[int, int]]:
...
# @property
@@ -35,7 +35,7 @@ class OptimizationStateMixin(abc.ABC):
# def local_index_dict(self) -> dict[tuple[PolyMatrixMixin, int], dict[int, int]]:
# ...
- def get_polymat(self, p_index: int) -> tuple[PolyMatrixMixin, int, int]:
+ def get_polymat(self, p_index: int) -> tuple[OldPolyMatrixMixin, int, int]:
for (polymat, degree), (start, end) in self.offset_dict.items():
if start <= p_index < end:
return polymat, degree, p_index - start
@@ -46,7 +46,7 @@ class OptimizationStateMixin(abc.ABC):
# start_idx, end_idx = self.offset_dict[(polymat, degree)]
# return end_idx - start_idx
- def update_offsets(self, polymats: tuple[PolyMatrixMixin]) -> 'OptimizationStateMixin':
+ def update_offsets(self, polymats: tuple[OldPolyMatrixMixin]) -> 'OldPolyMatrixExprStateMixin':
registered_polymats = set(polymat for polymat, _ in self.offset_dict.keys())
parametric_polymats = set(p for p in polymats if not p.is_constant and p not in registered_polymats)
diff --git a/polymatrix/mixins/polymatrixmixin.py b/polymatrix/mixins/oldpolymatrixmixin.py
index c569a22..dc33449 100644
--- a/polymatrix/mixins/polymatrixmixin.py
+++ b/polymatrix/mixins/oldpolymatrixmixin.py
@@ -5,7 +5,7 @@ import typing
from matplotlib.pyplot import streamplot
-class PolyMatrixMixin(abc.ABC):
+class OldPolyMatrixMixin(abc.ABC):
@property
@abc.abstractmethod
def name(self) -> str:
diff --git a/polymatrix/mixins/optimizationmixin.py b/polymatrix/mixins/optimizationmixin.py
index e44a36f..1a0febe 100644
--- a/polymatrix/mixins/optimizationmixin.py
+++ b/polymatrix/mixins/optimizationmixin.py
@@ -6,18 +6,17 @@ import itertools
import more_itertools
from sympy import Equivalent
-from polymatrix.mixins.optimizationpipeopmixin import OptimizationPipeOpMixin
-from polymatrix.mixins.optimizationstatemixin import OptimizationStateMixin
+from polymatrix.exprcontainer import ExprContainer
+from polymatrix.mixins.oldpolymatrixexprstatemixin import OldPolyMatrixExprStateMixin
-from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
-from polymatrix.polystruct import DegreeType
+from polymatrix.oldpolymatrix import OldPolyMatrix
from polymatrix.utils import monomial_to_index
class OptimizationMixin(abc.ABC):
@property
@abc.abstractmethod
- def state(self) -> OptimizationStateMixin:
+ def state(self) -> OldPolyMatrixExprStateMixin:
...
@property
@@ -39,45 +38,28 @@ class OptimizationMixin(abc.ABC):
def n_var(self) -> int:
return self.state.n_var
- # @property
- # def n_equality_constraints(self) -> int:
- # if len(self.equality_constraints) == 0:
- # return 0
-
- # return max(eq for eq, _ in self.equality_constraints.items()) + 1
-
- # @property
- # def n_inequality_constraints(self) -> int:
- # if len(self.inequality_constraints) == 0:
- # return 0
-
- # return max(eq for eq, _ in self.inequality_constraints.items()) + 1
- # # return max(key[0] for _, d_data in self.inequality_constraints.items() for key, _ in d_data.items()) + 1
-
- # @property
- # def n_auxillary_equality(self) -> int:
- # if len(self.auxillary_equality) == 0:
- # return 0
-
- # return max(eq for eq, _ in self.auxillary_equality.items()) + 1
- # # return max(key[0] for _, d_data in self.auxillary_equality.items() for key, _ in d_data.items()) + 1
-
# def pipe(self, *operators: OptimizationPipeOpMixin):
# return functools.reduce(lambda obs, op: op(obs), operators, self)
- def add_equality_constraints(self, expr: tuple[tuple[PolyMatrixMixin, ...], ...]):
- for term in expr:
+ @staticmethod
+ def _get_equality_terms_from_expression(
+ expr: ExprContainer[OldPolyMatrix],
+ state: OldPolyMatrixExprStateMixin,
+ ) -> dict[tuple[int, tuple[int, ...], tuple[int, ...]], float]:
+ expr_list = expr.to_list()
+
+ for _, term in expr_list:
for left, right in itertools.pairwise(term):
assert left.shape[1] == right.shape[0], f'{left} and {right} do not match'
- all_polymats = tuple(polymat for term in expr for polymat in term)
+ all_polymats = tuple(polymat for _, term in expr_list for polymat in term)
# update offsets with unseen polymats
- state = self.state.update_offsets(all_polymats)
+ state = state.update_offsets(all_polymats)
equality_constraint = collections.defaultdict(float)
- for term in expr:
+ for _, term in expr_list:
for degrees in itertools.product(*(polymat.degrees for polymat in term)):
@@ -89,7 +71,7 @@ class OptimizationMixin(abc.ABC):
# n_var = 2, degree = 3
# cominations = [(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)]
- for combination in itertools.combinations_with_replacement(range(self.n_var), total_degree):
+ for combination in itertools.combinations_with_replacement(range(state.n_var), total_degree):
for x_monomial in more_itertools.distinct_permutations(combination):
def acc_func(acc, v):
@@ -108,7 +90,7 @@ class OptimizationMixin(abc.ABC):
# (1,0) -> x2*x1 instead of (0,1)->x1*x2
if all(non_increasing(rel_x_monomial) for rel_x_monomial in rel_x_monomials):
- col_defaults = [monomial_to_index(self.n_var, monom) for monom in rel_x_monomials]
+ col_defaults = [monomial_to_index(state.n_var, monom) for monom in rel_x_monomials]
n_rows_col = itertools.chain((m.shape[0] for m in term), (term[-1].shape[1],))
@@ -131,7 +113,7 @@ class OptimizationMixin(abc.ABC):
else:
new_poly_row, new_poly_col, new_monom, factor = re_index
- col = monomial_to_index(self.n_var, new_monom)
+ col = monomial_to_index(state.n_var, new_monom)
if subs is not None:
try:
@@ -168,19 +150,38 @@ class OptimizationMixin(abc.ABC):
continue
poly_row = data[0][0]
+ poly_col = data[-1][1]
p_monomial = tuple(sorted(d[2] for d in data if d[4] is None))
x_monomial_sorted = tuple(sorted(x_monomial))
- equality_constraint[poly_row, x_monomial_sorted, p_monomial] += value
+ equality_constraint[poly_row, poly_col, x_monomial_sorted, p_monomial] += value
+
+ return equality_constraint, state
+
+ @staticmethod
+ def _get_equality_equations_from_terms(
+ equality_terms: dict[tuple[int, int, tuple[int, ...], tuple[int, ...]], float],
+ ) -> dict[int, list[tuple[list[int], float]]]:
+ offset = 0
+ equality_constraints = collections.defaultdict(list)
- # eq_constr = collections.defaultdict(list)
+ for _, data in equality_terms.items():
+ equation_set = set((poly_row, poly_col, x_monomial) for (poly_row, poly_col, x_monomial, _) in data.keys())
+ to_eq_index_dict = {eq: offset + idx for idx, eq in enumerate(equation_set)}
+ offset += len(equation_set)
- # for degree, d_data in eq_constr_buffer.items():
- # for (eq_idx, perm, p_monoms), val in d_data.items():
- # row = eq_to_rows[(eq_idx, perm)]
- # eq_constr[row] += ((p_monoms, val),)
+ for (poly_row, poly_col, x_monomial, p_monomial), val in data.items():
+ eq_index = to_eq_index_dict[(poly_row, poly_col, x_monomial)]
+ equality_constraints[eq_index] += ((p_monomial, val),)
- # eq_data = dict(gen_eq_data())
+ return equality_constraints
+
+ def add_equality_constraints(self, expr: ExprContainer[OldPolyMatrix]):
+
+ equality_constraint, state = self._get_equality_terms_from_expression(
+ expr=expr,
+ state=self.state
+ )
if len(self.equality_constraints) == 0:
constraint_index = 0
@@ -193,9 +194,10 @@ class OptimizationMixin(abc.ABC):
equality_constraints=self.equality_constraints | {constraint_index: equality_constraint},
)
- def add_inequality_constraints(self, expr):
+ def add_inequality_constraints(self, expr: ExprContainer[OldPolyMatrix]):
+
# all_polymats = tuple(polymat for term in expr for polymat in term)
- polymat = expr[0][0]
+ polymat = expr.expr
all_polymats = (polymat,)
# update offsets with unseen polymats
@@ -275,22 +277,11 @@ class OptimizationMixin(abc.ABC):
return dataclasses.replace(self, inequality_constraints=ineq_constr_buffer, auxillary_equality=aux_eq_buffer, state=state)
- def get_equality_constraints(self):
- offset = 0
- equality_constraints = collections.defaultdict(list)
-
- for _, data in self.equality_constraints.items():
- equation_set = set((poly_row, x_monomial) for (poly_row, x_monomial, _) in data.keys())
- to_eq_index_dict = {eq: offset + idx for idx, eq in enumerate(equation_set)}
- offset += len(equation_set)
-
- for (poly_row, x_monomial, p_monomial), val in data.items():
- eq_index = to_eq_index_dict[(poly_row, x_monomial)]
- equality_constraints[eq_index] += ((p_monomial, val),)
-
- return equality_constraints
-
- def minimize(self, cost_func=None):
+ def minimize(
+ self,
+ cost: tuple[ExprContainer[OldPolyMatrix], ...],
+ t: float,
+ ):
"""
- assume sum of squares cost function on variables x
- introduce nu/lambda for each equality/inequality
@@ -303,10 +294,20 @@ class OptimizationMixin(abc.ABC):
> r1 * r2
> P @ x = v
"""
+
+ state = [self.state]
- state = self.state
+ def gen_terms():
+ for idx, expr in enumerate(cost):
+ sum_of_squares, state[0] = self._get_equality_terms_from_expression(
+ expr=expr,
+ state=state[0]
+ )
+ yield idx, sum_of_squares
+ # print(dict(gen_terms()))
+ sum_of_squares = self._get_equality_equations_from_terms(dict(gen_terms()))
- equality_constraints = self.get_equality_constraints()
+ equality_constraints = self._get_equality_equations_from_terms(self.equality_constraints)
n_equality_constraints = len(equality_constraints)
@@ -323,7 +324,7 @@ class OptimizationMixin(abc.ABC):
used_param_indices = set(param_idx for monomial_value_pairs in itertools.chain(equality_constraints.values(), self.inequality_constraints.values(), self.auxillary_equality.values()) for monomial, _ in monomial_value_pairs for param_idx in monomial)
param_update = {monom: idx for idx, monom in enumerate(sorted(used_param_indices))}
- assert max(used_param_indices) == state.n_param - 1, f'{max(used_param_indices)=} is not {state.n_param - 1=}'
+ assert max(used_param_indices) == state[0].n_param - 1, f'{max(used_param_indices)=} is not {state[0].n_param - 1=}'
# param_indices = tuple(start + idx for start, end in state.offset_dict.values() for idx in range(end - start))
# print(tuple(m for m in monom_update_reverse if m not in param_indices))
@@ -331,11 +332,12 @@ class OptimizationMixin(abc.ABC):
# the variables x are assumed to be the parameters of all registered polynomial matrices
# todo: this would later come from a specific polynomial matrices
- param_indices = tuple(param_update[start + idx] for start, end in state.offset_dict.values() for idx in range(end - start) if start + idx in used_param_indices)
+ param_indices = tuple(param_update[start + idx] for start, end in state[0].offset_dict.values() for idx in range(end - start) if start + idx in used_param_indices)
equality_constraints = tuple((eq_index, tuple((tuple(param_update[m] for m in monom), val) for monom, val in monomial_value_pair)) for eq_index, monomial_value_pair in equality_constraints.items())
inequality_constraints = tuple((key, tuple((tuple(param_update[m] for m in monom), val) for monom, val in monom_val)) for key, monom_val in self.inequality_constraints.items())
auxillary_equality = tuple((key, tuple((tuple(param_update[m] for m in monom), val) for monom, val in monom_val)) for key, monom_val in self.auxillary_equality.items())
+ sum_of_squares = tuple((key, tuple((tuple(param_update[m] for m in monom), val) for monom, val in monom_val)) for key, monom_val in sum_of_squares.items())
# introduce variable nu for each equality/inequality
idx_nu = len(used_param_indices)
@@ -357,7 +359,7 @@ class OptimizationMixin(abc.ABC):
# > inequality constraints - r1^2
for eq, monom_val_list in inequality_constraints:
- eq_buffer[current_eq_offset + eq] += monom_val_list + ((2 * (idx_r1 + eq,), -1),)
+ eq_buffer[current_eq_offset + eq] += monom_val_list + ((2 * (idx_r1 + eq,), -1),) # (tuple(), -0.1))
current_eq_offset += n_inequality_constraints
# assert max(eq_buffer.keys()) < current_eq_offset
@@ -372,25 +374,16 @@ class OptimizationMixin(abc.ABC):
aux_variables = set(var for p_monomial, _ in aux_eq_data for var in p_monomial if var not in param_indices)
variables = set(var for p_monomial, _ in aux_eq_data for var in p_monomial if var in param_indices)
- # def gen_equation():
# eq1 = (x, a d_a x)
# eq2 = (-1, a d_b x)
for var in variables:
-
- # print(f'{var=}')
# a x - b
# ((a, x), (-b,))
for p_monomial, val in aux_eq_data:
-
- # print(f'{p_monomial=}')
p_monomial_grp = dict(collections.Counter(p_monomial))
- # d_a, d_b
- # for var, counter in p_monomial_grp.items():
- # if var in param_indices:
-
if var in p_monomial_grp:
def generate_monom():
# for each variable in the monomial
@@ -432,7 +425,6 @@ class OptimizationMixin(abc.ABC):
yield dx_dict[var][aux_var]
der_monomial_2 = tuple(generate_monom_2())
- # print(f'{der_monomial_2=}')
if len(der_monomial_2):
eq_buffer[current_eq_offset] += ((der_monomial_2, val * p_monomial_grp[aux_var]),)
@@ -440,25 +432,14 @@ class OptimizationMixin(abc.ABC):
# print(dx_dict)
- # @dataclasses.dataclass
- # class VariableIndexingState:
- # index: int
-
- # def increment(self):
- # return dataclasses.replace(self, index=self.index + 1), self.index
-
def differentiate(data, var, dx_dict):
eq_buffer = tuple()
-
- # print(f'{var=}')
for p_monomial, val in data:
# count powers for each variable
p_monomial_grp = dict(collections.Counter(p_monomial))
- # print(f'{p_monomial=}')
-
# for var, counter in p_monomial_grp.items():
if var in p_monomial_grp:
def generate_monom():
@@ -473,7 +454,7 @@ class OptimizationMixin(abc.ABC):
eq_buffer += ((
tuple(generate_monom()),
- -val * p_monomial_grp[var],
+ val * p_monomial_grp[var],
),)
if var in dx_dict:
@@ -490,24 +471,12 @@ class OptimizationMixin(abc.ABC):
yield i_var
yield dx_dict[var][aux_var]
-
- # der_monomial = tuple(generate_monom())
-
- # if 1 < len(der_monomial):
- # aux_eq_buffer += ((
- # (der_monomial, 1), ((index_buffer[0],), -1)
- # ),)
- # variable_indices = (index_buffer[0],)
- # else:
- # variable_indices = der_monomial
eq_buffer += ((
tuple(generate_monom()),
- -val * p_monomial_grp[aux_var],
+ val * p_monomial_grp[aux_var],
),)
- # index_buffer[0] += 1
-
return eq_buffer
def reduce_to_quadratic_equation(eq, index):
@@ -518,9 +487,6 @@ class OptimizationMixin(abc.ABC):
for monomial, val in eq:
n_aux = len(monomial) - 2
- # print(f'{monomial=}')
- # print(f'{n_aux=}')
-
if 0 < n_aux:
eq_buffer += ((monomial[0:1] + (index_buffer[0],), val),)
@@ -544,62 +510,43 @@ class OptimizationMixin(abc.ABC):
# > x + nu * equality constraint + lambda * inequality constraint
# ---------------------------------------------------------------
+ # print(sum_of_squares)
+
for index, param in enumerate(param_indices):
- eq_buffer[current_eq_offset + index] += (((param,), 1),)
+ # print(param)
+ def gen_cost_terms():
+ for _, data in sum_of_squares:
+ equation = differentiate(data, param, {})
+ for d_monomials, d_val in equation:
+ for monomials, val in data:
+ yield monomials + d_monomials, val * d_val
+
+ eq_buffer[current_eq_offset + index] += tuple(gen_cost_terms())
+
+ # eq_buffer[current_eq_offset + index] += (((param,), 1),)
# differentiate equality constraints for each variable x
for eq_constr_idx, eq_data in equality_constraints:
for var in param_indices:
- equation = differentiate(eq_data, var, dx_dict)
+ equation = differentiate(eq_data, var, dx_dict)
equation = tuple((monomials + (idx_nu + eq_constr_idx,), val) for monomials, val in equation)
-
equation, aux_equations, idx_dx[0] = reduce_to_quadratic_equation(equation, idx_dx[0])
if len(equation):
eq_idx = param_indices.index(var)
- # print(f'{equation=}')
- eq_buffer[current_eq_offset + eq_idx] += equation
-
- # # print(f'{aux_equations=}')
- # for aux_eq in aux_equations:
- # eq_buffer[next_current_eq_offset] = aux_eq
- # next_current_eq_offset += 1
-
- # for p_monomial, val in ineq_data:
-
- # p_monomial_grp = dict(collections.Counter(p_monomial))
-
- # for var, counter in p_monomial_grp.items():
- # if var in param_indices:
- # def generate_monom():
- # for i_m, i_counter in p_monomial_grp.items():
- # if var is i_m:
- # sel_counter = i_counter - 1
- # else:
- # sel_counter = i_counter
-
- # for _ in range(sel_counter):
- # yield i_m
-
- # eq_idx = param_indices.index(var)
- # der_monomial = tuple(generate_monom()) + (idx_nu + eq_constr_idx,)
- # eq_buffer[current_eq_offset + eq_idx] += ((der_monomial, val * counter),)
+ eq_buffer[current_eq_offset + eq_idx] += equation
next_current_eq_offset = current_eq_offset + len(param_indices)
# differentiate inequality constraints for each variable x
for ineq_constr_idx, ineq_data in inequality_constraints:
- # aux_variables = set(var for p_monomial, _ in ineq_data for var in p_monomial if var not in param_indices)
- # variables = set(var for p_monomial, _ in ineq_data for var in p_monomial if var in param_indices)
for var in param_indices:
- equation = differentiate(ineq_data, var, dx_dict)
-
- equation = tuple((monomials + (idx_lambda + ineq_constr_idx,), val) for monomials, val in equation)
+ equation = differentiate(ineq_data, var, dx_dict)
+ equation = tuple((monomials + (idx_lambda + ineq_constr_idx,), -val) for monomials, val in equation)
equation, aux_equations, idx_dx[0] = reduce_to_quadratic_equation(equation, idx_dx[0])
- # idx_dx[0] = index
if len(equation):
eq_idx = param_indices.index(var)
@@ -611,50 +558,6 @@ class OptimizationMixin(abc.ABC):
eq_buffer[next_current_eq_offset] = aux_eq
next_current_eq_offset += 1
- # for p_monomial, val in ineq_data:
-
- # # count powers for each variable
- # p_monomial_grp = dict(collections.Counter(p_monomial))
-
- # # for var, counter in p_monomial_grp.items():
- # if var in p_monomial_grp:
- # def generate_monom():
- # for i_var, i_counter in p_monomial_grp.items():
- # if var is i_var:
- # sel_counter = i_counter - 1
- # else:
- # sel_counter = i_counter
-
- # for _ in range(sel_counter):
- # yield i_var
-
- # eq_idx = param_indices.index(var)
- # der_monomial = tuple(generate_monom()) + (idx_lambda + ineq_constr_idx,)
- # eq_buffer[current_eq_offset + eq_idx] += ((der_monomial, -val * counter),)
-
- # if var in dx_dict:
- # for aux_var in dx_dict[var]:
- # if aux_var in p_monomial_grp:
- # def generate_monom():
- # for i_var, i_counter in p_monomial_grp.items():
- # if i_var is aux_var:
- # sel_counter = i_counter - 1
- # else:
- # sel_counter = i_counter
-
- # for _ in range(sel_counter):
- # yield i_var
-
- # yield dx_dict[var][aux_var]
-
- # eq_idx = param_indices.index(var)
- # der_monomial_2 = tuple(generate_monom())
- # eq_buffer[next_current_eq_offset] = ((der_monomial_2, 1), ((idx_dx[0],), -1))
- # next_current_eq_offset += 1
-
- # eq_buffer[current_eq_offset + eq_idx] += (((idx_lambda + ineq_constr_idx, idx_dx[0]), -val * p_monomial_grp[aux_var]),)
- # idx_dx[0] += 1
-
# current_eq_offset += len(param_indices)
current_eq_offset = next_current_eq_offset
@@ -669,11 +572,22 @@ class OptimizationMixin(abc.ABC):
# assert max(eq_buffer.keys()) < current_eq_offset
assert max(eq_buffer.keys()) == current_eq_offset - 1
- # > r1 * r2
+ next_current_eq_offset = current_eq_offset + n_inequality_constraints
+ # # > r1 * r2
+ # for idx in range(n_inequality_constraints):
+ # eq_buffer[current_eq_offset + idx] += (((idx_r1 + idx, idx_r2 + idx), 1),)
+
+ # > r1^2 * lambda
for idx in range(n_inequality_constraints):
- eq_buffer[current_eq_offset + idx] += (((idx_r1 + idx, idx_r2 + idx), 1),)
+ eq_buffer[next_current_eq_offset] += (((idx_dx[0],), 1), (2 * (idx_r1 + idx,), -1))
+ next_current_eq_offset += 1
+
+ eq_buffer[current_eq_offset + idx] += (((idx_lambda + idx, idx_dx[0]), 1), (tuple(), -1/t))
+ idx_dx[0] += 1
+
+ # current_eq_offset += n_inequality_constraints
+ current_eq_offset = next_current_eq_offset
- current_eq_offset += n_inequality_constraints
# assert max(eq_buffer.keys()) < current_eq_offset
assert max(eq_buffer.keys()) == current_eq_offset - 1
diff --git a/polymatrix/mixins/optimizationpipeopmixin.py b/polymatrix/mixins/optimizationpipeopmixin.py
deleted file mode 100644
index a90c193..0000000
--- a/polymatrix/mixins/optimizationpipeopmixin.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import abc
-import typing
-
-
-class OptimizationPipeOpMixin(abc.ABC):
- @property
- @abc.abstractmethod
- def func(self) -> typing.Callable[[typing.Any], typing.Any]:
- ...
-
- def __call__(self, source: typing.Any):
- return self.func(source)
diff --git a/polymatrix/mixins/scalarmultexprmixin.py b/polymatrix/mixins/scalarmultexprmixin.py
new file mode 100644
index 0000000..ec98722
--- /dev/null
+++ b/polymatrix/mixins/scalarmultexprmixin.py
@@ -0,0 +1,14 @@
+import abc
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+
+class ScalarMultExprMixin(ExpressionBaseMixin, abc.ABC):
+ @property
+ @abc.abstractmethod
+ def left(self) -> float:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def right(self) -> ExpressionBaseMixin:
+ ...
diff --git a/polymatrix/mixins/statemonadmixin.py b/polymatrix/mixins/statemonadmixin.py
new file mode 100644
index 0000000..c90ebc8
--- /dev/null
+++ b/polymatrix/mixins/statemonadmixin.py
@@ -0,0 +1,18 @@
+import abc
+import typing
+
+
+StateType = typing.TypeVar('StateType')
+ValueType = typing.TypeVar('ValueType')
+
+class StateMonadMixin(
+ # typing.Generic[StateType, ValueType],
+ abc.ABC,
+):
+ @abc.abstractmethod
+ def apply(self, state: StateType) -> tuple[StateType, ValueType]:
+ ...
+
+ # def __apply__(self, state: StateType) -> ValueType:
+ # _, value = self.state_func(state)
+ # return value
diff --git a/polymatrix/multexpr.py b/polymatrix/multexpr.py
new file mode 100644
index 0000000..568bbf8
--- /dev/null
+++ b/polymatrix/multexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.mixins.multexprmixin import MultExprMixin
+
+class MultExpr(MultExprMixin):
+ pass
diff --git a/polymatrix/oldpolymatrix.py b/polymatrix/oldpolymatrix.py
new file mode 100644
index 0000000..409680d
--- /dev/null
+++ b/polymatrix/oldpolymatrix.py
@@ -0,0 +1,6 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.mixins.oldpolymatrixmixin import OldPolyMatrixMixin
+
+
+class OldPolyMatrix(OldPolyMatrixMixin):
+ pass
diff --git a/polymatrix/oldpolymatrixexprstate.py b/polymatrix/oldpolymatrixexprstate.py
new file mode 100644
index 0000000..1397f9e
--- /dev/null
+++ b/polymatrix/oldpolymatrixexprstate.py
@@ -0,0 +1,5 @@
+from polymatrix.mixins.oldpolymatrixexprstatemixin import OldPolyMatrixExprStateMixin
+
+
+class OldPolyMatrixExprState(OldPolyMatrixExprStateMixin):
+ pass
diff --git a/polymatrix/optimizationstate.py b/polymatrix/optimizationstate.py
deleted file mode 100644
index 775e976..0000000
--- a/polymatrix/optimizationstate.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from polymatrix.mixins.optimizationstatemixin import OptimizationStateMixin
-
-
-class OptimizationState(OptimizationStateMixin):
- pass
diff --git a/polymatrix/polymatrix.py b/polymatrix/polymatrix.py
deleted file mode 100644
index ff43af5..0000000
--- a/polymatrix/polymatrix.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
-
-
-class PolyMatrix(PolyMatrixMixin):
- pass
diff --git a/polymatrix/polymatrixaddexpr.py b/polymatrix/polymatrixaddexpr.py
new file mode 100644
index 0000000..5d4f567
--- /dev/null
+++ b/polymatrix/polymatrixaddexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.additionexprmixin import AddExprMixin
+
+class PolyMatrixAddExpr(AddExprMixin):
+ pass
diff --git a/polymatrix/polymatrixarrayexpr.py b/polymatrix/polymatrixarrayexpr.py
new file mode 100644
index 0000000..c9b572d
--- /dev/null
+++ b/polymatrix/polymatrixarrayexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.fromarrayexprmixin import FromArrayExprMixin
+
+class PolyMatrixArrayExpr(FromArrayExprMixin):
+ pass
diff --git a/polymatrix/polymatrixexpr.py b/polymatrix/polymatrixexpr.py
new file mode 100644
index 0000000..74e8757
--- /dev/null
+++ b/polymatrix/polymatrixexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
+
+class PolyMatrixExpr(ExpressionMixin):
+ pass
diff --git a/polymatrix/polymatrixexprstate.py b/polymatrix/polymatrixexprstate.py
new file mode 100644
index 0000000..11419cb
--- /dev/null
+++ b/polymatrix/polymatrixexprstate.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+
+class PolyMatrixExprState(ExpressionStateMixin):
+ pass
diff --git a/polymatrix/polymatrixmultexpr.py b/polymatrix/polymatrixmultexpr.py
new file mode 100644
index 0000000..6b7f997
--- /dev/null
+++ b/polymatrix/polymatrixmultexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.matrixmultexprmixin import MatrixMultExprMixin
+
+class PolyMatrixMultExpr(MatrixMultExprMixin):
+ pass
diff --git a/polymatrix/polymatrixparamexpr.py b/polymatrix/polymatrixparamexpr.py
new file mode 100644
index 0000000..f396bc9
--- /dev/null
+++ b/polymatrix/polymatrixparamexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.parametrizetermsexprmixin import ParametrizeTermsExprMixin
+
+class PolyMatrixParamExpr(ParametrizeTermsExprMixin):
+ pass
diff --git a/polymatrix/polysolver.py b/polymatrix/polysolver.py
index 1f7bc8e..2cb9b7b 100644
--- a/polymatrix/polysolver.py
+++ b/polymatrix/polysolver.py
@@ -276,7 +276,9 @@ def outer_smart_solve(data, n_iter=10, a_max=1.0, irange=None, idegree=None):
subs_data = substitude_x_add_a(data, a[-1])
sol = solve_poly_system(subs_data, 6)
- error_index = np.max(np.abs(eval_solution(subs_data, sol))) + max(a[-1])
+
+ error_index = np.max(np.abs(eval_solution(data, a[-1] + sol))) + max(a[-1])
+ # error_index = np.max(np.abs(eval_solution(subs_data, sol))) + max(a[-1])
# error_index = np.max(np.abs(eval_solution(subs_data, sol)))
except:
diff --git a/polymatrix/polystruct.py b/polymatrix/polystruct.py
deleted file mode 100644
index 808b7ab..0000000
--- a/polymatrix/polystruct.py
+++ /dev/null
@@ -1,480 +0,0 @@
-import abc
-import collections
-import typing
-import numpy as np
-import dataclasses
-import dataclass_abc
-import scipy.sparse
-import itertools
-import functools
-import more_itertools
-
-import polymatrix.utils
-
-DegreeType = int
-# CoordType = tuple[int, int]
-# SubsDictType = dict[DegreeType, dict[CoordType, float]]
-
-########################################
-# Mixins
-########################################
-
-class PolyMatrixMixin(abc.ABC):
- @property
- @abc.abstractmethod
- def degrees(self) -> list[int]:
- ...
-
- @property
- @abc.abstractmethod
- def subs(self) -> dict[DegreeType, dict[int, int, float]]:
- ...
-
- @property
- @abc.abstractmethod
- def re_index(self) -> typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]]:
- ...
-
- @property
- @abc.abstractmethod
- def is_constant(self) -> int:
- ...
-
- @property
- @abc.abstractmethod
- def shape(self) -> tuple[int, int]:
- ...
-
-
-class State(abc.ABC):
- @property
- @abc.abstractmethod
- def n_param(self) -> int:
- ...
-
- @property
- @abc.abstractmethod
- def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]:
- ...
-
- @property
- @abc.abstractmethod
- def local_index_dict(self) -> dict[tuple[PolyMatrixMixin, int], dict[int, int]]:
- ...
-
- @property
- @abc.abstractmethod
- def eq_constraints(self) -> dict[int, dict[tuple[int, int], float]]:
- ...
-
- @property
- @abc.abstractmethod
- def ineq_constraints(self) -> dict[int, dict[tuple[int, int], float]]:
- ...
-
-
-
-class PolyExpressionMixin(abc.ABC):
- @property
- @abc.abstractmethod
- def data(self) -> dict[int, dict[tuple[int, int], float]]:
- ...
-
- @property
- @abc.abstractmethod
- def n_var(self) -> int:
- ...
-
- @property
- @abc.abstractmethod
- def n_eq(self):
- ...
-
- # @property
- # @abc.abstractmethod
- # def n_param(self) -> int:
- # ...
-
- # @property
- # @abc.abstractmethod
- # def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]:
- # ...
-
-
-class OptimizationMixin(abc.ABC):
- @property
- @abc.abstractmethod
- def cost_function(self) -> dict[int, dict[int, float]]:
- ...
-
- @property
- @abc.abstractmethod
- def eq_constraints(self) -> dict[int, dict[tuple[int, int], float]]:
- ...
-
- @property
- @abc.abstractmethod
- def ineq_constraints(self) -> dict[int, dict[tuple[int, int], float]]:
- ...
-
- @property
- @abc.abstractmethod
- def n_var(self) -> int:
- ...
-
- # @property
- # @abc.abstractmethod
- # def n_param(self) -> int:
- # ...
-
- @property
- @abc.abstractmethod
- def n_eq_constraints(self):
- ...
-
- @property
- @abc.abstractmethod
- def n_ineq_constraints(self):
- ...
-
- # @property
- # @abc.abstractmethod
- # def offset_dict(self) -> dict[tuple[PolyMatrixMixin, int], int]:
- # ...
-
- def create(self) -> PolyExpressionMixin:
- """
- - adds nu and lambda for each equality and inequality
- - lambda - r_lambda = 0 for each inequality
- - lambda r_ineq = 0 for each inequality
- - add vanishing gradient
- """
-
- pass
-
- def add_positive_definiteness_condition(self, key):
- pass
-
-
-class PolyMatrixEquationMixin(abc.ABC):
- @property
- @abc.abstractmethod
- def terms(self) -> list[tuple[PolyMatrixMixin, PolyMatrixMixin]]:
- """
- the terms the polynomial matrix equation consists of
- """
-
- ...
-
- @property
- @abc.abstractmethod
- def n_var(self) -> int:
- """
- number of variables defining the polynomials
-
- for example n_var=3: x1, x2 and x3
- """
-
- ...
-
- @property
- @abc.abstractmethod
- def monom_to_index(self) -> typing.Callable[[int, tuple[int, ...]], int]:
- ...
-
- def create(
- self,
- subs: dict[PolyMatrixMixin, dict[DegreeType, dict[int, int, float]]] = None,
- ) -> PolyExpressionMixin:
- if subs is None:
- added_subs = {}
- else:
- added_subs = subs
-
- # create parameter offset
- all_structs = set(indexed_poly_mat for term in self.terms for indexed_poly_mat in term)
-
- def gen_n_param_per_struct():
-
- for struct in all_structs:
-
- if struct.is_constant:
- continue
-
- for degree in struct.degrees:
- number_of_terms = int(struct.shape[0] * struct.shape[1] * (self.monom_to_index(self.n_var, degree*(self.n_var-1,)) + 1))
- yield (struct, degree), number_of_terms
-
- # param_list = list(gen_n_param_per_struct())
- param_key, param_value = list(zip(*gen_n_param_per_struct()))
- cum_sum = list(itertools.accumulate(param_value))
- offset_dict = dict(zip(param_key, [0] + cum_sum[:-1]))
-
- # # create parameter offset
- # all_structs = set(indexed_poly_mat for term in self.terms for indexed_poly_mat in term)
-
- def gen_substitutions():
- for struct in all_structs:
- added_struct_subs = added_subs.get(struct, None)
-
- if added_struct_subs is not None:
- if struct.subs is not None:
- def gen_merged_subs():
- for degree in struct.degrees:
- all_subs = struct.subs.get(degree, {}) | added_struct_subs.get(degree, {})
- yield degree, all_subs
-
- all_subs = dict(gen_merged_subs())
-
- else:
- all_subs = added_struct_subs
-
- else:
- if struct.subs is not None:
- all_subs = struct.subs
-
- else:
- all_subs = None
-
- yield struct, all_subs
-
- subs_dict = dict(gen_substitutions())
-
- eq_constr_buffer = collections.defaultdict(lambda: collections.defaultdict(float))
-
- def gen_re_indexing(term, degrees, curr_rows_col, monoms, col_defaults, d_subs, d_offsets, perm):
- for m, degree, (poly_row, poly_col), monom, col_default, subs, offset in zip(
- term, degrees, itertools.pairwise(curr_rows_col), monoms, col_defaults, d_subs, d_offsets):
-
- if m.re_index is not None:
- re_index = m.re_index(degree, poly_row, poly_col, monom)
- else:
- re_index = None
-
- if re_index is None:
- col = col_default
- new_poly_row = poly_row
- new_poly_col = poly_col
- factor = 1.0
-
- else:
- new_poly_row, new_poly_col, new_monom, factor = re_index
- col = self.monom_to_index(self.n_var, new_monom)
-
- if subs is not None:
- try:
- subs_val = subs[(new_poly_row, new_poly_col, col)]
- except KeyError:
- subs_val = None
- else:
- subs_val = None
-
- if subs_val is None:
- # the coefficient is selected after the reindexing
-
- row = new_poly_row + new_poly_col * m.shape[0]
-
- # linearize parameter matrix
- param_idx = int(offset + row + col * m.shape[0] * m.shape[1])
-
- else:
- param_idx = None
-
- yield poly_row, poly_col, param_idx, factor, subs_val
-
- for term in self.terms:
-
- term_subs = [subs_dict[m] for m in term]
- # n_matrices = len(term)
-
- for degrees in itertools.product(*(m.degrees for m in term)):
- total_degree = sum(degrees)
- d_subs = [subs[degree] if subs is not None and degree in subs else None for degree, subs in zip(degrees, term_subs)]
- d_offsets = [offset_dict.get((m, degree), 0) for m, degree in zip(term, degrees)]
-
- for combination in itertools.combinations_with_replacement(range(self.n_var), total_degree):
- # n_var = 2, degree = 3
- # cominations = [(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)]
-
- for monom in more_itertools.distinct_permutations(combination):
-
- def acc_func(acc, v):
- last, _ = acc
- new = last + v
-
- return new, monom[last:new]
-
- # monom=(1, 0, 1) -> monom1=(x2, x1), monom2=(x2)
- monoms = list(monom for _, monom in itertools.accumulate(degrees, acc_func, initial=(0, None)))[1:]
-
- def non_increasing(seq):
- return all(y <= x for x, y in zip(seq, seq[1:]))
-
- # (1,0) -> x2*x1 instead of (0,1)->x1*x2
- if all(non_increasing(monom) for monom in monoms):
-
- col_defaults = [self.monom_to_index(self.n_var, monom) for monom in monoms]
-
- n_rows_col = itertools.chain((m.shape[0] for m in term), (term[-1].shape[1],))
-
- for curr_rows_col in itertools.product(*[range(e) for e in n_rows_col]):
-
- data = tuple(gen_re_indexing(term, degrees, curr_rows_col, monoms, col_defaults, d_subs, d_offsets, monom))
-
- total_factor = functools.reduce(lambda x, y: x*y, (d[3] for d in data))
-
- if total_factor == 0:
- continue
-
- value = functools.reduce(lambda x, y: x*y, (d[4] for d in data if d[4] is not None), 1) * total_factor
-
- if value == 0:
- continue
-
- poly_row = data[0][0]
- param_idx = tuple(d[2] for d in data if d[4] is None)
- degree = len(param_idx)
-
- eq_constr_buffer[degree][poly_row, monom, param_idx] += value
-
- # assign equations
- rows_perm_set = set((eq_idx, perm) for eq_tuple_degree in eq_constr_buffer.values() for (eq_idx, perm, _) in eq_tuple_degree.keys())
- eq_to_rows = {eq: idx for idx, eq in enumerate(rows_perm_set)}
-
- # # calculate offset
- # monom_update_reverse = sorted(set(m for eq_tuple_degree in eq_constr_buffer.values() for (_, _, monoms) in eq_tuple_degree.keys() for m in monoms))
- # monom_update = {monom: idx for idx, monom in enumerate(monom_update_reverse)}
-
- # def gen_n_cum_sum():
- # groups = itertools.groupby(monom_update_reverse, lambda v: next((i for i, g in enumerate(cum_sum) if v < g)))
- # for _, group in groups:
- # yield sum(1 for _ in group)
-
- # n_cum_sum = list(itertools.accumulate(gen_n_cum_sum()))
- # n_offset_dict = dict(zip(param_key, [0] + n_cum_sum[:-1]))
- # n_param = n_cum_sum[-1]
-
- # def gen_eq_data():
- # for degree, d_data in eq_constr_buffer.items():
- # def gen_eq_degree_data():
- # for (eq_idx, perm, monoms), val in d_data.items():
- # row = eq_to_rows[(eq_idx, perm)]
- # monoms_updated = tuple(monom_update[m] for m in monoms)
- # # print(f'{n_param=}')
- # col = self.monom_to_index(n_param, monoms_updated)
- # yield (row, col), val
-
- # yield degree, dict(gen_eq_degree_data())
-
- # eq_data = dict(gen_eq_data())
-
- def gen_eq_data():
- for degree, d_data in eq_constr_buffer.items():
- def gen_eq_degree_data():
- for (eq_idx, perm, monoms), val in d_data.items():
- row = eq_to_rows[(eq_idx, perm)]
- # monoms_updated = tuple(monom_update[m] for m in monoms)
- # col = self.monom_to_index(n_param, monoms_updated)
- yield (row, monoms), val
-
- yield degree, dict(gen_eq_degree_data())
-
- eq_data = dict(gen_eq_data())
-
- return PolyEquationImpl(
- data=eq_data,
- n_param=cum_sum[-1],
- n_eq=len(eq_to_rows),
- n_var=self.n_var,
- offset_dict=offset_dict,
- )
-
-########################################
-# Classes
-########################################
-
-class PolyMatrix(PolyMatrixMixin):
- pass
-
-
-class PolyEquation(PolyExpressionMixin):
- pass
-
-
-class PolyMatrixEquation(PolyMatrixEquationMixin):
- pass
-
-
-########################################
-# Implementations
-########################################
-
-@dataclass_abc.dataclass_abc(frozen=True, eq=False)
-class PolyMatrixImpl(PolyMatrix):
- degrees: list[int]
- subs: dict[int, dict[tuple[int, int], float]]
- re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]]
- is_constant: bool
- shape: tuple[int, int]
-
-
-@dataclass_abc.dataclass_abc(frozen=True)
-class PolyEquationImpl(PolyEquation):
- data: dict[int, dict[tuple[int, tuple, int], float]]
- offset_dict: dict[tuple[PolyMatrixMixin, int], int]
- n_param: int
- n_eq: int
- n_var: int
-
-
-@dataclass_abc.dataclass_abc(frozen=True)
-class PolyMatrixEquationImpl(PolyMatrixEquation):
- terms: list[tuple[PolyMatrix, PolyMatrix]]
- monom_to_index: typing.Callable[[int, tuple[int, ...]], int]
- n_var: int
-
-########################################
-# init functions
-########################################
-
-def init_poly_matrix(
- shape: tuple[int, int],
- degrees: list[int] = None,
- subs: dict[int, dict[tuple[int, int], float]] = None,
- re_index: typing.Callable[[int, int, int, tuple[int, ...]], tuple[int, int, int, tuple[int, ...], float]] = None,
- is_constant: bool = None,
-):
- if degrees is None:
- assert isinstance(subs, dict)
- degrees = list(subs.keys())
-
- if is_constant is None:
- if subs is None:
- is_constant = False
- else:
- is_constant = True
-
- return PolyMatrixImpl(
- degrees=degrees,
- subs=subs,
- re_index=re_index,
- is_constant=is_constant,
- shape=shape,
- )
-
-
-def init_equation(
- n_var: int,
- terms: list[tuple[PolyMatrix, PolyMatrix]],
- monom_to_index: typing.Callable[[int, tuple[int, ...]], int] = None,
-):
- # assert all(not left.is_vector and right.is_vector for left, right in terms)
-
- if monom_to_index is None:
- monom_to_index = polymatrix.utils.variable_to_index
-
- return PolyMatrixEquationImpl(
- n_var=n_var,
- terms=terms,
- monom_to_index=monom_to_index,
- )
diff --git a/polymatrix/scalarmultexpr.py b/polymatrix/scalarmultexpr.py
new file mode 100644
index 0000000..3d9a075
--- /dev/null
+++ b/polymatrix/scalarmultexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.mixins.scalarmultexprmixin import ScalarMultExprMixin
+
+class ScalarMultExpr(ScalarMultExprMixin):
+ pass
diff --git a/polymatrix/statemonad.py b/polymatrix/statemonad.py
index 9be3b61..3054160 100644
--- a/polymatrix/statemonad.py
+++ b/polymatrix/statemonad.py
@@ -14,6 +14,7 @@ class StateMonad(Generic[U, State]):
return cls(lambda state: (value, state))
def map(self, fn: Callable[[U], V]) -> 'StateMonad[V, State]':
+
def internal_map(state: State) -> Tuple[U, State]:
val, n_state = self._fn(state)
return fn(val), n_state
@@ -24,7 +25,6 @@ class StateMonad(Generic[U, State]):
def internal_map(state: State) -> Tuple[V, State]:
val, n_state = self._fn(state)
-
return fn(val).run(n_state)
return StateMonad(internal_map)