Skip to content

Commit 59abf30

Browse files
authored
Merge pull request #3 from matlab-deep-learning/batchModel
Batched model
2 parents 144ec6b + 2ed3b1c commit 59abf30

File tree

10 files changed

+163
-104
lines changed

10 files changed

+163
-104
lines changed

+gpt2/model.m

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
% model A GPT-2 model
33
%
44
% [logits, presents] = model(X, pasts, parameters) performs prediction
5-
% with a GPT-2 model on the input X. X is a 1-by-numInputSubwords array
6-
% of tokenized text, and the model returns an array logits that is
7-
% 50257-by-numInputSubwords. This array can be used to predict the next
8-
% subword. See below for more details of inputs and outputs.
5+
% with a GPT-2 model on the input X. X is a
6+
% 1-by-numInputSubwords-by-numObs array of tokenized text, and the model
7+
% returns an array logits that is 50257-by-numInputSubwords-by-numObs.
8+
% This array can be used to predict the next subword. See below for more
9+
% details of inputs and outputs.
910
%
1011
% Inputs:
11-
% X - A 1-by-numInputSubwords array. This array is a
12-
% tokenized sentence. It should be created using
13-
% the tokenizer for GPT-2.
12+
% X - A 1-by-numInputSubwords-by-numObs array. This
13+
% array is a tokenized sentence. It should be
14+
% created using the tokenizer for GPT-2.
1415
% pasts - A numLayers-by-1 cell array containing "keys"
1516
% and "values" for the attention layers. These
1617
% come from the previous subwords in the text we
@@ -40,14 +41,14 @@
4041
% normalization.
4142
%
4243
% Outputs:
43-
% logits - A 50257-by-numInputSubwords array of logits
44-
% (pre-softmax outputs). If we apply softmax to
45-
% this array, we get the probabilities for the
46-
% next subword. However, we usually want to do
47-
% more pre-processing before doing this (like
48-
% taking the top-K entries). 50257 is the number
49-
% of subwords in the vocabulary for GPT-2's
50-
% tokenizer.
44+
% logits - A 50257-by-numInputSubwords-by-numObs array of
45+
% logits (pre-softmax outputs). If we apply
46+
% softmax to this array, we get the probabilities
47+
% for the next subword. However, we usually want
48+
% to do more pre-processing before doing this
49+
% (like taking the top-K entries). 50257 is the
50+
% number of subwords in the vocabulary for
51+
% GPT-2's tokenizer.
5152
% presents - A numLayers-by-1 cell array containing "keys"
5253
% and "values" from the attention blocks. We feed
5354
% these back in as the 'pasts' input.
@@ -57,9 +58,13 @@
5758

5859
% Apply the embedding. If there are inputs for the "past", we need to
5960
% offset the position embedding to account for this.
61+
% Word embedding
62+
seqLen = size(X, 2);
63+
h = weights.wte_0(:, X);
64+
h = reshape(h, size(h,1), seqLen, []);
65+
% Positional embedding
6066
positionOffset = size(pasts{1},2);
61-
h = weights.wte_0( :,X ) + ...
62-
weights.wpe_0( :, positionOffset + (1:length(X)) );
67+
h = h + weights.wpe_0(:, positionOffset + (1:seqLen) );
6368

6469
% Run the layers
6570
presents = cell(hyperparameters.NumLayers,1);
@@ -74,7 +79,7 @@
7479
weights.ln_f_g_0, ...
7580
weights.ln_f_b_0 );
7681

77-
% Calculate logits (50257-by-numInputSubwords)
78-
logits = weights.wte_0'*h;
82+
% Calculate logits (50257-by-numInputSubwords-by-numObs)
83+
logits = dlmtimes(weights.wte_0', h);
7984

8085
end

+transformer/+layer/attention.m

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
% 2 in [1]. See below for details of inputs and outputs.
77
%
88
% Inputs:
9-
% X - A (numFeatures*numHeads)-by-numInputSubwords
9+
% X - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
1010
% input array.
11-
% past - A numFeatures-by-numPastSubwords-by-numHeads-by-2
11+
% past - A numFeatures-by-numPastSubwords-by-numHeads-by-numObs-by-2
1212
% array. This contains the 'keys' and 'values' for
1313
% past subwords. These are needed to predict future
1414
% outputs in an autoregressive manner. 'keys' are
15-
% stored in past(:,:,:,1) and 'values' are stored
16-
% in past(:,:,:,2).
15+
% stored in past(:,:,:,:,1) and 'values' are stored
16+
% in past(:,:,:,:,2).
1717
% weights - The weights for the full multi-head attention
1818
% block stored in a struct. This includes:
1919
% - attn_c_attn_w_0: A weight matrix for the
@@ -28,15 +28,15 @@
2828
% hyper-parameter.
2929
%
3030
% Outputs:
31-
% Z - A (numFeatures*numHeads)-by-numInputSubwords
31+
% Z - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
3232
% output array.
33-
% present - A numFeatures-by-numAllSubwords-by-numHeads-by-2
33+
% present - A numFeatures-by-numAllSubwords-by-numHeads-by-numObs-by-2
3434
% array. This contains the 'keys' and 'values' that
3535
% are created from inputs. These need to passed
3636
% back in as the 'past' input if we want to predict
3737
% future outputs in an autoregressive manner. 'keys'
38-
% are stored in present(:,:,:,1) and 'values' are
39-
% stored in present(:,:,:,2).
38+
% are stored in present(:,:,:,:,1) and 'values' are
39+
% stored in present(:,:,:,:,2).
4040
%
4141
% References:
4242
%
@@ -52,9 +52,9 @@
5252

5353
% Split the results into Q (Query), K (Keys) and V (Values).
5454
splitSize = size(C,1)/3;
55-
Q = C(1:splitSize,:);
56-
K = C((splitSize+1):(2*splitSize),:);
57-
V = C((2*splitSize+1):(3*splitSize),:);
55+
Q = C(1:splitSize,:,:);
56+
K = C((splitSize+1):(2*splitSize),:,:);
57+
V = C((2*splitSize+1):(3*splitSize),:,:);
5858

5959
% Split heads
6060
Q = iSplitHeads(Q, splitSize, hyperParameters.NumHeads);
@@ -63,16 +63,16 @@
6363

6464
% Use the past
6565
if ~isempty(past)
66-
PK = past(:,:,:,1);
67-
PV = past(:,:,:,2);
66+
PK = past(:,:,:,:,1);
67+
PV = past(:,:,:,:,2);
6868
K = cat(2,PK,K);
6969
V = cat(2,PV,V);
7070
end
7171

7272
% Set present. Note that this is done differently from the original
7373
% implementation which sets the value of present before the previous if
74-
% statement.
75-
present = cat(4,K,V);
74+
% statement
75+
present = cat(5,K,V);
7676

7777
A = transformer.layer.multiheadAttention(Q,K,V);
7878

@@ -81,23 +81,22 @@
8181
A = transformer.layer.convolution1d( A, ...
8282
weights.attn_c_proj_w_0, ...
8383
weights.attn_c_proj_b_0 );
84-
8584
end
8685

8786
function Z = iSplitHeads(X, splitSize, numHeads)
8887
% We permute the data to put the dimension for the heads last, so that we
8988
% can use batched matrix multiplication to compute attention for all of the
9089
% heads at once.
9190
%
92-
% X - A (numFeatures*numHeads)-by-numSubwords array.
93-
% Z - A numFeatures-by-numSubwords-by-numHeads array.
94-
X = reshape(X, splitSize/numHeads, numHeads, []);
95-
Z = permute(X,[1 3 2]);
91+
% X - A (numFeatures*numHeads)-by-numSubwords-by-numObs array.
92+
% Z - A numFeatures-by-numSubwords-by-numHeads-by-numObs array.
93+
X = reshape(X, splitSize/numHeads, numHeads, [], size(X,3));
94+
Z = permute(X,[1 3 2 4]);
9695
end
9796

9897
function Z = iMergeHeads(X)
99-
% X - A numFeatures-by-numSubwords-by-numHeads array.
100-
% Z - A (numFeatures*numHeads)-by-numSubwords array.
101-
X = permute(X, [1 3 2]);
102-
Z = reshape(X, size(X,1)*size(X,2), []);
98+
% X - A numFeatures-by-numSubwords-by-numHeads-by-numObs array.
99+
% Z - A (numFeatures*numHeads)-by-numSubwords-by-numObs array.
100+
X = permute(X, [1 3 2 4]);
101+
Z = reshape(X, size(X,1)*size(X,2), [], size(X,4));
103102
end

+transformer/+layer/convolution1d.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
% Output:
1414
% Z - A numOutputFeatures-by-numInputSubwords array.
1515

16-
Z = W*X + b;
16+
Z = dlmtimes(W,X) + b;
1717

1818
end

+transformer/+layer/multiheadAttention.m

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@
1010
% below for details.
1111
%
1212
% Inputs:
13-
% Q - A numFeatures-by-numInputSubWords-by-numHeads array of
14-
% queries.
15-
% K - A numFeatures-by-numAllSubWords-by-numHeads array of keys.
16-
% V - A numFeatures-by-numAllSubWords-by-numHeads array of values.
13+
% Q - numFeatures-by-numInputSubWords-by-numHeads-by-numObs array of queries.
14+
% K - numFeatures-by-numAllSubWords-by-numHeads-by-numObs array of keys.
15+
% V - numFeatures-by-numAllSubWords-by-numHeads-by-numObs array of values.
1716
%
1817
% Outputs:
19-
% A - A numFeatures-by-numInputSubWords-by-numHeads array of
20-
% attention matrices.
18+
% A - numFeatures-by-numInputSubWords-by-numHeads-by-numObs array of attention matrices.
2119
%
2220
% References:
2321
%
@@ -29,7 +27,7 @@
2927
% matrices. W is numAllSubWords-by-numInputSubWords-by-numHeads. Each
3028
% element of W is the dot product of a query vector from Q and a key vector
3129
% from K.
32-
W = dlmtimes(permute(K, [2 1 3]), Q);
30+
W = dlmtimes(permute(K, [2 1 3 4]), Q);
3331

3432
% Divide by square root of d
3533
W = W./sqrt(size(Q,1));
@@ -38,7 +36,7 @@
3836
W = transformer.layer.maskAttentionWeights(W);
3937

4038
% Apply softmax
41-
W = softmax(W, 'DataFormat', 'CTB');
39+
W = softmax(W, 'DataFormat', 'CTUB');
4240

4341
% We compute the attention by taking products between the attention weights
4442
% W and V. A is numFeatures-by-numInputSubWords-by-numHeads. One

+transformer/+layer/normalization.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
% Layer normzalization is described in [1].
66
%
77
% Inputs:
8-
% X - A numFeatures-by-numInputSubwords input array.
8+
% X - A numFeatures-by-numInputSubwords-by-numObs input array.
99
% g - A numFeatures-by-1 weight vector.
1010
% b - A numFeatures-by-1 bias vector.
1111
%
1212
% Outputs:
13-
% Z - A numFeatures-by-numInputSubwords output array.
13+
% Z - A numFeatures-by-numInputSubwords-by-numObs output array.
1414
%
1515
% References:
1616
%

test/gpt2/layer/tblock.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function outputHasInputSizeWithPasts(test,Input)
3737
% Provide a fake past of sequence length 1
3838
K_fake = dlarray(rand(C,1));
3939
V_fake = dlarray(rand(C,1));
40-
past = cat(4,K_fake,V_fake);
40+
past = cat(5,K_fake,V_fake);
4141
[y,present] = test.block(x,past,weights,hyperParameters);
4242
test.verifySize(y,size(x));
4343
% The size of presents is the size of past except the sequence

test/gpt2/tmodel.m

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,49 @@
77
model = @gpt2.model
88
end
99

10+
properties(TestParameter)
11+
InputData = iGetInputData();
12+
end
13+
1014
methods(Test)
11-
function canUseModel(test)
12-
inputs = test.prepareInputs();
13-
test.verifyWarningFree(@() test.model(inputs{:}));
15+
function canUseModel(test, InputData)
16+
X = InputData;
17+
[pasts, parameters] = test.prepareInputs();
18+
test.verifyWarningFree(@() test.model(X, pasts, parameters));
19+
end
20+
21+
function canAcceptBatches(test)
22+
% gpt2.model should be able to accept multiple observations
23+
% with the same sequence length
24+
25+
% Create inputs
26+
[pasts, parameters] = test.prepareInputs();
27+
numObs = 4;
28+
seqLen = 5;
29+
vocabSize = size( parameters.Weights.wte_0, 2 );
30+
X = randi(vocabSize, [1 seqLen numObs]);
31+
32+
% Get batch results
33+
Ybatch = test.model(X, pasts, parameters);
34+
35+
% Iterate over batch
36+
YperObs = dlarray(zeros([vocabSize seqLen numObs], 'single'));
37+
for i = 1:numObs
38+
YperObs(:, :, i) = test.model(X(:, :, i), pasts, parameters);
39+
end
40+
41+
% Verify the results are within a relative tolerance for single
42+
% precision data
43+
test.verifyEqual(extractdata(Ybatch), extractdata(YperObs), 'RelTol', single(1e-5));
1444
end
1545
end
1646

1747
methods(Access=private)
18-
function inputs = prepareInputs(test)
48+
function [pasts, parameters] = prepareInputs(test)
1949
% Convenience method to setup inputs for
2050
% transformer.model
21-
X = test.prepareX();
2251
parameters = test.prepareParameters();
2352
pasts = test.preparePasts(parameters.Hyperparameters.NumLayers);
24-
inputs = {X,pasts,parameters};
25-
end
26-
27-
function X = prepareX(~)
28-
X = dlarray(1);
2953
end
3054

3155
function pasts = preparePasts(~,numLayers)
@@ -37,4 +61,12 @@ function canUseModel(test)
3761
parameters = gpt2.load(parametersFile);
3862
end
3963
end
64+
end
65+
66+
function s = iGetInputData()
67+
s = struct( ...
68+
'SingleToken', dlarray(1), ...
69+
'MultiSeqLen', dlarray([1 7 2 9]), ...
70+
'MultiSeqLenAndObs', dlarray( permute([1 7 2 9; 7 2 1 9], [3 2 1]) ) ...
71+
);
4072
end

0 commit comments

Comments
 (0)