Commit b8d3b8a8 authored by Philipp  Trunschke's avatar Philipp Trunschke
Browse files

add pickling to tensor networks; add tests for ttnetworks

parent 90472906
Pipeline #1441 passed with stages
in 19 minutes and 21 seconds
......@@ -23,7 +23,7 @@ except ImportError:
from pickle import dumps, loads
def generate_random_tttensors(num_tests, max_order, max_dimension, max_rank, random):
def generate_random_tttensors(num_tests, max_order=10, max_dimension=10, max_rank=4, random=None):
orders = random.randint(low=1, high=max_order+1, size=num_tests)
dimensions = [random.randint(low=1, high=max_dimension, size=order) for order in orders]
ranks = [random.randint(low=1, high=max_rank, size=(order-1)) for order in orders]
......@@ -38,8 +38,11 @@ def generate_random_tensors(num_tests, max_order=4, max_dimension=10, random=Non
yield xe.Tensor.random(dim.tolist())
list2str = lambda ls: "-".join(map(str, ls))
def __test_tensor(A):
name = "test_pickle_" + "-".join(map(str, A.dimensions))
name = "test_pickle_{}".format(list2str(A.dimensions))
def test_pickle(self):
bytes = dumps(A)
......@@ -50,10 +53,19 @@ def __test_tensor(A):
return name, test_pickle
# def test_tttensor(A):
# name_d = "-".join(map(str, A.dimensions))
# name_r = "-".join(map(str, A.ranks()))
# name = "test_pickle_{}_{}".format(name_d, name_r)
def __test_tttensor(A):
name = "test_pickle_{}_{}".format(list2str(A.dimensions), list2str(A.ranks()))
def test_pickle(self):
bytes = dumps(A)
Au = loads(bytes)
print(A.dimensions)
print(Au.dimensions)
self.assertEqual(Au.dimensions, A.dimensions)
self.assertEqual(Au.ranks(), A.ranks())
self.assertLessEqual(xe.frob_norm(A-Au), 1e-10)
return name, test_pickle
def build_TestPickleTensor(seed, num_tests):
......@@ -61,4 +73,12 @@ def build_TestPickleTensor(seed, num_tests):
odir = dict(__test_tensor(t) for t in generate_random_tensors(num_tests, random=random))
return type("TestPickleTensor", (unittest.TestCase,), odir)
def build_TestPickleTTTensor(seed, num_tests):
random = np.random.RandomState(seed)
odir = dict(__test_tttensor(t) for t in generate_random_tttensors(num_tests, random=random))
return type("TestPickleTensor", (unittest.TestCase,), odir)
TestPickleTensor = build_TestPickleTensor(0, 20)
TestPickleTTTensor = build_TestPickleTTTensor(0, 20)
......@@ -27,6 +27,14 @@
void expose_htnetwork(module& m) {
class_<HTTensor, TensorNetwork>(m,"HTTensor")
.def(pickle(
[](const HTTensor &_self) { // __getstate__
return bytes(misc::serialize(_self));
},
[](bytes _bytes) { // __setstate__
return misc::deserialize<HTTensor>(_bytes);
}
))
.def(init<const HTTensor &>())
.def(init<const Tensor&>())
.def(init<const Tensor&, value_t>())
......
......@@ -2,6 +2,14 @@
void expose_tensorNetwork(module& m) {
class_<TensorNetwork>(m, "TensorNetwork")
.def(pickle(
[](const TensorNetwork &_self) { // __getstate__
return bytes(misc::serialize(_self));
},
[](bytes _bytes) { // __setstate__
return misc::deserialize<TensorNetwork>(_bytes);
}
))
.def(init<>())
.def(init<Tensor>())
.def(init<const TensorNetwork &>())
......
......@@ -2,6 +2,14 @@
void expose_ttnetwork(module& m) {
class_<TTTensor, TensorNetwork>(m, "TTTensor")
.def(pickle(
[](const TTTensor &_self) { // __getstate__
return bytes(misc::serialize(_self));
},
[](bytes _bytes) { // __setstate__
return misc::deserialize<TTTensor>(_bytes);
}
))
.def(init<const TTTensor &>())
.def(init<const Tensor&>())
.def(init<const Tensor&, value_t>())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment