summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-26 19:24:24 +0200
committerNao Pross <np@0hm.ch>2024-05-27 00:02:18 +0200
commit4de26b33a476c43be5e12e7fd342672745fcd071 (patch)
tree2eb359a51935c4bb79c858641584f3cdd2edb117
parentCreate pretty printing functions for some polymatrix types (diff)
downloadpolymatrix-4de26b33a476c43be5e12e7fd342672745fcd071.tar.gz
polymatrix-4de26b33a476c43be5e12e7fd342672745fcd071.zip
Create poly.arange and ARangeExprMixin
-rw-r--r--polymatrix/__init__.py2
-rw-r--r--polymatrix/expression/__init__.py17
-rw-r--r--polymatrix/expression/impl.py20
-rw-r--r--polymatrix/expression/init.py11
-rw-r--r--polymatrix/expression/mixins/arangeexprmixin.py78
-rw-r--r--polymatrix/expression/mixins/nsexprmixin.py2
6 files changed, 126 insertions, 4 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 59637d4..475749f 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -13,6 +13,7 @@ from polymatrix.expression.from_ import (
from polymatrix.expression import (
Expression as internal_Expression,
+ arange as internal_arange,
v_stack as internal_v_stack,
h_stack as internal_h_stack,
product as internal_product,
@@ -35,6 +36,7 @@ ExpressionState = internal_ExpressionState
init_expression_state = internal_init_expression_state
make_state = init_expression_state
+arange = internal_arange
v_stack = internal_v_stack
h_stack = internal_h_stack
product = internal_product
diff --git a/polymatrix/expression/__init__.py b/polymatrix/expression/__init__.py
index 87ee99d..30be90b 100644
--- a/polymatrix/expression/__init__.py
+++ b/polymatrix/expression/__init__.py
@@ -9,6 +9,7 @@ from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.expression.expression import init_expression, Expression
from polymatrix.expression.init import (
+ init_arange_expr,
init_block_diag_expr,
init_concatenate_expr,
init_lower_triangular_expr,
@@ -37,6 +38,22 @@ def ones(shape: Expression):
""" Make a matrix filled with ones """
return init_expression(init_ns_expr(n=1., shape=shape.underlying))
+@convert_args_to_expression
+def arange(start_or_stop: Expression, stop: Expression | None = None, step: Expression | None = None):
+ # Replicate range()'s behaviour
+ if stop is None and step is None:
+ return init_arange_expr(start=None, stop=start_or_stop, step=None)
+
+ elif stop is not None and step is None:
+ return init_arange_expr(start=start_or_stop, stop=stop, step=None)
+
+ elif stop is not None and step is not None:
+ return init_arange_expr(start=start_or_stop, stop=stop, step=step)
+
+ else:
+ # FIXME: error message missing
+ raise ValueError
+
@convert_args_to_expression
def zeros(shape: Expression):
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index ad5b9b8..65cc2fe 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -8,6 +8,7 @@ import dataclassabc
from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin
+from polymatrix.expression.mixins.arangeexprmixin import ARangeExprMixin
from polymatrix.expression.mixins.blockdiagexprmixin import BlockDiagExprMixin
from polymatrix.expression.mixins.cacheexprmixin import CacheExprMixin
from polymatrix.expression.mixins.combinationsexprmixin import CombinationsExprMixin
@@ -78,6 +79,25 @@ class AdditionExprImpl(AdditionExprMixin):
@dataclassabc.dataclassabc(frozen=True)
+class ARangeExprImpl(ARangeExprMixin):
+ start: int | ExpressionBaseMixin | None
+ stop: int | ExpressionBaseMixin
+ step: int | ExpressionBaseMixin | None
+
+ def __str__(self):
+ if self.start:
+ if self.step:
+ return f"arange({self.start}, {self.stop}, {self.end})"
+ else:
+ return f"arange({self.start}, {self.stop})"
+ else:
+ if self.step:
+ return f"arange(0, {self.stop}, {self.step})"
+ else:
+ return f"arange({self.stop})"
+
+
+@dataclassabc.dataclassabc(frozen=True)
class BlockDiagExprImpl(BlockDiagExprMixin):
underlying: tuple[ExpressionBaseMixin]
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index abe0c1b..2740d38 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -14,7 +14,6 @@ from polymatrix.utils.getstacklines import FrameSummary
from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.expression.utils.formatsubstitutions import format_substitutions
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expression.impl import AdditionExprImpl
from polymatrix.expression.expression import VariableExpression
@@ -23,13 +22,21 @@ def init_addition_expr(
right: ExpressionBaseMixin,
stack: tuple[FrameSummary],
):
- return AdditionExprImpl(
+ return polymatrix.expression.impl.AdditionExprImpl(
left=left,
right=right,
stack=stack,
)
+def init_arange_expr(
+ start: int | ExpressionBaseMixin | None,
+ stop: int | ExpressionBaseMixin,
+ step: int | ExpressionBaseMixin | None
+):
+ return polymatrix.expression.impl.ARangeExprImpl(start, stop, step)
+
+
def init_block_diag_expr(
underlying: tuple,
):
diff --git a/polymatrix/expression/mixins/arangeexprmixin.py b/polymatrix/expression/mixins/arangeexprmixin.py
new file mode 100644
index 0000000..ecd0fcb
--- /dev/null
+++ b/polymatrix/expression/mixins/arangeexprmixin.py
@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing_extensions import override
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expressionstate import ExpressionState
+from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.mixins import PolyMatrixMixin
+
+class ARangeExprMixin(ExpressionBaseMixin):
+ """
+ Create a column vector of evenly spaced integer values in an interval.
+ Essentially a wrapper around python's `range` built-in function.
+ """
+
+ @property
+ @abstractmethod
+ def start(self) -> int | ExpressionBaseMixin | None:
+ """ Start of the range """
+
+ @property
+ @abstractmethod
+ def stop(self) -> int | ExpressionBaseMixin:
+ """ End of the range, this value is not included in the interval. """
+
+ @property
+ @abstractmethod
+ def step(self) -> int | ExpressionBaseMixin | None:
+ """ Step of the range """
+
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
+ if not self.start:
+ start = 0
+
+ elif isinstance(self.start, int):
+ start = self.start
+ elif isinstance(self.start, ExpressionBaseMixin):
+ state, pm = self.start.apply(state)
+ # TODO: check that is actually an integer
+ start = int(pm.scalar().constant())
+ else:
+ raise TypeError
+
+ if isinstance(self.stop, int):
+ stop = self.stop
+ elif isinstance(self.stop, ExpressionBaseMixin):
+ state, pm = self.stop.apply(state)
+ # TODO: check that is actually an integer
+ stop = int(pm.scalar().constant())
+ else:
+ raise TypeError
+
+ if not self.step:
+ step = 1
+ elif isinstance(self.step, int):
+ step = self.step
+ elif isinstance(self.step, ExpressionBaseMixin):
+ state, pm = self.step.apply(state)
+ # TODO: check that is actually an integer
+ step = int(pm.scalar().constant())
+ else:
+ raise TypeError
+
+ p = PolyMatrixDict.empty()
+ values = tuple(range(start, stop, step))
+
+ for r, v in enumerate(values):
+ p[r, 0] = PolyDict({ MonomialIndex.constant(): v })
+
+ return state, init_poly_matrix(p, shape=(len(values), 1))
+
+
+
+
diff --git a/polymatrix/expression/mixins/nsexprmixin.py b/polymatrix/expression/mixins/nsexprmixin.py
index f97d724..6129093 100644
--- a/polymatrix/expression/mixins/nsexprmixin.py
+++ b/polymatrix/expression/mixins/nsexprmixin.py
@@ -1,7 +1,6 @@
from __future__ import annotations
from abc import abstractmethod
-from math import sqrt, isclose
from typing_extensions import override
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -31,7 +30,6 @@ class NsExprMixin(ExpressionBaseMixin):
"""
@override
- @override
def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]:
if isinstance(self.n, ExpressionBaseMixin):
state, pm = self.n.apply(state)