Commit de928700 authored by Michael Goette's avatar Michael Goette

added more tests and simple functionality to HTNetwork

parent 98c57a34
Pipeline #935 passed with stages
in 8 minutes and 41 seconds
......@@ -137,7 +137,7 @@ namespace xerus {
template<class distribution=std::normal_distribution<value_t>, class generator=std::mt19937_64>
static HTNetwork XERUS_warn_unused random(std::vector<size_t> _dimensions, const std::vector<size_t> &_ranks, distribution& _dist=xerus::misc::defaultNormalDistribution, generator& _rnd=xerus::misc::randomEngine) {
const size_t numIntComp = static_cast<size_t>(std::pow(2,std::ceil(std::log2(static_cast<double>(_dimensions.size()/N ))))) - 1;
const size_t numIntComp = static_cast<size_t>(0.5 + std::pow(2,std::ceil(std::log2(static_cast<double>(_dimensions.size()/N ))))) - 1;
const size_t numOfLeaves = _dimensions.size()/N;
const size_t numComponents = numIntComp + numOfLeaves;
......@@ -420,28 +420,28 @@ namespace xerus {
* @param _maxRanks maximal allowed ranks. All current ranks that are larger than the given ones are reduced by truncation.
* @param _eps the accuracy to use for truncation in the individual SVDs.
*/
// void round(const std::vector<size_t>& _maxRanks, const double _eps = EPSILON);
void round(const std::vector<size_t>& _maxRanks, const double _eps = EPSILON);
/**
* @brief Reduce all ranks to the given number.
* @param _maxRank maximal allowed rank. All current ranks that are larger than this are reduced by truncation.
*/
// void round(const size_t _maxRank);
void round(const size_t _maxRank);
/**
* @brief Reduce all ranks to the given number.
* @param _maxRank maximal allowed rank. All current ranks that are larger than this are reduced by truncation.
*/
// void round(const int _maxRank);
void round(const int _maxRank);
/**
* @brief Reduce all ranks up to a given accuracy.
* @param _eps the accuracy to use for truncation in the individual SVDs.
*/
// void round(const value_t _eps);
void round(const value_t _eps);
/**
......@@ -510,13 +510,14 @@ namespace xerus {
* @brief Transpose the TTOperator
* @details Swaps all external indices to create the transposed operator.
*/
// template<bool B = isOperator, typename std::enable_if<B, int>::type = 0>
// void transpose() {
// const std::vector<size_t> shuffle({0,2,1,3});
// for (size_t n = 0; n < degree()/N; ++n) {
// xerus::reshuffle(component(n), component(n), shuffle);
// }
// }
template<bool B = isOperator, typename std::enable_if<B, int>::type = 0>
void transpose() {
const std::vector<size_t> shuffle({0,2,1});
//only leaves
for (size_t n = numberOfComponents - 1; n >= numberOfComponents - degree()/N; --n) {
xerus::reshuffle(component(n), component(n), shuffle);
}
}
virtual TensorNetwork* get_copy() const override;
......
This diff is collapsed.
......@@ -746,8 +746,8 @@ namespace xerus {
REQUIRE(_position < numComponents || (_position == 0 && degree() == 0), "Illegal core-position " << _position << " chosen for TTNetwork with " << numComponents << " components");
require_correct_format();
if (!canonicalized){
for (size_t n = numComponents - 1; n > _position; --n) {
if (!canonicalized){//canonicalize to 0
for (size_t n = numComponents - 1; n > 0; --n) {
transfer_core(n, (n + 1) / 2 - 1, !_keepRank);
corePosition = 0;
}
......@@ -758,7 +758,7 @@ namespace xerus {
size_t start = path.back();
path.pop_back();
size_t end = path.back();
transfer_core(start, end, !_keepRank);
transfer_core(start, end, !_keepRank);
}
// TODO whz is this here??
......@@ -796,48 +796,50 @@ namespace xerus {
// }
//
//
// template<bool isOperator>
// void TTNetwork<isOperator>::round(const std::vector<size_t>& _maxRanks, const double _eps) {
// require_correct_format();
// const size_t numComponents = degree()/N;
// REQUIRE(_eps < 1, "_eps must be smaller than one. " << _eps << " was given.");
// REQUIRE(_maxRanks.size()+1 == numComponents || (_maxRanks.empty() && numComponents == 0), "There must be exactly degree/N-1 maxRanks. Here " << _maxRanks.size() << " instead of " << numComponents-1 << " are given.");
// REQUIRE(!misc::contains(_maxRanks, size_t(0)), "Trying to round a TTTensor to rank 0 is not possible.");
//
// const bool initialCanonicalization = canonicalized;
// const size_t initialCorePosition = corePosition;
//
// canonicalize_right();
//
// for(size_t i = 0; i+1 < numComponents; ++i) {
// round_edge(numComponents-i, numComponents-i-1, _maxRanks[numComponents-i-2], _eps, 0.0);
// }
//
// assume_core_position(0);
//
// if(initialCanonicalization) {
// move_core(initialCorePosition);
// }
// }
//
//
// template<bool isOperator>
// void TTNetwork<isOperator>::round(const size_t _maxRank) {
// round(std::vector<size_t>(num_ranks(), _maxRank), EPSILON);
// }
//
//
// template<bool isOperator>
// void TTNetwork<isOperator>::round(const int _maxRank) {
// REQUIRE( _maxRank > 0, "MaxRank must be positive");
// round(size_t(_maxRank));
// }
//
//
// template<bool isOperator>
// void TTNetwork<isOperator>::round(const value_t _eps) {
// round(std::vector<size_t>(num_ranks(), std::numeric_limits<size_t>::max()), _eps);
// }
template<bool isOperator>
void HTNetwork<isOperator>::round(const std::vector<size_t>& _maxRanks, const double _eps) {
require_correct_format();
const size_t numIntComp = static_cast<size_t>(0.5 + std::pow(2,std::ceil(std::log2(static_cast<double>(degree()/N ))))) - 1;
const size_t numOfLeaves = degree()/N;
const size_t numComponents = numIntComp + numOfLeaves; REQUIRE(_eps < 1, "_eps must be smaller than one. " << _eps << " was given.");
REQUIRE(_maxRanks.size()+1 == numComponents || (_maxRanks.empty() && numComponents == 0) ,"There must be exactly degree/N-1 maxRanks. Here " << _maxRanks.size() << " instead of " << numComponents-1 << " are given.");
REQUIRE(!misc::contains(_maxRanks, size_t(0)), "Trying to round a HTTensor to rank 0 is not possible.");
const bool initialCanonicalization = canonicalized;
const size_t initialCorePosition = corePosition;
canonicalize_root();
for (size_t n = numComponents - 1; n > 0; --n) {
round_edge(n, (n + 1) / 2 - 1, _maxRanks[n], _eps, 0.0);
}
assume_core_position(0);
if(initialCanonicalization) {
move_core(initialCorePosition);
}
}
template<bool isOperator>
void HTNetwork<isOperator>::round(const size_t _maxRank) {
round(std::vector<size_t>(num_ranks(), _maxRank), EPSILON);
}
template<bool isOperator>
void HTNetwork<isOperator>::round(const int _maxRank) {
REQUIRE( _maxRank > 0, "MaxRank must be positive");
round(size_t(_maxRank));
}
template<bool isOperator>
void HTNetwork<isOperator>::round(const value_t _eps) {
round(std::vector<size_t>(num_ranks(), std::numeric_limits<size_t>::max()), _eps);
}
//
//
// template<bool isOperator>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment