summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-24 18:19:32 +0200
committerNao Pross <np@0hm.ch>2024-05-24 18:19:32 +0200
commitab63e0bce3e4fa64f4d713316e65128da1a4b06b (patch)
treeb34db3bdae7ff776e4c02eea983d3b894d9e99ad
parentCreate ConcatenateExprMixin (diff)
downloadpolymatrix-ab63e0bce3e4fa64f4d713316e65128da1a4b06b.tar.gz
polymatrix-ab63e0bce3e4fa64f4d713316e65128da1a4b06b.zip
Create ShapeExprMixin
-rw-r--r--polymatrix/expression/expression.py4
-rw-r--r--polymatrix/expression/impl.py9
-rw-r--r--polymatrix/expression/init.py4
-rw-r--r--polymatrix/expression/mixins/shapeexprmixin.py30
4 files changed, 47 insertions, 0 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 291560f..195a256 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -339,6 +339,7 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
+ # FIXME: replace with __setitem__?
def set_element_at(
self,
row: int,
@@ -358,6 +359,9 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
+ def shape(self) -> Expression:
+ return self.copy(underlying=polymatrix.expression.init.init_shape_expr(self.underlying))
+
# remove?
def squeeze(
self,
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 176fe39..f4c4df4 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -43,6 +43,7 @@ from polymatrix.expression.mixins.quadraticmonomialsexprmixin import QuadraticMo
from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin
from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin
from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin
+from polymatrix.expression.mixins.shapeexprmixin import ShapeExprMixin
from polymatrix.expression.mixins.sliceexprmixin import SliceExprMixin
from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin
from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin
@@ -364,6 +365,14 @@ class SetElementAtExprImpl(SetElementAtExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class ShapeExprImpl(ShapeExprMixin):
+ underlying: ExpressionBaseMixin
+
+ def __str__(self):
+ return f"shape({self.underlying})"
+
+
+@dataclassabc.dataclassabc(frozen=True)
class SliceExprImpl(SliceExprMixin):
underlying: ExpressionBaseMixin
slice: tuple[int | slice | range, int | slice | range]
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 7c3f5fa..60fa580 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -376,6 +376,10 @@ def init_set_element_at_expr(
)
+def init_shape_expr(underlying: ExpressionBaseMixin):
+ return polymatrix.expression.impl.ShapeExprImpl(underlying)
+
+
def init_slice_expr(
underlying: ExpressionBaseMixin,
slice: tuple[int | slice | range, int | slice | range]
diff --git a/polymatrix/expression/mixins/shapeexprmixin.py b/polymatrix/expression/mixins/shapeexprmixin.py
new file mode 100644
index 0000000..916289f
--- /dev/null
+++ b/polymatrix/expression/mixins/shapeexprmixin.py
@@ -0,0 +1,30 @@
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MatrixIndex, MonomialIndex
+
+class ShapeExprMixin(ExpressionBaseMixin):
+ """
+ Get the shape of a polymatrix.
+ This gives the shape as a row vector [[nrows], [ncols]].
+ """
+ @property
+ @abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ """ The expression for which we compute the shape. """
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+ state, u = self.underlying.apply(state)
+ nrows, ncols = u.shape
+ p = PolyMatrixDict({
+ MatrixIndex(0, 0): PolyDict({MonomialIndex.constant(): nrows}),
+ MatrixIndex(1, 0): PolyDict({MonomialIndex.constant(): ncols})
+ })
+
+ return state, init_poly_matrix(p, shape=(2,1))
+