summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/expression/expression.py96
-rw-r--r--polymatrix/expression/from_.py2
-rw-r--r--polymatrix/statemonad/mixins.py45
3 files changed, 58 insertions, 85 deletions
diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py
index 52c3da6..abe37aa 100644
--- a/polymatrix/expression/expression.py
+++ b/polymatrix/expression/expression.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import dataclasses
-import typing
import numpy as np
+from numpy.typing import NDArray
from abc import ABC, abstractmethod
from dataclassabc import dataclassabc
from typing_extensions import override
@@ -42,7 +44,7 @@ class Expression(ExpressionBaseMixin, ABC):
def read(self, state: ExpressionState) -> PolyMatrix:
return self.apply(state)[1]
- def __add__(self, other: ExpressionBaseMixin) -> "Expression":
+ def __add__(self, other: ExpressionBaseMixin) -> Expression:
return self._binary(polymatrix.expression.init.init_addition_expr, self, other)
def __getattr__(self, name):
@@ -57,7 +59,8 @@ class Expression(ExpressionBaseMixin, ABC):
else:
return attr
- def __getitem__(self, key: tuple[int, int]):
+ def __getitem__(self, key: tuple[int, int]) -> Expression:
+ # FIXME: typing for key is incorrect, could be a slice
return self.copy(
underlying=polymatrix.expression.init.init_get_item_expr(
underlying=self.underlying,
@@ -65,25 +68,16 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def __matmul__(
- self, other: ExpressionBaseMixin | np.ndarray
- ) -> "Expression":
+ def __matmul__(self, other: ExpressionBaseMixin | np.ndarray) -> Expression:
return self._binary(
polymatrix.expression.init.init_matrix_mult_expr, self, other
)
- def __mul__(self, other) -> "Expression":
+ def __mul__(self, other) -> Expression:
return self._binary(polymatrix.expression.init.init_elem_mult_expr, self, other)
- def __pow__(self, num):
- curr = 1
-
- # FIXME: this only works for positive integral powers, consider raising
- # an error if the power is not a positive integer
- for _ in range(num):
- curr = curr * self
-
- return curr
+ def __pow__(self, exponent: Expression | int | float) -> Expression:
+ return self._binary(polymatrix.expression.init.init_power_expr, self, exponent)
def __neg__(self):
return self * (-1)
@@ -132,7 +126,7 @@ class Expression(ExpressionBaseMixin, ABC):
return right.copy(underlying=op(left, right.underlying, stack))
- def cache(self) -> "Expression":
+ def cache(self) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_cache_expr(
underlying=self.underlying,
@@ -171,11 +165,9 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def diff(
- self,
- variables: "Expression",
- introduce_derivatives: bool = None,
- ) -> "Expression":
+ # FIXME: sometime variables is a tuple, sometimes an expression. make consistent.
+ # FIXME: this function is probably broken
+ def diff(self, variables: Expression, introduce_derivatives: bool | None = None) -> Expression:
return self.copy(
underlying=diff(
expression=self,
@@ -184,10 +176,7 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def divergence(
- self,
- variables: tuple,
- ) -> "Expression":
+ def divergence(self, variables: tuple) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_divergence_expr(
underlying=self.underlying,
@@ -195,11 +184,7 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def eval(
- self,
- variable: tuple,
- value: tuple[float, ...] = None,
- ) -> "Expression":
+ def eval(self, variable: tuple, value: tuple[float, ...] | None = None,) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_eval_expr(
underlying=self.underlying,
@@ -211,9 +196,9 @@ class Expression(ExpressionBaseMixin, ABC):
# also applies to monomials (and variables?)
def filter(
self,
- predicator: "Expression",
- inverse: bool = None,
- ) -> "Expression":
+ predicator: Expression,
+ inverse: bool | None = None,
+ ) -> Expression:
return self.copy(
underlying=filter_(
underlying=self.underlying,
@@ -223,7 +208,7 @@ class Expression(ExpressionBaseMixin, ABC):
)
# only applies to symmetric matrix
- def from_symmetric_matrix(self) -> "Expression":
+ def from_symmetric_matrix(self) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_from_symmetric_matrix_expr(
underlying=self.underlying,
@@ -231,11 +216,7 @@ class Expression(ExpressionBaseMixin, ABC):
)
# only applies to monomials
- def half_newton_polytope(
- self,
- variables: "Expression",
- filter: "Expression | None" = None,
- ) -> "Expression":
+ def half_newton_polytope(self, variables: Expression, filter: Expression | None = None,) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_half_newton_polytope_expr(
monomials=self.underlying,
@@ -245,8 +226,11 @@ class Expression(ExpressionBaseMixin, ABC):
)
def integrate(
- self, variables: "Expression", from_: tuple[float, ...], to: tuple[float, ...]
- ) -> "Expression":
+ self,
+ variables: Expression,
+ from_: tuple[float, ...],
+ to: tuple[float, ...]
+ ) -> Expression:
return self.copy(
underlying=integrate(
expression=self,
@@ -256,7 +240,7 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def linear_matrix_in(self, variable: "Expression") -> "Expression":
+ def linear_matrix_in(self, variable: Expression) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_linear_matrix_in_expr(
underlying=self.underlying,
@@ -264,10 +248,7 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def linear_monomials(
- self,
- variables: "Expression",
- ) -> "Expression":
+ def linear_monomials(self, variables: Expression) -> Expression:
return self.copy(
underlying=linear_monomials(
expression=self.underlying,
@@ -277,10 +258,10 @@ class Expression(ExpressionBaseMixin, ABC):
def linear_in(
self,
- variables: "Expression",
- monomials: "Expression" = None,
- ignore_unmatched: bool = None,
- ) -> "Expression":
+ variables: Expression,
+ monomials: Expression | None = None,
+ ignore_unmatched: bool | None = None,
+ ) -> Expression:
return self.copy(
underlying=linear_in(
expression=self.underlying,
@@ -290,10 +271,7 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def legendre(
- self,
- degrees: tuple[int, ...] = None,
- ) -> "Expression":
+ def legendre(self, degrees: tuple[int, ...] | None = None) -> Expression:
return self.copy(
underlying=legendre(
expression=self.underlying,
@@ -301,14 +279,14 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def max(self) -> "Expression":
+ def max(self) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_max_expr(
underlying=self.underlying,
),
)
- def parametrize(self, name: str = None) -> "Expression":
+ def parametrize(self, name: str | None = None) -> Expression:
return self.copy(
underlying=polymatrix.expression.init.init_parametrize_expr(
underlying=self.underlying,
@@ -316,9 +294,7 @@ class Expression(ExpressionBaseMixin, ABC):
),
)
- def quadratic_in(
- self, variables: "Expression", monomials: "Expression" = None
- ) -> "Expression":
+ def quadratic_in(self, variables: Expression, monomials: Expression | None = None) -> Expression:
if monomials is None:
monomials = self.quadratic_monomials(variables)
diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py
index 52f9d06..052ee7a 100644
--- a/polymatrix/expression/from_.py
+++ b/polymatrix/expression/from_.py
@@ -33,7 +33,7 @@ def from_(
)
-def from_statemonad(monad: StateMonad):
+def from_statemonad(monad: StateMonad) -> Expression:
return init_expression(polymatrix.expression.init.init_from_statemonad(monad))
diff --git a/polymatrix/statemonad/mixins.py b/polymatrix/statemonad/mixins.py
index 68d1db6..5703241 100644
--- a/polymatrix/statemonad/mixins.py
+++ b/polymatrix/statemonad/mixins.py
@@ -1,11 +1,14 @@
-import abc
-import dataclasses
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import replace
+from functools import wraps
from typing import Callable, Tuple, TypeVar, Generic
-class StateCacheMixin(abc.ABC):
+class StateCacheMixin(ABC):
@property
- @abc.abstractmethod
+ @abstractmethod
def cache(self) -> dict: ...
@@ -34,18 +37,15 @@ V = TypeVar("V")
# NP: and make a new function (M[U] -> M[V] -> M[W])
#
# NP: TODO: text comparing the above to implementation below
-class StateMonadMixin(
- Generic[State, U],
- abc.ABC,
-):
+class StateMonadMixin(Generic[State, U], ABC):
@property
- @abc.abstractmethod
+ @abstractmethod
def apply_func(self) -> Callable[[State], tuple[State, U]]:
# NP: TODO comment
...
@property
- @abc.abstractmethod
+ @abstractmethod
def arguments(self) -> U | None:
# arguments that were given to the function apply_func.
# this field is optional
@@ -56,50 +56,47 @@ class StateMonadMixin(
# sympy.
...
- # NP: typing, use from __future__ import annotations
- def map(self, fn: Callable[[U], V]) -> 'StateMonadMixin[State, V]':
- # NP: add functools.wrap(fn) decorator to copy docstrings etc.
+ def map(self, fn: Callable[[U], V]) -> StateMonadMixin[State, V]:
+ @wraps(fn)
def internal_map(state: State) -> Tuple[State, U]:
n_state, val = self.apply(state)
return n_state, fn(val)
- return dataclasses.replace(self, apply_func=internal_map)
+ return replace(self, apply_func=internal_map)
- # NP: shouldn't typing be
- # NP: flat_map(self, fn: Callable[[U], StateMonadMixin[State, V]]) -> StateMonadMixin[State, V]
- def flat_map(self, fn: Callable[[U], 'StateMonadMixin']) -> 'StateMonadMixin[State, V]':
- # NP: add functools.wrap(fn) decorator
+ def flat_map(self, fn: Callable[[U], StateMonadMixin[State, V]]) -> StateMonadMixin[State, V]:
+ @wraps(fn)
def internal_map(state: State) -> Tuple[State, V]:
n_state, val = self.apply(state)
return fn(val).apply(n_state)
- return dataclasses.replace(self, apply_func=internal_map)
+ return replace(self, apply_func=internal_map)
# FIXME: typing
- def zip(self, other: 'StateMonadMixin') -> 'StateMonadMixin':
+ def zip(self, other: StateMonadMixin) -> StateMonadMixin:
def internal_map(state: State) -> Tuple[State, V]:
state, val1 = self.apply(state)
state, val2 = other.apply(state)
return state, (val1, val2)
- return dataclasses.replace(self, apply_func=internal_map)
+ return replace(self, apply_func=internal_map)
# FIXME: typing
- def cache(self) -> 'StateMonadMixin':
+ def cache(self) -> StateMonadMixin:
def internal_map(state: State) -> Tuple[State, V]:
if self in state.cache:
return state, state.cache[self]
state, val = self.apply(state)
- state = dataclasses.replace(
+ state = replace(
state,
cache=state.cache | {self: val},
)
return state, val
- return dataclasses.replace(self, apply_func=internal_map)
+ return replace(self, apply_func=internal_map)
# NP: Need to find consistent naming and explain somewhere naming convention
# NP: of monad operations (is this from scala conventions? I have never used scala)