This repository has been archived by the owner on Jun 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 135
/
tensorrt_optimize.py
executable file
·70 lines (60 loc) · 2.45 KB
/
tensorrt_optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# coding=utf-8
"""Given the tensorflow frozen graph, use TensorRT to optimize,
get a new frozen graph."""
from __future__ import print_function
import argparse
import time
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
parser = argparse.ArgumentParser()
parser.add_argument("pbfile")
parser.add_argument("newpbfile")
parser.add_argument("--precision_mode", default="FP32",
help="FP32, FP16, or INT8")
parser.add_argument("--maximum_cached_engines", default=100,
help="Don't know what this does.")
# parameter
# https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html
if __name__ == "__main__":
args = parser.parse_args()
max_batch_size = 1
precision_mode = args.precision_mode
minimum_segment_size = 2 # smaller the faster? 5 -60?
max_workspace_size_bytes = 1 << 32
maximum_cached_engines = args.maximum_cached_engines
output_names = [
"final_boxes",
"final_labels",
"final_probs",
"fpn_box_feat",
]
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
with tf.Graph().as_default() as tf_graph:
with tf.Session(config=tf_config) as tf_sess:
with tf.gfile.GFile(args.pbfile, "rb") as f:
frozen_graph = tf.GraphDef()
frozen_graph.ParseFromString(f.read())
graph_size = len(frozen_graph.SerializeToString())
num_nodes = len(frozen_graph.node)
start_time = time.time()
frozen_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=output_names,
max_batch_size=max_batch_size,
max_workspace_size_bytes=max_workspace_size_bytes,
precision_mode=precision_mode,
minimum_segment_size=minimum_segment_size,
is_dynamic_op=True, # this is needed for FPN
maximum_cached_engines=maximum_cached_engines)
end_time = time.time()
print("graph_size(MB)(native_tf): %.1f" % (float(graph_size)/(1<<20)))
print("graph_size(MB)(trt): %.1f" % (
float(len(frozen_graph.SerializeToString()))/(1<<20)))
print("num_nodes(native_tf): %d" % num_nodes)
print("num_nodes(tftrt_total): %d" % len(frozen_graph.node))
print("num_nodes(trt_only): %d" % len(
[1 for n in frozen_graph.node if str(n.op) == "TRTEngineOp"]))
print("time(s) (trt_conversion): %.4f" % (end_time - start_time))
with open(args.newpbfile, "wb") as f:
f.write(frozen_graph.SerializeToString())