aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--mdpoly/test/__init__.py1
-rw-r--r--mdpoly/test/__main__.py6
-rw-r--r--mdpoly/test/algebra.py68
-rw-r--r--mdpoly/test/index.py38
-rw-r--r--mdpoly/test/representations.py12
5 files changed, 125 insertions, 0 deletions
diff --git a/mdpoly/test/__init__.py b/mdpoly/test/__init__.py
new file mode 100644
index 0000000..f810045
--- /dev/null
+++ b/mdpoly/test/__init__.py
@@ -0,0 +1 @@
+""" Unit tests for mdpoly """
diff --git a/mdpoly/test/__main__.py b/mdpoly/test/__main__.py
new file mode 100644
index 0000000..b034fb4
--- /dev/null
+++ b/mdpoly/test/__main__.py
@@ -0,0 +1,6 @@
+import unittest
+
+from .algebra import TestPolyRingExpr
+from .index import TestPolyVarIndex, TestPolyIndex
+
+unittest.main()
diff --git a/mdpoly/test/algebra.py b/mdpoly/test/algebra.py
new file mode 100644
index 0000000..158b224
--- /dev/null
+++ b/mdpoly/test/algebra.py
@@ -0,0 +1,68 @@
+""" Tests for :py:mod:`mdpoly.algebra` """
+
+from unittest import TestCase
+
+from .representations import TestRepr
+
+from .. import Variable, Constant, Parameter, State
+from ..algebra import PolyRingExpr, PolyConst
+from ..errors import AlgebraicError
+from ..index import MatrixIndex
+from ..representations import SparseRepr
+
+
+class TestPolyRingExpr(TestCase):
+
+ def test_basic_algebra_correct(self):
+ x, y = Variable.from_names("x, y")
+ k = Constant(2, "k")
+ p = Parameter("p")
+
+ ops = (
+ x + y, x + k, x + p, k + x, p + x, # addition
+ x - y, x - k, x - p, k - x, p - x, # subtraction
+ x * y, x * k, x * p, k * x, p * x, # multiplication
+ x / k, x / p, # division
+ x ** k, x ** p, # exponentiation
+ -x, # negation
+ )
+
+ for result in ops:
+ self.assertIsInstance(result, PolyRingExpr)
+
+ def test_basic_algebra_invalid(self):
+ x, y = Variable.from_names("x, y")
+ k = Constant(2, "k")
+
+ self.assertRaises(AlgebraicError, lambda: x / y)
+ self.assertRaises(AlgebraicError, lambda: k / x)
+ self.assertRaises(AlgebraicError, lambda: x ** y)
+
+ def test_derivative_integral(self):
+ x, y, z = Variable.from_names("x, y, z")
+
+ dp_auto = (x ** 2 * y + 2 * x * z + y * z * (x + 1) + 3).diff(x)
+ dp_hand = 2 * x * y + 2 * z + y * z # diff wrt x by hand
+
+ # FIXME: use TestRepr
+ auto_repr, auto_state = dp_auto.to_repr(SparseRepr, State())
+ hand_repr, hand_state = dp_hand.to_repr(SparseRepr, State())
+
+ # FIXME: do not use internal stuff in test, implement TestRepr
+ self.assertTrue(auto_repr.data == hand_repr.data)
+
+
+ def test_wraps_literals(self):
+ x = Variable("x")
+
+ ops_right = (x + 2, x * 2, x / 2, x ** 2,)
+ ops_left = (2 + x, 2 * x,)
+
+ for result in ops_right:
+ self.assertIsInstance(result.right, PolyConst)
+
+ for result in ops_left:
+ self.assertIsInstance(result.left, PolyConst)
+
+
+
diff --git a/mdpoly/test/index.py b/mdpoly/test/index.py
new file mode 100644
index 0000000..05012df
--- /dev/null
+++ b/mdpoly/test/index.py
@@ -0,0 +1,38 @@
+from unittest import TestCase
+
+from .. import Variable, State
+from ..index import PolyVarIndex, PolyIndex
+
+class TestPolyVarIndex(TestCase):
+
+ def test_constant(self):
+ const = PolyVarIndex.constant()
+ self.assertTrue(const.is_constant())
+
+ def test_from_var(self):
+ x = Variable("x")
+ state = State()
+ ix = state.index(x)
+
+ idx = PolyVarIndex.from_var(x, state)
+ self.assertTrue(idx.var_idx == ix)
+
+ def test_total_order(self):
+ ...
+
+
+class TestPolyIndex(TestCase):
+
+ def test_magic_methods(self):
+ ...
+
+ def test_from_dict(self):
+ ...
+
+ def test_multi_index(self):
+ ...
+
+ def test_constant(self):
+ const = PolyIndex.constant()
+ self.assertTrue(const.is_constant())
+
diff --git a/mdpoly/test/representations.py b/mdpoly/test/representations.py
new file mode 100644
index 0000000..1284c46
--- /dev/null
+++ b/mdpoly/test/representations.py
@@ -0,0 +1,12 @@
+from ..abc import Repr
+from ..index import Shape, MatrixIndex, PolyIndex, Number
+
+
+class TestRepr(Repr):
+ """ Representation used for test cases """
+
+ def __init__(self, shape: Shape):
+ ...
+
+ def at(self, entry: MatrixIndex, term: PolyIndex) -> Number:
+ raise NotImplementedError