From fd26d918b4723c85ee762c6479fddc175c085acf Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Wed, 6 Mar 2024 12:01:39 +0100 Subject: Rename mdpoly.types to mdpoly.index as it only contains index types --- mdpoly/abc.py | 2 +- mdpoly/algebra.py | 2 +- mdpoly/index.py | 224 +++++++++++++++++++++++++++++++++++++++++++++ mdpoly/leaves.py | 2 +- mdpoly/representations.py | 2 +- mdpoly/state.py | 4 +- mdpoly/types.py | 229 ---------------------------------------------- mdpoly/util.py | 6 ++ 8 files changed, 236 insertions(+), 235 deletions(-) create mode 100644 mdpoly/index.py delete mode 100644 mdpoly/types.py diff --git a/mdpoly/abc.py b/mdpoly/abc.py index 811910b..0a2e783 100644 --- a/mdpoly/abc.py +++ b/mdpoly/abc.py @@ -1,6 +1,6 @@ """ Abstract Base Classes of MDPoly """ -from .types import Number, Shape, MatrixIndex, PolyIndex +from .index import Number, Shape, MatrixIndex, PolyIndex from .constants import NUMERICS_EPS from .util import iszero diff --git a/mdpoly/algebra.py b/mdpoly/algebra.py index abd15d5..a07c4d6 100644 --- a/mdpoly/algebra.py +++ b/mdpoly/algebra.py @@ -4,7 +4,7 @@ from .abc import Leaf, Expr, Repr from .leaves import Nothing, Const, Param, Var from .errors import AlgebraicError, InvalidShape from .state import State -from .types import Shape, MatrixIndex, PolyIndex +from .index import Shape, MatrixIndex, PolyIndex from .representations import HasRepr from typing import Protocol, Self, Any, runtime_checkable diff --git a/mdpoly/index.py b/mdpoly/index.py new file mode 100644 index 0000000..2837a46 --- /dev/null +++ b/mdpoly/index.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +from .errors import InvalidShape +from .util import partition, isclose + +from itertools import filterfalse +from typing import Self, NamedTuple, Iterable, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .state import State + from .leaves import Var + +Number = int | float + + +class Shape(NamedTuple): + """ Describes the shape of a mathematical object. """ + rows: int + cols: int + + @property + def infer(self) -> Self: + return self.__class__(-1, -1) + + @classmethod + def scalar(cls) -> Self: + return cls(1, 1) + + @classmethod + def row(cls, n: int) -> Self: + if n <= 0: + raise InvalidShape("Row vector must have dimension >= 1") + return cls(1, n) + + def is_row(self) -> bool: + return self.cols == 1 + + @classmethod + def col(cls, n: int) -> Self: + if n <= 0: + raise InvalidShape("Column vector must have dimension >= 1") + return cls(n, 1) + + def is_col(self) -> bool: + return self.rows == 1 + + # --- aliases / shorthands --- + + @property + def columns(self): + return self.cols + + @classmethod + def column(cls, n: int) -> Self: + return cls.col(n) + + def is_column(self) -> bool: + return self.is_col() + + +class MatrixIndex(NamedTuple): + """ Tuple to index an element of a matrix or vector. """ + row: int + col: int + + @classmethod + def infer(cls, row=-1, col=-1) -> Self: + return cls(row, col) + + @classmethod + def scalar(cls): + """ Shorthand for index of a scalar """ + return cls(row=0, col=0) + + +class PolyVarIndex(NamedTuple): + """ Tuple to represent a variable and its exponent. They are totally + ordered with respect to the variable index stored in the ``state`` object + (see :py:attr:`mdpoly.state.State.variables`). + + Concretely this represents things like :math:`x^2`, :math:`y^4` but not + :math:`xy`. + """ + var_idx: int # Index in State.variables, not variable! + power: Number + + @classmethod + def from_var(cls, variable: Var, state: State, power: Number =1) -> Self: + """ Make an index from a variable object. """ + return cls(var_idx=state.index(variable), power=power) + + def __eq__(self, other): + if type(other) is not PolyVarIndex: + other = PolyVarIndex(other) + + if self.var_idx != other.var_idx: + return False + + if not isclose(self.power, other.power): + return False + + return True + + def __lt__(self, other): + if type(other) is not PolyVarIndex: + other = PolyVarIndex(other) + + return self.var_idx < other.var_idx + + @classmethod + def constant(cls) -> Self: + """ Special index for constants. + + Constants do not have an associated variable, and the field power is + ignored. + """ + return cls(var_idx=-1, power=0) + + @staticmethod + def is_constant(index: PolyVarIndex) -> bool: + """ Check if is index of a constant term. """ + return index.var_idx == -1 + + +class PolyIndex(tuple[PolyVarIndex]): + """ + Tuple to index a coefficient of a monomial in a polynomial. + + For example, suppose there are two variables :math:`x, y` with indices 0, 1 + respectively (in the State object, see :py:attr:`mdpoly.state.State.variables`). + Then the monomal :math:`xy^2` has an index + + .. code:: py + + PolyIndex(PolyVarIndex(var_idx=0, power=1), PolyVarIndex(var_idx=1, power=2) + + Then given the polynomial + + .. math:: + p(x, y) = 3x^2 + 5xy^2 + + The with the ``PolyIndex`` above we can retrieve the coefficient 5 in front + of :math:`xy^2` from the ``Repr`` object (see also :py:meth:`mdpoly.abc.Repr.at`). + """ + + def __repr__(self) -> str: + name = self.__class__.__qualname__ + indices = ", ".join(map(repr, self)) + return f"{name}({indices})" + + @classmethod + def from_dict(cls, d) -> Self: + """ Construct an index froma dictionary, where the keys are the + variable index from the state object and the values are the exponent + (power). """ + return cls(PolyVarIndex(k, v) for k, v in d.items()) + + @classmethod + def constant(cls) -> Self: + """ Index of the constant term. """ + return cls((PolyVarIndex.constant(),)) + + @staticmethod + def is_constant(index: PolyIndex) -> bool: + """ Check if is index of a constant term. """ + if len(index) != 1: + return False + + return PolyVarIndex.is_constant(index[0]) + + @classmethod + def sort(cls, index: tuple | Self) -> Self: + """ Sort a tuple of indices. """ + return cls(sorted(index)) + + @classmethod + def product(cls, left: Self, right: Self) -> Self: + """ Compute index of product. + + For example if left is the index of :math:`xy` and right is the index + of :math:`y^2` this functions returns the index of :math:`xy^3`. + """ + if cls.is_constant(left): + return right + + if cls.is_constant(right): + return left + + result: dict[int, Number] = dict(filterfalse(PolyVarIndex.is_constant, left)) + for idx, power in filterfalse(PolyVarIndex.is_constant, right): + if idx not in result: + result[idx] = power + else: + result[idx] += power + + return cls.sort(cls.from_dict(result)) + + @classmethod + def differentiate(cls, index: Self, wrt: int) -> Optional[Self]: + """ Compute the index after differentiation + + For example if the index is :math:`xy^2` and ``wrt`` is the index of + :math:`y` this function returns the index of :math:`xy`. In particular + this function takes care of the edge cases of differentiating a + constant and differentiating linear terms. Specifically, if the index + is a constant ``None`` is returned. + """ + + if cls.is_constant(index): + return None + + def is_wrt_var(idx: PolyVarIndex) -> bool: + return idx.var_idx == wrt + + others, with_wrt_var = partition(is_wrt_var, index) + with_wrt_var, *_ = with_wrt_var # Take first, as should be only one + + # Check if is linear term + if isclose(with_wrt_var.power, 1.): + return cls.sort(tuple(others) + cls.constant()) + + # Decrease exponent + new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1)) + return cls.sort(tuple(others) + (new_idx,)) diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py index dbfd332..af43435 100644 --- a/mdpoly/leaves.py +++ b/mdpoly/leaves.py @@ -1,5 +1,5 @@ from .abc import Leaf, Repr -from .types import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex +from .index import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex from .state import State from .errors import MissingParameters from .representations import HasRepr diff --git a/mdpoly/representations.py b/mdpoly/representations.py index b7943ad..b541c3e 100644 --- a/mdpoly/representations.py +++ b/mdpoly/representations.py @@ -1,5 +1,5 @@ from .abc import Repr -from .types import Number, Shape, MatrixIndex, PolyIndex +from .index import Number, Shape, MatrixIndex, PolyIndex from .state import State from typing import Protocol, Iterable diff --git a/mdpoly/state.py b/mdpoly/state.py index 70fca24..e0a6bc5 100644 --- a/mdpoly/state.py +++ b/mdpoly/state.py @@ -2,10 +2,10 @@ from __future__ import annotations from typing import TYPE_CHECKING from dataclasses import dataclass, field -from .types import PolyVarIndex +from .index import PolyVarIndex if TYPE_CHECKING: - from .types import Number + from .index import Number from .leaves import Var, Param, Const diff --git a/mdpoly/types.py b/mdpoly/types.py deleted file mode 100644 index 9e4ecb7..0000000 --- a/mdpoly/types.py +++ /dev/null @@ -1,229 +0,0 @@ -from __future__ import annotations - -from .errors import InvalidShape -from .util import partition, isclose - -from itertools import filterfalse -from typing import Self, NamedTuple, Iterable, Optional, TYPE_CHECKING - -if TYPE_CHECKING: - from .state import State - from .leaves import Var - -Number = int | float - - -def filtertype(t: type, seq: Iterable): - """ Filter based on type. """ - yield from filter(lambda x: isinstance(x, t), seq) - - -class Shape(NamedTuple): - """ Describes the shape of a mathematical object. """ - rows: int - cols: int - - @property - def infer(self) -> Self: - return self.__class__(-1, -1) - - @classmethod - def scalar(cls) -> Self: - return cls(1, 1) - - @classmethod - def row(cls, n: int) -> Self: - if n <= 0: - raise InvalidShape("Row vector must have dimension >= 1") - return cls(1, n) - - def is_row(self) -> bool: - return self.cols == 1 - - @classmethod - def col(cls, n: int) -> Self: - if n <= 0: - raise InvalidShape("Column vector must have dimension >= 1") - return cls(n, 1) - - def is_col(self) -> bool: - return self.rows == 1 - - # --- aliases / shorthands --- - - @property - def columns(self): - return self.cols - - @classmethod - def column(cls, n: int) -> Self: - return cls.col(n) - - def is_column(self) -> bool: - return self.is_col() - - -class MatrixIndex(NamedTuple): - """ Tuple to index an element of a matrix or vector. """ - row: int - col: int - - @classmethod - def infer(cls, row=-1, col=-1) -> Self: - return cls(row, col) - - @classmethod - def scalar(cls): - """ Shorthand for index of a scalar """ - return cls(row=0, col=0) - - -class PolyVarIndex(NamedTuple): - """ Tuple to represent a variable and its exponent. They are totally - ordered with respect to the variable index stored in the ``state`` object - (see :py:attr:`mdpoly.state.State.variables`). - - Concretely this represents things like :math:`x^2`, :math:`y^4` but not - :math:`xy`. - """ - var_idx: int # Index in State.variables, not variable! - power: Number - - @classmethod - def from_var(cls, variable: Var, state: State, power: Number =1) -> Self: - """ Make an index from a variable object. """ - return cls(var_idx=state.index(variable), power=power) - - def __eq__(self, other): - if type(other) is not PolyVarIndex: - other = PolyVarIndex(other) - - if self.var_idx != other.var_idx: - return False - - if not isclose(self.power, other.power): - return False - - return True - - def __lt__(self, other): - if type(other) is not PolyVarIndex: - other = PolyVarIndex(other) - - return self.var_idx < other.var_idx - - @classmethod - def constant(cls) -> Self: - """ Special index for constants. - - Constants do not have an associated variable, and the field power is - ignored. - """ - return cls(var_idx=-1, power=0) - - @staticmethod - def is_constant(index: PolyVarIndex) -> bool: - """ Check if is index of a constant term. """ - return index.var_idx == -1 - - -class PolyIndex(tuple[PolyVarIndex]): - """ - Tuple to index a coefficient of a monomial in a polynomial. - - For example, suppose there are two variables :math:`x, y` with indices 0, 1 - respectively (in the State object, see :py:attr:`mdpoly.state.State.variables`). - Then the monomal :math:`xy^2` has an index - - .. code:: py - - PolyIndex(PolyVarIndex(var_idx=0, power=1), PolyVarIndex(var_idx=1, power=2) - - Then given the polynomial - - .. math:: - p(x, y) = 3x^2 + 5xy^2 - - The with the ``PolyIndex`` above we can retrieve the coefficient 5 in front - of :math:`xy^2` from the ``Repr`` object (see also :py:meth:`mdpoly.abc.Repr.at`). - """ - - def __repr__(self) -> str: - name = self.__class__.__qualname__ - indices = ", ".join(map(repr, self)) - return f"{name}({indices})" - - @classmethod - def from_dict(cls, d) -> Self: - """ Construct an index froma dictionary, where the keys are the - variable index from the state object and the values are the exponent - (power). """ - return cls(PolyVarIndex(k, v) for k, v in d.items()) - - @classmethod - def constant(cls) -> Self: - """ Index of the constant term. """ - return cls((PolyVarIndex.constant(),)) - - @staticmethod - def is_constant(index: PolyIndex) -> bool: - """ Check if is index of a constant term. """ - if len(index) != 1: - return False - - return PolyVarIndex.is_constant(index[0]) - - @classmethod - def sort(cls, index: tuple | Self) -> Self: - """ Sort a tuple of indices. """ - return cls(sorted(index)) - - @classmethod - def product(cls, left: Self, right: Self) -> Self: - """ Compute index of product. - - For example if left is the index of :math:`xy` and right is the index - of :math:`y^2` this functions returns the index of :math:`xy^3`. - """ - if cls.is_constant(left): - return right - - if cls.is_constant(right): - return left - - result: dict[int, Number] = dict(filterfalse(PolyVarIndex.is_constant, left)) - for idx, power in filterfalse(PolyVarIndex.is_constant, right): - if idx not in result: - result[idx] = power - else: - result[idx] += power - - return cls.sort(cls.from_dict(result)) - - @classmethod - def differentiate(cls, index: Self, wrt: int) -> Optional[Self]: - """ Compute the index after differentiation - - For example if the index is :math:`xy^2` and ``wrt`` is the index of - :math:`y` this function returns the index of :math:`xy`. In particular - this function takes care of the edge cases of differentiating a - constant and differentiating linear terms. Specifically, if the index - is a constant ``None`` is returned. - """ - - if cls.is_constant(index): - return None - - def is_wrt_var(idx: PolyVarIndex) -> bool: - return idx.var_idx == wrt - - others, with_wrt_var = partition(is_wrt_var, index) - with_wrt_var, *_ = with_wrt_var # Take first, as should be only one - - # Check if is linear term - if isclose(with_wrt_var.power, 1.): - return cls.sort(tuple(others) + cls.constant()) - - # Decrease exponent - new_idx = PolyVarIndex(var_idx=wrt, power=(with_wrt_var.power - 1)) - return cls.sort(tuple(others) + (new_idx,)) diff --git a/mdpoly/util.py b/mdpoly/util.py index 937e7d4..c7f3b51 100644 --- a/mdpoly/util.py +++ b/mdpoly/util.py @@ -1,6 +1,7 @@ from .constants import NUMERICS_EPS from itertools import tee, filterfalse +from typing import Iterable def partition(pred, iterable): """Partition entries into false entries and true entries. @@ -30,3 +31,8 @@ def iszero(x: float, eps: float =NUMERICS_EPS) -> bool: ``eps`` or the default value at :py:data:`mdpoly.constants.NUMERICS_EPS`. """ return isclose(x, 0.) + + +def filtertype(t: type, seq: Iterable): + """ Filter based on type. """ + yield from filter(lambda x: isinstance(x, t), seq) -- cgit v1.2.1