diff options
Diffstat (limited to 'polymatrix/expression/init')
-rw-r--r-- | polymatrix/expression/init/initaccumulateexpr.py | 25 | ||||
-rw-r--r-- | polymatrix/expression/init/initexpressionstate.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/init/initfromtermsexpr.py | 14 | ||||
-rw-r--r-- | polymatrix/expression/init/inittoquadraticexpr.py | 10 |
4 files changed, 50 insertions, 1 deletions
diff --git a/polymatrix/expression/init/initaccumulateexpr.py b/polymatrix/expression/init/initaccumulateexpr.py new file mode 100644 index 0000000..30297bf --- /dev/null +++ b/polymatrix/expression/init/initaccumulateexpr.py @@ -0,0 +1,25 @@ +import dataclass_abc +import typing + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.accumulateexpr import AccumulateExpr + + +def init_accumulate_expr( + underlying: ExpressionBaseMixin, + acc_func: typing.Callable, + initial: ExpressionBaseMixin = None, +): + + @dataclass_abc.dataclass_abc(frozen=True) + class AccumulateExprImpl(AccumulateExpr): + underlying: ExpressionBaseMixin + initial: ExpressionBaseMixin + + def acc_func(self, acc, v): + return acc_func(acc, v) + + return AccumulateExprImpl( + underlying=underlying, + initial=initial, + ) diff --git a/polymatrix/expression/init/initexpressionstate.py b/polymatrix/expression/init/initexpressionstate.py index a7d3aac..7e8a6fe 100644 --- a/polymatrix/expression/init/initexpressionstate.py +++ b/polymatrix/expression/init/initexpressionstate.py @@ -14,6 +14,6 @@ def init_expression_state( return ExpressionStateImpl( n_param=n_param, offset_dict=offset_dict, - auxillary_terms=tuple(), + auxillary_equations={}, cached_polymatrix={}, ) diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py new file mode 100644 index 0000000..80d5198 --- /dev/null +++ b/polymatrix/expression/init/initfromtermsexpr.py @@ -0,0 +1,14 @@ +from polymatrix.expression.impl.fromtermsexprimpl import FromTermsExprImpl + + +def init_from_terms_expr( + terms: tuple, + shape: tuple[int, int] +): + if isinstance(terms, dict): + terms = tuple((key, tuple(value.items())) for key, value in terms.items()) + + return FromTermsExprImpl( + terms=terms, + shape=shape, + ) diff --git a/polymatrix/expression/init/inittoquadraticexpr.py b/polymatrix/expression/init/inittoquadraticexpr.py new file mode 100644 index 0000000..dfc0567 --- /dev/null +++ b/polymatrix/expression/init/inittoquadraticexpr.py @@ -0,0 +1,10 @@ +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.impl.toquadraticexprimpl import ToQuadraticExprImpl + + +def init_to_quadratic_expr( + underlying: ExpressionBaseMixin, +): + return ToQuadraticExprImpl( + underlying=underlying, +) |