forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
import.h
108 lines (90 loc) · 3.75 KB
/
import.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#pragma once
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <istream>
#include <memory>
#include <caffe2/serialize/file_adapter.h>
namespace torch::jit {
using caffe2::serialize::ReadAdapterInterface;
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
constexpr const char* kArchiveNameBytecode = "bytecode";
constexpr const char* kArchiveNameConstants = "constants";
constexpr const char* kArchiveNameVersion = "version";
// The family of methods below load a serialized Mobile Module
// into a mobile::Module object.
TORCH_API mobile::Module _load_for_mobile(
std::istream& in,
std::optional<at::Device> device,
ExtraFilesMap& extra_file,
uint64_t module_load_options = kDefaultMobileLoadOptions);
TORCH_API mobile::Module _load_for_mobile(
const std::string& filename,
std::optional<at::Device> device,
ExtraFilesMap& extra_files);
TORCH_API mobile::Module _load_for_mobile(
std::unique_ptr<ReadAdapterInterface> rai,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options = kDefaultMobileLoadOptions);
TORCH_API mobile::Module _load_for_mobile(
const std::string& filename,
std::optional<at::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options);
TORCH_API mobile::Module _load_for_mobile(
std::istream& in,
std::optional<at::Device> device = std::nullopt);
TORCH_API mobile::Module _load_for_mobile(
const std::string& filename,
std::optional<at::Device> device = std::nullopt);
TORCH_API mobile::Module _load_for_mobile(
std::unique_ptr<ReadAdapterInterface> rai,
std::optional<c10::Device> device = std::nullopt);
/**
* Load only the contents of the "extra/" files whose names are
* passed in the map (extra_files). Populate the corresponding values
* with the contents of those files. Do not attempt to load the entire
* model, and stop once the extra files have been extracted.
*
* This API is needed to be able to load GPU models on linux CPU
* machines and extract only the extra files so that we can inspect
* the metadata that was added to the .ptl archive when it was
* generated.
*
*/
void _load_extra_only_for_mobile(
const std::string& filename,
std::optional<at::Device> device,
ExtraFilesMap& extra_files);
// Currently used by both mobile/import.cpp and model_compatibility.cpp.
// Should be removed after model_compatibility.cpp start using simplified
// version type_resolver and obj_loader.
at::TypePtr resolveTypeNameMobile(
const c10::QualifiedName& qn,
const std::shared_ptr<CompilationUnit>& compilation_unit);
c10::StrongTypePtr typeResolverMobile(
const c10::QualifiedName& qn,
const std::shared_ptr<CompilationUnit>& compilation_unit);
c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile(
const at::StrongTypePtr& type,
const at::IValue& input,
mobile::CompilationUnit& mobile_compilation_unit);
// Given a reader, which has access to a model file,
// return true if there exists tensors in `bytecode` archive
bool isTensorInBytecodeArchive(
caffe2::serialize::PyTorchStreamReader& stream_reader);
namespace mobile {
/**
* Given a torch::jit::mobile::Module, return a set of operator names
* (with overload name) that are used by any method in this mobile
* Mobile. This method runs through the bytecode for all methods
* in the specified model (module), and extracts all the root
* operator names. Root operators are operators that are called
* directly by the model (as opposed to non-root operators, which
* may be called transitively by the root operators).
*
*/
TORCH_API std::set<std::string> _export_operator_list(
torch::jit::mobile::Module& module);
} // namespace mobile
} // namespace torch::jit