Commit d0cb0240 authored by Philipp  Trunschke's avatar Philipp Trunschke
Browse files

test relative errors instead of absoulte & add test for TensorNetwork

parent b8d3b8a8
Pipeline #1442 passed with stages
in 19 minutes and 7 seconds
......@@ -48,7 +48,8 @@ def __test_tensor(A):
bytes = dumps(A)
Au = loads(bytes)
self.assertEqual(Au.dimensions, A.dimensions)
self.assertLessEqual(xe.frob_norm(A-Au), 1e-16)
error = xe.frob_norm(A-Au) / xe.frob_norm(A)
self.assertLessEqual(error, 1e-16)
return name, test_pickle
......@@ -59,11 +60,27 @@ def __test_tttensor(A):
def test_pickle(self):
bytes = dumps(A)
Au = loads(bytes)
print(A.dimensions)
print(Au.dimensions)
self.assertEqual(Au.dimensions, A.dimensions)
self.assertEqual(Au.ranks(), A.ranks())
self.assertLessEqual(xe.frob_norm(A-Au), 1e-10)
error = xe.frob_norm(A-Au) / xe.frob_norm(A)
self.assertLessEqual(error, 1e-10)
return name, test_pickle
def __test_tensornetwork(A):
name = "test_pickle_{}_{}".format(list2str(A.dimensions), list2str(A.ranks()))
A = xe.TensorNetwork(A)
i, = xe.indices(1)
def test_pickle(self):
bytes = dumps(A)
Au = loads(bytes)
normA = xe.frob_norm(A)
normAu = xe.frob_norm(Au)
innerAuA = float(A(i&0) * Au(i&0))
error = np.sqrt(max(normA**2 - 2*innerAuA + normAu**2, 0)) / normA
self.assertLessEqual(error, 1e-7)
return name, test_pickle
......@@ -77,8 +94,16 @@ def build_TestPickleTensor(seed, num_tests):
def build_TestPickleTTTensor(seed, num_tests):
random = np.random.RandomState(seed)
odir = dict(__test_tttensor(t) for t in generate_random_tttensors(num_tests, random=random))
return type("TestPickleTensor", (unittest.TestCase,), odir)
return type("TestPickleTTTensor", (unittest.TestCase,), odir)
def build_TestPickleTensorNetwork(seed, num_tests):
random = np.random.RandomState(seed)
odir = dict(__test_tensornetwork(t) for t in generate_random_tttensors(num_tests, random=random))
return type("TestPickleTensorNetwork", (unittest.TestCase,), odir)
TestPickleTensor = build_TestPickleTensor(0, 20)
TestPickleTTTensor = build_TestPickleTTTensor(0, 20)
# TestPickleTTTensor = build_TestPickleHTTensor(0, 20)
TestPickleTensor = build_TestPickleTensorNetwork(0, 20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment