Commit 5e714c4c authored by Michael Goette's avatar Michael Goette

added constant TT initializer

parent 8186ac7f
Pipeline #949 passed with stages
in 10 minutes and 17 seconds
......@@ -244,6 +244,38 @@ namespace xerus {
*/
static TTNetwork XERUS_warn_unused dirac(std::vector<size_t> _dimensions, const size_t _position);
/**
* @brief constructs a constant TTNetwork with the given dimensions and ranks.
* @param _dimensions the dimensions of the to be created TTNetwork.
* @param _ranks the ranks of the to be created TTNetwork.
* TODO wright Tests
*/
static TTNetwork XERUS_warn_unused constant(value_t constant, std::vector<size_t> _dimensions, const std::vector<size_t> &_ranks) {
const size_t numComponents = _dimensions.size()/N;
XERUS_REQUIRE(_dimensions.size()%N==0, "Illegal number of dimensions for TTOperator.");
XERUS_REQUIRE(_ranks.size()+1 == numComponents,"Non-matching amount of ranks given to TTNetwork::random.");
XERUS_REQUIRE(!misc::contains(_dimensions, size_t(0)), "Trying to construct a TTTensor with dimension 0 is not possible.");
XERUS_REQUIRE(!misc::contains(_ranks, size_t(0)), "Trying to construct random TTTensor with rank 0 is illegal.");
TTNetwork result(_dimensions.size());
const std::vector<size_t> targetRank = reduce_to_maximal_ranks(_ranks, _dimensions);
for(size_t i = 0; i < numComponents; ++i) {
const size_t leftRank = i==0 ? 1 : targetRank[i-1];
const size_t rightRank = (i==numComponents-1) ? 1 : targetRank[i];
if(isOperator) {
const auto constCmp = constant * Tensor::ones({leftRank, _dimensions[i], _dimensions[numComponents+i], rightRank});
result.set_component(i, constCmp);
} else {
const auto constCmp = constant * Tensor::ones({leftRank, _dimensions[i], rightRank});
result.set_component(i, constCmp);
}
}
//result.move_core(0);
return result;
}
/*- - - - - - - - - - - - - - - - - - - - - - - - - - Standard Operators - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -*/
///@brief TTNetworks are default assignable.
TTNetwork& operator=(const TTNetwork& _other) = default;
......
......@@ -47,6 +47,10 @@ void expose_ttnetwork() {
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::TTTensor::random(_dim, _rank);
}).staticmethod("random")
.def("constant",
+[](value_t _constant, std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::TTTensor::constant(_constant, _dim, _rank);
}).staticmethod("constant")
.def("ones", &TTTensor::ones).staticmethod("ones")
.def("kronecker", &TTTensor::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTTensor (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTTensor::dirac))
......@@ -62,12 +66,11 @@ void expose_ttnetwork() {
return boost::python::make_tuple(result.first, result.second);
}, arg("position"))
// .def("round", static_cast<void (TTTensor::*)(const std::vector<size_t>&, double)>(&TTTensor::round),
// (arg("ranks"), arg("epsilon")=EPSILON)
// )
.def("round", static_cast<void (TTTensor::*)(const std::vector<size_t>&, double)>(&TTTensor::round),
(arg("ranks"), arg("epsilon")=EPSILON)
)
.def("round", static_cast<void (TTTensor::*)(double)>(&TTTensor::round))
// .def("round", static_cast<void (TTTensor::*)(size_t)>(&TTTensor::round))
.def("round", static_cast<void (TTTensor::*)(size_t)>(&TTTensor::round))
.def("soft_threshold", static_cast<void (TTTensor::*)(const double, const bool)>(&TTTensor::soft_threshold),
(arg("tau"), arg("preventZero")=false)
)
......@@ -112,6 +115,10 @@ void expose_ttnetwork() {
+[](std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::TTOperator::random(_dim, _rank);
}).staticmethod("random")
.def("constant",
+[](value_t _constant, std::vector<size_t> _dim, std::vector<size_t> _rank) {
return xerus::TTOperator::constant(_constant, _dim, _rank);
}).staticmethod("constant")
.def("ones", &TTOperator::ones).staticmethod("ones")
.def("kronecker", &TTOperator::kronecker).staticmethod("kronecker")
.def("dirac", static_cast<TTOperator (*)(Tensor::DimensionTuple, const Tensor::MultiIndex&)>(&TTOperator::dirac))
......@@ -132,7 +139,6 @@ void expose_ttnetwork() {
)
.def("round", static_cast<void (TTOperator::*)(double)>(&TTOperator::round))
.def("round", static_cast<void (TTOperator::*)(size_t)>(&TTOperator::round))
.def("soft_threshold", static_cast<void (TTOperator::*)(const double, const bool)>(&TTOperator::soft_threshold),
(arg("tau"), arg("preventZero")=false)
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment