Skip to content

Commit 6f79126

Browse files
authored
Merge pull request #6669 from pkuyym/fix-6581
Add python wrapper for lstm unit op.
2 parents cb23c63 + 9573256 commit 6f79126

File tree

4 files changed

+142
-12
lines changed

4 files changed

+142
-12
lines changed

doc/api/v2/fluid/layers.rst

+7-6
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,6 @@ beam_search_decode
188188
:noindex:
189189

190190

191-
lstm
192-
---------
193-
.. autofunction:: paddle.v2.fluid.layers.lstm
194-
:noindex:
195-
196-
197191
lod_rank_table
198192
---------
199193
.. autofunction:: paddle.v2.fluid.layers.lod_rank_table
@@ -300,7 +294,14 @@ conv2d_transpose
300294
.. autofunction:: paddle.v2.fluid.layers.conv2d_transpose
301295
:noindex:
302296

297+
303298
sequence_expand
304299
---------
305300
.. autofunction:: paddle.v2.fluid.layers.sequence_expand
306301
:noindex:
302+
303+
304+
lstm_unit
305+
---------
306+
.. autofunction:: paddle.v2.fluid.layers.lstm_unit
307+
:noindex:

paddle/operators/lstm_unit_op.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
5151
LstmUnitOpMaker(framework::OpProto* proto,
5252
framework::OpAttrChecker* op_checker)
5353
: OpProtoAndCheckerMaker(proto, op_checker) {
54-
AddInput("X", "FC input before the non-linear activation.");
54+
AddInput("X",
55+
"Lstm unit only applies non-linear activations, please make sure"
56+
"that linear tranformation has already been applied to `X`. "
57+
"Linear tranformation can be applied by adding a `fc` layer");
5558
AddInput(
5659
"C_prev",
5760
"The cell state tensor of last time-step in the Lstm Unit operator.");

python/paddle/v2/fluid/layers/nn.py

+113-4
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
from ..layer_helper import LayerHelper
66
from ..initializer import Normal, Constant
77
from ..framework import Variable
8+
from ..param_attr import ParamAttr
9+
from tensor import concat
810

911
__all__ = [
1012
'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf',
1113
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
1214
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
13-
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand'
15+
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
16+
'lstm_unit'
1417
]
1518

1619

@@ -761,7 +764,7 @@ def conv2d_transpose(input,
761764
return out
762765

763766

764-
def sequence_expand(x, y, main_program=None, startup_program=None):
767+
def sequence_expand(x, y):
765768
"""Sequence Expand Layer. This layer will expand the input variable **x**
766769
according to LoD information of **y**. And the following examples will
767770
explain how sequence_expand works:
@@ -805,8 +808,6 @@ def sequence_expand(x, y, main_program=None, startup_program=None):
805808
Args:
806809
x (Variable): The input variable which is a Tensor or LoDTensor.
807810
y (Variable): The input variable which is a LoDTensor.
808-
main_program (Program): The main program.
809-
startup_program (Program): The startup program.
810811
811812
Returns:
812813
Variable: The expanded variable which is a LoDTensor.
@@ -826,3 +827,111 @@ def sequence_expand(x, y, main_program=None, startup_program=None):
826827
type='sequence_expand', inputs={'X': x,
827828
'Y': y}, outputs={'Out': tmp})
828829
return tmp
830+
831+
832+
def lstm_unit(x_t,
833+
hidden_t_prev,
834+
cell_t_prev,
835+
forget_bias=0.0,
836+
param_attr=None,
837+
bias_attr=None):
838+
"""Lstm unit layer. The equation of a lstm step is:
839+
840+
.. math::
841+
842+
i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i)
843+
844+
f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f)
845+
846+
c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c)
847+
848+
o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o)
849+
850+
h_t & = o_t tanh(c_t)
851+
852+
The inputs of lstm unit includes :math:`x_t`, :math:`h_{t-1}` and
853+
:math:`c_{t-1}`. The implementation separates the linear transformation
854+
and non-linear transformation apart. Here, we take :math:`i_t` as an
855+
example. The linear transformation is applied by calling a `fc` layer and
856+
the equation is:
857+
858+
.. math::
859+
860+
L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i
861+
862+
The non-linear transformation is applied by calling `lstm_unit_op` and the
863+
equation is:
864+
865+
.. math::
866+
867+
i_t = \sigma(L_{i_t})
868+
869+
This layer has two outputs including :math:`h_t` and :math:`o_t`.
870+
871+
Args:
872+
x_t (Variable): The input value of current step.
873+
hidden_t_prev (Variable): The hidden value of lstm unit.
874+
cell_t_prev (Variable): The cell value of lstm unit.
875+
forget_bias (float): The forget bias of lstm unit.
876+
param_attr (ParamAttr): The attributes of parameter weights, used to set
877+
initializer, name etc.
878+
bias_attr (ParamAttr): The attributes of bias weights, if not False,
879+
bias weights will be created and be set to default value.
880+
881+
Returns:
882+
tuple: The hidden value and cell value of lstm unit.
883+
884+
Raises:
885+
ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\
886+
not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \
887+
and **cell_t_prev** not be the same.
888+
889+
Examples:
890+
891+
.. code-block:: python
892+
893+
x_t = fluid.layers.fc(input=x_t_data, size=10)
894+
prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20)
895+
prev_cell = fluid.layers.fc(input=prev_cell_data, size=30)
896+
hidden_value, cell_value = fluid.layers.lstm_unit(x_t=x_t,
897+
hidden_t_prev=prev_hidden,
898+
cell_t_prev=prev_cell)
899+
"""
900+
helper = LayerHelper('lstm_unit', **locals())
901+
902+
if len(x_t.shape) != 2:
903+
raise ValueError("Rank of x_t must be 2.")
904+
905+
if len(hidden_t_prev.shape) != 2:
906+
raise ValueError("Rank of hidden_t_prev must be 2.")
907+
908+
if len(cell_t_prev.shape) != 2:
909+
raise ValueError("Rank of cell_t_prev must be 2.")
910+
911+
if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[
912+
0] != cell_t_prev.shape[0]:
913+
raise ValueError("The 1s dimension of x_t, hidden_t_prev and "
914+
"cell_t_prev must be the same.")
915+
916+
if bias_attr is None:
917+
bias_attr = ParamAttr()
918+
919+
size = cell_t_prev.shape[1]
920+
concat_out = concat(input=[x_t, hidden_t_prev], axis=1)
921+
fc_out = fc(input=concat_out,
922+
size=4 * size,
923+
param_attr=param_attr,
924+
bias_attr=bias_attr)
925+
dtype = x_t.dtype
926+
c = helper.create_tmp_variable(dtype)
927+
h = helper.create_tmp_variable(dtype)
928+
929+
helper.append_op(
930+
type='lstm_unit',
931+
inputs={"X": fc_out,
932+
"C_prev": cell_t_prev},
933+
outputs={"C": c,
934+
"H": h},
935+
attrs={"forget_bias": forget_bias})
936+
937+
return h, c

python/paddle/v2/fluid/tests/test_layers.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_sigmoid_cross_entropy(self):
161161
x=dat, label=lbl))
162162
print(str(program))
163163

164-
def test_seq_expand(self):
164+
def test_sequence_expand(self):
165165
program = Program()
166166
with program_guard(program):
167167
x = layers.data(name='x', shape=[10], dtype='float32')
@@ -170,6 +170,23 @@ def test_seq_expand(self):
170170
self.assertIsNotNone(layers.sequence_expand(x=x, y=y))
171171
print(str(program))
172172

173+
def test_lstm_unit(self):
174+
program = Program()
175+
with program_guard(program):
176+
x_t_data = layers.data(
177+
name='x_t_data', shape=[10, 10], dtype='float32')
178+
x_t = layers.fc(input=x_t_data, size=10)
179+
prev_hidden_data = layers.data(
180+
name='prev_hidden_data', shape=[10, 20], dtype='float32')
181+
prev_hidden = layers.fc(input=prev_hidden_data, size=20)
182+
prev_cell_data = layers.data(
183+
name='prev_cell', shape=[10, 30], dtype='float32')
184+
prev_cell = layers.fc(input=prev_cell_data, size=30)
185+
self.assertIsNotNone(
186+
layers.lstm_unit(
187+
x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell))
188+
print(str(program))
189+
173190

174191
if __name__ == '__main__':
175192
unittest.main()

0 commit comments

Comments
 (0)