Commit 005a783d authored by Philipp  Trunschke's avatar Philipp Trunschke
Browse files

add tests for pickling of httensors

parent 01529b98
......@@ -23,6 +23,13 @@ except ImportError:
from pickle import dumps, loads
def generate_random_tensors(num_tests, max_order=4, max_dimension=10, random=None):
orders = random.randint(low=1, high=max_order+1, size=num_tests)
dimensions = [random.randint(low=1, high=max_dimension, size=order) for order in orders]
for dim in dimensions:
yield xe.Tensor.random(dim.tolist())
def generate_random_tttensors(num_tests, max_order=10, max_dimension=10, max_rank=4, random=None):
orders = random.randint(low=1, high=max_order+1, size=num_tests)
dimensions = [random.randint(low=1, high=max_dimension, size=order) for order in orders]
......@@ -31,11 +38,16 @@ def generate_random_tttensors(num_tests, max_order=10, max_dimension=10, max_ran
yield xe.TTTensor.random(dim.tolist(), rk.tolist())
def generate_random_tensors(num_tests, max_order=4, max_dimension=10, random=None):
orders = random.randint(low=1, high=max_order+1, size=num_tests)
def generate_random_httensors(num_tests, max_order=10, max_dimension=10, max_rank=4, random=None):
orders = random.randint(low=2, high=max_order+1, size=num_tests) #TODO: change `low=1` back when possible
dimensions = [random.randint(low=1, high=max_dimension, size=order) for order in orders]
for dim in dimensions:
yield xe.Tensor.random(dim.tolist())
# A perfect binary tree with `order` leaves has depth `log2(order)` and `sum(2**k for k in range(log2(order))) = 2*order-1` nodes.
# With exception of the root every node has one link to its parent. This results in `2*order - 2` links.
ranks = [random.randint(low=1, high=max_rank, size=2*(order-1)) for order in orders]
print( xe.HTTensor([6]) )
for dim, rk in zip(dimensions, ranks):
print("5:", dim, rk)
yield xe.HTTensor.random(dim.tolist(), rk.tolist())
list2str = lambda ls: "-".join(map(str, ls))
......@@ -68,6 +80,20 @@ def __test_tttensor(A):
return name, test_pickle
def __test_httensor(A):
name = "test_pickle_{}_{}".format(list2str(A.dimensions), list2str(A.ranks()))
def test_pickle(self):
bytes = dumps(A)
Au = loads(bytes)
self.assertEqual(Au.dimensions, A.dimensions)
self.assertEqual(Au.ranks(), A.ranks())
error = xe.frob_norm(A-Au) / xe.frob_norm(A)
self.assertLessEqual(error, 1e-10)
return name, test_pickle
def __test_tensornetwork(A):
name = "test_pickle_{}_{}".format(list2str(A.dimensions), list2str(A.ranks()))
......@@ -97,6 +123,12 @@ def build_TestPickleTTTensor(seed, num_tests):
return type("TestPickleTTTensor", (unittest.TestCase,), odir)
def build_TestPickleHTTensor(seed, num_tests):
random = np.random.RandomState(seed)
odir = dict(__test_httensor(t) for t in generate_random_httensors(num_tests, random=random))
return type("TestPickleHTTensor", (unittest.TestCase,), odir)
def build_TestPickleTensorNetwork(seed, num_tests):
random = np.random.RandomState(seed)
odir = dict(__test_tensornetwork(t) for t in generate_random_tttensors(num_tests, random=random))
......@@ -105,5 +137,5 @@ def build_TestPickleTensorNetwork(seed, num_tests):
TestPickleTensor = build_TestPickleTensor(0, 20)
TestPickleTTTensor = build_TestPickleTTTensor(0, 20)
# TestPickleTTTensor = build_TestPickleHTTensor(0, 20)
TestPickleHTTensor = build_TestPickleHTTensor(0, 20)
TestPickleTensor = build_TestPickleTensorNetwork(0, 20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment