forked from NVIDIA/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
struct.py
103 lines (83 loc) · 3.16 KB
/
struct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from polygraphy import mod
from polygraphy.common.interface import TypedDict
np = mod.lazy_import("numpy")
class MetadataTuple(object):
def __init__(self, dtype, shape):
self.dtype = dtype
self.shape = shape
def __iter__(self):
yield from [self.dtype, self.shape]
def __repr__(self):
return "MetadataTuple({:}, {:})".format(self.dtype, self.shape)
def __str__(self):
ret = ""
meta_items = []
if self.dtype is not None:
meta_items.append("dtype={:}".format(np.dtype(self.dtype).name))
if self.shape is not None:
meta_items.append("shape={:}".format(tuple(self.shape)))
if meta_items:
ret += "[" + ", ".join(meta_items) + "]"
return ret
@mod.export()
class TensorMetadata(TypedDict(lambda: str, lambda: MetadataTuple)):
"""
An OrderedDict[str, MetadataTuple] that maps input names to their data types and shapes.
Shapes may include negative values, ``None``, or strings to indicate dynamic dimensions.
Example:
::
shape = tensor_meta["input0"].shape
dtype = tensor_meta["input0"].dtype
"""
@staticmethod
def from_feed_dict(feed_dict):
"""
Constructs a new TensorMetadata using information from the provided feed_dict.
Args:
feed_dict (OrderedDict[str, numpy.ndarray]):
A mapping of input tensor names to corresponding input NumPy arrays.
Returns:
TensorMetadata
"""
meta = TensorMetadata()
for name, arr in feed_dict.items():
meta.add(name, arr.dtype, arr.shape)
return meta
def add(self, name, dtype, shape):
"""
Convenience function for adding entries.
Args:
name (str): The name of the input.
dtype (numpy.dtype): The data type of the input.
shape (Sequence[Union[int, str]]]):
The shape of the input. Dynamic dimensions may
be indicated by negative values, ``None``, or a string.
Returns:
The newly added entry.
"""
self[name] = MetadataTuple(dtype, shape)
return self
def __repr__(self):
ret = "TensorMetadata()"
for name, (dtype, shape) in self.items():
ret += ".add('{:}', {:}, {:})".format(name, dtype, shape)
return ret
def __str__(self):
sep = ",\n "
elems = ["{:} {:}".format(name, meta_tuple).strip() for name, meta_tuple in self.items()]
return "{" + sep.join(elems) + "}"