diff --git a/go.mod b/go.mod index bae524c..1f8e537 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/yudai/gojsondiff v1.0.0 // indirect github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect github.com/yudai/pp v2.0.1+incompatible // indirect + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d google.golang.org/appengine v1.6.6 // indirect gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect diff --git a/go.sum b/go.sum index 59aa50c..9fbc013 100644 --- a/go.sum +++ b/go.sum @@ -126,6 +126,7 @@ github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDf github.com/yudai/pp v2.0.1+incompatible h1:Q4//iY4pNF6yPLZIigmvcl7k/bPgrcTPIFIcmawg5bI= github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/store.go b/store.go index a30c438..21d184c 100644 --- a/store.go +++ b/store.go @@ -1,12 +1,22 @@ package oauth2 -import "context" +import ( + "context" +) type ( // ClientStore the client information storage interface ClientStore interface { - // according to the ID for the client information + // get client information by ID + GetByID(ctx context.Context, id string) (ClientInfo, error) + } + + // SavingClientStore can save client information and retrieve it by ID + SavingClientStore interface { + // get client information by ID GetByID(ctx context.Context, id string) (ClientInfo, error) + // store client information + Save(ctx context.Context, info ClientInfo) error } // TokenStore the token information storage interface diff --git a/store/client.go b/store/client.go index 0001bb2..538ec88 100755 --- a/store/client.go +++ b/store/client.go @@ -40,3 +40,8 @@ func (cs *ClientStore) Set(id string, cli oauth2.ClientInfo) (err error) { cs.data[id] = cli return } + +// Save stores client information, implements the oauth2.SavingClientStore interface +func (cs *ClientStore) Save(_ context.Context, cli oauth2.ClientInfo) (err error) { + return cs.Set(cli.GetID(), cli) +} diff --git a/store/hash.go b/store/hash.go new file mode 100644 index 0000000..3486515 --- /dev/null +++ b/store/hash.go @@ -0,0 +1,152 @@ +package store + +import ( + "context" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" + "github.com/go-oauth2/oauth2/v4/models" + "golang.org/x/crypto/bcrypt" +) + +// Hasher is an interface for hashing and verifying client secrets. +type Hasher interface { + // Hash hashes the given secret and returns the hashed value. + Hash(secret string) (string, error) + // Verify checks if the hashed secret matches the given secret. + Verify(hashedPassword, secret string) error +} + +// BcryptHasher is a Hasher implementation using bcrypt for hashing and verifying secrets. +type BcryptHasher struct{} + +func (b *BcryptHasher) Hash(secret string) (string, error) { + hashed, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) + if err != nil { + return "", err + } + return string(hashed), nil +} + +func (b *BcryptHasher) Verify(hashed, secret string) error { + return bcrypt.CompareHashAndPassword([]byte(hashed), []byte(secret)) +} + +// ClientInfoWithHash wraps an oauth2.ClientInfo and provides secret verification using a Hasher. +type ClientInfoWithHash struct { + wrapped oauth2.ClientInfo + hasher Hasher +} + +// NewClientInfoWithHash creates a new instance of client info supporting hashed secret verification. +func NewClientInfoWithHash( + info oauth2.ClientInfo, + hasher Hasher, +) *ClientInfoWithHash { + if info == nil { + return nil + } + return &ClientInfoWithHash{ + wrapped: info, + hasher: hasher, + } +} + +// VerifyPassword verifies the given plain secret against the hashed secret. +// It implements the oauth2.ClientPasswordVerifier interface. +func (v *ClientInfoWithHash) VerifyPassword(secret string) bool { + if secret == "" { + return false + } + err := v.hasher.Verify(v.GetSecret(), secret) + return err == nil +} + +// GetID returns the client ID. +func (v *ClientInfoWithHash) GetID() string { + return v.wrapped.GetID() +} + +// GetSecret returns the hashed client secret. +func (v *ClientInfoWithHash) GetSecret() string { + return v.wrapped.GetSecret() +} + +// GetDomain returns the client domain. +func (v *ClientInfoWithHash) GetDomain() string { + return v.wrapped.GetDomain() +} + +// GetUserID returns the user ID associated with the client. +func (v *ClientInfoWithHash) GetUserID() string { + return v.wrapped.GetUserID() +} + +// IsPublic returns true if the client is public. +func (v *ClientInfoWithHash) IsPublic() bool { + return v.wrapped.IsPublic() +} + +// ClientStoreWithHash is a wrapper around oauth2.SavingClientStore that hashes client secrets. +type ClientStoreWithHash struct { + underlying oauth2.SavingClientStore + hasher Hasher +} + +// NewClientStoreWithBcrypt creates a new ClientStoreWithHash using bcrypt for hashing. +// +// It is a convenience function for creating a store with the default bcrypt hasher. +// The store will hash client secrets using bcrypt before saving them and would +// return secret information supporting secret verification against the hashed secret. +func NewClientStoreWithBcrypt(store oauth2.SavingClientStore) *ClientStoreWithHash { + return NewClientStoreWithHash(store, &BcryptHasher{}) +} + +func NewClientStoreWithHash(underlying oauth2.SavingClientStore, hasher Hasher) *ClientStoreWithHash { + if hasher == nil { + hasher = &BcryptHasher{} + } + return &ClientStoreWithHash{ + underlying: underlying, + hasher: hasher, + } +} + +// GetByID retrieves client information by ID and returns a ClientInfoWithHash instance. +func (w *ClientStoreWithHash) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { + info, err := w.underlying.GetByID(ctx, id) + if err != nil { + return nil, err + } + rval := NewClientInfoWithHash(info, w.hasher) + if rval == nil { + return nil, errors.ErrInvalidClient + } + return rval, nil +} + +// Save hashes the client secret before saving it to the underlying store. +func (w *ClientStoreWithHash) Save( + ctx context.Context, + info oauth2.ClientInfo, +) error { + if info == nil { + return errors.ErrInvalidClient + } + if info.GetSecret() == "" { + return errors.ErrInvalidClient + } + + hashed, err := w.hasher.Hash(info.GetSecret()) + if err != nil { + return err + } + hashedInfo := models.Client{ + ID: info.GetID(), + Secret: hashed, + Domain: info.GetDomain(), + UserID: info.GetUserID(), + Public: info.IsPublic(), + } + return w.underlying.Save(ctx, &hashedInfo) +} diff --git a/store/hash_test.go b/store/hash_test.go new file mode 100644 index 0000000..c7a12ca --- /dev/null +++ b/store/hash_test.go @@ -0,0 +1,54 @@ +package store_test + +import ( + "context" + "testing" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/models" + "github.com/go-oauth2/oauth2/v4/store" + . "github.com/smartystreets/goconvey/convey" +) + +func TestClientStoreWithHash(t *testing.T) { + Convey("Test client store with hash - save", t, func() { + hasher := &store.BcryptHasher{} + memory := store.NewClientStore() + store := store.NewClientStoreWithHash(memory, hasher) + secret := "123456" + err := store.Save(context.Background(), &models.Client{ + ID: "123", + Secret: secret, + Domain: "http://localhost", + Public: false, + UserID: "123", + }) + So(err, ShouldBeNil) + + Convey("get by id", func() { + storedClient, err := store.GetByID(context.Background(), "123") + + So(err, ShouldBeNil) + So(storedClient.GetID(), ShouldEqual, "123") + So(storedClient.GetSecret(), ShouldNotEqual, secret) + + verifier := storedClient.(oauth2.ClientPasswordVerifier) + + Convey("verify correct password - success", func() { + So(verifier.VerifyPassword(secret), ShouldBeTrue) + }) + + Convey("verify incorrect password - fail", func() { + So(verifier.VerifyPassword("wrong"), ShouldBeFalse) + }) + }) + }) +} + +// check interfaces + +var _ = (oauth2.ClientStore)((*store.ClientStoreWithHash)(nil)) +var _ = (oauth2.SavingClientStore)((*store.ClientStoreWithHash)(nil)) + +var _ = (oauth2.ClientInfo)((*store.ClientInfoWithHash)(nil)) +var _ = (oauth2.ClientPasswordVerifier)((*store.ClientInfoWithHash)(nil))