diff options
author | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-09 13:57:17 +0200 |
---|---|---|
committer | Michael Schneeberger <michael.schneeberger@fhnw.ch> | 2022-08-09 13:57:17 +0200 |
commit | 7d3f1d7142d4301f70e4ee248df3ac9f4b7326bb (patch) | |
tree | d434214f51bae97e0e75bd13fa550067d5178cc6 | |
parent | add 'to_sympy_expr' function (diff) | |
download | polymatrix-7d3f1d7142d4301f70e4ee248df3ac9f4b7326bb.tar.gz polymatrix-7d3f1d7142d4301f70e4ee248df3ac9f4b7326bb.zip |
bug fix: monomial returned by subtract_monomial cannot have negative powers
Diffstat (limited to '')
5 files changed, 22 insertions, 10 deletions
diff --git a/polymatrix/expression/init/initevalexpr.py b/polymatrix/expression/init/initevalexpr.py index dad9664..5bbd404 100644 --- a/polymatrix/expression/init/initevalexpr.py +++ b/polymatrix/expression/init/initevalexpr.py @@ -41,15 +41,15 @@ def init_eval_expr( yield from value[variable] elif isinstance(value, int) or isinstance(value, float): - # else: yield value else: yield float(value) - # raise Exception(f'{value=}, {type(value)=}') + + values = tuple(gen_formatted_values()) return EvalExprImpl( underlying=underlying, variables=variables, - values=tuple(gen_formatted_values()), + values=values, ) diff --git a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py index b55f412..d5091c6 100644 --- a/polymatrix/expression/mixins/subtractmonomialsexprmixin.py +++ b/polymatrix/expression/mixins/subtractmonomialsexprmixin.py @@ -8,7 +8,7 @@ from polymatrix.expressionstate.mixins.expressionstatemixin import ExpressionSta from polymatrix.polymatrix.mixins.polymatrixmixin import PolyMatrixMixin from polymatrix.expression.utils.getmonomialindices import get_monomial_indices from polymatrix.expression.utils.sortmonomials import sort_monomials -from polymatrix.expression.utils.subtractmonomialindices import subtract_monomial_indices +from polymatrix.expression.utils.subtractmonomialindices import SubtractError, subtract_monomial_indices class SubtractMonomialsExprMixin(ExpressionBaseMixin): @@ -36,7 +36,7 @@ class SubtractMonomialsExprMixin(ExpressionBaseMixin): for m2 in sub_monomials: try: remainder = subtract_monomial_indices(m1, m2) - except KeyError: + except SubtractError: continue yield remainder diff --git a/polymatrix/expression/utils/getvariableindices.py b/polymatrix/expression/utils/getvariableindices.py index c72bb02..7010d23 100644 --- a/polymatrix/expression/utils/getvariableindices.py +++ b/polymatrix/expression/utils/getvariableindices.py @@ -33,6 +33,6 @@ def get_variable_indices(state, variables): yield variable else: - state.offset_dict[variable][0] + yield state.offset_dict[variable][0] return state, tuple(gen_indices()) diff --git a/polymatrix/expression/utils/subtractmonomialindices.py b/polymatrix/expression/utils/subtractmonomialindices.py index 766997a..236ba9c 100644 --- a/polymatrix/expression/utils/subtractmonomialindices.py +++ b/polymatrix/expression/utils/subtractmonomialindices.py @@ -1,13 +1,23 @@ from polymatrix.expression.utils.sortmonomialindices import sort_monomial_indices +class SubtractError(Exception): + pass + + def subtract_monomial_indices(m1, m2): m1_dict = dict(m1) for index, count in m2: + if index not in m1_dict: + raise SubtractError() + m1_dict[index] -= count if m1_dict[index] == 0: del m1_dict[index] + elif m1_dict[index] < 0: + raise SubtractError() + return sort_monomial_indices(m1_dict.items()) diff --git a/polymatrix/expressionstate/mixins/expressionstatemixin.py b/polymatrix/expressionstate/mixins/expressionstatemixin.py index 523d2b7..6d2ded3 100644 --- a/polymatrix/expressionstate/mixins/expressionstatemixin.py +++ b/polymatrix/expressionstate/mixins/expressionstatemixin.py @@ -31,10 +31,12 @@ class ExpressionStateMixin( def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]: ... - # @property - # @abc.abstractmethod - # def cache(self) -> dict: - # ... + def get_variable_from_offset(self, offset: int): + for variable, (start, end) in self.offset_dict.items(): + if offset == start: + assert end - start == 1, f'{start=}, {end=}, {variable=}' + + return variable def register( self, |