Commit 6eafbd96 authored by Philipp  Trunschke's avatar Philipp Trunschke

bugfixes in weighted uq_ra_adf

parent 8da40a27
Pipeline #1205 passed with stages
in 21 minutes and 23 seconds
......@@ -140,8 +140,8 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
setNorms[k] = 0.0;
for(const auto j : sets[k]) {
const double sqrNorm = misc::sqr(frob_norm(solutions[j]));
optNorm += sqrNorm;
setNorms[k] += sqrNorm;
optNorm += weights[j] * sqrNorm; // HIER
setNorms[k] += weights[j] * sqrNorm; // HIER
}
setNorms[k] = std::sqrt(setNorms[k]);
}
......@@ -150,7 +150,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
testNorm = 0.0;
for(const auto j : controlSet) {
const double sqrNorm = misc::sqr(frob_norm(solutions[j]));
testNorm += sqrNorm;
testNorm += weights[j] * sqrNorm; // HIER
}
testNorm = std::sqrt(testNorm);
}
......@@ -317,7 +317,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
// Combine with ought part
contract(dyadComp, isPart - leftOughtStack[_corePosition-1][j], dyadicPart, 0);
delta += weights[j] * dyadComp;
delta += weights[j] * dyadComp; // HIER
}
} else { // _corePosition == 0
Tensor shuffledX = x.get_core(_setId);
......@@ -332,7 +332,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
contract(dyadComp, dyadComp - solutions[j], rightStack[_corePosition+1][j], 0);
dyadComp.reinterpret_dimensions({1, dyadComp.dimensions[0], dyadComp.dimensions[1]});
delta += weights[j] * dyadComp;
delta += weights[j] * dyadComp; // HIER
}
}
......@@ -350,7 +350,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
const size_t j = sets[_setId][jIdx];
contract(tmp, _delta, rightStack[1][j], 1);
const double normPart = misc::sqr(frob_norm(tmp));
norm += weights[j] * normPart;
norm += weights[j] * normPart; // HIER
}
} else { // _corePosition > 0
Tensor shuffledDelta = reshuffle(_delta, {1, 0, 2});
......@@ -380,7 +380,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
}
REQUIRE(tmp.size == 1, "IE");
norm += weights[j] * tmp[0];
norm += weights[j] * tmp[0]; // HIER
}
}
......@@ -406,8 +406,8 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
tmp -= solutions[j];
const double resSqr = misc::sqr(frob_norm(tmp));
optResidual += resSqr;
setResiduals[k] += resSqr;
optResidual += weights[j] * resSqr; // HIER
setResiduals[k] += weights[j] * resSqr; // HIER
}
setResiduals[k] = std::sqrt(setResiduals[k])/setNorms[k];
}
......@@ -420,7 +420,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
tmp -= solutions[j];
const double resSqr = misc::sqr(frob_norm(tmp));
testResidual += resSqr;
testResidual += weights[j] * resSqr; // HIER
}
testResidual = std::sqrt(testResidual)/testNorm;
......@@ -430,14 +430,19 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
void update_core(const size_t _corePosition) {
const Index left, right, ext, p;
REQUIRE(x.corePosition == _corePosition, "IE");
for(size_t setId = 0; setId < P; ++setId) {
const auto delta = calculate_delta(_corePosition, setId);
const auto normAProjGrad = calculate_norm_A_projGrad(delta, _corePosition, setId);
const value_t PyR = misc::sqr(frob_norm(delta));
// Actual update
x.component(_corePosition)(left, ext, p, right) = x.component(_corePosition)(left, ext, p, right)-((PyR/misc::sqr(normAProjGrad))*delta)(left, ext, right)*Tensor::dirac({P}, setId)(p);
if (PyR > 0) {
// Actual update
x.component(_corePosition)(left, ext, p, right) = x.component(_corePosition)(left, ext, p, right)-((PyR/misc::sqr(normAProjGrad))*delta)(left, ext, right)*Tensor::dirac({P}, setId)(p);
} else {
LOG(uqADF, "Warning: vanishing gradient on set " << setId);
}
}
}
......@@ -451,7 +456,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
}
}
LOG(ADF, "Residual decrease from " << std::scientific << 0.0 /* TODO */ << " to " << std::scientific << residuals.back() << " in " << _iteration << " iterations.");
LOG(uqADF, "Residual decrease from " << std::scientific << 0.0 /* TODO */ << " to " << std::scientific << residuals.back() << " in " << _iteration << " iterations.");
}
......@@ -480,7 +485,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
}
LOG(ADFx, "Residual " << std::scientific << residuals.back() << " " << setResiduals << ". NonImpCnt: " << nonImprovementCounter << ", Controlset: " << testResidual << ". Ranks: " << x.ranks() << ". DOFs: " << x.dofs());
LOG(uqADF, "Residual " << std::scientific << residuals.back() << " " << setResiduals << ". NonImpCnt: " << nonImprovementCounter << ", Controlset: " << testResidual << ". Ranks: " << x.ranks() << ". DOFs: " << x.dofs());
if(residuals.back() < targetResidual || nonImprovementCounter >= 100) {
finish(iteration);
......@@ -501,7 +506,7 @@ namespace xerus { namespace uq { namespace impl_uqRaAdf {
}
if(!rankMaxed) {
LOG(ADFx, "Reduce rankEps to " << std::max(minRankEps, epsDecay*rankEps));
LOG(uqADF, "Reduce rankEps to " << std::max(minRankEps, epsDecay*rankEps));
rankEps = std::max(minRankEps, epsDecay*rankEps);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment