Commit fa1beb95 authored by Sebastian Wolf's avatar Sebastian Wolf

Limit RandomSVD ranks

parent b0dfeaa2
Pipeline #881 passed with stages
in 8 minutes and 40 seconds
......@@ -29,22 +29,22 @@
namespace xerus {
TTTensor randomTTSVD(const Tensor& _x, const std::vector<size_t>& _ranks) {
REQUIRE(_x.degree() == _ranks.size()+1, "Inconsitend degree vs number of provided ranks.");
std::normal_distribution<double> dist(0, 1);
const size_t d = _x.degree();
REQUIRE(d == _ranks.size()+1, "Inconsitend degree vs number of provided ranks.");
TTTensor u(d);
Tensor a;
Tensor b = _x;
for(size_t j = d-1; j > 0; --j) {
const size_t s = _ranks[j-1];
const size_t contractSize = misc::product(b.dimensions, 0, j);
const size_t staySize = misc::product(b.dimensions, j, b.dimensions.size());
const size_t s = std::min(_ranks[j-1], std::min(contractSize, staySize)+1);
if(b.is_sparse()) {
const size_t staySize = misc::product(b.dimensions, j, b.dimensions.size());
std::map<size_t, std::vector<value_t>> usedG;
std::vector<size_t> outDims({s});
......@@ -60,7 +60,7 @@ TTTensor randomTTSVD(const Tensor& _x, const std::vector<size_t>& _ranks) {
if(gEntry.empty()) {
gEntry.reserve(s);
for(size_t k = 0; k < s; ++k) {
gEntry.push_back(dist(xerus::misc::randomEngine));
gEntry.push_back(misc::defaultNormalDistribution(misc::randomEngine));
}
}
......@@ -72,13 +72,13 @@ TTTensor randomTTSVD(const Tensor& _x, const std::vector<size_t>& _ranks) {
} else {
std::vector<size_t> gDims({s});
gDims.insert(gDims.end(), b.dimensions.cbegin(), b.dimensions.cbegin()+j);
const Tensor g = Tensor::random(gDims, dist, xerus::misc::randomEngine);
const Tensor g = Tensor::random(gDims, misc::defaultNormalDistribution, misc::randomEngine);
contract(a, g, false, b, false, j);
}
Tensor R, Q;
calculate_rq(R, Q, a, 1);
calculate_cq(R, Q, a, 1);
if(j == d-1) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment