Commit 6ddda9ec authored by Philipp  Trunschke's avatar Philipp Trunschke
Browse files

fix segfault in test_pickle

parent 2af7a541
Pipeline #1437 passed with stages
in 19 minutes and 39 seconds
import unittest import unittest
import numpy as np import numpy as np
import xerus as xe import xerus as xe
import pickle try:
# python2
# Here cPickle has to be used instead of pickle and a protocol version newer 2.0 has to be used.
# From the documentation:
#
# Note that only the cPickle module is supported on Python 2.7.
# The second argument to dumps is also crucial: it selects the pickle protocol version 2,
# since the older version 1 is not supported. Newer versions are also fine - for instance,
# specify -1 to always use the latest available version.
#
# Beware: failure to follow these instructions will cause important pybind11 memory
# allocation routines to be skipped during unpickling, which will likely lead to memory
# corruption and/or segmentation faults.
import cPickle
assert '2.0' in cPickle.compatible_formats
dumps = lambda o: cPickle.dumps(o, protocol=-1)
loads = cPickle.loads
except ImportError:
# python3
from pickle import dumps, loads
def generate_random_tttensors(num_tests, max_order, max_dimension, max_rank, random): def generate_random_tttensors(num_tests, max_order, max_dimension, max_rank, random):
...@@ -23,8 +42,8 @@ def __test_tensor(A): ...@@ -23,8 +42,8 @@ def __test_tensor(A):
name = "test_pickle_" + "-".join(map(str, A.dimensions)) name = "test_pickle_" + "-".join(map(str, A.dimensions))
def test_pickle(self): def test_pickle(self):
bytes = pickle.dumps(A) bytes = dumps(A)
Au = pickle.loads(bytes) Au = loads(bytes)
self.assertEqual(Au.dimensions, A.dimensions) self.assertEqual(Au.dimensions, A.dimensions)
self.assertLessEqual(xe.frob_norm(A-Au), 1e-16) self.assertLessEqual(xe.frob_norm(A-Au), 1e-16)
......
...@@ -40,6 +40,8 @@ class TestTensor(unittest.TestCase): ...@@ -40,6 +40,8 @@ class TestTensor(unittest.TestCase):
a = xe.Tensor() a = xe.Tensor()
xe.Tensor(a) xe.Tensor(a)
xe.Tensor(xe.TTTensor([2]))
# xe.Tensor(tensor_network) # xe.Tensor(tensor_network)
def test_create_tensors(self): def test_create_tensors(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment