Commit 9e70d862 authored by Michael Goette's avatar Michael Goette

added python bindings for HTTensors

parent b8fa7ccb
......@@ -407,7 +407,7 @@ namespace xerus {
/**
* @brief Splits the TTNetwork into two parts by removing the node.
* @brief Splits the HTNetwork into two parts by removing the node.
* @param _position index of the component to be removed, thereby also defining the position
* of the split.
* @return a std::pair containing the two remaining parts as TensorNetworks.
......@@ -448,21 +448,21 @@ namespace xerus {
* @brief Applies the soft threshholding operation to all ranks.
* @param _tau the soft threshholding parameter to be applied. I.e. all singular values are reduced to max(0, Lambda_ui - _tau).
*/
// void soft_threshold(const double _tau, const bool _preventZero = false);
void soft_threshold(const double _tau, const bool _preventZero = false);
/**
* @brief Applies soft threshholding operations to all ranks.
* @param _taus the soft threshholding parameters to be applied. I.e. all singular values of the j-th matrification are reduced to max(0, Lambda_ui - _tau[j]).
*/
// void soft_threshold(const std::vector<double>& _taus, const bool _preventZero = false);
void soft_threshold(const std::vector<double>& _taus, const bool _preventZero = false);
/**
* @brief Gets the ranks of the HTNetwork.
* @return A vector containing the current ranks.
*/
// std::vector<size_t> ranks() const;
std::vector<size_t> ranks() const;
/**
......@@ -470,7 +470,7 @@ namespace xerus {
* @param _i Position of the edge in question.
* @return The current rank of edge _i.
*/
// size_t rank(const size_t _i) const;
size_t rank(const size_t _i) const;
/**
......
......@@ -844,51 +844,50 @@ namespace xerus {
}
//
//
// template<bool isOperator>
// void TTNetwork<isOperator>::soft_threshold(const std::vector<double> &_taus, const bool /*_preventZero*/) {
// const size_t numComponents = degree()/N;
// REQUIRE(_taus.size()+1 == numComponents || (_taus.empty() && numComponents == 0), "There must be exactly degree/N-1 taus. Here " << _taus.size() << " instead of " << numComponents-1 << " are given.");
// require_correct_format();
//
// const bool initialCanonicalization = canonicalized;
// const size_t initialCorePosition = corePosition;
//
// canonicalize_right();
//
// for(size_t i = 0; i+1 < numComponents; ++i) {
// round_edge(numComponents-i, numComponents-i-1, std::numeric_limits<size_t>::max(), 0.0, _taus[i]);
// }
//
// assume_core_position(0);
//
// if(initialCanonicalization) {
// move_core(initialCorePosition);
// }
// }
//
//
// template<bool isOperator>
// void TTNetwork<isOperator>::soft_threshold(const double _tau, const bool _preventZero) {
// soft_threshold(std::vector<double>(num_ranks(), _tau), _preventZero);
// }
//
//
// template<bool isOperator>
// std::vector<size_t> HTNetwork<isOperator>::ranks() const {
// std::vector<size_t> res;
// res.reserve(num_ranks());
// for (size_t n = 1; n+2 < nodes.size(); ++n) {
// res.push_back(nodes[n].neighbors.back().dimension);
// }
// return res;
// }
template<bool isOperator>
void HTNetwork<isOperator>::soft_threshold(const std::vector<double> &_taus, const bool /*_preventZero*/) {
REQUIRE(_taus.size()+1 == numberOfComponents || (_taus.empty() && numberOfComponents == 0), "There must be exactly " << numberOfComponents << " taus. Here " << _taus.size() << " instead of " << numberOfComponents-1 << " are given.");
require_correct_format();
const bool initialCanonicalization = canonicalized;
const size_t initialCorePosition = corePosition;
canonicalize_root();
for(size_t i = 0; i+1 < numberOfComponents; ++i) {
round_edge(numberOfComponents-i, numberOfComponents-i-1, std::numeric_limits<size_t>::max(), 0.0, _taus[i]);
}
assume_core_position(0);
if(initialCanonicalization) {
move_core(initialCorePosition);
}
}
template<bool isOperator>
void HTNetwork<isOperator>::soft_threshold(const double _tau, const bool _preventZero) {
soft_threshold(std::vector<double>(num_ranks(), _tau), _preventZero);
}
//
//
// template<bool isOperator>
// size_t HTNetwork<isOperator>::rank(const size_t _i) const {
// REQUIRE(_i < numberOfComponents, "Requested illegal rank " << _i);
// return nodes[_i].neighbors.back().dimension;
// }
template<bool isOperator>
std::vector<size_t> HTNetwork<isOperator>::ranks() const {
std::vector<size_t> res;
res.reserve(num_ranks());
for (size_t n = 1; n+2 < nodes.size(); ++n) {
res.push_back(nodes[n].neighbors.back().dimension);
}
return res;
}
template<bool isOperator>
size_t HTNetwork<isOperator>::rank(const size_t _i) const {
REQUIRE(_i < numberOfComponents, "Requested illegal rank " << _i);
return nodes[_i].neighbors.back().dimension;
}
template<bool isOperator>
......
// Xerus - A General Purpose Tensor Library
// Copyright (C) 2014-2018 Benjamin Huber and Sebastian Wolf.
//
// Xerus is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// Xerus is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with Xerus. If not, see <http://www.gnu.org/licenses/>.
//
// For further information on Xerus visit https://libXerus.org
// or contact us at contact@libXerus.org.
/**
* @file
* @brief Definition of the TT-Network python bindings.
*/
#define NO_IMPORT_ARRAY
#include "misc.h"
void expose_htnetwork() {
VECTOR_TO_PY(HTTensor, "HTTensorVector");
VECTOR_TO_PY(HTOperator, "HTOperatorVector");
class_<HTTensor, bases<TensorNetwork>>("HTTensor")
.def(init<const Tensor&, optional<value_t, size_t>>())
.def(init<const Tensor&, value_t, TensorNetwork::RankTuple>())
.def(init<Tensor::DimensionTuple>())
.def(init<size_t>())
.def(init<const HTTensor &>())
.def("get_component", &HTTensor::get_component, return_value_policy<copy_const_reference>())
.def("set_component", &HTTensor::set_component)
.def_readonly("canonicalized", &HTTensor::canonicalized)
.def_readonly("corePosition", &HTTensor::corePosition)
.def_readonly("numberOfComponents", &HTTensor::numberOfComponents)
.def("ranks", &HTTensor::ranks)
.def("rank", &HTTensor::rank)
// .def("frob_norm", &TTTensor::frob_norm) // NOTE unneccessary because correct call is inherited
.def("random",
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::HTTensor::random(_dim, _rank);
}).staticmethod("random")
.def("ones", &TTTensor::ones).staticmethod("ones")
.def("kronecker", &TTTensor::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTTensor::dirac))
.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const size_t)>(&TTTensor::dirac)).staticmethod("dirac")
.def("use_dense_representations", &TTTensor::use_dense_representations)
.def("reduce_to_maximal_ranks", &TTTensor::reduce_to_maximal_ranks).staticmethod("reduce_to_maximal_ranks")
//// .def("degrees_of_freedom", static_cast<size_t (TTTensor::*)()>(&TTTensor::degrees_of_freedom))
.def("degrees_of_freedom", static_cast<size_t (*)(const std::vector<size_t>&, const std::vector<size_t>&)>(&TTTensor::degrees_of_freedom)).staticmethod("degrees_of_freedom")
// .def("chop",
// +[](HTTensor &_this, size_t _pos) {
// const auto result = _this.chop(_pos);
// return boost::python::make_tuple(result.first, result.second);
// }, arg("position"))
//// .def("round", static_cast<void (TTTensor::*)(const std::vector<size_t>&, double)>(&TTTensor::round),
//// (arg("ranks"), arg("epsilon")=EPSILON)
//// )
.def("round", static_cast<void (HTTensor::*)(double)>(&HTTensor::round))
//// .def("round", static_cast<void (TTTensor::*)(size_t)>(&TTTensor::round))
//
.def("soft_threshold", static_cast<void (HTTensor::*)(const double, const bool)>(&HTTensor::soft_threshold),
(arg("tau"), arg("preventZero")=false)
)
.def("soft_threshold", static_cast<void (HTTensor::*)(const std::vector<double>&, const bool)>(&HTTensor::soft_threshold),
(arg("tau"), arg("preventZero")=false)
)
.def("move_core", &HTTensor::move_core,
(arg("position"), arg("keepRank")=false)
)
.def("assume_core_position", &HTTensor::assume_core_position)
.def("canonicalize_root", &HTTensor::canonicalize_root)
.def(self + self)
.def(self - self)
.def(self * other<value_t>())
.def(other<value_t>() * self)
.def(self / other<value_t>())
.def(self += self)
.def(self -= self)
;
// def("entrywise_product", static_cast<TTTensor (*)(const TTTensor&, const TTTensor&)>(&entrywise_product));
// def("find_largest_entry", static_cast<size_t (*)(const TTTensor&, value_t, value_t)>(&find_largest_entry));
// def("dyadic_product", static_cast<TTTensor (*)(const std::vector<TTTensor> &)>(&dyadic_product));
class_<HTOperator, bases<TensorNetwork>>("HTOperator")
.def(init<const Tensor&, optional<value_t, size_t>>())
.def(init<const Tensor&, value_t, TensorNetwork::RankTuple>())
.def(init<Tensor::DimensionTuple>())
.def(init<size_t>())
.def(init<const HTOperator &>())
.def("get_component", &HTOperator::get_component, return_value_policy<copy_const_reference>())
.def("set_component", &HTOperator::set_component)
.def_readonly("canonicalized", &HTOperator::canonicalized)
.def_readonly("corePosition", &HTOperator::corePosition)
.def_readonly("numberOfComponents", &HTOperator::numberOfComponents)
.def("ranks", &HTOperator::ranks)
.def("rank", &HTOperator::rank)
// .def("frob_norm", &TTOperator::frob_norm) // NOTE unneccessary because correct call is inherited
.def("random",
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::HTOperator::random(_dim, _rank);
}).staticmethod("random")
.def("ones", &TTOperator::ones).staticmethod("ones")
.def("kronecker", &TTOperator::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTOperator::dirac))
.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const size_t)>(&TTOperator::dirac)).staticmethod("dirac")
.def("use_dense_representations", &TTOperator::use_dense_representations)
.def("reduce_to_maximal_ranks", &TTOperator::reduce_to_maximal_ranks).staticmethod("reduce_to_maximal_ranks")
// .def("degrees_of_freedom", static_cast<size_t (TTOperator::*)()>(&TTOperator::degrees_of_freedom))
.def("degrees_of_freedom", static_cast<size_t (*)(const std::vector<size_t>&, const std::vector<size_t>&)>(&TTOperator::degrees_of_freedom)).staticmethod("degrees_of_freedom")
// .def("chop",
// +[](HTOperator &_this, size_t _pos) {
// const auto result = _this.chop(_pos);
// return boost::python::make_tuple(result.first, result.second);
// }, arg("position"))
//
.def("round", static_cast<void (TTOperator::*)(const std::vector<size_t>&, double)>(&TTOperator::round),
(arg("ranks"), arg("epsilon")=EPSILON)
)
.def("round", static_cast<void (TTOperator::*)(double)>(&TTOperator::round))
.def("round", static_cast<void (TTOperator::*)(size_t)>(&TTOperator::round))
.def("soft_threshold", static_cast<void (HTOperator::*)(const double, const bool)>(&HTOperator::soft_threshold),
(arg("tau"), arg("preventZero")=false)
)
.def("soft_threshold", static_cast<void (HTOperator::*)(const std::vector<double>&, const bool)>(&HTOperator::soft_threshold),
(arg("tau"), arg("preventZero")=false)
)
.def("move_core", &HTOperator::move_core,
(arg("position"), arg("keepRank")=false)
)
.def("assume_core_position", &HTOperator::assume_core_position)
.def("canonicalize_root", &HTOperator::canonicalize_root)
.def(self + self)
.def(self - self)
.def(self += self)
.def(self -= self)
.def(self * other<value_t>())
.def(other<value_t>() * self)
.def(self / other<value_t>())
//
//
// // for TTOperator only:
// .def("identity", &TTOperator::identity<>).staticmethod("identity")
.def("transpose", &TTOperator::transpose<>)
;
// def("entrywise_product", static_cast<TTOperator (*)(const TTOperator&, const TTOperator&)>(&entrywise_product));
// def("find_largest_entry", static_cast<size_t (*)(const TTOperator&, value_t, value_t)>(&find_largest_entry));
// def("dyadic_product", static_cast<TTOperator (*)(const std::vector<TTOperator> &)>(&dyadic_product));
}
......@@ -59,6 +59,7 @@ void expose_factorizations();
void expose_indexedTensors();
void expose_tensorNetwork();
void expose_ttnetwork();
void expose_htnetwork();
void expose_leastSquaresAlgorithms();
void expose_recoveryAlgorithms();
......
......@@ -48,6 +48,7 @@ BOOST_PYTHON_MODULE(xerus) {
expose_tensor();
expose_tensorNetwork();
expose_ttnetwork();
expose_htnetwork();
expose_leastSquaresAlgorithms();
expose_recoveryAlgorithms();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment