-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add python wrapper for lstm unit op. #6669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
734e87e
ed56ed9
a398e25
58d6946
69072ef
d993a4f
9ee9fef
2a058b1
9573256
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,13 @@ | |
from ..layer_helper import LayerHelper | ||
from ..initializer import Normal, Constant | ||
from ..framework import Variable | ||
from tensor import concat | ||
|
||
__all__ = [ | ||
'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf', | ||
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy', | ||
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', | ||
'batch_norm', 'beam_search_decode', 'conv2d_transpose' | ||
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'lstm_unit' | ||
] | ||
|
||
|
||
|
@@ -392,7 +393,7 @@ def chunk_eval(input, | |
excluded_chunk_types=None, | ||
**kwargs): | ||
""" | ||
This function computes and outputs the precision, recall and | ||
This function computes and outputs the precision, recall and | ||
F1-score of chunk detection. | ||
""" | ||
helper = LayerHelper("chunk_eval", **kwargs) | ||
|
@@ -789,3 +790,110 @@ def conv2d_transpose(input, | |
attrs=op_attr) | ||
|
||
return out | ||
|
||
|
||
def lstm_unit(x_t, | ||
hidden_t_prev, | ||
cell_t_prev, | ||
forget_bias=0.0, | ||
main_program=None, | ||
startup_program=None): | ||
"""Lstm unit layer. The equation of a lstm step is: | ||
|
||
.. math:: | ||
|
||
i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) | ||
|
||
f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) | ||
|
||
c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) | ||
|
||
o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) | ||
|
||
h_t & = o_t tanh(c_t) | ||
|
||
The inputs of lstm unit includes :math:`x_t`, :math:`h_{t-1}` and | ||
:math:`c_{t-1}`. The implementation separates the linear transformation | ||
and non-linear transformation apart. Here, we take :math:`i_t` as an | ||
example. The linear transformation is applied by calling a `fc` layer and | ||
the equation is: | ||
|
||
.. math:: | ||
|
||
L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i | ||
|
||
The non-linear transformation is applied by calling `lstm_unit_op` and the | ||
equation is: | ||
|
||
.. math:: | ||
|
||
i_t = \sigma(L_{i_t}) | ||
|
||
This layer has two outputs including :math:`o_t` and :math:`h_t`. | ||
|
||
Args: | ||
x_t (Variable): The input value of current step. | ||
hidden_t_prev (Variable): The hidden value of lstm unit. | ||
cell_t_prev (Variable): The cell value of lstm unit. | ||
forget_bias (float): The forget bias of lstm unit. | ||
main_program (Program): The main program. | ||
startup_program (Program): the startup program. | ||
|
||
Returns: | ||
tuple: The cell value and hidden value of lstm unit. | ||
|
||
Raises: | ||
ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ | ||
not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \ | ||
and **cell_t_prev** not be the same. | ||
|
||
Examples: | ||
|
||
.. code-block:: python | ||
|
||
x_t = fluid.layers.fc(input=x_t_data, size=10) | ||
prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) | ||
prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) | ||
cell_value, hidden_value = fluid.layers.lstm_unit(x_t=x_t, | ||
hidden_t_prev=prev_hidden, | ||
cell_t_prev=prev_cell) | ||
""" | ||
helper = LayerHelper('lstm_unit', **locals()) | ||
|
||
if len(x_t.shape) != 2: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be removed, this will be check by infershape There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove all the shape check, if need, can add into infershape of operators. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make the exception more accurate, I think shape checking is also necessary here. |
||
raise ValueError("Rank of x_t must be 2.") | ||
|
||
if len(hidden_t_prev.shape) != 2: | ||
raise ValueError("Rank of hidden_t_prev must be 2.") | ||
|
||
if len(cell_t_prev.shape) != 2: | ||
raise ValueError("Rank of cell_t_prev must be 2.") | ||
|
||
if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[ | ||
0] != cell_t_prev.shape[0]: | ||
raise ValueError("The 1s dimension of x_t, hidden_t_prev and " | ||
"cell_t_prev must be the same.") | ||
|
||
size = cell_t_prev.shape[1] | ||
concat_out = concat( | ||
input=[x_t, hidden_t_prev], | ||
axis=1, | ||
main_program=main_program, | ||
startup_program=startup_program) | ||
fc_out = fc(input=concat_out, | ||
size=4 * size, | ||
main_program=main_program, | ||
startup_program=startup_program) | ||
dtype = x_t.dtype | ||
c = helper.create_tmp_variable(dtype) | ||
h = helper.create_tmp_variable(dtype) | ||
|
||
helper.append_op( | ||
type='lstm_unit', | ||
inputs={"X": fc_out, | ||
"C_prev": cell_t_prev}, | ||
outputs={"C": c, | ||
"H": h}, | ||
attrs={"forget_bias": forget_bias}) | ||
|
||
return c, h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add ParamAttr for fc's weight and bias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.