Commit 6e5a86ab authored by Michael Goette's avatar Michael Goette

fixes for httensors (more to come)

parent 5f4634a5
Pipeline #1017 passed with stages
in 10 minutes and 20 seconds
......@@ -507,7 +507,7 @@ namespace xerus {
/**
* @brief Transpose the TTOperator
* @brief Transpose the HTOperator
* @details Swaps all external indices to create the transposed operator.
*/
template<bool B = isOperator, typename std::enable_if<B, int>::type = 0>
......
......@@ -876,7 +876,9 @@ namespace xerus {
std::vector<size_t> HTNetwork<isOperator>::ranks() const {
std::vector<size_t> res;
res.reserve(num_ranks());
for (size_t n = 1; n+2 < nodes.size(); ++n) {
const size_t numIntComp = static_cast<size_t>(0.5 + std::pow(2,std::ceil(std::log2(static_cast<double>(degree()/N ))))) - 1;
for (size_t n = 0; n < numIntComp; ++n) {
res.push_back(nodes[n].neighbors.end()[-2].dimension);
res.push_back(nodes[n].neighbors.back().dimension);
}
return res;
......
......@@ -48,7 +48,7 @@ void expose_htnetwork() {
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::HTTensor::random(_dim, _rank);
}).staticmethod("random")
.def("ones", &TTTensor::ones).staticmethod("ones")
.def("ones", &HTTensor::ones).staticmethod("ones")
.def("kronecker", &TTTensor::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTTensor::dirac))
.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const size_t)>(&TTTensor::dirac)).staticmethod("dirac")
......@@ -113,7 +113,7 @@ void expose_htnetwork() {
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::HTOperator::random(_dim, _rank);
}).staticmethod("random")
.def("ones", &TTOperator::ones).staticmethod("ones")
.def("ones", &HTOperator::ones).staticmethod("ones")
.def("kronecker", &TTOperator::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTOperator::dirac))
.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const size_t)>(&TTOperator::dirac)).staticmethod("dirac")
......@@ -128,11 +128,11 @@ void expose_htnetwork() {
// return boost::python::make_tuple(result.first, result.second);
// }, arg("position"))
//
.def("round", static_cast<void (TTOperator::*)(const std::vector<size_t>&, double)>(&TTOperator::round),
.def("round", static_cast<void (HTOperator::*)(const std::vector<size_t>&, double)>(&HTOperator::round),
(arg("ranks"), arg("epsilon")=EPSILON)
)
.def("round", static_cast<void (TTOperator::*)(double)>(&TTOperator::round))
.def("round", static_cast<void (TTOperator::*)(size_t)>(&TTOperator::round))
.def("round", static_cast<void (HTOperator::*)(double)>(&HTOperator::round))
.def("round", static_cast<void (HTOperator::*)(size_t)>(&HTOperator::round))
.def("soft_threshold", static_cast<void (HTOperator::*)(const double, const bool)>(&HTOperator::soft_threshold),
(arg("tau"), arg("preventZero")=false)
......@@ -158,7 +158,7 @@ void expose_htnetwork() {
//
// // for TTOperator only:
// .def("identity", &TTOperator::identity<>).staticmethod("identity")
.def("transpose", &TTOperator::transpose<>)
.def("transpose", &HTOperator::transpose<>)
;
// def("entrywise_product", static_cast<TTOperator (*)(const TTOperator&, const TTOperator&)>(&entrywise_product));
// def("find_largest_entry", static_cast<size_t (*)(const TTOperator&, value_t, value_t)>(&find_largest_entry));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment