Commit 0747d57d authored by Philipp  Trunschke's avatar Philipp Trunschke
Browse files

remove buffer interface & reintroduce to_ndarray/from_ndarray

parent e722066f
Pipeline #2150 failed with stages
in 4 minutes and 2 seconds
......@@ -23,12 +23,9 @@ Tensor Tensor_from_buffer(buffer& _b) {
return Tensor({}, Tensor::Representation::Dense, Tensor::Initialisation::None);
}
std::vector<size_t> dims(info.shape.begin(), info.shape.end());
std::vector<size_t> strides(info.strides.begin(), info.strides.end());
Tensor result(dims, Tensor::Representation::Dense, Tensor::Initialisation::None);
std::vector<size_t> dimensions(info.shape.begin(), info.shape.end());
Tensor result(dimensions, Tensor::Representation::Dense, Tensor::Initialisation::None);
misc::copy(result.get_unsanitized_dense_data(), static_cast<double*>(info.ptr), result.size);
return result;
}
......@@ -43,16 +40,17 @@ void expose_tensor(module& m) {
;
class_<Tensor>(m, "Tensor", "a non-decomposed Tensor in either sparse or dense representation", buffer_protocol())
.def_buffer([](Tensor& t) -> buffer_info {
return buffer_info(
t.get_dense_data(), /* Pointer to buffer */
sizeof(value_t), /* Size of one scalar */
format_descriptor<value_t>::format(), /* Python struct-style format descriptor */
t.order(), /* Number of dimensions */
t.dimensions, /* Buffer dimensions */
strides_from_dimensions_and_item_size(t.dimensions, sizeof(value_t)) /* Strides (in bytes) for each index */
);
})
// .def_buffer([](Tensor& t) -> buffer_info {
// std::cerr << "buffer" << std::endl;
// return buffer_info(
// t.get_dense_data(), /* Pointer to buffer */
// sizeof(value_t), /* Size of one scalar */
// format_descriptor<value_t>::format(), /* Python struct-style format descriptor */
// t.order(), /* Number of dimensions */
// t.dimensions, /* Buffer dimensions */
// strides_from_dimensions_and_item_size(t.dimensions, sizeof(value_t)) /* Strides (in bytes) for each index */
// );
// })
.def(pickle(
[](const Tensor &_self) { // __getstate__
return bytes(misc::serialize(_self));
......@@ -76,6 +74,15 @@ void expose_tensor(module& m) {
return Tensor(_dim, _f);
})
.def_static("from_buffer", &Tensor_from_buffer)
.def_static("from_ndarray", &Tensor_from_buffer)
.def("to_ndarray", +[](const xerus::Tensor& _self){
static_assert(std::is_same<value_t, double>::value);
Tensor* shcp = new xerus::Tensor(_self); // create a shallow copy (this effectively increases its reference count for the required amount of time)
shcp->ensure_own_data_and_apply_factor();
shcp->use_dense_representation();
capsule cps(shcp, [](void *v) { delete reinterpret_cast<xerus::Tensor*>(v); });
return array(dtype::of<double>(), shcp->dimensions, shcp->get_unsanitized_dense_data(), cps);
})
.def_property_readonly("dimensions", +[](Tensor &_A) {
return _A.dimensions;
})
......
......@@ -3,10 +3,10 @@
void expose_tensorNetwork(module& m) {
class_<TensorNetwork>(m, "TensorNetwork")
.def(pickle(
[](const TensorNetwork &_self) { // __getstate__
+[](const TensorNetwork &_self) { // __getstate__
return bytes(misc::serialize(_self));
},
[](bytes _bytes) { // __setstate__
+[](bytes _bytes) { // __setstate__
return misc::deserialize<TensorNetwork>(_bytes);
}
))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment