import abc import collections import itertools import math from polymatrix.polymatrix.init.initpolymatrix import init_poly_matrix from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin from polymatrix.polymatrix.polymatrix import PolyMatrix from polymatrix.expressionstate.expressionstate import ExpressionState from polymatrix.expression.utils.getvariableindices import get_variable_indices from polymatrix.expression.utils.multiplypolynomial import multiply_polynomial class SubstituteExprMixin(ExpressionBaseMixin): @property @abc.abstractmethod def underlying(self) -> ExpressionBaseMixin: ... @property @abc.abstractmethod def variables(self) -> tuple: ... @property @abc.abstractmethod def substitutions(self) -> tuple[ExpressionBaseMixin, ...]: ... # overwrites abstract method of `ExpressionBaseMixin` def apply( self, state: ExpressionState, ) -> tuple[ExpressionState, PolyMatrix]: state, underlying = self.underlying.apply(state=state) state, variable_indices = get_variable_indices(state, self.variables) def acc_substitutions(acc, substitution_expr): state, result = acc # for expr in self.expressions: if isinstance(substitution_expr, ExpressionBaseMixin): state, substitution = substitution_expr.apply(state) assert substitution.shape == (1, 1), f'{substitution=}' polynomial = substitution.get_poly(0, 0) elif isinstance(substitution_expr, int) or isinstance(substitution_expr, float): polynomial = {tuple(): substitution_expr} else: raise Exception(f'{substitution_expr=} not recognized') return state, result + (polynomial,) *_, (state, substitutions) = tuple(itertools.accumulate( self.substitutions, acc_substitutions, initial=(state, tuple()), )) terms = {} for row in range(underlying.shape[0]): for col in range(underlying.shape[1]): polynomial = underlying.get_poly(row, col) if polynomial is None: continue terms_row_col = collections.defaultdict(float) for monomial, value in polynomial.items(): terms_monomial = {tuple(): value} for variable, count in monomial: if variable in variable_indices: index = variable_indices.index(variable) substitution = substitutions[index] for _ in range(count): next = {} multiply_polynomial(terms_monomial, substitution, next) terms_monomial = next else: next = {} multiply_polynomial(terms_monomial, {((variable, count),): 1.0}, next) terms_monomial = next for monomial, value in terms_monomial.items(): terms_row_col[monomial] += value terms_row_col = {key: val for key, val in terms_row_col.items() if not math.isclose(val, 0, abs_tol=1e-12)} if 0 < len(terms_row_col): terms[row, col] = terms_row_col poly_matrix = init_poly_matrix( terms=terms, shape=underlying.shape, ) return state, poly_matrix