Commit 49c753f2 authored by Sebastian Wolf's avatar Sebastian Wolf

Fix for soft threshholding in TensorNetworks and reworked find_largest_entry

parent acae6058
Pipeline #1095 failed with stages
in 13 minutes and 32 seconds
......@@ -38,6 +38,6 @@ namespace xerus {
* @return the position of the entry found.
*/
template<bool isOperator>
size_t find_largest_entry(const TTNetwork<isOperator> &_T, double _accuracy, value_t _lowerBound = 0.0);
size_t find_largest_entry(const TTNetwork<isOperator> &_T, const double _accuracy, const value_t _lowerBound = 0.0);
}
......@@ -32,9 +32,7 @@ XERUS_SET_LOGGING(TensorAssignment, xerus::misc::internal::LOGGING_ON_ERROR)
XERUS_SET_LOGGING(ALS, xerus::misc::internal::LOGGING_ON_ERROR)
XERUS_SET_LOGGING(unit_test, xerus::misc::internal::LOGGING_ON_ERROR)
XERUS_SET_LOGGING(unit_tests, xerus::misc::internal::LOGGING_ON_ERROR)
XERUS_SET_LOGGING(largestEntry, xerus::misc::internal::LOGGING_ON_ERROR)
XERUS_SET_LOGGING(pydebug, xerus::misc::internal::LOGGING_ON_ERROR)
XERUS_SET_LOGGING(ADF, xerus::misc::internal::LOGGING_ON_ERROR)
// XERUS_SET_LOGGING(ADFx, xerus::misc::internal::LOGGING_ON_ERROR)
XERUS_SET_LOGGING(UQ, xerus::misc::internal::LOGGING_ON_ERROR)
/* */
......@@ -352,9 +352,9 @@ namespace xerus {
/**
* @brief Reduce all ranks up to a given accuracy and maximal number.
* @param _maxRanks maximal allowed ranks. All current ranks that are larger than the given ones are reduced by truncation.
* @param _eps the accuracy to use for truncation in the individual SVDs.
* @param _eps the maximal relative L2 error of the rounding, unless the @a _maxRanks determine a higher error.
*/
void round(const std::vector<size_t>& _maxRanks, const double _eps = EPSILON);
void round(const std::vector<size_t>& _maxRanks, const double _eps = 0.0);
/**
......@@ -373,7 +373,7 @@ namespace xerus {
/**
* @brief Reduce all ranks up to a given accuracy.
* @param _eps the accuracy to use for truncation in the individual SVDs.
* @param _eps the maximal relative L2 error.
*/
void round(const value_t _eps);
......@@ -382,14 +382,14 @@ namespace xerus {
* @brief Applies the soft threshholding operation to all ranks.
* @param _tau the soft threshholding parameter to be applied. I.e. all singular values are reduced to max(0, Lambda_ui - _tau).
*/
void soft_threshold(const double _tau, const bool _preventZero = false);
void soft_threshold(const double _tau);
/**
* @brief Applies soft threshholding operations to all ranks.
* @param _taus the soft threshholding parameters to be applied. I.e. all singular values of the j-th matrification are reduced to max(0, Lambda_ui - _tau[j]).
*/
void soft_threshold(const std::vector<double>& _taus, const bool _preventZero = false);
void soft_threshold(const std::vector<double>& _taus);
/**
......
......@@ -28,12 +28,13 @@ using namespace xerus;
static misc::UnitTest alg_largestEntry("Algorithm", "LargestEntry", [](){
//Random numbers
std::mt19937_64 rnd = xerus::misc::randomEngine;
std::uniform_int_distribution<size_t> dimDist(1,3);
std::uniform_int_distribution<size_t> rankDist(1,4);
std::uniform_real_distribution<double> realDist(0.75, 1.0);
std::uniform_int_distribution<size_t> dimDist(1, 4);
std::uniform_int_distribution<size_t> rankDist(1, 5);
const size_t D = 16;
const size_t D = 15;
for(size_t k = 0; k < 2; ++k) {
for(size_t k = 0; k < 3; ++k) {
std::vector<size_t> stateDims;
stateDims.push_back(dimDist(rnd));
......@@ -49,30 +50,42 @@ static misc::UnitTest alg_largestEntry("Algorithm", "LargestEntry", [](){
Tensor fullX(X);
size_t posA = 0, posB = 0;
size_t posA = 0, posB = 0, posC = 0;
for(size_t i = 1; i < fullX.size; ++i) {
if(std::abs(fullX[i]) >= std::abs(fullX[posA])) {
if (std::abs(fullX[i]) >= std::abs(fullX[posA])) {
posC = posB;
posB = posA;
posA = i;
}
else if(std::abs(fullX[i]) >= std::abs(fullX[posB])) {
} else if (std::abs(fullX[i]) >= std::abs(fullX[posB])) {
posC = posB;
posB = i;
} else if (std::abs(fullX[i]) >= std::abs(fullX[posC])) {
posC = i;
}
}
LOG(largestEntry, "Largest entries are: " << fullX[posA] << " and " << fullX[posB] << " at " << posA << " and " << posB);
// LOG(largestEntry, "Largest entries are: " << fullX[posA] << " , " << fullX[posB] << " and " << fullX[posC] << " at " << posA << " , " << posB << " and " << posC);
double alpha = std::abs(fullX[posB]/fullX[posA]);
double Xn = std::abs(fullX[posA]);
double Xn = realDist(rnd)*std::abs(fullX[posA]);
size_t position = find_largest_entry(X, alpha, Xn);
LOG(largestEntry, "Result: " << fullX[position] << " vs " << fullX[posA] << " at positions " << position << " and " << posA);
// LOG(largestEntry, "Result: " << fullX[position] << " vs " << fullX[posA] << " at positions " << position << " and " << posA);
TEST(position == posA);
if(position != posA) {
LOG(omg, fullX.to_string());
}
alpha = std::abs(fullX[posC]/fullX[posA]);
Xn = realDist(rnd)*std::abs(fullX[posA]);
position = find_largest_entry(X, alpha, Xn);
// LOG(largestEntry, "Result: " << fullX[position] << " vs " << fullX[posA] << " at positions " << position << " and " << posA);
TEST(position == posA || position == posB);
}
}
});
// UNIT_TEST(Algorithm, rankRange,
......
#include <xerus/algorithms/largestEntry.h>
#include <xerus/misc/internal.h>
namespace xerus {
template<bool isOperator>
size_t find_largest_entry(const TTNetwork<isOperator> &_T, const double _accuracy, const value_t _lowerBound) {
_T.require_correct_format();
size_t find_largest_entry_rank_one(const TTNetwork<isOperator> &_T) {
INTERNAL_CHECK(misc::sum(_T.ranks())+1 == _T.degree(), "Ie");
// There is actual work to be done
if(misc::sum(_T.ranks()) >= _T.degree()) {
const double alpha = _accuracy;
TTNetwork<isOperator> X = _T;
X.round(size_t(1));
double Xn = std::max(_T[find_largest_entry(X, 0.0, 0.0)], _lowerBound);
double tau = (1-alpha)*alpha*Xn*Xn/(2.0*double(_T.degree()-1));
X = _T;
while(misc::sum(X.ranks()) >= _T.degree()) {
X = entrywise_product(X, X);
X.soft_threshold(tau, true);
TTNetwork<isOperator> Y = X;
Y.round(1);
const size_t yMaxPos = find_largest_entry(Y, 0.0, 0.0);
Xn = std::max(X[yMaxPos], (1-(1-alpha)*alpha/2.0)*Xn*Xn);
const double fNorm = X.frob_norm();
Xn /= fNorm;
X /= fNorm;
tau = (1-alpha)*alpha*Xn*Xn/(2.0*double(_T.degree()-1));
}
return find_largest_entry(X, 0.0, 0.0);
}
// We are already rank one
const size_t numComponents = _T.degree()/(isOperator?2:1);
size_t position = 0;
size_t factor = misc::product(_T.dimensions);
......@@ -53,6 +25,39 @@ namespace xerus {
return position;
}
template<bool isOperator>
size_t find_largest_entry(const TTNetwork<isOperator> &_T, const double _accuracy, const value_t _lowerBound) {
_T.require_correct_format();
const value_t numSVDs = double(_T.degree()-1);
const value_t gamma = (1-_accuracy)*_accuracy/2.0;
TTNetwork<isOperator> X = _T;
TTNetwork<isOperator> Y = X;
Y.round(1);
double Xn = std::max(_T[find_largest_entry_rank_one(Y)], _lowerBound);
double tau = gamma*Xn*Xn/numSVDs;
while(misc::sum(X.ranks()) >= _T.degree()) {
X = entrywise_product(X, X);
X.soft_threshold(tau);
Y = X; Y.round(1);
Xn = std::max(X[find_largest_entry_rank_one(Y)], (1-gamma)*Xn*Xn);
const double fNorm = X.frob_norm();
Xn /= fNorm;
X /= fNorm;
tau = gamma*Xn*Xn/numSVDs;
}
return find_largest_entry_rank_one(X);
}
template size_t find_largest_entry(const TTNetwork<true> &, double, value_t);
template size_t find_largest_entry(const TTNetwork<false> &, double, value_t);
} // namespace xerus
......@@ -764,7 +764,22 @@ namespace xerus {
// ... calculate svd ...
calculate_svd(coreA, S, coreB, X, 1, _maxRank, _eps);
S.modify_diagonal_entries([&](value_t& _d){ _d = std::max(0.0, _d - _softThreshold); });
if(_softThreshold > 0.0) {
// Reduce SVs and find new rank
size_t newRank = 0;
do {
S[{newRank, newRank}] = std::max(0.0, S[{newRank, newRank}]- _softThreshold);
newRank++;
} while (newRank < S.dimensions[0] && S[{newRank, newRank}] > _softThreshold);
// Resize Tensors accordingly
S.resize_mode(0, newRank);
S.resize_mode(1, newRank);
coreA.resize_mode(1, newRank);
coreB.resize_mode(0, newRank);
XERUS_REQUIRE_TEST;
}
// ... contract S to the right ...
xerus::contract(coreB, S, false, coreB, false, 1);
......@@ -785,8 +800,23 @@ namespace xerus {
xerus::contract(X, fromTensor, transFrom, toTensor, transTo, 1);
calculate_svd(fromTensor, S, toTensor, X, fromDegree-1, _maxRank, _eps);
S.modify_diagonal_entries([&](value_t& _d){ _d = std::max(0.0, _d - _softThreshold); });
if(_softThreshold > 0.0) {
// Reduce SVs and find new rank
size_t newRank = 0;
do {
S[{newRank, newRank}] = std::max(0.0, S[{newRank, newRank}]- _softThreshold);
newRank++;
} while (newRank < S.dimensions[0] && S[{newRank, newRank}] > _softThreshold);
// Resize Tensors accordingly
S.resize_mode(0, newRank);
S.resize_mode(1, newRank);
fromTensor.resize_mode(fromDegree-1, newRank);
toTensor.resize_mode(0, newRank);
XERUS_REQUIRE_TEST;
}
if(transTo) {
xerus::contract(toTensor, toTensor, true, S, true, 1);
......
......@@ -665,11 +665,12 @@ namespace xerus {
const size_t initialCorePosition = corePosition;
canonicalize_right();
auto epsPerSite = _eps / std::sqrt(double(numComponents)-1);
if(numComponents > 1) {
auto epsPerSite = misc::hard_equal(_eps, 0.0) ? EPSILON : _eps / std::sqrt(double(numComponents-1));
for(size_t i = 0; i+1 < numComponents; ++i) {
round_edge(numComponents-i, numComponents-i-1, _maxRanks[numComponents-i-2], epsPerSite, 0.0);
for(size_t i = 0; i+1 < numComponents; ++i) {
round_edge(numComponents-i, numComponents-i-1, _maxRanks[numComponents-i-2], epsPerSite, 0.0);
}
}
assume_core_position(0);
......@@ -700,7 +701,7 @@ namespace xerus {
template<bool isOperator>
void TTNetwork<isOperator>::soft_threshold(const std::vector<double> &_taus, const bool /*_preventZero*/) {
void TTNetwork<isOperator>::soft_threshold(const std::vector<double> &_taus) {
const size_t numComponents = degree()/N;
REQUIRE(_taus.size()+1 == numComponents || (_taus.empty() && numComponents == 0), "There must be exactly degree/N-1 taus. Here " << _taus.size() << " instead of " << numComponents-1 << " are given.");
require_correct_format();
......@@ -723,8 +724,8 @@ namespace xerus {
template<bool isOperator>
void TTNetwork<isOperator>::soft_threshold(const double _tau, const bool _preventZero) {
soft_threshold(std::vector<double>(num_ranks(), _tau), _preventZero);
void TTNetwork<isOperator>::soft_threshold(const double _tau) {
soft_threshold(std::vector<double>(num_ranks(), _tau));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment