summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-04 23:07:04 +0200
committerNao Pross <np@0hm.ch>2024-05-04 23:07:04 +0200
commite8341a6367db6e98ed3ebd6a732544fc52c006e8 (patch)
tree2deea4795361d096661ca1b7e80a8f1946dfa34d
parentFix FromNumbersExpr.__str__ (diff)
downloadpolymatrix-e8341a6367db6e98ed3ebd6a732544fc52c006e8.tar.gz
polymatrix-e8341a6367db6e98ed3ebd6a732544fc52c006e8.zip
Adapt nested polymatrix classes to use new API
-rw-r--r--polymatrix/expression/mixins/blockdiagexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/diagexprmixin.py32
-rw-r--r--polymatrix/expression/mixins/eyeexprmixin.py17
-rw-r--r--polymatrix/expression/mixins/fromsympyexprmixin.py2
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/repmatexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/reshapeexprmixin.py5
-rw-r--r--polymatrix/expression/mixins/setelementatexprmixin.py8
-rw-r--r--polymatrix/expression/mixins/symmetricexprmixin.py1
-rw-r--r--polymatrix/expression/mixins/transposeexprmixin.py7
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py7
11 files changed, 58 insertions, 39 deletions
diff --git a/polymatrix/expression/mixins/blockdiagexprmixin.py b/polymatrix/expression/mixins/blockdiagexprmixin.py
index 77edf2b..420372f 100644
--- a/polymatrix/expression/mixins/blockdiagexprmixin.py
+++ b/polymatrix/expression/mixins/blockdiagexprmixin.py
@@ -10,6 +10,7 @@ if typing.TYPE_CHECKING:
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.typing import PolyDict
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -34,13 +35,16 @@ class BlockDiagExprMixin(ExpressionBaseMixin):
state, polymat = expr.apply(state=state)
all_underlying.append(polymat)
- # NP: this is a very weird place to put a class
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class BlockDiagPolyMatrix(PolyMatrixMixin):
all_underlying: tuple[PolyMatrixMixin]
underlying_row_col_range: tuple[tuple[int, int], ...]
shape: tuple[int, int]
+ def at(self, row: int, col: int) -> PolyDict:
+ return self.get_poly(row, col) or PolyDict.empty()
+
# FIXME: typing problems
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
for polymatrix, ((row_start, col_start), (row_end, col_end)) in zip(
diff --git a/polymatrix/expression/mixins/diagexprmixin.py b/polymatrix/expression/mixins/diagexprmixin.py
index 66d5a6c..9d019e4 100644
--- a/polymatrix/expression/mixins/diagexprmixin.py
+++ b/polymatrix/expression/mixins/diagexprmixin.py
@@ -9,6 +9,7 @@ if typing.TYPE_CHECKING:
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.polymatrix.typing import PolyDict
class DiagExprMixin(ExpressionBaseMixin):
@@ -32,40 +33,41 @@ class DiagExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
state, underlying = self.underlying.apply(state)
+ # Vector to diagonal matrix
if underlying.shape[1] == 1:
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
- class DiagPolyMatrix(PolyMatrixMixin):
+ class DiagFromVecPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
shape: tuple[int, int]
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- if row == col:
- return self.underlying.get_poly(row, 0)
- else:
- # FIXME: should return none according to base class
- # NP: Though returning zero makes more sense
- return {tuple(): 0.0}
+ def at(self, row: int, col: int) -> PolyDict:
+ if row != col:
+ return PolyDict.empty()
- return state, DiagPolyMatrix(
+ return self.underlying.get_poly(row, 0)
+
+ return state, DiagFromVecPolyMatrix(
underlying=underlying,
shape=(underlying.shape[0], underlying.shape[0]),
)
+ # Diagonal matrix to vector
else:
- # NP: replace assertions with meaningful exception
+ # FIXME: replace assertions with meaningful error message
assert underlying.shape[0] == underlying.shape[1], f"{underlying.shape=}"
- # NP: why is this called Trace?
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
- class TracePolyMatrix(PolyMatrixMixin):
+ class VecFromDiagPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
shape: tuple[int, int]
- def get_poly(self, row: int, _) -> dict[tuple[int, ...], float]:
- return self.underlying.get_poly(row, row)
+ def at(self, row: int, _col: int) -> PolyDict:
+ return self.underlying.at(row, row)
- return state, TracePolyMatrix(
+ return state, VecFromDiagPolyMatrix(
underlying=underlying,
shape=(underlying.shape[0], 1),
)
diff --git a/polymatrix/expression/mixins/eyeexprmixin.py b/polymatrix/expression/mixins/eyeexprmixin.py
index b18c89b..8b61018 100644
--- a/polymatrix/expression/mixins/eyeexprmixin.py
+++ b/polymatrix/expression/mixins/eyeexprmixin.py
@@ -10,6 +10,7 @@ if typing.TYPE_CHECKING:
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.typing import PolyDict, MonomialIndex
class EyeExprMixin(ExpressionBaseMixin):
@@ -24,20 +25,20 @@ class EyeExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionState, PolyMatrix]:
state, variable = self.variable.apply(state)
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class EyePolyMatrix(PolyMatrixMixin):
shape: tuple[int, int]
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- if max(row, col) <= self.shape[0]:
- if row == col:
- return {tuple(): 1.0}
+ def at(self, row: int, col: int) -> PolyDict:
+ size, _ = self.shape
+ if max(row, col) > size:
+ raise IndexError(f"Identity matrix has size {size}, {row, col} is out of bounds.")
- else:
- return None
+ if row != col:
+ return PolyDict.empty()
- else:
- raise Exception(f"{(row, col)=} is out of bounds")
+ return PolyDict({MonomialIndex.constant(): 1.})
# FIXME: this behaviour is counterintuitive, eye should take just a number for the dimension
n_row = variable.shape[0]
diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py
index 784ba00..3a896e1 100644
--- a/polymatrix/expression/mixins/fromsympyexprmixin.py
+++ b/polymatrix/expression/mixins/fromsympyexprmixin.py
@@ -75,7 +75,7 @@ class FromSympyExprMixin(ExpressionBaseMixin):
raise ValueError(f"Cannot convert sympy expression {entry} "
"into a polynomial, are you sure it is a polynomial?") from e
- # Convert sympy variables to our variables, i.e VariableMixin
+ # Convert sympy variables to our variables
sympy_to_var = {
sympy_idx: init_variable(var.name, shape=(1,1))
for sympy_idx, var in enumerate(sympy_poly.gens)
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
index f8fece4..81cb439 100644
--- a/polymatrix/expression/mixins/getitemexprmixin.py
+++ b/polymatrix/expression/mixins/getitemexprmixin.py
@@ -11,6 +11,7 @@ if typing.TYPE_CHECKING:
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.typing import PolyDict
class GetItemExprMixin(ExpressionBaseMixin):
@@ -65,6 +66,7 @@ class GetItemExprMixin(ExpressionBaseMixin):
get_proper_index(self.index[1], underlying.shape[1]),
)
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class GetItemPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
@@ -74,6 +76,10 @@ class GetItemExprMixin(ExpressionBaseMixin):
def shape(self) -> tuple[int, int]:
return (len(self.index[0]), len(self.index[1]))
+ def at(self, row: int, col: int) -> PolyDict:
+ # FIXME: this is a quick fix
+ return self.get_poly(self, row, col) or PolyDict.empty()
+
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
try:
n_row = self.index[0][row]
diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py
index 324a298..29e0553 100644
--- a/polymatrix/expression/mixins/repmatexprmixin.py
+++ b/polymatrix/expression/mixins/repmatexprmixin.py
@@ -9,6 +9,7 @@ if typing.TYPE_CHECKING:
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.polymatrix.typing import PolyDict
class RepMatExprMixin(ExpressionBaseMixin):
@@ -27,18 +28,19 @@ class RepMatExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
state, underlying = self.underlying.apply(state)
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class RepMatPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
shape: tuple[int, int]
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ def at(self, row: int, col: int) -> PolyDict:
n_row, n_col = underlying.shape
rel_row = row % n_row
rel_col = col % n_col
- return self.underlying.get_poly(rel_row, rel_col)
+ return self.underlying.at(rel_row, rel_col)
return state, RepMatPolyMatrix(
underlying=underlying,
diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py
index 84c2581..4c1140e 100644
--- a/polymatrix/expression/mixins/reshapeexprmixin.py
+++ b/polymatrix/expression/mixins/reshapeexprmixin.py
@@ -12,6 +12,7 @@ if typing.TYPE_CHECKING:
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.mixins import PolyMatrixMixin
+from polymatrix.polymatrix.typing import PolyDict
class ReshapeExprMixin(ExpressionBaseMixin):
@@ -38,13 +39,13 @@ class ReshapeExprMixin(ExpressionBaseMixin):
shape: tuple[int, int]
underlying_shape: tuple[int, int]
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ def at(self, row: int, col: int) -> PolyDict:
index = row + self.shape[0] * col
underlying_col = int(index / self.underlying_shape[0])
underlying_row = index - underlying_col * self.underlying_shape[0]
- return self.underlying.get_poly(underlying_row, underlying_col)
+ return self.underlying.at(underlying_row, underlying_col)
# replace expression by their number of rows
def acc_new_shape(acc, index):
diff --git a/polymatrix/expression/mixins/setelementatexprmixin.py b/polymatrix/expression/mixins/setelementatexprmixin.py
index 9757a28..35568b4 100644
--- a/polymatrix/expression/mixins/setelementatexprmixin.py
+++ b/polymatrix/expression/mixins/setelementatexprmixin.py
@@ -12,6 +12,7 @@ from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.typing import PolyDict
class SetElementAtExprMixin(ExpressionBaseMixin):
@@ -47,18 +48,19 @@ class SetElementAtExprMixin(ExpressionBaseMixin):
if polynomial is None:
polynomial = 0
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class SetElementAtPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
shape: tuple[int, int]
index: tuple[int, int]
- polynomial: dict
+ polynomial: PolyDict
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
+ def at(self, row: int, col: int) -> PolyDict:
if (row, col) == self.index:
return self.polynomial
else:
- return self.underlying.get_poly(row, col)
+ return self.underlying.at(row, col)
return state, SetElementAtPolyMatrix(
underlying=underlying,
diff --git a/polymatrix/expression/mixins/symmetricexprmixin.py b/polymatrix/expression/mixins/symmetricexprmixin.py
index f585240..f608139 100644
--- a/polymatrix/expression/mixins/symmetricexprmixin.py
+++ b/polymatrix/expression/mixins/symmetricexprmixin.py
@@ -34,6 +34,7 @@ class SymmetricExprMixin(ExpressionBaseMixin):
assert underlying.shape[0] == underlying.shape[1]
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class SymmetricPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py
index fbd5f30..0e076d3 100644
--- a/polymatrix/expression/mixins/transposeexprmixin.py
+++ b/polymatrix/expression/mixins/transposeexprmixin.py
@@ -9,6 +9,7 @@ if typing.TYPE_CHECKING:
from polymatrix.polymatrix.mixins import PolyMatrixMixin
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.polymatrix.typing import PolyDict
from polymatrix.polymatrix.abc import PolyMatrix
@@ -30,13 +31,15 @@ class TransposeExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class TransposePolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
shape: tuple[int, int]
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- return self.underlying.get_poly(col, row)
+ def at(self, row: int, col: int) -> PolyDict:
+ return self.underlying.at(col, row)
+
return state, TransposePolyMatrix(
underlying=underlying,
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
index 2091f88..dc9339a 100644
--- a/polymatrix/expression/mixins/vstackexprmixin.py
+++ b/polymatrix/expression/mixins/vstackexprmixin.py
@@ -40,6 +40,7 @@ class VStackExprMixin(ExpressionBaseMixin):
underlying.shape[1] == all_underlying[0].shape[1]
), f"{underlying.shape[1]} not equal {all_underlying[0].shape[1]}"
+ # FIXME: move to polymatrix module
@dataclassabc.dataclassabc(frozen=True)
class VStackPolyMatrix(PolyMatrixMixin):
all_underlying: tuple[PolyMatrixMixin]
@@ -47,15 +48,11 @@ class VStackExprMixin(ExpressionBaseMixin):
shape: tuple[int, int]
def at(self, row: int, col: int) -> PolyDict:
- # FIXME: this is a quick workaround
- return self.get_poly(row, col) or PolyDict.empty()
-
- def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
for polymatrix, (row_start, row_end) in zip(
self.all_underlying, self.underlying_row_range
):
if row_start <= row < row_end:
- return polymatrix.get_poly(
+ return polymatrix.at(
row=row - row_start,
col=col,
)