summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-01-30 16:19:24 +0100
committerMichael Schneeberger <michael.schneeberger@fhnw.ch>2023-01-30 16:19:24 +0100
commit09175e1f03dc260f743de28d47a16d2f5a97bf38 (patch)
tree09320967339e60467c3f6f2a2fb6d999f112e8b5
parentupdate README (diff)
downloadpolymatrix-09175e1f03dc260f743de28d47a16d2f5a97bf38.tar.gz
polymatrix-09175e1f03dc260f743de28d47a16d2f5a97bf38.zip
bugfix in eval/substitution operator
-rw-r--r--polymatrix/__init__.py52
-rw-r--r--polymatrix/expression/impl/__init__.py0
-rw-r--r--polymatrix/expression/impl/evalexprimpl.py4
-rw-r--r--polymatrix/expression/impl/substituteexprimpl.py2
-rw-r--r--polymatrix/expression/init/initderivativeexpr.py5
-rw-r--r--polymatrix/expression/init/initevalexpr.py182
-rw-r--r--polymatrix/expression/init/initfromsympyexpr.py6
-rw-r--r--polymatrix/expression/init/initlinearinexpr.py2
-rw-r--r--polymatrix/expression/init/initlinearmonomialsexpr.py7
-rw-r--r--polymatrix/expression/init/initquadraticinexpr.py7
-rw-r--r--polymatrix/expression/init/initquadraticmonomialsexpr.py7
-rw-r--r--polymatrix/expression/init/initsubstituteexpr.py107
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/divergenceexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py60
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py70
-rw-r--r--polymatrix/expression/mixins/filterexprmixin.py6
-rw-r--r--polymatrix/expression/mixins/linearinexprmixin.py26
-rw-r--r--polymatrix/expression/mixins/linearmatrixinexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/linearmonomialsexprmixin.py23
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py4
-rw-r--r--polymatrix/expression/mixins/quadraticmonomialsexprmixin.py23
-rw-r--r--polymatrix/expression/mixins/substituteexprmixin.py48
-rw-r--r--polymatrix/expression/mixins/truncateexprmixin.py4
-rw-r--r--polymatrix/expression/utils/getvariableindices.py108
-rw-r--r--polymatrix/expression/utils/mergemonomialindices.py5
-rw-r--r--polymatrix/expressionstate/mixins/expressionstatemixin.py9
27 files changed, 559 insertions, 222 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 0c15ae3..e6cf4a8 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -17,7 +17,7 @@ from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr
from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
from polymatrix.expression.init.initvstackexpr import init_v_stack_expr
from polymatrix.polymatrix.polymatrix import PolyMatrix
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices, get_variable_indices_from_variable
from polymatrix.statemonad.init.initstatemonad import init_state_monad
from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
from polymatrix.expression.utils.monomialtoindex import monomial_to_index
@@ -363,7 +363,7 @@ class MatrixRepresentations:
return dict(gen_matrices())
def get_value(self, variable, value):
- variable_indices = get_variable_indices(self.state, variable)[1]
+ variable_indices = get_variable_indices_from_variable(self.state, variable)[1]
def gen_value_index():
for variable_index in variable_indices:
@@ -377,7 +377,7 @@ class MatrixRepresentations:
return value[value_index]
def set_value(self, variable, value):
- variable_indices = get_variable_indices(self.state, variable)[1]
+ variable_indices = get_variable_indices_from_variable(self.state, variable)[1]
value_index = list(self.variable_mapping.index(variable_index) for variable_index in variable_indices)
vec = np.zeros(len(self.variable_mapping))
vec[value_index] = value
@@ -430,13 +430,15 @@ class MatrixRepresentations:
return func
def to_matrix_repr(
- expressions: tuple[Expression],
+ expressions: Expression | tuple[Expression],
variables: Expression,
) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]:
if isinstance(expressions, Expression):
expressions = (expressions,)
+ assert isinstance(variables, Expression), f'{variables=}'
+
def func(state: ExpressionState):
def acc_underlying_application(acc, v):
@@ -454,39 +456,7 @@ def to_matrix_repr(
initial=(state, tuple()),
))
- # if variables is None:
-
- # def gen_used_variables():
- # def gen_used_auxillary_variables(considered):
- # monomial_terms = state.auxillary_equations[considered[-1]]
- # for monomial in monomial_terms.keys():
- # for variable in monomial:
- # yield variable
-
- # if variable not in considered and variable in state.auxillary_equations:
- # yield from gen_used_auxillary_variables(considered + (variable,))
-
- # for underlying in underlying_list:
- # for row in range(underlying.shape[0]):
- # for col in range(underlying.shape[1]):
-
- # try:
- # underlying_terms = underlying.get_poly(row, col)
- # except KeyError:
- # continue
-
- # for monomial in underlying_terms.keys():
- # for variable, _ in monomial:
- # yield variable
-
- # if variable in state.auxillary_equations:
- # yield from gen_used_auxillary_variables((variable,))
-
- # ordered_variable_index = tuple(sorted(set(gen_used_variables())))
-
- # else:
-
- state, ordered_variable_index = get_variable_indices(state, variables)
+ state, ordered_variable_index = get_variable_indices_from_variable(state, variables)
assert len(ordered_variable_index) == len(set(ordered_variable_index)), f'{ordered_variable_index=} contains repeated variables'
@@ -534,8 +504,6 @@ def to_matrix_repr(
underlying_matrices = tuple(gen_underlying_matrices())
- # current_row = underlying.shape[0]
-
def gen_auxillary_equations():
for key, monomial_terms in state.auxillary_equations.items():
if key in ordered_variable_index:
@@ -581,6 +549,7 @@ def to_matrix_repr(
def to_constant_repr(
expr: Expression,
+ assert_constant: bool = True,
) -> StateMonadMixin[ExpressionState, np.ndarray]:
def func(state: ExpressionState):
@@ -592,6 +561,9 @@ def to_constant_repr(
for monomial, value in polynomial.items():
if len(monomial) == 0:
A[row, col] = value
+
+ elif assert_constant:
+ raise Exception(f'non-constant term {monomial=}')
return state, A
@@ -605,7 +577,7 @@ def degrees(
def func(state: ExpressionState):
state, underlying = expr.apply(state)
- state, variable_indices = get_variable_indices(state, variables)
+ state, variable_indices = get_variable_indices_from_variable(state, variables)
def gen_rows():
for row in range(underlying.shape[0]):
diff --git a/polymatrix/expression/impl/__init__.py b/polymatrix/expression/impl/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/polymatrix/expression/impl/__init__.py
diff --git a/polymatrix/expression/impl/evalexprimpl.py b/polymatrix/expression/impl/evalexprimpl.py
index cd97155..7f13a78 100644
--- a/polymatrix/expression/impl/evalexprimpl.py
+++ b/polymatrix/expression/impl/evalexprimpl.py
@@ -6,5 +6,5 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@dataclass_abc.dataclass_abc(frozen=True)
class EvalExprImpl(EvalExpr):
underlying: ExpressionBaseMixin
- variables: tuple
- values: tuple
+ substitutions: tuple
+ # values: tuple
diff --git a/polymatrix/expression/impl/substituteexprimpl.py b/polymatrix/expression/impl/substituteexprimpl.py
index 1eae940..086f17e 100644
--- a/polymatrix/expression/impl/substituteexprimpl.py
+++ b/polymatrix/expression/impl/substituteexprimpl.py
@@ -6,5 +6,5 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@dataclass_abc.dataclass_abc(frozen=True)
class SubstituteExprImpl(SubstituteExpr):
underlying: ExpressionBaseMixin
- variables: tuple
+ # variables: tuple
substitutions: tuple
diff --git a/polymatrix/expression/init/initderivativeexpr.py b/polymatrix/expression/init/initderivativeexpr.py
index c640f47..a6ca06c 100644
--- a/polymatrix/expression/init/initderivativeexpr.py
+++ b/polymatrix/expression/init/initderivativeexpr.py
@@ -4,9 +4,12 @@ from polymatrix.expression.impl.derivativeexprimpl import DerivativeExprImpl
def init_derivative_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
introduce_derivatives: bool = None,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
if introduce_derivatives is None:
introduce_derivatives = False
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py
index 5bbd404..0fad8e5 100644
--- a/polymatrix/expression/init/initevalexpr.py
+++ b/polymatrix/expression/init/initevalexpr.py
@@ -1,55 +1,171 @@
+import typing
import numpy as np
+from polymatrix.expression.init.initsubstituteexpr import format_substitutions
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl.evalexprimpl import EvalExprImpl
def init_eval_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
- values: tuple = None,
+ variables: typing.Union[typing.Any, tuple, dict],
+ values: typing.Union[float, tuple] = None,
):
- if values is None:
- if isinstance(variables, tuple):
- variables, values = tuple(zip(*variables))
+ substitutions = format_substitutions(
+ variables=variables,
+ values=values,
+ )
+
+ def formatted_values(value):
+ if isinstance(value, np.ndarray):
+ return tuple(value.reshape(-1))
+
+ elif isinstance(value, tuple):
+ return value
- elif isinstance(variables, dict):
- variables, values = tuple(zip(*variables.items()))
+ elif isinstance(value, int) or isinstance(value, float):
+ return (value,)
else:
- raise Exception(f'unsupported case {variables=}')
+ return (float(value),)
+
+ substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions)
+
+ return EvalExprImpl(
+ underlying=underlying,
+ substitutions=substitutions,
+ )
- elif isinstance(values, np.ndarray):
- values = tuple(values.reshape(-1))
+ # if values is not None:
+ # if isinstance(variables, tuple):
+ # if isinstance(values, tuple):
+ # assert len(variables) == len(values), f'{variables=}, {values=}'
- elif not isinstance(values, tuple):
- values = (values,)
+ # else:
+ # values = tuple(values for _ in variables)
- if not isinstance(variables, tuple):
- variables = (variables,)
+ # else:
+ # variables = (variables,)
+ # values = (values,)
- def gen_formatted_values():
- for value in values:
- if isinstance(value, np.ndarray):
- yield from value.reshape(-1)
+ # subs = zip(variables, values)
+
+ # elif isinstance(variables, dict):
+ # subs = variables.items()
- elif isinstance(value, tuple):
- yield from value
+ # elif isinstance(variables, tuple):
+ # subs = variables
- elif isinstance(value, dict):
- for variable in variables:
- yield from value[variable]
+ # else:
+ # raise Exception(f'{variables=}')
- elif isinstance(value, int) or isinstance(value, float):
- yield value
+ # def formatted_values(value):
+ # if isinstance(value, np.ndarray):
+ # return tuple(value.reshape(-1))
- else:
- yield float(value)
+ # elif isinstance(value, tuple):
+ # return value
- values = tuple(gen_formatted_values())
+ # elif isinstance(value, int) or isinstance(value, float):
+ # return (value,)
- return EvalExprImpl(
- underlying=underlying,
- variables=variables,
- values=values,
-)
+ # else:
+ # return (float(value),)
+
+ # subs = tuple((var, formatted_values(val)) for var, val in subs)
+
+ # def formatted_values(value):
+ # # def gen_formatted_values():
+ # # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield tuple(value.reshape(-1))
+
+ # elif isinstance(value, tuple):
+ # yield value
+
+ # # elif isinstance(value, dict):
+ # # for variable in variables:
+ # # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield (value,)
+
+ # else:
+ # yield (float(value),)
+ # return tuple(gen_formatted_values())
+
+
+ # if values is None:
+ # if isinstance(variables, tuple):
+ # variables, values = tuple(zip(*variables))
+
+ # elif isinstance(variables, dict):
+ # variables, values = tuple(zip(*variables.items()))
+
+ # else:
+ # raise Exception(f'unsupported case {variables=}')
+
+ # elif isinstance(values, np.ndarray):
+ # values = tuple(values.reshape(-1))
+
+ # elif not isinstance(values, tuple):
+ # values = (values,)
+
+ # if not isinstance(variables, tuple):
+ # variables = (variables,)
+
+ # def gen_formatted_values():
+ # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield tuple(value.reshape(-1))
+
+ # elif isinstance(value, tuple):
+ # yield value
+
+ # elif isinstance(value, dict):
+ # raise Exception('is this right?')
+
+ # for variable in variables:
+ # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield (value,)
+
+ # else:
+ # yield (float(value),)
+
+ # values = tuple(gen_formatted_values())
+
+ # if len(values) == 1:
+ # values = tuple((values[0],) for _ in variables)
+
+ # else:
+ # assert len(variables) == len(values), f'length of {variables} does not match length of {values}'
+
+ # def gen_flattened_values():
+ # for value in values:
+ # if isinstance(value, np.ndarray):
+ # yield from value.reshape(-1)
+
+ # elif isinstance(value, tuple):
+ # yield from value
+
+ # elif isinstance(value, dict):
+ # raise Exception('is this right?')
+
+ # for variable in variables:
+ # yield from value[variable]
+
+ # elif isinstance(value, int) or isinstance(value, float):
+ # yield value
+
+ # else:
+ # yield float(value)
+
+ # values = tuple(gen_flattened_values())
+
+# return EvalExprImpl(
+# underlying=underlying,
+# variables=variables,
+# values=values,
+# )
diff --git a/polymatrix/expression/init/initfromsympyexpr.py b/polymatrix/expression/init/initfromsympyexpr.py
index bb37f1d..3fb52f7 100644
--- a/polymatrix/expression/init/initfromsympyexpr.py
+++ b/polymatrix/expression/init/initfromsympyexpr.py
@@ -36,7 +36,13 @@ def init_from_sympy_expr(
case _:
data = tuple((e,) for e in data)
+ case np.number:
+ data = ((float(data),),)
+
case _:
+ if not isinstance(data, (float, int, sympy.Expr)):
+ raise Exception(f'{data=}, {type(data)=}')
+
data = ((data,),)
return FromSympyExprImpl(
diff --git a/polymatrix/expression/init/initlinearinexpr.py b/polymatrix/expression/init/initlinearinexpr.py
index 5cd172c..b869aee 100644
--- a/polymatrix/expression/init/initlinearinexpr.py
+++ b/polymatrix/expression/init/initlinearinexpr.py
@@ -8,6 +8,8 @@ def init_linear_in_expr(
variables: ExpressionBaseMixin,
ignore_unmatched: bool = None,
):
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return LinearInExprImpl(
underlying=underlying,
monomials=monomials,
diff --git a/polymatrix/expression/init/initlinearmonomialsexpr.py b/polymatrix/expression/init/initlinearmonomialsexpr.py
index f116562..8083715 100644
--- a/polymatrix/expression/init/initlinearmonomialsexpr.py
+++ b/polymatrix/expression/init/initlinearmonomialsexpr.py
@@ -4,9 +4,12 @@ from polymatrix.expression.impl.linearmonomialsexprimpl import LinearMonomialsEx
def init_linear_monomials_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return LinearMonomialsExprImpl(
underlying=underlying,
variables=variables,
-)
+ )
diff --git a/polymatrix/expression/init/initquadraticinexpr.py b/polymatrix/expression/init/initquadraticinexpr.py
index 5aa40a5..6555b4b 100644
--- a/polymatrix/expression/init/initquadraticinexpr.py
+++ b/polymatrix/expression/init/initquadraticinexpr.py
@@ -5,10 +5,13 @@ from polymatrix.expression.impl.quadraticinexprimpl import QuadraticInExprImpl
def init_quadratic_in_expr(
underlying: ExpressionBaseMixin,
monomials: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return QuadraticInExprImpl(
underlying=underlying,
monomials=monomials,
variables=variables,
-)
+ )
diff --git a/polymatrix/expression/init/initquadraticmonomialsexpr.py b/polymatrix/expression/init/initquadraticmonomialsexpr.py
index 8e46c62..190f7df 100644
--- a/polymatrix/expression/init/initquadraticmonomialsexpr.py
+++ b/polymatrix/expression/init/initquadraticmonomialsexpr.py
@@ -4,9 +4,12 @@ from polymatrix.expression.impl.quadraticmonomialsexprimpl import QuadraticMonom
def init_quadratic_monomials_expr(
underlying: ExpressionBaseMixin,
- variables: tuple,
+ variables: ExpressionBaseMixin,
):
+
+ assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+
return QuadraticMonomialsExprImpl(
underlying=underlying,
variables=variables,
-)
+ )
diff --git a/polymatrix/expression/init/initsubstituteexpr.py b/polymatrix/expression/init/initsubstituteexpr.py
index 403b169..50cbee0 100644
--- a/polymatrix/expression/init/initsubstituteexpr.py
+++ b/polymatrix/expression/init/initsubstituteexpr.py
@@ -1,3 +1,4 @@
+import typing
import numpy as np
from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr
@@ -5,35 +6,97 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expression.impl.substituteexprimpl import SubstituteExprImpl
+def format_substitutions(
+ variables: typing.Union[typing.Any, tuple, dict],
+ values: typing.Union[float, tuple] = None,
+):
+ """
+ (variables = x, values = 1.0) # ok
+ (variables = x, values = np.array(1.0)) # ok
+ (variables = (x, y, z), values = 1.0) # ok
+ (variables = (x, y, z), values = (1.0, 2.0, 3.0)) # ok
+ (variables = {x: 1.0, y: 2.0, z: 3.0}) # ok
+ (variables = ((x, 1.0), (y, 2.0), (z, 3.0))) # ok
+
+ (variables = v, values = (1.0, 2.0)) # ok
+ (variables = (v1, v2), values = ((1.0, 2.0), (3.0,))) # ok
+ (variables = (v1, v2), values = (1.0, 2.0, 3.0)) # not ok
+ """
+
+ if values is not None:
+ if isinstance(variables, tuple):
+ if isinstance(values, tuple):
+ assert len(variables) == len(values), f'{variables=}, {values=}'
+
+ else:
+ values = tuple(values for _ in variables)
+
+ else:
+ variables = (variables,)
+ values = (values,)
+
+ substitutions = zip(variables, values)
+
+ elif isinstance(variables, dict):
+ substitutions = variables.items()
+
+ elif isinstance(variables, tuple):
+ substitutions = variables
+
+ else:
+ raise Exception(f'{variables=}')
+
+ return substitutions
+
+
def init_substitute_expr(
underlying: ExpressionBaseMixin,
variables: tuple,
- substitutions: tuple = None,
+ values: tuple = None,
):
- if substitutions is None:
- assert isinstance(variables, tuple)
-
- if len(variables) == 0:
- return underlying
-
- variables, substitutions = tuple(zip(*variables))
- elif isinstance(substitutions, np.ndarray):
- substitutions = tuple(substitutions.reshape(-1))
+ substitutions = format_substitutions(
+ variables=variables,
+ values=values,
+ )
- elif not isinstance(substitutions, tuple):
- substitutions = (substitutions,)
+ def formatted_values(value) -> ExpressionBaseMixin:
+ if isinstance(value, ExpressionBaseMixin):
+ return value
+ else:
+ return init_from_sympy_expr(value)
- def gen_substitutions():
- for substitution in substitutions:
- match substitution:
- case ExpressionBaseMixin():
- yield substitution
- case _:
- yield init_from_sympy_expr(substitution)
+ substitutions = tuple((variable, formatted_values(value)) for variable, value in substitutions)
return SubstituteExprImpl(
underlying=underlying,
- variables=variables,
- substitutions=tuple(gen_substitutions()),
-)
+ substitutions=substitutions,
+ )
+
+ # if values is None:
+ # assert isinstance(variables, tuple)
+
+ # if len(variables) == 0:
+ # return underlying
+
+ # variables, values = tuple(zip(*variables))
+
+ # elif isinstance(values, np.ndarray):
+ # values = tuple(values.reshape(-1))
+
+ # elif not isinstance(values, tuple):
+ # values = (values,)
+
+ # def gen_substitutions():
+ # for substitution in values:
+ # match substitution:
+ # case ExpressionBaseMixin():
+ # yield substitution
+ # case _:
+ # yield init_from_sympy_expr(substitution)
+
+ # return SubstituteExprImpl(
+ # underlying=underlying,
+ # variables=variables,
+ # substitutions=tuple(gen_substitutions()),
+ # )
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 5fce215..183608b 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
class DerivativeExprMixin(ExpressionBaseMixin):
@@ -18,7 +18,7 @@ class DerivativeExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> typing.Union[tuple, ExpressionBaseMixin]:
+ def variables(self) -> ExpressionBaseMixin:
...
@property
@@ -33,7 +33,7 @@ class DerivativeExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, diff_wrt_variables = get_variable_indices(state, self.variables)
+ state, diff_wrt_variables = get_variable_indices_from_variable(state, self.variables)
assert underlying.shape[1] == 1, f'{underlying.shape=}'
diff --git a/polymatrix/expression/mixins/divergenceexprmixin.py b/polymatrix/expression/mixins/divergenceexprmixin.py
index 3002109..8ada607 100644
--- a/polymatrix/expression/mixins/divergenceexprmixin.py
+++ b/polymatrix/expression/mixins/divergenceexprmixin.py
@@ -8,7 +8,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.utils.getderivativemonomials import get_derivative_monomials
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
class DivergenceExprMixin(ExpressionBaseMixin):
@@ -29,7 +29,7 @@ class DivergenceExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, variables = get_variable_indices(state, self.variables)
+ state, variables = get_variable_indices_from_variable(state, self.variables)
assert underlying.shape[1] == 1, f'{underlying.shape=}'
assert len(variables) == underlying.shape[0], f'{variables=}, {underlying.shape=}'
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index 2358566..cbf4ce8 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
class EvalExprMixin(ExpressionBaseMixin):
@@ -18,12 +18,7 @@ class EvalExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> tuple:
- ...
-
- @property
- @abc.abstractmethod
- def values(self) -> tuple[float, ...]:
+ def substitutions(self) -> tuple:
...
# overwrites abstract method of `ExpressionBaseMixin`
@@ -31,17 +26,32 @@ class EvalExprMixin(ExpressionBaseMixin):
self,
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
+
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices(state, self.variables)
- if len(self.values) == 1:
- values = tuple(self.values[0] for _ in variable_indices)
+ def acc_variable_indices_and_values(acc, next):
+ state, acc_indices, acc_values = acc
+ variable, values = next
- else:
- assert len(variable_indices) == len(self.values), f'length of {variable_indices} does not match length of {self.values}'
+ state, indices = get_variable_indices_from_variable(state, variable)
- values = self.values
+ if indices is None:
+ return acc
+ if len(values) == 1:
+ values = tuple(values[0] for _ in indices)
+
+ else:
+ assert len(indices) == len(values), f'{variable=}, {indices=} ({len(indices)}), {values=} ({len(values)})'
+
+ return state, acc_indices + indices, acc_values + values
+
+ *_, (state, variable_indices, values) = itertools.accumulate(
+ self.substitutions,
+ acc_variable_indices_and_values,
+ initial=(state, tuple(), tuple())
+ )
+
terms = {}
for row in range(underlying.shape[0]):
@@ -91,3 +101,27 @@ class EvalExprMixin(ExpressionBaseMixin):
)
return state, poly_matrix
+
+
+ # if len(self.values) == 1:
+ # values = tuple(self.values[0] for _ in self.variables)
+
+ # else:
+ # values = self.values
+
+ # def filter_valid_variables():
+ # for var, val in zip(self.variables, self.values):
+ # if isinstance(var, ExpressionBaseMixin) or isinstance(var, int) or (var in state.offset_dict):
+ # yield var, val
+
+ # variables, values = zip(*filter_valid_variables())
+
+ # state, variable_indices = get_variable_indices(state, self.variables)
+
+ # if len(self.values) == 1:
+ # values = tuple(self.values[0] for _ in variable_indices)
+
+ # else:
+ # assert len(variable_indices) == len(self.values), f'length of {variable_indices} does not match length of {self.values}'
+
+ # values = self.values
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index 5602294..b828d6a 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -2,6 +2,7 @@ import abc
import dataclasses
import typing
import numpy as np
+import sympy
from polymatrix.expression.init.initadditionexpr import init_addition_expr
from polymatrix.expression.init.initcacheexpr import init_cache_expr
@@ -27,6 +28,7 @@ from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr
from polymatrix.expression.init.initreshapeexpr import init_reshape_expr
from polymatrix.expression.init.initsetelementatexpr import init_set_element_at_expr
from polymatrix.expression.init.initquadraticmonomialsexpr import init_quadratic_monomials_expr
+from polymatrix.expression.init.initsqueezeexpr import init_squeeze_expr
from polymatrix.expression.init.initsubstituteexpr import init_substitute_expr
from polymatrix.expression.init.initsubtractmonomialsexpr import init_subtract_monomials_expr
from polymatrix.expression.init.initsumexpr import init_sum_expr
@@ -60,11 +62,7 @@ class ExpressionMixin(
if other is None:
return self
- match other:
- case ExpressionBaseMixin():
- right = other
- case _:
- right = init_from_sympy_expr(other)
+ right = self._convert_to_expression(other)
return dataclasses.replace(
self,
@@ -96,11 +94,7 @@ class ExpressionMixin(
)
def __matmul__(self, other: typing.Union[ExpressionBaseMixin, np.ndarray]) -> 'ExpressionMixin':
- match other:
- case ExpressionBaseMixin():
- right = other
- case _:
- right = init_from_sympy_expr(other)
+ right = self._convert_to_expression(other)
return dataclasses.replace(
self,
@@ -111,13 +105,7 @@ class ExpressionMixin(
)
def __mul__(self, other) -> 'ExpressionMixin':
- # assert isinstance(other, float)
-
- match other:
- case ExpressionBaseMixin():
- right = other
- case _:
- right = init_from_sympy_expr(other)
+ right = self._convert_to_expression(other)
return dataclasses.replace(
self,
@@ -127,6 +115,14 @@ class ExpressionMixin(
),
)
+ def __pow__(self, num):
+ curr = 1
+
+ for _ in range(num):
+ curr = curr * self
+
+ return curr
+
def __neg__(self):
return self * (-1)
@@ -134,9 +130,9 @@ class ExpressionMixin(
return self + other
def __rmatmul__(self, other):
- other = init_from_sympy_expr(other)
+ left = self._convert_to_expression(other)
- return other @ self
+ return left @ self
def __rmul__(self, other):
return self * other
@@ -145,11 +141,7 @@ class ExpressionMixin(
return self + other * (-1)
def __truediv__(self, other: ExpressionBaseMixin):
- match other:
- case ExpressionBaseMixin():
- right = other
- case _:
- right = init_from_sympy_expr(other)
+ right = self._convert_to_expression(other)
return dataclasses.replace(
self,
@@ -159,6 +151,18 @@ class ExpressionMixin(
),
)
+ def _convert_to_expression(self, other):
+ if isinstance(other, ExpressionBaseMixin):
+ return other
+
+ # can_convert = isinstance(other, (float, int, sympy.Expr, np.ndarray))
+
+ # if not can_convert:
+ # raise Exception(f'{other} cannot be converted to an Expression')
+ # else:
+
+ return init_from_sympy_expr(other)
+
def cache(self) -> 'ExpressionMixin':
return dataclasses.replace(
self,
@@ -371,6 +375,16 @@ class ExpressionMixin(
),
)
+ def squeeze(
+ self,
+ ) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_squeeze_expr(
+ underlying=self.underlying,
+ ),
+ )
+
def subtract_monomials(
self,
monomials: 'ExpressionMixin',
@@ -386,25 +400,25 @@ class ExpressionMixin(
def substitute(
self,
variable: tuple,
- substitutions: tuple['ExpressionMixin', ...] = None,
+ values: tuple['ExpressionMixin', ...] = None,
) -> 'ExpressionMixin':
return dataclasses.replace(
self,
underlying=init_substitute_expr(
underlying=self.underlying,
variables=variable,
- substitutions=substitutions,
+ values=values,
),
)
def subs(
self,
variable: tuple,
- substitutions: tuple['ExpressionMixin', ...] = None,
+ values: tuple['ExpressionMixin', ...] = None,
) -> 'ExpressionMixin':
return self.substitute(
variable=variable,
- substitutions=substitutions,
+ values=values,
)
def sum(self):
diff --git a/polymatrix/expression/mixins/filterexprmixin.py b/polymatrix/expression/mixins/filterexprmixin.py
index 5083073..ec549a8 100644
--- a/polymatrix/expression/mixins/filterexprmixin.py
+++ b/polymatrix/expression/mixins/filterexprmixin.py
@@ -1,9 +1,5 @@
import abc
-import collections
-import math
-import typing
-import dataclass_abc
from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -12,6 +8,7 @@ from polymatrix.expressionstate.expressionstate import ExpressionState
class FilterExprMixin(ExpressionBaseMixin):
+
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
@@ -45,6 +42,7 @@ class FilterExprMixin(ExpressionBaseMixin):
for row in range(underlying.shape[0]):
underlying_terms = underlying.get_poly(row, 0)
+
if underlying_terms is None:
continue
diff --git a/polymatrix/expression/mixins/linearinexprmixin.py b/polymatrix/expression/mixins/linearinexprmixin.py
index cd94761..9bcabbe 100644
--- a/polymatrix/expression/mixins/linearinexprmixin.py
+++ b/polymatrix/expression/mixins/linearinexprmixin.py
@@ -7,10 +7,30 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.utils.getmonomialindices import get_monomial_indices
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
class LinearInExprMixin(ExpressionBaseMixin):
+ """
+ Maps a polynomial column vector
+
+ underlying = [
+ [1 + a x],
+ [x^2 ],
+ ]
+
+ into a polynomial matrix
+
+ output = [
+ [1, a, 0],
+ [0, 0, 1],
+ ],
+
+ where each column corresponds to a monomial defined by
+
+ monomials = [1, x, x^2].
+ """
+
@property
@abc.abstractmethod
def underlying(self) -> ExpressionBaseMixin:
@@ -23,7 +43,7 @@ class LinearInExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> tuple:
+ def variables(self) -> ExpressionBaseMixin:
...
@property
@@ -39,7 +59,7 @@ class LinearInExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state=state)
state, monomials = get_monomial_indices(state, self.monomials)
- state, variable_indices = get_variable_indices(state, self.variables)
+ state, variable_indices = get_variable_indices_from_variable(state, self.variables)
assert underlying.shape[1] == 1
diff --git a/polymatrix/expression/mixins/linearmatrixinexprmixin.py b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
index d013722..f5bd18f 100644
--- a/polymatrix/expression/mixins/linearmatrixinexprmixin.py
+++ b/polymatrix/expression/mixins/linearmatrixinexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
class LinearMatrixInExprMixin(ExpressionBaseMixin):
@@ -27,7 +27,7 @@ class LinearMatrixInExprMixin(ExpressionBaseMixin):
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, variable_index = get_variable_indices(state, variables=self.variable)
+ state, variable_index = get_variable_indices_from_variable(state, variables=self.variable)
assert len(variable_index) == 1
diff --git a/polymatrix/expression/mixins/linearmonomialsexprmixin.py b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
index b309ccf..350ac6d 100644
--- a/polymatrix/expression/mixins/linearmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/linearmonomialsexprmixin.py
@@ -6,11 +6,28 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin
from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.expression.utils.sortmonomials import sort_monomials
class LinearMonomialsExprMixin(ExpressionBaseMixin):
+ """
+ Maps a polynomial matrix
+
+ underlying = [
+ [1, a x ],
+ [x^2, x + x^2],
+ ]
+
+ into a vector of monomials
+
+ output = [1, x, x^2]
+
+ in variable
+
+ variables = [x].
+ """
+
@property
@abc.abstractclassmethod
def underlying(self) -> ExpressionBaseMixin:
@@ -18,7 +35,7 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> tuple:
+ def variables(self) -> ExpressionBaseMixin:
...
# overwrites abstract method of `ExpressionBaseMixin`
@@ -28,7 +45,7 @@ class LinearMonomialsExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices(state, self.variables)
+ state, variable_indices = get_variable_indices_from_variable(state, self.variables)
def gen_linear_monomials():
for row in range(underlying.shape[0]):
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index 130ab79..5c32f45 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -7,7 +7,7 @@ from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.utils.getmonomialindices import get_monomial_indices
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices
@@ -35,7 +35,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
state, underlying = self.underlying.apply(state=state)
state, sos_monomials = get_monomial_indices(state, self.monomials)
- state, variable_indices = get_variable_indices(state, self.variables)
+ state, variable_indices = get_variable_indices_from_variable(state, self.variables)
assert underlying.shape == (1, 1), f'underlying shape is {underlying.shape}'
diff --git a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
index bddc321..81fbcb9 100644
--- a/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticmonomialsexprmixin.py
@@ -6,11 +6,28 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionStateMixin
from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.expression.utils.splitmonomialindices import split_monomial_indices
class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
+ """
+ Maps a polynomial matrix
+
+ underlying = [
+ [x y ],
+ [x + x^2],
+ ]
+
+ into a vector of monomials
+
+ output = [1, x, y]
+
+ in variable
+
+ variables = [x, y].
+ """
+
@property
@abc.abstractclassmethod
def underlying(self) -> ExpressionBaseMixin:
@@ -18,7 +35,7 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> tuple:
+ def variables(self) -> ExpressionBaseMixin:
...
# overwrites abstract method of `ExpressionBaseMixin`
@@ -28,7 +45,7 @@ class QuadraticMonomialsExprMixin(ExpressionBaseMixin):
) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices(state, self.variables)
+ state, variable_indices = get_variable_indices_from_variable(state, self.variables)
def gen_sos_monomials():
for row in range(underlying.shape[0]):
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
index 9741897..6c0beae 100644
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ b/polymatrix/expression/mixins/substituteexprmixin.py
@@ -3,12 +3,13 @@ import abc
import collections
import itertools
import math
+import typing
from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial
@@ -20,12 +21,7 @@ class SubstituteExprMixin(ExpressionBaseMixin):
@property
@abc.abstractmethod
- def variables(self) -> tuple:
- ...
-
- @property
- @abc.abstractmethod
- def substitutions(self) -> tuple[ExpressionBaseMixin, ...]:
+ def substitutions(self) -> tuple[tuple[typing.Any, ExpressionBaseMixin], ...]:
...
# overwrites abstract method of `ExpressionBaseMixin`
@@ -34,33 +30,40 @@ class SubstituteExprMixin(ExpressionBaseMixin):
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices(state, self.variables)
- def acc_substitutions(acc, substitution_expr):
- state, result = acc
+ def acc_substitutions(acc, next):
+ state, acc_variable, acc_substitution = acc
+ variable, expr = next
- # for expr in self.expressions:
- if isinstance(substitution_expr, ExpressionBaseMixin):
- state, substitution = substitution_expr.apply(state)
+ state, indices = get_variable_indices_from_variable(state, variable)
- assert substitution.shape == (1, 1), f'{substitution=}'
+ if indices is None:
+ return acc
- polynomial = substitution.get_poly(0, 0)
+ state, substitution = expr.apply(state)
- elif isinstance(substitution_expr, int) or isinstance(substitution_expr, float):
- polynomial = {tuple(): substitution_expr}
+ assert substitution.shape[1] == 1, f'{substitution=}'
- else:
- raise Exception(f'{substitution_expr=} not recognized')
+ def gen_polynomials():
+ for row in range(substitution.shape[0]):
+ yield substitution.get_poly(row, 0)
- return state, result + (polynomial,)
+ polynomials = tuple(gen_polynomials())
- *_, (state, substitutions) = tuple(itertools.accumulate(
+ return state, acc_variable + indices, acc_substitution + polynomials
+
+ *_, (state, variable_indices, substitutions) = tuple(itertools.accumulate(
self.substitutions,
acc_substitutions,
- initial=(state, tuple()),
+ initial=(state, tuple(), tuple()),
))
+ if len(substitutions) == 1:
+ substitutions = tuple(substitutions[0] for _ in variable_indices)
+
+ else:
+ assert len(variable_indices) == len(substitutions), f'{substitutions=}'
+
terms = {}
for row in range(underlying.shape[0]):
@@ -81,7 +84,6 @@ class SubstituteExprMixin(ExpressionBaseMixin):
index = variable_indices.index(variable)
substitution = substitutions[index]
-
for _ in range(count):
next = {}
diff --git a/polymatrix/expression/mixins/truncateexprmixin.py b/polymatrix/expression/mixins/truncateexprmixin.py
index e144ae5..cadc856 100644
--- a/polymatrix/expression/mixins/truncateexprmixin.py
+++ b/polymatrix/expression/mixins/truncateexprmixin.py
@@ -5,7 +5,7 @@ from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
from polymatrix.polymatrix.polymatrix import PolyMatrix
from polymatrix.expressionstate.expressionstate import ExpressionState
-from polymatrix.expression.utils.getvariableindices import get_variable_indices
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
class TruncateExprMixin(ExpressionBaseMixin):
@@ -35,7 +35,7 @@ class TruncateExprMixin(ExpressionBaseMixin):
state: ExpressionState,
) -> tuple[ExpressionState, PolyMatrix]:
state, underlying = self.underlying.apply(state=state)
- state, variable_indices = get_variable_indices(state, self.variables)
+ state, variable_indices = get_variable_indices_from_variable(state, self.variables)
terms = {}
diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py
index aaf5b0b..371894e 100644
--- a/polymatrix/expression/utils/getvariableindices.py
+++ b/polymatrix/expression/utils/getvariableindices.py
@@ -1,40 +1,104 @@
+import itertools
+import typing
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-def get_variable_indices(state, variables):
+def get_variable_indices_from_variable(state, variable) -> typing.Optional[tuple[int]]:
+
+ if isinstance(variable, ExpressionBaseMixin):
+ state, variable_polynomial = variable.apply(state)
+
+ assert variable_polynomial.shape[1] == 1
+
+ def gen_variables_indices():
+
+ for row in range(variable_polynomial.shape[0]):
+ row_terms = variable_polynomial.get_poly(row, 0)
+
+ assert len(row_terms) == 1, f'{row_terms} contains more than one term'
+
+ for monomial in row_terms.keys():
+ assert len(monomial) <= 1, f'{monomial=} contains more than one variable'
+
+ if len(monomial) == 0:
+ continue
+
+ assert monomial[0][1] == 1, f'{monomial[0]=}'
+ yield monomial[0][0]
+
+ variable_indices = tuple(gen_variables_indices())
+
+ elif isinstance(variable, int):
+ variable_indices = (variable,)
+
+ elif variable in state.offset_dict:
+ variable_indices = (state.offset_dict[variable][0],)
- global_state = [state]
+ else:
+ variable_indices = None
+ # raise Exception(f'variable index not found for {variable=}, {state.offset_dict=}')
+
+ return state, variable_indices
+
+
+def get_variable_indices(state, variables):
if not isinstance(variables, tuple):
variables = (variables,)
- def gen_indices():
- for variable in variables:
- if isinstance(variable, ExpressionBaseMixin):
- global_state[0], variable_polynomial = variable.apply(global_state[0])
+ # assert isinstance(variables, tuple), f'{variables=}'
+
+ def acc_variable_indices(acc, variable):
+ state, indices = acc
+
+ state, new_indices = get_variable_indices_from_variable(state, variable)
+
+ return state, indices + new_indices
+
+ *_, (state, indices) = itertools.accumulate(
+ variables,
+ acc_variable_indices,
+ initial=(state, tuple()),
+ )
+
+ return state, indices
+
+
+ # global_state = [state]
+
+ # assert isinstance(variables, tuple), f'{variables=}'
+
+ # # if not isinstance(variables, tuple):
+ # # variables = (variables,)
+
+ # def gen_indices():
+ # for variable in variables:
+ # if isinstance(variable, ExpressionBaseMixin):
+ # global_state[0], variable_polynomial = variable.apply(global_state[0])
- assert variable_polynomial.shape[1] == 1
+ # assert variable_polynomial.shape[1] == 1
- for row in range(variable_polynomial.shape[0]):
- row_terms = variable_polynomial.get_poly(row, 0)
+ # for row in range(variable_polynomial.shape[0]):
+ # row_terms = variable_polynomial.get_poly(row, 0)
- assert len(row_terms) == 1, f'{row_terms} contains more than one term'
+ # assert len(row_terms) == 1, f'{row_terms} contains more than one term'
- for monomial in row_terms.keys():
- assert len(monomial) <= 1, f'{monomial=} contains more than one variable'
+ # for monomial in row_terms.keys():
+ # assert len(monomial) <= 1, f'{monomial=} contains more than one variable'
- if len(monomial) == 0:
- continue
+ # if len(monomial) == 0:
+ # continue
- assert monomial[0][1] == 1, f'{monomial[0]=}'
- yield monomial[0][0]
+ # assert monomial[0][1] == 1, f'{monomial[0]=}'
+ # yield monomial[0][0]
- elif isinstance(variable, int):
- yield variable
+ # elif isinstance(variable, int):
+ # yield variable
- else:
- yield global_state[0].offset_dict[variable][0]
+ # # else:
+ # elif variable in global_state[0].offset_dict:
+ # yield global_state[0].offset_dict[variable][0]
- indices = tuple(gen_indices())
+ # indices = tuple(gen_indices())
- return global_state[0], indices
+ # return global_state[0], indices
diff --git a/polymatrix/expression/utils/mergemonomialindices.py b/polymatrix/expression/utils/mergemonomialindices.py
index 7f8ba15..c572852 100644
--- a/polymatrix/expression/utils/mergemonomialindices.py
+++ b/polymatrix/expression/utils/mergemonomialindices.py
@@ -19,9 +19,4 @@ def merge_monomial_indices(monomials):
else:
m1_dict[index] = count
- # return tuple(sorted(
- # m1_dict.items(),
- # key=lambda m: m[0],
- # ))
-
return sort_monomial_indices(m1_dict.items())
diff --git a/polymatrix/expressionstate/mixins/expressionstatemixin.py b/polymatrix/expressionstate/mixins/expressionstatemixin.py
index e08e1eb..18dba05 100644
--- a/polymatrix/expressionstate/mixins/expressionstatemixin.py
+++ b/polymatrix/expressionstate/mixins/expressionstatemixin.py
@@ -16,7 +16,7 @@ class ExpressionStateMixin(
@abc.abstractmethod
def n_param(self) -> int:
"""
- current number of parameters used in polynomial matrix expressions
+ number of parameters used in polynomial matrix expressions
"""
...
@@ -24,11 +24,16 @@ class ExpressionStateMixin(
@property
@abc.abstractmethod
def offset_dict(self) -> dict[tuple[typing.Any], tuple[int, int]]:
+ """
+ a variable consists of one or more parameters indexed by a start
+ and an end index
+ """
+
...
@property
@abc.abstractmethod
- def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
+ def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
...
def get_key_from_offset(self, offset: int):