Commit 3ad08c42 authored by Ben Huber's avatar Ben Huber

removed all transpositions from svd calculation

parent 1aef9637
Pipeline #1110 passed with stages
in 16 minutes and 26 seconds
......@@ -90,7 +90,7 @@ static misc::UnitTest tensor_svd_rnd512("Tensor", "SVD_Random_512x512", [](){
(res1(i,j,k,o), res2(o,p), res3(p,l,m,n)) = SVD(A(l,i,m,k,j,n));
res4(k,i,m,l,j,n) = res1(i,j,l,o)*res2(o,p)*res3(p,k,m,n);
TEST(approx_equal(res4, A, 1e-14));
MTEST(approx_equal(res4, A, 1e-14), frob_norm(res4-A));
MTEST(frob_norm(res1(i^3, m)*res1(i^3, n) - Tensor::identity(res2.dimensions)(m, n)) < 1e-12, " U not orthogonal");
MTEST(frob_norm(res3(m, i^3)*res3(n, i^3) - Tensor::identity(res2.dimensions)(m, n)) < 1e-12, " Vt not orthogonal");
......
......@@ -221,26 +221,16 @@ namespace xerus {
}
void dgesdd_work(lapack_int m, lapack_int n, double* a,
double* s, double* u, double* vt, double* work,
lapack_int lwork, lapack_int* iwork )
{
void dgesdd_work(lapack_int m, lapack_int n, double* a, double* s, double* u, double* vt, double* work, lapack_int lwork, lapack_int* iwork ) {
REQUIRE(lwork > 0, "");
lapack_int info = 0;
char job = 'S';
lapack_int min = std::min(m,n);
std::unique_ptr<double[]> a_t(new double[size_t(m*n)]);
std::unique_ptr<double[]> u_t(new double[size_t(m*min)]);
std::unique_ptr<double[]> vt_t(new double[size_t(min*n)]);
// Transpose input matrices
low_level_transpose(a_t.get(), a, size_t(m), size_t(n));
LAPACK_dgesdd( &job, &m, &n, a_t.get(), &m, s, u_t.get(), &m, vt_t.get(), &min, work, &lwork, iwork, &info );
// if A = U*S*V^T, then A^T = V^T*S*U^T, so we can simply call lapack without transposing our matrices
// by simply changing the order of U and Vt
LAPACK_dgesdd( &job, &n, &m, a, &n, s, vt, &n, u, &min, work, &lwork, iwork, &info );
REQUIRE(info == 0, "dgesdd failed with info " << info);
// Transpose output matrices
// low_level_transpose(a, a_t.get(), n, m);
low_level_transpose(u, u_t.get(), size_t(min), size_t(m));
low_level_transpose(vt, vt_t.get(), size_t(n), size_t(min));
}
void svd_destructive( double* const _U, double* const _S, double* const _Vt, double* const _A, const size_t _m, const size_t _n) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment