From 5c5268d2adfa3dfb6fb1426ac3d59d08c9e36d2b Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Tue, 4 Jun 2024 17:53:02 +0200 Subject: Separate symbols from variables to avoid shape problems Variables are now expression that have a shape and contain a symbol. The symbol in its own is not a variable, as the shape is unknown, but this information can is stored in the state object. So it is possible to reconstruct a variable by doing init_variable_expr(sym, state.get_shape(sym)) --- polymatrix/expression/expression.py | 12 +- polymatrix/expression/from_.py | 28 +--- polymatrix/expression/impl.py | 9 +- polymatrix/expression/init.py | 16 +- polymatrix/expression/mixins/fromsympyexprmixin.py | 6 +- .../expression/mixins/parametrizeexprmixin.py | 9 +- polymatrix/expression/mixins/variableexprmixin.py | 61 +++++++ polymatrix/expression/mixins/variablemixin.py | 59 ------- polymatrix/expression/to.py | 2 +- polymatrix/expressionstate.py | 178 ++++++++++++--------- polymatrix/symbol.py | 2 + polymatrix/variable.py | 23 --- 12 files changed, 193 insertions(+), 212 deletions(-) create mode 100644 polymatrix/expression/mixins/variableexprmixin.py delete mode 100644 polymatrix/expression/mixins/variablemixin.py create mode 100644 polymatrix/symbol.py delete mode 100644 polymatrix/variable.py diff --git a/polymatrix/expression/expression.py b/polymatrix/expression/expression.py index 3222034..fc4d3b6 100644 --- a/polymatrix/expression/expression.py +++ b/polymatrix/expression/expression.py @@ -11,11 +11,10 @@ from typing_extensions import override import polymatrix.expression.init as init from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expression.mixins.variableexprmixin import VariableExprMixin from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.utils.getstacklines import get_stack_lines -from polymatrix.variable import Variable from polymatrix.expression.op import ( diff, @@ -469,7 +468,7 @@ def init_expression( ) -class VariableExpression(Expression, Variable): +class VariableExpression(Expression): """ Expression that is a polynomial variable, i.e. an expression that cannot be reduced further. @@ -478,11 +477,6 @@ class VariableExpression(Expression, Variable): def var(self): return self.underlying - @override - @property - def name(self): - return self.underlyng.name - @override @property def shape(self): @@ -500,6 +494,6 @@ class VariableExpressionImpl(VariableExpression): return init_expression(underlying) -def init_variable_expression(underlying: VariableMixin): +def init_variable_expression(underlying: VariableExprMixin): return VariableExpressionImpl(underlying) diff --git a/polymatrix/expression/from_.py b/polymatrix/expression/from_.py index a3ee56e..7198e0c 100644 --- a/polymatrix/expression/from_.py +++ b/polymatrix/expression/from_.py @@ -9,12 +9,13 @@ import polymatrix.expression.init as init from polymatrix.expression.init import init_variable_expr from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.mixins.variableexprmixin import VariableExprMixin from polymatrix.expression.typing import FromSupportedTypes from polymatrix.expressionstate import ExpressionState from polymatrix.statemonad import StateMonad from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.utils.deprecation import deprecated -from polymatrix.variable import Variable +from polymatrix.symbol import Symbol from polymatrix.expression.expression import ( init_expression, Expression, @@ -45,29 +46,6 @@ def from_any_or(value: FromSupportedTypes, value_if_not_supported: Any) -> Expre elif isinstance(value, ExpressionBaseMixin): return init_expression(value) - elif isinstance(value, Variable) and not isinstance(value, ExpressionBaseMixin): - # This happens when a variable is constructed somewhere using - # polymatrix.variable.init_variable. What should happen here? - - # This is problematic because if there is already another variable in - # the state with the same name, but different type (eg a VariableExpr) - # it will throw an error. - - # Also we need to consider the case when a variable was created by - # another package, e.g. is an optimization variable. What happens then? - - # Should polymatrix.variable.VariableImpl even exist in the first - # place? It was part of the proposed architecture but I think it causes - # more problems than it is solving. - - # Commented code below won't work - - # value = init_variable_expr(value.name, shape=value.shape) - # return init_variable_expression(value) - - raise NotImplementedError("You have encountered a convoluted edge case. " - "I'm very sorry it doesn't work yet.") - elif isinstance(value, int | float): return from_number(value) @@ -111,7 +89,7 @@ def from_name(name: str, shape: tuple[int, int] | ExpressionBaseMixin = (1,1)) - if isinstance(shape, Expression): shape = shape.underlying - return init_variable_expression(init.init_variable_expr(name, shape)) + return init_variable_expression(init.init_variable_expr(Symbol(name), shape)) def from_number(num: int | float) -> Expression: diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 47829e0..eccf99a 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -6,6 +6,7 @@ from polymatrix.statemonad import StateMonad import dataclassabc from polymatrix.utils.getstacklines import FrameSummary +from polymatrix.symbol import Symbol from polymatrix.expression.mixins.additionexprmixin import AdditionExprMixin from polymatrix.expression.mixins.arangeexprmixin import ARangeExprMixin @@ -59,7 +60,7 @@ from polymatrix.expression.mixins.tosortedvariablesmixin import ToSortedVariable from polymatrix.expression.mixins.tosymmetricmatrixexprmixin import ToSymmetricMatrixExprMixin from polymatrix.expression.mixins.transposeexprmixin import TransposeExprMixin from polymatrix.expression.mixins.truncateexprmixin import TruncateExprMixin -from polymatrix.expression.mixins.variablemixin import VariableMixin +from polymatrix.expression.mixins.variableexprmixin import VariableExprMixin from polymatrix.expression.mixins.fromtermsexprmixin import ( FromPolynomialDataExprMixin, @@ -507,9 +508,9 @@ class TruncateExprImpl(TruncateExprMixin): @dataclassabc.dataclassabc(frozen=True) -class VariableImpl(VariableMixin): - name: str +class VariableExprImpl(VariableExprMixin): + symbol: Symbol shape: tuple[int, int] | ExpressionBaseMixin def __str__(self): - return self.name + return self.symbol diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 31434f7..4b8d182 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -7,14 +7,15 @@ from polymatrix.expression.typing import FromSupportedTypes import polymatrix.expression.impl -from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.expression.expression import VariableExpression +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.expression.utils.formatsubstitutions import format_substitutions from polymatrix.polymatrix.index import PolynomialMatrixData +from polymatrix.polymatrix.mixins import PolyMatrixMixin from polymatrix.statemonad import StateMonad +from polymatrix.symbol import Symbol from polymatrix.utils.getstacklines import FrameSummary from polymatrix.utils.getstacklines import get_stack_lines -from polymatrix.expression.utils.formatsubstitutions import format_substitutions -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.expression.expression import VariableExpression def init_addition_expr( @@ -495,8 +496,5 @@ def init_truncate_expr( ) -def init_variable_expr( - name: str, - shape: tuple[int, int] | ExpressionBaseMixin -): - return polymatrix.expression.impl.VariableImpl(name, shape) +def init_variable_expr(sym: Symbol, shape: tuple[int, int] | ExpressionBaseMixin): + return polymatrix.expression.impl.VariableExprImpl(sym, shape) diff --git a/polymatrix/expression/mixins/fromsympyexprmixin.py b/polymatrix/expression/mixins/fromsympyexprmixin.py index ec28e09..c686972 100644 --- a/polymatrix/expression/mixins/fromsympyexprmixin.py +++ b/polymatrix/expression/mixins/fromsympyexprmixin.py @@ -10,7 +10,7 @@ from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.abc import PolyMatrix from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex -from polymatrix.variable import init_variable +from polymatrix.symbol import Symbol class FromSympyExprMixin(ExpressionBaseMixin): @@ -72,7 +72,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): # Convert sympy variables to our variables sympy_to_var = { - sympy_idx: init_variable(var.name, shape=(1,1)) + sympy_idx: Symbol(var.name) for sympy_idx, var in enumerate(sympy_poly.gens) } @@ -85,7 +85,7 @@ class FromSympyExprMixin(ExpressionBaseMixin): m: list[VariableIndex] = [] for i, exponent in enumerate(monom): var = sympy_to_var[i] - state, idx = state.index(var) + state, idx = state.index(var, shape=(1,1)) # idx.start because var is a scalar m.append(VariableIndex(idx.start, exponent)) diff --git a/polymatrix/expression/mixins/parametrizeexprmixin.py b/polymatrix/expression/mixins/parametrizeexprmixin.py index 644198e..583ac1f 100644 --- a/polymatrix/expression/mixins/parametrizeexprmixin.py +++ b/polymatrix/expression/mixins/parametrizeexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.expressionstate import ExpressionState from polymatrix.polymatrix.index import PolyMatrixDict, MatrixIndex, PolyDict, MonomialIndex, VariableIndex from polymatrix.polymatrix.init import init_poly_matrix from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.variable import init_variable +from polymatrix.symbol import Symbol class ParametrizeExprMixin(ExpressionBaseMixin): @@ -42,16 +42,17 @@ class ParametrizeExprMixin(ExpressionBaseMixin): nrows, ncols = underlying.shape # FIXME: not sure this behaviour is intuitive, discuss - if v := state.get_variable_from_name_or(self.name, if_not_present=False): + if v := state.get_symbol_from_name_or(self.name, if_not_present=False): start, end = state.offset_dict[v] + # FIXME: This is not a good check for shapes, this condition could + # be false even if the shapes do not match if (nrows * ncols) != (end - start): raise ValueError("Cannot parametrize {self.underlying} with variable {v} " "found in state object, because its shape {(nrow, ncols)} " "does not match ({self.underlying.shape}). ") else: - v = init_variable(self.name, shape=(nrows, ncols)) - state = state.register(v) + state = state.register(Symbol(self.name), shape=(nrows, ncols)) p = PolyMatrixDict.empty() indices = state.get_indices(v) diff --git a/polymatrix/expression/mixins/variableexprmixin.py b/polymatrix/expression/mixins/variableexprmixin.py new file mode 100644 index 0000000..83bc66e --- /dev/null +++ b/polymatrix/expression/mixins/variableexprmixin.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import itertools +from abc import abstractmethod +from typing_extensions import override + +from polymatrix.expressionstate import ExpressionState +from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin +from polymatrix.polymatrix.init import init_poly_matrix +from polymatrix.polymatrix.mixins import PolyMatrixMixin +from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex +from polymatrix.symbol import Symbol + + +class VariableExprMixin(ExpressionBaseMixin): + """ Underlying object for VariableExpression """ + + @property + @abstractmethod + def shape(self) -> tuple[int, int] | ExpressionBaseMixin: + """ Shape of the variable expression. """ + + @property + @abstractmethod + def symbol(self) -> Symbol: + """ The symbol representing the variable. """ + + @override + def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: + if isinstance(self.shape, ExpressionBaseMixin): + state, shape_pm = self.shape.apply(state) + if shape_pm.shape != (2, 1): + raise ValueError("If shape is an expression it must evaluate to a 2d row vector, " + f"but here it has shape {shape_pm.shape}") + + # FIXME: should check that they are actually integers + nrows = int(shape_pm.at(0, 0).constant()) + ncols = int(shape_pm.at(1, 0).constant()) + + # Replace shape field with computed shape + state = state.register(self.symbol, shape=(nrows, ncols)) + + elif isinstance(self.shape, tuple): + nrows, ncols = self.shape # for for loop below + state = state.register(self.symbol, self.shape) + + else: + raise ValueError("Shape must be a tuple or expression that " + f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}") + + indices = state.get_indices(self.symbol) + + p = PolyMatrixDict() + for (row, col), index in zip(itertools.product(range(nrows), range(ncols)), indices): + p[row, col] = PolyDict({ + # Create monomial with variable to the first power + # with coefficient of one + MonomialIndex((VariableIndex(index, power=1),)): 1. + }) + + return state, init_poly_matrix(p, shape=(nrows, ncols)) diff --git a/polymatrix/expression/mixins/variablemixin.py b/polymatrix/expression/mixins/variablemixin.py deleted file mode 100644 index 355b6b1..0000000 --- a/polymatrix/expression/mixins/variablemixin.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import itertools -from abc import abstractmethod -from dataclasses import replace -from typing_extensions import override - -from polymatrix.expressionstate import ExpressionState -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.polymatrix.mixins import PolyMatrixMixin -from polymatrix.polymatrix.index import PolyMatrixDict, PolyDict, MonomialIndex, VariableIndex -from polymatrix.variable import Variable - - -class VariableMixin(ExpressionBaseMixin, Variable): - """ Underlying object for VariableExpression """ - - @override - @property - @abstractmethod - def shape(self) -> tuple[int, int] | ExpressionBaseMixin: - """ Shape of the variable expression. """ - - @override - def apply(self, state: ExpressionState) -> tuple[ExpressionState, PolyMatrixMixin]: - if isinstance(self.shape, ExpressionBaseMixin): - state, shape_pm = self.shape.apply(state) - if shape_pm.shape != (2, 1): - raise ValueError("If shape is an expression it must evaluate to a 2d row vector, " - f"but here it has shape {shape_pm.shape}") - - # FIXME: should check that they are actually integers - nrows = int(shape_pm.at(0, 0).constant()) - ncols = int(shape_pm.at(1, 0).constant()) - - # Replace shape field with computed shape - v = replace(self, shape=(nrows, ncols)) - state = state.register(v) - indices = state.get_indices(v) - - elif isinstance(self.shape, tuple): - nrows, ncols = self.shape - state = state.register(self) - indices = state.get_indices(self) - - else: - raise ValueError("Shape must be a tuple or expression that " - f"evaluates to a 2d row vector, cannot be of type {type(self.shape)}") - - p = PolyMatrixDict() - for (row, col), index in zip(itertools.product(range(nrows), range(ncols)), indices): - p[row, col] = PolyDict({ - # Create monomial with variable to the first power - # with coefficient of one - MonomialIndex((VariableIndex(index, power=1),)): 1. - }) - - return state, init_poly_matrix(p, shape=(nrows, ncols)) diff --git a/polymatrix/expression/to.py b/polymatrix/expression/to.py index 7f9bac7..c7f9af4 100644 --- a/polymatrix/expression/to.py +++ b/polymatrix/expression/to.py @@ -54,7 +54,7 @@ def to_sympy( sympy_poly_terms = [] for monomial, coeff in poly.terms(): sympy_monomial = math.prod( - sympy.Symbol(state.get_name(variable.index)) ** variable.power + sympy.Symbol(state.get_symbol(variable.index)) ** variable.power for variable in monomial) if math.isclose(coeff, 1.): diff --git a/polymatrix/expressionstate.py b/polymatrix/expressionstate.py index e330337..562f07a 100644 --- a/polymatrix/expressionstate.py +++ b/polymatrix/expressionstate.py @@ -1,12 +1,13 @@ from __future__ import annotations -from abc import abstractmethod +import math + from typing import Any, NamedTuple, Iterable from math import prod from dataclassabc import dataclassabc from dataclasses import replace -from polymatrix.variable import Variable +from polymatrix.symbol import Symbol from polymatrix.utils.deprecation import deprecated from polymatrix.polymatrix.index import MonomialIndex, VariableIndex @@ -15,7 +16,17 @@ from polymatrix.statemonad import StateCacheMixin # TODO: move to typing submodule class IndexRange(NamedTuple): start: int - end: int + """ Start of the indices """ + stop: int + """ End of the indices, this value is not included """ + + ncols: int + """ + Number of columns of the indexed symbol. + + This information is kept here because if a symbol represents a matrix, the + shape of the matrix is lost. + """ def __lt__(self, other): return self.start < other.start @@ -24,75 +35,86 @@ class IndexRange(NamedTuple): @dataclassabc(frozen=True) class ExpressionState(StateCacheMixin): n_variables: int - """ Number of polynomial variables """ + """ Number of polynomial variables. + + What is the difference between variables and symbols? + Suppose M is a 2x2 matrix, then "M" is a symbol, however being a matrix it + contains 4 variables. The symbol with its shape is a variable. + """ - indices: dict[Variable, IndexRange] - """ Map from variable objects to their indices. """ + indices: dict[Symbol, IndexRange] + """ Map from symbols representing variables to their indices. """ cache: dict """ Cache for StateCacheMixin """ - def index(self, var: Variable) -> tuple[ExpressionState, IndexRange]: + # --- indexing --- + + def index(self, sym: Symbol, shape: tuple[int, int]) -> tuple[ExpressionState, IndexRange]: """ Index a variable and get its index range. """ - if not isinstance(var, Variable): - raise ValueError("State can only index object of type Variable!") + if not isinstance(sym, Symbol): + raise ValueError("State can only index symbols!") - for v, irange in self.indices.items(): + for s, irange in self.indices.items(): # Check if already in there - if v == var: - return self, irange - - # Check that there is not another variable with the same name - if v.name == var.name: - raise ValueError("Variable must have unique names! " - f"There is already a variable named {var.name} " - f"with shape {v.shape}") + if s == sym: + if irange.ncols == shape[1]: + return self, irange + else: + nrows = (irange.end - irange.start) // irange.ncols + raise ValueError(f"Symbols must be unique names! Cannot index symbol " + f"{sym} with shape {shape} because there is already a symbol " + f"with the same name with shape {(nrows, irange.ncols)}") # If not save new index - size = prod(var.shape) - index = IndexRange(start=self.n_variables, end=self.n_variables + size) + size = prod(shape) + index = IndexRange(start=self.n_variables, + stop=self.n_variables + size, + ncols=shape[1]) return replace( self, n_variables=self.n_variables + size, - indices=self.indices | {var: index} + indices=self.indices | {sym: index} ), index - def register(self, var: Variable) -> ExpressionStateMixin: + def register(self, sym: Symbol, shape: tuple[int, int]) -> ExpressionState: """ Create an index for a variable, but does not return the index. If you want the index range use :py:meth:`index` """ - state, _ = self.index(var) + state, _ = self.index(sym, shape) return state - def get_indices(self, var: Variable) -> Iterable[int]: + # --- retrieval of indices --- + + def get_indices(self, sym: Symbol) -> Iterable[int]: """ - Get all indices associated to a variable. + Get all indices associated to a symbol. - When a variable is not a scalar multiple indices will be associated to - the variables, one for each entry. + When a symbol is not a scalar multiple indices will be associated to + the symbol, one for each entry. - See also :py:meth:`get_variable_indices`, :py:meth:`get_monomial_indices`. + See also :py:meth:`get_symbol_indices`, :py:meth:`get_monomial_indices`. """ - if var not in self.indices: - raise IndexError(f"There is no variable {var} in this state object.") + if sym not in self.indices: + raise IndexError(f"There is no symbol {sym} in this state object.") - yield from range(*self.indices[var]) + yield from range(self.indices[sym].start, self.indices[sym].stop) - def get_indices_as_variable_index(self, var: Variable) -> Iterable[VariableIndex]: + def get_indices_as_variable_index(self, sym: Symbol) -> Iterable[VariableIndex]: """ - Get all indices associated to a variable, wrapped in a `VariableIndex`. + Get all indices associated to a symbol, wrapped in a `VariableIndex`. See also :py:meth:`get_indices`, :py:class:`polymatrix.polymatrix.index.VariableIndex`. """ yield from (VariableIndex(index=i, power=1) - for i in self.get_indices(var)) + for i in self.get_indices(sym)) - def get_indices_as_monomial_index(self, var: Variable) -> Iterable[MonomialIndex]: + def get_indices_as_monomial_index(self, var: Symbol) -> Iterable[MonomialIndex]: """ - Get all indices associated to a variable, wrapped in a `MonomialIndex`. + Get all indices associated to a symbol, wrapped in a `MonomialIndex`. See also :py:meth:`get_indices`, :py:class:`polymatrix.polymatrix.index.MonomialIndex`. @@ -100,56 +122,62 @@ class ExpressionState(StateCacheMixin): yield from (MonomialIndex((v,)) for v in self.get_indices_as_variable_index(var)) - def get_variable(self, index: int) -> Variable: - """ Get the variable object from its index. """ - for variable, (start, end) in self.indices.items(): - if start <= index < end: - return variable + # --- retrieval of shapes --- + + def get_shape(self, sym: Symbol) -> tuple[int, int]: + if sym not in self.indices: + raise IndexError(f"There is no symbol {sym} in this state object.") - raise IndexError(f"There is no variable with index {index}.") + idx = self.indices[sym] + nrows = (idx.stop - idx.start) / idx.ncols + # FIXME: error message + assert nrows > 0 and math.isclose(int(nrows), nrows), ( + "State has inconsistent indices, this is an internal " + "problem. Something went wrong.") + return (int(nrows), idx.ncols) - def get_variable_from_variable_index(self, var: VariableIndex) -> Variable: - """ Get the variable object from the index contained in a `VariableIndex` """ - return self.get_variable(var.index) + # --- retrieval of symbols --- - def get_variables_from_monomial_index(self, monomial: MonomialIndex) -> Iterable[Variable]: - """ Get all variable objects from the indices contained in a `MonomialIndex` """ + def get_symbol(self, index: int) -> Symbol: + """ Get the symbol that contains the the given index. """ + for symbol, (start, stop, _) in self.indices.items(): + if start <= index < stop: + return symbol + + raise IndexError(f"There is no symbol with index {index}.") + + def get_symbol_from_variable_index(self, var: VariableIndex) -> Symbol: + """ Get the symbol that contains the index inside of a `VariableIndex` """ + return self.get_symbol(var.index) + + def get_symbols_from_monomial_index(self, monomial: MonomialIndex) -> set[Symbol]: + """ Get all symbols that contain the indices inside of a `MonomialIndex` """ + symbols = set() for v in monomial: - # FIXME: non-scalar variable will be yielded multiple times - yield self.get_variable_from_variable_index(v) + symbols.add(self.get_symbol_from_variable_index(v)) + + return symbols - def get_variable_from_name_or(self, name: str, if_not_present: Any) -> Variable | Any: + def get_symbol_from_name_or(self, name: str | Symbol, if_not_present: Any) -> Symbol | Any: """ - Get a variable object given its name, or if there is no variable with - the given name return what is passed in the `if_not_present` argument. + Get a symbol given its name, or if there is no symbol with the given + name return what is passed in the `if_not_present` argument. """ - for v in self.indices.keys(): - if v.name == name: - return v + for s in self.indices.keys(): + if s == name: + return s return if_not_present - def get_variable_from_name(self, name: str) -> Variable: + def get_symbol_from_name(self, name: str) -> Symbol: """ - Get a variable object given its name, raises KeyError if there is no - variable with the given name. + Get a symbol given its name, raises KeyError if there is no symbol with + the given name. """ - if v := self.get_variable_from_name_or(name, False): + if v := self.get_symbol_from_name_or(name, False): return v - raise KeyError(f"There is no variable named {name}") - - def get_name(self, index: int) -> str: - """ Get the name of a variable given its index. """ - for variable, (start, end) in self.indices.items(): - if start <= index < end: - # Variable is not scalar - if end - start > 1: - return f"{variable.name}_{index - start}" - - return variable.name - - raise IndexError(f"There is no variable with index {index}.") + raise KeyError(f"There is no symbol named {name}") # -- Old API --- @@ -169,11 +197,11 @@ class ExpressionState(StateCacheMixin): return {} @deprecated("replaced by get_variable") - def get_key_from_offset(self, index: int) -> Variable: - return self.get_variable(index) + def get_key_from_offset(self, index: int) -> Symbol: + return self.get_symbol(index) -def init_expression_state(n_variables: int = 0, indices: dict[Variable, IndexRange] = {}): +def init_expression_state(n_variables: int = 0, indices: dict[Symbol, IndexRange] = {}): return ExpressionState( n_variables=n_variables, indices=indices, diff --git a/polymatrix/symbol.py b/polymatrix/symbol.py new file mode 100644 index 0000000..7e53e3a --- /dev/null +++ b/polymatrix/symbol.py @@ -0,0 +1,2 @@ +class Symbol(str): + """ Symbol (name) of a variable. """ diff --git a/polymatrix/variable.py b/polymatrix/variable.py deleted file mode 100644 index e320313..0000000 --- a/polymatrix/variable.py +++ /dev/null @@ -1,23 +0,0 @@ -from abc import ABC, abstractmethod -from dataclassabc import dataclassabc - -class Variable(ABC): - @property - @abstractmethod - def name(self) -> str: - """ Name of the variable. """ - - @property - @abstractmethod - def shape(self) -> tuple[int, int]: - """ Shape of the variable. """ - -# FIXME: there is a class with the same name in expression.impl -@dataclassabc(frozen=True) -class VariableImpl(Variable): - name: str - shape: tuple[int, int] - - -def init_variable(name: str, shape: tuple[int, int]) -> Variable: - return VariableImpl(name, shape) -- cgit v1.2.1