Commit a8222295 authored by Michael Goette's avatar Michael Goette

updates ht network

parent 6e5a86ab
Pipeline #1022 failed with stages
in 3 minutes and 7 seconds
......@@ -62,10 +62,6 @@ namespace xerus {
*/
size_t corePosition;
/**
* @brief Number of Components
*/
size_t numberOfComponents;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Constructors - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
/**
......@@ -337,7 +333,10 @@ namespace xerus {
*/
bool is_left_child(size_t _comp) const;
/**
* @brief function which returns the number of components in an HTTensor
*/
size_t get_number_of_components() const;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Miscellaneous - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
public:
......@@ -513,8 +512,9 @@ namespace xerus {
template<bool B = isOperator, typename std::enable_if<B, int>::type = 0>
void transpose() {
const std::vector<size_t> shuffle({0,2,1});
size_t numComp = get_number_of_components();
//only leaves
for (size_t n = numberOfComponents - 1; n >= numberOfComponents - degree()/N; --n) {
for (size_t n = numComp - 1; n >= numComp - degree()/N; --n) {
xerus::reshuffle(component(n), component(n), shuffle);
}
}
......@@ -572,19 +572,19 @@ namespace xerus {
virtual void operator/=(const value_t _divisor) override;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Operator specializations - - - - - - - - - - - - - - - - - - - - - - - - - - */
// static bool specialized_contraction_f(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other);
static bool specialized_contraction_f(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other);
// static bool specialized_sum_f(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other);
// virtual bool specialized_contraction(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) const override {
// return specialized_contraction_f(_out, std::move(_me), std::move(_other));
// }
virtual bool specialized_contraction(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) const override {
return specialized_contraction_f(_out, std::move(_me), std::move(_other));
}
//
// virtual bool specialized_sum(std::unique_ptr<internal::IndexedTensorMoveable<TensorNetwork>>& _out, internal::IndexedTensorReadOnly<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) const override {
// return specialized_sum_f(_out, std::move(_me), std::move(_other));
// }
//
// virtual void specialized_evaluation(internal::IndexedTensorWritable<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) override;
virtual void specialized_evaluation(internal::IndexedTensorWritable<TensorNetwork>&& _me, internal::IndexedTensorReadOnly<TensorNetwork>&& _other) override;
};
......
// Xerus - A General Purpose Tensor Library
// Copyright (C) 2014-2018 Benjamin Huber and Sebastian Wolf.
//
// Xerus is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// Xerus is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with Xerus. If not, see <http://www.gnu.org/licenses/>.
//
// For further information on Xerus visit https://libXerus.org
// or contact us at contact@libXerus.org.
/**
* @file
* @brief Header file for the HTStack class.
*/
#pragma once
#include "misc/check.h"
#include "tensorNetwork.h"
#include "indexedTensorMoveable.h"
namespace xerus {
template<bool isOperator> class HTNetwork;
namespace internal {
template<bool isOperator>
///@brief Internal class used to represent stacks of (possibly multiply) applications of HTOperators to either a HTTensor or HTOperator.
class HTStack final : public TensorNetwork {
public:
///@brief The number of external links in each node, i.e. one for HTTensors and two for HTOperators.
static constexpr const size_t N = isOperator?2:1;
const bool cannonicalization_required;
const size_t futureCorePosition;
explicit HTStack(const bool _canno, const size_t _corePos = 0);
HTStack(const HTStack& _other) = default;
HTStack( HTStack&& _other) = default;
virtual TensorNetwork* get_copy() const override;
HTStack& operator= (const HTStack& _other) = delete;
HTStack& operator= ( HTStack&& _other) = delete;
// explicit operator HTNetwork<isOperator>();
virtual void operator*=(const value_t _factor) override;
virtual void operator/=(const value_t _divisor) override;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Operator specializations - - - - - - - - - - - - - - - - - - - - - - - - - - */
virtual void specialized_evaluation(IndexedTensorWritable<TensorNetwork>&& , IndexedTensorReadOnly<TensorNetwork>&&) override;
virtual bool specialized_contraction(std::unique_ptr<IndexedTensorMoveable<TensorNetwork>>& _out, IndexedTensorReadOnly<TensorNetwork>&& _me, IndexedTensorReadOnly<TensorNetwork>&& _other) const override;
// virtual bool specialized_sum(std::unique_ptr<IndexedTensorMoveable<TensorNetwork>>& _out, IndexedTensorReadOnly<TensorNetwork>&& _me, IndexedTensorReadOnly<TensorNetwork>&& _other) const override;
static void contract_stack(IndexedTensorWritable<TensorNetwork>&& _me);
virtual value_t frob_norm() const override;
};
}
}
......@@ -169,23 +169,23 @@ static misc::UnitTest ht_spec_sumdiff("HT", "special_sum_diff", [](){
C(i&0) = A(i&0) + B(i&0);
htC(i&0) = htA(i&0) + htB(i&0);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-12);
C(i&0) = B(i&0) + A(i&0);
htC(i&0) = htB(i&0) + htA(i&0);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12);
TEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-12);
C(i&0) = A(i&0) - B(i&0);
htC(i&0) = htA(i&0) - htB(i&0);
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)) < 3.1*1e-13, frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)) < 3.1*1e-12, frob_norm(Tensor(htC)(i&0) + Tensor(htB)(i&0)));
C(i&0) = B(i&0) - A(i&0);
htC(i&0) = htB(i&0) - htA(i&0);
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-13, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-13, frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-12, frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)));
Tensor X({10});
Tensor Y = Tensor::random({10});
......@@ -196,23 +196,23 @@ static misc::UnitTest ht_spec_sumdiff("HT", "special_sum_diff", [](){
Z(i&0) = X(i&0) + Y(i&0);
htZ(i&0) = htX(i&0) + htY(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-12);
Z(i&0) = Y(i&0) + X(i&0);
htZ(i&0) = htY(i&0) + htX(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-12);
Z(i&0) = X(i&0) - Y(i&0);
htZ(i&0) = htX(i&0) - htY(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) + Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) + Tensor(htY)(i&0)) < 3.1*1e-12);
Z(i&0) = Y(i&0) - X(i&0);
htZ(i&0) = htY(i&0) - htX(i&0);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-13);
TEST(frob_norm(Tensor(htZ)(i&0) - Z(i&0)) < 3.1*1e-12);
TEST(frob_norm(Tensor(htZ)(i&0) - Tensor(htY)(i&0)) < 3.1*1e-12);
});
static misc::UnitTest ht_prod("HT", "product", [](){
......@@ -237,7 +237,7 @@ static misc::UnitTest ht_prod("HT", "product", [](){
C(i^2,k^2) = A(i^2,j^2) * B(j^2,k^2);
htC(i^2,k^2) = htA(i^2,j^2) * htB(j^2,k^2);
TEST(htC.nodes.size() == 8);
TEST(htC.nodes.size() == 4);
fnorm = frob_norm(Tensor(htC)(i&0) - C(i&0));
LOG(unit_tests, "frob_norm " << fnorm);
TEST(fnorm < 10*10*10*10*1e-15);
......@@ -378,9 +378,9 @@ static misc::UnitTest ht_axb("HT", "ax_b", [](){
});
static misc::UnitTest ht_opt("HT", "operator_times_tensor", [](){
Tensor A = Tensor::random({10,10,10,10});
Tensor B = Tensor::random({10,10,10,10});
Tensor C = Tensor::random({10,10});
Tensor A = Tensor::random({6,6,6,6});
Tensor B = Tensor::random({6,6,6,6});
Tensor C = Tensor::random({6,6});
Tensor D;
Tensor Do;
HTOperator htA(A);
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment