Commit 9296cab4 authored by Michael Goette's avatar Michael Goette
Browse files

added ones() for htnetworks@

parent 7427977b
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@
#include "indexedTensorMoveable.h"
#include "indexedTensorList.h"

#include <xerus/misc/internal.h>
//#include <xerus/misc/internal.h>



@@ -232,7 +232,7 @@ namespace xerus {
		 * @brief: Returns a the (rank one) HT-Tensor with all entries equal to one.
		 * @param _dimensions the dimensions of the new tensor.
		 */
//		static HTNetwork XERUS_warn_unused ones(const std::vector<size_t>& _dimensions);
		static HTNetwork XERUS_warn_unused ones(const std::vector<size_t>& _dimensions);
		
		
		/** 
@@ -464,10 +464,10 @@ namespace xerus {
		
		
		/** 
		* @brief Move the core to the left.
		* @brief Move the core to the root.
		* @details Basically calls move_core() with _position = 0
		*/
//		void canonicalize_left();
		void canonicalize_root();
		
		
		/** 
+38 −29
Original line number Diff line number Diff line
@@ -254,29 +254,38 @@ namespace xerus {
//// 	}
//
//
//	template<bool isOperator>
//	HTNetwork<isOperator> HTNetwork<isOperator>::ones(const std::vector<size_t>& _dimensions) {
//		REQUIRE(_dimensions.size()%N == 0, "Illegal number of dimensions for ttOperator");
//		REQUIRE(!misc::contains(_dimensions, size_t(0)), "Trying to construct a TTTensor with dimension 0 is not possible.");
//
//		if(_dimensions.empty()) {
//			return HTNetwork(Tensor::ones({}));
//		}
//
//		HTNetwork result(_dimensions.size());
//		const size_t numNodes = result.numberOfComponents;
//
//		std::vector<size_t> dimensions(isOperator ? 4 : 3, 1);
//		for(size_t i = 0; i < numNodes; ++i) {
//			dimensions[1] = _dimensions[i];
//			if (isOperator) {
//				dimensions[2] = _dimensions[i+numNodes];
//			}
//			result.set_component(i, Tensor::ones(dimensions));
//		}
//		result.canonicalize_left();
//		return result;
//	}
	template<bool isOperator>
	HTNetwork<isOperator> HTNetwork<isOperator>::ones(const std::vector<size_t>& _dimensions) {
		REQUIRE(_dimensions.size()%N == 0, "Illegal number of dimensions for htOperator");
		REQUIRE(!misc::contains(_dimensions, size_t(0)), "Trying to construct a HTTensor with dimension 0 is not possible.");

		const size_t numIntComp = static_cast<size_t>(std::pow(2,std::ceil(std::log2(static_cast<double>(_dimensions.size()/N ))))) - 1;
		const size_t numOfLeaves = _dimensions.size()/N;
		const size_t numComponents = numIntComp + numOfLeaves;

		if(_dimensions.empty()) {
			return HTNetwork(Tensor::ones({}));
		}

		HTNetwork result(_dimensions.size());
		//Leaves
		std::vector<size_t> dimensions(isOperator ? 3 : 2, 1);
		for(size_t i = numIntComp; i < numComponents; ++i) {
			dimensions[1] = _dimensions[i - numIntComp];
			if (isOperator) {
				dimensions[2] = _dimensions[i - numIntComp + numOfLeaves];
			}
			result.set_component(i, Tensor::ones(dimensions));
		}
		//Internal Components
		std::vector<size_t> dimensions2(3, 1);
		for(size_t i = 0; i < numIntComp; ++i) {
			result.set_component(i, Tensor::ones(dimensions2));
		}

		result.canonicalize_root();
		return result;
	}
//
//
//	template<> template<>
@@ -746,12 +755,12 @@ namespace xerus {
		canonicalized = true;
		corePosition = _position;
	}
//
//
//	template<bool isOperator>
//	void TTNetwork<isOperator>::canonicalize_left() {
//		move_core(0);
//	}


	template<bool isOperator>
	void HTNetwork<isOperator>::canonicalize_root() {
		move_core(0);
	}
//
//
//	template<bool isOperator>