Commit ae6f7655 authored by Michael Goette's avatar Michael Goette

updated HT Format for unbalaced trees, tests pass

parent 0b4c8ec3
......@@ -133,8 +133,9 @@ namespace xerus {
template<class distribution=std::normal_distribution<value_t>, class generator=std::mt19937_64>
static HTNetwork XERUS_warn_unused random(std::vector<size_t> _dimensions, const std::vector<size_t> &_ranks, distribution& _dist=xerus::misc::defaultNormalDistribution, generator& _rnd=xerus::misc::randomEngine) {
const size_t numIntComp = static_cast<size_t>(0.5 + std::pow(2,std::ceil(std::log2(static_cast<double>(_dimensions.size()/N ))))) - 1;
const size_t numOfLeaves = _dimensions.size()/N;
const size_t numIntComp = numOfLeaves - 1;
const size_t numComponents = numIntComp + numOfLeaves;
XERUS_REQUIRE(_dimensions.size()%N==0, "Illegal number of dimensions/Leaves for HTOperator.");
......@@ -184,7 +185,7 @@ namespace xerus {
*/
template<class distribution=std::normal_distribution<value_t>, class generator=std::mt19937_64>
static HTNetwork XERUS_warn_unused random(const std::vector<size_t>& _dimensions, const size_t _rank, distribution& _dist=xerus::misc::defaultNormalDistribution, generator& _rnd=xerus::misc::randomEngine) {
return HTNetwork::random(_dimensions, std::vector<size_t>((static_cast<size_t>(std::pow(2,std::ceil(std::log2(static_cast<double>(_dimensions.size()/N ))))) - 2 + _dimensions.size()/N), _rank), _dist, _rnd);
return HTNetwork::random(_dimensions, std::vector<size_t>(2 * _dimensions.size()/N - 2, _rank), _dist, _rnd);
}
......
......@@ -88,8 +88,8 @@ static misc::UnitTest ht_diff("HT", "difference", [](){
static misc::UnitTest ht_real_diff("HT", "real_difference", [](){
HTTensor htA = HTTensor::random({10,10,10,10,10}, {4,4,4,4,4,4,4,4,4,4,4});
HTTensor htB = HTTensor::random({10,10,10,10,10}, {4,4,4,4,4,4,4,4,4,4,4});
HTTensor htA = HTTensor::random({10,10,10,10,10}, {4,4,4,4,4,4,4,4});
HTTensor htB = HTTensor::random({10,10,10,10,10}, {4,4,4,4,4,4,4,4});
HTTensor htC(5);
Index i;
......@@ -103,10 +103,11 @@ static misc::UnitTest ht_real_diff("HT", "real_difference", [](){
MTEST(frob_norm(htC(i&0)) < 1e-9, "3 " << frob_norm(htC(i&0)));
htC(i&0) = (htA(i&0) + htB(i&0));
TEST(htC.ranks() == std::vector<size_t>({ 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 1, 1, 1 }));
TEST(htC.ranks() == std::vector<size_t>({ 8, 8, 8, 8, 8, 8, 8, 8 }));
htC(i&0) = (htB(i&0) + htA(i&0));
TEST(htC.ranks() == std::vector<size_t>({ 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 1, 1, 1 }));
TEST(htC.ranks() == std::vector<size_t>({ 8, 8, 8, 8, 8, 8, 8, 8 }));
htC(i&0) = (htA(i&0) + htB(i&0)) - (htB(i&0) + htA(i&0));
MTEST(frob_norm(htC(i&0)) < 1e-9, "4 " << frob_norm(htC(i&0)));
......@@ -114,7 +115,7 @@ static misc::UnitTest ht_real_diff("HT", "real_difference", [](){
htC(i&0) = (73*htA(i&0) + htB(i&0)) - (htB(i&0) + 73*htA(i&0));
MTEST(frob_norm(htC(i&0)) < 1e-7, "5 " << frob_norm(htC(i&0)));
htA = HTTensor::random({10,10,10,10,10}, {2,5,7,2,4,4,3,6,5,2,4});
htA = HTTensor::random({10,10,10,10,10}, {2,5,7,2,4,4,3,6});
htC(i&0) = htA(i&0) - htA(i&0);
MTEST(frob_norm(htC(i&0)) < 1e-9, "6 " << frob_norm(htC(i&0)));
......@@ -190,8 +191,8 @@ static misc::UnitTest ht_spec_sumdiff("HT", "special_sum_diff", [](){
MTEST(frob_norm(Tensor(htC)(i&0) - C(i&0)) < 5*1e-12, frob_norm(Tensor(htC)(i&0) - C(i&0)));
MTEST(frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)) < 3.1*1e-12, frob_norm(Tensor(htC)(i&0) - Tensor(htB)(i&0)));
Tensor X({10});
Tensor Y = Tensor::random({10});
Tensor X({10,10});
Tensor Y = Tensor::random({10,10});
Tensor Z;
HTTensor htX(X);
HTTensor htY(Y);
......@@ -320,8 +321,8 @@ static misc::UnitTest ht_trans("HT", "transpose", [](){
});
static misc::UnitTest ht_axb("HT", "ax_b", [](){
HTTensor X = HTTensor::random({10,10,10}, {2,2,2,2,2});
HTTensor B = HTTensor::random({10,10,10}, {2,2,2,2,2});
HTTensor X = HTTensor::random({10,10,10}, {2,2,2,2});
HTTensor B = HTTensor::random({10,10,10}, {2,2,2,2});
Tensor I({10,10,10,10,10,10}, [](const std::vector<size_t> &_idx) {
if (_idx[0]==_idx[3] && _idx[1] == _idx[4] && _idx[2] == _idx[5]) {
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment