summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNao Pross <np@0hm.ch>2024-05-01 22:29:33 +0200
committerNao Pross <np@0hm.ch>2024-05-01 22:39:16 +0200
commit7074ab2e4056ef70dcdbcf8ca97d483e0f106c3c (patch)
tree89dede6e589ef1b74d0288272fa4056242815cd8
parentMinor changes to PolyMatrixAsAffineExpression (diff)
downloadpolymatrix-7074ab2e4056ef70dcdbcf8ca97d483e0f106c3c.tar.gz
polymatrix-7074ab2e4056ef70dcdbcf8ca97d483e0f106c3c.zip
Replace VariableMixin with Variable & VariableExpression
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/expression.py66
-rw-r--r--polymatrix/expression/from_.py10
-rw-r--r--polymatrix/expression/impl.py7
-rw-r--r--polymatrix/expression/init.py4
-rw-r--r--polymatrix/expression/mixins/variablemixin.py45
-rw-r--r--polymatrix/expressionstate/mixins.py16
-rw-r--r--polymatrix/variable/__init__.py0
-rw-r--r--polymatrix/variable/abc.py13
-rw-r--r--polymatrix/variable/impl.py8
-rw-r--r--polymatrix/variable/init.py5
10 files changed, 92 insertions, 82 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 4ef955d..0e15f44 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -1,19 +1,25 @@
from __future__ import annotations
-import abc
import dataclasses
-import dataclassabc
import typing
+import itertools
import numpy as np
+from abc import ABC, abstractmethod
+from dataclassabc import dataclassabc
+from typing_extensions import override
+
if typing.TYPE_CHECKING:
from polymatrix.expressionstate.abc import ExpressionState
+from polymatrix.variable.abc import Variable
import polymatrix.expression.init
from polymatrix.utils.getstacklines import get_stack_lines
from polymatrix.polymatrix.abc import PolyMatrix
+from polymatrix.polymatrix.init import init_poly_matrix
+from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.op import (
diff,
@@ -25,12 +31,9 @@ from polymatrix.expression.op import (
degree,
)
-class Expression(
- ExpressionBaseMixin,
- abc.ABC,
-):
+class Expression(ExpressionBaseMixin, ABC):
@property
- @abc.abstractmethod
+ @abstractmethod
def underlying(self) -> ExpressionBaseMixin:
# NP: I know what it is but it was very confusing in the beginning.
# FIXME: provide documentation on how underlying works / what it means
@@ -67,7 +70,7 @@ class Expression(
)
def __matmul__(
- self, other: typing.Union[ExpressionBaseMixin, np.ndarray]
+ self, other: ExpressionBaseMixin | np.ndarray
) -> "Expression":
return self._binary(
polymatrix.expression.init.init_matrix_mult_expr, self, other
@@ -109,7 +112,7 @@ class Expression(
def __truediv__(self, other: ExpressionBaseMixin):
return self._binary(polymatrix.expression.init.init_division_expr, self, other)
- @abc.abstractmethod
+ @abstractmethod
def copy(self, underlying: ExpressionBaseMixin) -> "Expression": ...
@staticmethod
@@ -500,7 +503,7 @@ class Expression(
# This is here and not in impl.py because of circular imports
# FIXME: this is not ideal
-@dataclassabc.dataclassabc(frozen=True, repr=False)
+@dataclassabc(frozen=True, repr=False)
class ExpressionImpl(Expression):
underlying: ExpressionBaseMixin
@@ -523,3 +526,46 @@ def init_expression(
return ExpressionImpl(
underlying=underlying,
)
+
+
+class VariableExpression(Expression, Variable):
+ """
+ Expression that is a polynomial variable, i.e. an expression that cannot be
+ reduced further.
+ """
+ @property
+ @override
+ def underlying(self):
+ return None
+
+ @override
+ def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrix]:
+ # Since there is no underlying this class directly does the job of
+ # creating a polymatrix
+
+ state, indices = state.index(self)
+ p = PolyMatrixDict()
+
+ rows, cols = self.shape
+ for (row, col), index in zip(itertools.product(range(rows), range(cols)), indices):
+ p[row, col] = PolyDict({
+ # Create monomial with variable to the first power
+ # with coefficient of one
+ MonomialIndex((VariableIndex(index, power=1),)): 1.
+ })
+
+ return state, init_poly_matrix(p, self.shape)
+
+
+@dataclassabc(frozen=True)
+class VariableExpressionImpl(VariableExpression):
+ name: str
+ shape: tuple[int, int]
+
+ @override
+ def copy(self, underlying: ExpressionBaseMixin) -> "Expression":
+ return self
+
+
+def init_variable_expression(name: str, shape: tuple[int, int] = (1, 1)):
+ return VariableExpressionImpl(name, shape)
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index af11d49..378a4b7 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -5,10 +5,8 @@ from polymatrix.expression.typing import FromDataTypes
import polymatrix.expression.init
-from polymatrix.expression.expression import init_expression, Expression
+from polymatrix.expression.expression import init_expression, Expression, init_variable_expression, VariableExpression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expression.mixins.variablemixin import VariableMixin
-from polymatrix.expression.init import init_variable_expr
from polymatrix.statemonad.abc import StateMonad
@@ -37,10 +35,10 @@ def from_(
)
-def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[VariableMixin] | VariableMixin:
+def from_names(names: str, shape: tuple[int, int] = (1,1)) -> tuple[VariableExpression] | VariableExpression:
""" Construct one or multiple variables from comma separated a list of names. """
- variables = tuple(init_expression(underlying=init_variable_expr(name.strip(), shape))
- for name in names.split(","))
+ variables = tuple(init_variable_expression(name.strip(), shape)
+ for name in names.split(","))
if len(variables) == 1:
return variables[0]
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index d4f5429..7e23329 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -68,7 +68,6 @@ from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import (
)
from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin
from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin
-from polymatrix.expression.mixins.variablemixin import VariableMixin
from polymatrix.expression.mixins.vstackexprmixin import VStackExprMixin
from polymatrix.expression.mixins.tosortedvariablesmixin import (
ToSortedVariablesExprMixin,
@@ -400,11 +399,5 @@ class TruncateExprImpl(TruncateExprMixin):
@dataclassabc.dataclassabc(frozen=True)
-class VariableImpl(VariableMixin):
- name: str
- shape: tuple[int, int]
-
-
-@dataclassabc.dataclassabc(frozen=True)
class VStackExprImpl(VStackExprMixin):
underlying: tuple
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 83ee26f..0e6faf9 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -529,7 +529,3 @@ def init_truncate_expr(
degrees=degrees,
inverse=inverse,
)
-
-
-def init_variable_expr(name: str, shape: tuple[int, int] = (1, 1)):
- return polymatrix.expression.impl.VariableImpl(name, shape)
diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py
deleted file mode 100644
index da72b65..0000000
--- a/polymatrix/expression/mixins/variablemixin.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from __future__ import annotations
-
-from abc import abstractmethod
-from itertools import product
-from typing_extensions import override
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from polymatrix.expressionstate.mixins import ExpressionStateMixin
-
-from polymatrix.polymatrix.mixins import PolyMatrixMixin
-from polymatrix.polymatrix.typing import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex
-from polymatrix.polymatrix.init import init_poly_matrix
-
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-
-
-class VariableMixin(ExpressionBaseMixin):
- """ Mixin to describe a polynomial variable, i.e. an expression that cannot
- be reduced further. """
-
- @property
- @abstractmethod
- def name(self) -> str:
- """ Name of the variable. """
-
- @property
- @abstractmethod
- def shape(self) -> tuple[int, int]:
- """ Shape of the variable. """
-
- @override
- def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
- state, indices = state.index(self)
- p = PolyMatrixDict()
-
- rows, cols = self.shape
- for (row, col), index in zip(product(range(rows), range(cols)), indices):
- p[row, col] = PolyDict({
- # Create monomial with variable to the first power
- # with coefficient of one
- MonomialIndex((VariableIndex(index, power=1),)): 1.
- })
-
- return state, init_poly_matrix(p, self.shape)
diff --git a/polymatrix/expressionstate/mixins.py b/polymatrix/expressionstate/mixins.py
index f937137..836032e 100644
--- a/polymatrix/expressionstate/mixins.py
+++ b/polymatrix/expressionstate/mixins.py
@@ -7,7 +7,7 @@ from math import prod
import dataclasses
from polymatrix.expression.expression import Expression
-from polymatrix.expression.mixins.variablemixin import VariableMixin
+from polymatrix.variable.abc import Variable
from polymatrix.statemonad.mixins import StateCacheMixin
@@ -30,18 +30,14 @@ class ExpressionStateMixin(
@property
@abstractmethod
- def indices(self) -> dict[VariableMixin, IndexRange]:
+ def indices(self) -> dict[Variable, IndexRange]:
""" Map from variable objects to their indices. """
- def index(self, var: VariableMixin) -> tuple[ExpressionStateMixin, IndexRange]:
+ def index(self, var: Variable) -> tuple[ExpressionStateMixin, IndexRange]:
""" Get the index of a variable. """
# Unwrap if wrapped in expression object
- if isinstance(var, Expression):
- var = var.underlying
-
- if not isinstance(var, VariableMixin):
- raise ValueError("State can only index object of type VariableMixin or expressions "
- "that contain a single VariableMixin object!")
+ if not isinstance(var, Variable):
+ raise ValueError("State can only index object of type Variable!")
# Check if already in there
if var in self.indices.keys():
@@ -62,7 +58,7 @@ class ExpressionStateMixin(
indices={**self.indices, var: index}
), index
- def register(self, var: VariableMixin) -> ExpressionStateMixin:
+ def register(self, var: Variable) -> ExpressionStateMixin:
"""
Create an index for a variable, but does not return the index. If you
want the index use
diff --git a/polymatrix/variable/__init__.py b/polymatrix/variable/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/polymatrix/variable/__init__.py
diff --git a/polymatrix/variable/abc.py b/polymatrix/variable/abc.py
new file mode 100644
index 0000000..70ad8de
--- /dev/null
+++ b/polymatrix/variable/abc.py
@@ -0,0 +1,13 @@
+from abc import ABC, abstractmethod
+
+class Variable(ABC):
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """ Name of the variable. """
+
+ @property
+ @abstractmethod
+ def shape(self) -> tuple[int, int]:
+ """ Shape of the variable. """
+
diff --git a/polymatrix/variable/impl.py b/polymatrix/variable/impl.py
new file mode 100644
index 0000000..47e2041
--- /dev/null
+++ b/polymatrix/variable/impl.py
@@ -0,0 +1,8 @@
+from dataclassabc import dataclassabc
+
+from .abc import Variable
+
+@dataclassabc(frozen=True)
+class VariableImpl(Variable):
+ name: str
+ shape: tuple[int, int]
diff --git a/polymatrix/variable/init.py b/polymatrix/variable/init.py
new file mode 100644
index 0000000..de217bd
--- /dev/null
+++ b/polymatrix/variable/init.py
@@ -0,0 +1,5 @@
+from .abc import Variable
+from .impl import VariableImpl
+
+def init_variable(name: str, shape: tuple[int, int]) -> Variable:
+ return VariableImpl(str, shape)