Commit 8de067f0 by Ben Huber

### reduced xALS for the examples to less than 100 lines of code

parent 573be0c8
Pipeline #734 passed with stages
in 8 minutes and 29 seconds
 --- layout: post title: "The ALS Algorithm" date: 1000-12-10 topic: "Examples" section: "Examples" --- __tabsInit # The ALS Algorithm Implementing the ALS algorithm for the first time was the most important step for us to understand the TT format and its intricacies. We still think that it is a good point to start so we want to provide a simple implementation of the ALS algorithm as an example. Using the `xerus` library this will be an efficient implementation using less than 100 lines of code.
 ... ... @@ -19,7 +19,7 @@ /** * @file * @brief Implementation of the ADF variants. * @brief Implementation of a simple ALS variant. */ #include ... ... @@ -29,9 +29,6 @@ #include #include #ifdef _OPENMP #include #endif namespace xerus { ... ... @@ -52,50 +49,42 @@ namespace xerus { public: InternalSolver(TTTensor& _x, const TTOperator& _A, const TTTensor& _b) : d(_x.degree()), leftAStack(d), rightAStack(d), leftBStack(d), rightBStack(d), x(_x), A(_A), b(_b) { } InternalSolver(TTTensor& _x, const TTOperator& _A, const TTTensor& _b) : d(_x.degree()), x(_x), A(_A), b(_b) { leftAStack.emplace_back(Tensor::ones(std::vector(d, 1ul))); rightAStack.emplace_back(Tensor::ones(std::vector(d, 1ul))); leftBStack.emplace_back(Tensor::ones(std::vector(d, 1ul))); rightBStack.emplace_back(Tensor::ones(std::vector(d, 1ul))); } void calc_left_stack(const size_t _position) { Tensor xi = x.get_component(_position); Tensor Ai = A.get_component(_position); Tensor bi = b.get_component(_position); void push_left_stack(const size_t _position) { const Tensor &xi = x.get_component(_position); const Tensor &Ai = A.get_component(_position); const Tensor &bi = b.get_component(_position); if(_position == 0) { xi.reinterpret_dimensions({xi.dimensions[1], xi.dimensions[2]}); Ai.reinterpret_dimensions({Ai.dimensions[1], Ai.dimensions[2], Ai.dimensions[3]}); bi.reinterpret_dimensions({bi.dimensions[1], bi.dimensions[2]}); leftAStack[_position](i1, i2, i3) = xi(k1, i1)*Ai(k1, k2, i2)*xi(k2, i3); leftBStack[_position](i1, i2) = xi(k1, i1)*bi(k1, i2); } else { leftAStack[_position](i1, i2, i3) = leftAStack[_position-1](j1, j2, j3)*xi(j1, k1, i1)*Ai(j2, k1, k2, i2)*xi(j3, k2, i3); leftBStack[_position](i1, i2) = leftBStack[_position-1](j1, j2)*xi(j1, k1, i1)*bi(j2, k1, i2); } Tensor tmpA, tmpB; tmpA(i1, i2, i3) = leftAStack.back()(j1, j2, j3)*xi(j1, k1, i1)*Ai(j2, k1, k2, i2)*xi(j3, k2, i3); leftAStack.emplace_back(std::move(tmpA)); tmpB(i1, i2) = leftBStack.back()(j1, j2)*xi(j1, k1, i1)*bi(j2, k1, i2); leftBStack.emplace_back(std::move(tmpB)); } void calc_right_stack(const size_t _position) { Tensor xi = x.get_component(_position); Tensor Ai = A.get_component(_position); Tensor bi = b.get_component(_position); void push_right_stack(const size_t _position) { const Tensor &xi = x.get_component(_position); const Tensor &Ai = A.get_component(_position); const Tensor &bi = b.get_component(_position); if(_position == d-1) { xi.reinterpret_dimensions({xi.dimensions[0], xi.dimensions[1]}); Ai.reinterpret_dimensions({Ai.dimensions[0], Ai.dimensions[1], Ai.dimensions[2]}); bi.reinterpret_dimensions({bi.dimensions[0], bi.dimensions[1]}); rightAStack[_position](i1, i2, i3) = xi(i1, k1)*Ai(i2, k1, k2)*xi(i3, k2); rightBStack[_position](i1, i2) = xi(i1, k1)*bi(i2, k1); } else { rightAStack[_position](i1, i2, i3) = xi(i1, k1, j1)*Ai(i2, k1, k2, j2)*xi(i3, k2, j3)*rightAStack[_position+1](j1, j2, j3); rightBStack[_position](i1, i2) = xi(i1, k1, j1)*bi(i2, k1, j2)*rightBStack[_position+1](j1, j2); } Tensor tmpA, tmpB; tmpA(i1, i2, i3) = xi(i1, k1, j1)*Ai(i2, k1, k2, j2)*xi(i3, k2, j3)*rightAStack.back()(j1, j2, j3); rightAStack.emplace_back(std::move(tmpA)); tmpB(i1, i2) = xi(i1, k1, j1)*bi(i2, k1, j2)*rightBStack.back()(j1, j2); rightBStack.emplace_back(std::move(tmpB)); } double calc_residual_norm() { const Index i, j; TTTensor tmp; tmp(i&0) = A(i/2, j/2)*x(j&0)-b(i&0); tmp(i1&0) = A(i1/2, j1/2)*x(j1&0)-b(i1&0); return frob_norm(tmp); } ... ... @@ -103,13 +92,12 @@ namespace xerus { void solve() { const double solutionsNorm = frob_norm(b); std::vector residuals(10, 1000.0); const size_t maxIterations = 1; const size_t maxIterations = 1000; // Rebuild right stack x.move_core(0, true); for(size_t corePosition = d-1; corePosition > 0; --corePosition) { calc_right_stack(corePosition); for(size_t pos = d-1; pos > 0; --pos) { push_right_stack(pos); } for(size_t iteration = 0; maxIterations == 0 || iteration < maxIterations; ++iteration) { ... ... @@ -127,79 +115,30 @@ namespace xerus { for(size_t corePosition = 0; corePosition < d; ++corePosition) { Tensor op, rhs; Tensor Ai = A.get_component(corePosition); Tensor bi = b.get_component(corePosition); const Tensor &Ai = A.get_component(corePosition); const Tensor &bi = b.get_component(corePosition); if(corePosition == 0) { Ai.reinterpret_dimensions({Ai.dimensions[1], Ai.dimensions[2], Ai.dimensions[3]}); bi.reinterpret_dimensions({bi.dimensions[1], bi.dimensions[2]}); op(i2, i3, j2, j3) = Ai(i2, j2, k2)*rightAStack[corePosition+1](i3, k2, j3); rhs(i2, i3) = bi(i2, k2)*rightBStack[corePosition+1](i3, k2); } else if(corePosition == d-1) { Ai.reinterpret_dimensions({Ai.dimensions[0], Ai.dimensions[1], Ai.dimensions[2]}); bi.reinterpret_dimensions({bi.dimensions[0], bi.dimensions[1]}); op(i1, i2, j1, j2) = leftAStack[corePosition-1](i1, k1, j1)*Ai(k1, i2, j2); rhs(i1, i2) = leftBStack[corePosition-1](i1, k1)*bi(k1, i2); } else { op(i1, i2, i3, j1, j2, j3) = leftAStack[corePosition-1](i1, k1, j1)*Ai(k1, i2, j2, k2)*rightAStack[corePosition+1](i3, k2, j3); rhs(i1, i2, i3) = leftBStack[corePosition-1](i1, k1)*bi(k1, i2, k2)*rightBStack[corePosition+1](i3, k2); } solve_least_squares(x.component(corePosition), op, rhs, 0); op(i1, i2, i3, j1, j2, j3) = leftAStack.back()(i1, k1, j1)*Ai(k1, i2, j2, k2)*rightAStack.back()(i3, k2, j3); rhs(i1, i2, i3) = leftBStack.back()(i1, k1) * bi(k1, i2, k2) * rightBStack.back()(i3, k2); if(corePosition == 0) { x.component(corePosition).reinterpret_dimensions({1, x.component(corePosition).dimensions[0], x.component(corePosition).dimensions[1]}); } else if(corePosition == d-1) { x.component(corePosition).reinterpret_dimensions({x.component(corePosition).dimensions[0], x.component(corePosition).dimensions[1], 1}); } xerus::solve(x.component(corePosition), op, rhs, 0); // If we have not yet reached the end of the sweep we need to take care of the core and update our stacks if(corePosition+1 < d) { x.move_core(corePosition+1, true); calc_left_stack(corePosition); push_left_stack(corePosition); rightAStack.pop_back(); rightBStack.pop_back(); } } // Sweep Right -> Left // Sweep Right -> Left : only move core and update stacks x.move_core(0, true); for(size_t corePosition = d-1; corePosition > 0; --corePosition) { // Tensor op, rhs; // // Tensor Ai = A.get_component(corePosition); // Tensor bi = b.get_component(corePosition); // // if(corePosition == 0) { // Ai.reinterpret_dimensions({Ai.dimensions[1], Ai.dimensions[2], Ai.dimensions[3]}); // bi.reinterpret_dimensions({bi.dimensions[1], bi.dimensions[2]}); // // op(i2, i3, j2, j3) = Ai(i2, j2, k2)*rightAStack[corePosition+1](i3, k2, j3); // rhs(i2, i3) = bi(i2, k2)*rightBStack[corePosition+1](i3, k2); // } else if(corePosition == d-1) { // Ai.reinterpret_dimensions({Ai.dimensions[0], Ai.dimensions[1], Ai.dimensions[2]}); // bi.reinterpret_dimensions({bi.dimensions[0], bi.dimensions[1]}); // // op(i1, i2, j1, j2) = leftAStack[corePosition-1](i1, k1, j1)*Ai(k1, i2, j2); // rhs(i1, i2) = leftBStack[corePosition-1](i1, k1)*bi(k1, i2); // } else { // op(i1, i2, i3, j1, j2, j3) = leftAStack[corePosition-1](i1, k1, j1)*Ai(k1, i2, j2, k2)*rightAStack[corePosition+1](i3, k2, j3); // rhs(i1, i2, i3) = leftBStack[corePosition-1](i1, k1)*bi(k1, i2, k2)*rightBStack[corePosition+1](i3, k2); // } // // solve_least_squares(x.component(corePosition), op, rhs, 0); // // if(corePosition == 0) { // x.component(corePosition).reinterpret_dimensions({1, x.component(corePosition).dimensions[0], x.component(corePosition).dimensions[1]}); // } else if(corePosition == d-1) { // x.component(corePosition).reinterpret_dimensions({x.component(corePosition).dimensions[0], x.component(corePosition).dimensions[1], 1}); // } // If we have not yet reached the end of the sweep we need to take care of the core and update our stacks if(corePosition > 0) { x.move_core(corePosition-1, true); calc_right_stack(corePosition); } push_right_stack(corePosition); leftAStack.pop_back(); leftBStack.pop_back(); } } ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!