aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--mdpoly/leaves.py44
-rw-r--r--mdpoly/types.py37
2 files changed, 72 insertions, 9 deletions
diff --git a/mdpoly/leaves.py b/mdpoly/leaves.py
index 4dd487d..4032ead 100644
--- a/mdpoly/leaves.py
+++ b/mdpoly/leaves.py
@@ -1,8 +1,13 @@
-from .abc import Leaf
-from .types import Number, Shape
+from .abc import Leaf, Repr
+from .types import Number, Shape, MatrixIndex, PolyVarIndex, PolyIndex
+from .state import State
+from .representations import HasRepr
+from typing import TypeVar
from dataclasses import dataclass
+T = TypeVar("T", bound=Repr)
+
@dataclass(frozen=True)
class Nothing(Leaf):
@@ -14,22 +19,42 @@ class Nothing(Leaf):
@dataclass(frozen=True)
-class Const(Leaf):
+class Const(Leaf, HasRepr):
"""
- A constant
+ A scalar constant
"""
value: Number
name: str = ""
- shape: Shape = Shape.scalar()
+ shape: Shape = Shape.scalar
+
+ def to_repr(self, repr_type: type[T], state: State) -> T:
+ r = repr_type()
+ r.set(MatrixIndex.scalar, PolyIndex.constant, self.value)
+ return r, state
+
+ def __repr__(self) -> str:
+ if self.name:
+ return self.name
+
+ return repr(self.value)
@dataclass(frozen=True)
-class Var(Leaf):
+class Var(Leaf, HasRepr):
"""
Polynomial variable
"""
name: str
- shape: Shape = Shape.scalar()
+ shape: Shape = Shape.scalar
+
+ def to_repr(self, repr_type: type[T], state: State) -> T:
+ r = repr_type()
+ idx = PolyVarIndex(varindex=state.index(self), power=1)
+ r.set(MatrixIndex.scalar, PolyIndex(idx), 1)
+ return r, state
+
+ def __repr__(self) -> str:
+ return self.name
@dataclass(frozen=True)
@@ -38,4 +63,7 @@ class Param(Leaf):
A parameter
"""
name: str
- shape: Shape = Shape.scalar()
+ shape: Shape = Shape.scalar
+
+ def __repr__(self) -> str:
+ return self.name
diff --git a/mdpoly/types.py b/mdpoly/types.py
index 29e7b83..a0ea671 100644
--- a/mdpoly/types.py
+++ b/mdpoly/types.py
@@ -1,6 +1,12 @@
+from __future__ import annotations
+
from .errors import InvalidShape
-from typing import Self, NamedTuple, Iterable
+from typing import Self, NamedTuple, Iterable, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .state import State
+ from .leaves import Var
Number = int | float
@@ -17,10 +23,12 @@ class Shape(NamedTuple):
cols: int
@classmethod
+ @property
def infer(cls) -> Self:
return cls(-1, -1)
@classmethod
+ @property
def scalar(cls) -> Self:
return cls(1, 1)
@@ -49,6 +57,12 @@ class MatrixIndex(NamedTuple):
def infer(cls, row=-1, col=-1) -> Self:
return cls(row, col)
+ @classmethod
+ @property
+ def scalar(cls):
+ """ Shorthand for index of a scalar """
+ return cls(row=0, col=0)
+
class PolyVarIndex(NamedTuple):
"""
@@ -59,6 +73,16 @@ class PolyVarIndex(NamedTuple):
varindex: int
power: Number
+ @classmethod
+ @property
+ def constant(cls):
+ """ Special index for constants, which have no variable """
+ return cls(varindex=-1, power=0)
+
+ @classmethod
+ def of_var(cls, variable: Var, state: State, power: Number =1) -> Self:
+ return cls(varindex=state.index(variable), power=power)
+
class PolyIndex(tuple[PolyVarIndex]):
"""
@@ -78,3 +102,14 @@ class PolyIndex(tuple[PolyVarIndex]):
The with the PolyIndex above we can retrieve the coefficient 5 in front
of xy^2.
"""
+
+ def __repr__(self) -> str:
+ name = self.__class__.__qualname__
+ indices = ", ".join(map(repr, self))
+ return f"{name}({indices})"
+
+ @classmethod
+ @property
+ def constant(cls):
+ """ Index of the constant term """
+ return cls(PolyVarIndex.constant)