diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/__init__.py | 2 | ||||
-rw-r--r-- | polymatrix/expression/__init__.py | 17 | ||||
-rw-r--r-- | polymatrix/expression/impl.py | 20 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 11 | ||||
-rw-r--r-- | polymatrix/expression/mixins/arangeexprmixin.py | 78 | ||||
-rw-r--r-- | polymatrix/expression/mixins/nsexprmixin.py | 2 |
6 files changed, 126 insertions, 4 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py index 59637d4..475749f 100644 --- a/polymatrix/__init__.py +++ b/polymatrix/__init__.py @@ -13,6 +13,7 @@ from polymatrix.expression.from_ import ( from polymatrix.expression import ( Expression as internal_Expression, + arange as internal_arange, v_stack as internal_v_stack, h_stack as internal_h_stack, product as internal_product, @@ -35,6 +36,7 @@ ExpressionState = internal_ExpressionState init_expression_state = internal_init_expression_state make_state = init_expression_state +arange = internal_arange v_stack = internal_v_stack h_stack = internal_h_stack product = internal_product diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py index 87ee99d..30be90b 100644 --- a/polymatrix/expression/__init__.py +++ b/polymatrix/expression/__init__.py @@ -9,6 +9,7 @@ from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.expression import init_expression, Expression from polymatrix.expression.init import ( + init_arange_expr, init_block_diag_expr, init_concatenate_expr, init_lower_triangular_expr, @@ -37,6 +38,22 @@ def ones(shape: Expression): """ Make a matrix filled with ones """ return init_expression(init_ns_expr(n=1., shape=shape.underlying)) +@convert_args_to_expression +def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expression | None = None): + # Replicate range()'s behaviour + if stop is None and step is None: + return init_arange_expr(start=None, stop=start_or_stop, step=None) + + elif stop is not None and step is None: + return init_arange_expr(start=start_or_stop, stop=stop, step=None) + + elif stop is not None and step is not None: + return init_arange_expr(start=start_or_stop, stop=stop, step=step) + + else: + # FIXME: error message missing + raise ValueError + @convert_args_to_expression def zeros(shape: Expression): diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index ad5b9b8..65cc2fe 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -8,6 +8,7 @@ import dataclassabc from polymatrix.utils.getstacklines import FrameSummary from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin +from polymatrix.expression.mixins.arangeexprmixin import ARangeExprMixin from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin @@ -78,6 +79,25 @@ class AdditionExprImpl(AdditionExprMixin): @dataclassabc.dataclassabc(frozen=True) +class ARangeExprImpl(ARangeExprMixin): + start: int | ExpressionBaseMixin | None + stop: int | ExpressionBaseMixin + step: int | ExpressionBaseMixin | None + + def __str__(self): + if self.start: + if self.step: + return f"arange({self.start}, {self.stop}, {self.end})" + else: + return f"arange({self.start}, {self.stop})" + else: + if self.step: + return f"arange(0, {self.stop}, {self.step})" + else: + return f"arange({self.stop})" + + +@dataclassabc.dataclassabc(frozen=True) class BlockDiagExprImpl(BlockDiagExprMixin): underlying: tuple[ExpressionBaseMixin] diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index abe0c1b..2740d38 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -14,7 +14,6 @@ from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.getstacklines import get_stack_lines from polymatrix.expression.utils.formatsubstitutions import format_substitutions from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.impl import AdditionExprImpl from polymatrix.expression.expression import VariableExpression @@ -23,13 +22,21 @@ def init_addition_expr( right: ExpressionBaseMixin, stack: tuple[FrameSummary], ): - return AdditionExprImpl( + return polymatrix.expression.impl.AdditionExprImpl( left=left, right=right, stack=stack, ) +def init_arange_expr( + start: int | ExpressionBaseMixin | None, + stop: int | ExpressionBaseMixin, + step: int | ExpressionBaseMixin | None +): + return polymatrix.expression.impl.ARangeExprImpl(start, stop, step) + + def init_block_diag_expr( underlying: tuple, ): diff --git a/polymatrix/expression/mixins/arangeexprmixin.py b/polymatrix/expression/mixins/arangeexprmixin.py new file mode 100644 index 0000000..ecd0fcb --- /dev/null +++ b/polymatrix/expression/mixins/arangeexprmixin.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expressionstate import ExpressionState +from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.mixins import PolyMatrixMixin + +class ARangeExprMixin(ExpressionBaseMixin): + """ + Create a column vector of evenly spaced integer values in an interval. + Essentially a wrapper around python's `range` built-in function. + """ + + @property + @abstractmethod + def start(self) -> int | ExpressionBaseMixin | None: + """ Start of the range """ + + @property + @abstractmethod + def stop(self) -> int | ExpressionBaseMixin: + """ End of the range, this value is not included in the interval. """ + + @property + @abstractmethod + def step(self) -> int | ExpressionBaseMixin | None: + """ Step of the range """ + + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + if not self.start: + start = 0 + + elif isinstance(self.start, int): + start = self.start + elif isinstance(self.start, ExpressionBaseMixin): + state, pm = self.start.apply(state) + # TODO: check that is actually an integer + start = int(pm.scalar().constant()) + else: + raise TypeError + + if isinstance(self.stop, int): + stop = self.stop + elif isinstance(self.stop, ExpressionBaseMixin): + state, pm = self.stop.apply(state) + # TODO: check that is actually an integer + stop = int(pm.scalar().constant()) + else: + raise TypeError + + if not self.step: + step = 1 + elif isinstance(self.step, int): + step = self.step + elif isinstance(self.step, ExpressionBaseMixin): + state, pm = self.step.apply(state) + # TODO: check that is actually an integer + step = int(pm.scalar().constant()) + else: + raise TypeError + + p = PolyMatrixDict.empty() + values = tuple(range(start, stop, step)) + + for r, v in enumerate(values): + p[r, 0] = PolyDict({ MonomialIndex.constant(): v }) + + return state, init_poly_matrix(p, shape=(len(values), 1)) + + + + diff --git a/polymatrix/expression/mixins/nsexprmixin.py b/polymatrix/expression/mixins/nsexprmixin.py index f97d724..6129093 100644 --- a/polymatrix/expression/mixins/nsexprmixin.py +++ b/polymatrix/expression/mixins/nsexprmixin.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import abstractmethod -from math import sqrt, isclose from typing_extensions import override from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin @@ -31,7 +30,6 @@ class NsExprMixin(ExpressionBaseMixin): """ @override - @override def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: if isinstance(self.n, ExpressionBaseMixin): state, pm = self.n.apply(state) |