Commit 8f69feb1 authored by Sebastian Wolf's avatar Sebastian Wolf

ADF hotfix

parent 82a836d8
......@@ -129,6 +129,8 @@ namespace xerus {
/// @brief Vector containing for each corePosition a vector of the smallest ids of each group of unique backwardStack entries.
std::vector<std::vector<size_t>> backwardUpdates;
/// @brief: Norm of each rank one measurment operator
std::unique_ptr<double[]> measurmentNorms;
///@brief: Reference to the performanceData object (external ownership)
PerformanceData& perfData;
......@@ -216,6 +218,8 @@ namespace xerus {
backwardStack(backwardStackMem.get()+numMeasurments),
backwardUpdates(degree),
measurmentNorms(new double[numMeasurments]),
perfData(_perfData)
{
_x.require_correct_format();
......
......@@ -427,7 +427,7 @@ namespace xerus {
#pragma omp for schedule(static)
for(size_t i = 0; i < numMeasurments; ++i) {
contract(currentValue, *forwardStack[i + _corePosition*numMeasurments], false, *backwardStack[i + (_corePosition+1)*numMeasurments], false, 1);
partialNormAProjGrad[position_or_zero(measurments, i, _corePosition)] += misc::sqr(currentValue[0]);
partialNormAProjGrad[position_or_zero(measurments, i, _corePosition)] += misc::sqr(currentValue[0]*measurmentNorms[i]); // TODO measurmentNorms
}
// Accumulate the partical components
......@@ -448,7 +448,7 @@ namespace xerus {
#pragma omp for schedule(static)
for(size_t i = 0; i < numMeasurments; ++i) {
contract(currentValue, *forwardStack[i + (_corePosition-1)*numMeasurments], false, *backwardStack[i + _corePosition*numMeasurments], false, 1);
partialNormAProjGrad[position_or_zero(measurments, i, _corePosition)] += misc::sqr(currentValue[0]);
partialNormAProjGrad[position_or_zero(measurments, i, _corePosition)] += misc::sqr(currentValue[0]*measurmentNorms[i]); // TODO measurmentNorms
}
// Accumulate the partical components
......@@ -542,6 +542,26 @@ namespace xerus {
}
template<class MeasurmentSet>
inline void calc_measurment_norm(double* _norms, const MeasurmentSet& _measurments);
template<>
inline void calc_measurment_norm<SinglePointMeasurementSet>(double* _norms, const SinglePointMeasurementSet& _measurments) {
for(size_t i = 0; i < _measurments.size(); ++i) {
_norms[i] = 1.0;
}
}
template<>
inline void calc_measurment_norm<RankOneMeasurementSet>(double* _norms, const RankOneMeasurementSet& _measurments) {
for(size_t i = 0; i < _measurments.size(); ++i) {
_norms[i] = 1.0;
for(size_t j = 0; j < _measurments.degree(); ++j) {
_norms[i] *= _measurments.positions[i][j].frob_norm();
}
}
}
template<class MeasurmentSet>
double ADFVariant::InternalSolver<MeasurmentSet>::solve() {
......@@ -555,6 +575,9 @@ namespace xerus {
#pragma omp section
construct_stacks(backwardStackSaveSlots, backwardUpdates, backwardStackMem, false);
}
calc_measurment_norm(measurmentNorms.get(), measurments);
// We need x to be canonicalized in the sense that there is no edge with more than maximal rank (prior to stack resize).
x.cannonicalize_left();
......@@ -565,24 +588,11 @@ namespace xerus {
// If we follow a rank increasing strategie, increase the ransk until we reach the targetResidual, the maxRanks or the maxIterations.
while(residualNorm > targetResidualNorm && x.ranks() != maxRanks && (maxIterations == 0 || iteration < maxIterations)) {
// LOG(xRanKResBefore, measurments.test(x));
// Increase the ranks
x.move_core(0, true);
const auto rndTensor = TTTensor::random(x.dimensions, std::vector<size_t>(x.degree()-1, 1));
const auto oldX = x;
auto diff = (1.0/frob_norm(rndTensor))*rndTensor;
// LOG(bla, frob_norm(diff) << " x " << frob_norm(x) << " b " << normMeasuredValues);
for(size_t i = 0; i < diff.degree(); ++i) {
diff.component(i).apply_factor();
}
// LOG(diff1, measurments.test(diff));
diff *= normMeasuredValues*1e-5;
// LOG(diff2, measurments.test(diff));
const auto diff = (1e-6*frob_norm(x))*rndTensor/frob_norm(rndTensor);
x = x+diff;
// LOG(realDifference, frob_norm(x -oldX) << " x " << frob_norm(x));
// LOG(xRanKResAfter, measurments.test(x));
x.round(maxRanks);
......
......@@ -270,7 +270,7 @@ namespace xerus {
}
res() = 0.5*xAx(r1&0) * localOperatorCache.right.back()(r1&0)
- bx(r1&0) * rhsCache.right.back()(r1&0);
return res[0];
return res.frob_norm();
};
}
} else {
......@@ -541,7 +541,7 @@ namespace xerus {
}
if(check_for_end_of_sweep(data, _numHalfSweeps, _convergenceEpsilon, _perfData)) {
return data.energy;
return data.energy; // TODO residual?
}
data.move_to_next_index();
......
......@@ -470,34 +470,37 @@ namespace xerus {
double RankOneMeasurementSet::test(const TensorNetwork& _solution) const {
REQUIRE(_solution.degree() == degree(), "Degrees of solution and measurements must match!");
const auto cSize = size();
double error = 0.0, norm = 0.0;
std::vector<TensorNetwork> stack(degree()+1);
stack[0] = _solution;
stack[0].reduce_representation();
Index l, k;
for(size_t j = 0; j < cSize; ++j) {
size_t rebuildIndex = 0;
if(j > 0) {
// Find the maximal recyclable stack position
for(; rebuildIndex < degree(); ++rebuildIndex) {
if(!approx_equal(positions[j-1][rebuildIndex], positions[j][rebuildIndex])) {
break;
}
}
}
// Rebuild stack
for(size_t i = rebuildIndex; i < degree(); ++i) {
stack[i+1](k&0) = positions[j][i](l) * stack[i](l, k&1);
stack[i+1].reduce_representation();
}
error += misc::sqr(measuredValues[j] - stack.back()[0]);
norm += misc::sqr(measuredValues[j]);
const Index l, k;
double error = 0.0, norm = 0.0;
#pragma omp parallel reduction(+:error, norm)
{
std::vector<TensorNetwork> stack(degree()+1);
stack[degree()] = _solution;
stack[degree()].reduce_representation();
bool init = true;
#pragma omp for
for(size_t j = 0; j < cSize; ++j) {
size_t unchangedModes = 0;
if(!init) {
// Find the maximal recyclable stack position
for(; unchangedModes < degree(); ++unchangedModes) {
if(!approx_equal(positions[j-1][degree()-1-unchangedModes], positions[j][degree()-1-unchangedModes])) { break; }
}
} else { init = false; }
// Rebuild stack
for(long i = degree()-1-unchangedModes; i >= 0; --i) {
stack[i](k&0) = stack[i+1](k&1, l) * positions[j][i](l);
stack[i].reduce_representation();
}
error += misc::sqr(measuredValues[j] - stack.front()[0]);
norm += misc::sqr(measuredValues[j]);
}
}
return std::sqrt(error/norm);
......
......@@ -1594,12 +1594,23 @@ namespace xerus {
}
} else { // Dense A
REQUIRE(_B.is_dense(), "Not yet implemented");
blasWrapper::solve_least_squares(
_X.override_dense_data(),
_A.get_unsanitized_dense_data(), m, n,
_B.get_unsanitized_dense_data(), p);
if(_B.is_dense()) {
blasWrapper::solve_least_squares(
_X.override_dense_data(),
_A.get_unsanitized_dense_data(), m, n,
_B.get_unsanitized_dense_data(), p);
} else {
LOG(warning, "Sparse RHS not yet implemented (casting to dense)"); //TODO
Tensor Bcpy(_B);
Bcpy.factor = 1.0;
Bcpy.use_dense_representation();
blasWrapper::solve_least_squares(
_X.override_dense_data(),
_A.get_unsanitized_dense_data(), m, n,
Bcpy.get_unsanitized_dense_data(), p);
}
}
// Propagate the constant factor
......
......@@ -785,6 +785,7 @@ namespace xerus {
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Basic arithmetics - - - - - - - - - - - - - - - - - - - - - - - - - - */
// TODO why sparse?
template<bool isOperator>
TTNetwork<isOperator>& TTNetwork<isOperator>::operator+=(const TTNetwork<isOperator>& _other) {
REQUIRE(dimensions == _other.dimensions, "The dimensions in TT sum must coincide. Given " << dimensions << " vs " << _other.dimensions);
......@@ -817,7 +818,7 @@ namespace xerus {
if (isOperator) { nxtDimensions.emplace_back(myComponent.dimensions[2]); }
nxtDimensions.emplace_back(position == numComponents-1 ? 1 : myComponent.dimensions.back()+otherComponent.dimensions.back());
const Tensor::Representation newRep = myComponent.is_sparse() || otherComponent.is_sparse() ? Tensor::Representation::Sparse : Tensor::Representation::Dense;
const Tensor::Representation newRep = myComponent.is_sparse() && otherComponent.is_sparse() ? Tensor::Representation::Sparse : Tensor::Representation::Dense;
std::unique_ptr<Tensor> newComponent(new Tensor(std::move(nxtDimensions), newRep));
newComponent->offset_add(myComponent, isOperator ? std::vector<size_t>({0,0,0,0}) : std::vector<size_t>({0,0,0}));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment