Commit acae6058 authored by Sebastian Wolf's avatar Sebastian Wolf

Calculate SVD considers error by maxRank for eps error bound

parent c3831fa5
Pipeline #1094 passed with stages
in 15 minutes and 7 seconds
......@@ -1422,20 +1422,21 @@ namespace xerus {
_rhs.reset(std::move(newDim), std::move(_rhsData));
}
value_t calculate_svd(Tensor& _U, Tensor& _S, Tensor& _Vt, Tensor _input, const size_t _splitPos, const size_t _maxRank, const value_t _eps) {
value_t calculate_svd(Tensor& _U, Tensor& _S, Tensor& _Vt, Tensor _input, const size_t _splitPos, size_t _maxRank, const value_t _eps) {
REQUIRE(0 <= _eps && _eps < 1, "Epsilon must be fullfill 0 <= _eps < 1.");
size_t lhsSize, rhsSize, rank;
std::tie(lhsSize, rhsSize, rank) = calculate_factorization_sizes(_input, _splitPos);
size_t lhsSize, rhsSize, full_rank;
std::tie(lhsSize, rhsSize, full_rank) = calculate_factorization_sizes(_input, _splitPos);
_maxRank = _maxRank == 0 ? full_rank : _maxRank;
std::unique_ptr<value_t[]> tmpS(new value_t[rank]);
std::unique_ptr<value_t[]> tmpS(new value_t[full_rank]);
// sparse SVD becomes inefficient when the matrix is not sparse enough
// sparse SVD is about equally fast to dense SVD when there are about N = 1.55*(min(m,n)+(max-min)/5) entries set
// will calculate with 2 instead of 1.55 to make sure that we will certainly be slower with sparse
// (note that the algorithm is quadratic in the sparsity)
size_t min = std::min(lhsSize, rhsSize);
size_t max = std::max(lhsSize, rhsSize);
const size_t min = std::min(lhsSize, rhsSize);
const size_t max = std::max(lhsSize, rhsSize);
if (_input.is_sparse() && _input.sparsity() > 2*(min+(max-min)/5)) {
_input.use_dense_representation();
}
......@@ -1449,34 +1450,30 @@ namespace xerus {
// then calculate SVD only of remaining (hopefully small) core
Tensor UPrime, VPrime;
std::tie(lhsSize, rhsSize, rank) = calculate_factorization_sizes(_input, 1);
prepare_factorization_output(UPrime, VPrime, _input, 1, rank, Tensor::Representation::Dense);
std::tie(lhsSize, rhsSize, full_rank) = calculate_factorization_sizes(_input, 1);
prepare_factorization_output(UPrime, VPrime, _input, 1, full_rank, Tensor::Representation::Dense);
blasWrapper::svd(UPrime.override_dense_data(), tmpS.get(), VPrime.override_dense_data(), _input.get_unsanitized_dense_data(), lhsSize, rhsSize);
// contract U*UPrime and VPrime*V to obtain SVD (UU', S, V'V) from orthogonal U and V as wel as the SVD (U', S, V')
contract(_U, _U, UPrime, 1);
contract(_Vt, VPrime, _Vt, 1);
} else {
prepare_factorization_output(_U, _Vt, _input, _splitPos, rank, Tensor::Representation::Dense);
prepare_factorization_output(_U, _Vt, _input, _splitPos, full_rank, Tensor::Representation::Dense);
blasWrapper::svd(_U.override_dense_data(), tmpS.get(), _Vt.override_dense_data(), _input.get_unsanitized_dense_data(), lhsSize, rhsSize);
}
size_t full_rank = rank;
size_t rank = full_rank;
// Account for hard threshold
if(_maxRank != 0) {
rank = std::min(rank, _maxRank);
}
// Find rank due to the Epsilon (NOTE the scaling factor can be ignored, as it does not change the ratios).
// For the total error to be < _eps, the sum of discarded singular value squares must be smaller than _eps times the norm, squared
value_t maxErrorSqr = misc::sqr(blasWrapper::two_norm(tmpS.get(), rank) * _eps);
const value_t maxErrorSqr = misc::sqr(blasWrapper::two_norm(tmpS.get(), rank) * _eps);
value_t error = 0;
while (rank > 1 && error + misc::sqr(tmpS[rank-1]) <= maxErrorSqr) {
while (rank > 1 && (rank > _maxRank || error + misc::sqr(tmpS[rank-1]) <= maxErrorSqr)) {
error += misc::sqr(tmpS[rank-1]);
rank -= 1;
}
// Create tensor from diagonal values
_S.reset(Tensor::DimensionTuple(2, rank), Tensor::Representation::Sparse);
for(size_t i = 0; i < rank; ++i) {
......@@ -1491,11 +1488,7 @@ namespace xerus {
_U.resize_mode(_U.degree()-1, rank);
_Vt.resize_mode(0, rank);
if (rank < full_rank) {
return blasWrapper::two_norm(tmpS.get()+rank, full_rank-rank);
} else {
return 0;
}
return std::sqrt(error);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment