9
9
#include " onnxErrorRecorder.hpp"
10
10
#include " onnx/common/stl_backports.h"
11
11
#include < list>
12
+ #include < string>
12
13
#include < unordered_map>
14
+ #include < utility>
13
15
14
16
namespace onnx2trt
15
17
{
@@ -84,8 +86,24 @@ class ImporterContext final : public IImporterContext
84
86
int64_t mSuffixCounter {0 }; // increasing suffix counter used to uniquify layer names.
85
87
std::unordered_set<std::string> mUnsupportedShapeTensors ; // Container to hold output tensor names of layers that produce shape tensor outputs but do not natively support them.
86
88
StringMap<std::string> mLoopTensors ; // Container to map subgraph tensors to their original outer graph names.
87
- std::string mOnnxFileLocation ; // Keep track of the directory of the parsed ONNX file
89
+ std::string mOnnxFileLocation ; // Keep track of the directory of the parsed ONNX file
88
90
std::unique_ptr<ErrorRecorderWrapper> mErrorWrapper ; // error recorder to control TRT errors
91
+ StringMap<nvinfer1::IConstantLayer*> mConstantLayers ;
92
+
93
+ // ! Stack of names defined by nested ONNX graphs, with information about how to
94
+ // ! restore their associated values when popping back to the surrounding scope.
95
+ // !
96
+ // ! The stack is empty when processing the top-level ONNX graph.
97
+ // ! back() corresponds to the innermost ONNX graph being processed.
98
+ // !
99
+ // ! For each entry {name, {bool, TensorOrWeights}}:
100
+ // !
101
+ // ! * If the bool is true, the name was newly introduced by the scope.
102
+ // !
103
+ // ! * If the bool is false, the name shadows a name in a surrounding scope,
104
+ // ! and TensorOrWeights was the name's value before being shadowed.
105
+ // !
106
+ std::vector<StringMap<std::pair<bool , TensorOrWeights>>> mBaseNameScopeStack ;
89
107
90
108
public:
91
109
ImporterContext (nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
@@ -134,52 +152,15 @@ class ImporterContext final : public IImporterContext
134
152
{
135
153
return mOnnxFileLocation ;
136
154
}
137
- // This actually handles weights as well, but is named this way to be consistent with the tensors()
138
- void registerTensor (TensorOrWeights tensor, const std::string& basename) override
139
- {
140
- // TRT requires unique tensor names.
141
- const std::string uniqueName = generateUniqueName (mTensorNames , basename);
142
155
143
- if (tensor)
144
- {
145
- auto * ctx = this ; // To enable logging.
146
- if (tensor.is_tensor ())
147
- {
148
- tensor.tensor ().setName (uniqueName.c_str ());
156
+ void pushBaseNameScope () override ;
149
157
150
- LOG_VERBOSE (" Registering tensor: " << uniqueName << " for ONNX tensor: " << basename);
151
- }
152
- else if (tensor.is_weights ())
153
- {
154
- const auto & weights = tensor.weights ();
155
- if (tensor.weights ().type == ::ONNX_NAMESPACE::TensorProto::INT64)
156
- {
157
- tensor = ShapedWeights{::ONNX_NAMESPACE::TensorProto::INT32,
158
- convertINT64 (reinterpret_cast <int64_t *>(weights.values ), weights.shape , ctx), weights.shape };
159
- }
160
- tensor.weights ().setName (basename.c_str ());
161
- }
158
+ void popBaseNameScope () override ;
162
159
163
- }
164
- // Overwrite previous tensors registered with the same name (this only happens when there are subgraphs,
165
- // and in that case, overwriting is the desired behavior).
166
- this ->tensors ()[basename] = std::move (tensor);
167
- }
168
-
169
- void registerLayer (nvinfer1::ILayer* layer, const std::string& basename) override
170
- {
171
- // No layer will be added for Constant nodes in ONNX.
172
- if (layer)
173
- {
174
- const std::string name = basename.empty () ? layer->getName () : basename;
175
- const std::string uniqueName = generateUniqueName (mLayerNames , name);
176
-
177
- auto * ctx = this ; // To enable logging.
178
- LOG_VERBOSE (" Registering layer: " << uniqueName << " for ONNX node: " << basename);
160
+ // This actually handles weights as well, but is named this way to be consistent with the tensors()
161
+ void registerTensor (TensorOrWeights tensor, std::string const & basename) override ;
179
162
180
- layer->setName (uniqueName.c_str ());
181
- }
182
- }
163
+ void registerLayer (nvinfer1::ILayer* layer, std::string const & basename) override ;
183
164
184
165
nvinfer1::ILogger& logger () override
185
166
{
@@ -188,16 +169,10 @@ class ImporterContext final : public IImporterContext
188
169
189
170
ShapedWeights createTempWeights (ShapedWeights::DataType type, nvinfer1::Dims shape, uint8_t value = 0 ) override
190
171
{
172
+ std::string const & name = generateUniqueName (mTensorNames , " tmp_weight" );
191
173
ShapedWeights weights (type, nullptr , shape);
192
- // Need special logic for handling scalars.
193
- if (shape.nbDims == 0 )
194
- {
195
- mTempBufs .push_back (std::vector<uint8_t >(getDtypeSize (type), value));
196
- }
197
- else
198
- {
199
- mTempBufs .push_back (std::vector<uint8_t >(weights.size_bytes (), value));
200
- }
174
+ weights.setName (name.c_str ());
175
+ mTempBufs .push_back (std::vector<uint8_t >(weights.size_bytes (), value));
201
176
weights.values = mTempBufs .back ().data ();
202
177
return weights;
203
178
}
@@ -256,8 +231,13 @@ class ImporterContext final : public IImporterContext
256
231
{
257
232
return mOpsets .begin ()->second ;
258
233
}
234
+ else if (mOpsets .count (domain))
235
+ {
236
+ return mOpsets .at (domain);
237
+ }
259
238
else
260
239
{
240
+ domain = " ai.onnx" ;
261
241
assert (mOpsets .count (domain));
262
242
return mOpsets .at (domain);
263
243
}
@@ -271,8 +251,22 @@ class ImporterContext final : public IImporterContext
271
251
{
272
252
return mErrorWrapper ? mErrorWrapper ->getErrorRecorder () : nullptr ;
273
253
}
254
+ nvinfer1::IConstantLayer* getConstantLayer (const char * name) const final
255
+ {
256
+ if (name == nullptr )
257
+ {
258
+ return nullptr ;
259
+ }
260
+ auto const iter = mConstantLayers .find (name);
261
+ if (iter == mConstantLayers .end ())
262
+ {
263
+ return nullptr ;
264
+ }
265
+ return iter->second ;
266
+ }
267
+
274
268
private:
275
- std::string generateUniqueName (std::set<std::string>& namesSet, const std::string& basename)
269
+ std::string const & generateUniqueName (std::set<std::string>& namesSet, const std::string& basename)
276
270
{
277
271
std::string candidate = basename;
278
272
@@ -283,8 +277,8 @@ class ImporterContext final : public IImporterContext
283
277
}
284
278
285
279
namesSet.insert (candidate);
286
-
287
- return candidate;
280
+ // Return reference to newly inserted string to avoid any c_str()'s going out of scope
281
+ return *namesSet. find ( candidate) ;
288
282
}
289
283
};
290
284
0 commit comments