Commit 984c3633 authored by Philipp  Trunschke's avatar Philipp Trunschke
Browse files

add TODO and minor formatting

parent 3d956056
Pipeline #2148 failed with stages
in 4 minutes and 10 seconds
......@@ -1170,6 +1170,7 @@ namespace xerus { namespace uq {
}
void SALSA::run() {
//TODO: let run() return the optimal state and allow initialization with a given state
LOG(debug, "Entering run()");
REQUIRE(omegaFactor > 0.0, "omegaFactor must be positive");
REQUIRE(alphaFactor >= 0.0, "alphaFactor must be non-negative");
......@@ -1181,23 +1182,19 @@ namespace xerus { namespace uq {
initialize();
const auto perform_sweep = [&](const bool _adapt) {
// std::cerr << "perform_sweep(start)" << std::endl;
REQUIRE(x.corePosition == 0, "IE");
LOG(debug, "Sweep: left --> right");
for (size_t m=0; m<M-1; ++m) {
// std::cerr << "perform_sweep(" << m << ")" << std::endl;
REQUIRE(x.corePosition == m, "IE");
solve_local();
move_core_right(_adapt);
}
LOG(debug, "Sweep: right --> left");
for (size_t m=M-1; m>0; --m) {
// std::cerr << "perform_sweep(" << m << ")" << std::endl;
REQUIRE(x.corePosition == m, "IE");
solve_local();
move_core_left(_adapt);
}
// std::cerr << "perform_sweep(end)" << std::endl;
REQUIRE(x.corePosition == 0, "IE");
};
......@@ -1260,19 +1257,18 @@ namespace xerus { namespace uq {
const auto reset = attr(0);
const auto bold = attr(1);
if (improvement) { std::cout << bold; }
std::cout << "[" << iteration << "] Costs:"
<< " LS=" << update_str(bestTrainingResidual , trainingResiduals.back())
<< u8", R\u03B1=" << update_str(bestAlphaCosts, alphaCosts)
<< u8", R\u03C9=" << update_str(bestOmegaCosts, omegaCosts)
std::cout << "[" << iteration << "] Costs:"
<< " LS=" << update_str(bestTrainingResidual , trainingResiduals.back())
<< u8", R\u03B1=" << update_str(bestAlphaCosts, alphaCosts)
<< u8", R\u03C9=" << update_str(bestOmegaCosts, omegaCosts)
<< " | Validation: " << update_str(bestState.validationResidual , validationResiduals.back())
<< " | \u03C9: " << string_format("%.2e", omega)
<< " | Densities: " << print_densities()
<< " | Ranks: " << print_fractional_ranks() << reset << std::endl; // Flush to ensure that the user does not have to wait for other sweeps to complete
<< " | \u03C9: " << string_format("%.2e", omega)
<< " | Densities: " << print_densities()
<< " | Ranks: " << print_fractional_ranks() << reset << std::endl; // Flush to ensure that the user does not have to wait for other sweeps to complete
};
print_update(true);
for (iteration=1; iteration<maxSweeps; ++iteration) {
// std::cerr << iteration << " (0)" << std::endl;
perform_sweep(true);
trainingResiduals.push_back(residual(trainingSet));
......@@ -1280,7 +1276,6 @@ namespace xerus { namespace uq {
omegaCosts = omega_residual();
totalCosts = trainingResiduals.back() + alphaCosts + omegaCosts;
validationResiduals.push_back(residual(validationSet));
// std::cerr << iteration << " (2)" << std::endl;
// check if the validation residual decreased in a meaningful way during the last sweep
const bool sweepImprovement = validationResiduals.back() < (1-minDecrease)*bestState.validationResidual;
......@@ -1294,10 +1289,6 @@ namespace xerus { namespace uq {
bestIteration = iteration;
prev_bestValidationResidual = bestState.validationResidual;
bestState = State{alpha, omega, x, trainingResiduals.back(), validationResiduals.back()};
// bestX = x;
// prev_bestValidationResidual = bestValidationResidual;
// bestValidationResidual = validationResiduals.back();
// bestTrainingResidual = trainingResiduals.back();
}
if (validationResiduals.back() < targetResidual) {
......@@ -1313,8 +1304,6 @@ namespace xerus { namespace uq {
smin = 0.2*std::min(omega, res);
omega *= omegaFactor;
// std::cerr << iteration << " (5)" << std::endl;
// check for stagnation in the training residual (trainingResiduals is a circular buffer)
const bool stagnation = trainingResiduals.size() == trackingPeriodLength && trainingResiduals.back() > (1-minDecrease)*trainingResiduals.front() && omegaMinimal;
if (stagnation) {
......@@ -1333,7 +1322,7 @@ namespace xerus { namespace uq {
break;
}
// adapt alpha
// adapt alpha
if (alphaFactor > 0.0) {
const auto prev_alpha = alpha;
// Select omega and x from the most recent optimal state.
......@@ -1357,8 +1346,8 @@ namespace xerus { namespace uq {
//TODO: use disp_shortest_unequal(self.alpha, alpha)
std::cout << "Reduce \u03B1: " << string_format("%.3f", prev_alpha)
<< std::string(u8" \u2192 ")
<< string_format("%.3f", alpha) << std::endl;
<< std::string(u8" \u2192 ")
<< string_format("%.3f", alpha) << std::endl;
} else { REQUIRE(misc::hard_equal(alpha, 0.0), "IE"); }
// clear buffers to ensure that they contain only values for the current choice of alpha
......@@ -1372,8 +1361,8 @@ namespace xerus { namespace uq {
}
std::cout << "Best validation residual in iteration " << bestIteration << ".\n"
<< std::string(125, '-') << '\n'
<< "Truncating inactive singular values." << std::endl;
<< std::string(125, '-') << '\n'
<< "Truncating inactive singular values." << std::endl;
size_t rank;
for (size_t m=0; m<M-1; ++m) {
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment