Skip to content

Adding two variations of Japanese-BERT #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions +bert/+internal/convertModelNameToDirectories.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
% convertModelNameToDirectories Converts the user facing model name to
% the directory name used by support files.

% Copyright 2021 The MathWorks, Inc.
% Copyright 2021-2023 The MathWorks, Inc.
arguments
name (1,1) string
end
modelName = userInputToSupportFileName(name);
dirpath = {"data","networks","bert",modelName};
bertBaseLocation = "bert";
if contains(name,"japanese")
bertBaseLocation = "ja_" + bertBaseLocation;
end
dirpath = {"data","networks",bertBaseLocation,modelName};
end

function supportfileName = userInputToSupportFileName(name)
Expand All @@ -26,5 +30,7 @@
"medium", "uncased_L8_H512_A8";
"small", "uncased_L4_H512_A8";
"mini", "uncased_L4_H256_A4";
"tiny", "uncased_L2_H128_A2"];
"tiny", "uncased_L2_H128_A2";
"japanese-base-wwm", "";
"japanese-base", ""];
end
41 changes: 23 additions & 18 deletions +bert/+tokenizer/+internal/BasicTokenizer.m
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
classdef BasicTokenizer < bert.tokenizer.internal.Tokenizer
% BasicTokenizer Perform basic tokenization.

% Copyright 2020 The MathWorks, Inc.
% Copyright 2020-2023 The MathWorks, Inc.

properties(SetAccess=private)
IgnoreCase
Expand All @@ -28,24 +28,29 @@
function tokens = tokenize(this,text)
arguments
this (1,1) bert.tokenizer.internal.BasicTokenizer
text (1,1) string
text (1,:) string
end
u = textanalytics.unicode.UTF32(text);
u = this.cleanText(u);
u = this.tokenizeCJK(u);
text = u.string();
if this.IgnoreCase
text = lower(text);
text = textanalytics.unicode.nfd(text);
end
u = textanalytics.unicode.UTF32(text);
cats = u.characterCategories('Granularity','detailed');
if this.IgnoreCase
[u,cats] = this.stripAccents(u,cats);
tokens = cell(1,numel(string));
for i = 1:numel(text)
thisText = text(i);
u = textanalytics.unicode.UTF32(thisText);
u = this.cleanText(u);
u = this.tokenizeCJK(u);
thisText = u.string();
if this.IgnoreCase
thisText = lower(thisText);
thisText = textanalytics.unicode.nfd(thisText);
end
u = textanalytics.unicode.UTF32(thisText);
cats = u.characterCategories('Granularity','detailed');
if this.IgnoreCase
[u,cats] = this.stripAccents(u,cats);
end
theseTokens = this.splitOnPunc(u,cats);
theseTokens = join(cat(2,theseTokens{:})," ");
theseTokens = this.whiteSpaceTokenize(theseTokens);
tokens{i} = theseTokens;
end
tokens = this.splitOnPunc(u,cats);
tokens = join(cat(2,tokens{:})," ");
tokens = this.whiteSpaceTokenize(tokens);
end
end

Expand Down Expand Up @@ -160,4 +165,4 @@
inRange(udata,123,126);
cats = string(cats);
tf = (tf)|(cats.startsWith("P"));
end
end
57 changes: 41 additions & 16 deletions +bert/+tokenizer/+internal/FullTokenizer.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
% using the vocabulary specified in the newline delimited txt file
% vocabFile.
%
% tokenizer = FullTokenizer(vocabFile,'IgnoreCase',tf) controls if
% the FullTokenizer is case sensitive or not. The default value for
% tf is true.
% tokenizer = FullTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...)
% specifies the optional parameter name/value pairs:
%
% 'BasicTokenizer' - Tokenizer used to split text into words.
% If not specified, a default
% BasicTokenizer is constructed.
%
% 'IgnoreCase' - A logical value to control if the
% FullTokenizer is case sensitive or not.
% The default value is true.
%
% FullTokenizer methods:
% tokenize - tokenize text
% encode - encode tokens
% decode - decode encoded tokens
%
% Example:
% % Save a file named vocab.txt with the text on the next 3 lines:
% % Save a file named fakeVocab.txt with the text on the next 3 lines:
% fake
% vo
% ##cab
Expand All @@ -30,7 +37,7 @@
% % This returns the encoded form of the tokens - each token is
% % replaced by its corresponding line number in the fakeVocab.txt

% Copyright 2021 The MathWorks, Inc.
% Copyright 2021-2023 The MathWorks, Inc.

properties(Access=private)
Basic
Expand All @@ -46,17 +53,24 @@
% using the vocabulary specified in the newline delimited txt file
% vocabFile.
%
% tokenizer = FullTokenizer(vocabFile,'IgnoreCase',tf) controls if
% the FullTokenizer is case sensitive or not. The default value for
% tf is true.
% tokenizer = FullTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...) specifies
% the optional parameter name/value pairs:
%
% 'BasicTokenizer' - Tokenizer used to split text into words.
% If not specified, a default
% BasicTokenizer is constructed.
%
% 'IgnoreCase' - A logical value to control if the
% FullTokenizer is case sensitive or not.
% The default value is true.
%
% FullTokenizer methods:
% tokenize - tokenize text
% encode - encode tokens
% decode - decode encoded tokens
%
% Example:
% % Save a file named vocab.txt with the text on the next 3 lines:
% % Save a file named fakeVocab.txt with the text on the next 3 lines:
% fake
% vo
% ##cab
Expand All @@ -72,9 +86,16 @@
% % replaced by its corresponding line number in the fakeVocab.txt
arguments
vocab
nvp.BasicTokenizer = []
nvp.IgnoreCase = true
end
this.Basic = bert.tokenizer.internal.BasicTokenizer('IgnoreCase',nvp.IgnoreCase);
if isempty(nvp.BasicTokenizer)
% Default case
this.Basic = bert.tokenizer.internal.BasicTokenizer('IgnoreCase',nvp.IgnoreCase);
else
assert(isa(nvp.BasicTokenizer,'bert.tokenizer.internal.Tokenizer'),"BasicTokenizer must be a bert.tokenizer.internal.Tokenizer implementation.");
this.Basic = nvp.BasicTokenizer;
end
this.WordPiece = bert.tokenizer.internal.WordPieceTokenizer(vocab);
this.Encoding = this.WordPiece.Vocab;
end
Expand All @@ -85,12 +106,16 @@
% tokens = tokenize(tokenizer,text) tokenizes the input
% string text using the FullTokenizer specified by tokenizer.
basicToks = this.Basic.tokenize(txt);
basicToksUnicode = textanalytics.unicode.UTF32(basicToks);
subToks = cell(numel(basicToks),1);
for i = 1:numel(basicToks)
subToks{i} = this.WordPiece.tokenize(basicToksUnicode(i));
basicToksUnicode = cellfun(@textanalytics.unicode.UTF32,basicToks,UniformOutput=false);
toks = cell(numel(txt),1);
for i = 1:numel(txt)
theseBasicToks = basicToksUnicode{i};
theseSubToks = cell(numel(theseBasicToks),1);
for j = 1:numel(theseBasicToks)
theseSubToks{j} = this.WordPiece.tokenize(theseBasicToks(j));
end
toks{i} = cat(2,theseSubToks{:});
end
toks = cat(2,subToks{:});
end

function idx = encode(this,tokens)
Expand All @@ -109,4 +134,4 @@
tokens = this.Encoding.ind2word(x);
end
end
end
end
36 changes: 36 additions & 0 deletions +bert/+tokenizer/+internal/TokenizedDocumentTokenizer.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
classdef TokenizedDocumentTokenizer < bert.tokenizer.internal.Tokenizer
% TokenizedDocumentTokenizer Implements a word-level tokenizer using
% tokenizedDocument.

% Copyright 2023 The MathWorks, Inc.

properties
TokenizedDocumentOptions
IgnoreCase
end

methods
function this = TokenizedDocumentTokenizer(varargin,args)
arguments(Repeating)
varargin
end
arguments
args.IgnoreCase (1,1) logical = true
end
this.IgnoreCase = args.IgnoreCase;
this.TokenizedDocumentOptions = varargin;
end

function toks = tokenize(this,txt)
arguments
this
txt (1,:) string
end
if this.IgnoreCase
txt = lower(txt);
end
t = tokenizedDocument(txt,this.TokenizedDocumentOptions{:});
toks = doc2cell(t);
end
end
end
45 changes: 32 additions & 13 deletions +bert/+tokenizer/BERTTokenizer.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
% case-insensitive BERTTokenizer using the file vocabFile as
% the vocabulary.
%
% tokenizer = BERTTokenizer(vocabFile,'IgnoreCase',tf)
% Constructs a BERTTokenizer which is case-sensitive or not
% according to the scalar logical tf. The default is true.
% tokenizer = BERTTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...)
% specifies the optional parameter name/value pairs:
%
% 'IgnoreCase' - A logical value to control if the
% BERTTokenizer is case sensitive or not.
% The default value is true.
%
% 'FullTokenizer' - The underlying word-piece tokenizer.
% If not specified, a default
% FullTokenizer is constructed.
%
% BERTTokenizer properties:
% FullTokenizer - The underlying word-piece tokenizer.
Expand All @@ -34,7 +41,7 @@
% tokenizer = bert.tokenizer.BERTTokenizer();
% sequences = tokenizer.encode("Hello World!")

% Copyright 2021 The MathWorks, Inc.
% Copyright 2021-2023 The MathWorks, Inc.

properties(Constant)
PaddingToken = "[PAD]"
Expand Down Expand Up @@ -63,9 +70,16 @@
% case-insensitive BERTTokenizer using the file vocabFile as
% the vocabulary.
%
% tokenizer = BERTTokenizer(vocabFile,'IgnoreCase',tf)
% Constructs a BERTTokenizer which is case-sensitive or not
% according to the scalar logical tf. The default is true.
% tokenizer = BERTTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...)
% specifies the optional parameter name/value pairs:
%
% 'IgnoreCase' - A logical value to control if the
% BERTTokenizer is case sensitive or not.
% The default value is true.
%
% 'FullTokenizer' - The underlying word-piece tokenizer.
% If not specified, a default
% FullTokenizer is constructed.
%
% BERTTokenizer properties:
% FullTokenizer - The underlying word-piece tokenizer.
Expand All @@ -90,9 +104,15 @@
arguments
vocabFile (1,1) string {mustBeFile} = bert.internal.getSupportFilePath("base","vocab.txt")
nvp.IgnoreCase (1,1) logical = true
nvp.FullTokenizer = []
end
if isempty(nvp.FullTokenizer)
ignoreCase = nvp.IgnoreCase;
this.FullTokenizer = bert.tokenizer.internal.FullTokenizer(vocabFile,'IgnoreCase',ignoreCase);
else
assert(isa(nvp.FullTokenizer,'bert.tokenizer.internal.FullTokenizer'),"FullTokenizer must be a bert.tokenizer.internal.FullTokenizer.");
this.FullTokenizer = nvp.FullTokenizer;
end
ignoreCase = nvp.IgnoreCase;
this.FullTokenizer = bert.tokenizer.internal.FullTokenizer(vocabFile,'IgnoreCase',ignoreCase);
this.PaddingCode = this.FullTokenizer.encode(this.PaddingToken);
this.SeparatorCode = this.FullTokenizer.encode(this.SeparatorToken);
this.StartCode = this.FullTokenizer.encode(this.StartToken);
Expand Down Expand Up @@ -131,10 +151,9 @@
inputShape = size(text_a);
text_a = reshape(text_a,[],1);
text_b = reshape(text_b,[],1);
tokenize = @(text) this.FullTokenizer.tokenize(text);
tokens = arrayfun(tokenize,text_a,'UniformOutput',false);
tokens = this.FullTokenizer.tokenize(text_a);
if ~isempty(text_b)
tokens_b = arrayfun(tokenize,text_b,'UniformOutput',false);
tokens_b = this.FullTokenizer.tokenize(text_b);
tokens = cellfun(@(tokens_a,tokens_b) [tokens_a,this.SeparatorToken,tokens_b], tokens, tokens_b, 'UniformOutput', false);
end
tokens = cellfun(@(tokens) [this.StartToken, tokens, this.SeparatorToken], tokens, 'UniformOutput', false);
Expand Down Expand Up @@ -218,4 +237,4 @@
text = cellfun(@(x) join(x," "), tokens);
end
end
end
end
Loading