Skip to content

Commit d13b6bf

Browse files
committed
Test: Add Uint32 tests to dataset_tokenizer
1 parent baeedfc commit d13b6bf

File tree

3 files changed

+205
-17
lines changed

3 files changed

+205
-17
lines changed

cmd/dataset_tokenizer/dataset_tokenizer_test.go

Lines changed: 189 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package main
33
import (
44
"bufio"
55
"bytes"
6-
"encoding/binary"
76
"errors"
87
"fmt"
98
"io"
@@ -19,6 +18,7 @@ import (
1918
"github.com/aws/aws-sdk-go/service/s3"
2019
"github.com/stretchr/testify/assert"
2120
"github.com/wbrown/gpt_bpe"
21+
"github.com/wbrown/gpt_bpe/types"
2222
)
2323

2424
type SanitizerTest struct {
@@ -66,24 +66,11 @@ var sanitizerTests = SanitizerTests{
6666

6767
const corpusPath = "../../resources/frankenstein.txt"
6868

69-
func TokensFromBin(bin *[]byte) *gpt_bpe.Tokens {
70-
tokens := make(gpt_bpe.Tokens, 0)
71-
buf := bytes.NewReader(*bin)
72-
for {
73-
var token gpt_bpe.Token
74-
if err := binary.Read(buf, binary.LittleEndian, &token); err != nil {
75-
break
76-
}
77-
tokens = append(tokens, token)
78-
}
79-
return &tokens
80-
}
81-
8269
// DecodeBuffer
8370
// Decode Tokens from a byte array into a string.
8471
func DecodeBuffer(encoded *[]byte) (text string) {
8572
// First convert our bytearray into a uint32 `Token` array.
86-
tokens := TokensFromBin(encoded)
73+
tokens := types.TokensFromBin(encoded)
8774
// Decode our tokens into a string.
8875
var enc *gpt_bpe.GPTEncoder
8976
encoderString := "gpt2"
@@ -736,3 +723,190 @@ func TestListObjectsRecursively(t *testing.T) {
736723

737724
wg.Wait() // Wait for all goroutines to finish
738725
}
726+
727+
func TestUInt16WithNoEnforce(t *testing.T) {
728+
// Test if with Uint32 enforce disabled,
729+
// using a Uint16 tokenizer works as intended with no padding.
730+
731+
textsTokenizer := NewTextsTokenizer()
732+
textsTokenizer.ContextSize = 2048
733+
textsTokenizer.TokenizerId = "gpt2"
734+
textsTokenizer.EndOfText = ""
735+
736+
// Test data
737+
testString := "The quick brown fox jumps over the lazy dog."
738+
expectedTokens := types.Tokens{464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13, 50256}
739+
// Generate temp directory and test file
740+
tempDir := os.TempDir()
741+
testFile := tempDir + "/test.txt"
742+
f, err := os.Create(testFile)
743+
if err != nil {
744+
log.Fatal(err)
745+
}
746+
// Write test string to file
747+
_, err = f.WriteString(testString)
748+
if err != nil {
749+
log.Fatal(err)
750+
}
751+
f.Close()
752+
defer os.Remove(testFile)
753+
754+
reorderPaths := ""
755+
sampling := 100
756+
outputFile := "base.chunk"
757+
defer os.Remove(outputFile)
758+
759+
enc, tokErr := textsTokenizer.InitTokenizer()
760+
if tokErr != nil {
761+
log.Fatal(tokErr)
762+
}
763+
764+
if texts, err := ReadTexts(
765+
testFile, false,
766+
reorderPaths,
767+
1,
768+
); err != nil {
769+
log.Fatal(err)
770+
} else {
771+
begin := time.Now()
772+
contexts, tokErr := textsTokenizer.TokenizeTexts(
773+
texts, "./test", enc,
774+
)
775+
if tokErr != nil {
776+
log.Fatal(tokErr)
777+
}
778+
779+
total, writeErr := WriteContexts(
780+
outputFile,
781+
contexts,
782+
enc,
783+
sampling,
784+
false,
785+
false,
786+
false,
787+
)
788+
if writeErr != nil {
789+
log.Fatal(writeErr)
790+
}
791+
duration := time.Since(begin).Seconds()
792+
log.Printf(
793+
"%d tokens in %0.2fs, %0.2f tokens/s", total,
794+
duration, float64(total)/duration,
795+
)
796+
}
797+
// Read the encoded tokens from the output file
798+
binaryData, err := os.ReadFile(outputFile)
799+
if err != nil {
800+
log.Fatal(err)
801+
}
802+
803+
// Convert to Tokens array
804+
tokens := types.TokensFromBin(&binaryData)
805+
806+
if len(*tokens) != len(expectedTokens) {
807+
t.Fatalf(
808+
"Expected %d tokens, but got %d", len(expectedTokens),
809+
len(*tokens),
810+
)
811+
}
812+
if &expectedTokens != tokens {
813+
t.Fatalf("Expected tokens: %v, but got: %v", expectedTokens, tokens)
814+
}
815+
816+
// Verify the encoded tokens
817+
assert.Equal(t, &expectedTokens, tokens)
818+
}
819+
820+
func TestUInt16WithEnforce(t *testing.T) {
821+
// Test if with Uint32 enforce enabled,
822+
// using a Uint16 tokenizer works as intended with padding
823+
// ie X, 0 Y, 0, Z, 0
824+
825+
textsTokenizer := NewTextsTokenizer()
826+
textsTokenizer.ContextSize = 2048
827+
textsTokenizer.TokenizerId = "gpt2"
828+
textsTokenizer.EndOfText = ""
829+
830+
// Test data
831+
testString := "The quick brown fox jumps over the lazy dog."
832+
expectedTokens := types.Tokens{464, 0, 2068, 0, 7586, 0, 21831, 0, 18045, 0, 625, 0, 262, 0, 16931, 0, 3290, 0, 13, 0, 50256, 0}
833+
// Generate temp directory and test file
834+
tempDir := os.TempDir()
835+
testFile := tempDir + "/test.txt"
836+
f, err := os.Create(testFile)
837+
if err != nil {
838+
log.Fatal(err)
839+
}
840+
// Write test string to file
841+
_, err = f.WriteString(testString)
842+
if err != nil {
843+
log.Fatal(err)
844+
}
845+
f.Close()
846+
defer os.Remove(testFile)
847+
848+
reorderPaths := ""
849+
sampling := 100
850+
outputFile := "base.chunk"
851+
defer os.Remove(outputFile)
852+
853+
enc, tokErr := textsTokenizer.InitTokenizer()
854+
if tokErr != nil {
855+
log.Fatal(tokErr)
856+
}
857+
858+
if texts, err := ReadTexts(
859+
testFile, false,
860+
reorderPaths,
861+
1,
862+
); err != nil {
863+
log.Fatal(err)
864+
} else {
865+
begin := time.Now()
866+
contexts, tokErr := textsTokenizer.TokenizeTexts(
867+
texts, "./test", enc,
868+
)
869+
if tokErr != nil {
870+
log.Fatal(tokErr)
871+
}
872+
873+
total, writeErr := WriteContexts(
874+
outputFile,
875+
contexts,
876+
enc,
877+
sampling,
878+
false,
879+
true,
880+
false,
881+
)
882+
if writeErr != nil {
883+
log.Fatal(writeErr)
884+
}
885+
duration := time.Since(begin).Seconds()
886+
log.Printf(
887+
"%d tokens in %0.2fs, %0.2f tokens/s", total,
888+
duration, float64(total)/duration,
889+
)
890+
}
891+
// Read the encoded tokens from the output file
892+
binaryData, err := os.ReadFile(outputFile)
893+
if err != nil {
894+
log.Fatal(err)
895+
}
896+
897+
// Convert to Tokens array
898+
tokens := types.TokensFromBin(&binaryData)
899+
900+
if len(*tokens) != len(expectedTokens) {
901+
t.Fatalf(
902+
"Expected %d tokens, but got %d", len(expectedTokens),
903+
len(*tokens),
904+
)
905+
}
906+
if &expectedTokens != tokens {
907+
t.Fatalf("Expected tokens: %v, but got: %v", expectedTokens, tokens)
908+
}
909+
910+
// Verify the encoded tokens
911+
assert.Equal(t, &expectedTokens, tokens)
912+
}

gpt_bpe_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,18 @@ func TestGPTEncoder_Encode(t *testing.T) {
539539
}
540540
}
541541

542+
func TestGPTEncode(t *testing.T) {
543+
// This test is to check if the GPTEncoder is able to encode the tokens correctly
544+
strin := "The quick brown fox jumps over the lazy dog."
545+
expected := Tokens{464, 21831, 11687, 625, 262, 387, 260, 25970, 82, 29, 464, 28699, 318, 5443, 621, 262, 387, 260, 13}
546+
encoded := gpt2Encoder.Encode(&strin)
547+
fmt.Printf("Encoded: with commas:")
548+
for _, token := range *encoded {
549+
fmt.Printf("%v, ", token)
550+
}
551+
assert.Equal(t, *encoded, expected)
552+
}
553+
542554
func TestGPTEncoder_StreamingEncode(t *testing.T) {
543555
// This test is to check if the GPTEncoder is able to encode the tokens correctly
544556
start := time.Now()

js/js.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package main
33
//go:generate gopherjs build --minify
44

55
import (
6+
"log"
7+
68
"github.com/gopherjs/gopherjs/js"
79
"github.com/wbrown/gpt_bpe"
8-
"log"
10+
"github.com/wbrown/gpt_bpe/types"
911
)
1012

1113
var encoder gpt_bpe.GPTEncoder
@@ -15,7 +17,7 @@ func Tokenize(text string) gpt_bpe.Tokens {
1517
}
1618

1719
func Decode(arr []byte) string {
18-
tokens := gpt_bpe.TokensFromBin(&arr)
20+
tokens := types.TokensFromBin(&arr)
1921
return encoder.Decode(tokens)
2022
}
2123

0 commit comments

Comments
 (0)