Skip to content

Commit

Permalink
Adding API changes for external source and set_mem_handle
Browse files Browse the repository at this point in the history
  • Loading branch information
SundarRajan98 committed Jan 9, 2024
1 parent 558911b commit 4c0f40d
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 8 deletions.
2 changes: 0 additions & 2 deletions rocAL/include/api/rocal_api_data_loaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,6 @@ extern "C" RocalTensor ROCAL_API_CALL rocalJpegCaffe2LMDBRecordSourcePartialSing
/*! \brief Creates JPEG external source image reader.
* \ingroup group_rocal_data_loaders
* \param [in] rocal_context Rocal context
* \param [in] source_path A NULL terminated char string pointing to the location on the disk
* \param [in] rocal_color_format The color format the images will be decoded to.
* \param [in] is_output Determines if the user wants the loaded images to be part of the output or not.
* \param [in] shuffle Determines if the user wants to shuffle the dataset or not.
Expand All @@ -816,7 +815,6 @@ extern "C" RocalTensor ROCAL_API_CALL rocalJpegCaffe2LMDBRecordSourcePartialSing
* \return Reference to the output tensor
*/
extern "C" RocalTensor ROCAL_API_CALL rocalJpegExternalFileSource(RocalContext p_context,
const char* source_path,
RocalImageColor rocal_color_format,
bool is_output = false,
bool shuffle = false,
Expand Down
8 changes: 7 additions & 1 deletion rocAL/include/pipeline/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,13 @@ class Tensor : public rocalTensor {
void* buffer() { return _mem_handle; }
vx_tensor handle() { return _vx_handle; }
vx_context context() { return _context; }
void set_mem_handle(void* buffer) override { _mem_handle = buffer; }
void set_mem_handle(void* buffer) override {
if (buffer)
_mem_handle = buffer;
else {
THROW("Invalid buffer pointer passed")
}
}
#if ENABLE_OPENCL
unsigned copy_data(cl_command_queue queue, unsigned char* user_buffer, bool sync);
unsigned copy_data(cl_command_queue queue, cl_mem user_buffer, bool sync);
Expand Down
3 changes: 1 addition & 2 deletions rocAL/source/api/rocal_api_data_loaders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2022,7 +2022,6 @@ rocalRawCIFAR10Source(
RocalTensor ROCAL_API_CALL
rocalJpegExternalFileSource(
RocalContext p_context,
const char* source_path,
RocalImageColor rocal_color_format,
bool is_output,
bool shuffle,
Expand Down Expand Up @@ -2062,7 +2061,7 @@ rocalJpegExternalFileSource(

unsigned shard_count = 1; // Hardcoding the shard count to 1 for now.
auto cpu_num_threads = context->master_graph->calculate_cpu_num_threads(shard_count);
context->master_graph->add_node<ImageLoaderNode>({}, {output})->init(shard_count, cpu_num_threads, source_path, "", std::map<std::string, std::string>(), StorageType::EXTERNAL_FILE_SOURCE,
context->master_graph->add_node<ImageLoaderNode>({}, {output})->init(shard_count, cpu_num_threads, "", "", std::map<std::string, std::string>(), StorageType::EXTERNAL_FILE_SOURCE,
decType, shuffle, loop, context->user_batch_size(), context->master_graph->mem_type(), context->master_graph->meta_data_reader(),
decoder_keep_original, "", 0, 0, 0, ExternalSourceFileMode(external_source_mode));
context->master_graph->set_loop(loop);
Expand Down
2 changes: 1 addition & 1 deletion rocAL_pybind/amd/rocal/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def external_source(*inputs, source, device=None, color_format=types.RGB, random
Pipeline._current_pipeline._external_source_mode = mode
Pipeline._current_pipeline._external_source_user_given_width = max_width
Pipeline._current_pipeline._external_source_user_given_height = max_height
kwargs_pybind = {"source_path": source.images_dir, "rocal_color_format": color_format, "is_output": False, "shuffle": random_shuffle, "loop": False, "decode_size_policy": types.USER_GIVEN_SIZE,
kwargs_pybind = {"rocal_color_format": color_format, "is_output": False, "shuffle": random_shuffle, "loop": False, "decode_size_policy": types.USER_GIVEN_SIZE,
"max_width": max_width, "max_height": max_height, "dec_type": types.DECODER_TJPEG, "external_source_mode": mode}
external_source_operator = b.externalFileSource(
Pipeline._current_pipeline._handle, *(kwargs_pybind.values()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,12 @@ int main(int argc, const char **argv) {
}
if (max_height != 0 && max_width != 0) {
input1 = rocalJpegExternalFileSource(
handle, folder_path, color_format, false, false, false,
handle, color_format, false, false, false,
ROCAL_USE_USER_GIVEN_SIZE, max_width, max_height,
RocalDecoderType::ROCAL_DECODER_TJPEG, RocalExternalSourceMode(mode));
} else {
input1 = rocalJpegExternalFileSource(
handle, folder_path, color_format, false, false, false,
handle, color_format, false, false, false,
ROCAL_USE_USER_GIVEN_SIZE, decode_width, decode_height,
RocalDecoderType::ROCAL_DECODER_TJPEG, RocalExternalSourceMode(mode));
}
Expand Down

0 comments on commit 4c0f40d

Please sign in to comment.