-
Notifications
You must be signed in to change notification settings - Fork 0
/
NvInferRuntime.h
2007 lines (1850 loc) · 83.5 KB
/
NvInferRuntime.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef NV_INFER_RUNTIME_H
#define NV_INFER_RUNTIME_H
//!
//! \file NvInferRuntime.h
//!
//! This is the top-level API file for TensorRT extended runtime library.
//!
#include "NvInferRuntimeCommon.h"
namespace nvinfer1
{
class IExecutionContext; //!< Forward declaration of IExecutionContext for use by other interfaces.
class ICudaEngine; //!< Forward declaration of ICudaENgine for use by other interfaces.
class IPluginFactory; //!< Forward declaration of IPluginFactory for use by other interfaces.
//!
//! \enum EngineCapability
//!
//! \brief List of supported engine capability flows.
//!
//! The EngineCapability determines the restrictions of a network during build time for what can be executed
//! at runtime. EngineCapability::kDEFAULT does not provide any restrictions on functionality and the
//! resulting serialized engine can be executed with TensorRT's standard runtime APIs in the nvinfer1 namespace.
//! EngineCapabiltiy::kSAFE_GPU provides a restricted subset of network operations that are safety certified and
//! the resulting serialized engine can be executed with TensorRT's safe runtime APIs in the nvinfer1::safe namespace.
//! EngineCapability::kSAFE_DLA provides a restricted subset of network operations that are DLA compatible and
//! the resulting serialized engine can be executed using NvMediaDLA's runtime APIs. See sampleNvmedia for an
//! example of integrating NvMediaDLA APIs with TensorRT APIs.
//!
enum class EngineCapability : int32_t
{
kDEFAULT = 0, //!< Full capability, TensorRT mode without any restrictions using TensorRT nvinfer1 APIs.
kSAFE_GPU = 1, //!< Safety restricted capability, TensorRT flow that can only run on GPU devices via TensorRT
//!< nvinfer1::safe APIs.
kSAFE_DLA = 2, //!< Safety restricted capability, TensorRT flow that can only run on DLA devices via
//!< NvMediaDLA APIs.
};
//! Maximum number of elements in EngineCapability enum. \see EngineCapability
template <>
constexpr inline int32_t EnumMax<EngineCapability>()
{
return 3;
}
//!
//! \class Weights
//!
//! \brief An array of weights used as a layer parameter.
//!
//! When using the DLA, the cumulative size of all Weights used in a network
//! must be less than 512MB in size. If the build option kGPU_FALLBACK is specified,
//! then multiple DLA sub-networks may be generated from the single original network.
//!
//! The weights are held by reference until the engine has been built. Therefore the data referenced
//! by \p values field should be preserved until the build is complete.
//!
class Weights
{
public:
DataType type; //!< The type of the weights.
const void* values; //!< The weight values, in a contiguous array.
int64_t count; //!< The number of weights in the array.
};
//!
//! \class IHostMemory
//!
//! \brief Class to handle library allocated memory that is accessible to the user.
//!
//! The memory allocated via the host memory object is owned by the library and will
//! be de-allocated when the destroy method is called.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
class IHostMemory
{
public:
virtual void* data() const noexcept = 0; //!< A pointer to the raw data that is owned by the library.
virtual std::size_t size() const noexcept = 0; //!< The size in bytes of the data that was allocated.
virtual DataType type() const noexcept = 0; //!< The type of the memory that was allocated.
virtual void destroy() noexcept = 0; //!< Destroy the allocated memory.
protected:
virtual ~IHostMemory() {}
};
//! \class IPlugin
//!
//! \brief Plugin class for user-implemented layers.
//!
//! Plugins are a mechanism for applications to implement custom layers. Each plugin is owned by the application, and its lifetime
//! must span any use of it by TensorRT
//!
class IPlugin
{
public:
//!
//! \brief Get the number of outputs from the layer.
//!
//! \return The number of outputs.
//!
//! This function is called by the implementations of INetworkDefinition and IBuilder. In particular, it is called
//! prior to any call to initialize().
//!
virtual int32_t getNbOutputs() const TRTNOEXCEPT = 0;
//!
//! \brief Get the dimension of an output tensor.
//!
//! \param index The index of the output tensor.
//! \param inputs The input tensors.
//! \param nbInputDims The number of input tensors.
//!
//! This function is called by the implementations of INetworkDefinition and IBuilder. In particular, it is called
//! prior to any call to initialize().
//!
virtual Dims getOutputDimensions(int32_t index, const Dims* inputs, int32_t nbInputDims) TRTNOEXCEPT = 0;
//!
//! \brief Configure the layer.
//!
//! This function is called by the builder prior to initialize(). It provides an opportunity for the layer to make
//! algorithm choices on the basis of its weights, dimensions, and maximum batch size. The type is assumed to be
//! FP32 and format NCHW.
//!
//! \param inputDims The input tensor dimensions.
//! \param nbInputs The number of inputs.
//! \param outputDims The output tensor dimensions.
//! \param nbOutputs The number of outputs.
//! \param maxBatchSize The maximum batch size.
//!
//! The dimensions passed here do not include the outermost batch size (i.e. for 2-D image networks, they will be
//! 3-dimensional CHW dimensions).
//!
//! This method is not called for PluginExt classes, configureWithFormat is called instead.
//!
virtual void configure(const Dims* inputDims, int32_t nbInputs, const Dims* outputDims, int32_t nbOutputs,
int32_t maxBatchSize) TRTNOEXCEPT = 0;
//!
//! \brief Initialize the layer for execution. This is called when the engine is created.
//!
//! \return 0 for success, else non-zero (which will cause engine termination).
//!
virtual int32_t initialize() TRTNOEXCEPT = 0;
//!
//! \brief Release resources acquired during plugin layer initialization. This is called when the engine is
//! destroyed. \see initialize()
//!
virtual void terminate() TRTNOEXCEPT = 0;
//!
//! \brief Find the workspace size required by the layer.
//!
//! This function is called during engine startup, after initialize(). The workspace size returned should be
//! sufficient for any batch size up to the maximum.
//!
//! \return The workspace size.
//!
virtual size_t getWorkspaceSize(int32_t maxBatchSize) const TRTNOEXCEPT = 0;
//!
//! \brief Execute the layer.
//!
//! \param batchSize The number of inputs in the batch.
//! \param inputs The memory for the input tensors.
//! \param outputs The memory for the output tensors.
//! \param workspace Workspace for execution.
//! \param stream The stream in which to execute the kernels.
//!
//! \return 0 for success, else non-zero (which will cause engine termination).
//!
virtual int32_t enqueue(int32_t batchSize, const void* const* inputs, void** outputs, void* workspace,
cudaStream_t stream) TRTNOEXCEPT = 0;
//!
//! \brief Find the size of the serialization buffer required.
//!
//! \return The size of the serialization buffer.
//!
virtual size_t getSerializationSize() TRTNOEXCEPT = 0;
//!
//! \brief Serialize the layer.
//!
//! \param buffer A pointer to a buffer of size at least that returned by getSerializationSize().
//!
//! \see getSerializationSize()
//!
virtual void serialize(void* buffer) TRTNOEXCEPT = 0;
virtual ~IPlugin() {}
};
//!
//! \class IPluginExt
//!
//! \brief Plugin class for user-implemented layers.
//!
//! Plugins are a mechanism for applications to implement custom layers. Each plugin is owned by the application, and its lifetime
//! must span any use of it by TensorRT.
//!
class IPluginExt : public IPlugin
{
public:
//!
//! \brief Return the API version with which this plugin was built.
//!
//! Do not override this method as it is used by the TensorRT library to maintain backwards-compatibility with
//! plugins.
//!
virtual int32_t getTensorRTVersion() const TRTNOEXCEPT
{
return NV_TENSORRT_VERSION;
}
//!
//! \brief Check format support.
//!
//! \param type DataType requested.
//! \param format PluginFormat requested.
//! \return true if the plugin supports the type-format combination.
//!
//! This function is called by the implementations of INetworkDefinition, IBuilder, and ICudaEngine.
//! In particular, it is called when creating an engine and when deserializing an engine.
//!
//! \warning DataType:kBOOL not supported.
//!
virtual bool supportsFormat(DataType type, PluginFormat format) const TRTNOEXCEPT = 0;
//!
//! \brief Configure the layer.
//!
//! This function is called by the builder prior to initialize(). It provides an opportunity for the layer to make
//! algorithm choices on the basis of its weights, dimensions, and maximum batch size.
//!
//! \param inputDims The input tensor dimensions.
//! \param nbInputs The number of inputs.
//! \param outputDims The output tensor dimensions.
//! \param nbOutputs The number of outputs.
//! \param type The data type selected for the engine.
//! \param format The format selected for the engine.
//! \param maxBatchSize The maximum batch size.
//!
//! The dimensions passed here do not include the outermost batch size (i.e. for 2-D image networks, they will be
//! 3-dimensional CHW dimensions).
//!
//! \warning DataType:kBOOL not supported.
//!
virtual void configureWithFormat(const Dims* inputDims, int32_t nbInputs, const Dims* outputDims, int32_t nbOutputs,
DataType type, PluginFormat format, int32_t maxBatchSize) TRTNOEXCEPT = 0;
virtual ~IPluginExt() {}
protected:
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
void configure(const Dims* /*inputDims*/, int32_t /*nbInputs*/, const Dims* /*outputDims*/, int32_t /*nbOutputs*/,
int32_t /*maxBatchSize*/) _TENSORRT_FINAL TRTNOEXCEPT
{
}
};
//!
//! \enum DimensionOperation
//!
//! \brief An operation on two IDimensionExpr, which represent integer expressions used in dimension computations.
//!
//! For example, given two IDimensionExpr x and y and an IExprBuilder& eb,
//! eb.operation(DimensionOperation::kSUM, x, y) creates a representation of x+y.
//!
//! \see IDimensionExpr, IExprBuilder
//!
enum class DimensionOperation : int32_t
{
kSUM = 0, //!< Sum of the two operands.
kPROD = 1, //!< Product of the two operands.
kMAX = 2, //!< Maximum of the two operands.
kMIN = 3, //!< Minimum of the two operands.
kSUB = 4, //!< Substract the second element from the first.
kEQUAL = 5, //!< 1 if operands are equal, 0 otherwise.
kLESS = 6, //!< 1 if first operand is less than second operand, 0 otherwise.
kFLOOR_DIV = 7, //!< Floor division of the first element by the second.
kCEIL_DIV = 8 //!< Division rounding up
};
//! Maximum number of elements in DimensionOperation enum. \see DimensionOperation
template <>
constexpr inline int32_t EnumMax<DimensionOperation>()
{
return 9;
}
//!
//! \class IDimensionExpr
//!
//! An IDimensionExpr represents an integer expression constructed from constants,
//! input dimensions, and binary operations. These expressions are can be used
//! in overrides of IPluginV2DynamicExt::getOutputDimensions to define output
//! dimensions in terms of input dimensions.
//!
//! \see DimensionOperation, IPluginV2DynamicExt::getOutputDimensions
//!
class IDimensionExpr
{
public:
//! Return true if expression is a build-time constant.
virtual bool isConstant() const = 0;
//! If isConstant(), returns value of the constant.
//! If !isConstant(), return std::numeric_limits<int32_t>::min().
virtual int32_t getConstantValue() const = 0;
protected:
virtual ~IDimensionExpr() {}
};
//!
//! \class IExprBuilder
//!
//! Object for constructing IDimensionExpr.
//!
//! There is no public way to construct an IExprBuilder. It appears as an argument to
//! method IPluginV2DynamicExt::getOutputDimensions(). Overrides of that method can use
//! that IExprBuilder argument to construct expressions that define output dimensions
//! in terms of input dimensions.
//!
//! Clients should assume that any values constructed by the IExprBuilder are destroyed
//! after IPluginV2DynamicExt::getOutputDimensions() returns.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
//! \see IDimensionExpr
//!
class IExprBuilder
{
public:
//! Return pointer to IDimensionExp for given value.
virtual const IDimensionExpr* constant(int32_t value) = 0;
//! Return pointer to IDimensionExp that represents the given operation applied to first and second.
//! Returns nullptr if op is not a valid DimensionOperation.
virtual const IDimensionExpr* operation(DimensionOperation op, const IDimensionExpr& first, const IDimensionExpr& second) = 0;
protected:
virtual ~IExprBuilder() {}
};
//!
//! \class DimsExprs
//!
//! Analog of class Dims with expressions instead of constants for the dimensions.
//!
class DimsExprs
{
public:
int32_t nbDims; //!< The number of dimensions.
const IDimensionExpr* d[Dims::MAX_DIMS]; //!< The extent of each dimension.
};
//!
//! \class DynamicPluginTensorDesc
//!
//! Summarizes tensors that a plugin might see for an input or output.
//!
struct DynamicPluginTensorDesc
{
//! Information required to interpret a pointer to tensor data, except that desc.dims has -1 in place of any runtime dimension.
PluginTensorDesc desc;
//! Lower bounds on tensor’s dimensions
Dims min;
//! Upper bounds on tensor’s dimensions
Dims max;
};
//!
//! \class IPluginV2DynamicExt
//!
//! Similar to IPluginV2Ext, but with support for dynamic shapes.
//!
//! Clients should override the public methods, including the following inherited methods:
//!
//! virtual int32_t getNbOutputs() const TRTNOEXCEPT = 0;
//! virtual nvinfer1::DataType getOutputDataType(int32_t index, const nvinfer1::DataType* inputTypes, int32_t
//! nbInputs) const TRTNOEXCEPT = 0; virtual size_t getSerializationSize() const TRTNOEXCEPT = 0; virtual void
//! serialize(void* buffer) const TRTNOEXCEPT = 0; virtual void destroy() TRTNOEXCEPT = 0; virtual void
//! setPluginNamespace(const char* pluginNamespace) TRTNOEXCEPT = 0; virtual const char* getPluginNamespace() const
//! TRTNOEXCEPT = 0;
//!
//! For getOutputDataType, the inputTypes will always be DataType::kFLOAT or DataType::kINT32,
//! and the returned type is canonicalized to DataType::kFLOAT if it is DataType::kHALF or DataType:kINT8.
//! Details about the floating-point precision are elicited later by method supportsFormatCombination.
//!
class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext
{
public:
IPluginV2DynamicExt* clone() const _TENSORRT_OVERRIDE TRTNOEXCEPT = 0;
//!
//! \brief Get expressions for computing dimensions of an output tensor from dimensions of the input tensors.
//!
//! \param outputIndex The index of the output tensor
//! \param inputs Expressions for dimensions of the input tensors
//! \param nbInputDims The number of input tensors
//! \param exprBuilder Object for generating new expressions
//!
//! This function is called by the implementations of IBuilder during analysis of the network.
//!
//! Example #1: A plugin has a single output that transposes the last two dimensions of the plugin's single input.
//! The body of the override of getOutputDimensions can be:
//!
//! DimsExprs output(inputs[0]);
//! std::swap(output.d[output.nbDims-1], output.d[output.nbDims-2]);
//! return output;
//!
//! Example #2: A plugin concatenates its two inputs along the first dimension.
//! The body of the override of getOutputDimensions can be:
//!
//! DimsExprs output(inputs[0]);
//! output.d[0] = exprBuilder.operation(DimensionOperation::kSUM, *inputs[0].d[0], *inputs[1].d[0]);
//! return output;
//!
virtual DimsExprs getOutputDimensions(
int32_t outputIndex, const DimsExprs* inputs, int32_t nbInputs, IExprBuilder& exprBuilder)
= 0;
//!
//! Limit on number of format combinations accepted.
//!
static constexpr int32_t kFORMAT_COMBINATION_LIMIT = 100;
//!
//! \brief Return true if plugin supports the format and datatype for the input/output indexed by pos.
//!
//! For this method inputs are numbered 0..(nbInputs-1) and outputs are numbered nbInputs..(nbInputs+nbOutputs-1).
//! Using this numbering, pos is an index into InOut, where 0 <= pos < nbInputs+nbOutputs-1.
//!
//! TensorRT invokes this method to ask if the input/output indexed by pos supports the format/datatype specified
//! by inOut[pos].format and inOut[pos].type. The override should return true if that format/datatype at inOut[pos]
//! are supported by the plugin. If support is conditional on other input/output formats/datatypes, the plugin can
//! make its result conditional on the formats/datatypes in inOut[0..pos-1], which will be set to values
//! that the plugin supports. The override should not inspect inOut[pos+1..nbInputs+nbOutputs-1],
//! which will have invalid values. In other words, the decision for pos must be based on inOut[0..pos] only.
//!
//! Some examples:
//!
//! * A definition for a plugin that supports only FP16 NCHW:
//!
//! return inOut.format[pos] == TensorFormat::kLINEAR && inOut.type[pos] == DataType::kHALF;
//!
//! * A definition for a plugin that supports only FP16 NCHW for its two inputs,
//! and FP32 NCHW for its single output:
//!
//! return inOut.format[pos] == TensorFormat::kLINEAR && (inOut.type[pos] == pos < 2 ? DataType::kHALF :
//! DataType::kFLOAT);
//!
//! * A definition for a "polymorphic" plugin with two inputs and one output that supports
//! any format or type, but the inputs and output must have the same format and type:
//!
//! return pos == 0 || (inOut.format[pos] == inOut.format[0] && inOut.type[pos] == inOut.type[0]);
//!
//! Warning: TensorRT will stop asking for formats once it finds kFORMAT_COMBINATION_LIMIT on combinations.
//!
virtual bool supportsFormatCombination(
int32_t pos, const PluginTensorDesc* inOut, int32_t nbInputs, int32_t nbOutputs) TRTNOEXCEPT = 0;
//!
//! \brief Configure the layer.
//!
//! This function is called by the builder prior to initialize(). It provides an opportunity for the layer to make
//! algorithm choices on the basis of bounds on the input and output tensors, and the target value.
//!
//! This function is also called once when the resource requirements are changed based on the optimization profiles.
//!
//! \param in The input tensors attributes that are used for configuration.
//! \param nbInputs Number of input tensors.
//! \param out The output tensors attributes that are used for configuration.
//! \param nbOutputs Number of output tensors.
//!
virtual void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs,
const DynamicPluginTensorDesc* out, int32_t nbOutputs) TRTNOEXCEPT = 0;
//!
//! \brief Find the workspace size required by the layer.
//!
//! This function is called after the plugin is configured, and possibly during execution.
//! The result should be a sufficient workspace size to deal with inputs and outputs of the given size
//! or any smaller problem.
//!
//! \return The workspace size.
//!
virtual size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, const PluginTensorDesc* outputs,
int32_t nbOutputs) const TRTNOEXCEPT = 0;
//!
//! \brief Execute the layer.
//!
//! \param inputDesc how to interpret the memory for the input tensors.
//! \param outputDesc how to interpret the memory for the output tensors.
//! \param inputs The memory for the input tensors.
//! \param outputs The memory for the output tensors.
//! \param workspace Workspace for execution.
//! \param stream The stream in which to execute the kernels.
//!
//! \return 0 for success, else non-zero (which will cause engine termination).
//!
virtual int32_t enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) TRTNOEXCEPT = 0;
protected:
int32_t getTensorRTVersion() const _TENSORRT_OVERRIDE TRTNOEXCEPT
{
return (static_cast<int32_t>(PluginVersion::kV2_DYNAMICEXT) << 24 | (NV_TENSORRT_VERSION & 0xFFFFFF));
}
virtual ~IPluginV2DynamicExt() {}
// Rest of the methods below are obsolete inherited methods, and marked final when using a C++11 compiler.
// Derived classes should not override them.
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
//! Instead, derived classes should override the overload of getOutputDimensions that returns DimsExprs.
//!
//! \deprecated Deprecated interface will be removed in TensorRT 8.0.
//!
TRT_DEPRECATED
Dims getOutputDimensions(
int32_t /*index*/, const Dims* /*inputs*/, int32_t /*nbInputDims*/) _TENSORRT_FINAL TRTNOEXCEPT
{
return Dims{-1, {}, {}};
}
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
//! This method is not used because with dynamic shapes there is no implicit batch dimension to broadcast across.
//!
//! \deprecated Deprecated interface will be removed in TensorRT 8.0.
//!
TRT_DEPRECATED
bool isOutputBroadcastAcrossBatch(int32_t /*outputIndex*/, const bool* /*inputIsBroadcasted*/,
int32_t /*nbInputs*/) const _TENSORRT_FINAL TRTNOEXCEPT
{
return false;
}
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
//! This method is not used because with dynamic shapes there is no implicit batch dimension to broadcast across.
//!
//! \deprecated Deprecated interface will be removed in TensorRT 8.0.
//!
TRT_DEPRECATED
bool canBroadcastInputAcrossBatch(int32_t /*inputIndex*/) const _TENSORRT_FINAL TRTNOEXCEPT
{
return true;
}
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
//! This method is not used because it does not allow a plugin to specify mixed formats.
//!
//! Instead, derived classes should override supportsFormatCombination, which allows plugins
//! to express mixed formats.
//!
//! \deprecated Deprecated interface will be removed in TensorRT 8.0.
//!
TRT_DEPRECATED
bool supportsFormat(DataType /*type*/, PluginFormat /*format*/) const _TENSORRT_FINAL TRTNOEXCEPT
{
return false;
}
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
//! This method is not used because tensors with dynamic shapes do not have an implicit batch dimension,
//! input dimensions might be variable, and outputs might have different floating-point formats.
//!
//! Instead, derived classes should override the overload of configurePlugin that takes poiners to
//! DynamicPluginTensorDesc.
//!
//! \deprecated Deprecated interface will be removed in TensorRT 8.0.
//!
TRT_DEPRECATED
void configurePlugin(const Dims* /*inputDims*/, int32_t /*nbInputs*/, const Dims* /*outputDims*/,
int32_t /*nbOutputs*/, const DataType* /*inputTypes*/, const DataType* /*outputTypes*/,
const bool* /*inputIsBroadcast*/, const bool* /*outputIsBroadcast*/, PluginFormat /*floatFormat*/,
int32_t /*maxBatchSize*/) _TENSORRT_FINAL TRTNOEXCEPT
{
}
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
//! This method is not used because tensors with dynamic shapes do not have an implicit batch dimension,
//! and the other dimensions might not be build-time constants.
//!
//! Instead, derived classes should override the overload of getWorkspaceSize that takes pointers to
//! PluginTensorDesc. The arguments to that overload provide maximum bounds on all dimensions.
//!
//! \deprecated Deprecated interface will be removed in TensorRT 8.0.
//!
TRT_DEPRECATED
size_t getWorkspaceSize(int32_t /*maxBatchSize*/) const _TENSORRT_FINAL TRTNOEXCEPT
{
return 0;
}
//!
//! \brief Derived classes should not implement this. In a C++11 API it would be override final.
//!
//! This method is not used because tensors with dynamic shapes can have different sizes in different execution
//! contexts.
//!
//! Instead, derived classes should override the overload of enqueue that takes pointers to PluginTensorDesc.
//!
//! \deprecated Deprecated interface will be removed in TensorRT 8.0.
//!
TRT_DEPRECATED
int32_t enqueue(int32_t /*batchSize*/, const void* const* /*inputs*/, void** /*outputs*/, void* /*workspace*/,
cudaStream_t /*stream*/) _TENSORRT_FINAL TRTNOEXCEPT
{
return 1;
}
};
//!
//! \class IProfiler
//!
//! \brief Application-implemented interface for profiling.
//!
//! When this class is added to an execution context, the profiler will be called once per layer for each invocation of execute().
//! Note that enqueue() does not currently support profiling.
//!
//! The profiler will only be called after execution is complete. It has a small impact on execution time.
//!
class IProfiler
{
public:
//!
//! \brief Layer time reporting callback.
//!
//! \param layerName The name of the layer, set when constructing the network definition.
//! \param ms The time in milliseconds to execute the layer.
//!
virtual void reportLayerTime(const char* layerName, float ms) TRTNOEXCEPT = 0;
virtual ~IProfiler() {}
};
//!
//! \enum WeightsRole
//! \brief How a layer uses particular Weights.
//!
//! The power weights of an IScaleLayer are omitted. Refitting those is not supported.
//!
enum class WeightsRole : int32_t
{
kKERNEL = 0, //!< kernel for IConvolutionLayer, IDeconvolutionLayer, or IFullyConnectedLayer
kBIAS = 1, //!< bias for IConvolutionLayer, IDeconvolutionLayer, or IFullyConnectedLayer
kSHIFT = 2, //!< shift part of IScaleLayer
kSCALE = 3, //!< scale part of IScaleLayer
kCONSTANT = 4, //!< weights for IConstantLayer
};
//! Maximum number of elements in WeightsRole enum. \see WeightsRole
template <>
constexpr inline int32_t EnumMax<WeightsRole>()
{
return 5;
}
//!
//! \enum DeviceType
//! \brief The device that this layer/network will execute on.
//!
//!
enum class DeviceType : int32_t
{
kGPU, //!< GPU Device
kDLA, //!< DLA Core
};
//! Maximum number of elements in DeviceType enum. \see DeviceType
template <>
constexpr inline int32_t EnumMax<DeviceType>()
{
return 2;
}
//!
//! \class IRuntime
//!
//! \brief Allows a serialized functionally unsafe engine to be deserialized.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
class IRuntime
{
public:
//!
//! \brief Deserialize an engine from a stream.
//!
//! \param blob The memory that holds the serialized engine.
//! \param size The size of the memory.
//! \param pluginFactory The plugin factory, if any plugins are used by the network, otherwise nullptr.
//!
//! \return The engine, or nullptr if it could not be deserialized.
//!
virtual nvinfer1::ICudaEngine* deserializeCudaEngine(const void* blob, std::size_t size, IPluginFactory* pluginFactory) noexcept = 0;
//!
//! \brief Set the DLA core that the deserialized engine must execute on.
//! \param dlaCore The DLA core to execute the engine on (0 to N-1, where N is the maximum number of DLA's present
//! on the device). Default value is 0. \see getDLACore()
//!
//! \warning Starting with TensorRT 8, the default value will be -1 if the DLA is not specified or unused.
//!
virtual void setDLACore(int32_t dlaCore) noexcept = 0;
//!
//! \brief Get the DLA core that the engine executes on.
//! \return If setDLACore is called, returns DLA core from 0 to N-1, else returns 0.
//!
//! \warning Starting with TensorRT 8, the default value will be -1 if the DLA is not specified or unused.
//!
virtual int32_t getDLACore() const noexcept = 0;
//!
//! \brief Returns number of DLA hardware cores accessible.
//!
virtual int32_t getNbDLACores() const noexcept = 0;
//!
//! \brief Destroy this object.
//!
virtual void destroy() noexcept = 0;
protected:
virtual ~IRuntime() {}
public:
//!
//! \brief Set the GPU allocator.
//! \param allocator Set the GPU allocator to be used by the runtime. All GPU memory acquired will use this allocator. If NULL is passed, the default allocator will be used.
//!
//! Default: uses cudaMalloc/cudaFree.
//!
//! If nullptr is passed, the default allocator will be used.
//!
virtual void setGpuAllocator(IGpuAllocator* allocator) noexcept = 0;
//!
//! \brief Set the ErrorRecorder for this interface
//!
//! Assigns the ErrorRecorder to this interface. The ErrorRecorder will track all errors during execution.
//! This function will call incRefCount of the registered ErrorRecorder at least once. Setting
//! recorder to nullptr unregisters the recorder with the interface, resulting in a call to decRefCount if
//! a recorder has been registered.
//!
//! \param recorder The error recorder to register with this interface.
//
//! \see getErrorRecorder
//!
virtual void setErrorRecorder(IErrorRecorder* recorder) noexcept = 0;
//!
//! \brief get the ErrorRecorder assigned to this interface.
//!
//! Retrieves the assigned error recorder object for the given class. A default error recorder does not exist,
//! so a nullptr will be returned if setErrorRecorder has not been called.
//!
//! \return A pointer to the IErrorRecorder object that has been registered.
//!
//! \see setErrorRecorder
//!
virtual IErrorRecorder* getErrorRecorder() const noexcept = 0;
//!
//! \brief Deserialize an engine from a stream when plugin factory is not used.
//!
//! \param blob The memory that holds the serialized engine.
//! \param size The size of the memory.
//!
//! \return The engine, or nullptr if it could not be deserialized.
//!
nvinfer1::ICudaEngine* deserializeCudaEngine(const void* blob, std::size_t size) noexcept
{
return deserializeCudaEngine(blob, size, nullptr);
}
};
//!
//! \class IRefitter
//!
//! \brief Updates weights in an engine.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
class IRefitter
{
public:
//!
//! \brief Specify new weights for a layer of given name.
//! Returns true on success, or false if new weights are rejected.
//! Possible reasons for rejection are:
//!
//! * There is no such layer by that name.
//! * The layer does not have weights with the specified role.
//! * The number of weights is inconsistent with the layer’s original specification.
//!
//! Modifying the weights before method refit() completes will result in undefined behavior.
virtual bool setWeights(const char* layerName, WeightsRole role, Weights weights) TRTNOEXCEPT = 0;
//!
//! \brief Updates associated engine. Return true if successful.
//!
//! Failure occurs if getMissing() != 0 before the call.
//!
virtual bool refitCudaEngine() TRTNOEXCEPT = 0;
//!
//! \brief Get description of missing weights.
//!
//! For example, if some Weights have been set, but the engine was optimized
//! in a way that combines weights, any unsupplied Weights in the combination
//! are considered missing.
//!
//! \param size The number of items that can be safely written to a non-null layerNames or roles.
//! \param layerNames Where to write the layer names.
//! \param roles Where to write the weights roles.
//!
//! \return The number of missing Weights.
//!
//! If layerNames!=nullptr, each written pointer points to a string owned by
//! the engine being refitted, and becomes invalid when the engine is destroyed.
//!
virtual int32_t getMissing(int32_t size, const char** layerNames, WeightsRole* roles) TRTNOEXCEPT = 0;
//!
//! \brief Get description of all weights that could be refit.
//!
//! \param size The number of items that can be safely written to a non-null layerNames or roles.
//! \param layerNames Where to write the layer names.
//! \param roles Where to write the weights roles.
//!
//! \return The number of Weights that could be refit.
//!
//! If layerNames!=nullptr, each written pointer points to a string owned by
//! the engine being refitted, and becomes invalid when the engine is destroyed.
//!
virtual int32_t getAll(int32_t size, const char** layerNames, WeightsRole* roles) TRTNOEXCEPT = 0;
virtual void destroy() TRTNOEXCEPT = 0;
protected:
virtual ~IRefitter() {}
public:
//!
//! Update dynamic range for a tensor.
//!
//! \param tensorName The name of an ITensor in the network.
//! \param min The minimum of the dynamic range for the tensor.
//! \param max The maximum of the dynamic range for the tensor.
//!
//! \return True if successful; false otherwise.
//!
//! Returns false if there is no Int8 engine tensor derived from
//! a network tensor of that name. If successful, then getMissing
//! may report that some weights need to be supplied.
virtual bool setDynamicRange(const char* tensorName, float min, float max) TRTNOEXCEPT = 0;
//!
//! \brief Get minimum of dynamic range.
//!
//! \return Minimum of dynamic range.
//!
//! If the dynamic range was never set, returns the minimum computed during calibration.
//!
virtual float getDynamicRangeMin(const char* tensorName) const TRTNOEXCEPT = 0;
//!
//! \brief Get maximum of dynamic range.
//!
//! \return Maximum of dynamic range.
//!
//! If the dynamic range was never set, returns the maximum computed during calibration.
//!
virtual float getDynamicRangeMax(const char* tensorName) const TRTNOEXCEPT = 0;
//!
//! \brief Get names of all tensors that have refittable dynamic ranges.
//!
//! \param size The number of items that can be safely written to a non-null tensorNames.
//! \param tensorNames Where to write the layer names.
//!
//! \return The number of Weights that could be refit.
//!
//! If tensorNames!=nullptr, each written pointer points to a string owned by
//! the engine being refitted, and becomes invalid when the engine is destroyed.
//!
virtual int32_t getTensorsWithDynamicRange(int32_t size, const char** tensorNames) const TRTNOEXCEPT = 0;
//!
//! \brief Set the ErrorRecorder for this interface
//!
//! Assigns the ErrorRecorder to this interface. The ErrorRecorder will track all errors during execution.
//! This function will call incRefCount of the registered ErrorRecorder at least once. Setting
//! recorder to nullptr unregisters the recorder with the interface, resulting in a call to decRefCount if
//! a recorder has been registered.
//!
//! \param recorder The error recorder to register with this interface.
//
//! \see getErrorRecorder
//!
virtual void setErrorRecorder(IErrorRecorder* recorder) TRTNOEXCEPT = 0;
//!
//! \brief get the ErrorRecorder assigned to this interface.
//!
//! Retrieves the assigned error recorder object for the given class. A default error recorder does not exist,
//! so a nullptr will be returned if setErrorRecorder has not been called.
//!
//! \return A pointer to the IErrorRecorder object that has been registered.
//!
//! \see setErrorRecorder
//!
virtual IErrorRecorder* getErrorRecorder() const TRTNOEXCEPT = 0;
};
//!
//! \class IPluginFactory
//!
//! \brief Plugin factory for deserialization.
//!
//! This Interface is guaranteed not to change for the same major version of TensorRT.
class IPluginFactory
{
public:
//!
//! \brief Create a plugin from serialized data.
//!
//! Responsibility of destroying this plugin lies with the application.
//! It can be done anytime after consumers of this plugin are destroyed.
//!
//! \param layerName The name of the layer.
//! \param serialData The serialized data.
//! \param serialLength The length of the serialized data.
//!
//! \return The plugin.
//!
//! \see IPlugin::serialize()
//!
virtual IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) TRTNOEXCEPT = 0;
virtual ~IPluginFactory() {}
};
//!
//! \enum OptProfileSelector
//!
//! \brief When setting or querying optimization profile parameters (such as shape tensor inputs or dynamic dimensions),
//! select whether we are interested in the minimum, optimum, or maximum values for these parameters.
//! The minimum and maximum specify the permitted range that is supported at runtime, while the optimum value
//! is used for the kernel selection. This should be the "typical" value that is expected to occur at runtime.
//!
//! \see IOptimizationProfile::setDimensions(), IOptimizationProfile::setShapeValues()
//!
enum class OptProfileSelector : int32_t
{
kMIN = 0, //!< This is used to set or get the minimum permitted value for dynamic dimensions etc.