Skip to content

Commit 000e63d

Browse files
Merge pull request #497 from okta/OKTA-733548
update oauth2 method, adding id to client assertion
2 parents 964f066 + 3a618dd commit 000e63d

File tree

5 files changed

+108
-244
lines changed

5 files changed

+108
-244
lines changed

.generator/templates/cache_test.go

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
package okta
22

33
import (
4-
"fmt"
54
"io"
65
"io/ioutil"
76
"net/http"
87
"net/http/httptest"
98
"testing"
109

11-
"github.com/jarcoal/httpmock"
1210
"github.com/stretchr/testify/assert"
13-
"github.com/stretchr/testify/require"
1411
)
1512

1613
func Test_Create_Cache_Key(t *testing.T) {
@@ -74,86 +71,3 @@ func Test_Cache_Cleared_Successful(t *testing.T) {
7471
found = myCache.Has(cacheKey)
7572
assert.False(t, found, "cache was not cleared")
7673
}
77-
78-
func TestOAuthTokensAlwaysCached(t *testing.T) {
79-
httpmock.Activate()
80-
defer httpmock.DeactivateAndReset()
81-
WithCache(false)
82-
cfg, err := NewConfiguration(
83-
WithCache(false),
84-
WithOrgUrl("https://testing.oktapreview.com"),
85-
WithAuthorizationMode("PrivateKey"),
86-
WithClientId("abc"),
87-
WithPrivateKey(`
88-
-----BEGIN RSA PRIVATE KEY-----
89-
MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu
90-
KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm
91-
o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k
92-
TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7
93-
9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy
94-
v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs
95-
/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00
96-
-----END RSA PRIVATE KEY-----
97-
`),
98-
WithScopes(([]string{"okta.users.read"})),
99-
)
100-
require.NoError(t, err, "Creating a new config should not error")
101-
102-
client := NewAPIClient(cfg)
103-
104-
accessToken := RequestAccessToken{
105-
TokenType: "Bearer",
106-
ExpiresIn: 3600,
107-
AccessToken: "xyz",
108-
Scope: "okta.users.read",
109-
}
110-
httpmockTokenURLRegex := `=~^https://testing\.oktapreview\.com/oauth2/v1/token\?client_assertion=.*\z`
111-
jsonResp, err := httpmock.NewJsonResponder(200, accessToken)
112-
require.NoError(t, err)
113-
httpmock.RegisterResponder("POST", httpmockTokenURLRegex, jsonResp)
114-
115-
adminConsole := Application{}
116-
adminConsole.SetId("abc123")
117-
adminConsole.SetStatus("ACTIVE")
118-
adminConsole.SetLabel("Okta Admin Console")
119-
apps1 := []*Application{
120-
&adminConsole,
121-
}
122-
jsonResp, err = httpmock.NewJsonResponder(200, apps1)
123-
require.NoError(t, err)
124-
httpmockAdminConsoleRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Admin\+Console.*\z`
125-
httpmock.RegisterResponder("GET", httpmockAdminConsoleRegex, jsonResp)
126-
127-
dashboard := Application{}
128-
adminConsole.SetId("def456")
129-
adminConsole.SetStatus("ACTIVE")
130-
adminConsole.SetLabel("Okta Dashboard")
131-
apps2 := []*Application{
132-
&dashboard,
133-
}
134-
jsonResp, err = httpmock.NewJsonResponder(200, apps2)
135-
require.NoError(t, err)
136-
httpmockDashboardRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Dashboard.*\z`
137-
httpmock.RegisterResponder("GET", httpmockDashboardRegex, jsonResp)
138-
139-
_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute()
140-
require.NoError(t, err)
141-
_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute()
142-
require.NoError(t, err)
143-
144-
_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute()
145-
require.NoError(t, err)
146-
_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute()
147-
require.NoError(t, err)
148-
149-
info := httpmock.GetCallCountInfo()
150-
totalCalls := httpmock.GetTotalCallCount()
151-
152-
assert.Equal(t, 5, totalCalls, fmt.Sprintf("there should only be 5 API calls in this test but there were %d calls", totalCalls))
153-
// Tokens from requests should be cached.
154-
require.True(t, info[fmt.Sprintf("POST %s", httpmockTokenURLRegex)] == 1, "tokens endpoint should only be called once")
155-
156-
// But all other requests should not be cached.
157-
require.True(t, info[fmt.Sprintf("GET %s", httpmockAdminConsoleRegex)] == 2)
158-
require.True(t, info[fmt.Sprintf("GET %s", httpmockDashboardRegex)] == 2)
159-
}

.generator/templates/client.mustache

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ var (
5050
)
5151

5252
const (
53-
VERSION = "{{{packageVersion}}}"
53+
VERSION = "{{{packageVersion}}}"
5454
AccessTokenCacheKey = "OKTA_ACCESS_TOKEN"
5555
DpopAccessTokenNonce = "DPOP_OKTA_ACCESS_TOKEN_NONCE"
5656
DpopAccessTokenPrivateKey = "DPOP_OKTA_ACCESS_TOKEN_PRIVATE_KEY"
@@ -59,9 +59,9 @@ const (
5959
// APIClient manages communication with the {{appName}} API v{{version}}
6060
// In most cases there should be only one, shared, APIClient.
6161
type APIClient struct {
62-
cfg *Configuration
63-
common service // Reuse a single struct instead of allocating one for each service on the heap.
64-
cache Cache
62+
cfg *Configuration
63+
common service // Reuse a single struct instead of allocating one for each service on the heap.
64+
cache Cache
6565
tokenCache *goCache.Cache
6666
freshcache bool
6767
@@ -196,7 +196,7 @@ func (a *PrivateKeyAuth) Authorize(method, URL string) error {
196196
return err
197197
}
198198

199-
accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff)
199+
accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, a.clientId, a.privateKeySigner)
200200
if err != nil {
201201
return err
202202
}
@@ -287,7 +287,7 @@ func (a *JWTAuth) Authorize(method, URL string) error {
287287
}
288288
}
289289
} else {
290-
accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff)
290+
accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil)
291291
if err != nil {
292292
return err
293293
}
@@ -408,7 +408,7 @@ func (a *JWKAuth) Authorize(method, URL string) error {
408408
return err
409409
}
410410

411-
accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff)
411+
accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil)
412412
if err != nil {
413413
return err
414414
}
@@ -446,16 +446,16 @@ func convertJWKToPrivateKey(jwks, encryptionType string) (string, error) {
446446
pair := it.Pair()
447447
key := pair.Value.(jwk.Key)
448448
var rawkey interface{} // This is the raw key, like *rsa.PrivateKey or *ecdsa.PrivateKey
449-
err := key.Raw(&rawkey);
449+
err := key.Raw(&rawkey)
450450
if err != nil {
451-
return "",err
451+
return "", err
452452
}
453453

454454
switch encryptionType {
455455
case "RSA":
456456
rsaPrivateKey, ok := rawkey.(*rsa.PrivateKey)
457457
if !ok {
458-
return "",fmt.Errorf("expected rsa key, got %T", rawkey)
458+
return "", fmt.Errorf("expected rsa key, got %T", rawkey)
459459
}
460460
return string(privateKeyToBytes(rsaPrivateKey)), nil
461461
default:
@@ -514,26 +514,25 @@ func createClientAssertion(orgURL, clientID string, privateKeySinger jose.Signer
514514
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))),
515515
Issuer: clientID,
516516
Audience: orgURL + "/oauth2/v1/token",
517+
ID: uuid.New().String(),
517518
}
518519
jwtBuilder := jwt.Signed(privateKeySinger).Claims(claims)
519520
return jwtBuilder.CompactSerialize()
520521
}
521522

522-
func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
523-
var tokenRequestBuff io.ReadWriter
523+
func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
524524
query := url.Values{}
525525
tokenRequestURL := orgURL + "/oauth2/v1/token"
526526

527527
query.Add("grant_type", "client_credentials")
528528
query.Add("scope", strings.Join(scopes, " "))
529529
query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
530530
query.Add("client_assertion", clientAssertion)
531-
tokenRequestURL += "?" + query.Encode()
532-
tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff)
531+
532+
tokenRequest, err := http.NewRequest("POST", tokenRequestURL, strings.NewReader(query.Encode()))
533533
if err != nil {
534534
return nil, "", nil, err
535535
}
536-
537536
tokenRequest.Header.Add("Accept", "application/json")
538537
tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
539538
tokenRequest.Header.Add("User-Agent", userAgent)
@@ -552,14 +551,20 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio
552551
if err != nil {
553552
return nil, "", nil, err
554553
}
554+
555555
respBody, err := io.ReadAll(tokenResponse.Body)
556556
origResp := io.NopCloser(bytes.NewBuffer(respBody))
557557
tokenResponse.Body = origResp
558558
var accessToken *RequestAccessToken
559559

560+
newClientAssertion, err := createClientAssertion(orgURL, clientID, signer)
561+
if err != nil {
562+
return nil, "", nil, err
563+
}
564+
560565
if tokenResponse.StatusCode >= 300 {
561566
if strings.Contains(string(respBody), "invalid_dpop_proof") {
562-
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff)
567+
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff, newClientAssertion, strings.Join(scopes, " "), clientID, signer)
563568
} else {
564569
return nil, "", nil, err
565570
}
@@ -572,7 +577,7 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio
572577
return accessToken, "", nil, nil
573578
}
574579

575-
func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
580+
func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64, clientAssertion string, scopes string, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
576581
privateKey, err := generatePrivateKey(2048)
577582
if err != nil {
578583
return nil, "", nil, err
@@ -581,7 +586,19 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt
581586
if err != nil {
582587
return nil, "", nil, err
583588
}
589+
newClientAssertion, err := createClientAssertion(orgURL, clientID, signer)
590+
if err != nil {
591+
return nil, "", nil, err
592+
}
593+
594+
query := url.Values{}
595+
query.Add("grant_type", "client_credentials")
596+
query.Add("scope", scopes)
597+
query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
598+
query.Add("client_assertion", newClientAssertion)
599+
tokenRequest.Body = io.NopCloser(strings.NewReader(query.Encode()))
584600
tokenRequest.Header.Set("DPoP", dpopJWT)
601+
585602
bOff := &oktaBackoff{
586603
ctx: context.TODO(),
587604
maxRetries: maxRetries,
@@ -603,9 +620,9 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt
603620
}
604621

605622
if tokenResponse.StatusCode >= 300 {
606-
if strings.Contains(string(respBody), "use_dpop_nonce") {
623+
if strings.Contains(string(respBody), "use_dpop_nonce") {
607624
newNonce := tokenResponse.Header.Get("Dpop-Nonce")
608-
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff)
625+
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff, clientAssertion, scopes, clientID, signer)
609626
} else {
610627
return nil, "", nil, err
611628
}
@@ -780,9 +797,9 @@ func (c *APIClient) GetConfig() *Configuration {
780797
}
781798

782799
type formFile struct {
783-
fileBytes []byte
784-
fileName string
785-
formFileName string
800+
fileBytes []byte
801+
fileName string
802+
formFileName string
786803
}
787804

788805
// prepareRequest build the request
@@ -836,11 +853,11 @@ func (c *APIClient) prepareRequest(
836853
w.Boundary()
837854
part, err := w.CreateFormFile(formFile.formFileName, filepath.Base(formFile.fileName))
838855
if err != nil {
839-
return nil, err
856+
return nil, err
840857
}
841858
_, err = part.Write(formFile.fileBytes)
842859
if err != nil {
843-
return nil, err
860+
return nil, err
844861
}
845862
}
846863
}
@@ -879,7 +896,7 @@ func (c *APIClient) prepareRequest(
879896
URL.Scheme = c.cfg.Scheme
880897
}
881898

882-
var urlWithoutQuery = *URL
899+
urlWithoutQuery := *URL
883900

884901
// Adding Query Param
885902
query := URL.Query()
@@ -1103,7 +1120,7 @@ func (c *APIClient) RefreshNext() *APIClient {
11031120
return c
11041121
}
11051122

1106-
func (c *APIClient) do(ctx context.Context, req *http.Request)(*http.Response, error){
1123+
func (c *APIClient) do(ctx context.Context, req *http.Request) (*http.Response, error) {
11071124
cacheKey := CreateCacheKey(req)
11081125
if req.Method != http.MethodGet {
11091126
c.cache.Delete(cacheKey)
@@ -1343,9 +1360,9 @@ func (e GenericOpenAPIError) Model() interface{} {
13431360

13441361
// Okta Backoff
13451362
type oktaBackoff struct {
1346-
retryCount, maxRetries int32
1347-
backoffDuration time.Duration
1348-
ctx context.Context
1363+
retryCount, maxRetries int32
1364+
backoffDuration time.Duration
1365+
ctx context.Context
13491366
}
13501367

13511368
// NextBackOff returns the duration to wait before retrying the operation,
@@ -1456,7 +1473,7 @@ func generateDpopJWT(privateKey *rsa.PrivateKey, httpMethod, URL, nonce, accessT
14561473
return "", err
14571474
}
14581475
key := jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}
1459-
var signerOpts = jose.SignerOptions{}
1476+
signerOpts := jose.SignerOptions{}
14601477
signerOpts.WithType("dpop+jwt")
14611478
signerOpts.WithHeader("jwk", set)
14621479
rsaSigner, err := jose.NewSigner(key, &signerOpts)

0 commit comments

Comments
 (0)