From edb932093b32127c67b358a56ab3afb7c054c968 Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 18 Mar 2024 22:14:28 +0100 Subject: Fix mdpoly.test --- mdpoly/test/__main__.py | 2 +- mdpoly/test/algebra.py | 68 -------------------------------------------- mdpoly/test/expressions.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 69 deletions(-) delete mode 100644 mdpoly/test/algebra.py create mode 100644 mdpoly/test/expressions.py diff --git a/mdpoly/test/__main__.py b/mdpoly/test/__main__.py index b034fb4..722065f 100644 --- a/mdpoly/test/__main__.py +++ b/mdpoly/test/__main__.py @@ -1,6 +1,6 @@ import unittest -from .algebra import TestPolyRingExpr +from .expressions import TestPolyExpressions from .index import TestPolyVarIndex, TestPolyIndex unittest.main() diff --git a/mdpoly/test/algebra.py b/mdpoly/test/algebra.py deleted file mode 100644 index 158b224..0000000 --- a/mdpoly/test/algebra.py +++ /dev/null @@ -1,68 +0,0 @@ -""" Tests for :py:mod:`mdpoly.algebra` """ - -from unittest import TestCase - -from .representations import TestRepr - -from .. import Variable, Constant, Parameter, State -from ..algebra import PolyRingExpr, PolyConst -from ..errors import AlgebraicError -from ..index import MatrixIndex -from ..representations import SparseRepr - - -class TestPolyRingExpr(TestCase): - - def test_basic_algebra_correct(self): - x, y = Variable.from_names("x, y") - k = Constant(2, "k") - p = Parameter("p") - - ops = ( - x + y, x + k, x + p, k + x, p + x, # addition - x - y, x - k, x - p, k - x, p - x, # subtraction - x * y, x * k, x * p, k * x, p * x, # multiplication - x / k, x / p, # division - x ** k, x ** p, # exponentiation - -x, # negation - ) - - for result in ops: - self.assertIsInstance(result, PolyRingExpr) - - def test_basic_algebra_invalid(self): - x, y = Variable.from_names("x, y") - k = Constant(2, "k") - - self.assertRaises(AlgebraicError, lambda: x / y) - self.assertRaises(AlgebraicError, lambda: k / x) - self.assertRaises(AlgebraicError, lambda: x ** y) - - def test_derivative_integral(self): - x, y, z = Variable.from_names("x, y, z") - - dp_auto = (x ** 2 * y + 2 * x * z + y * z * (x + 1) + 3).diff(x) - dp_hand = 2 * x * y + 2 * z + y * z # diff wrt x by hand - - # FIXME: use TestRepr - auto_repr, auto_state = dp_auto.to_repr(SparseRepr, State()) - hand_repr, hand_state = dp_hand.to_repr(SparseRepr, State()) - - # FIXME: do not use internal stuff in test, implement TestRepr - self.assertTrue(auto_repr.data == hand_repr.data) - - - def test_wraps_literals(self): - x = Variable("x") - - ops_right = (x + 2, x * 2, x / 2, x ** 2,) - ops_left = (2 + x, 2 * x,) - - for result in ops_right: - self.assertIsInstance(result.right, PolyConst) - - for result in ops_left: - self.assertIsInstance(result.left, PolyConst) - - - diff --git a/mdpoly/test/expressions.py b/mdpoly/test/expressions.py new file mode 100644 index 0000000..1520f12 --- /dev/null +++ b/mdpoly/test/expressions.py @@ -0,0 +1,70 @@ +""" Tests for :py:mod:`mdpoly.expressions` """ + +from unittest import TestCase + +from .representations import TestRepr + +from .. import Variable, Constant, Parameter, State + +from ..abc import Expr +from ..errors import AlgebraicError +from ..expressions import PolyConst +from ..index import MatrixIndex +from ..representations import SparseRepr + + +class TestPolyExpressions(TestCase): + + def test_scalar_valid(self): + x, y = Variable.from_names("x, y") + k = Constant(2, "k") + p = Parameter("p") + + ops = ( + x + y, x + k, x + p, k + x, p + x, # addition + x - y, x - k, x - p, k - x, p - x, # subtraction + x * y, x * k, x * p, k * x, p * x, # multiplication + x / k, x / p, # division + x ** k, x ** p, # exponentiation + -x, # negation + ) + + for result in ops: + self.assertIsInstance(result, Expr) + + def test_scalar_invalid(self): + x, y = Variable.from_names("x, y") + k = Constant(2, "k") + + self.assertRaises(AlgebraicError, lambda: x / y) + self.assertRaises(AlgebraicError, lambda: k / x) + self.assertRaises(AlgebraicError, lambda: x ** y) + + def test_derivative_integral(self): + x, y, z = Variable.from_names("x, y, z") + + dp_auto = (x ** 2 * y + 2 * x * z + y * z * (x + 1) + 3).diff(x) + dp_hand = 2 * x * y + 2 * z + y * z # diff wrt x by hand + + # FIXME: use TestRepr + auto_repr, auto_state = dp_auto.to_repr(SparseRepr, State()) + hand_repr, hand_state = dp_hand.to_repr(SparseRepr, State()) + + # FIXME: do not use internal stuff in test, implement TestRepr + self.assertTrue(auto_repr.data == hand_repr.data) + + + def test_wraps_literals(self): + x = Variable("x") + + ops_right = (x + 2, x * 2, x / 2, x ** 2,) + ops_left = (2 + x, 2 * x,) + + for result in ops_right: + self.assertIsInstance(result.right, PolyConst) + + for result in ops_left: + self.assertIsInstance(result.left, PolyConst) + + + -- cgit v1.2.1