diff options
-rw-r--r-- | polymatrix/expression/mixins/getitemexprmixin.py | 1 | ||||
-rw-r--r-- | polymatrix/expression/mixins/reshapeexprmixin.py | 42 |
2 files changed, 30 insertions, 13 deletions
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py index 60adac1..4ac9aff 100644 --- a/polymatrix/expression/mixins/getitemexprmixin.py +++ b/polymatrix/expression/mixins/getitemexprmixin.py @@ -1,7 +1,6 @@ import abc import dataclasses -import typing import dataclassabc from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin diff --git a/polymatrix/expression/mixins/reshapeexprmixin.py b/polymatrix/expression/mixins/reshapeexprmixin.py index 683d2fa..3f2e837 100644 --- a/polymatrix/expression/mixins/reshapeexprmixin.py +++ b/polymatrix/expression/mixins/reshapeexprmixin.py @@ -1,5 +1,6 @@ import abc import functools +import itertools import operator import dataclassabc import numpy as np @@ -16,7 +17,7 @@ class ReshapeExprMixin(ExpressionBaseMixin): @property @abc.abstractclassmethod - def new_shape(self) -> tuple[int, int]: + def new_shape(self) -> tuple[int | ExpressionBaseMixin, int | ExpressionBaseMixin]: ... # overwrites abstract method of `ExpressionBaseMixin` @@ -39,34 +40,51 @@ class ReshapeExprMixin(ExpressionBaseMixin): underlying_col = int(index / self.underlying_shape[0]) underlying_row = index - underlying_col * self.underlying_shape[0] - # print(f'{row=}, {col=}') - # print(f'{underlying_row=}, {underlying_col=}') - return self.underlying.get_poly(underlying_row, underlying_col) + # replace expression by their number of rows + def acc_new_shape(acc, index): + state, acc_indices = acc + + # for idx in self.new_shape: + if isinstance(index, int): + pass + + elif isinstance(index, ExpressionBaseMixin): + state, polymatrix = index.apply(state) + index = polymatrix.shape[0] + + else: + raise Exception(f'{index=}') + + return state, acc_indices + (index,) + + *_, (state, new_shape) = itertools.accumulate( + self.new_shape, + acc_new_shape, + initial=(state, tuple()), + ) + # replace '-1' by the remaining number of elements - if -1 in self.new_shape: + if -1 in new_shape: n_total = underlying.shape[0] * underlying.shape[1] - remaining_shape = tuple(e for e in self.new_shape if e != -1) + remaining_shape = tuple(e for e in new_shape if e != -1) - assert len(remaining_shape) + 1 == len(self.new_shape) + assert len(remaining_shape) + 1 == len(new_shape) n_used = functools.reduce(operator.mul, remaining_shape) n_remaining = int(n_total / n_used) def gen_shape(): - for e in self.new_shape: + for e in new_shape: if e == -1: yield n_remaining else: yield e - new_shape = tuple(gen_shape()) - - else: - new_shape = self.new_shape + new_shape = tuple(gen_shape()) return state, ReshapePolyMatrix( underlying=underlying, |