@@ -75,18 +75,18 @@ class BertPooler(Layer):
75
75
"""
76
76
"""
77
77
78
- def __init__ (self , hidden_size , with_pool ):
78
+ def __init__ (self , hidden_size , pool_act = "tanh" ):
79
79
super (BertPooler , self ).__init__ ()
80
80
self .dense = nn .Linear (hidden_size , hidden_size )
81
81
self .activation = nn .Tanh ()
82
- self .with_pool = with_pool
82
+ self .pool_act = pool_act
83
83
84
84
def forward (self , hidden_states ):
85
85
# We "pool" the model by simply taking the hidden state corresponding
86
86
# to the first token.
87
87
first_token_tensor = hidden_states [:, 0 ]
88
88
pooled_output = self .dense (first_token_tensor )
89
- if self .with_pool == ' tanh' :
89
+ if self .pool_act == " tanh" :
90
90
pooled_output = self .activation (pooled_output )
91
91
return pooled_output
92
92
@@ -372,7 +372,7 @@ def __init__(self,
372
372
type_vocab_size = 16 ,
373
373
initializer_range = 0.02 ,
374
374
pad_token_id = 0 ,
375
- with_pool = ' tanh' ):
375
+ pool_act = " tanh" ):
376
376
super (BertModel , self ).__init__ ()
377
377
self .pad_token_id = pad_token_id
378
378
self .initializer_range = initializer_range
@@ -388,7 +388,7 @@ def __init__(self,
388
388
attn_dropout = attention_probs_dropout_prob ,
389
389
act_dropout = 0 )
390
390
self .encoder = nn .TransformerEncoder (encoder_layer , num_hidden_layers )
391
- self .pooler = BertPooler (hidden_size , with_pool )
391
+ self .pooler = BertPooler (hidden_size , pool_act )
392
392
self .apply (self .init_weights )
393
393
394
394
def forward (self ,
0 commit comments