Commit 7427977b authored by Michael Goette's avatar Michael Goette

added move core functionality on binary tree - next steps testing

parent 20dc3048
Pipeline #919 failed with stages
in 2 minutes and 42 seconds
......@@ -173,7 +173,7 @@ namespace xerus {
}
}
}
//result.move_core(0);
result.move_core(0);
return result;
}
......@@ -277,7 +277,7 @@ namespace xerus {
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Internal helper functions - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
// protected:
protected:
///@brief Constructs a TTNetwork in _out by decomposing the given Tensor _A.
// static void construct_train_from_full(TensorNetwork& _out, const Tensor& _A, const double _eps);
......@@ -293,8 +293,26 @@ namespace xerus {
///@brief Return the number of ranks, i.e. 0 for degree zero and degree()/N-1 otherwise.
size_t num_ranks() const;
/**
* @brief returns the path from one node to another in the binary tree
* @details this function is used to shift the core tensor along this path when the core tensor is moved, e.g. in a level 3 hierachical where 0 is the root
* and 3,4,5,6 are the leaves the path from 1 to 6 would be 1 -> 0 -> 2 -> 6
* @param start node
* @param end node
*/
std::vector<size_t> get_path(size_t start, size_t end) const;
/**
* @brief function to recursively find the path from the root to a destination
* @param root starting point for the downward search is 0 for the first call
* @param dest destination node
* @param path path from root to dest
*/
bool get_path_from_root(size_t root, size_t dest, std::vector<size_t>& path ) const;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Miscellaneous - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
// public:
public:
/**
* @brief Reduces the given ranks to the maximal possible.
* @details If a given rank is already smaller or equal it is left unchanged.
......@@ -335,7 +353,7 @@ namespace xerus {
/**
* @brief Read access to a specific component of the TT decomposition.
* @brief Read access to a specific component of the HT decomposition.
* @details This function should be used to access the components, instead of direct access via
* nodes[...], because the implementation does not store the first component in nodes[0] but rather as
* nodes[1] etc. nodes[0] is an order one node with dimension one only used to allow the first component
......@@ -343,7 +361,7 @@ namespace xerus {
* @param _idx index of the component to access.
* @returns a const reference to the requested component.
*/
// const Tensor& get_component(const size_t _idx) const;
const Tensor& get_component(const size_t _idx) const;
/**
......@@ -429,12 +447,12 @@ namespace xerus {
/**
* @brief Move the core to a new position.
* @details The core is moved to @a _position and the nodes between the old and the new position are orthogonalized
* accordingly. If the TTNetwork is not yet canonicalized it will be with @a _position as new corePosition.
* accordingly. If the HTNetwork is not yet canonicalized it will be with @a _position as new corePosition.
* @param _position the new core position.
* @param _keepRank by default a rank revealing QR decomposition is used to move the core and the ranks are reduced
* accordingly. If @a _keepRank is set the rank is not reduced, this is need e.g. in the ALS.
*/
// void move_core(const size_t _position, const bool _keepRank=false);
void move_core(const size_t _position, const bool _keepRank=false);
/**
......@@ -488,6 +506,7 @@ namespace xerus {
// virtual void require_correct_format() const override;
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Basic arithmetics - - - - - - - - - - - - - - - - - - - - - - - - - - */
/**
* @brief Adds a given TTNetwork to this one.
......
......@@ -255,16 +255,16 @@ namespace xerus {
//
//
// template<bool isOperator>
// TTNetwork<isOperator> TTNetwork<isOperator>::ones(const std::vector<size_t>& _dimensions) {
// HTNetwork<isOperator> HTNetwork<isOperator>::ones(const std::vector<size_t>& _dimensions) {
// REQUIRE(_dimensions.size()%N == 0, "Illegal number of dimensions for ttOperator");
// REQUIRE(!misc::contains(_dimensions, size_t(0)), "Trying to construct a TTTensor with dimension 0 is not possible.");
//
// if(_dimensions.empty()) {
// return TTNetwork(Tensor::ones({}));
// return HTNetwork(Tensor::ones({}));
// }
//
// TTNetwork result(_dimensions.size());
// const size_t numNodes = _dimensions.size()/N;
// HTNetwork result(_dimensions.size());
// const size_t numNodes = result.numberOfComponents;
//
// std::vector<size_t> dimensions(isOperator ? 4 : 3, 1);
// for(size_t i = 0; i < numNodes; ++i) {
......@@ -452,6 +452,39 @@ namespace xerus {
return degree() == 0 ? 0 : numberOfComponents - 1;
}
template<bool isOperator>
std::vector<size_t> HTNetwork<isOperator>::get_path(size_t start, size_t end) const {
std::vector<size_t> path_start;
std::vector<size_t> path_end;
std::vector<size_t> result;
REQUIRE(get_path_from_root(0, start, path_start ), "start point is wrong");
REQUIRE(get_path_from_root(0, end, path_end ), "end point is wrong");
while(!path_start.empty()){
size_t tmp = path_start.back();
path_start.pop_back();
auto tmp_found = std::find(path_end.begin(), path_end.end(), tmp);
if (path_end.end() == tmp_found){ result.emplace_back(tmp);}
else{
result.insert(result.end(), tmp_found, path_end.end());
break;
}
}
std::reverse(result.begin(),result.end());
return result;
}
template<bool isOperator>
bool HTNetwork<isOperator>::get_path_from_root(size_t root, size_t dest, std::vector<size_t>& path ) const {
if (root > numberOfComponents) { return false;}
path.emplace_back(root);
if (root == dest) { return true;}
if(get_path_from_root(root*2+1,dest,path) || get_path_from_root(root*2+2,dest,path)) {return true;}
path.pop_back();
return false;
}
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Miscellaneous - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
// template<bool isOperator>
......@@ -547,14 +580,14 @@ namespace xerus {
// return *nodes[degree() == 0 ? 0 : _idx+1].tensorObject;
// }
//
//
// template<bool isOperator>
// const Tensor& TTNetwork<isOperator>::get_component(const size_t _idx) const {
// REQUIRE(_idx == 0 || _idx < degree()/N, "Illegal index " << _idx <<" in TTNetwork::get_component.");
// return *nodes[degree() == 0 ? 0 : _idx+1].tensorObject;
// }
//
//
template<bool isOperator>
const Tensor& HTNetwork<isOperator>::get_component(const size_t _idx) const {
REQUIRE(_idx <= 0 || _idx < numberOfComponents, "Illegal index " << _idx <<" in TTNetwork::get_component.");
return *nodes[degree() == 0 ? 0 : _idx].tensorObject;
}
template<bool isOperator>
void HTNetwork<isOperator>::set_component(const size_t _idx, Tensor _T) {
if(degree() == 0) {
......@@ -671,34 +704,28 @@ namespace xerus {
// }
//
//
// template<bool isOperator>
// void TTNetwork<isOperator>::move_core(const size_t _position, const bool _keepRank) {
// const size_t numComponents = degree()/N;
// REQUIRE(_position < numComponents || (_position == 0 && degree() == 0), "Illegal core-position " << _position << " chosen for TTNetwork with " << numComponents << " components");
// require_correct_format();
//
// if(canonicalized) {
// // Move right?
// for (size_t n = corePosition; n < _position; ++n) {
// transfer_core(n+1, n+2, !_keepRank);
// }
//
// // Move left?
// for (size_t n = corePosition; n > _position; --n) {
// transfer_core(n+1, n, !_keepRank);
// }
// } else {
// // Move right?
// for (size_t n = 0; n < _position; ++n) {
// transfer_core(n+1, n+2, !_keepRank);
// }
//
// // Move left?
// for (size_t n = numComponents; n > _position+1; --n) {
// transfer_core(n, n-1, !_keepRank);
// }
// }
//
template<bool isOperator>
void HTNetwork<isOperator>::move_core(const size_t _position, const bool _keepRank) {
const size_t numComponents = numberOfComponents;
REQUIRE(_position < numComponents || (_position == 0 && degree() == 0), "Illegal core-position " << _position << " chosen for TTNetwork with " << numComponents << " components");
require_correct_format();
if (!canonicalized){
for (size_t n = numComponents - 1; n > _position; --n) {
transfer_core(n, (n + 1) / 2 - 1, !_keepRank);
corePosition = 0;
}
}
std::vector<size_t> path = get_path(corePosition, _position);
while (path.size() > 1){
size_t start = path.back();
path.pop_back();
size_t end = path.back();
transfer_core(start, end, !_keepRank);
}
// TODO whz is this here??
// while (exceeds_maximal_ranks()) {
// // Move left from given CorePosition
// for (size_t n = _position; n > 0; --n) {
......@@ -715,10 +742,10 @@ namespace xerus {
// transfer_core(n, n-1, !_keepRank);
// }
// }
//
// canonicalized = true;
// corePosition = _position;
// }
canonicalized = true;
corePosition = _position;
}
//
//
// template<bool isOperator>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment