Skip to content

Commit

Permalink
Issue #925: Added noncopy constructor for Tensor and TensorT
Browse files Browse the repository at this point in the history
Added methods reserve() and constructor for Tensor and TensorT

Signed-off-by: Andrea Calabrese <[email protected]>
  • Loading branch information
ThePseudo committed Apr 17, 2024
1 parent 51b05fe commit 5010ac4
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
36 changes: 36 additions & 0 deletions src/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ Tensor::Tensor(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
this->rebuild(data, elementTotalCount, elementMemorySize);
}

Tensor::Tensor(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
uint32_t elementTotalCount,
uint32_t elementMemorySize,
const TensorDataTypes& dataType,
const TensorTypes& tensorType)
{
KP_LOG_DEBUG("Kompute Tensor constructor data length: {}, and type: {}",
elementTotalCount,
Tensor::toString(tensorType));

this->mPhysicalDevice = physicalDevice;
this->mDevice = device;
this->mDataType = dataType;
this->mTensorType = tensorType;

this->reserve(elementTotalCount, elementMemorySize);
}

Tensor::~Tensor()
{
KP_LOG_DEBUG("Kompute Tensor destructor started. Type: {}",
Expand All @@ -70,6 +89,23 @@ Tensor::~Tensor()
KP_LOG_DEBUG("Kompute Tensor destructor success");
}

void
Tensor::reserve(uint32_t elementTotalCount, uint32_t elementMemorySize)
{
KP_LOG_DEBUG("Reserving {} bytes for memory", elementTotalCount);

this->mSize = elementTotalCount;
this->mDataTypeMemorySize = elementMemorySize;

if (this->mPrimaryBuffer || this->mPrimaryMemory) {
KP_LOG_DEBUG(
"Kompute Tensor destroying existing resources before rebuild");
this->destroy();
}

this->allocateMemoryCreateGPUResources();
}

void
Tensor::rebuild(void* data,
uint32_t elementTotalCount,
Expand Down
41 changes: 40 additions & 1 deletion src/include/kompute/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ class Tensor
const TensorDataTypes& dataType,
const TensorTypes& tensorType = TensorTypes::eDevice);

/**
* Constructor with size provided which would be used to create the
* respective vulkan buffer and memory. Data is not copied.
*
* @param physicalDevice The physical device to use to fetch properties
* @param device The device to use to create the buffer and memory from
* @param elmentTotalCount the number of elements of the array
* @param elementMemorySize the size of the element
* @param tensorTypes Type for the tensor which is of type TensorTypes
*/
Tensor(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
uint32_t elementTotalCount,
uint32_t elementMemorySize,
const TensorDataTypes& dataType,
const TensorTypes& tensorType = TensorTypes::eDevice);

/**
* Destructor which is in charge of freeing vulkan resources unless they
* have been provided externally.
Expand All @@ -79,6 +96,13 @@ class Tensor
uint32_t elementTotalCount,
uint32_t elementMemorySize);

/**
* @brief Reserve memory on the tensor
*
* @param newSize the new size for reservation
*/
void reserve(uint32_t elementTotalCount, uint32_t elementMemorySize);

/**
* Destroys and frees the GPU resources which include the buffer and memory.
*/
Expand Down Expand Up @@ -301,6 +325,21 @@ class TensorT : public Tensor
{

public:
TensorT(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
const size_t size,
const TensorTypes& tensorType = TensorTypes::eDevice)
: Tensor(physicalDevice,
device,
size,
sizeof(T),
this->dataType(),
tensorType)
{
KP_LOG_DEBUG("Kompute TensorT constructor with data size {}",
data.size());
}

TensorT(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
const std::vector<T>& data,
Expand All @@ -313,7 +352,7 @@ class TensorT : public Tensor
this->dataType(),
tensorType)
{
KP_LOG_DEBUG("Kompute TensorT constructor with data size {}",
KP_LOG_DEBUG("Kompute TensorT filling constructor with data size {}",
data.size());
}

Expand Down

0 comments on commit 5010ac4

Please sign in to comment.