diff options
-rw-r--r-- | mdpoly/leaves.py | 44 | ||||
-rw-r--r-- | mdpoly/types.py | 37 |
2 files changed, 72 insertions, 9 deletions
diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py index 4dd487d..4032ead 100644 --- a/mdpoly/leaves.py +++ b/mdpoly/leaves.py @@ -1,8 +1,13 @@ -from .abc import Leaf -from .types import Number, Shape +from .abc import Leaf, Repr +from .types import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex +from .state import State +from .representations import HasRepr +from typing import TypeVar from dataclasses import dataclass +T = TypeVar("T", bound=Repr) + @dataclass(frozen=True) class Nothing(Leaf): @@ -14,22 +19,42 @@ class Nothing(Leaf): @dataclass(frozen=True) -class Const(Leaf): +class Const(Leaf, HasRepr): """ - A constant + A scalar constant """ value: Number name: str = "" - shape: Shape = Shape.scalar() + shape: Shape = Shape.scalar + + def to_repr(self, repr_type: type[T], state: State) -> T: + r = repr_type() + r.set(MatrixIndex.scalar, PolyIndex.constant, self.value) + return r, state + + def __repr__(self) -> str: + if self.name: + return self.name + + return repr(self.value) @dataclass(frozen=True) -class Var(Leaf): +class Var(Leaf, HasRepr): """ Polynomial variable """ name: str - shape: Shape = Shape.scalar() + shape: Shape = Shape.scalar + + def to_repr(self, repr_type: type[T], state: State) -> T: + r = repr_type() + idx = PolyVarIndex(varindex=state.index(self), power=1) + r.set(MatrixIndex.scalar, PolyIndex(idx), 1) + return r, state + + def __repr__(self) -> str: + return self.name @dataclass(frozen=True) @@ -38,4 +63,7 @@ class Param(Leaf): A parameter """ name: str - shape: Shape = Shape.scalar() + shape: Shape = Shape.scalar + + def __repr__(self) -> str: + return self.name diff --git a/mdpoly/types.py b/mdpoly/types.py index 29e7b83..a0ea671 100644 --- a/mdpoly/types.py +++ b/mdpoly/types.py @@ -1,6 +1,12 @@ +from __future__ import annotations + from .errors import InvalidShape -from typing import Self, NamedTuple, Iterable +from typing import Self, NamedTuple, Iterable, TYPE_CHECKING + +if TYPE_CHECKING: + from .state import State + from .leaves import Var Number = int | float @@ -17,10 +23,12 @@ class Shape(NamedTuple): cols: int @classmethod + @property def infer(cls) -> Self: return cls(-1, -1) @classmethod + @property def scalar(cls) -> Self: return cls(1, 1) @@ -49,6 +57,12 @@ class MatrixIndex(NamedTuple): def infer(cls, row=-1, col=-1) -> Self: return cls(row, col) + @classmethod + @property + def scalar(cls): + """ Shorthand for index of a scalar """ + return cls(row=0, col=0) + class PolyVarIndex(NamedTuple): """ @@ -59,6 +73,16 @@ class PolyVarIndex(NamedTuple): varindex: int power: Number + @classmethod + @property + def constant(cls): + """ Special index for constants, which have no variable """ + return cls(varindex=-1, power=0) + + @classmethod + def of_var(cls, variable: Var, state: State, power: Number =1) -> Self: + return cls(varindex=state.index(variable), power=power) + class PolyIndex(tuple[PolyVarIndex]): """ @@ -78,3 +102,14 @@ class PolyIndex(tuple[PolyVarIndex]): The with the PolyIndex above we can retrieve the coefficient 5 in front of xy^2. """ + + def __repr__(self) -> str: + name = self.__class__.__qualname__ + indices = ", ".join(map(repr, self)) + return f"{name}({indices})" + + @classmethod + @property + def constant(cls): + """ Index of the constant term """ + return cls(PolyVarIndex.constant) |