Skip to content

Commit b86d5b7

Browse files
authored
change varname,set default value for pool_act (PaddlePaddle#766)
1 parent 243cdd7 commit b86d5b7

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

examples/text_matching/simbert/predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def predict(model, data_loader):
8585
paddle.set_device(args.device)
8686

8787
model = ppnlp.transformers.BertModel.from_pretrained(
88-
'simbert-base-chinese', with_pool='linear')
88+
'simbert-base-chinese', pool_act='linear')
8989
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained(
9090
'simbert-base-chinese')
9191

paddlenlp/transformers/bert/modeling.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,18 @@ class BertPooler(Layer):
7575
"""
7676
"""
7777

78-
def __init__(self, hidden_size, with_pool):
78+
def __init__(self, hidden_size, pool_act="tanh"):
7979
super(BertPooler, self).__init__()
8080
self.dense = nn.Linear(hidden_size, hidden_size)
8181
self.activation = nn.Tanh()
82-
self.with_pool = with_pool
82+
self.pool_act = pool_act
8383

8484
def forward(self, hidden_states):
8585
# We "pool" the model by simply taking the hidden state corresponding
8686
# to the first token.
8787
first_token_tensor = hidden_states[:, 0]
8888
pooled_output = self.dense(first_token_tensor)
89-
if self.with_pool == 'tanh':
89+
if self.pool_act == "tanh":
9090
pooled_output = self.activation(pooled_output)
9191
return pooled_output
9292

@@ -372,7 +372,7 @@ def __init__(self,
372372
type_vocab_size=16,
373373
initializer_range=0.02,
374374
pad_token_id=0,
375-
with_pool='tanh'):
375+
pool_act="tanh"):
376376
super(BertModel, self).__init__()
377377
self.pad_token_id = pad_token_id
378378
self.initializer_range = initializer_range
@@ -388,7 +388,7 @@ def __init__(self,
388388
attn_dropout=attention_probs_dropout_prob,
389389
act_dropout=0)
390390
self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
391-
self.pooler = BertPooler(hidden_size, with_pool)
391+
self.pooler = BertPooler(hidden_size, pool_act)
392392
self.apply(self.init_weights)
393393

394394
def forward(self,

0 commit comments

Comments
 (0)