diff --git a/include/xerus/htNetwork.h b/include/xerus/htNetwork.h index 5e209bf015ccae779db3bb11d813022980c780f5..3400754ef7f291f4360641be5bc9a88d8df63e45 100644 --- a/include/xerus/htNetwork.h +++ b/include/xerus/htNetwork.h @@ -173,7 +173,7 @@ namespace xerus { } } } - //result.move_core(0); + result.move_core(0); return result; } @@ -277,7 +277,7 @@ namespace xerus { /*- - - - - - - - - - - - - - - - - - - - - - - - - - Internal helper functions - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/ -// protected: + protected: ///@brief Constructs a TTNetwork in _out by decomposing the given Tensor _A. // static void construct_train_from_full(TensorNetwork& _out, const Tensor& _A, const double _eps); @@ -293,8 +293,26 @@ namespace xerus { ///@brief Return the number of ranks, i.e. 0 for degree zero and degree()/N-1 otherwise. size_t num_ranks() const; + /** + * @brief returns the path from one node to another in the binary tree + * @details this function is used to shift the core tensor along this path when the core tensor is moved, e.g. in a level 3 hierachical where 0 is the root + * and 3,4,5,6 are the leaves the path from 1 to 6 would be 1 -> 0 -> 2 -> 6 + * @param start node + * @param end node + */ + std::vector get_path(size_t start, size_t end) const; + + /** + * @brief function to recursively find the path from the root to a destination + * @param root starting point for the downward search is 0 for the first call + * @param dest destination node + * @param path path from root to dest + */ + bool get_path_from_root(size_t root, size_t dest, std::vector& path ) const; + /*- - - - - - - - - - - - - - - - - - - - - - - - - - Miscellaneous - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/ -// public: + public: + /** * @brief Reduces the given ranks to the maximal possible. * @details If a given rank is already smaller or equal it is left unchanged. @@ -335,7 +353,7 @@ namespace xerus { /** - * @brief Read access to a specific component of the TT decomposition. + * @brief Read access to a specific component of the HT decomposition. * @details This function should be used to access the components, instead of direct access via * nodes[...], because the implementation does not store the first component in nodes[0] but rather as * nodes[1] etc. nodes[0] is an order one node with dimension one only used to allow the first component @@ -343,7 +361,7 @@ namespace xerus { * @param _idx index of the component to access. * @returns a const reference to the requested component. */ -// const Tensor& get_component(const size_t _idx) const; + const Tensor& get_component(const size_t _idx) const; /** @@ -429,12 +447,12 @@ namespace xerus { /** * @brief Move the core to a new position. * @details The core is moved to @a _position and the nodes between the old and the new position are orthogonalized - * accordingly. If the TTNetwork is not yet canonicalized it will be with @a _position as new corePosition. + * accordingly. If the HTNetwork is not yet canonicalized it will be with @a _position as new corePosition. * @param _position the new core position. * @param _keepRank by default a rank revealing QR decomposition is used to move the core and the ranks are reduced * accordingly. If @a _keepRank is set the rank is not reduced, this is need e.g. in the ALS. */ -// void move_core(const size_t _position, const bool _keepRank=false); + void move_core(const size_t _position, const bool _keepRank=false); /** @@ -488,6 +506,7 @@ namespace xerus { // virtual void require_correct_format() const override; + /*- - - - - - - - - - - - - - - - - - - - - - - - - - Basic arithmetics - - - - - - - - - - - - - - - - - - - - - - - - - - */ /** * @brief Adds a given TTNetwork to this one. diff --git a/src/xerus/htNetwork.cpp b/src/xerus/htNetwork.cpp index cb7c0200e618a3b27dadfb074c1230449370b203..65b77a1429d0d8ce54cc3876c4507bc35a80c391 100644 --- a/src/xerus/htNetwork.cpp +++ b/src/xerus/htNetwork.cpp @@ -255,16 +255,16 @@ namespace xerus { // // // template -// TTNetwork TTNetwork::ones(const std::vector& _dimensions) { +// HTNetwork HTNetwork::ones(const std::vector& _dimensions) { // REQUIRE(_dimensions.size()%N == 0, "Illegal number of dimensions for ttOperator"); // REQUIRE(!misc::contains(_dimensions, size_t(0)), "Trying to construct a TTTensor with dimension 0 is not possible."); // // if(_dimensions.empty()) { -// return TTNetwork(Tensor::ones({})); +// return HTNetwork(Tensor::ones({})); // } // -// TTNetwork result(_dimensions.size()); -// const size_t numNodes = _dimensions.size()/N; +// HTNetwork result(_dimensions.size()); +// const size_t numNodes = result.numberOfComponents; // // std::vector dimensions(isOperator ? 4 : 3, 1); // for(size_t i = 0; i < numNodes; ++i) { @@ -452,6 +452,39 @@ namespace xerus { return degree() == 0 ? 0 : numberOfComponents - 1; } + template + std::vector HTNetwork::get_path(size_t start, size_t end) const { + std::vector path_start; + std::vector path_end; + std::vector result; + + REQUIRE(get_path_from_root(0, start, path_start ), "start point is wrong"); + REQUIRE(get_path_from_root(0, end, path_end ), "end point is wrong"); + while(!path_start.empty()){ + size_t tmp = path_start.back(); + path_start.pop_back(); + auto tmp_found = std::find(path_end.begin(), path_end.end(), tmp); + if (path_end.end() == tmp_found){ result.emplace_back(tmp);} + else{ + result.insert(result.end(), tmp_found, path_end.end()); + break; + } + } + std::reverse(result.begin(),result.end()); + return result; + } + + template + bool HTNetwork::get_path_from_root(size_t root, size_t dest, std::vector& path ) const { + if (root > numberOfComponents) { return false;} + path.emplace_back(root); + if (root == dest) { return true;} + if(get_path_from_root(root*2+1,dest,path) || get_path_from_root(root*2+2,dest,path)) {return true;} + path.pop_back(); + return false; + + } + /*- - - - - - - - - - - - - - - - - - - - - - - - - - Miscellaneous - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/ // template @@ -547,14 +580,14 @@ namespace xerus { // return *nodes[degree() == 0 ? 0 : _idx+1].tensorObject; // } // -// -// template -// const Tensor& TTNetwork::get_component(const size_t _idx) const { -// REQUIRE(_idx == 0 || _idx < degree()/N, "Illegal index " << _idx <<" in TTNetwork::get_component."); -// return *nodes[degree() == 0 ? 0 : _idx+1].tensorObject; -// } -// // + template + const Tensor& HTNetwork::get_component(const size_t _idx) const { + REQUIRE(_idx <= 0 || _idx < numberOfComponents, "Illegal index " << _idx <<" in TTNetwork::get_component."); + return *nodes[degree() == 0 ? 0 : _idx].tensorObject; + } + + template void HTNetwork::set_component(const size_t _idx, Tensor _T) { if(degree() == 0) { @@ -671,34 +704,28 @@ namespace xerus { // } // // -// template -// void TTNetwork::move_core(const size_t _position, const bool _keepRank) { -// const size_t numComponents = degree()/N; -// REQUIRE(_position < numComponents || (_position == 0 && degree() == 0), "Illegal core-position " << _position << " chosen for TTNetwork with " << numComponents << " components"); -// require_correct_format(); -// -// if(canonicalized) { -// // Move right? -// for (size_t n = corePosition; n < _position; ++n) { -// transfer_core(n+1, n+2, !_keepRank); -// } -// -// // Move left? -// for (size_t n = corePosition; n > _position; --n) { -// transfer_core(n+1, n, !_keepRank); -// } -// } else { -// // Move right? -// for (size_t n = 0; n < _position; ++n) { -// transfer_core(n+1, n+2, !_keepRank); -// } -// -// // Move left? -// for (size_t n = numComponents; n > _position+1; --n) { -// transfer_core(n, n-1, !_keepRank); -// } -// } -// + template + void HTNetwork::move_core(const size_t _position, const bool _keepRank) { + const size_t numComponents = numberOfComponents; + REQUIRE(_position < numComponents || (_position == 0 && degree() == 0), "Illegal core-position " << _position << " chosen for TTNetwork with " << numComponents << " components"); + require_correct_format(); + + if (!canonicalized){ + for (size_t n = numComponents - 1; n > _position; --n) { + transfer_core(n, (n + 1) / 2 - 1, !_keepRank); + corePosition = 0; + } + } + std::vector path = get_path(corePosition, _position); + + while (path.size() > 1){ + size_t start = path.back(); + path.pop_back(); + size_t end = path.back(); + transfer_core(start, end, !_keepRank); + } + +// TODO whz is this here?? // while (exceeds_maximal_ranks()) { // // Move left from given CorePosition // for (size_t n = _position; n > 0; --n) { @@ -715,10 +742,10 @@ namespace xerus { // transfer_core(n, n-1, !_keepRank); // } // } -// -// canonicalized = true; -// corePosition = _position; -// } + + canonicalized = true; + corePosition = _position; + } // // // template