Skip to content

Commit a1f634d

Browse files
authored
映射文档 No.70 (#5860)
* fix docs bugs * fix docs bugs * torch.nn.functional.mish * fix * fix2
1 parent 825d882 commit a1f634d

File tree

10 files changed

+352
-0
lines changed

10 files changed

+352
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
## [ torch 参数更多 ]torch.Tensor.nanmean
2+
3+
### [torch.Tensor.nanmean](https://pytorch.org/docs/1.13/generated/torch.Tensor.nanmean.html?highlight=nanmean#torch.Tensor.nanmean)
4+
5+
```python
6+
torch.Tensor.nanmean(dim=None,
7+
keepdim=False,
8+
dtype=None,
9+
out=None)
10+
```
11+
12+
### [paddle.Tensor.nanmean](暂无对应文档)
13+
14+
```python
15+
paddle.Tensor.nanmean(axis=None,
16+
keepdim=False,
17+
name=None)
18+
```
19+
20+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
21+
### 参数映射
22+
| PyTorch | PaddlePaddle | 备注 |
23+
| ------------- | ------------ | ------------------------------------------------------ |
24+
| dim | axis | 表示进行运算的轴,可选项,仅参数名不一致。 |
25+
| keepdim | keepdim | 表示是否保留计算后的维度,可选项。 |
26+
| dtype | - | 指定输出数据类型,可选项,Pytorch 默认值为 None,Paddle 无此参数,需要转写。 |
27+
| out | - | 表示输出的 Tensor,可选项,Paddle 无此参数,需要进行转写。 |
28+
29+
### 转写示例
30+
31+
#### dytpe:指定数据类型
32+
33+
```python
34+
# Pytorch 写法
35+
x.nanmean(dim=-1, dtype=torch.float32,out=y)
36+
37+
# Paddle 写法
38+
x.astype('float32')
39+
paddle.assign(x.nanmean(dim=-1),y)
40+
```
41+
42+
#### out:指定输出
43+
44+
```python
45+
# Pytorch 写法
46+
x.nanmean(dim=1out=y)
47+
48+
# Paddle 写法
49+
paddle.assign(x.nanmean(dim=1), y)
50+
```
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
## [ torch 参数更多 ]torch.nn.functional.mish
2+
3+
### [torch.nn.functional.mish](https://pytorch.org/docs/1.13/generated/torch.nn.functional.mish.html?highlight=torch+nn+functional+mish#torch.nn.functional.mish)
4+
5+
```python
6+
torch.nn.functional.mish(input,
7+
inplace=False)
8+
```
9+
10+
### [paddle.nn.functional.mish](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/functional/mish_cn.html)
11+
12+
```python
13+
paddle.nn.functional.mish(x,
14+
name=None)
15+
```
16+
17+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
18+
### 参数映射
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------------- | ------------ | ------------------------------------------------------ |
21+
| input | x | 表示输入的 Tensor ,仅参数名不一致。 |
22+
| inplace | - | 表示在不更改变量的内存地址的情况下,直接修改变量的值,主要功能为节省显存,一般对网络训练影响不大,Paddle 无此参数,可直接删除。 |
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## [ torch 参数更多 ]torch.linalg.eig
2+
3+
### [torch.linalg.eig](https://pytorch.org/docs/1.13/generated/torch.linalg.eig.html?highlight=torch+linalg+eig#torch.linalg.eig)
4+
5+
```python
6+
torch.linalg.eig(A,
7+
out=None)
8+
```
9+
10+
### [paddle.linalg.eig](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/linalg/eig_cn.html)
11+
12+
```python
13+
paddle.linalg.eig(x,
14+
name=None)
15+
```
16+
17+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
18+
### 参数映射
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------------- | ------------ | ------------------------------------------------------ |
21+
| A | x | 表示输入的 Tensor ,仅参数名不一致。 |
22+
| out | - | 表示输出的 tuple, Paddle 无此参数,需要进行转写。 |
23+
24+
### 转写示例
25+
26+
#### out:指定输出
27+
28+
```python
29+
# Pytorch 写法
30+
torch.linalg.eig(t,out=(L,V))
31+
32+
# Paddle 写法
33+
L,V=paddle.linalg.eig(t)
34+
```
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## [ torch 参数更多 ]torch.linalg.eigvals
2+
3+
### [torch.linalg.eigvals](https://pytorch.org/docs/1.13/generated/torch.linalg.eigvals.html?highlight=torch+linalg+eigvals#torch.linalg.eigvals)
4+
5+
```python
6+
torch.linalg.eigvals(A,
7+
out=None)
8+
```
9+
10+
### [paddle.linalg.eigvals](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/linalg/eigvals_cn.html)
11+
12+
```python
13+
paddle.linalg.eigvals(x,
14+
name=None)
15+
```
16+
17+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
18+
### 参数映射
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------- | ------------ | ---------------------------------------------------- |
21+
| A | x | 表示输入的 Tensor ,仅参数名不一致。 |
22+
| out | - | 表示输出的 Tensor , Paddle 无此参数,需要进行转写。 |
23+
24+
### 转写示例
25+
26+
#### out:指定输出
27+
28+
```python
29+
# Pytorch 写法
30+
torch.linalg.eigvals(t, out=y)
31+
32+
# Paddle 写法
33+
paddle.assign(paddle.linalg.eigvals(t), y)
34+
```
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## [ torch 参数更多 ]torch.linalg.multi_dot
2+
3+
### [torch.linalg.multi_dot](https://pytorch.org/docs/1.13/generated/torch.linalg.multi_dot.html?highlight=torch+linalg+multi_dot#torch.linalg.multi_dot)
4+
5+
```python
6+
torch.linalg.multi_dot(tensors,
7+
out=None)
8+
```
9+
10+
### [paddle.linalg.multi_dot](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/linalg/multi_dot_cn.html)
11+
12+
```python
13+
paddle.linalg.multi_dot(x,
14+
name=None)
15+
```
16+
17+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
18+
### 参数映射
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------------- | ------------ | ------------------------------------------------------ |
21+
| tensors | x | 表示输入的一个 tensor 列表 ,仅参数名不一致。 |
22+
| out | - | 表示输出的 Tensor , Paddle 无此参数,需要进行转写。 |
23+
24+
### 转写示例
25+
26+
#### out:指定输出
27+
28+
```python
29+
# Pytorch 写法
30+
torch.linalg.multi_dot(x, out=y)
31+
32+
# Paddle 写法
33+
paddle.assign(paddle.linalg.multi_dot(x), y)
34+
```
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
## [ torch 参数更多 ]troch.nn.Mish
2+
3+
### [troch.nn.Mish](https://pytorch.org/docs/1.13/generated/torch.nn.Mish.html?highlight=troch+nn+mish)
4+
5+
```python
6+
troch.nn.Mish(inplace=False)
7+
```
8+
9+
### [paddle.nn.Mish](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Mish_cn.html)
10+
11+
```python
12+
paddle.nn.Mish(name=None)
13+
```
14+
15+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
16+
### 参数映射
17+
| PyTorch | PaddlePaddle | 备注 |
18+
| ------------- | ------------ | ------------------------------------------------------ |
19+
| inplace | - | 表示在不更改变量的内存地址的情况下,直接修改变量的值,主要功能为节省显存,一般对网络训练影响不大,Paddle 无此参数,可直接删除。 |
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## [ torch 参数更多 ]torch.frexp
2+
3+
### [torch.frexp](https://pytorch.org/docs/1.13/generated/torch.frexp.html?highlight=frexp#torch.frexp)
4+
5+
```python
6+
torch.frexp(input,
7+
out=None)
8+
```
9+
10+
### [paddle.frexp](暂无对应文档)
11+
12+
```python
13+
paddle.frexp(x,
14+
name=None)
15+
```
16+
17+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
18+
### 参数映射
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------------- | ------------ | ------------------------------------------------------ |
21+
| input | x | 表示输入的 Tensor,仅参数名不一致。 |
22+
| out | - | 表示输出的 Tensor,可选项,Paddle 无此参数,需要进行转写。 |
23+
24+
### 转写示例
25+
26+
#### out:指定输出
27+
28+
```python
29+
# Pytorch 写法
30+
torch.frexp(x,out=y)
31+
32+
# Paddle 写法
33+
paddle.assign(paddle.frexp(x), y)
34+
```
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
## [ torch 参数更多 ]torch.nanmean
2+
3+
### [torch.nanmean](https://pytorch.org/docs/1.13/generated/torch.nanmean.html?highlight=nanmean#torch.nanmean)
4+
5+
```python
6+
torch.nanmean(input,
7+
dim=None,
8+
keepdim=False,
9+
dtype=None,
10+
out=None)
11+
```
12+
13+
### [paddle.nanmean](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nanmean_cn.html)
14+
15+
```python
16+
paddle.nanmean(x,
17+
axis=None,
18+
keepdim=False,
19+
name=None)
20+
```
21+
22+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
23+
### 参数映射
24+
| PyTorch | PaddlePaddle | 备注 |
25+
| ------------- | ------------ | ------------------------------------------------------ |
26+
| input | x | 表示输入的 Tensor,仅参数名不一致。 |
27+
| dim | axis | 表示进行运算的轴,可选项,仅参数名不一致。 |
28+
| keepdim | keepdim | 表示是否保留计算后的维度,可选项。 |
29+
| dtype | - | 指定输出数据类型,可选项,Pytorch 默认值为 None,Paddle 无此参数,需要转写。 |
30+
| out | - | 表示输出的 Tensor,可选项,Paddle 无此参数,需要进行转写。 |
31+
32+
### 转写示例
33+
34+
#### dytpe:指定数据类型
35+
36+
```python
37+
# Pytorch 写法
38+
torch.nanmean(x, dim=-1, dtype=torch.float32,out=y)
39+
40+
# Paddle 写法
41+
paddle.assign(paddle.nanmean(x.astype('float32'),dim=-1),y)
42+
```
43+
44+
#### out:指定输出
45+
46+
```python
47+
# Pytorch 写法
48+
torch.nanmean(t, dim=1out=y)
49+
50+
# Paddle 写法
51+
paddle.assign(paddle.nanmean(t, dim=1), y)
52+
```
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
## [ torch 参数更多 ]torch.take_along_dim
2+
3+
### [torch.take_along_dim](https://pytorch.org/docs/1.13/generated/torch.take_along_dim.html?highlight=torch+take_along_dim#torch.take_along_dim)
4+
5+
```python
6+
torch.take_along_dim(input,
7+
indices,
8+
dim,
9+
out=None)
10+
```
11+
12+
### [paddle.take_along_axis](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/take_along_axis_cn.html)
13+
14+
```python
15+
paddle.take_along_axis(arr,
16+
indices,
17+
axis)
18+
```
19+
20+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
21+
### 参数映射
22+
| PyTorch | PaddlePaddle | 备注 |
23+
| ------------- | ------------ | ------------------------------------------------------ |
24+
| input | arr | 表示输入的 Tensor ,仅参数名不一致。 |
25+
| indices | indices | 表示索引矩阵 ,仅参数名不一致。 |
26+
| dim | axis | 表示沿着哪个维度获取对应的值,仅参数名不一致。 |
27+
| out | - | 表示输出的 Tensor , Paddle 无此参数,需要进行转写。 |
28+
29+
### 转写示例
30+
31+
#### out:指定输出
32+
33+
```python
34+
# Pytorch 写法
35+
torch.take_along_dim(t, idx, 1out=y)
36+
37+
# Paddle 写法
38+
paddle.assign(paddle.take_along_axis(t, idx, 1), y)
39+
```
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## [ torch 参数更多 ]torch.special.erf
2+
3+
### [torch.special.erf](https://pytorch.org/docs/1.13/special.html?highlight=torch+special+erf#torch.special.erf)
4+
5+
```python
6+
torch.special.erf(input,
7+
out=None)
8+
```
9+
10+
### [paddle.erf](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/erf_cn.html)
11+
12+
```python
13+
paddle.erf(x,
14+
name=None)
15+
```
16+
17+
Pytorch 相比 Paddle 支持更多其他参数,具体如下:
18+
### 参数映射
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------------- | ------------ | ------------------------------------------------------ |
21+
| input | x | 表示输入的 Tensor ,仅参数名不一致。 |
22+
| out | - | 表示输出的 Tensor , Paddle 无此参数,需要进行转写。 |
23+
24+
### 转写示例
25+
26+
#### out:指定输出
27+
28+
```python
29+
# Pytorch 写法
30+
torch.special.erf(t, out=y)
31+
32+
# Paddle 写法
33+
paddle.assign(paddle.erf(t), y)
34+
```

0 commit comments

Comments
 (0)