forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
python_list.h
228 lines (187 loc) · 5.37 KB
/
python_list.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#pragma once
#include <ATen/core/Dict.h>
#include <ATen/core/List.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <pybind11/detail/common.h>
#include <torch/csrc/utils/pybind.h>
#include <cstddef>
#include <optional>
#include <stdexcept>
namespace torch::jit {
void initScriptListBindings(PyObject* module);
/// An iterator over the elements of ScriptList. This is used to support
/// __iter__(), .
class ScriptListIterator final {
public:
ScriptListIterator(
c10::impl::GenericList::iterator iter,
c10::impl::GenericList::iterator end)
: iter_(iter), end_(end) {}
at::IValue next();
bool done() const;
private:
c10::impl::GenericList::iterator iter_;
c10::impl::GenericList::iterator end_;
};
/// A wrapper around c10::List that can be exposed in Python via pybind
/// with an API identical to the Python list class. This allows
/// lists to have reference semantics across the Python/TorchScript
/// boundary.
class ScriptList final {
public:
// TODO: Do these make sense?
using size_type = size_t;
using diff_type = ptrdiff_t;
using ssize_t = Py_ssize_t;
// Constructor for empty lists created during slicing, extending, etc.
ScriptList(const at::TypePtr& type) : list_(at::AnyType::get()) {
auto list_type = type->expect<at::ListType>();
list_ = c10::impl::GenericList(list_type);
}
// Constructor for instances based on existing lists (e.g. a
// Python instance or a list nested inside another).
ScriptList(const at::IValue& data) : list_(at::AnyType::get()) {
TORCH_INTERNAL_ASSERT(data.isList());
list_ = data.toList();
}
at::ListTypePtr type() const {
return at::ListType::create(list_.elementType());
}
// Return a string representation that can be used
// to reconstruct the instance.
std::string repr() const {
std::ostringstream s;
s << '[';
bool f = false;
for (auto const& elem : list_) {
if (f) {
s << ", ";
}
s << at::IValue(elem);
f = true;
}
s << ']';
return s.str();
}
// Return an iterator over the elements of the list.
ScriptListIterator iter() const {
auto begin = list_.begin();
auto end = list_.end();
return ScriptListIterator(begin, end);
}
// Interpret the list as a boolean; empty means false, non-empty means
// true.
bool toBool() const {
return !(list_.empty());
}
// Get the value for the given index.
at::IValue getItem(diff_type idx) {
idx = wrap_index(idx);
return list_.get(idx);
}
// Set the value corresponding to the given index.
void setItem(diff_type idx, const at::IValue& value) {
idx = wrap_index(idx);
return list_.set(idx, value);
}
// Check whether the list contains the given value.
bool contains(const at::IValue& value) {
for (const auto& elem : list_) {
if (elem == value) {
return true;
}
}
return false;
}
// Delete the item at the given index from the list.
void delItem(diff_type idx) {
idx = wrap_index(idx);
auto iter = list_.begin() + idx;
list_.erase(iter);
}
// Get the size of the list.
ssize_t len() const {
return list_.size();
}
// Count the number of times a value appears in the list.
ssize_t count(const at::IValue& value) const {
ssize_t total = 0;
for (const auto& elem : list_) {
if (elem == value) {
++total;
}
}
return total;
}
// Remove the first occurrence of a value from the list.
void remove(const at::IValue& value) {
auto list = list_;
int64_t idx = -1, i = 0;
for (const auto& elem : list) {
if (elem == value) {
idx = i;
break;
}
++i;
}
if (idx == -1) {
throw py::value_error();
}
list.erase(list.begin() + idx);
}
// Append a value to the end of the list.
void append(const at::IValue& value) {
list_.emplace_back(value);
}
// Clear the contents of the list.
void clear() {
list_.clear();
}
// Append the contents of an iterable to the list.
void extend(const at::IValue& iterable) {
list_.append(iterable.toList());
}
// Remove and return the element at the specified index from the list. If no
// index is passed, the last element is removed and returned.
at::IValue pop(std::optional<size_type> idx = std::nullopt) {
at::IValue ret;
if (idx) {
idx = wrap_index(*idx);
ret = list_.get(*idx);
list_.erase(list_.begin() + *idx);
} else {
ret = list_.get(list_.size() - 1);
list_.pop_back();
}
return ret;
}
// Insert a value before the given index.
void insert(const at::IValue& value, diff_type idx) {
// wrap_index cannot be used; idx == len() is allowed
if (idx < 0) {
idx += len();
}
if (idx < 0 || idx > len()) {
throw std::out_of_range("list index out of range");
}
list_.insert(list_.begin() + idx, value);
}
// A c10::List instance that holds the actual data.
c10::impl::GenericList list_;
private:
// Wrap an index so that it can safely be used to access
// the list. For list of size sz, this function can successfully
// wrap indices in the range [-sz, sz-1]
diff_type wrap_index(diff_type idx) {
auto sz = len();
if (idx < 0) {
idx += sz;
}
if (idx < 0 || idx >= sz) {
throw std::out_of_range("list index out of range");
}
return idx;
}
};
} // namespace torch::jit