@@ -3,7 +3,6 @@ package main
3
3
import (
4
4
"bufio"
5
5
"bytes"
6
- "encoding/binary"
7
6
"errors"
8
7
"fmt"
9
8
"io"
@@ -19,6 +18,7 @@ import (
19
18
"github.com/aws/aws-sdk-go/service/s3"
20
19
"github.com/stretchr/testify/assert"
21
20
"github.com/wbrown/gpt_bpe"
21
+ "github.com/wbrown/gpt_bpe/types"
22
22
)
23
23
24
24
type SanitizerTest struct {
@@ -66,24 +66,11 @@ var sanitizerTests = SanitizerTests{
66
66
67
67
const corpusPath = "../../resources/frankenstein.txt"
68
68
69
- func TokensFromBin (bin * []byte ) * gpt_bpe.Tokens {
70
- tokens := make (gpt_bpe.Tokens , 0 )
71
- buf := bytes .NewReader (* bin )
72
- for {
73
- var token gpt_bpe.Token
74
- if err := binary .Read (buf , binary .LittleEndian , & token ); err != nil {
75
- break
76
- }
77
- tokens = append (tokens , token )
78
- }
79
- return & tokens
80
- }
81
-
82
69
// DecodeBuffer
83
70
// Decode Tokens from a byte array into a string.
84
71
func DecodeBuffer (encoded * []byte ) (text string ) {
85
72
// First convert our bytearray into a uint32 `Token` array.
86
- tokens := TokensFromBin (encoded )
73
+ tokens := types . TokensFromBin (encoded )
87
74
// Decode our tokens into a string.
88
75
var enc * gpt_bpe.GPTEncoder
89
76
encoderString := "gpt2"
@@ -736,3 +723,190 @@ func TestListObjectsRecursively(t *testing.T) {
736
723
737
724
wg .Wait () // Wait for all goroutines to finish
738
725
}
726
+
727
+ func TestUInt16WithNoEnforce (t * testing.T ) {
728
+ // Test if with Uint32 enforce disabled,
729
+ // using a Uint16 tokenizer works as intended with no padding.
730
+
731
+ textsTokenizer := NewTextsTokenizer ()
732
+ textsTokenizer .ContextSize = 2048
733
+ textsTokenizer .TokenizerId = "gpt2"
734
+ textsTokenizer .EndOfText = ""
735
+
736
+ // Test data
737
+ testString := "The quick brown fox jumps over the lazy dog."
738
+ expectedTokens := types.Tokens {464 , 2068 , 7586 , 21831 , 18045 , 625 , 262 , 16931 , 3290 , 13 , 50256 }
739
+ // Generate temp directory and test file
740
+ tempDir := os .TempDir ()
741
+ testFile := tempDir + "/test.txt"
742
+ f , err := os .Create (testFile )
743
+ if err != nil {
744
+ log .Fatal (err )
745
+ }
746
+ // Write test string to file
747
+ _ , err = f .WriteString (testString )
748
+ if err != nil {
749
+ log .Fatal (err )
750
+ }
751
+ f .Close ()
752
+ defer os .Remove (testFile )
753
+
754
+ reorderPaths := ""
755
+ sampling := 100
756
+ outputFile := "base.chunk"
757
+ defer os .Remove (outputFile )
758
+
759
+ enc , tokErr := textsTokenizer .InitTokenizer ()
760
+ if tokErr != nil {
761
+ log .Fatal (tokErr )
762
+ }
763
+
764
+ if texts , err := ReadTexts (
765
+ testFile , false ,
766
+ reorderPaths ,
767
+ 1 ,
768
+ ); err != nil {
769
+ log .Fatal (err )
770
+ } else {
771
+ begin := time .Now ()
772
+ contexts , tokErr := textsTokenizer .TokenizeTexts (
773
+ texts , "./test" , enc ,
774
+ )
775
+ if tokErr != nil {
776
+ log .Fatal (tokErr )
777
+ }
778
+
779
+ total , writeErr := WriteContexts (
780
+ outputFile ,
781
+ contexts ,
782
+ enc ,
783
+ sampling ,
784
+ false ,
785
+ false ,
786
+ false ,
787
+ )
788
+ if writeErr != nil {
789
+ log .Fatal (writeErr )
790
+ }
791
+ duration := time .Since (begin ).Seconds ()
792
+ log .Printf (
793
+ "%d tokens in %0.2fs, %0.2f tokens/s" , total ,
794
+ duration , float64 (total )/ duration ,
795
+ )
796
+ }
797
+ // Read the encoded tokens from the output file
798
+ binaryData , err := os .ReadFile (outputFile )
799
+ if err != nil {
800
+ log .Fatal (err )
801
+ }
802
+
803
+ // Convert to Tokens array
804
+ tokens := types .TokensFromBin (& binaryData )
805
+
806
+ if len (* tokens ) != len (expectedTokens ) {
807
+ t .Fatalf (
808
+ "Expected %d tokens, but got %d" , len (expectedTokens ),
809
+ len (* tokens ),
810
+ )
811
+ }
812
+ if & expectedTokens != tokens {
813
+ t .Fatalf ("Expected tokens: %v, but got: %v" , expectedTokens , tokens )
814
+ }
815
+
816
+ // Verify the encoded tokens
817
+ assert .Equal (t , & expectedTokens , tokens )
818
+ }
819
+
820
+ func TestUInt16WithEnforce (t * testing.T ) {
821
+ // Test if with Uint32 enforce enabled,
822
+ // using a Uint16 tokenizer works as intended with padding
823
+ // ie X, 0 Y, 0, Z, 0
824
+
825
+ textsTokenizer := NewTextsTokenizer ()
826
+ textsTokenizer .ContextSize = 2048
827
+ textsTokenizer .TokenizerId = "gpt2"
828
+ textsTokenizer .EndOfText = ""
829
+
830
+ // Test data
831
+ testString := "The quick brown fox jumps over the lazy dog."
832
+ expectedTokens := types.Tokens {464 , 0 , 2068 , 0 , 7586 , 0 , 21831 , 0 , 18045 , 0 , 625 , 0 , 262 , 0 , 16931 , 0 , 3290 , 0 , 13 , 0 , 50256 , 0 }
833
+ // Generate temp directory and test file
834
+ tempDir := os .TempDir ()
835
+ testFile := tempDir + "/test.txt"
836
+ f , err := os .Create (testFile )
837
+ if err != nil {
838
+ log .Fatal (err )
839
+ }
840
+ // Write test string to file
841
+ _ , err = f .WriteString (testString )
842
+ if err != nil {
843
+ log .Fatal (err )
844
+ }
845
+ f .Close ()
846
+ defer os .Remove (testFile )
847
+
848
+ reorderPaths := ""
849
+ sampling := 100
850
+ outputFile := "base.chunk"
851
+ defer os .Remove (outputFile )
852
+
853
+ enc , tokErr := textsTokenizer .InitTokenizer ()
854
+ if tokErr != nil {
855
+ log .Fatal (tokErr )
856
+ }
857
+
858
+ if texts , err := ReadTexts (
859
+ testFile , false ,
860
+ reorderPaths ,
861
+ 1 ,
862
+ ); err != nil {
863
+ log .Fatal (err )
864
+ } else {
865
+ begin := time .Now ()
866
+ contexts , tokErr := textsTokenizer .TokenizeTexts (
867
+ texts , "./test" , enc ,
868
+ )
869
+ if tokErr != nil {
870
+ log .Fatal (tokErr )
871
+ }
872
+
873
+ total , writeErr := WriteContexts (
874
+ outputFile ,
875
+ contexts ,
876
+ enc ,
877
+ sampling ,
878
+ false ,
879
+ true ,
880
+ false ,
881
+ )
882
+ if writeErr != nil {
883
+ log .Fatal (writeErr )
884
+ }
885
+ duration := time .Since (begin ).Seconds ()
886
+ log .Printf (
887
+ "%d tokens in %0.2fs, %0.2f tokens/s" , total ,
888
+ duration , float64 (total )/ duration ,
889
+ )
890
+ }
891
+ // Read the encoded tokens from the output file
892
+ binaryData , err := os .ReadFile (outputFile )
893
+ if err != nil {
894
+ log .Fatal (err )
895
+ }
896
+
897
+ // Convert to Tokens array
898
+ tokens := types .TokensFromBin (& binaryData )
899
+
900
+ if len (* tokens ) != len (expectedTokens ) {
901
+ t .Fatalf (
902
+ "Expected %d tokens, but got %d" , len (expectedTokens ),
903
+ len (* tokens ),
904
+ )
905
+ }
906
+ if & expectedTokens != tokens {
907
+ t .Fatalf ("Expected tokens: %v, but got: %v" , expectedTokens , tokens )
908
+ }
909
+
910
+ // Verify the encoded tokens
911
+ assert .Equal (t , & expectedTokens , tokens )
912
+ }
0 commit comments