Commit 43371e43 authored by Ben Huber's avatar Ben Huber

svd without LAPACKE

parent 4b2d5559
Pipeline #1107 failed with stages
in 16 minutes and 7 seconds
......@@ -71,6 +71,15 @@ namespace xerus {
// thread_local value_t defaultWorkspace[DEFAULT_WORKSPACE_SIZE]; // NOTE recheck compatibility with eigen (dolfin) when reinserting this!
/// @brief stores in @a _out the transpose of the @a _m x @a _n matrix @a _in
void low_level_transpose(double * _out, double * _in, size_t _m, size_t _n) {
for (size_t i=0; i<_m; ++i) {
for (size_t j=0; j<_n; ++j) {
_out[j*_m + i] = _in[i*_n+j];
}
}
}
//----------------------------------------------- LEVEL I BLAS ----------------------------------------------------------
double one_norm(const double* const _x, const size_t _n) {
......@@ -206,26 +215,54 @@ namespace xerus {
}
lapack_int dgesdd_get_workarray_size(lapack_int m, lapack_int n) {
lapack_int info = 0;
char job = 'S';
double work = 0;
lapack_int lwork = -1;
lapack_int min = std::min(m,n);
LAPACK_dgesdd( &job, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &min, &work, &lwork, nullptr, &info );
REQUIRE(info == 0, "work array size query of dgesdd returned " << info);
return lapack_int(work);
}
void dgesdd_work(lapack_int m, lapack_int n, double* a,
double* s, double* u, double* vt, double* work,
lapack_int lwork, lapack_int* iwork )
{
REQUIRE(lwork > 0, "");
lapack_int info = 0;
char job = 'S';
lapack_int min = std::min(m,n);
std::unique_ptr<double[]> a_t(new double[m*n]);
std::unique_ptr<double[]> u_t(new double[m*min]);
std::unique_ptr<double[]> vt_t(new double[min*n]);
// Transpose input matrices
low_level_transpose(a_t.get(), a, m, n);
LAPACK_dgesdd( &job, &m, &n, a_t.get(), &m, s, u_t.get(), &m, vt_t.get(),
&min, work, &lwork, iwork, &info );
REQUIRE(info == 0, "dgesdd failed with info " << info);
// Transpose output matrices
// low_level_transpose(a, a_t.get(), n, m);
low_level_transpose(u, u_t.get(), min, m);
low_level_transpose(vt, vt_t.get(), n, min);
}
void svd_destructive( double* const _U, double* const _S, double* const _Vt, double* const _A, const size_t _m, const size_t _n) {
REQUIRE(_m <= static_cast<size_t>(std::numeric_limits<int>::max()), "Dimension to large for BLAS/Lapack");
REQUIRE(_n <= static_cast<size_t>(std::numeric_limits<int>::max()), "Dimension to large for BLAS/Lapack");
XERUS_PA_START;
std::unique_ptr<double[]> tmpA(new double[_m*_n]);
misc::copy(tmpA.get(), _A, _m*_n);
lapack_int m = lapack_int(_m);
lapack_int n = lapack_int(_n);
lapack_int lwork = -1;
std::unique_ptr<lapack_int[]> iwork(new lapack_int[std::max(1,8*std::min(m,n))]);
lwork = dgesdd_get_workarray_size(m, n);
std::unique_ptr<double[]> work(new double[lwork]);
int lapackAnswer = LAPACKE_dgesdd(LAPACK_ROW_MAJOR, 'S', static_cast<int>(_m), static_cast<int>(_n), _A, static_cast<int>(_n), _S, _U, static_cast<int>(std::min(_m, _n)), _Vt, static_cast<int>(_n));
CHECK(lapackAnswer == 0, warning, "Lapack failed to compute SVD. Answer is: " << lapackAnswer);
CHECK(lapackAnswer == 0, warning, "Call was: LAPACKE_dgesdd(LAPACK_ROW_MAJOR, 'S', " << static_cast<int>(_m) << ", " << static_cast<int>(_n) << ", " << _A << ", " << static_cast<int>(_n) <<", "
<< _S <<", " << _U << ", " << static_cast<int>(std::min(_m, _n)) << ", " << _Vt << ", " << static_cast<int>(_n) << ");");
if(lapackAnswer != 0) {
LOG(warning, "SVD failed ");
// for(size_t i=0; i < _m; ++i) {
// for(size_t j=0; j < _n; ++j) {
// LOG(warning, tmpA[i*_n+j]);
// }
// }
}
dgesdd_work( m, n, _A, _S, _U, _Vt, work.get(), lwork, iwork.get());
XERUS_PA_END("Dense LAPACK", "Singular Value Decomposition", misc::to_string(_m)+"x"+misc::to_string(_n));
}
......@@ -679,7 +716,7 @@ namespace xerus {
static_cast<int>(_n) // LDVR TODO check size of _x
);
CHECK(lapackAnswer == 0, error, "Unable to solve Ax = lambda*x (DGEEV solver). Lapacke says: " << lapackAnswer);
XERUS_PA_END("Dense LAPACK", "Solve (DGEEV)", misc::to_string(_n)+"x"+misc::to_string(_n)+"x"+misc::to_string(_nrhs));
XERUS_PA_END("Dense LAPACK", "Solve (DGEEV)", misc::to_string(_n)+"x"+misc::to_string(_n));
return;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment