summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--polymatrix/expression/impl.py7
-rw-r--r--polymatrix/expression/init.py32
-rw-r--r--polymatrix/expression/mixins/substituteexprmixin.py120
3 files changed, 0 insertions, 159 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py
index 6c0edc5..e4cae1b 100644
--- a/polymatrix/expression/impl.py
+++ b/polymatrix/expression/impl.py
@@ -46,7 +46,6 @@ from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin
from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin
from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin
from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin
-from polymatrix.expression.mixins.substituteexprmixin import SubstituteExprMixin
from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin
from polymatrix.expression.mixins.sumexprmixin import SumExprMixin
from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin
@@ -362,12 +361,6 @@ class SqueezeExprImpl(SqueezeExprMixin):
@dataclassabc.dataclassabc(frozen=True)
-class SubstituteExprImpl(SubstituteExprMixin):
- underlying: ExpressionBaseMixin
- substitutions: tuple
-
-
-@dataclassabc.dataclassabc(frozen=True)
class SubtractMonomialsExprImpl(SubtractMonomialsExprMixin):
underlying: ExpressionBaseMixin
monomials: ExpressionBaseMixin
diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py
index 32de8c1..5e72524 100644
--- a/polymatrix/expression/init.py
+++ b/polymatrix/expression/init.py
@@ -404,38 +404,6 @@ def init_squeeze_expr(
)
-def init_substitute_expr(
- underlying: ExpressionBaseMixin,
- variables: tuple,
- values: tuple = None,
-):
- substitutions = format_substitutions(
- variables=variables,
- values=values,
- )
-
- def formatted_values(value) -> ExpressionBaseMixin:
- if isinstance(value, ExpressionBaseMixin):
- expr = value
-
- else:
- expr = init_from_expr(value)
-
- return polymatrix.expression.impl.ReshapeExprImpl(
- underlying=expr,
- new_shape=(-1, 1),
- )
-
- substitutions = tuple(
- (variable, formatted_values(value)) for variable, value in substitutions
- )
-
- return polymatrix.expression.impl.SubstituteExprImpl(
- underlying=underlying,
- substitutions=substitutions,
- )
-
-
def init_subtract_monomials_expr(
underlying: ExpressionBaseMixin,
monomials: ExpressionBaseMixin,
diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py
deleted file mode 100644
index a3e76b1..0000000
--- a/polymatrix/expression/mixins/substituteexprmixin.py
+++ /dev/null
@@ -1,120 +0,0 @@
-import abc
-import collections
-import itertools
-import math
-import typing
-
-from polymatrix.polymatrix.init import init_poly_matrix
-from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.polymatrix.abc import PolyMatrix
-from polymatrix.expressionstate import ExpressionState
-from polymatrix.expression.utils.getvariableindices import (
- get_variable_indices_from_variable,
-)
-from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial
-
-
-class SubstituteExprMixin(ExpressionBaseMixin):
- @property
- @abc.abstractmethod
- def underlying(self) -> ExpressionBaseMixin: ...
-
- @property
- @abc.abstractmethod
- def substitutions(self) -> tuple[tuple[typing.Any, ExpressionBaseMixin], ...]: ...
-
- # overwrites the abstract method of `ExpressionBaseMixin`
- def apply(
- self,
- state: ExpressionState,
- ) -> tuple[ExpressionState, PolyMatrix]:
- state, underlying = self.underlying.apply(state=state)
-
- def acc_substitutions(acc, next):
- state, acc_variable, acc_substitution = acc
- variable, expr = next
-
- state, indices = get_variable_indices_from_variable(state, variable)
-
- if indices is None:
- return acc
-
- state, substitution = expr.apply(state)
-
- # state, substitution = ReshapeExprImpl(
- # underlying=expr,
- # new_shape=(-1, 1),
- # ).apply(state)
-
- def gen_polynomials():
- for row in range(substitution.shape[0]):
- yield substitution.get_poly(row, 0)
-
- polynomials = tuple(gen_polynomials())
-
- return state, acc_variable + indices, acc_substitution + polynomials
-
- *_, (state, variable_indices, substitutions) = tuple(
- itertools.accumulate(
- self.substitutions,
- acc_substitutions,
- initial=(state, tuple(), tuple()),
- )
- )
-
- if len(substitutions) == 1:
- substitutions = tuple(substitutions[0] for _ in variable_indices)
-
- else:
- assert len(variable_indices) == len(substitutions), f"{substitutions=}"
-
- poly_matrix_data = {}
-
- for row in range(underlying.shape[0]):
- for col in range(underlying.shape[1]):
- polynomial = underlying.get_poly(row, col)
- if polynomial is None:
- continue
-
- polynomial = collections.defaultdict(float)
-
- for monomial, value in polynomial.items():
- substituted_monomial = {tuple(): value}
-
- for variable, count in monomial:
- if variable in variable_indices:
- index = variable_indices.index(variable)
- substitution = substitutions[index]
-
- for _ in range(count):
- next = {}
- multiply_polynomial(
- substituted_monomial, substitution, next
- )
- substituted_monomial = next
-
- else:
- next = {}
- multiply_polynomial(
- substituted_monomial, {((variable, count),): 1.0}, next
- )
- substituted_monomial = next
-
- for monomial, value in substituted_monomial.items():
- polynomial[monomial] += value
-
- polynomial = {
- key: val
- for key, val in polynomial.items()
- if not math.isclose(val, 0, abs_tol=1e-12)
- }
-
- if 0 < len(polynomial):
- poly_matrix_data[row, col] = polynomial
-
- poly_matrix = init_poly_matrix(
- data=poly_matrix_data,
- shape=underlying.shape,
- )
-
- return state, poly_matrix