Commit b316a105 authored by Michael Goette's avatar Michael Goette

Merge branch 'development' into iterative_ev_solver

parents a632b2c1 5cfa25a4
......@@ -62,10 +62,6 @@ namespace xerus {
*/
size_t corePosition;
/**
* @brief Number of Components
*/
size_t numberOfComponents;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Constructors - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
/**
......@@ -337,7 +333,10 @@ namespace xerus {
*/
bool is_left_child(size_t _comp) const;
/**
* @brief function which returns the number of components in an HTTensor
*/
size_t get_number_of_components() const;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Miscellaneous - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
public:
......@@ -507,14 +506,15 @@ namespace xerus {
/**
* @brief Transpose the TTOperator
* @brief Transpose the HTOperator
* @details Swaps all external indices to create the transposed operator.
*/
template<bool B = isOperator, typename std::enable_if<B, int>::type = 0>
void transpose() {
const std::vector<size_t> shuffle({0,2,1});
size_t numComp = get_number_of_components();
//only leaves
for (size_t n = numberOfComponents - 1; n >= numberOfComponents - degree()/N; --n) {
for (size_t n = numComp - 1; n >= numComp - degree()/N; --n) {
xerus::reshuffle(component(n), component(n), shuffle);
}
}
......@@ -572,19 +572,19 @@ namespace xerus {
virtual void operator/=(const value_t _divisor) override;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Operator specializations - - - - - - - - - - - - - - - - - - - - - - - - - - */
// static bool specialized_contraction_f(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other);
static bool specialized_contraction_f(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other);
// static bool specialized_sum_f(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other);
static bool specialized_sum_f(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other);
// virtual bool specialized_contraction(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) const override {
// return specialized_contraction_f(_out, std::move(_me), std::move(_other));
// }
//
// virtual bool specialized_sum(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) const override {
// return specialized_sum_f(_out, std::move(_me), std::move(_other));
// }
//
// virtual void specialized_evaluation(internal::IndexedTensorWritable<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) override;
virtual bool specialized_contraction(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) const override {
return specialized_contraction_f(_out, std::move(_me), std::move(_other));
}
virtual bool specialized_sum(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) const override {
return specialized_sum_f(_out, std::move(_me), std::move(_other));
}
virtual void specialized_evaluation(internal::IndexedTensorWritable<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) override;
};
......
// Xerus - A General Purpose Tensor Library
// Copyright (C) 2014-2018 Benjamin Huber and Sebastian Wolf.
//
// Xerus is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// Xerus is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with Xerus. If not, see <http://www.gnu.org/licenses/>.
//
// For further information on Xerus visit https://libXerus.org
// or contact us at contact@libXerus.org.
/**
* @file
* @brief Header file for the HTStack class.
*/
#pragma once
#include "misc/check.h"
#include "tensorNetwork.h"
#include "indexedTensorMoveable.h"
namespace xerus {
template<bool isOperator> class HTNetwork;
namespace internal {
template<bool isOperator>
///@brief Internal class used to represent stacks of (possibly multiply) applications of HTOperators to either a HTTensor or HTOperator.
class HTStack final : public TensorNetwork {
public:
///@brief The number of external links in each node, i.e. one for HTTensors and two for HTOperators.
static constexpr const size_t N = isOperator?2:1;
const bool cannonicalization_required;
const size_t futureCorePosition;
explicit HTStack(const bool _canno, const size_t _corePos = 0);
HTStack(const HTStack& _other) = default;
HTStack( HTStack&& _other) = default;
virtual TensorNetwork* get_copy() const override;
HTStack& operator= (const HTStack& _other) = delete;
HTStack& operator= ( HTStack&& _other) = delete;
explicit operator HTNetwork<isOperator>();
virtual void operator*=(const value_t _factor) override;
virtual void operator/=(const value_t _divisor) override;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Operator specializations - - - - - - - - - - - - - - - - - - - - - - - - - - */
virtual void specialized_evaluation(IndexedTensorWritable<TensorNetwork>&& , IndexedTensorReadOnly<TensorNetwork>&&) override;
virtual bool specialized_contraction(std::unique_ptr<IndexedTensorMoveable<TensorNetwork>>& _out, IndexedTensorReadOnly<TensorNetwork>&& _me, IndexedTensorReadOnly<TensorNetwork>&& _other) const override;
// virtual bool specialized_sum(std::unique_ptr<IndexedTensorMoveable<TensorNetwork>>& _out, IndexedTensorReadOnly<TensorNetwork>&& _me, IndexedTensorReadOnly<TensorNetwork>&& _other) const override;
static void contract_stack(IndexedTensorWritable<TensorNetwork>&& _me);
virtual value_t frob_norm() const override;
};
}
}
......@@ -300,3 +300,42 @@ static misc::UnitTest tensor_solve_matrix("Tensor", "solve_matrix", [](){
MTEST(frob_norm(residual) < 1e-10, frob_norm(residual));
}
});
static misc::UnitTest tensor_solve_w_extra_degree("Tensor", "solve with extra degrees", [](){
Index ii,jj,kk,ll,mm,nn;
Tensor A = xerus::Tensor::random({2,2});
Tensor B = xerus::Tensor::random({2,2});
Tensor X({2,2});
Tensor tmp({2,2});
//non symmetric
xerus::solve(X, A, B,1);
tmp(ii,kk) = A(ii,jj)*X(jj,kk);
TEST((tmp - B).frob_norm() < 1e-13);
//symmetric
A(ii,jj) = A(ii,jj) + A(jj,ii);
xerus::solve(X, A, B,1);
tmp(ii,kk) = A(ii,jj)*X(jj,kk);
TEST((tmp - B).frob_norm() < 1e-13);
//higher order
Tensor A2 = xerus::Tensor::random({5,5,5,5});
Tensor B2 = xerus::Tensor::random({5,5,5,5});
Tensor X2({5,5,5,5});
Tensor tmp2({5,5,5,5});
xerus::solve(X2, A2, B2,2);
tmp2(ii^2,kk^2) = A2(ii^2,jj^2)*X2(jj^2,kk^2);
TEST((tmp2 - B2).frob_norm() < 1e-13);
Tensor A3 = xerus::Tensor::random({5,5,5,5});
Tensor B3 = xerus::Tensor::random({5,5,5});
Tensor X3({5,5,5});
Tensor tmp3({5,5,5});
xerus::solve(X3, A3, B3,1);
tmp3(ii^2,kk) = A3(ii^2,jj^2)*X3(jj^2,kk);
TEST((tmp3 - B3).frob_norm() < 1e-13);
});
......@@ -94,68 +94,71 @@ static misc::UnitTest ht_real_diff("HT", "real_difference", [](){
Index i;
htC(i&0) = htA(i&0) - htA(i&0);
MTEST(frob_norm(htC(i&0)) < 1e-11, "1 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "1 " << frob_norm(htC(i&0)));
htC(i&0) = htB(i&0) - htB(i&0);
MTEST(frob_norm(htC(i&0)) < 1e-11, "2 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "2 " << frob_norm(htC(i&0)));
htC(i&0) = (htA(i&0) + htB(i&0)) - (htA(i&0) + htB(i&0));
MTEST(frob_norm(htC(i&0)) < 1e-11, "3 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "3 " << frob_norm(htC(i&0)));
// htC(i&0) = (htA(i&0) + htB(i&0));
// TEST(htC.ranks() == std::vector<size_t>({8,8,8,8}));
// ttC(i&0) = (ttB(i&0) + ttA(i&0));
// TEST(ttC.ranks() == std::vector<size_t>({8,8,8,8}));
htC(i&0) = (htA(i&0) + htB(i&0));
TEST(htC.ranks() == std::vector<size_t>({ 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 1, 1, 1 }));
htC(i&0) = (htB(i&0) + htA(i&0));
TEST(htC.ranks() == std::vector<size_t>({ 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 1, 1, 1 }));
htC(i&0) = (htA(i&0) + htB(i&0)) - (htB(i&0) + htA(i&0));
MTEST(frob_norm(htC(i&0)) < 1e-11, "4 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "4 " << frob_norm(htC(i&0)));
htC(i&0) = (73*htA(i&0) + htB(i&0)) - (htB(i&0) + 73*htA(i&0));
MTEST(frob_norm(htC(i&0)) < 1e-9, "5 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-7, "5 " << frob_norm(htC(i&0)));
htA = HTTensor::random({10,10,10,10,10}, {2,5,7,2,4,4,3,6,5,2,4});
htC(i&0) = htA(i&0) - htA(i&0);
MTEST(frob_norm(htC(i&0)) < 1e-11, "6 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "6 " << frob_norm(htC(i&0)));
htC(i&0) = htB(i&0) - htB(i&0);
MTEST(frob_norm(htC(i&0)) < 1e-11, "7 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "7 " << frob_norm(htC(i&0)));
htC(i&0) = (htA(i&0) + htB(i&0)) - (htA(i&0) + htB(i&0));
MTEST(frob_norm(htC(i&0)) < 1e-11, "8 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "8 " << frob_norm(htC(i&0)));
htC(i&0) = (htA(i&0) + htB(i&0)) - (htB(i&0) + htA(i&0));
MTEST(frob_norm(htC(i&0)) < 1e-11, "9 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-9, "9 " << frob_norm(htC(i&0)));
htC(i&0) = (73*htA(i&0) + htB(i&0)) - (htB(i&0) + 73*htA(i&0));
MTEST(frob_norm(htC(i&0)) < 5e-10, "10 " << frob_norm(htC(i&0)));
MTEST(frob_norm(htC(i&0)) < 1e-8, "10 " << frob_norm(htC(i&0)));
});
//
//static misc::UnitTest tt_diff_stacks("TT", "difference_of_TTStacks", [](){
// TTOperator ttO = TTOperator::random({10,10,10,10,10,10,10,10,10,10}, {4,4,4,4});
// TTTensor ttA = TTTensor::random({10,10,10,10,10}, {4,4,4,4});
// TTTensor ttB = TTTensor::random({10,10,10,10,10}, {4,4,4,4});
// TTTensor ttC;
//
// Index i,j,k;
// ttC(i&0) = ttO(i/2, j/2)*ttA(j&0) - ttO(i/2, j/2)*ttA(j&0);
// LOG(unit_tests, "Frob norm 1 " << frob_norm(ttC(i&0)));
// TEST(frob_norm(ttC(i&0)) < 1e-7);
//
// ttC(i&0) = ttO(i/2, j/2)*ttB(j&0) - ttO(i/2, j/2)*ttB(j&0);
// LOG(unit_tests, "Frob norm 2 " << frob_norm(ttC(i&0)));
// TEST(frob_norm(ttC(i&0)) < 1e-7);
//});
//
//static misc::UnitTest tt_stack_norm("TT", "ttStacks_frob_norm", [](){
static misc::UnitTest ht_diff_stacks("HT", "difference_of_HTStacks", [](){
HTOperator htO = HTOperator::random({10,10,10,10,10,10,10,10}, {4,4,4,4,4,4});
HTTensor htA = HTTensor::random({10,10,10,10}, {4,4,4,4,4,4});
HTTensor htB = HTTensor::random({10,10,10,10}, {4,4,4,4,4,4});
HTTensor htC;
Index i,j,k;
htC(i&0) = htO(i/2, j/2)*htA(j&0) - htO(i/2, j/2)*htA(j&0);
LOG(unit_tests, "Frob norm 1 " << frob_norm(htC(i&0)));
TEST(frob_norm(htC(i&0)) < 1e-7);
htC(i&0) = htO(i/2, j/2)*htB(j&0) - htO(i/2, j/2)*htB(j&0);
LOG(unit_tests, "Frob norm 2 " << frob_norm(htC(i&0)));
TEST(frob_norm(htC(i&0)) < 1e-7);
});
//static misc::UnitTest ht_stack_norm("HT", "htStacks_frob_norm", [](){
// const Index i, j, k;
//
// TTOperator ttO1 = TTOperator::identity({10,10,10,10,10,10,10,10,10,10});
// TTOperator ttO2 = TTOperator::identity({10,10,10,10,10,10,10,10,10,10});
// HTOperator htO1 = HTOperator::identity({10,10,10,10,10,10,10,10,10,10});
// HTOperator htO2 = HTOperator::identity({10,10,10,10,10,10,10,10,10,10});
//
// MTEST(misc::approx_equal(frob_norm(ttO1(i&0)*ttO2(i&0)), double(misc::pow(10, 5))), frob_norm(ttO1(i&0)*ttO2(i&0)) << " vs " << misc::pow(10, 5));
// MTEST(misc::approx_equal(frob_norm(htO1(i&0)*htO2(i&0)), double(misc::pow(10, 5))), frob_norm(htO1(i&0)*htO2(i&0)) << " vs " << misc::pow(10, 5));
//
// TEST(misc::approx_equal(frob_norm(ttO1(i/2, j/2)*ttO2(j/2, k/2)), std::sqrt(misc::pow(10, 5))));
// TEST(misc::approx_equal(frob_norm(htO1(i/2, j/2)*htO2(j/2, k/2)), std::sqrt(misc::pow(10, 5))));
//});
//
static misc::UnitTest ht_spec_sumdiff("HT", "special_sum_diff", [](){
Tensor A({10,10,10,10}); // NOTE that this is the 0 tensor
Tensor B = Tensor::random({10,10,10,10});
......@@ -169,23 +172,23 @@ static misc::UnitTest ht_spec_sumdiff("HT", "special_sum_diff", [](){
C(i&0) = A(i&0) + B(i&0);
htC(i&0) = htA(i&0) + htB(i&0);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-12);
C(i&0) = B(i&0) + A(i&0);
htC(i&0) = htB(i&0) + htA(i&0);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-12);
C(i&0) = A(i&0) - B(i&0);
htC(i&0) = htA(i&0) - htB(i&0);
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)) < 3.1*1e-13, frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)) < 3.1*1e-12, frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)));
C(i&0) = B(i&0) - A(i&0);
htC(i&0) = htB(i&0) - htA(i&0);
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-13, frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-12, frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)));
Tensor X({10});
Tensor Y = Tensor::random({10});
......@@ -196,23 +199,23 @@ static misc::UnitTest ht_spec_sumdiff("HT", "special_sum_diff", [](){
Z(i&0) = X(i&0) + Y(i&0);
htZ(i&0) = htX(i&0) + htY(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-12);
Z(i&0) = Y(i&0) + X(i&0);
htZ(i&0) = htY(i&0) + htX(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-12);
Z(i&0) = X(i&0) - Y(i&0);
htZ(i&0) = htX(i&0) - htY(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) + Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) + Tensor(htY)(i&0)) < 3.1*1e-12);
Z(i&0) = Y(i&0) - X(i&0);
htZ(i&0) = htY(i&0) - htX(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-12);
});
static misc::UnitTest ht_prod("HT", "product", [](){
......@@ -237,7 +240,7 @@ static misc::UnitTest ht_prod("HT", "product", [](){
C(i^2,k^2) = A(i^2,j^2) * B(j^2,k^2);
htC(i^2,k^2) = htA(i^2,j^2) * htB(j^2,k^2);
TEST(htC.nodes.size() == 8);
TEST(htC.nodes.size() == 4);
fnorm = frob_norm(Tensor(htC)(i&0) - C(i&0));
LOG(unit_tests, "frob_norm " << fnorm);
TEST(fnorm < 10*10*10*10*1e-15);
......@@ -378,9 +381,9 @@ static misc::UnitTest ht_axb("HT", "ax_b", [](){
});
static misc::UnitTest ht_opt("HT", "operator_times_tensor", [](){
Tensor A = Tensor::random({10,10,10,10});
Tensor B = Tensor::random({10,10,10,10});
Tensor C = Tensor::random({10,10});
Tensor A = Tensor::random({6,6,6,6});
Tensor B = Tensor::random({6,6,6,6});
Tensor C = Tensor::random({6,6});
Tensor D;
Tensor Do;
HTOperator htA(A);
......
......@@ -564,7 +564,7 @@ namespace xerus {
std::unique_ptr<int[]> pivot(new int[_n]);
misc::copy(_x, _b, _n);
misc::copy(_x, _b, _n*_nrhs);
IF_CHECK( int lapackAnswer = ) LAPACKE_dgesv(
LAPACK_ROW_MAJOR,
......@@ -604,7 +604,7 @@ namespace xerus {
LOG(debug, "cholesky");
XERUS_PA_START;
misc::copy(_x, _b, _n);
misc::copy(_x, _b, _n*_nrhs);
lapackAnswer = LAPACKE_dpotrs(
LAPACK_ROW_MAJOR,
......@@ -631,7 +631,7 @@ namespace xerus {
// non-definite diagonal or choleksy failed -> fallback to LDL^T decomposition
XERUS_PA_START;
misc::copy(_x, _b, _n);
misc::copy(_x, _b, _n*_nrhs);
std::unique_ptr<int[]> pivot(new int[_n]);
LAPACKE_dsysv(
......
This diff is collapsed.
This diff is collapsed.
......@@ -19,7 +19,7 @@
/**
* @file
* @brief Definition of the TT-Network python bindings.
* @brief Definition of the hT-Network python bindings.
*/
......@@ -40,7 +40,6 @@ void expose_htnetwork() {
.def("set_component", &HTTensor::set_component)
.def_readonly("canonicalized", &HTTensor::canonicalized)
.def_readonly("corePosition", &HTTensor::corePosition)
.def_readonly("numberOfComponents", &HTTensor::numberOfComponents)
.def("ranks", &HTTensor::ranks)
.def("rank", &HTTensor::rank)
// .def("frob_norm", &TTTensor::frob_norm) // NOTE unneccessary because correct call is inherited
......@@ -48,15 +47,15 @@ void expose_htnetwork() {
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::HTTensor::random(_dim, _rank);
}).staticmethod("random")
.def("ones", &TTTensor::ones).staticmethod("ones")
.def("kronecker", &TTTensor::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTTensor::dirac))
.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const size_t)>(&TTTensor::dirac)).staticmethod("dirac")
.def("ones", &HTTensor::ones).staticmethod("ones")
//.def("kronecker", &TTTensor::kronecker).staticmethod("kronecker")
//.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTTensor::dirac))
//.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const size_t)>(&TTTensor::dirac)).staticmethod("dirac")
.def("use_dense_representations", &TTTensor::use_dense_representations)
.def("reduce_to_maximal_ranks", &TTTensor::reduce_to_maximal_ranks).staticmethod("reduce_to_maximal_ranks")
//.def("use_dense_representations", &TTTensor::use_dense_representations)
//.def("reduce_to_maximal_ranks", &TTTensor::reduce_to_maximal_ranks).staticmethod("reduce_to_maximal_ranks")
//// .def("degrees_of_freedom", static_cast<size_t (TTTensor::*)()>(&TTTensor::degrees_of_freedom))
.def("degrees_of_freedom", static_cast<size_t (*)(const std::vector<size_t>&, const std::vector<size_t>&)>(&TTTensor::degrees_of_freedom)).staticmethod("degrees_of_freedom")
//.def("degrees_of_freedom", static_cast<size_t (*)(const std::vector<size_t>&, const std::vector<size_t>&)>(&TTTensor::degrees_of_freedom)).staticmethod("degrees_of_freedom")
// .def("chop",
// +[](HTTensor &_this, size_t _pos) {
// const auto result = _this.chop(_pos);
......@@ -105,7 +104,6 @@ void expose_htnetwork() {
.def("set_component", &HTOperator::set_component)
.def_readonly("canonicalized", &HTOperator::canonicalized)
.def_readonly("corePosition", &HTOperator::corePosition)
.def_readonly("numberOfComponents", &HTOperator::numberOfComponents)
.def("ranks", &HTOperator::ranks)
.def("rank", &HTOperator::rank)
// .def("frob_norm", &TTOperator::frob_norm) // NOTE unneccessary because correct call is inherited
......@@ -113,26 +111,26 @@ void expose_htnetwork() {
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::HTOperator::random(_dim, _rank);
}).staticmethod("random")
.def("ones", &TTOperator::ones).staticmethod("ones")
.def("kronecker", &TTOperator::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTOperator::dirac))
.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const size_t)>(&TTOperator::dirac)).staticmethod("dirac")
.def("ones", &HTOperator::ones).staticmethod("ones")
//.def("kronecker", &TTOperator::kronecker).staticmethod("kronecker")
//.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTOperator::dirac))
//.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const size_t)>(&TTOperator::dirac)).staticmethod("dirac")
.def("use_dense_representations", &TTOperator::use_dense_representations)
.def("reduce_to_maximal_ranks", &TTOperator::reduce_to_maximal_ranks).staticmethod("reduce_to_maximal_ranks")
//.def("use_dense_representations", &TTOperator::use_dense_representations)
//.def("reduce_to_maximal_ranks", &TTOperator::reduce_to_maximal_ranks).staticmethod("reduce_to_maximal_ranks")
// .def("degrees_of_freedom", static_cast<size_t (TTOperator::*)()>(&TTOperator::degrees_of_freedom))
.def("degrees_of_freedom", static_cast<size_t (*)(const std::vector<size_t>&, const std::vector<size_t>&)>(&TTOperator::degrees_of_freedom)).staticmethod("degrees_of_freedom")
//.def("degrees_of_freedom", static_cast<size_t (*)(const std::vector<size_t>&, const std::vector<size_t>&)>(&TTOperator::degrees_of_freedom)).staticmethod("degrees_of_freedom")
// .def("chop",
// +[](HTOperator &_this, size_t _pos) {
// const auto result = _this.chop(_pos);
// return boost::python::make_tuple(result.first, result.second);
// }, arg("position"))
//
.def("round", static_cast<void (TTOperator::*)(const std::vector<size_t>&, double)>(&TTOperator::round),
.def("round", static_cast<void (HTOperator::*)(const std::vector<size_t>&, double)>(&HTOperator::round),
(arg("ranks"), arg("epsilon")=EPSILON)
)
.def("round", static_cast<void (TTOperator::*)(double)>(&TTOperator::round))
.def("round", static_cast<void (TTOperator::*)(size_t)>(&TTOperator::round))
.def("round", static_cast<void (HTOperator::*)(double)>(&HTOperator::round))
.def("round", static_cast<void (HTOperator::*)(size_t)>(&HTOperator::round))
.def("soft_threshold", static_cast<void (HTOperator::*)(const double, const bool)>(&HTOperator::soft_threshold),
(arg("tau"), arg("preventZero")=false)
......@@ -158,7 +156,7 @@ void expose_htnetwork() {
//
// // for TTOperator only:
// .def("identity", &TTOperator::identity<>).staticmethod("identity")
.def("transpose", &TTOperator::transpose<>)
.def("transpose", &HTOperator::transpose<>)
;
// def("entrywise_product", static_cast<TTOperator (*)(const TTOperator&, const TTOperator&)>(&entrywise_product));
// def("find_largest_entry", static_cast<size_t (*)(const TTOperator&, value_t, value_t)>(&find_largest_entry));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment