Commit 09d0ae8c authored by Sebastian Wolf's avatar Sebastian Wolf

Tidy up Tensor solve / operator

parent 578c7a6d
Pipeline #628 failed with stages
in 7 minutes and 56 seconds
......@@ -25,19 +25,14 @@
#include <xerus/misc/internal.h>
#include <xerus/misc/containerSupport.h>
#include <xerus/indexedTensor_tensor_factorisations.h>
#include <xerus/index.h>
#include <xerus/indexedTensor.h>
#include <xerus/indexedTensorMoveable.h>
#include <xerus/tensor.h>
#include <xerus/blasLapackWrapper.h>
#include <xerus/cholmod_wrapper.h>
namespace xerus {
void solve(internal::IndexedTensorWritable<Tensor>&& _x, internal::IndexedTensorReadOnly<Tensor>&& _a, internal::IndexedTensorReadOnly<Tensor>&& _b) {
_a.assign_indices();
internal::IndexedTensorMoveable<Tensor> operator/ (internal::IndexedTensorReadOnly<Tensor>&& _b, internal::IndexedTensorReadOnly<Tensor>&& _A) {
_A.assign_indices();
_b.assign_indices();
size_t extraDims = 0;
......@@ -46,14 +41,11 @@ namespace xerus {
std::vector<Index> orderX;
// If possible we don't want to reorder A, so first divide A into those shared with b and x.
for(const Index& idx : _a.indices) {
for(const Index& idx : _A.indices) {
if(misc::contains(_b.indices, idx)) {
orderA.push_back(idx);
} else {
REQUIRE(misc::contains(_x.indices, idx), "Invalid indices");
orderX.push_back(idx);
for(size_t i = 0; i < idx.span; ++i) {
}
}
}
......@@ -65,47 +57,24 @@ namespace xerus {
// Now complete indices of b and x with those not shared with A ( in order of b as we don't want to reorder b if possible).
for(const Index& idx : _b.indices) {
if(misc::contains(_x.indices, idx)) {
if(!misc::contains(_A.indices, idx)) {
orderB.push_back(idx);
orderX.push_back(idx);
for(size_t i = 0; i < idx.span; ++i) {
extraDims++;
}
} else {
REQUIRE(misc::contains(_a.indices, idx), "Invalid indices");
}
}
// If indices coincide no reordering occours (only shared data pointer is set).
Tensor reorderedA, reorderedB;
reorderedA(orderA) = std::move(_a);
reorderedA(orderA) = std::move(_A);
reorderedB(orderB) = std::move(_b);
solve_least_squares(*_x.tensorObject, reorderedA, reorderedB, extraDims);
(*_x.tensorObject)(_x.indices) = (*_x.tensorObject)(orderX);
}
internal::IndexedTensorMoveable<Tensor> operator/ (internal::IndexedTensorReadOnly<Tensor>&& _b, internal::IndexedTensorReadOnly<Tensor>&& _A) {
_A.assign_indices();
_b.assign_indices();
std::vector<Index> indicesX;
for(const Index& idx : _A.indices) {
if(!misc::contains(_b.indices, idx)) {
indicesX.push_back(idx);
}
}
for(const Index& idx : _b.indices) {
if(!misc::contains(_A.indices, idx)) {
indicesX.push_back(idx);
}
}
internal::IndexedTensorMoveable<Tensor> tmpX(new Tensor(), std::move(orderX));
internal::IndexedTensorMoveable<Tensor> tmpX(new Tensor(), std::move(indicesX));
solve_least_squares(*tmpX.tensorObject, reorderedA, reorderedB, extraDims);
solve(std::move(tmpX), std::move(_A), std::move(_b));
return tmpX;
}
} // namespace xerus
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment