diff options
Diffstat (limited to '')
-rw-r--r-- | polymatrix/expression/impl.py | 7 | ||||
-rw-r--r-- | polymatrix/expression/init.py | 32 | ||||
-rw-r--r-- | polymatrix/expression/mixins/substituteexprmixin.py | 120 |
3 files changed, 0 insertions, 159 deletions
diff --git a/polymatrix/expression/impl.py b/polymatrix/expression/impl.py index 6c0edc5..e4cae1b 100644 --- a/polymatrix/expression/impl.py +++ b/polymatrix/expression/impl.py @@ -46,7 +46,6 @@ from polymatrix.expression.mixins.repmatexprmixin import RepMatExprMixin from polymatrix.expression.mixins.reshapeexprmixin import ReshapeExprMixin from polymatrix.expression.mixins.setelementatexprmixin import SetElementAtExprMixin from polymatrix.expression.mixins.squeezeexprmixin import SqueezeExprMixin -from polymatrix.expression.mixins.substituteexprmixin import SubstituteExprMixin from polymatrix.expression.mixins.subtractmonomialsexprmixin import SubtractMonomialsExprMixin from polymatrix.expression.mixins.sumexprmixin import SumExprMixin from polymatrix.expression.mixins.symmetricexprmixin import SymmetricExprMixin @@ -362,12 +361,6 @@ class SqueezeExprImpl(SqueezeExprMixin): @dataclassabc.dataclassabc(frozen=True) -class SubstituteExprImpl(SubstituteExprMixin): - underlying: ExpressionBaseMixin - substitutions: tuple - - -@dataclassabc.dataclassabc(frozen=True) class SubtractMonomialsExprImpl(SubtractMonomialsExprMixin): underlying: ExpressionBaseMixin monomials: ExpressionBaseMixin diff --git a/polymatrix/expression/init.py b/polymatrix/expression/init.py index 32de8c1..5e72524 100644 --- a/polymatrix/expression/init.py +++ b/polymatrix/expression/init.py @@ -404,38 +404,6 @@ def init_squeeze_expr( ) -def init_substitute_expr( - underlying: ExpressionBaseMixin, - variables: tuple, - values: tuple = None, -): - substitutions = format_substitutions( - variables=variables, - values=values, - ) - - def formatted_values(value) -> ExpressionBaseMixin: - if isinstance(value, ExpressionBaseMixin): - expr = value - - else: - expr = init_from_expr(value) - - return polymatrix.expression.impl.ReshapeExprImpl( - underlying=expr, - new_shape=(-1, 1), - ) - - substitutions = tuple( - (variable, formatted_values(value)) for variable, value in substitutions - ) - - return polymatrix.expression.impl.SubstituteExprImpl( - underlying=underlying, - substitutions=substitutions, - ) - - def init_subtract_monomials_expr( underlying: ExpressionBaseMixin, monomials: ExpressionBaseMixin, diff --git a/polymatrix/expression/mixins/substituteexprmixin.py b/polymatrix/expression/mixins/substituteexprmixin.py deleted file mode 100644 index a3e76b1..0000000 --- a/polymatrix/expression/mixins/substituteexprmixin.py +++ /dev/null @@ -1,120 +0,0 @@ -import abc -import collections -import itertools -import math -import typing - -from polymatrix.polymatrix.init import init_poly_matrix -from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin -from polymatrix.polymatrix.abc import PolyMatrix -from polymatrix.expressionstate import ExpressionState -from polymatrix.expression.utils.getvariableindices import ( - get_variable_indices_from_variable, -) -from polymatrix.polymatrix.utils.multiplypolynomial import multiply_polynomial - - -class SubstituteExprMixin(ExpressionBaseMixin): - @property - @abc.abstractmethod - def underlying(self) -> ExpressionBaseMixin: ... - - @property - @abc.abstractmethod - def substitutions(self) -> tuple[tuple[typing.Any, ExpressionBaseMixin], ...]: ... - - # overwrites the abstract method of `ExpressionBaseMixin` - def apply( - self, - state: ExpressionState, - ) -> tuple[ExpressionState, PolyMatrix]: - state, underlying = self.underlying.apply(state=state) - - def acc_substitutions(acc, next): - state, acc_variable, acc_substitution = acc - variable, expr = next - - state, indices = get_variable_indices_from_variable(state, variable) - - if indices is None: - return acc - - state, substitution = expr.apply(state) - - # state, substitution = ReshapeExprImpl( - # underlying=expr, - # new_shape=(-1, 1), - # ).apply(state) - - def gen_polynomials(): - for row in range(substitution.shape[0]): - yield substitution.get_poly(row, 0) - - polynomials = tuple(gen_polynomials()) - - return state, acc_variable + indices, acc_substitution + polynomials - - *_, (state, variable_indices, substitutions) = tuple( - itertools.accumulate( - self.substitutions, - acc_substitutions, - initial=(state, tuple(), tuple()), - ) - ) - - if len(substitutions) == 1: - substitutions = tuple(substitutions[0] for _ in variable_indices) - - else: - assert len(variable_indices) == len(substitutions), f"{substitutions=}" - - poly_matrix_data = {} - - for row in range(underlying.shape[0]): - for col in range(underlying.shape[1]): - polynomial = underlying.get_poly(row, col) - if polynomial is None: - continue - - polynomial = collections.defaultdict(float) - - for monomial, value in polynomial.items(): - substituted_monomial = {tuple(): value} - - for variable, count in monomial: - if variable in variable_indices: - index = variable_indices.index(variable) - substitution = substitutions[index] - - for _ in range(count): - next = {} - multiply_polynomial( - substituted_monomial, substitution, next - ) - substituted_monomial = next - - else: - next = {} - multiply_polynomial( - substituted_monomial, {((variable, count),): 1.0}, next - ) - substituted_monomial = next - - for monomial, value in substituted_monomial.items(): - polynomial[monomial] += value - - polynomial = { - key: val - for key, val in polynomial.items() - if not math.isclose(val, 0, abs_tol=1e-12) - } - - if 0 < len(polynomial): - poly_matrix_data[row, col] = polynomial - - poly_matrix = init_poly_matrix( - data=poly_matrix_data, - shape=underlying.shape, - ) - - return state, poly_matrix |