Skip to content

Commit 119dc22

Browse files
committed
extract function newCipherBlock
1 parent fc195ab commit 119dc22

File tree

2 files changed

+71
-29
lines changed

2 files changed

+71
-29
lines changed

apikey.go

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ import (
88
"encoding/base64"
99
"encoding/hex"
1010
"encoding/json"
11+
"errors"
1112
"fmt"
1213
"io"
1314
"net/http"
1415
"time"
1516

16-
"github.com/pkg/errors"
17-
1817
"golang.org/x/crypto/bcrypt"
1918
)
2019

@@ -72,16 +71,9 @@ func (k *ApiKey) IsCorrect(given string) error {
7271

7372
// EncryptData uses the Secret to AES encrypt an arbitrary data block. It does not encrypt the key itself.
7473
func (k *ApiKey) EncryptData(plaintext []byte) ([]byte, error) {
75-
var sec []byte
76-
var err error
77-
sec, err = base64.StdEncoding.DecodeString(k.Secret)
78-
if err != nil {
79-
sec = []byte(k.Secret)
80-
}
81-
// create cipher block with api secret as aes key
82-
block, err := aes.NewCipher(sec)
74+
block, err := newCipherBlock(k.Secret)
8375
if err != nil {
84-
return []byte{}, err
76+
return nil, err
8577
}
8678

8779
// byte array to hold encrypted content
@@ -103,16 +95,9 @@ func (k *ApiKey) EncryptData(plaintext []byte) ([]byte, error) {
10395

10496
// DecryptData uses the Secret to AES decrypt an arbitrary data block. It does not decrypt the key itself.
10597
func (k *ApiKey) DecryptData(ciphertext []byte) ([]byte, error) {
106-
var sec []byte
107-
var err error
108-
sec, err = base64.StdEncoding.DecodeString(k.Secret)
109-
if err != nil {
110-
sec = []byte(k.Secret)
111-
}
112-
113-
block, err := aes.NewCipher(sec)
98+
block, err := newCipherBlock(k.Secret)
11499
if err != nil {
115-
return []byte{}, errors.Wrap(err, "failed to create new cipher")
100+
return nil, err
116101
}
117102

118103
// plaintext must be as long as ciphertext minus the length of the IV, which is the same as the AES block size
@@ -131,16 +116,9 @@ func (k *ApiKey) DecryptData(ciphertext []byte) ([]byte, error) {
131116
// DecryptLegacy uses the Secret to AES decrypt an arbitrary data block. This is intended only for legacy data such
132117
// as U2F keys.
133118
func (k *ApiKey) DecryptLegacy(ciphertext []byte) ([]byte, error) {
134-
var sec []byte
135-
var err error
136-
sec, err = base64.StdEncoding.DecodeString(k.Secret)
137-
if err != nil {
138-
sec = []byte(k.Secret)
139-
}
140-
141-
block, err := aes.NewCipher(sec)
119+
block, err := newCipherBlock(k.Secret)
142120
if err != nil {
143-
return []byte{}, errors.Wrap(err, "failed to create new cipher")
121+
return nil, err
144122
}
145123

146124
// data was encrypted, then base64 encoded, then joined with a :, need to split
@@ -302,3 +280,20 @@ func NewApiKey(email string) (ApiKey, error) {
302280
}
303281
return key, nil
304282
}
283+
284+
// newCipherBlock creates a new cipher.Block from a base64-encoded AES key. If the string is not valid base64 data, it
285+
// will be interpreted as binary data.
286+
func newCipherBlock(key string) (cipher.Block, error) {
287+
var sec []byte
288+
var err error
289+
sec, err = base64.StdEncoding.DecodeString(key)
290+
if err != nil {
291+
sec = []byte(key)
292+
}
293+
294+
block, err := aes.NewCipher(sec)
295+
if err != nil {
296+
return nil, fmt.Errorf("failed to create new cipher: %w", err)
297+
}
298+
return block, nil
299+
}

apikey_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ package mfa
22

33
import (
44
"bytes"
5+
"crypto/aes"
6+
"crypto/rand"
7+
"encoding/base64"
58
"encoding/json"
69
"fmt"
10+
"io"
711
"net/http"
812
"regexp"
913
"testing"
@@ -309,3 +313,46 @@ func (ms *MfaSuite) TestNewApiKey() {
309313
ms.NoError(err)
310314
ms.Regexp(regexp.MustCompile("[a-f0-9]{40}"), got)
311315
}
316+
317+
func (ms *MfaSuite) TestNewCipherBlock() {
318+
random := make([]byte, 32)
319+
_, err := io.ReadFull(rand.Reader, random)
320+
ms.NoError(err)
321+
322+
tests := []struct {
323+
name string
324+
key string
325+
wantErr bool
326+
}{
327+
{
328+
name: "key too short",
329+
key: "0123456789012345678901234567890",
330+
wantErr: true,
331+
},
332+
{
333+
name: "key too long",
334+
key: "012345678901234567890123456789012",
335+
wantErr: true,
336+
},
337+
{
338+
name: "raw",
339+
key: string(random),
340+
},
341+
{
342+
name: "base64",
343+
key: base64.StdEncoding.EncodeToString(random),
344+
},
345+
}
346+
for _, tt := range tests {
347+
ms.Run(tt.name, func() {
348+
got, err := newCipherBlock(tt.key)
349+
if tt.wantErr {
350+
ms.Error(err)
351+
return
352+
}
353+
354+
ms.NoError(err)
355+
ms.Equal(aes.BlockSize, got.BlockSize())
356+
})
357+
}
358+
}

0 commit comments

Comments
 (0)