Commit ec260050 authored by Sebastian Wolf's avatar Sebastian Wolf

Closes #214 Added tensor solve checks and testcases

parent e3da5941
Pipeline #1056 passed with stages
in 7 minutes and 39 seconds
......@@ -218,6 +218,68 @@ static misc::UnitTest tensor_constructors("Tensor", "Constructors", [](){
static misc::UnitTest tensor_solve_failtests("Tensor", "solve_failtests", []() {
xerus::Tensor A = xerus::Tensor::identity({5, 5});
for(size_t k = 0; k < 2; ++k) { // Sparse and dense A
xerus::Tensor x = xerus::Tensor({4});
xerus::Tensor b = xerus::Tensor::random({3});
FAILTEST(solve(x, A, b););
x = xerus::Tensor({5});
b = xerus::Tensor::random({3});
FAILTEST(solve(x, A, b););
x = xerus::Tensor({6});
b = xerus::Tensor::random({6});
FAILTEST(solve(x, A, b););
x = xerus::Tensor({3});
b = xerus::Tensor::random({5});
solve(x, A, b);
TEST(approx_equal(x, b, 1e-10));
x = xerus::Tensor({5});
b = xerus::Tensor::random({5});
solve(x, A, b);
TEST(approx_equal(x, b, 1e-10));
x = xerus::Tensor({4,3});
b = xerus::Tensor::random({3});
FAILTEST(solve(x, A, b););
if(k==1) { // Currently extra orders only works in dense.
x = xerus::Tensor({4,3});
b = xerus::Tensor::random({3,4});
FAILTEST(solve(x, A, b););
x = xerus::Tensor({5,3});
b = xerus::Tensor::random({3,3});
FAILTEST(solve(x, A, b, 1););
x = xerus::Tensor({5,3});
b = xerus::Tensor::random({3,5});
FAILTEST(solve(x, A, b, 1););
x = xerus::Tensor::random({3,7,2});
b = xerus::Tensor::random({5,3});
solve(x, A, b, 1);
TEST(approx_equal(x, b, 1e-10));
x = xerus::Tensor::random({5,5,2});
b = xerus::Tensor::random({5,3,4});
solve(x, A, b, 2);
TEST(approx_equal(x, b, 1e-10));
}
A.use_dense_representation();
}
return 0;
});
static misc::UnitTest tensor_sparse_dense("Tensor", "Sparse_Dense_Conversions", [](){
Tensor n({3,3,3,3});
const size_t dim = 100;
......
......@@ -1591,9 +1591,6 @@ namespace xerus {
const size_t degM = _B.degree() - _extraDegree;
const size_t degN = _A.degree() - degM;
REQUIRE(_A.degree() == degM+degN, "Inconsistent dimensions.");
REQUIRE(_B.degree() == degM+_extraDegree, "Inconsistent dimensions.");
// Make sure X has right dimensions
if( _X.degree() != degN + _extraDegree
|| !std::equal(_X.dimensions.begin(), _X.dimensions.begin() + degN, _A.dimensions.begin() + degM)
......@@ -1605,6 +1602,9 @@ namespace xerus {
_X.reset(std::move(newDimX), Tensor::Representation::Dense, Tensor::Initialisation::None);
}
REQUIRE(std::equal(_A.dimensions.begin(), _A.dimensions.begin() + degM, _B.dimensions.begin()), "A and b have incompatible dimensions. A: " << _A.dimensions << ", b: " << _B.dimensions);
XERUS_INTERNAL_CHECK(std::equal(_A.dimensions.begin() + degM, _A.dimensions.end(), _X.dimensions.begin()), "A and b have incompatible dimensions. A: " << _A.dimensions << ", b: " << _B.dimensions);
XERUS_INTERNAL_CHECK(std::equal(_B.dimensions.begin() + degM, _B.dimensions.end(), _X.dimensions.begin()+degN), "A and b have incompatible dimensions. A: " << _A.dimensions << ", b: " << _B.dimensions);
// Calculate multDimensions
const size_t m = misc::product(_A.dimensions, 0, degM);
......@@ -1662,14 +1662,13 @@ namespace xerus {
solve_least_squares(_X, _A, _B, _extraDegree);
return;
}
REQUIRE(&_X != &_B && &_X != &_A, "Not supportet yet");
REQUIRE(&_X != &_B && &_X != &_A, "x=b and x=a is not supported yet.");
const size_t degM = _B.degree() - _extraDegree;
const size_t degN = _A.degree() - degM;
REQUIRE(_A.degree() == degM+degN, "Inconsistent dimensions.");
REQUIRE(_B.degree() == degM+_extraDegree, "Inconsistent dimensions.");
// Make sure X has right dimensions
// Make sure X has the right dimensions
if( _X.degree() != degN + _extraDegree
|| !std::equal(_X.dimensions.begin(), _X.dimensions.begin() + degN, _A.dimensions.begin() + degM)
|| !std::equal(_X.dimensions.begin()+ degN, _X.dimensions.end(), _B.dimensions.begin() + degM))
......@@ -1680,13 +1679,20 @@ namespace xerus {
_X.reset(std::move(newDimX), Tensor::Representation::Dense, Tensor::Initialisation::None);
}
REQUIRE(std::equal(_A.dimensions.begin(), _A.dimensions.begin() + degM, _B.dimensions.begin()), "A and b have incompatible dimensions. A: " << _A.dimensions << ", b: " << _B.dimensions);
XERUS_INTERNAL_CHECK(std::equal(_A.dimensions.begin() + degM, _A.dimensions.end(), _X.dimensions.begin()), "A and b have incompatible dimensions. A: " << _A.dimensions << ", b: " << _B.dimensions);
XERUS_INTERNAL_CHECK(std::equal(_B.dimensions.begin() + degM, _B.dimensions.end(), _X.dimensions.begin()+degN), "A and b have incompatible dimensions. A: " << _A.dimensions << ", b: " << _B.dimensions);
// Calculate multDimensions
const size_t m = misc::product(_A.dimensions, 0, degM);
const size_t n = misc::product(_A.dimensions, degM, degM+degN);
const size_t p = misc::product(_B.dimensions, degM, degM+_extraDegree);
REQUIRE(_B.size == m*p, "A and b have incompatible dimensions. A: " << _A.dimensions << ", b: " << _B.dimensions);
XERUS_INTERNAL_CHECK(_X.size == n*p, "Invalid dimension of x");
// Note that A isdense here
// Note that A is dense here
if(_B.is_dense()) {
blasWrapper::solve(
_X.override_dense_data(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment