Skip to content

Commit 96eb515

Browse files
committed
style: clean up some print code
1 parent 741198f commit 96eb515

File tree

1 file changed

+39
-50
lines changed

1 file changed

+39
-50
lines changed

gpt_bpe_test.go

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ func TestGPTEncoder_Split(t *testing.T) {
273273
func BenchmarkGPTEncoder_WordSplitterChan(b *testing.B) {
274274
b.StopTimer()
275275
corpusHandle, err := os.Open(largeCorpusPath)
276-
defer corpusHandle.Close()
277276
if err != nil {
278277
b.Error(err)
279278
}
279+
defer corpusHandle.Close()
280280
gpt2Encoder.SplitterThreads = 8
281281
nextWord := gpt2Encoder.WordSplitter(bufio.NewReaderSize(corpusHandle,
282282
8*1024*1024))
@@ -413,8 +413,8 @@ func BenchmarkGPTEncoder_Decode(b *testing.B) {
413413
start := time.Now()
414414
tokenNumBytes := len(gpt2Encoder.Decode(gpt2Encoded))
415415
duration := time.Since(start)
416-
b.Log(fmt.Sprintf("%v tokens into %v bytes over %v",
417-
len(*gpt2Encoded), tokenNumBytes, duration))
416+
b.Logf("%v tokens into %v bytes over %v",
417+
len(*gpt2Encoded), tokenNumBytes, duration)
418418
}
419419

420420
type EncoderTest struct {
@@ -453,25 +453,25 @@ func BenchmarkGPTEncoder_Encode(b *testing.B) {
453453
start := time.Now()
454454
tokenCt := len(*gpt2Encoder.Encode(&corpus))
455455
duration := time.Since(start)
456-
b.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
457-
len(corpus), tokenCt, duration))
456+
b.Logf("%v bytes into %v tokens over %v",
457+
len(corpus), tokenCt, duration)
458458
}
459459

460460
func BenchmarkGPTEncoder_EncodeBuffer(b *testing.B) {
461461
corpusBytes := []byte(corpus)
462462
start := time.Now()
463463
tokenCt := len(*gpt2Encoder.EncodeBuffer(&corpusBytes)) / 2
464464
duration := time.Since(start)
465-
b.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
466-
len(corpus), tokenCt, duration))
465+
b.Logf("%v bytes into %v tokens over %v",
466+
len(corpus), tokenCt, duration)
467467
}
468468

469469
func TestGPTEncoder_Encode(t *testing.T) {
470470
start := time.Now()
471471
tokenCt := len(*gpt2Encoder.Encode(&corpus))
472472
duration := time.Since(start)
473-
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
474-
len(corpus), tokenCt, duration))
473+
t.Logf("%v bytes into %v tokens over %v",
474+
len(corpus), tokenCt, duration)
475475
for testIdx := range GPTEncoderTests {
476476
tokensPtr := *gpt2Encoder.Encode(
477477
&(GPTEncoderTests[testIdx].Input))
@@ -492,19 +492,18 @@ func TestGPTEncoder_StreamingEncode(t *testing.T) {
492492
tokenCt += len(*tokens)
493493
}
494494
duration := time.Since(start)
495-
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
496-
len(corpus), tokenCt, duration))
495+
t.Logf("%v bytes into %v tokens over %v",
496+
len(corpus), tokenCt, duration)
497497
}
498498

499499
func TestCLIPEncoder_Encode(t *testing.T) {
500500
start := time.Now()
501501
tokenCt := len(*clipEncoder.Encode(&corpus))
502502
duration := time.Since(start)
503-
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
504-
len(corpus), tokenCt, duration))
503+
t.Logf("%v bytes into %v tokens over %v",
504+
len(corpus), tokenCt, duration)
505505
for testIdx := range GPTEncoderTests {
506-
testStr := fmt.Sprintf("%s",
507-
GPTEncoderTests[testIdx].Input)
506+
testStr := GPTEncoderTests[testIdx].Input
508507
tokensPtr := *clipEncoder.Encode(&testStr)
509508
assert.Equal(t, GPTEncoderTests[testIdx].CLIPExpected, tokensPtr)
510509
}
@@ -514,8 +513,8 @@ func TestPileEncoder_Encode(t *testing.T) {
514513
start := time.Now()
515514
tokenCt := len(*pileEncoder.Encode(&corpus))
516515
duration := time.Since(start)
517-
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
518-
len(corpus), tokenCt, duration))
516+
t.Logf("%v bytes into %v tokens over %v",
517+
len(corpus), tokenCt, duration)
519518
for testIdx := range GPTEncoderTests {
520519
tokensPtr := *pileEncoder.Encode(
521520
&(GPTEncoderTests[testIdx].Input))
@@ -527,8 +526,8 @@ func TestNerdstashEncoder_Encode(t *testing.T) {
527526
start := time.Now()
528527
tokenCt := len(*nerdstashV2Encoder.Encode(&corpus))
529528
duration := time.Since(start)
530-
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
531-
len(corpus), tokenCt, duration))
529+
t.Logf("%v bytes into %v tokens over %v",
530+
len(corpus), tokenCt, duration)
532531
for testIdx := range GPTEncoderTests {
533532
tokensPtr := *nerdstashV2Encoder.Encode(
534533
&(GPTEncoderTests[testIdx].Input))
@@ -576,7 +575,7 @@ func TestNerdstashEncoder_Encode2(t *testing.T) {
576575
encoded := nerdstashV2Encoder.Encode(&inputStr)
577576
// check that the encoded string is the same as the expected
578577
if !assert.Equal(t, expected, *encoded) {
579-
t.Log(fmt.Sprintf("failure on input: `%v`", inputStr))
578+
t.Logf("failure on input: `%v`", inputStr)
580579
expectedRepr := []string{}
581580
for _, token := range expected {
582581
expectedRepr = append(expectedRepr,
@@ -587,14 +586,14 @@ func TestNerdstashEncoder_Encode2(t *testing.T) {
587586
actualRepr = append(actualRepr,
588587
string(nerdstashV2Encoder.Decoder[token]))
589588
}
590-
t.Log(fmt.Sprintf("expected: |%s", strings.Join(expectedRepr, "|")))
591-
t.Log(fmt.Sprintf("actual: |%s", strings.Join(actualRepr, "|")))
589+
t.Logf("expected: |%s", strings.Join(expectedRepr, "|"))
590+
t.Logf("actual: |%s", strings.Join(actualRepr, "|"))
592591
failCt += 1
593592
} else {
594593
passCt += 1
595594
}
596595
}
597-
t.Log(fmt.Sprintf("pass: %v, fail: %v", passCt, failCt))
596+
t.Logf("pass: %v, fail: %v", passCt, failCt)
598597
}
599598

600599
func TestNerdstashEncoder_Decode(t *testing.T) {
@@ -613,6 +612,9 @@ func TestGPTEncoder_Decode2(t *testing.T) {
613612
} else {
614613
tokens := TokensFromBin(&binTokens)
615614
tokens, err = gpt2Encoder.TrimIncompleteSentence(tokens)
615+
if err != nil {
616+
t.Error(err)
617+
}
616618
assert.Equal(t, gpt2Encoder.Decode(tokens), decodedCorpus)
617619
}
618620
}
@@ -626,8 +628,8 @@ func TestGPTEncoder_Decode(t *testing.T) {
626628
decoded := gpt2Encoder.Decode(gpt2Encoded)
627629
duration := time.Since(start)
628630
tokenNumBytes := len(decoded)
629-
t.Log(fmt.Sprintf("%v tokens into %v bytes over %v\n",
630-
len(*gpt2Encoded), tokenNumBytes, duration))
631+
t.Logf("%v tokens into %v bytes over %v\n",
632+
len(*gpt2Encoded), tokenNumBytes, duration)
631633
assert.Equal(t, corpus, decoded)
632634
}
633635

@@ -647,8 +649,7 @@ func TestCLIPEncoder_Decode(t *testing.T) {
647649
duration := time.Since(start)
648650
tokenNumBytes := len(decoded)
649651
idxToStop := 229550
650-
t.Log(fmt.Sprintf("%v tokens into %v bytes over %v\n",
651-
len(*clipEncoded), tokenNumBytes, duration))
652+
t.Logf("%v tokens into %v bytes over %v\n", len(*clipEncoded), tokenNumBytes, duration)
652653
for idx := range clipCorpus {
653654
if idx > idxToStop {
654655
break
@@ -672,8 +673,8 @@ func TestPileEncoder_Decode(t *testing.T) {
672673
decoded := pileEncoder.Decode(pileEncoded)
673674
duration := time.Since(start)
674675
tokenNumBytes := len(decoded)
675-
t.Log(fmt.Sprintf("%v tokens into %v bytes over %v\n",
676-
len(*pileEncoded), tokenNumBytes, duration))
676+
t.Logf("%v tokens into %v bytes over %v\n",
677+
len(*pileEncoded), tokenNumBytes, duration)
677678
range_data := corpus
678679
if len(corpus) > len(decoded) {
679680
range_data = decoded
@@ -779,8 +780,8 @@ func TestLlamaEncoder_Encode(t *testing.T) {
779780
start := time.Now()
780781
tokenCt := len(*gpt2Encoder.Encode(&corpus))
781782
duration := time.Since(start)
782-
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
783-
len(corpus), tokenCt, duration))
783+
t.Logf("%v bytes into %v tokens over %v",
784+
len(corpus), tokenCt, duration)
784785
for testIdx := range GPTEncoderTests {
785786
tokensPtr := *gpt2Encoder.Encode(
786787
&(GPTEncoderTests[testIdx].Input))
@@ -852,9 +853,7 @@ func TestReadTokenizerConfig(t *testing.T) {
852853
destPath := "./TestReadTokenizerConfig"
853854
destPathPTR := &destPath
854855
defer os.RemoveAll(destPath)
855-
var rsrcType resources.ResourceType
856-
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
857-
hfApiToken := os.Getenv("HF_API_TOKEN")
856+
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
858857
os.MkdirAll(destPath, 0755)
859858
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
860859
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
@@ -898,9 +897,7 @@ func TestModelDownload(t *testing.T) {
898897
destPath := "./TestModelDownload"
899898
destPathPTR := &destPath
900899

901-
var rsrcType resources.ResourceType
902-
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
903-
hfApiToken := os.Getenv("HF_API_TOKEN")
900+
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
904901
os.MkdirAll(destPath, 0755)
905902
defer os.RemoveAll(destPath)
906903
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
@@ -975,9 +972,7 @@ func TestModelDownloadPythia(t *testing.T) {
975972
destPath := "./TestModelDownloadPythia"
976973
destPathPTR := &destPath
977974

978-
var rsrcType resources.ResourceType
979-
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
980-
hfApiToken := os.Getenv("HF_API_TOKEN")
975+
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
981976
os.MkdirAll(destPath, 0755)
982977
defer os.RemoveAll(destPath)
983978
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
@@ -1051,9 +1046,7 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
10511046
destPath := "./TestModelDownloadPythiaSharded"
10521047
destPathPTR := &destPath
10531048

1054-
var rsrcType resources.ResourceType
1055-
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
1056-
hfApiToken := os.Getenv("HF_API_TOKEN")
1049+
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
10571050
os.MkdirAll(destPath, 0755)
10581051
defer os.RemoveAll(destPath)
10591052
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
@@ -1118,9 +1111,7 @@ func TestModelDownloadLlama(t *testing.T) {
11181111
destPathPTR := &destPath
11191112
defer os.RemoveAll(destPath)
11201113

1121-
var rsrcType resources.ResourceType
1122-
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
1123-
hfApiToken := os.Getenv("HF_API_TOKEN")
1114+
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
11241115
os.MkdirAll(destPath, 0755)
11251116
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
11261117
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
@@ -1166,7 +1157,7 @@ func TestModelDownloadLlama(t *testing.T) {
11661157
}
11671158

11681159
matches := re.FindStringSubmatch(file.Name())
1169-
if matches != nil && len(matches) > 2 {
1160+
if len(matches) > 2 {
11701161
if strings.Compare(matches[1], matches[2]) == 0 {
11711162
found = true
11721163
break
@@ -1212,9 +1203,7 @@ func TestModelDownloadFairseq(t *testing.T) {
12121203
destPath := "./TestModelDownloadFairseq"
12131204
destPathPTR := &destPath
12141205

1215-
var rsrcType resources.ResourceType
1216-
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
1217-
hfApiToken := os.Getenv("HF_API_TOKEN")
1206+
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
12181207
os.MkdirAll(destPath, 0755)
12191208
defer os.RemoveAll(destPath)
12201209
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,

0 commit comments

Comments
 (0)