Skip to content

Commit

Permalink
Merge pull request #881 from jmduarte/split_pointwise_conv_by_rf_codegen
Browse files Browse the repository at this point in the history
Pointwise Conv1D with code generation for "Latency" strategy (update of #811)
  • Loading branch information
JanFSchulte authored Dec 4, 2024
2 parents 22878ce + 4a1c25a commit 2fc8941
Show file tree
Hide file tree
Showing 14 changed files with 354 additions and 47 deletions.
File renamed without changes.
25 changes: 23 additions & 2 deletions hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
typedef {config_t} mult_config;
template<unsigned K, unsigned S, unsigned W>
using scale_index = nnet::{scale_index_type}<K, S, W>;
template<class data_T, class res_T, class CONFIG_T>
using conv_kernel = nnet::{conv_fn}<data_T, res_T, CONFIG_T>;
}};
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""

Expand Down Expand Up @@ -93,11 +95,30 @@ def format(self, node):
else:
params['fill_fn'] = 'FillConv1DBuffer'

is_pointwise_parallel_latency = (
node.get_attr('filt_width') == 1
and node.get_attr('strategy').lower() == 'latency'
and node.model.config.get_config_value('IOType') == 'io_parallel'
)
if is_pointwise_parallel_latency:
params['conv_fn'] = f'pointwise_conv_{node.index}'
else:
if node.get_attr('strategy').lower() == 'latency':
params['conv_fn'] = 'Conv1DLatency'
else:
params['conv_fn'] = 'Conv1DResource'

conv_config = self.template.format(**params)

mult_params = self._default_config_params(node)
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
if is_pointwise_parallel_latency:
mult_params['n_in'] = int(
node.get_attr('in_width') * node.get_attr('n_chan') * node.get_attr('filt_width') / mult_params['reuse']
)
mult_params['n_out'] = int(node.get_attr('in_width') * node.get_attr('n_filt') / mult_params['reuse'])
else:
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
mult_params['nzeros'] = node.get_weights('weight').nzeros
mult_params['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
Expand Down
84 changes: 84 additions & 0 deletions hls4ml/backends/vivado/passes/pointwise_codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from hls4ml.model.layers import Conv1D
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import Source


def generate_pointwise_conv1d_fn(layer_idx, reuse_factor=1):
"""Generate a C++ function for a pointwise convolution layer.
Args:
layer_idx (int): Index of layer ('index' attribute).
reuse_factor (int): Number of partitions to divide the input into.
Returns:
str: Generated C++ function
"""

generated_code = (
'template<class data_T, class res_T, typename CONFIG_T>\n'
'class pointwise_conv_{index} : public Conv1DKernel<data_T, res_T, CONFIG_T> {{\n'
' public:\n'
' static void conv(\n'
' data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n'
' res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],\n'
' typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],\n'
' typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {{\n'
' data_T data_tmp[CONFIG_T::reuse_factor][CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor];\n' # noqa: E501
' #pragma HLS ARRAY_PARTITION variable=data_tmp complete dim=0\n'
' res_T res_tmp[CONFIG_T::reuse_factor][CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor];\n' # noqa: E501
' #pragma HLS ARRAY_PARTITION variable=res_tmp complete dim=0\n\n'
' RFInputLoop:\n'
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {{\n'
' #pragma HLS UNROLL\n'
' InnerInputLoop:\n'
' for (int ii = 0; ii < CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor; ii++) {{\n'
' #pragma HLS UNROLL\n'
' data_tmp[jj][ii] = data[jj * CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor + ii];\n' # noqa: E501
' }}\n'
' }}\n\n'
).format(index=layer_idx)
indent = ' '
for i in range(reuse_factor):
generated_code += indent
generated_code += (
f'pointwise_conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data_tmp[{i}], res_tmp[{i}], weights, biases);\n'
)

generated_code += (
'\n'
' RFOutputLoop:\n'
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {\n'
' #pragma HLS UNROLL\n'
' InnerOutputLoop:\n'
' for (int ii = 0; ii < CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor; ii++) {\n'
' #pragma HLS UNROLL\n'
' res[jj * CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor + ii] = res_tmp[jj][ii];\n' # noqa: E501
' }\n'
' }\n'
' }\n'
'};\n'
)

return generated_code


class GeneratePointwiseConv1D(OptimizerPass):
'''Generates code for pointwise 1D convolution'''

def match(self, node):
return (
isinstance(node, Conv1D)
and node.model.config.get_config_value('IOType') == 'io_parallel'
and node.get_attr('filt_width') == 1
)

def transform(self, model, node):
self._generate_pointwise_conv1d(node)

def _generate_pointwise_conv1d(self, node):
code_str = generate_pointwise_conv1d_fn(
node.get_attr('index'),
node.get_attr('reuse_factor'),
)

node.set_attr('pointwise_conv1d_codegen', Source(code_str))
2 changes: 1 addition & 1 deletion hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _register_layer_attributes(self):
cnn_layers = [Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Pooling1D, Pooling2D]
for layer in cnn_layers:
attrs = self.attribute_map.get(layer, [])
# attrs.append(ConfigurableAttribute('conv_implementation', value_type=str, default='LineBuffer'))
attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer'))
self.attribute_map[layer] = attrs

Expand Down Expand Up @@ -114,6 +113,7 @@ def _register_flows(self):
'vivado:generate_conv_streaming_instructions',
'vivado:apply_resource_strategy',
'vivado:generate_conv_im2col',
'vivado:generate_pointwise_conv1_d',
'vivado:generate_unrolled_dense_resource',
'vivado:set_pipeline_style',
]
Expand Down
30 changes: 21 additions & 9 deletions hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "nnet_common.h"
#include "nnet_conv1d_latency.h"
#include "nnet_conv1d_resource.h"
#include "nnet_function_stubs.h"
#include <cstdlib>

namespace nnet {
Expand Down Expand Up @@ -38,11 +39,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
//#pragma HLS INLINE recursive

if (CONFIG_T::strategy == nnet::latency) {
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
} else {
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T>
Expand All @@ -55,13 +52,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
//#pragma HLS INLINE recursive

// Nothing special to be done for io_parallel implementation
if (CONFIG_T::strategy == nnet::latency) {
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
//#pragma HLS INLINE region
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
} else {
}
};

template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
//#pragma HLS INLINE region
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
}
};

} // namespace nnet

Expand Down
78 changes: 78 additions & 0 deletions hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,83 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
}
}

template <class data_T, class res_T, typename CONFIG_T>
void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor],
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor],
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
assert(CONFIG_T::filt_width == 1);

typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor];
typename CONFIG_T::accum_t acc[CONFIG_T::out_width / CONFIG_T::reuse_factor][CONFIG_T::n_filt];

#pragma HLS ARRAY_PARTITION variable=mult complete dim=0
#pragma HLS ARRAY_PARTITION variable=acc complete dim=0

// Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases
#pragma HLS function_instantiate variable=weights,biases

// Parallel mode
#pragma HLS PIPELINE II=CONFIG_T::reuse_factor
#pragma HLS ARRAY_PARTITION variable=weights complete dim=0
#pragma HLS ARRAY_PARTITION variable=biases complete dim=0

// Limit multipliers to control parallelization
#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit

// Convolve, saving all multiplication results to accumulate later
ConvOut:
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
ConvFilt:
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
ConvChan:
for (int cc = 0; cc < CONFIG_T::n_chan; cc++) {
#pragma HLS UNROLL
int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc;
int index_weight = cc * CONFIG_T::n_filt + ff;
int index_data = (ii * CONFIG_T::stride_width - CONFIG_T::pad_left) * CONFIG_T::n_chan + cc;

if ((ii * CONFIG_T::stride_width) < CONFIG_T::pad_left ||
(ii * CONFIG_T::stride_width) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) {
mult[index_mult] = 0;
} else {
mult[index_mult] = CONFIG_T::mult_config::template product<data_T, typename CONFIG_T::weight_t>::product(
data[index_data], weights[index_weight]);
}
} // end channel loop
} // end filter loop
} // end output loop

// Initialize accumulator with input biases
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
#pragma HLS UNROLL
acc[ii][ff] = biases[ff];
}
}

// Accumulate multiplication result
AccumOut:
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
AccumFilt:
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
// Do "dot product" sum within filter and sum over channels
AccumChan:
for (int cc = 0; cc < CONFIG_T::n_chan; cc++) {
int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc;
acc[ii][ff] += mult[index_mult];
} // end channel loop
} // end filter loop
} // end output loop

// Cast to "res_t" type
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
#pragma HLS UNROLL
res[ii * CONFIG_T::n_filt + ff] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ii][ff]);
}
}
}

} // namespace nnet
#endif
2 changes: 1 addition & 1 deletion hls4ml/templates/vivado/build_prj.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ if {$opt(reset)} {
} else {
open_solution "solution1"
}
catch {config_array_partition -maximum_size 4096}
catch {config_array_partition -maximum_size $maximum_size}
config_compile -name_max_length 80
set_part $part
config_schedule -enable_dsp_full_reg=false
Expand Down
11 changes: 11 additions & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef NNET_INSTR_GEN_H_
#define NNET_INSTR_GEN_H_

#include "nnet_conv1d_latency.h"
#include "nnet_helpers.h"

#include "hls_stream.h"
Expand All @@ -10,6 +11,16 @@

namespace nnet {

template <class data_T, class res_T, typename CONFIG_T> class PointwiseConv1D {
public:
static void pointwise_conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
// To be implemented in subclasses
}
};

// hls4ml insert code

} // namespace nnet
Expand Down
1 change: 1 addition & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define NNET_COMMON_H_

#include "ap_fixed.h"
#include "nnet_helpers.h"

// This is a substitute for "ceil(n/(float)d)".
#define DIV_ROUNDUP(n, d) ((n + d - 1) / d)
Expand Down
30 changes: 21 additions & 9 deletions hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "nnet_common.h"
#include "nnet_conv1d_latency.h"
#include "nnet_conv1d_resource.h"
#include "nnet_function_stubs.h"
#include <cstdlib>

namespace nnet {
Expand Down Expand Up @@ -37,11 +38,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
#pragma HLS INLINE region

if (CONFIG_T::strategy == nnet::latency) {
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
} else {
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T>
Expand All @@ -53,13 +50,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],

#pragma HLS INLINE region

// Nothing special to be done for io_parallel implementation
if (CONFIG_T::strategy == nnet::latency) {
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
#pragma HLS INLINE region
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
} else {
}
};

template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
#pragma HLS INLINE region
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
}
};

} // namespace nnet

Expand Down
Loading

0 comments on commit 2fc8941

Please sign in to comment.