diff options
-rw-r--r-- | mdpoly/test/__init__.py | 1 | ||||
-rw-r--r-- | mdpoly/test/__main__.py | 6 | ||||
-rw-r--r-- | mdpoly/test/algebra.py | 68 | ||||
-rw-r--r-- | mdpoly/test/index.py | 38 | ||||
-rw-r--r-- | mdpoly/test/representations.py | 12 |
5 files changed, 125 insertions, 0 deletions
diff --git a/mdpoly/test/__init__.py b/mdpoly/test/__init__.py new file mode 100644 index 0000000..f810045 --- /dev/null +++ b/mdpoly/test/__init__.py @@ -0,0 +1 @@ +""" Unit tests for mdpoly """ diff --git a/mdpoly/test/__main__.py b/mdpoly/test/__main__.py new file mode 100644 index 0000000..b034fb4 --- /dev/null +++ b/mdpoly/test/__main__.py @@ -0,0 +1,6 @@ +import unittest + +from .algebra import TestPolyRingExpr +from .index import TestPolyVarIndex, TestPolyIndex + +unittest.main() diff --git a/mdpoly/test/algebra.py b/mdpoly/test/algebra.py new file mode 100644 index 0000000..158b224 --- /dev/null +++ b/mdpoly/test/algebra.py @@ -0,0 +1,68 @@ +""" Tests for :py:mod:`mdpoly.algebra` """ + +from unittest import TestCase + +from .representations import TestRepr + +from .. import Variable, Constant, Parameter, State +from ..algebra import PolyRingExpr, PolyConst +from ..errors import AlgebraicError +from ..index import MatrixIndex +from ..representations import SparseRepr + + +class TestPolyRingExpr(TestCase): + + def test_basic_algebra_correct(self): + x, y = Variable.from_names("x, y") + k = Constant(2, "k") + p = Parameter("p") + + ops = ( + x + y, x + k, x + p, k + x, p + x, # addition + x - y, x - k, x - p, k - x, p - x, # subtraction + x * y, x * k, x * p, k * x, p * x, # multiplication + x / k, x / p, # division + x ** k, x ** p, # exponentiation + -x, # negation + ) + + for result in ops: + self.assertIsInstance(result, PolyRingExpr) + + def test_basic_algebra_invalid(self): + x, y = Variable.from_names("x, y") + k = Constant(2, "k") + + self.assertRaises(AlgebraicError, lambda: x / y) + self.assertRaises(AlgebraicError, lambda: k / x) + self.assertRaises(AlgebraicError, lambda: x ** y) + + def test_derivative_integral(self): + x, y, z = Variable.from_names("x, y, z") + + dp_auto = (x ** 2 * y + 2 * x * z + y * z * (x + 1) + 3).diff(x) + dp_hand = 2 * x * y + 2 * z + y * z # diff wrt x by hand + + # FIXME: use TestRepr + auto_repr, auto_state = dp_auto.to_repr(SparseRepr, State()) + hand_repr, hand_state = dp_hand.to_repr(SparseRepr, State()) + + # FIXME: do not use internal stuff in test, implement TestRepr + self.assertTrue(auto_repr.data == hand_repr.data) + + + def test_wraps_literals(self): + x = Variable("x") + + ops_right = (x + 2, x * 2, x / 2, x ** 2,) + ops_left = (2 + x, 2 * x,) + + for result in ops_right: + self.assertIsInstance(result.right, PolyConst) + + for result in ops_left: + self.assertIsInstance(result.left, PolyConst) + + + diff --git a/mdpoly/test/index.py b/mdpoly/test/index.py new file mode 100644 index 0000000..05012df --- /dev/null +++ b/mdpoly/test/index.py @@ -0,0 +1,38 @@ +from unittest import TestCase + +from .. import Variable, State +from ..index import PolyVarIndex, PolyIndex + +class TestPolyVarIndex(TestCase): + + def test_constant(self): + const = PolyVarIndex.constant() + self.assertTrue(const.is_constant()) + + def test_from_var(self): + x = Variable("x") + state = State() + ix = state.index(x) + + idx = PolyVarIndex.from_var(x, state) + self.assertTrue(idx.var_idx == ix) + + def test_total_order(self): + ... + + +class TestPolyIndex(TestCase): + + def test_magic_methods(self): + ... + + def test_from_dict(self): + ... + + def test_multi_index(self): + ... + + def test_constant(self): + const = PolyIndex.constant() + self.assertTrue(const.is_constant()) + diff --git a/mdpoly/test/representations.py b/mdpoly/test/representations.py new file mode 100644 index 0000000..1284c46 --- /dev/null +++ b/mdpoly/test/representations.py @@ -0,0 +1,12 @@ +from ..abc import Repr +from ..index import Shape, MatrixIndex, PolyIndex, Number + + +class TestRepr(Repr): + """ Representation used for test cases """ + + def __init__(self, shape: Shape): + ... + + def at(self, entry: MatrixIndex, term: PolyIndex) -> Number: + raise NotImplementedError |