diff options
-rw-r--r-- | polymatrix/expression/expression.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 9 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 4 | ||||
-rw-r--r-- | polymatrix/expression/mixins/shapeexprmixin.py | 30 |
4 files changed, 47 insertions, 0 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 291560f..195a256 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -339,6 +339,7 @@ class Expression(ExpressionBaseMixin, ABC): ), ) + # FIXME: replace with __setitem__? def set_element_at( self, row: int, @@ -358,6 +359,9 @@ class Expression(ExpressionBaseMixin, ABC): ), ) + def shape(self) -> Expression: + return self.copy(underlying=polymatrix.expression.init.init_shape_expr(self.underlying)) + # remove? def squeeze( self, diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 176fe39..f4c4df4 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -43,6 +43,7 @@ from polymatrix.expression.mixins.quadraticmonomialsexprmixin import QuadraticMo from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin +from polymatrix.expression.mixins.shapeexprmixin import ShapeExprMixin from polymatrix.expression.mixins.sliceexprmixin import SliceExprMixin from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin @@ -364,6 +365,14 @@ class SetElementAtExprImpl(SetElementAtExprMixin): @dataclassabc.dataclassabc(frozen=True) +class ShapeExprImpl(ShapeExprMixin): + underlying: ExpressionBaseMixin + + def __str__(self): + return f"shape({self.underlying})" + + +@dataclassabc.dataclassabc(frozen=True) class SliceExprImpl(SliceExprMixin): underlying: ExpressionBaseMixin slice: tuple[int | slice | range, int | slice | range] diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 7c3f5fa..60fa580 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -376,6 +376,10 @@ def init_set_element_at_expr( ) +def init_shape_expr(underlying: ExpressionBaseMixin): + return polymatrix.expression.impl.ShapeExprImpl(underlying) + + def init_slice_expr( underlying: ExpressionBaseMixin, slice: tuple[int | slice | range, int | slice | range] diff --git a/polymatrix/expression/mixins/shapeexprmixin.py b/polymatrix/expression/mixins/shapeexprmixin.py new file mode 100644 index 0000000..916289f --- /dev/null +++ b/polymatrix/expression/mixins/shapeexprmixin.py @@ -0,0 +1,30 @@ +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.abc import PolyMatrix +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MatrixIndex, MonomialIndex + +class ShapeExprMixin(ExpressionBaseMixin): + """ + Get the shape of a polymatrix. + This gives the shape as a row vector [[nrows], [ncols]]. + """ + @property + @abstractmethod + def underlying(self) -> ExpressionBaseMixin: + """ The expression for which we compute the shape. """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]: + state, u = self.underlying.apply(state) + nrows, ncols = u.shape + p = PolyMatrixDict({ + MatrixIndex(0, 0): PolyDict({MonomialIndex.constant(): nrows}), + MatrixIndex(1, 0): PolyDict({MonomialIndex.constant(): ncols}) + }) + + return state, init_poly_matrix(p, shape=(2,1)) + |