summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/__init__.py23
-rw-r--r--polymatrix/denserepr/from_.py6
-rw-r--r--polymatrix/expression/expression.py139
-rw-r--r--polymatrix/expression/from_.py33
-rw-r--r--polymatrix/expression/impl.py21
-rw-r--r--polymatrix/expression/init.py99
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py79
-rw-r--r--polymatrix/expression/mixins/blockdiagexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/cacheexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/combinationsexprmixin.py33
-rw-r--r--polymatrix/expression/mixins/degreeexprmixin.py (renamed from polymatrix/expression/mixins/maxdegreeexprmixin.py)15
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/diagexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/divergenceexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/eyeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/filterexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/filterlinearpartexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/fromtermsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/fromtupleexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/legendreseriesmixin.py2
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/linearmatrixinexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/linearmonomialsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/maxexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/parametrizeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/parametrizematrixexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/productexprmixin.py36
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/quadraticmonomialsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/repmatexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/reshapeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/setelementatexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/squeezeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/substituteexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/subtractmonomialsexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/sumexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/symmetricexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/toconstantexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/toquadraticexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/tosortedvariablesmixin.py2
-rw-r--r--polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/transposeexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/truncateexprmixin.py14
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py2
-rw-r--r--polymatrix/expression/op.py30
-rw-r--r--polymatrix/expression/to.py62
-rw-r--r--polymatrix/expression/utils/multiplypolynomial.py18
-rw-r--r--polymatrix/polymatrix/impl.py9
-rw-r--r--polymatrix/polymatrix/mixins.py20
-rw-r--r--polymatrix/polymatrix/typing.py4
-rw-r--r--polymatrix/statemonad/mixins.py3
59 files changed, 346 insertions, 392 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index e513300..6592238 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -6,11 +6,11 @@ from polymatrix.expression.from_ import from_ as internal_from
from polymatrix.expression import v_stack as internal_v_stack
from polymatrix.expression import h_stack as internal_h_stack
from polymatrix.expression import product as internal_product
-from polymatrix.denserepr.from_ import from_polymatrix_expr
-from polymatrix.expression.to import shape as internal_shape
-from polymatrix.expression.to import to_constant_repr as internal_to_constant_repr
-from polymatrix.expression.to import to_degrees as internal_to_degrees
-from polymatrix.expression.to import to_sympy_repr as internal_to_sympy_repr
+from polymatrix.denserepr.from_ import from_polymatrix
+# from polymatrix.expression.to import shape as internal_shape
+from polymatrix.expression.to import to_constant as internal_to_constant
+# from polymatrix.expression.to import to_degrees as internal_to_degrees
+from polymatrix.expression.to import to_sympy as internal_to_sympy
Expression = internal_Expression
ExpressionState = internal_ExpressionState
@@ -21,11 +21,14 @@ from_ = internal_from
v_stack = internal_v_stack
h_stack = internal_h_stack
product = internal_product
-to_shape = internal_shape
-to_constant_repr = internal_to_constant_repr
-to_degrees = internal_to_degrees
-to_sympy_repr = internal_to_sympy_repr
-to_matrix_repr = from_polymatrix_expr
+# to_shape = internal_shape
+to_constant_repr = internal_to_constant
+to_constant = internal_to_constant
+# to_degrees = internal_to_degrees
+to_sympy_repr = internal_to_sympy
+to_sympy = internal_to_sympy
+to_matrix_repr = from_polymatrix
+to_dense = from_polymatrix
# def from_sympy(
# data: tuple[tuple[float]],
diff --git a/polymatrix/denserepr/from_.py b/polymatrix/denserepr/from_.py
index b61af86..a8540e6 100644
--- a/polymatrix/denserepr/from_.py
+++ b/polymatrix/denserepr/from_.py
@@ -1,8 +1,6 @@
-import dataclasses
import itertools
-import typing
import numpy as np
-import scipy.sparse
+
from polymatrix.denserepr.impl import DenseReprBufferImpl, DenseReprImpl
from polymatrix.expression.expression import Expression
@@ -14,7 +12,7 @@ from polymatrix.statemonad.mixins import StateMonadMixin
from polymatrix.expression.utils.monomialtoindex import monomial_to_index
-def from_polymatrix_expr(
+def from_polymatrix(
expressions: Expression | tuple[Expression],
variables: Expression = None,
sorted: bool = None,
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index dc2f144..f5318ae 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -9,7 +9,7 @@ import polymatrix.expression.init
from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expression.op import diff, linear_in, linear_monomials, legendre
+from polymatrix.expression.op import diff, linear_in, linear_monomials, legendre, filter_, degree
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
@@ -24,7 +24,7 @@ class Expression(
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ # overwrites the abstract method of `PolyMatrixExprBaseMixin`
def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
return self.underlying.apply(state)
@@ -44,8 +44,7 @@ class Expression(
return attr
def __getitem__(self, key: tuple[int, int]):
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_get_item_expr(
underlying=self.underlying,
index=key,
@@ -118,14 +117,6 @@ class Expression(
underlying=op(left, right, stack),
)
- # def _convert_to_expression(self, other):
- # result = init_from_expr_or_none(other)
-
- # if result is None:
- # return NotImplemented
-
- # return result
-
def cache(self) -> 'Expression':
return self.copy(
underlying=polymatrix.expression.init.init_cache_expr(
@@ -143,18 +134,23 @@ class Expression(
degrees=degrees,
),
)
+
+ def degree(self) -> 'Expression':
+ return self.copy(
+ underlying=degree(
+ underlying=self.underlying,
+ ),
+ )
def determinant(self) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_determinant_expr(
underlying=self.underlying,
),
)
def diag(self):
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_diag_expr(
underlying=self.underlying,
),
@@ -165,8 +161,7 @@ class Expression(
variables: 'Expression',
introduce_derivatives: bool = None,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=diff(
expression=self,
variables=variables,
@@ -178,8 +173,7 @@ class Expression(
self,
variables: tuple,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_divergence_expr(
underlying=self.underlying,
variables=variables,
@@ -191,8 +185,7 @@ class Expression(
variable: tuple,
value: tuple[float, ...] = None,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_eval_expr(
underlying=self.underlying,
variables=variable,
@@ -206,9 +199,8 @@ class Expression(
predicator: 'Expression',
inverse: bool = None,
) -> 'Expression':
- return dataclasses.replace(
- self,
- underlying=polymatrix.expression.init.init_filter_expr(
+ return self.copy(
+ underlying=filter_(
underlying=self.underlying,
predicator=predicator,
inverse=inverse,
@@ -217,8 +209,7 @@ class Expression(
# only applies to symmetric matrix
def from_symmetric_matrix(self) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_from_symmetric_matrix_expr(
underlying=self.underlying,
),
@@ -230,8 +221,7 @@ class Expression(
variables: 'Expression',
filter: 'Expression | None' = None,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_half_newton_polytope_expr(
monomials=self.underlying,
variables=variables,
@@ -240,8 +230,7 @@ class Expression(
)
def linear_matrix_in(self, variable: 'Expression') -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_linear_matrix_in_expr(
underlying=self.underlying,
variable=variable,
@@ -253,8 +242,7 @@ class Expression(
variables: 'Expression',
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=linear_monomials(
expression=self.underlying,
variables=variables,
@@ -268,8 +256,7 @@ class Expression(
ignore_unmatched: bool = None,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=linear_in(
expression=self.underlying,
monomials=monomials,
@@ -282,8 +269,7 @@ class Expression(
self,
degrees: tuple[int, ...] = None,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=legendre(
expression=self.underlying,
degrees=degrees,
@@ -291,47 +277,27 @@ class Expression(
)
def max(self) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_max_expr(
underlying=self.underlying,
),
)
- def max_degree(self) -> 'Expression':
- return dataclasses.replace(
- self,
- underlying=polymatrix.expression.init.init_max_degree_expr(
- underlying=self.underlying,
- ),
- )
-
def parametrize(self, name: str = None) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_parametrize_expr(
underlying=self.underlying,
name=name,
),
)
- # def parametrize_matrix(self, name: str = None) -> 'ExpressionMixin':
- # return dataclasses.replace(
- # self,
- # underlying=init_parametrize_matrix_expr(
- # underlying=self.underlying,
- # name=name,
- # ),
- # )
-
def quadratic_in(self, variables: 'Expression', monomials: 'Expression' = None) -> 'Expression':
if monomials is None:
monomials = self.quadratic_monomials(variables)
stack = get_stack_lines()
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_quadratic_in_expr(
underlying=self.underlying,
monomials=monomials,
@@ -344,8 +310,7 @@ class Expression(
self,
variables: 'Expression',
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_quadratic_monomials_expr(
underlying=self.underlying,
variables=variables,
@@ -353,8 +318,7 @@ class Expression(
)
def reshape(self, n: int, m: int) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_reshape_expr(
underlying=self.underlying,
new_shape=(n, m),
@@ -362,8 +326,7 @@ class Expression(
)
def rep_mat(self, n: int, m: int) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_rep_mat_expr(
underlying=self.underlying,
repetition=(n, m),
@@ -381,8 +344,7 @@ class Expression(
else:
value = polymatrix.expression.init.init_from_expr(value)
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_set_element_at_expr(
underlying=self.underlying,
index=(row, col),
@@ -394,8 +356,7 @@ class Expression(
def squeeze(
self,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_squeeze_expr(
underlying=self.underlying,
),
@@ -406,8 +367,7 @@ class Expression(
self,
monomials: 'Expression',
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_subtract_monomials_expr(
underlying=self.underlying,
monomials=monomials,
@@ -419,8 +379,7 @@ class Expression(
variable: tuple,
values: tuple['Expression', ...] = None,
) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_substitute_expr(
underlying=self.underlying,
variables=variable,
@@ -439,24 +398,21 @@ class Expression(
)
def sum(self):
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_sum_expr(
underlying=self.underlying,
),
)
def symmetric(self):
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_symmetric_expr(
underlying=self.underlying,
),
)
def transpose(self) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_transpose_expr(
underlying=self.underlying,
),
@@ -467,24 +423,14 @@ class Expression(
return self.transpose()
def to_constant(self) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_to_constant_expr(
underlying=self.underlying,
),
)
- # def to_quadratic(self) -> 'ExpressionMixin':
- # return dataclasses.replace(
- # self,
- # underlying=init_to_quadratic_expr(
- # underlying=self.underlying,
- # ),
- # )
-
def to_symmetric_matrix(self) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_to_symmetric_matrix_expr(
underlying=self.underlying,
),
@@ -492,8 +438,7 @@ class Expression(
# only applies to variables
def to_sorted_variables(self) -> 'Expression':
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_to_sorted_variables(
underlying=self.underlying,
),
@@ -502,12 +447,11 @@ class Expression(
# also applies to monomials?
def truncate(
self,
- variables: tuple,
degrees: tuple[int],
+ variables: tuple | None = None,
inverse: bool = None,
):
- return dataclasses.replace(
- self,
+ return self.copy(
underlying=polymatrix.expression.init.init_truncate_expr(
underlying=self.underlying,
variables=variables,
@@ -531,6 +475,7 @@ class ExpressionImpl(Expression):
)
+
def init_expression(
underlying: ExpressionBaseMixin,
):
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index c290a67..daa916a 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -5,12 +5,13 @@ import polymatrix.expression.init
from polymatrix.expression.expression import init_expression, Expression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.statemonad.abc import StateMonad
-DATA_TYPE = str | np.ndarray | sympy.Matrix | sympy.Expr | tuple | ExpressionBaseMixin
+FromDataTypes = str | np.ndarray | sympy.Matrix | sympy.Expr | tuple | ExpressionBaseMixin | StateMonad
def from_expr_or_none(
- data: DATA_TYPE,
+ data: FromDataTypes,
) -> Expression | None:
return init_expression(
@@ -20,7 +21,7 @@ def from_expr_or_none(
)
def from_(
- data: DATA_TYPE,
+ data: FromDataTypes,
) -> Expression:
return init_expression(
@@ -28,29 +29,3 @@ def from_(
data=data,
),
)
-
-# def from_expr(
-# data: DATA_TYPE,
-# ) -> Expression:
-# return from_(data=data)
-
-# def from_sympy(
-# data: tuple[tuple[float]],
-# ):
-# return init_expression(
-# polymatrix.expression.init.init_from_sympy_expr(data)
-# )
-
-# def from_state_monad(
-# data: StateMonad,
-# ):
-# return init_expression(
-# data.flat_map(lambda inner_data: polymatrix.expression.init.init_from_sympy_expr(inner_data)),
-# )
-
-# def from_polymatrix(
-# polymatrix: PolyMatrix,
-# ):
-# return init_expression(
-# polymatrix.expression.init.init_from_terms_expr(polymatrix)
-# )
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 3e61941..937676d 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -1,8 +1,9 @@
+import typing
+import dataclassabc
+
from polymatrix.expression.mixins.legendreseriesmixin import LegendreSeriesMixin
from polymatrix.expression.mixins.productexprmixin import ProductExprMixin
from polymatrix.utils.getstacklines import FrameSummary
-import dataclassabc
-
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin
from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin
@@ -35,7 +36,7 @@ from polymatrix.expression.mixins.linearmonomialsexprmixin import \
LinearMonomialsExprMixin
from polymatrix.expression.mixins.matrixmultexprmixin import \
MatrixMultExprMixin
-from polymatrix.expression.mixins.maxdegreeexprmixin import MaxDegreeExprMixin
+from polymatrix.expression.mixins.degreeexprmixin import DegreeExprMixin
from polymatrix.expression.mixins.maxexprmixin import MaxExprMixin
from polymatrix.expression.mixins.parametrizeexprmixin import \
ParametrizeExprMixin
@@ -150,7 +151,7 @@ class EyeExprImpl(EyeExprMixin):
class FilterExprImpl(FilterExprMixin):
underlying: ExpressionBaseMixin
predicator: ExpressionBaseMixin
- inverse: bool == None
+ inverse: bool
@dataclassabc.dataclassabc(frozen=True)
@@ -211,6 +212,9 @@ class LegendreSeriesImpl(LegendreSeriesMixin):
degrees: tuple[int, ...] | None
stack: tuple[FrameSummary]
+ def __repr__(self):
+ return f'{self.__class__.__name__}(underlying={self.underlying}, degrees={self.degrees})'
+
@dataclassabc.dataclassabc(frozen=True)
class MatrixMultExprImpl(MatrixMultExprMixin):
@@ -223,8 +227,12 @@ class MatrixMultExprImpl(MatrixMultExprMixin):
@dataclassabc.dataclassabc(frozen=True)
-class MaxDegreeExprImpl(MaxDegreeExprMixin):
+class DegreeExprImpl(DegreeExprMixin):
underlying: ExpressionBaseMixin
+ stack: tuple[FrameSummary]
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(underlying={self.underlying})'
@dataclassabc.dataclassabc(frozen=True)
@@ -253,6 +261,9 @@ class ProductExprImpl(ProductExprMixin):
degrees: tuple[int, ...] | None
stack: tuple[FrameSummary]
+ def __repr__(self):
+ return f'{self.__class__.__name__}(underlying={self.underlying}, degrees={self.degrees})'
+
@dataclassabc.dataclassabc(frozen=True)
class QuadraticInExprImpl(QuadraticInExprMixin):
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index b7be83a..5e4eac1 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -4,11 +4,12 @@ import sympy
import polymatrix.expression.impl
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.statemonad.abc import StateMonad
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.getstacklines import get_stack_lines
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.utils.formatsubstitutions import format_substitutions
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl import FromTupleExprImpl, AdditionExprImpl
@@ -53,26 +54,6 @@ def init_combinations_expr(
)
-# def init_derivative_expr(
-# underlying: ExpressionBaseMixin,
-# variables: ExpressionBaseMixin,
-# stack: tuple[FrameSummary],
-# introduce_derivatives: bool = None,
-# ):
-
-# assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
-
-# if introduce_derivatives is None:
-# introduce_derivatives = False
-
-# return polymatrix.expression.impl.DerivativeExprImpl(
-# underlying=underlying,
-# variables=variables,
-# introduce_derivatives=introduce_derivatives,
-# stack=stack,
-# )
-
-
def init_determinant_expr(
underlying: ExpressionBaseMixin,
):
@@ -162,20 +143,6 @@ def init_eye_expr(
)
-def init_filter_expr(
- underlying: ExpressionBaseMixin,
- predicator: ExpressionBaseMixin,
- inverse: bool = None,
-):
- if inverse is None:
- inverse = False
-
- return polymatrix.expression.impl.FilterExprImpl(
- underlying=underlying,
- predicator=predicator,
- inverse=inverse,
-)
-
def init_from_symmetric_matrix_expr(
underlying: ExpressionBaseMixin,
@@ -195,6 +162,9 @@ def init_from_expr_or_none(
underlying=init_from_expr_or_none(1),
name=data,
)
+
+ elif isinstance(data, StateMonad):
+ return data.flat_map(lambda inner_data: init_from_expr_or_none(inner_data))
elif isinstance(data, np.ndarray):
assert len(data.shape) <= 2
@@ -310,22 +280,6 @@ def init_half_newton_polytope_expr(
)
-# def init_linear_in_expr(
-# underlying: ExpressionBaseMixin,
-# monomials: ExpressionBaseMixin,
-# variables: ExpressionBaseMixin,
-# ignore_unmatched: bool = None,
-# ):
-# assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
-
-# return polymatrix.expression.impl.LinearInExprImpl(
-# underlying=underlying,
-# monomials=monomials,
-# variables=variables,
-# ignore_unmatched = ignore_unmatched,
-# )
-
-
def init_linear_matrix_in_expr(
underlying: ExpressionBaseMixin,
variable: int,
@@ -336,19 +290,6 @@ def init_linear_matrix_in_expr(
)
-# def init_linear_monomials_expr(
-# underlying: ExpressionBaseMixin,
-# variables: ExpressionBaseMixin,
-# ):
-
-# assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
-
-# return polymatrix.expression.impl.LinearMonomialsExprImpl(
-# underlying=underlying,
-# variables=variables,
-# )
-
-
def init_matrix_mult_expr(
left: ExpressionBaseMixin,
right: ExpressionBaseMixin,
@@ -361,13 +302,6 @@ def init_matrix_mult_expr(
)
-def init_max_degree_expr(
- underlying: ExpressionBaseMixin,
-):
- return polymatrix.expression.impl.MaxDegreeExprImpl(
- underlying=underlying,
-)
-
def init_max_expr(
underlying: ExpressionBaseMixin,
@@ -569,9 +503,9 @@ def init_transpose_expr(
def init_truncate_expr(
underlying: ExpressionBaseMixin,
- variables: ExpressionBaseMixin,
degrees: tuple[int],
- inverse: bool = None,
+ variables: ExpressionBaseMixin | None = None,
+ inverse: bool | None = None,
):
if isinstance(degrees, int):
degrees = (degrees,)
@@ -585,20 +519,3 @@ def init_truncate_expr(
degrees=degrees,
inverse=inverse,
)
-
-
-# def init_v_stack_expr(
-# underlying: tuple,
-# ):
-
-# def gen_underlying():
-
-# for e in underlying:
-# if isinstance(e, ExpressionBaseMixin):
-# yield e
-# else:
-# yield init_from_(e)
-
-# return polymatrix.expression.impl.VStackExprImpl(
-# underlying=tuple(gen_underlying()),
-# )
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 2cbbe1e..7dbf1d2 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -1,12 +1,10 @@
import abc
import math
-import typing
-import dataclassabc
+from polymatrix.polymatrix.impl import BroadcastPolyMatrixImpl
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.tooperatorexception import to_operator_exception
from polymatrix.polymatrix.init import init_poly_matrix
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -28,7 +26,32 @@ class AdditionExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ @staticmethod
+ def broadcast(left: PolyMatrix, right: PolyMatrix, stack: tuple[FrameSummary]):
+ # broadcast left
+ if left.shape == (1, 1) and right.shape != (1, 1):
+ left = BroadcastPolyMatrixImpl(
+ polynomial=left.get_poly(0, 0),
+ shape=right.shape,
+ )
+
+ # broadcast right
+ elif left.shape != (1, 1) and right.shape == (1, 1):
+ right = BroadcastPolyMatrixImpl(
+ polynomial=right.get_poly(0, 0),
+ shape=left.shape,
+ )
+
+ else:
+ if not (left.shape == right.shape):
+ raise AssertionError(to_operator_exception(
+ message=f'{left.shape} != {right.shape}',
+ stack=stack,
+ ))
+
+ return left, right
+
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
@@ -36,38 +59,37 @@ class AdditionExprMixin(ExpressionBaseMixin):
state, left = self.left.apply(state=state)
state, right = self.right.apply(state=state)
- if left.shape == (1, 1):
- left, right = right, left
+ # if left.shape == (1, 1):
+ # left, right = right, left
- if left.shape != (1, 1) and right.shape == (1, 1):
+ # if left.shape != (1, 1) and right.shape == (1, 1):
- @dataclassabc.dataclassabc(frozen=True)
- class BroadCastedPolyMatrix(PolyMatrixMixin):
- underlying_monomials: tuple[tuple[int], float]
- shape: tuple[int, int]
+ # # @dataclassabc.dataclassabc(frozen=True)
+ # # class BroadCastedPolyMatrix(PolyMatrixMixin):
+ # # underlying_monomials: tuple[tuple[int], float]
+ # # shape: tuple[int, int]
- def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
- return self.underlying_monomials
+ # # def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ # # return self.underlying_monomials
- polynomial = right.get_poly(0, 0)
- if polynomial is not None:
+ # right = BroadcastPolyMatrixImpl(
+ # polynomial=right.get_poly(0, 0),
+ # shape=left.shape,
+ # )
- broadcasted_right = BroadCastedPolyMatrix(
- underlying_monomials=polynomial,
- shape=left.shape,
- )
+ # # all_underlying = (left, broadcasted_right)
- all_underlying = (left, broadcasted_right)
+ # else:
+ # if not (left.shape == right.shape):
+ # raise AssertionError(to_operator_exception(
+ # message=f'{left.shape} != {right.shape}',
+ # stack=self.stack,
+ # ))
- else:
- if not (left.shape == right.shape):
- raise AssertionError(to_operator_exception(
- message=f'{left.shape} != {right.shape}',
- stack=self.stack,
- ))
+ # # all_underlying = (left, right)
- all_underlying = (left, right)
+ left, right = self.broadcast(left, right, self.stack)
terms = {}
@@ -76,7 +98,7 @@ class AdditionExprMixin(ExpressionBaseMixin):
terms_row_col = {}
- for underlying in all_underlying:
+ for underlying in (left, right):
polynomial = underlying.get_poly(row, col)
if polynomial is None:
@@ -90,6 +112,7 @@ class AdditionExprMixin(ExpressionBaseMixin):
if monomial not in terms_row_col:
terms_row_col[monomial] = value
+
else:
terms_row_col[monomial] += value
diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py
index 3ed5766..4754d53 100644
--- a/polymatrix/expression/mixins/blockdiagexprmixin.py
+++ b/polymatrix/expression/mixins/blockdiagexprmixin.py
@@ -15,7 +15,7 @@ class BlockDiagExprMixin(ExpressionBaseMixin):
def underlying(self) -> tuple[ExpressionBaseMixin, ...]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/cacheexprmixin.py b/polymatrix/expression/mixins/cacheexprmixin.py
index 937b0f3..e5aef11 100644
--- a/polymatrix/expression/mixins/cacheexprmixin.py
+++ b/polymatrix/expression/mixins/cacheexprmixin.py
@@ -17,7 +17,7 @@ class CacheExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/combinationsexprmixin.py b/polymatrix/expression/mixins/combinationsexprmixin.py
index 395f068..da669f4 100644
--- a/polymatrix/expression/mixins/combinationsexprmixin.py
+++ b/polymatrix/expression/mixins/combinationsexprmixin.py
@@ -1,6 +1,7 @@
import abc
import itertools
+from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -26,7 +27,7 @@ class CombinationsExprMixin(ExpressionBaseMixin):
def degrees(self) -> tuple[int, ...]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
@@ -59,32 +60,28 @@ class CombinationsExprMixin(ExpressionBaseMixin):
for row, indexing in enumerate(indices):
- # print(indexing)
-
- if indexing is tuple():
+ # x.combinations((0, 1, 2)) produces [1, x, x**2]
+ if len(indexing) == 0:
terms[row, 0] = {tuple(): 1.0}
continue
- def acc_product(acc, v):
- left_monomials = acc
- row = v
-
- right_monomials = poly_matrix.get_poly(row, 0).keys()
- # print(right_monomials)
-
- if left_monomials is (None,):
- return right_monomials
+ def acc_product(left, row):
+ right = poly_matrix.get_poly(row, 0)
- monomials = tuple(multiply_monomials(left_monomials, right_monomials))
- return monomials
+ if len(left) == 0:
+ return right
+
+ result = {}
+ multiply_polynomial(left, right, result)
+ return result
- *_, monomials = itertools.accumulate(
+ *_, polynomial = itertools.accumulate(
indexing,
acc_product,
- initial=(None,),
+ initial={},
)
- terms[row, 0] = {m: 1.0 for m in monomials}
+ terms[row, 0] = polynomial
poly_matrix = init_poly_matrix(
terms=terms,
diff --git a/polymatrix/expression/mixins/maxdegreeexprmixin.py b/polymatrix/expression/mixins/degreeexprmixin.py
index 0094b9b..273add2 100644
--- a/polymatrix/expression/mixins/maxdegreeexprmixin.py
+++ b/polymatrix/expression/mixins/degreeexprmixin.py
@@ -5,15 +5,21 @@ from polymatrix.polymatrix.init import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
from polymatrix.expressionstate.abc import ExpressionState
+from polymatrix.utils.getstacklines import FrameSummary
-class MaxDegreeExprMixin(ExpressionBaseMixin):
+class DegreeExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ @property
+ @abc.abstractmethod
+ def stack(self) -> tuple[FrameSummary]:
+ ...
+
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
@@ -26,13 +32,16 @@ class MaxDegreeExprMixin(ExpressionBaseMixin):
for col in range(underlying.shape[1]):
underlying_terms = underlying.get_poly(row, col)
- if underlying_terms is None:
+
+ if underlying_terms is None or len(underlying_terms) == 0:
continue
def gen_degrees():
for monomial, _ in underlying_terms.items():
yield sum(count for _, count in monomial)
+ # degrees = tuple(gen_degrees())
+
terms[row, col] = {tuple(): max(gen_degrees())}
poly_matrix = init_poly_matrix(
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index bab3c91..1728a2d 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -41,7 +41,7 @@ class DerivativeExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
index ef18e91..7614669 100644
--- a/polymatrix/expression/mixins/determinantexprmixin.py
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -16,12 +16,12 @@ class DeterminantExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
+ # # overwrites the abstract method of `ExpressionBaseMixin`
# @property
# def shape(self) -> tuple[int, int]:
# return self.underlying.shape[0], 1
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py
index 1867f8a..046dd5c 100644
--- a/polymatrix/expression/mixins/diagexprmixin.py
+++ b/polymatrix/expression/mixins/diagexprmixin.py
@@ -19,7 +19,7 @@ class DiagExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py
index 4dc61cd..29d1a0a 100644
--- a/polymatrix/expression/mixins/divergenceexprmixin.py
+++ b/polymatrix/expression/mixins/divergenceexprmixin.py
@@ -22,7 +22,7 @@ class DivergenceExprMixin(ExpressionBaseMixin):
def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index 90919a5..0f2fada 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -27,7 +27,7 @@ class DivisionExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index 96b793c..e6e64b1 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -23,7 +23,7 @@ class ElemMultExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # # overwrites abstract method of `ExpressionBaseMixin`
+ # # overwrites the abstract method of `ExpressionBaseMixin`
# @property
# def shape(self) -> tuple[int, int]:
# return self.left.shape
@@ -96,7 +96,7 @@ class ElemMultExprMixin(ExpressionBaseMixin):
return state, poly_matrix
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index 30dc178..0c23c1e 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -21,7 +21,7 @@ class EvalExprMixin(ExpressionBaseMixin):
def substitutions(self) -> tuple:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py
index 4a294ac..e7cd928 100644
--- a/polymatrix/expression/mixins/eyeexprmixin.py
+++ b/polymatrix/expression/mixins/eyeexprmixin.py
@@ -15,7 +15,7 @@ class EyeExprMixin(ExpressionBaseMixin):
def variable(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py
index 1cf90d8..86b5c6a 100644
--- a/polymatrix/expression/mixins/filterexprmixin.py
+++ b/polymatrix/expression/mixins/filterexprmixin.py
@@ -24,7 +24,7 @@ class FilterExprMixin(ExpressionBaseMixin):
def inverse(self) -> bool:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
@@ -57,6 +57,9 @@ class FilterExprMixin(ExpressionBaseMixin):
if key in predicator_poly:
predicator_value = round(predicator_poly[key])
+ if isinstance(predicator_value, (float, bool)):
+ predicator_value = int(predicator_value)
+
else:
predicator_value = 0
diff --git a/polymatrix/expression/mixins/filterlinearpartexprmixin.py b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
index 8eaf2ef..dfca992 100644
--- a/polymatrix/expression/mixins/filterlinearpartexprmixin.py
+++ b/polymatrix/expression/mixins/filterlinearpartexprmixin.py
@@ -20,7 +20,7 @@ class FilterLinearPartExprMixin(ExpressionBaseMixin):
def variable(self) -> int:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py
index a31655f..7d48c0f 100644
--- a/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py
+++ b/polymatrix/expression/mixins/fromsymmetricmatrixexprmixin.py
@@ -18,7 +18,7 @@ class FromSymmetricMatrixExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py
index 00f1b64..8aac10b 100644
--- a/polymatrix/expression/mixins/fromtermsexprmixin.py
+++ b/polymatrix/expression/mixins/fromtermsexprmixin.py
@@ -20,7 +20,7 @@ class FromTermsExprMixin(ExpressionBaseMixin):
def shape(self) -> tuple[int, int]:
pass
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/fromtupleexprmixin.py b/polymatrix/expression/mixins/fromtupleexprmixin.py
index 302bcc6..19b3ed3 100644
--- a/polymatrix/expression/mixins/fromtupleexprmixin.py
+++ b/polymatrix/expression/mixins/fromtupleexprmixin.py
@@ -25,7 +25,7 @@ class FromTupleExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
@@ -36,6 +36,9 @@ class FromTupleExprMixin(ExpressionBaseMixin):
for poly_row, col_data in enumerate(self.data):
for poly_col, poly_data in enumerate(col_data):
+ if isinstance(poly_data, (bool, np.bool_)):
+ poly_data = int(poly_data)
+
if isinstance(poly_data, (int, float, np.number)):
if math.isclose(poly_data, 0):
polynomial = {}
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
index a3b513f..76a5ee4 100644
--- a/polymatrix/expression/mixins/getitemexprmixin.py
+++ b/polymatrix/expression/mixins/getitemexprmixin.py
@@ -27,7 +27,7 @@ class GetItemExprMixin(ExpressionBaseMixin):
def index(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py
index c6a470f..21d16fd 100644
--- a/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py
+++ b/polymatrix/expression/mixins/halfnewtonpolytopeexprmixin.py
@@ -28,7 +28,7 @@ class HalfNewtonPolytopeExprMixin(ExpressionBaseMixin):
def filter(self) -> ExpressionBaseMixin | None:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/legendreseriesmixin.py b/polymatrix/expression/mixins/legendreseriesmixin.py
index 98139ae..194e65e 100644
--- a/polymatrix/expression/mixins/legendreseriesmixin.py
+++ b/polymatrix/expression/mixins/legendreseriesmixin.py
@@ -23,7 +23,7 @@ class LegendreSeriesMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index 0bc99e9..ded61ec 100644
--- a/polymatrix/expression/mixins/linearinexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -51,7 +51,7 @@ class LinearInExprMixin(ExpressionBaseMixin):
def ignore_unmatched(self) -> bool:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
index 8e0f997..4d25a1e 100644
--- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py
+++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
@@ -21,7 +21,7 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin):
def variables(self) -> tuple:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
index 1d077e2..516a7a4 100644
--- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
@@ -38,7 +38,7 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin):
def variables(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index 7014157..3343a65 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -26,7 +26,7 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
@@ -69,4 +69,4 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
shape=(left.shape[0], right.shape[1]),
)
- return state, poly_matrix
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/maxexprmixin.py b/polymatrix/expression/mixins/maxexprmixin.py
index 5366a83..bfabc4d 100644
--- a/polymatrix/expression/mixins/maxexprmixin.py
+++ b/polymatrix/expression/mixins/maxexprmixin.py
@@ -13,7 +13,7 @@ class MaxExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py
index c16fcf4..712ff3f 100644
--- a/polymatrix/expression/mixins/parametrizeexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizeexprmixin.py
@@ -19,7 +19,7 @@ class ParametrizeExprMixin(ExpressionBaseMixin):
def name(self) -> str:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/parametrizematrixexprmixin.py b/polymatrix/expression/mixins/parametrizematrixexprmixin.py
index a4cc43d..1aa393b 100644
--- a/polymatrix/expression/mixins/parametrizematrixexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizematrixexprmixin.py
@@ -19,7 +19,7 @@ class ParametrizeMatrixExprMixin(ExpressionBaseMixin):
def name(self) -> str:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/productexprmixin.py b/polymatrix/expression/mixins/productexprmixin.py
index c60ffc2..45be74b 100644
--- a/polymatrix/expression/mixins/productexprmixin.py
+++ b/polymatrix/expression/mixins/productexprmixin.py
@@ -1,6 +1,7 @@
import abc
import itertools
+from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.tooperatorexception import to_operator_exception
@@ -27,7 +28,7 @@ class ProductExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
@@ -101,29 +102,36 @@ class ProductExprMixin(ExpressionBaseMixin):
for row, indexing in enumerate(indices):
- # print(indexing)
+ # def acc_product(acc, v):
+ # left_monomials = acc
+ # polymatrix, row = v
- def acc_product(acc, v):
- left_monomials = acc
- polymatrix, row = v
+ # right_monomials = polymatrix.get_poly(row, 0).keys()
- right_monomials = polymatrix.get_poly(row, 0).keys()
+ # if left_monomials is (None,):
+ # return right_monomials
- # print(f'{left_monomials=}')
- # print(f'{right_monomials=}')
+ # return tuple(multiply_monomials(left_monomials, right_monomials))
- if left_monomials is (None,):
- return right_monomials
+ def acc_product(left, v):
+ poly_matrix, row = v
- return tuple(multiply_monomials(left_monomials, right_monomials))
+ right = poly_matrix.get_poly(row, 0)
- *_, monomials = itertools.accumulate(
+ if len(left) == 0:
+ return right
+
+ result = {}
+ multiply_polynomial(left, right, result)
+ return result
+
+ *_, polynomial = itertools.accumulate(
zip(underlying, indexing),
acc_product,
- initial=(None,),
+ initial={},
)
- terms[row, 0] = {m: 1.0 for m in monomials}
+ terms[row, 0] = polynomial
poly_matrix = init_poly_matrix(
terms=terms,
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index f567f5e..6fe1d4b 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -34,7 +34,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
def stack(self) -> tuple[FrameSummary]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
index d53613a..ff16275 100644
--- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
@@ -38,7 +38,7 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
def variables(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py
index 4726e88..48748be 100644
--- a/polymatrix/expression/mixins/repmatexprmixin.py
+++ b/polymatrix/expression/mixins/repmatexprmixin.py
@@ -16,7 +16,7 @@ class RepMatExprMixin(ExpressionBaseMixin):
def repetition(self) -> tuple[int, int]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py
index 71042a3..915e750 100644
--- a/polymatrix/expression/mixins/reshapeexprmixin.py
+++ b/polymatrix/expression/mixins/reshapeexprmixin.py
@@ -20,7 +20,7 @@ class ReshapeExprMixin(ExpressionBaseMixin):
def new_shape(self) -> tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py
index 51428cb..9250d3c 100644
--- a/polymatrix/expression/mixins/setelementatexprmixin.py
+++ b/polymatrix/expression/mixins/setelementatexprmixin.py
@@ -33,7 +33,7 @@ class SetElementAtExprMixin(ExpressionBaseMixin):
def value(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/squeezeexprmixin.py b/polymatrix/expression/mixins/squeezeexprmixin.py
index 28fec73..eed3d71 100644
--- a/polymatrix/expression/mixins/squeezeexprmixin.py
+++ b/polymatrix/expression/mixins/squeezeexprmixin.py
@@ -14,7 +14,7 @@ class SqueezeExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
index 64670e9..1b30dcb 100644
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ b/polymatrix/expression/mixins/substituteexprmixin.py
@@ -23,7 +23,7 @@ class SubstituteExprMixin(ExpressionBaseMixin):
def substitutions(self) -> tuple[tuple[typing.Any, ExpressionBaseMixin], ...]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
index 67ee6f0..37add0a 100644
--- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py
@@ -22,7 +22,7 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin):
def monomials(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/sumexprmixin.py b/polymatrix/expression/mixins/sumexprmixin.py
index 92be990..812b260 100644
--- a/polymatrix/expression/mixins/sumexprmixin.py
+++ b/polymatrix/expression/mixins/sumexprmixin.py
@@ -21,7 +21,7 @@ class SumExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py
index bcdf316..769e8c6 100644
--- a/polymatrix/expression/mixins/symmetricexprmixin.py
+++ b/polymatrix/expression/mixins/symmetricexprmixin.py
@@ -22,7 +22,7 @@ class SymmetricExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/toconstantexprmixin.py b/polymatrix/expression/mixins/toconstantexprmixin.py
index 40693d3..b8b3d64 100644
--- a/polymatrix/expression/mixins/toconstantexprmixin.py
+++ b/polymatrix/expression/mixins/toconstantexprmixin.py
@@ -19,7 +19,7 @@ class ToConstantExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py
index 6ba2d3a..d10fcd9 100644
--- a/polymatrix/expression/mixins/toquadraticexprmixin.py
+++ b/polymatrix/expression/mixins/toquadraticexprmixin.py
@@ -16,7 +16,7 @@ class ToQuadraticExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/tosortedvariablesmixin.py b/polymatrix/expression/mixins/tosortedvariablesmixin.py
index 9768bcc..a61a7e9 100644
--- a/polymatrix/expression/mixins/tosortedvariablesmixin.py
+++ b/polymatrix/expression/mixins/tosortedvariablesmixin.py
@@ -15,7 +15,7 @@ class ToSortedVariablesExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py
index 164f159..8801147 100644
--- a/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py
+++ b/polymatrix/expression/mixins/tosymmetricmatrixexprmixin.py
@@ -20,7 +20,7 @@ class ToSymmetricMatrixExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionStateMixin,
diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py
index e1a468d..ce435e9 100644
--- a/polymatrix/expression/mixins/transposeexprmixin.py
+++ b/polymatrix/expression/mixins/transposeexprmixin.py
@@ -22,7 +22,7 @@ class TransposeExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ # overwrites the abstract method of `PolyMatrixExprBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py
index 05afd55..fdf3921 100644
--- a/polymatrix/expression/mixins/truncateexprmixin.py
+++ b/polymatrix/expression/mixins/truncateexprmixin.py
@@ -16,7 +16,7 @@ class TruncateExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> ExpressionBaseMixin:
+ def variables(self) -> ExpressionBaseMixin | None:
...
@property
@@ -29,13 +29,19 @@ class TruncateExprMixin(ExpressionBaseMixin):
def inverse(self) -> bool:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices_from_variable(state, self.variables)
+
+ if self.variables is None:
+ cond = lambda idx: True
+
+ else:
+ state, variable_indices = get_variable_indices_from_variable(state, self.variables)
+ cond = lambda idx: idx in variable_indices
terms = {}
@@ -50,7 +56,7 @@ class TruncateExprMixin(ExpressionBaseMixin):
for monomial, value in polynomial.items():
- degree = sum((count for var_idx, count in monomial if var_idx in variable_indices))
+ degree = sum((count for var_idx, count in monomial if cond(var_idx)))
if (degree in self.degrees) is not self.inverse:
terms_row_col[monomial] = value
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
index 104a608..98ab663 100644
--- a/polymatrix/expression/mixins/vstackexprmixin.py
+++ b/polymatrix/expression/mixins/vstackexprmixin.py
@@ -21,7 +21,7 @@ class VStackExprMixin(ExpressionBaseMixin):
def underlying(self) -> tuple[ExpressionBaseMixin, ...]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
+ # overwrites the abstract method of `ExpressionBaseMixin`
def apply(
self,
state: ExpressionState,
diff --git a/polymatrix/expression/op.py b/polymatrix/expression/op.py
index b62253c..f19d310 100644
--- a/polymatrix/expression/op.py
+++ b/polymatrix/expression/op.py
@@ -24,6 +24,22 @@ def diff(
)
+def filter_(
+ underlying: ExpressionBaseMixin,
+ predicator: ExpressionBaseMixin,
+ inverse: bool = None,
+) -> ExpressionBaseMixin:
+
+ if inverse is None:
+ inverse = False
+
+ return polymatrix.expression.impl.FilterExprImpl(
+ underlying=underlying,
+ predicator=predicator,
+ inverse=inverse,
+ )
+
+
def legendre(
expression: ExpressionBaseMixin,
degrees: tuple[int, ...] = None,
@@ -62,6 +78,14 @@ def linear_monomials(
) -> ExpressionBaseMixin:
return polymatrix.expression.impl.LinearMonomialsExprImpl(
- underlying=expression,
- variables=variables,
- )
+ underlying=expression,
+ variables=variables,
+ )
+
+def degree(
+ underlying: ExpressionBaseMixin,
+):
+ return polymatrix.expression.impl.DegreeExprImpl(
+ underlying=underlying,
+ stack=get_stack_lines(),
+ )
diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py
index 5bec552..b02bb11 100644
--- a/polymatrix/expression/to.py
+++ b/polymatrix/expression/to.py
@@ -10,18 +10,18 @@ from polymatrix.statemonad.init import init_state_monad
from polymatrix.statemonad.mixins import StateMonadMixin
-def shape(
- expr: Expression,
-) -> StateMonadMixin[ExpressionState, tuple[int, ...]]:
- def func(state: ExpressionState):
- state, polymatrix = expr.apply(state)
+# def shape(
+# expr: Expression,
+# ) -> StateMonadMixin[ExpressionState, tuple[int, ...]]:
+# def func(state: ExpressionState):
+# state, polymatrix = expr.apply(state)
- return state, polymatrix.shape
+# return state, polymatrix.shape
- return init_state_monad(func)
+# return init_state_monad(func)
-def to_constant_repr(
+def to_constant(
expr: Expression,
assert_constant: bool = True,
) -> StateMonadMixin[ExpressionState, np.ndarray]:
@@ -44,40 +44,40 @@ def to_constant_repr(
return init_state_monad(func)
-def to_degrees(
- expr: Expression,
- variables: Expression,
-) -> StateMonadMixin[ExpressionState, np.ndarray]:
+# def to_degrees(
+# expr: Expression,
+# variables: Expression,
+# ) -> StateMonadMixin[ExpressionState, np.ndarray]:
- def func(state: ExpressionState):
- state, underlying = expr.apply(state)
- state, variable_indices = get_variable_indices_from_variable(state, variables)
+# def func(state: ExpressionState):
+# state, underlying = expr.apply(state)
+# state, variable_indices = get_variable_indices_from_variable(state, variables)
- def gen_rows():
- for row in range(underlying.shape[0]):
- def gen_cols():
- for col in range(underlying.shape[1]):
+# def gen_rows():
+# for row in range(underlying.shape[0]):
+# def gen_cols():
+# for col in range(underlying.shape[1]):
- def gen_degrees():
- polynomial = underlying.get_poly(row, col)
+# def gen_degrees():
+# polynomial = underlying.get_poly(row, col)
- if polynomial is None:
- yield 0
+# if polynomial is None:
+# yield 0
- else:
- for monomial, _ in polynomial.items():
- yield sum(count for var, count in monomial if var in variable_indices)
+# else:
+# for monomial, _ in polynomial.items():
+# yield sum(count for var, count in monomial if var in variable_indices)
- yield tuple(set(gen_degrees()))
+# yield tuple(set(gen_degrees()))
- yield tuple(gen_cols())
+# yield tuple(gen_cols())
- return state, tuple(gen_rows())
+# return state, tuple(gen_rows())
- return init_state_monad(func)
+# return init_state_monad(func)
-def to_sympy_repr(
+def to_sympy(
expr: Expression,
) -> StateMonadMixin[ExpressionState, sympy.Expr]:
diff --git a/polymatrix/expression/utils/multiplypolynomial.py b/polymatrix/expression/utils/multiplypolynomial.py
index 0b6f5b6..e27a124 100644
--- a/polymatrix/expression/utils/multiplypolynomial.py
+++ b/polymatrix/expression/utils/multiplypolynomial.py
@@ -2,9 +2,13 @@ import itertools
import math
from polymatrix.expression.utils.mergemonomialindices import merge_monomial_indices
+from polymatrix.polymatrix.typing import PolynomialData
-
-def multiply_polynomial(left, right, terms):
+def multiply_polynomial(left: PolynomialData, right: PolynomialData, result: PolynomialData):
+ """
+ Multiplies two polynomials `left` and `right` and adds the result to the mutable polynomial `result`.
+ """
+
for (left_monomial, left_value), (right_monomial, right_value) \
in itertools.product(left.items(), right.items()):
@@ -15,10 +19,10 @@ def multiply_polynomial(left, right, terms):
monomial = merge_monomial_indices((left_monomial, right_monomial))
- if monomial not in terms:
- terms[monomial] = 0
+ if monomial not in result:
+ result[monomial] = 0
- terms[monomial] += value
+ result[monomial] += value
- if math.isclose(terms[monomial], 0, abs_tol=1e-12):
- del terms[monomial]
+ if math.isclose(result[monomial], 0, abs_tol=1e-12):
+ del result[monomial]
diff --git a/polymatrix/polymatrix/impl.py b/polymatrix/polymatrix/impl.py
index f44dc9c..fe5946e 100644
--- a/polymatrix/polymatrix/impl.py
+++ b/polymatrix/polymatrix/impl.py
@@ -1,9 +1,16 @@
import dataclassabc
-from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.mixins import BroadcastPolyMatrixMixin
@dataclassabc.dataclassabc(frozen=True)
class PolyMatrixImpl(PolyMatrix):
terms: dict
shape: tuple[int, ...]
+
+
+@dataclassabc.dataclassabc(frozen=True)
+class BroadcastPolyMatrixImpl(BroadcastPolyMatrixMixin):
+ polynomial: tuple[tuple[int], float]
+ shape: tuple[int, int]
diff --git a/polymatrix/polymatrix/mixins.py b/polymatrix/polymatrix/mixins.py
index a097615..1aa0f19 100644
--- a/polymatrix/polymatrix/mixins.py
+++ b/polymatrix/polymatrix/mixins.py
@@ -18,7 +18,7 @@ class PolyMatrixMixin(abc.ABC):
yield (row, col), polynomial
@abc.abstractclassmethod
- def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None:
...
@@ -31,7 +31,21 @@ class PolyMatrixAsDictMixin(
def terms(self) -> dict[tuple[int, int], dict[tuple[int, ...], float]]:
...
- # overwrites abstract method of `PolyMatrixMixin`
- def get_poly(self, row: int, col: int) -> typing.Optional[dict[tuple[int, ...], float]]:
+ # overwrites the abstract method of `PolyMatrixMixin`
+ def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float] | None:
if (row, col) in self.terms:
return self.terms[row, col]
+
+
+class BroadcastPolyMatrixMixin(
+ PolyMatrixMixin,
+ abc.ABC,
+):
+ @property
+ @abc.abstractmethod
+ def polynomial(self) -> dict[tuple[int, ...], float]:
+ ...
+
+ # overwrites the abstract method of `PolyMatrixMixin`
+ def get_poly(self, col: int, row: int) -> dict[tuple[int, ...], float] | None:
+ return self.polynomial
diff --git a/polymatrix/polymatrix/typing.py b/polymatrix/polymatrix/typing.py
new file mode 100644
index 0000000..f8031e5
--- /dev/null
+++ b/polymatrix/polymatrix/typing.py
@@ -0,0 +1,4 @@
+
+
+PolynomialData = dict[tuple[int, ...], float]
+PolynomialMatrixData = dict[tuple[int, int], dict[tuple[int, ...], float]] \ No newline at end of file
diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad/mixins.py
index 1a13440..dad7662 100644
--- a/polymatrix/statemonad/mixins.py
+++ b/polymatrix/statemonad/mixins.py
@@ -67,3 +67,6 @@ class StateMonadMixin(
def apply(self, state: State) -> Tuple[State, U]:
return self.apply_func(state)
+ def read(self, state: State) -> U:
+ return self.apply_func(state)[1]
+