|
16 | 16 | get_trt_tensor,
|
17 | 17 | has_dynamic_shape,
|
18 | 18 | set_layer_name,
|
| 19 | + to_trt_weights, |
19 | 20 | )
|
20 | 21 | from torch_tensorrt.dynamo.conversion.impl.cat import cat
|
21 | 22 | from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
|
@@ -48,89 +49,160 @@ def batch_norm(
|
48 | 49 | # Save the original output shape for later use
|
49 | 50 | output_shape = input.shape
|
50 | 51 |
|
51 |
| - # We name the weight here according to the state_dict name |
52 |
| - weight = ( |
53 |
| - get_trt_tensor(ctx, 1.0, f"{name}_weight") |
54 |
| - if weight is None |
55 |
| - else get_trt_tensor(ctx, weight, f"{name}_weight") |
56 |
| - ) |
57 |
| - bias = ( |
58 |
| - get_trt_tensor(ctx, 0.0, f"{name}_bias") |
59 |
| - if bias is None |
60 |
| - else get_trt_tensor(ctx, bias, f"{name}_bias") |
61 |
| - ) |
62 |
| - running_mean = ( |
63 |
| - get_trt_tensor(ctx, 0.0, f"{name}_running_mean") |
64 |
| - if running_mean is None |
65 |
| - else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") |
66 |
| - ) |
67 |
| - running_var = ( |
68 |
| - get_trt_tensor(ctx, 1.0, f"{name}_running_var") |
69 |
| - if running_var is None |
70 |
| - else get_trt_tensor(ctx, running_var, f"{name}_running_var") |
71 |
| - ) |
| 52 | + if all( |
| 53 | + [ |
| 54 | + isinstance(weight, torch.Tensor), |
| 55 | + isinstance(bias, torch.Tensor), |
| 56 | + isinstance(running_mean, torch.Tensor), |
| 57 | + isinstance(running_var, torch.Tensor), |
| 58 | + ] |
| 59 | + ): |
| 60 | + if weight is None: |
| 61 | + weight = 1.0 |
| 62 | + |
| 63 | + if bias is None: |
| 64 | + bias = 0.0 |
| 65 | + |
| 66 | + if running_mean is None: |
| 67 | + running_mean = 0.0 |
| 68 | + |
| 69 | + if running_var is None: |
| 70 | + running_var = 1.0 |
| 71 | + adjusted_scale = weight / torch.sqrt(running_var + eps) |
| 72 | + adjusted_bias = bias - running_mean * adjusted_scale |
| 73 | + power = torch.ones_like(adjusted_scale) |
| 74 | + adjusted_scale = to_trt_weights( |
| 75 | + ctx, |
| 76 | + adjusted_scale, |
| 77 | + name, |
| 78 | + layer_type_name="SCALE", |
| 79 | + weight_type_name="SCALE", |
| 80 | + target=target, |
| 81 | + source_ir=source_ir, |
| 82 | + ) |
| 83 | + adjusted_bias = to_trt_weights( |
| 84 | + ctx, |
| 85 | + adjusted_bias, |
| 86 | + name, |
| 87 | + layer_type_name="SCALE", |
| 88 | + weight_type_name="SHIFT", |
| 89 | + target=target, |
| 90 | + source_ir=source_ir, |
| 91 | + ) |
72 | 92 |
|
73 |
| - # eps_tensor for numerical stability |
74 |
| - eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps") |
| 93 | + power = to_trt_weights( |
| 94 | + ctx, |
| 95 | + power, |
| 96 | + name, |
| 97 | + layer_type_name="SCALE", |
| 98 | + weight_type_name="POWER", |
| 99 | + target=target, |
| 100 | + source_ir=source_ir, |
| 101 | + ) |
75 | 102 |
|
76 |
| - # adjusted_var = running_var + eps |
77 |
| - adjusted_var = impl.elementwise.add( |
78 |
| - ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor |
79 |
| - ) |
| 103 | + output_shape = input.shape |
| 104 | + if len(input.shape) < 4: |
80 | 105 |
|
81 |
| - # sqrt_adjusted_var = sqrt(adjusted_var) |
82 |
| - sqrt_adjusted_var = impl.unary.sqrt( |
83 |
| - ctx, target, source_ir, f"{name}_sqrt", adjusted_var |
84 |
| - ) |
| 106 | + new_shape = ( |
| 107 | + (input.shape[0], input.shape[1], 1, 1) |
| 108 | + if len(input.shape) == 2 |
| 109 | + else (input.shape[0], input.shape[1], input.shape[2], 1) |
| 110 | + ) |
| 111 | + input = impl.shuffle.reshape( |
| 112 | + ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape |
| 113 | + ) |
85 | 114 |
|
86 |
| - # scale = weight / sqrt_adjusted_var |
87 |
| - scale = impl.elementwise.div( |
88 |
| - ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var |
89 |
| - ) |
| 115 | + layer = ctx.net.add_scale_nd( |
| 116 | + input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1 |
| 117 | + ) |
| 118 | + set_layer_name(layer, target, name, source_ir) |
| 119 | + output = layer.get_output(0) |
90 | 120 |
|
91 |
| - # scaled_running_mean = running_mean * scale |
92 |
| - scaled_running_mean = impl.elementwise.mul( |
93 |
| - ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale |
94 |
| - ) |
| 121 | + else: |
95 | 122 |
|
96 |
| - # bias_adjusted = bias - scaled_running_mean |
97 |
| - bias_adjusted = impl.elementwise.sub( |
98 |
| - ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean |
99 |
| - ) |
| 123 | + # We name the weight here according to the state_dict name |
| 124 | + weight = ( |
| 125 | + get_trt_tensor(ctx, 1.0, f"{name}_weight") |
| 126 | + if weight is None |
| 127 | + else get_trt_tensor(ctx, weight, f"{name}_weight") |
| 128 | + ) |
| 129 | + bias = ( |
| 130 | + get_trt_tensor(ctx, 0.0, f"{name}_bias") |
| 131 | + if bias is None |
| 132 | + else get_trt_tensor(ctx, bias, f"{name}_bias") |
| 133 | + ) |
| 134 | + running_mean = ( |
| 135 | + get_trt_tensor(ctx, 0.0, f"{name}_running_mean") |
| 136 | + if running_mean is None |
| 137 | + else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") |
| 138 | + ) |
| 139 | + running_var = ( |
| 140 | + get_trt_tensor(ctx, 1.0, f"{name}_running_var") |
| 141 | + if running_var is None |
| 142 | + else get_trt_tensor(ctx, running_var, f"{name}_running_var") |
| 143 | + ) |
100 | 144 |
|
101 |
| - # Reshape scale and bias_adjusted to match input shape for broadcasting |
102 |
| - expanded_shape = [1] * len(output_shape) |
103 |
| - expanded_shape[1] = output_shape[1] # Set channel dimension |
| 145 | + # eps_tensor for numerical stability |
| 146 | + eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps") |
104 | 147 |
|
105 |
| - scale_reshape = impl.shuffle.reshape( |
106 |
| - ctx, |
107 |
| - target, |
108 |
| - source_ir, |
109 |
| - f"{name}_reshape_scale", |
110 |
| - scale, |
111 |
| - tuple(expanded_shape), |
112 |
| - ) |
113 |
| - bias_adjusted_reshape = impl.shuffle.reshape( |
114 |
| - ctx, |
115 |
| - target, |
116 |
| - source_ir, |
117 |
| - f"{name}_reshape_bias", |
118 |
| - bias_adjusted, |
119 |
| - tuple(expanded_shape), |
120 |
| - ) |
| 148 | + # adjusted_var = running_var + eps |
| 149 | + adjusted_var = impl.elementwise.add( |
| 150 | + ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor |
| 151 | + ) |
121 | 152 |
|
122 |
| - # Apply the scale and bias to the input |
123 |
| - scaled_input = impl.elementwise.mul( |
124 |
| - ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape |
125 |
| - ) |
126 |
| - output = impl.elementwise.add( |
127 |
| - ctx, |
128 |
| - target, |
129 |
| - source_ir, |
130 |
| - f"{name}_output", |
131 |
| - scaled_input, |
132 |
| - bias_adjusted_reshape, |
133 |
| - ) |
| 153 | + # sqrt_adjusted_var = sqrt(adjusted_var) |
| 154 | + sqrt_adjusted_var = impl.unary.sqrt( |
| 155 | + ctx, target, source_ir, f"{name}_sqrt", adjusted_var |
| 156 | + ) |
| 157 | + |
| 158 | + # scale = weight / sqrt_adjusted_var |
| 159 | + scale = impl.elementwise.div( |
| 160 | + ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var |
| 161 | + ) |
| 162 | + |
| 163 | + # scaled_running_mean = running_mean * scale |
| 164 | + scaled_running_mean = impl.elementwise.mul( |
| 165 | + ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale |
| 166 | + ) |
| 167 | + |
| 168 | + # bias_adjusted = bias - scaled_running_mean |
| 169 | + bias_adjusted = impl.elementwise.sub( |
| 170 | + ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean |
| 171 | + ) |
| 172 | + |
| 173 | + # Reshape scale and bias_adjusted to match input shape for broadcasting |
| 174 | + expanded_shape = [1] * len(output_shape) |
| 175 | + expanded_shape[1] = output_shape[1] # Set channel dimension |
| 176 | + |
| 177 | + scale_reshape = impl.shuffle.reshape( |
| 178 | + ctx, |
| 179 | + target, |
| 180 | + source_ir, |
| 181 | + f"{name}_reshape_scale", |
| 182 | + scale, |
| 183 | + tuple(expanded_shape), |
| 184 | + ) |
| 185 | + bias_adjusted_reshape = impl.shuffle.reshape( |
| 186 | + ctx, |
| 187 | + target, |
| 188 | + source_ir, |
| 189 | + f"{name}_reshape_bias", |
| 190 | + bias_adjusted, |
| 191 | + tuple(expanded_shape), |
| 192 | + ) |
| 193 | + |
| 194 | + # Apply the scale and bias to the input |
| 195 | + scaled_input = impl.elementwise.mul( |
| 196 | + ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape |
| 197 | + ) |
| 198 | + output = impl.elementwise.add( |
| 199 | + ctx, |
| 200 | + target, |
| 201 | + source_ir, |
| 202 | + f"{name}_output", |
| 203 | + scaled_input, |
| 204 | + bias_adjusted_reshape, |
| 205 | + ) |
134 | 206 |
|
135 | 207 | # For BatchNorm1d, reshape output back to original shape if necessary
|
136 | 208 | if len(output_shape) < 4:
|
|
0 commit comments