From 9d52691f16f1f079c6a493fb54d8a98f9ba2c7d9 Mon Sep 17 00:00:00 2001
From: Nao Pross <np@0hm.ch>
Date: Sun, 10 Mar 2024 23:43:26 +0100
Subject: Add stubs for testing

To run tests do:

  $ python3 -m mdpoly.test
---
 mdpoly/test/__init__.py        |  1 +
 mdpoly/test/__main__.py        |  6 ++++
 mdpoly/test/algebra.py         | 68 ++++++++++++++++++++++++++++++++++++++++++
 mdpoly/test/index.py           | 38 +++++++++++++++++++++++
 mdpoly/test/representations.py | 12 ++++++++
 5 files changed, 125 insertions(+)
 create mode 100644 mdpoly/test/__init__.py
 create mode 100644 mdpoly/test/__main__.py
 create mode 100644 mdpoly/test/algebra.py
 create mode 100644 mdpoly/test/index.py
 create mode 100644 mdpoly/test/representations.py

diff --git a/mdpoly/test/__init__.py b/mdpoly/test/__init__.py
new file mode 100644
index 0000000..f810045
--- /dev/null
+++ b/mdpoly/test/__init__.py
@@ -0,0 +1 @@
+""" Unit tests for mdpoly """
diff --git a/mdpoly/test/__main__.py b/mdpoly/test/__main__.py
new file mode 100644
index 0000000..b034fb4
--- /dev/null
+++ b/mdpoly/test/__main__.py
@@ -0,0 +1,6 @@
+import unittest
+
+from .algebra import TestPolyRingExpr
+from .index import TestPolyVarIndex, TestPolyIndex
+
+unittest.main()
diff --git a/mdpoly/test/algebra.py b/mdpoly/test/algebra.py
new file mode 100644
index 0000000..158b224
--- /dev/null
+++ b/mdpoly/test/algebra.py
@@ -0,0 +1,68 @@
+""" Tests for :py:mod:`mdpoly.algebra` """
+
+from unittest import TestCase
+
+from .representations import TestRepr
+
+from .. import Variable, Constant, Parameter, State
+from ..algebra import PolyRingExpr, PolyConst
+from ..errors import AlgebraicError
+from ..index import MatrixIndex
+from ..representations import SparseRepr
+
+
+class TestPolyRingExpr(TestCase):
+
+    def test_basic_algebra_correct(self):
+        x, y = Variable.from_names("x, y")
+        k = Constant(2, "k")
+        p = Parameter("p")
+
+        ops = (
+            x + y, x + k, x + p, k + x, p + x, # addition
+            x - y, x - k, x - p, k - x, p - x, # subtraction
+            x * y, x * k, x * p, k * x, p * x, # multiplication
+            x / k, x / p, # division
+            x ** k, x ** p, # exponentiation
+            -x, # negation
+        )
+
+        for result in ops:
+            self.assertIsInstance(result, PolyRingExpr)
+
+    def test_basic_algebra_invalid(self):
+        x, y = Variable.from_names("x, y")
+        k = Constant(2, "k")
+
+        self.assertRaises(AlgebraicError, lambda: x / y)
+        self.assertRaises(AlgebraicError, lambda: k / x)
+        self.assertRaises(AlgebraicError, lambda: x ** y)
+
+    def test_derivative_integral(self):
+        x, y, z = Variable.from_names("x, y, z")
+
+        dp_auto = (x ** 2 * y + 2 * x * z + y * z * (x + 1) + 3).diff(x)
+        dp_hand = 2 * x * y + 2 * z + y * z # diff wrt x by hand
+
+        # FIXME: use TestRepr
+        auto_repr, auto_state = dp_auto.to_repr(SparseRepr, State())
+        hand_repr, hand_state = dp_hand.to_repr(SparseRepr, State())
+
+        # FIXME: do not use internal stuff in test, implement TestRepr
+        self.assertTrue(auto_repr.data == hand_repr.data)
+
+
+    def test_wraps_literals(self):
+        x = Variable("x")
+
+        ops_right = (x + 2, x * 2, x / 2, x ** 2,)
+        ops_left = (2 + x, 2 * x,)
+
+        for result in ops_right:
+            self.assertIsInstance(result.right, PolyConst)
+
+        for result in ops_left:
+            self.assertIsInstance(result.left, PolyConst)
+
+
+
diff --git a/mdpoly/test/index.py b/mdpoly/test/index.py
new file mode 100644
index 0000000..05012df
--- /dev/null
+++ b/mdpoly/test/index.py
@@ -0,0 +1,38 @@
+from unittest import TestCase
+
+from .. import Variable, State
+from ..index import PolyVarIndex, PolyIndex
+
+class TestPolyVarIndex(TestCase):
+
+    def test_constant(self):
+        const = PolyVarIndex.constant()
+        self.assertTrue(const.is_constant())
+
+    def test_from_var(self):
+        x = Variable("x")
+        state = State()
+        ix = state.index(x)
+
+        idx = PolyVarIndex.from_var(x, state)
+        self.assertTrue(idx.var_idx == ix)
+
+    def test_total_order(self):
+        ...
+
+
+class TestPolyIndex(TestCase):
+
+    def test_magic_methods(self):
+        ...
+
+    def test_from_dict(self):
+        ...
+
+    def test_multi_index(self):
+        ...
+
+    def test_constant(self):
+        const = PolyIndex.constant()
+        self.assertTrue(const.is_constant())
+
diff --git a/mdpoly/test/representations.py b/mdpoly/test/representations.py
new file mode 100644
index 0000000..1284c46
--- /dev/null
+++ b/mdpoly/test/representations.py
@@ -0,0 +1,12 @@
+from ..abc import Repr
+from ..index import Shape, MatrixIndex, PolyIndex, Number
+
+
+class TestRepr(Repr):
+    """ Representation used for test cases """
+
+    def __init__(self, shape: Shape):
+        ...
+
+    def at(self, entry: MatrixIndex, term: PolyIndex) -> Number:
+        raise NotImplementedError
-- 
cgit v1.2.1