tonycloud commited on
Commit
bda3cf5
1 Parent(s): e7d466f

Add onnx2engine.py

Browse files
Files changed (2) hide show
  1. README.md +4 -1
  2. onnx2engine.py +96 -0
README.md CHANGED
@@ -3,4 +3,7 @@ license: apache-2.0
3
  ---
4
 
5
  This project contains the onnx and tensorrt model files converted from the chatglm-6b model.
6
- The infer scripts for onnx and tensorrt will be refined later
 
 
 
 
3
  ---
4
 
5
  This project contains the onnx and tensorrt model files converted from the chatglm-6b model.
6
+ The infer scripts for onnx and tensorrt will be refined later
7
+
8
+ onnx2engine.py used to convert onnx into tensorrt engine, batch is now 1, can be modified
9
+ according to their own video memory into dynamic batch
onnx2engine.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorrt as trt
2
+ from itertools import tee
3
+
4
+ from polygraphy.backend.trt import (
5
+ network_from_onnx_path,
6
+ engine_from_network,
7
+ save_engine,
8
+ Profile,
9
+ )
10
+
11
+ from polygraphy.backend.trt import CreateConfig
12
+ from tensorrt import PreviewFeature, MemoryPoolType
13
+
14
+ batch_size = 1
15
+ max_length = 2048
16
+ opt_length = max_length // 2
17
+
18
+
19
+ profiles = [Profile().add(
20
+ "input_ids",
21
+ min=(batch_size, 1),
22
+ opt=(batch_size, opt_length), # Optimized based on the inputs.
23
+ max=(batch_size, max_length),
24
+ ).add(
25
+ "position_ids",
26
+ min=(batch_size, 2,1),
27
+ opt=(batch_size, 2, opt_length), # Optimized based on the inputs.
28
+ max=(batch_size, 2,max_length),
29
+ ).add(
30
+ "attention_mask",
31
+ min=(batch_size, 1,1,1),
32
+ opt=(batch_size, 1,opt_length,opt_length), # Optimized based on the inputs.
33
+ max=(batch_size, 1,max_length,max_length),
34
+ )]
35
+
36
+
37
+
38
+
39
+
40
+ def get_network_definition(network_definition):
41
+ def pairwise(iterable):
42
+ a, b = tee(iterable)
43
+ next(b, None)
44
+ return zip(a, b)
45
+
46
+ indices = list(range(0, network_definition[1].num_layers))
47
+ for i, i_next in pairwise(indices):
48
+ l = network_definition[1].get_layer(i)
49
+ l_next = network_definition[1].get_layer(i_next)
50
+
51
+ if not all([l.get_output(i).is_execution_tensor for i in range(l.num_outputs)]):
52
+ continue
53
+
54
+ if l.get_output_type(0) != trt.float32:
55
+ continue
56
+
57
+ if l.type == trt.LayerType.ELEMENTWISE and l_next.type == trt.LayerType.REDUCE:
58
+ l.__class__ = getattr(trt, "IElementWiseLayer")
59
+ if l.op == trt.ElementWiseOperation.POW:
60
+ l.precision = trt.float32
61
+ l.set_output_type(0, trt.float32)
62
+
63
+ l_next.precision = trt.float32
64
+ l_next.set_output_type(0, trt.float32)
65
+
66
+ return network_definition
67
+
68
+
69
+ input_fpath = "./model6b_onnx_pkv/model.onnx"
70
+
71
+
72
+ preview_features = [PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
73
+
74
+
75
+
76
+ trt_inference_config = CreateConfig(
77
+ fp16=True,
78
+ memory_pool_limits = {MemoryPoolType.WORKSPACE: 2048 * 1024 * 1024},
79
+ profiles=profiles,
80
+ precision_constraints=("obey"),
81
+ preview_features=preview_features
82
+ )
83
+
84
+
85
+ onnx_network = network_from_onnx_path(input_fpath)
86
+
87
+
88
+ network_definition = get_network_definition(onnx_network)
89
+ print(network_definition)
90
+ print(trt_inference_config)
91
+
92
+ trt_engine = engine_from_network(network_definition, trt_inference_config)
93
+ print(trt_engine)
94
+
95
+ output_fpath = "./model6b_trt_pkv/out.engine"
96
+ save_engine(trt_engine, output_fpath)