Skip to content

Commit fccd9c0

Browse files
authored
Merge pull request #19 from coreweave/rwang.fixhardcoding.07152024
Refactor GPT_BPE tokenizer file loading and initial processing
2 parents b0e1976 + d260b08 commit fccd9c0

File tree

3 files changed

+276
-163
lines changed

3 files changed

+276
-163
lines changed

gpt_bpe.go

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ func NewMistralEncoder() GPTEncoder {
147147
// Returns a GPTEncoder with the tokenizer data loaded for that vocabulary
148148
// id.
149149
func NewEncoder(vocabId string) (*GPTEncoder, error) {
150-
hfConfig, resourcesPtr, vocabErr := resources.ResolveVocabId(vocabId,
151-
"")
150+
hfConfig, resourcesPtr, vocabErr := resources.ResolveVocabId(vocabId, "")
151+
152152
if vocabErr != nil {
153153
return nil, vocabErr
154154
}
@@ -176,32 +176,6 @@ func NewEncoder(vocabId string) (*GPTEncoder, error) {
176176
}
177177
}
178178

179-
tokenizerSpecialConfig := resources.TokenizerSpecialsConfig{
180-
AddBosToken: false,
181-
AddEosToken: false,
182-
PadToken: "",
183-
}
184-
altMistralSpecialsConfig := resources.MistralSpecialsConfig{
185-
AddBosToken: false,
186-
AddEosToken: false,
187-
PadToken: "",
188-
}
189-
if special, ok := (rsrcs)["tokenizer_config.json"]; ok {
190-
if special.Data != nil {
191-
err := json.Unmarshal(*special.Data, &tokenizerSpecialConfig)
192-
if err != nil {
193-
err = json.Unmarshal(*special.Data, &altMistralSpecialsConfig)
194-
if err != nil {
195-
log.Fatal("Error unmarshalling tokenizer_config.json")
196-
}
197-
//populate the tokenizerSpecialConfig from the altMistralSpecialsConfig
198-
tokenizerSpecialConfig.AddBosToken = altMistralSpecialsConfig.AddBosToken
199-
tokenizerSpecialConfig.AddEosToken = altMistralSpecialsConfig.AddEosToken
200-
tokenizerSpecialConfig.PadToken = altMistralSpecialsConfig.PadToken
201-
}
202-
}
203-
}
204-
205179
puncRunes := make([]rune, 0)
206180
if specialConfig.PuncRunes != nil {
207181
for _, r := range specialConfig.PuncRunes {
@@ -364,23 +338,28 @@ func NewEncoder(vocabId string) (*GPTEncoder, error) {
364338
}
365339

366340
if specialConfig.EncloseEosBos {
367-
tokenizerSpecialConfig.AddBosToken = true
368-
tokenizerSpecialConfig.AddEosToken = true
341+
bosBool := true
342+
eosBool := true
343+
hfConfig.AddBosToken = &bosBool
344+
hfConfig.AddEosToken = &eosBool
369345
}
370346

371347
// Add in default pad token if not already set
372-
padTokenNotFound := (tokenizerSpecialConfig.PadToken == "" && hfConfig.PadTokenStr == nil)
348+
padTokenNotFound := (hfConfig.PadTokenStr == nil)
373349
if padTokenNotFound {
374350
// Inject the pad token into the encoder to uintmax16,
375351
// throw an error if vocab is larger than uintmax16
376-
if len(encoderTokens) >= math.MaxInt16 {
377-
log.Fatalf("Vocab size is larger than uint16 max, default pad token cannot be added." +
378-
"Please specify a pad token in the vocab file.")
352+
if len(encoderTokens) >= math.MaxUint16 {
353+
log.Fatalf("Vocab size of %d is larger than uint16 max of %d. "+
354+
"Please specify a pad token in the vocab file.",
355+
len(encoderTokens), math.MaxUint16)
379356
}
380-
encoderTokens[defaultPadTokenString] = math.MaxUint16
381-
tokenizerSpecialConfig.PadToken = defaultPadTokenString
382-
hfConfig.PadTokenStr = &tokenizerSpecialConfig.PadToken
357+
padToken := defaultPadTokenString
358+
encoderTokens[padToken] = math.MaxUint16
359+
hfConfig.PadTokenStr = &padToken
383360
}
361+
362+
// Create the encoder
384363
encoder := &GPTEncoder{
385364
encoderTokens,
386365
tokensEncoder,
@@ -403,8 +382,8 @@ func NewEncoder(vocabId string) (*GPTEncoder, error) {
403382
encoderTokens[*hfConfig.EosTokenStr],
404383
encoderTokens[*hfConfig.PadTokenStr],
405384
specialConfig.EncloseEosBos,
406-
tokenizerSpecialConfig.AddBosToken,
407-
tokenizerSpecialConfig.AddEosToken,
385+
*hfConfig.AddBosToken,
386+
*hfConfig.AddEosToken,
408387
specialConfig.PrefixSpace,
409388
specialConfig.LowerCase,
410389
specialConfig.EndOfWord,

gpt_bpe_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,29 @@ func TestReadTokenizerConfig(t *testing.T) {
884884
fmt.Println("All Exists - Looks good.")
885885
}
886886

887+
func TestPythiaRemoteDownloadTokenizer(t *testing.T) {
888+
// Tests the ability to download a tokenizer from a remote model
889+
// and use it to encode and decode strings
890+
modelId := "EleutherAI/pythia-70m"
891+
destPath := "./TestPythiaRemoteDownloadTokenizer"
892+
defer os.RemoveAll(destPath)
893+
encoderPythia, err := NewEncoder(modelId)
894+
if err != nil {
895+
t.Errorf("Error creating encoder: %v", err)
896+
}
897+
898+
// Attempt to tokenize
899+
testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
900+
901+
// Encode the string
902+
encoded := encoderPythia.Encode(&testString)
903+
// Check that the encoded string is the same as the expected - Reference from python's transformers lib
904+
expected := Tokens{510, 30013, 16780, 689, 253, 419, 250, 15, 187, 510, 45993, 310, 7938, 685, 253, 419, 250, 15}
905+
if !assert.Equal(t, expected, *encoded) {
906+
t.Errorf("Expected: %v\nActual: %v", expected, *encoded)
907+
}
908+
}
909+
887910
func TestGPTDecoder_Decode(t *testing.T) {
888911
// TBD
889912
}

0 commit comments

Comments
 (0)