summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py1
-rw-r--r--polymatrix/expression/mixins/reshapeexprmixin.py42
2 files changed, 30 insertions, 13 deletions
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
index 60adac1..4ac9aff 100644
--- a/polymatrix/expression/mixins/getitemexprmixin.py
+++ b/polymatrix/expression/mixins/getitemexprmixin.py
@@ -1,7 +1,6 @@
import abc
import dataclasses
-import typing
import dataclassabc
from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py
index 683d2fa..3f2e837 100644
--- a/polymatrix/expression/mixins/reshapeexprmixin.py
+++ b/polymatrix/expression/mixins/reshapeexprmixin.py
@@ -1,5 +1,6 @@
import abc
import functools
+import itertools
import operator
import dataclassabc
import numpy as np
@@ -16,7 +17,7 @@ class ReshapeExprMixin(ExpressionBaseMixin):
@property
@abc.abstractclassmethod
- def new_shape(self) -> tuple[int, int]:
+ def new_shape(self) -> tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin]:
...
# overwrites abstract method of `ExpressionBaseMixin`
@@ -39,34 +40,51 @@ class ReshapeExprMixin(ExpressionBaseMixin):
underlying_col = int(index / self.underlying_shape[0])
underlying_row = index - underlying_col * self.underlying_shape[0]
- # print(f'{row=}, {col=}')
- # print(f'{underlying_row=}, {underlying_col=}')
-
return self.underlying.get_poly(underlying_row, underlying_col)
+ # replace expression by their number of rows
+ def acc_new_shape(acc, index):
+ state, acc_indices = acc
+
+ # for idx in self.new_shape:
+ if isinstance(index, int):
+ pass
+
+ elif isinstance(index, ExpressionBaseMixin):
+ state, polymatrix = index.apply(state)
+ index = polymatrix.shape[0]
+
+ else:
+ raise Exception(f'{index=}')
+
+ return state, acc_indices + (index,)
+
+ *_, (state, new_shape) = itertools.accumulate(
+ self.new_shape,
+ acc_new_shape,
+ initial=(state, tuple()),
+ )
+
# replace '-1' by the remaining number of elements
- if -1 in self.new_shape:
+ if -1 in new_shape:
n_total = underlying.shape[0] * underlying.shape[1]
- remaining_shape = tuple(e for e in self.new_shape if e != -1)
+ remaining_shape = tuple(e for e in new_shape if e != -1)
- assert len(remaining_shape) + 1 == len(self.new_shape)
+ assert len(remaining_shape) + 1 == len(new_shape)
n_used = functools.reduce(operator.mul, remaining_shape)
n_remaining = int(n_total / n_used)
def gen_shape():
- for e in self.new_shape:
+ for e in new_shape:
if e == -1:
yield n_remaining
else:
yield e
- new_shape = tuple(gen_shape())
-
- else:
- new_shape = self.new_shape
+ new_shape = tuple(gen_shape())
return state, ReshapePolyMatrix(
underlying=underlying,