@@ -210,13 +210,13 @@ def grid_sample(
210
210
+ ws * d_e * d_n + es * d_w * d_n
211
211
212
212
Args:
213
- x(Tensor): The input tensor, which is a 4-d tensor with shape
214
- [N, C, H, W] or a 5-d tensor with shape [N, C, D, H, W],
213
+ x(Tensor): The input tensor, which is a 4-D tensor with shape
214
+ [N, C, H, W] or a 5-D tensor with shape [N, C, D, H, W],
215
215
N is the batch size, C is the channel number,
216
216
D, H and W is the feature depth, height and width.
217
217
The data type is float32 or float64.
218
- grid(Tensor): Input grid tensor, which is a 4-d tensor with shape [N, grid_H,
219
- grid_W, 2] or a 5-d tensor with shape [N, grid_D, grid_H,
218
+ grid(Tensor): Input grid tensor, which is a 4-D tensor with shape [N, grid_H,
219
+ grid_W, 2] or a 5-D tensor with shape [N, grid_D, grid_H,
220
220
grid_W, 3]. The data type is float32 or float64.
221
221
mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'.
222
222
Default: 'bilinear'.
0 commit comments