summaryrefslogtreecommitdiffstats
path: root/polymatrix/expression/init
diff options
context:
space:
mode:
Diffstat (limited to 'polymatrix/expression/init')
-rw-r--r--polymatrix/expression/init/initaccumulateexpr.py25
-rw-r--r--polymatrix/expression/init/initexpressionstate.py2
-rw-r--r--polymatrix/expression/init/initfromtermsexpr.py14
-rw-r--r--polymatrix/expression/init/inittoquadraticexpr.py10
4 files changed, 50 insertions, 1 deletions
diff --git a/polymatrix/expression/init/initaccumulateexpr.py b/polymatrix/expression/init/initaccumulateexpr.py
new file mode 100644
index 0000000..30297bf
--- /dev/null
+++ b/polymatrix/expression/init/initaccumulateexpr.py
@@ -0,0 +1,25 @@
+import dataclass_abc
+import typing
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.accumulateexpr import AccumulateExpr
+
+
+def init_accumulate_expr(
+ underlying: ExpressionBaseMixin,
+ acc_func: typing.Callable,
+ initial: ExpressionBaseMixin = None,
+):
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class AccumulateExprImpl(AccumulateExpr):
+ underlying: ExpressionBaseMixin
+ initial: ExpressionBaseMixin
+
+ def acc_func(self, acc, v):
+ return acc_func(acc, v)
+
+ return AccumulateExprImpl(
+ underlying=underlying,
+ initial=initial,
+ )
diff --git a/polymatrix/expression/init/initexpressionstate.py b/polymatrix/expression/init/initexpressionstate.py
index a7d3aac..7e8a6fe 100644
--- a/polymatrix/expression/init/initexpressionstate.py
+++ b/polymatrix/expression/init/initexpressionstate.py
@@ -14,6 +14,6 @@ def init_expression_state(
return ExpressionStateImpl(
n_param=n_param,
offset_dict=offset_dict,
- auxillary_terms=tuple(),
+ auxillary_equations={},
cached_polymatrix={},
)
diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py
new file mode 100644
index 0000000..80d5198
--- /dev/null
+++ b/polymatrix/expression/init/initfromtermsexpr.py
@@ -0,0 +1,14 @@
+from polymatrix.expression.impl.fromtermsexprimpl import FromTermsExprImpl
+
+
+def init_from_terms_expr(
+ terms: tuple,
+ shape: tuple[int, int]
+):
+ if isinstance(terms, dict):
+ terms = tuple((key, tuple(value.items())) for key, value in terms.items())
+
+ return FromTermsExprImpl(
+ terms=terms,
+ shape=shape,
+ )
diff --git a/polymatrix/expression/init/inittoquadraticexpr.py b/polymatrix/expression/init/inittoquadraticexpr.py
new file mode 100644
index 0000000..dfc0567
--- /dev/null
+++ b/polymatrix/expression/init/inittoquadraticexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.toquadraticexprimpl import ToQuadraticExprImpl
+
+
+def init_to_quadratic_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return ToQuadraticExprImpl(
+ underlying=underlying,
+)