Commit 53ae48ff authored by Sebastian Wolf's avatar Sebastian Wolf

ADF now uses unified Optimization Algorithm interface

parent de4cabd9
Pipeline #1100 failed with stages
in 14 minutes and 34 seconds
This diff is collapsed.
......@@ -25,19 +25,14 @@
#pragma once
#include "optimizationAlgorithm.h"
#include "../ttNetwork.h"
#include "../performanceData.h"
#include "../measurments.h"
#include "../forwardDeclarations.h"
namespace xerus {
class ASDVariant : public OptimizationAlgorithm {
public:
double minRankEps = 1e-4;
// double maxRankEps = 1e-1;
double epsDecay = 1.1;
double controlSetFraction = 0.1;
......@@ -48,8 +43,10 @@ namespace xerus {
ASDVariant(const size_t _maxIterations, const double _targetRelativeResidual, const double _minimalResidualNormDecrease)
: OptimizationAlgorithm(0, _maxIterations, _targetRelativeResidual, _minimalResidualNormDecrease) { }
//TODO: Use OptimizationSolver for internal solver!
/**
* @brief Tries to reconstruct the (low rank) tensor _x from the given measurments.
* @brief Tries to reconstruct the (low rank) tensor @a _x from the given measurments.
* @param[in,out] _x On input: an initial guess of the solution, also defining the ranks. On output: The reconstruction found by the algorithm.
* @param _measurments the available measurments, can be either a SinglePointMeasurementSet or RankOneMeasurementSet.
* @param _perfData optinal performanceData object to be used.
......@@ -58,7 +55,7 @@ namespace xerus {
void operator()(TTTensor& _x, const RankOneMeasurementSet& _measurments, PerformanceData& _perfData) const;
/**
* @brief Tries to reconstruct the (low rank) tensor _x from the given measurments.
* @brief Tries to reconstruct the (low rank) tensor @a _x from the given measurments.
* @param[in,out] _x On input: an initial guess of the solution, may be of smaller rank. On output: The reconstruction found by the algorithm.
* @param _measurments the available measurments, can be either a SinglePointMeasurementSet or RankOneMeasurementSet.
* @param _maxRanks the maximal ranks the algorithm may use to decrease the resdiual.
......
......@@ -24,7 +24,12 @@
#pragma once
#include "../performanceData.h"
#include "../basic.h"
#include "../forwardDeclarations.h"
#include <boost/circular_buffer.hpp>
namespace xerus {
......@@ -36,7 +41,7 @@ namespace xerus {
///@brief Minimal number of iterations.
size_t minIterations;
///@brief Maximal allowed number of iterations.
///@brief Maximal allowed number of iterations. Zero for infinite.
size_t maxIterations;
///@brief The target residual norm at which the algorithm shall stop.
......@@ -48,11 +53,64 @@ namespace xerus {
///@brief Number of iterations used to check for stopping criteria (e.g. residual[iterations] <= residual[iteration-tracking]*pow(minimalResidualNormDecrease, tracking) )
size_t tracking = 10;
///@brief PerformanceData object used to record the performance of the algorithm.
PerformanceData perfData;
protected:
OptimizationAlgorithm(const size_t _minIterations, const size_t _maxIterations, const double _targetRelativeResidual, const double _minimalResidualNormDecrease);
};
} // namespace xerus
namespace internal {
class OptimizationSolver {
protected:
///@brief Minimal number of iterations.
const size_t minIterations;
///@brief Maximal allowed number of iterations. Zero for infinite.
const size_t maxIterations;
///@brief The target residual norm at which the algorithm shall stop.
const double targetRelativeResidual;
///@brief Minimal decrease of the residual norm ( newRes/oldRes ) until either the ranks are increased (if allowed) or the algorithm stops.
const double minimalResidualNormDecrease;
///@brief Number of iterations used to check for stopping criteria (e.g. residual[iterations] <= residual[iteration-tracking]*pow(minimalResidualNormDecrease, tracking) )
const size_t tracking;
///@brief Defined as pow(minimalResidualNormDecrease, tracking).
const double convergenceFactor;
private:
///@brief The current iteration.
size_t iteration = 0;
///@brief The last (tracking) residuals.
boost::circular_buffer<double> lastResiduals;
protected:
///@brief: Reference to the performanceData object (external ownership)
PerformanceData& perfData;
OptimizationSolver(const OptimizationAlgorithm& _optiAlgorithm, PerformanceData& _perfData);
///@brief Increased iteration by one and adds the residual to the circular buffer.
void make_step(const double _residual);
size_t current_iteration() const;
double current_residual() const;
///@brief True if either the maxIterations are reached or the targetRelativeResidual is reached.
bool reached_stopping_criteria() const;
///@brief True if either the minInterations are reached and convegence is reached (i.e. residual[iterations] <= residual[iteration-tracking]*pow(minimalResidualNormDecrease, tracking) ).
bool reached_convergence_criteria() const;
///@brief Resets the convergence buffer with max doubles. In particular at least tracking iterations are then nessecary the reach convergence.
void reset_convergence_buffer();
};
} // End namespace internal
} // End namespace xerus
......@@ -31,5 +31,10 @@ namespace xerus {
template<bool isOperator> class TTNetwork;
typedef TTNetwork<false> TTTensor;
typedef TTNetwork<true> TTOperator;
class SinglePointMeasurementSet;
class RankOneMeasurementSet;
class PerformanceData;
} // End xerus namespace
This diff is collapsed.
......@@ -31,6 +31,8 @@
#include <xerus/misc/internal.h>
#include <xerus/blockTT.h>
#include <xerus/measurments.h>
#include <xerus/performanceData.h>
#ifdef _OPENMP
#include <omp.h>
......
......@@ -24,13 +24,59 @@
#include <xerus/algorithms/optimizationAlgorithm.h>
#include <xerus/misc/math.h>
#include <xerus/performanceData.h>
namespace xerus {
OptimizationAlgorithm::OptimizationAlgorithm(const size_t _minIterations, const size_t _maxIterations, const double _targetRelativeResidual, const double _minimalResidualNormDecrease) :
minIterations(_minIterations),
maxIterations(_maxIterations),
targetRelativeResidual(_targetRelativeResidual),
minimalResidualNormDecrease(_minimalResidualNormDecrease),
perfData(false, false)
minimalResidualNormDecrease(_minimalResidualNormDecrease)
{}
internal::OptimizationSolver::OptimizationSolver(const OptimizationAlgorithm& _optiAlgorithm, PerformanceData& _perfData) :
minIterations(_optiAlgorithm.minIterations),
maxIterations(_optiAlgorithm.maxIterations),
targetRelativeResidual(_optiAlgorithm.targetRelativeResidual),
minimalResidualNormDecrease(_optiAlgorithm.minimalResidualNormDecrease),
tracking(_optiAlgorithm.tracking),
convergenceFactor(misc::pow(minimalResidualNormDecrease, tracking)),
lastResiduals(tracking, std::numeric_limits<double>::max()),
perfData(_perfData)
{ }
void internal::OptimizationSolver::make_step(const double _residual) {
iteration++;
lastResiduals.push_back(_residual);
}
size_t internal::OptimizationSolver::current_iteration() const {
return iteration;
}
double internal::OptimizationSolver::current_residual() const {
return lastResiduals.back();
}
bool internal::OptimizationSolver::reached_stopping_criteria() const {
return (maxIterations > 0 && iteration >= maxIterations) || (iteration >= minIterations && lastResiduals.back() <= targetRelativeResidual);
}
bool internal::OptimizationSolver::reached_convergence_criteria() const {
return iteration >= minIterations && lastResiduals.back() > convergenceFactor*lastResiduals.front();
}
void internal::OptimizationSolver::reset_convergence_buffer() {
lastResiduals = boost::circular_buffer<double>(tracking, std::numeric_limits<double>::max());
}
} // namespace xerus
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment