From 6e63daa66d9ce15025296e79cdd2ab830d9946b8 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Sat, 14 Jun 2025 10:29:47 +0100 Subject: [PATCH 01/35] xds: read JWT credentials from file as per A97 --- credentials/jwt/doc.go | 57 ++ credentials/jwt/example_test.go | 58 ++ credentials/jwt/jwt_token_file.go | 287 +++++++ credentials/jwt/jwt_token_file_test.go | 784 ++++++++++++++++++ internal/envconfig/xds.go | 5 + internal/xds/bootstrap/bootstrap.go | 87 +- internal/xds/bootstrap/bootstrap_test.go | 606 +++++++++++++- internal/xds/bootstrap/jwtcreds/bundle.go | 81 ++ .../xds/bootstrap/jwtcreds/bundle_test.go | 214 +++++ xds/bootstrap/bootstrap.go | 4 + xds/bootstrap/bootstrap_test.go | 42 +- xds/bootstrap/credentials.go | 14 + xds/internal/xdsclient/clientimpl.go | 4 +- xds/internal/xdsclient/clientimpl_test.go | 91 ++ 14 files changed, 2320 insertions(+), 14 deletions(-) create mode 100644 credentials/jwt/doc.go create mode 100644 credentials/jwt/example_test.go create mode 100644 credentials/jwt/jwt_token_file.go create mode 100644 credentials/jwt/jwt_token_file_test.go create mode 100644 internal/xds/bootstrap/jwtcreds/bundle.go create mode 100644 internal/xds/bootstrap/jwtcreds/bundle_test.go diff --git a/credentials/jwt/doc.go b/credentials/jwt/doc.go new file mode 100644 index 000000000000..a3d561ee12a1 --- /dev/null +++ b/credentials/jwt/doc.go @@ -0,0 +1,57 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package jwt implements JWT token file-based call credentials. +// +// This package provides support for A97 JWT Call Credentials, allowing gRPC +// clients to authenticate using JWT tokens read from files. While originally +// designed for xDS environments, these credentials are general-purpose. +// +// # Usage +// +// The credentials can be used directly: +// +// import "google.golang.org/grpc/credentials/jwt" +// +// creds, err := jwt.NewTokenFileCallCredentials("/path/to/jwt.token") +// if err != nil { +// log.Fatal(err) +// } +// +// conn, err := grpc.NewClient("example.com:443", grpc.WithPerRPCCredentials(creds)) +// +// Or configured via xDS bootstrap file; see grpc/xds/bootstrap for details. +// +// # Token Requirements +// +// JWT tokens must: +// - Be valid, well-formed JWT tokens with header, payload, and signature +// - Include an "exp" (expiration) claim +// - Be readable from the specified file path +// +// # Considerations +// +// - Tokens are cached until expiration to avoid excessive file I/O +// - Transport security is required (RequireTransportSecurity returns true) +// - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or UNAUTHENTICATED errors +// - These errors are cached and retried with exponential backoff. +// +// This implementation is originally intended for use in service mesh +// environments like Istio where JWT tokens are provisioned and rotated by the +// infrastructure. +package jwt diff --git a/credentials/jwt/example_test.go b/credentials/jwt/example_test.go new file mode 100644 index 000000000000..14716f8e35ac --- /dev/null +++ b/credentials/jwt/example_test.go @@ -0,0 +1,58 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwt_test + +import ( + "context" + "log" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/jwt" +) + +// ExampleNewTokenFileCallCredentials demonstrates how to create and use JWT +// token file call credentials for authentication. +func ExampleNewTokenFileCallCredentials() { + // Create JWT call credentials that read tokens from a file + creds, err := jwt.NewTokenFileCallCredentials( + "/path/to/jwt.token", // Path to JWT token file + ) + if err != nil { + log.Fatalf("Failed to create JWT credentials: %v", err) + } + + // Use the credentials when creating a gRPC connection + conn, err := grpc.NewClient( + "service.example.com:443", + grpc.WithPerRPCCredentials(creds), + // ... other dial options + ) + if err != nil { + log.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Use the connection for RPC calls + // The JWT token will be automatically included in the authorization header + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + // ... make RPC calls using ctx and conn + _ = ctx +} diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go new file mode 100644 index 000000000000..62b63e963f3f --- /dev/null +++ b/credentials/jwt/jwt_token_file.go @@ -0,0 +1,287 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package jwt implements gRPC credentials using JWT tokens from files. +package jwt + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + "sync" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/status" +) + +// jwtClaims represents the JWT claims structure for extracting expiration time. +type jwtClaims struct { + Exp int64 `json:"exp"` +} + +// jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads +// tokens from a file. +// This implementation follows the A97 JWT Call Credentials specification. +type jwtTokenFileCallCreds struct { + tokenFilePath string + + // Cached token data + mu sync.RWMutex + cachedToken string + cachedExpiration time.Time // Slightly reduced expiration time compared to the actual exp + + // Error caching with backoff + cachedError error // Cached error from last failed attempt + cachedErrorTime time.Time // When the error was cached + backoffStrategy backoff.Strategy // Backoff strategy when error occurs + retryAttempt int // Current retry attempt number + nextRetryTime time.Time // When next retry is allowed + + // Pre-emptive refresh mutex + refreshMu sync.Mutex +} + +// NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens +// from the specified file path. +// +// tokenFilePath is the filepath to the JWT token file. +func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) { + if tokenFilePath == "" { + return nil, fmt.Errorf("tokenFilePath cannot be empty") + } + + return &jwtTokenFileCallCreds{ + tokenFilePath: tokenFilePath, + backoffStrategy: backoff.DefaultExponential, + }, nil +} + +// GetRequestMetadata gets the current request metadata, refreshing tokens +// if required. This implementation follows the PerRPCCredentials interface. +// The tokens will get automatically refreshed if they are about to expire or if +// they haven't been loaded successfully yet. In the latter case, a backoff is +// applied before retrying. +// If it's not possible to extract a token from the file, UNAVAILABLE is returned. +// If the token is extracted but invalid, then UNAUTHENTICATED is returned. +func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { + ri, _ := credentials.RequestInfoFromContext(ctx) + if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer JWT token file PerRPCCredentials: %v", err) + } + + // this may be delayed if the token needs to be refreshed from file + token, err := c.getToken(ctx) + if err != nil { + return nil, err + } + + return map[string]string{ + "authorization": "Bearer " + token, + }, nil +} + +// RequireTransportSecurity indicates whether the credentials requires +// transport security. +func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool { + return true +} + +// getToken returns a valid JWT token, reading from file if necessary. +// Implements pre-emptive refresh and caches errors with backoff. +func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { + c.mu.RLock() + + if c.isTokenValid() { + token := c.cachedToken + shouldRefresh := c.needsPreemptiveRefresh() + c.mu.RUnlock() + + if shouldRefresh { + c.triggerPreemptiveRefresh() + } + return token, nil + } + + // if still within backoff period, return cached error to avoid repeated file reads + if c.cachedError != nil && time.Now().Before(c.nextRetryTime) { + err := c.cachedError + c.mu.RUnlock() + return "", err + } + + c.mu.RUnlock() + // Token is expired or missing or the retry backoff period has expired. So + // refresh synchronously. + // NOTE: refreshTokenSync itself acquires the write lock + return c.refreshTokenSync(ctx, false) +} + +// isTokenValid checks if the cached token is still valid. +// Caller must hold c.mu.RLock(). +func (c *jwtTokenFileCallCreds) isTokenValid() bool { + if c.cachedToken == "" { + return false + } + return c.cachedExpiration.After(time.Now()) +} + +// needsPreemptiveRefresh checks if a pre-emptive refresh should be triggered. +// Returns true if the cached token is valid but expires within 1 minute. +// We only trigger pre-emptive refresh for valid tokens - if the token is invalid +// or expired, the next RPC will handle synchronous refresh instead. +// Caller must hold c.mu.RLock(). +func (c *jwtTokenFileCallCreds) needsPreemptiveRefresh() bool { + return c.isTokenValid() && time.Until(c.cachedExpiration) < time.Minute +} + +// triggerPreemptiveRefresh starts a background refresh if needed. +// Multiple concurrent calls are safe - only one refresh will run at a time. +// The refresh runs in a separate goroutine and does not block the caller. +func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { + go func() { + c.refreshMu.Lock() + defer c.refreshMu.Unlock() + + // Re-check if refresh is still needed under mutex + c.mu.RLock() + stillNeeded := c.needsPreemptiveRefresh() + c.mu.RUnlock() + + if !stillNeeded { + return // Another goroutine already refreshed or token expired + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Force refresh to read new token even if current one is still valid + _, _ = c.refreshTokenSync(ctx, true) + }() +} + +// refreshTokenSync reads a new token from the file and updates the cache. If +// preemptiveRefresh is true, bypasses the validity check of the currently cached +// token and always reads from file. +// This is used for pre-emptive refresh to ensure new tokens are loaded even when +// the cached token is still valid. If preemptiveRefresh is false, skips file read +// when cached token is still valid, optimizing concurrent synchronous refresh calls +// where one RPC may have already updated the cache while another was waiting on the lock. +func (c *jwtTokenFileCallCreds) refreshTokenSync(_ context.Context, preemptiveRefresh bool) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + + // Double-check under write lock but skip if preemptive refresh is requested + if !preemptiveRefresh && c.isTokenValid() { + return c.cachedToken, nil + } + + tokenBytes, err := os.ReadFile(c.tokenFilePath) + if err != nil { + err = status.Errorf(codes.Unavailable, "failed to read token file %q: %v", c.tokenFilePath, err) + c.setErrorWithBackoff(err) + return "", err + } + + token := strings.TrimSpace(string(tokenBytes)) + if token == "" { + err := status.Errorf(codes.Unavailable, "token file %q is empty", c.tokenFilePath) + c.setErrorWithBackoff(err) + return "", err + } + + // Parse JWT to extract expiration + exp, err := c.extractExpiration(token) + if err != nil { + err = status.Errorf(codes.Unauthenticated, "failed to parse JWT from token file %q: %v", c.tokenFilePath, err) + c.setErrorWithBackoff(err) + return "", err + } + + // Success - clear any cached error and backoff state, update token cache + c.clearErrorAndBackoff() + c.cachedToken = token + // Per RFC A97: consider token invalid if it expires within the next 30 + // seconds to accommodate for clock skew and server processing time. + c.cachedExpiration = exp.Add(-30 * time.Second) + + return token, nil +} + +// extractExpiration parses the JWT token to extract the expiration time. +func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) + } + + // Decode the payload (second part) + payload := parts[1] + + // Add padding if necessary for base64 decoding + for len(payload)%4 != 0 { + payload += "=" + } + + payloadBytes, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + return time.Time{}, fmt.Errorf("failed to decode JWT payload: %v", err) + } + + var claims jwtClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return time.Time{}, fmt.Errorf("failed to unmarshal JWT claims: %v", err) + } + + if claims.Exp == 0 { + return time.Time{}, fmt.Errorf("JWT token has no expiration claim") + } + + expTime := time.Unix(claims.Exp, 0) + + // Check if token is already expired + if expTime.Before(time.Now()) { + return time.Time{}, fmt.Errorf("JWT token is expired") + } + + return expTime, nil +} + +// setErrorWithBackoff caches an error and calculates the next retry time using exponential backoff. +// Caller must hold c.mu write lock. +func (c *jwtTokenFileCallCreds) setErrorWithBackoff(err error) { + c.cachedError = err + c.cachedErrorTime = time.Now() + c.retryAttempt++ + backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1) + c.nextRetryTime = time.Now().Add(backoffDelay) +} + +// clearErrorAndBackoff clears the cached error and resets backoff state. +// Caller must hold c.mu write lock. +func (c *jwtTokenFileCallCreds) clearErrorAndBackoff() { + c.cachedError = nil + c.cachedErrorTime = time.Time{} + c.retryAttempt = 0 + c.nextRetryTime = time.Time{} +} diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go new file mode 100644 index 000000000000..c611cc838d28 --- /dev/null +++ b/credentials/jwt/jwt_token_file_test.go @@ -0,0 +1,784 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwt + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/status" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestNewTokenFileCallCredentials(t *testing.T) { + tests := []struct { + name string + tokenFilePath string + wantErr bool + wantErrContains string + }{ + { + name: "valid parameters", + tokenFilePath: "/path/to/token", + wantErr: false, + }, + { + name: "empty token file path", + tokenFilePath: "", + wantErr: true, + wantErrContains: "tokenFilePath cannot be empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creds, err := NewTokenFileCallCredentials(tt.tokenFilePath) + if tt.wantErr { + if err == nil { + t.Fatalf("NewTokenFileCallCredentials() expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErrContains) + } + return + } + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err) + } + if creds == nil { + t.Fatal("NewTokenFileCallCredentials() returned nil credentials") + } + }) + } +} + +func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) { + creds, err := NewTokenFileCallCredentials("/path/to/token") + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + if !creds.RequireTransportSecurity() { + t.Error("RequireTransportSecurity() = false, want true") + } +} + +func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { + tempDir, err := os.MkdirTemp("", "jwt_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + now := time.Now().Truncate(time.Second) + tests := []struct { + name string + tokenContent string + authInfo credentials.AuthInfo + wantErr bool + wantErrContains string + wantMetadata map[string]string + }{ + { + name: "valid token without expiration errors", + tokenContent: createTestJWT(t, "", time.Time{}), + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + wantErr: true, + wantErrContains: "JWT token has no expiration claim", + }, + { + name: "valid token with future expiration succeeds", + tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)), + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + wantErr: false, + wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, "https://example.com", now.Add(time.Hour))}, + }, + { + name: "insufficient security level", + tokenContent: createTestJWT(t, "", time.Time{}), + authInfo: &testAuthInfo{secLevel: credentials.NoSecurity}, + wantErr: true, + wantErrContains: "unable to transfer JWT token file PerRPCCredentials", + }, + { + name: "expired token errors", + tokenContent: createTestJWT(t, "", now.Add(-time.Hour)), + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + wantErr: true, + wantErrContains: "JWT token is expired", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenFile := filepath.Join(tempDir, "token") + if err := os.WriteFile(tokenFile, []byte(tt.tokenContent), 0600); err != nil { + t.Fatalf("Failed to write token file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: tt.authInfo, + }) + + metadata, err := creds.GetRequestMetadata(ctx) + if tt.wantErr { + if err == nil { + t.Fatalf("GetRequestMetadata() expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains) + } + return + } + + if err != nil { + t.Fatalf("GetRequestMetadata() unexpected error: %v", err) + } + + if len(metadata) != len(tt.wantMetadata) { + t.Fatalf("GetRequestMetadata() returned %d metadata entries, want %d", len(metadata), len(tt.wantMetadata)) + } + + for k, v := range tt.wantMetadata { + if metadata[k] != v { + t.Errorf("GetRequestMetadata() metadata[%q] = %q, want %q", k, metadata[k], v) + } + } + }) + } +} + +func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { + tempDir, err := os.MkdirTemp("", "jwt_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + tokenFile := filepath.Join(tempDir, "token") + token := createTestJWT(t, "", time.Now().Add(time.Hour)) + + if err := os.WriteFile(tokenFile, []byte(token), 0600); err != nil { + t.Fatalf("Failed to write token file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // First call should read from file + metadata1, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("First GetRequestMetadata() failed: %v", err) + } + + // Update the file with a different token + newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) + if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { + t.Fatalf("Failed to update token file: %v", err) + } + + // Second call should return cached token (not the updated one) + metadata2, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("Second GetRequestMetadata() failed: %v", err) + } + + if metadata1["authorization"] != metadata2["authorization"] { + t.Error("Expected cached token to be returned, but got different token") + } +} + +func (s) TestTokenFileCallCreds_FileErrors(t *testing.T) { + tests := []struct { + name string + setupFile func(string) error + wantErrContains string + }{ + { + name: "nonexistent file", + setupFile: func(_ string) error { + return nil // Don't create the file + }, + wantErrContains: "failed to read token file", + }, + { + name: "empty file", + setupFile: func(path string) error { + return os.WriteFile(path, []byte(""), 0600) + }, + wantErrContains: "token file", + }, + { + name: "file with whitespace only", + setupFile: func(path string) error { + return os.WriteFile(path, []byte(" \n\t "), 0600) + }, + wantErrContains: "token file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "jwt_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + tokenFile := filepath.Join(tempDir, "token") + if err := tt.setupFile(tokenFile); err != nil { + t.Fatalf("Failed to setup test file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + _, err = creds.GetRequestMetadata(ctx) + if err == nil { + t.Fatal("GetRequestMetadata() expected error, got nil") + } + + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains) + } + }) + } +} + +// testAuthInfo implements credentials.AuthInfo for testing. +type testAuthInfo struct { + secLevel credentials.SecurityLevel +} + +func (t *testAuthInfo) AuthType() string { + return "test" +} + +func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} +} + +// createTestJWT creates a test JWT token with the specified audience and expiration. +func createTestJWT(t *testing.T, audience string, expiration time.Time) string { + t.Helper() + + header := map[string]any{ + "typ": "JWT", + "alg": "HS256", + } + + claims := map[string]any{} + if audience != "" { + claims["aud"] = audience + } + if !expiration.IsZero() { + claims["exp"] = expiration.Unix() + } + + headerBytes, err := json.Marshal(header) + if err != nil { + t.Fatalf("Failed to marshal header: %v", err) + } + + claimsBytes, err := json.Marshal(claims) + if err != nil { + t.Fatalf("Failed to marshal claims: %v", err) + } + + headerB64 := base64.URLEncoding.EncodeToString(headerBytes) + claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes) + + // Remove padding for URL-safe base64 + headerB64 = strings.TrimRight(headerB64, "=") + claimsB64 = strings.TrimRight(claimsB64, "=") + + // For testing, we'll use a fake signature + signature := base64.URLEncoding.EncodeToString([]byte("fake_signature")) + signature = strings.TrimRight(signature, "=") + + return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature) +} + +// Tests that cached token expiration is set to 30 seconds before actual token expiration. +func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { + tempDir := t.TempDir() + tokenFile := filepath.Join(tempDir, "token") + + // Create token that expires in 2 hours + tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour) + token := createTestJWT(t, "", tokenExp) + if err := os.WriteFile(tokenFile, []byte(token), 0600); err != nil { + t.Fatalf("Failed to write token file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // Get token to trigger caching + _, err = creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata() failed: %v", err) + } + + // Verify cached expiration is 30 seconds before actual token expiration + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + cachedExp := impl.cachedExpiration + impl.mu.RUnlock() + + expectedExp := tokenExp.Add(-30 * time.Second) + if !cachedExp.Equal(expectedExp) { + t.Errorf("cache expiration = %v, want %v", cachedExp, expectedExp) + } +} + +// Tests that pre-emptive refresh is triggered within 1 minute of expiration. +func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { + tempDir := t.TempDir() + tokenFile := filepath.Join(tempDir, "token") + + // Create token that expires in 80 seconds (=> cache expires in ~50s) + // This ensures pre-emptive refresh triggers since 50s < the 1 minute check + tokenExp := time.Now().Add(80 * time.Second) + expiringToken := createTestJWT(t, "", tokenExp) + if err := os.WriteFile(tokenFile, []byte(expiringToken), 0600); err != nil { + t.Fatalf("Failed to write token file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // Get token - should trigger pre-emptive refresh + metadata1, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata() failed: %v", err) + } + + // Verify token was cached and check if refresh should be triggered + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + cacheExp := impl.cachedExpiration + tokenCached := impl.cachedToken != "" + shouldTriggerRefresh := impl.needsPreemptiveRefresh() + impl.mu.RUnlock() + + if !tokenCached { + t.Error("token should be cached after successful GetRequestMetadata") + } + + if !shouldTriggerRefresh { + timeUntilExp := time.Until(cacheExp) + t.Errorf("cache expires in %v, should be < 1 minute to trigger pre-emptive refresh", timeUntilExp) + } + + // Create new token file with different expiration while refresh is happening + newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) + if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { + t.Fatalf("Failed to write updated token file: %v", err) + } + + // Get token again - should trigger a refresh given that the first one was + // cached but expiring soon + // However, the function should have returned right away with the current cached token + metadata2, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("Second GetRequestMetadata() failed: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // now should get the new token + metadata3, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("Second GetRequestMetadata() failed: %v", err) + } + + // If pre-emptive refresh worked, we should get the new token + expectedAuth1 := "Bearer " + expiringToken + expectedAuth2 := "Bearer " + expiringToken + expectedAuth3 := "Bearer " + newToken + + actualAuth1 := metadata1["authorization"] + actualAuth2 := metadata2["authorization"] + actualAuth3 := metadata3["authorization"] + + if actualAuth1 != expectedAuth1 { + t.Errorf("First call should return original token: got %q, want %q", actualAuth1, expectedAuth1) + } + + if actualAuth2 != expectedAuth2 { + t.Errorf("Second call should return the original token: got %q, want %q", actualAuth2, expectedAuth2) + } + if actualAuth3 != expectedAuth3 { + t.Errorf("Third call should return the original token: got %q, want %q", actualAuth3, expectedAuth3) + } +} + +// Tests that backoff behavior handles file read errors correctly. +func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { + // This test has the following flow: + // First call to GetRequestMetadata() fails with UNAVAILABLE due to a missing file. + // Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff. + // Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry. + // Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff even though file exists. + // Fifth call to GetRequestMetadata() succeeds after creating the file. + tempDir := t.TempDir() + nonExistentFile := filepath.Join(tempDir, "nonexistent") + + creds, err := NewTokenFileCallCredentials(nonExistentFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // First call should fail with UNAVAILABLE + _, err1 := creds.GetRequestMetadata(ctx) + if err1 == nil { + t.Fatal("Expected error from nonexistent file") + } + if status.Code(err1) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err1)) + } + + // Verify error is cached internally + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + cachedErr := impl.cachedError + cachedErrTime := impl.cachedErrorTime + retryAttempt := impl.retryAttempt + nextRetryTime := impl.nextRetryTime + impl.mu.RUnlock() + + if cachedErr == nil { + t.Error("error should be cached internally after failed file read") + } + if cachedErrTime.IsZero() { + t.Error("error cache time should be set") + } + if retryAttempt != 1 { + t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt) + } + if nextRetryTime.IsZero() || nextRetryTime.Before(time.Now()) { + t.Error("Next retry time should be set to future time") + } + + // Second call should still return cached error + _, err2 := creds.GetRequestMetadata(ctx) + if err2 == nil { + t.Fatal("Expected cached error") + } + if status.Code(err2) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err2)) + } + if err1.Error() != err2.Error() { + t.Errorf("cached error = %q, want %q", err2.Error(), err1.Error()) + } + + impl.mu.RLock() + retryAttempt2 := impl.retryAttempt + nextRetryTime2 := impl.nextRetryTime + impl.mu.RUnlock() + + if !nextRetryTime2.Equal(nextRetryTime) { + t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, nextRetryTime) + } + if retryAttempt2 != 1 { + t.Error("retry attempt should not change due to backoff") + } + + // fast-forward the backoff retry time to allow next retry attempt + impl.mu.Lock() + impl.nextRetryTime = time.Now().Add(-1 * time.Minute) + impl.mu.Unlock() + + // Third call should retry but still fail with UNAVAILABLE + _, err3 := creds.GetRequestMetadata(ctx) + if err3 == nil { + t.Fatal("Expected cached error") + } + if status.Code(err3) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err3)) + } + if err3.Error() != err1.Error() { + t.Errorf("cached error = %q, want %q", err3.Error(), err1.Error()) + } + + impl.mu.RLock() + retryAttempt3 := impl.retryAttempt + nextRetryTime3 := impl.nextRetryTime + impl.mu.RUnlock() + + if !nextRetryTime3.After(nextRetryTime2) { + t.Error("nextRetryTime should not change due to backoff") + } + if retryAttempt3 != 2 { + t.Error("retry attempt should not change due to backoff") + } + + // Create valid token file + validToken := createTestJWT(t, "", time.Now().Add(time.Hour)) + if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil { + t.Fatalf("Failed to create valid token file: %v", err) + } + + // Forth call should still fail even though the file now exists + _, err4 := creds.GetRequestMetadata(ctx) + if err4 == nil { + t.Fatal("Expected cached error") + } + if status.Code(err4) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err4)) + } + if err4.Error() != err3.Error() { + t.Errorf("cached error = %q, want %q", err4.Error(), err3.Error()) + } + + impl.mu.RLock() + retryAttempt4 := impl.retryAttempt + nextRetryTime4 := impl.nextRetryTime + impl.mu.RUnlock() + + if !nextRetryTime4.Equal(nextRetryTime3) { + t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, nextRetryTime3) + } + if retryAttempt4 != retryAttempt3 { + t.Error("retry attempt should not change due to backoff") + } + + // fast-forward the backoff retry time to allow next retry attempt + impl.mu.Lock() + impl.nextRetryTime = time.Now().Add(-1 * time.Minute) + impl.mu.Unlock() + // Fifth call should succeed since the file now exists + // and the backoff has expired + _, err5 := creds.GetRequestMetadata(ctx) + if err5 != nil { + t.Errorf("after creating valid token file, GetRequestMetadata() should eventually succeed, but got: %v", err5) + t.Error("backoff should expire and trigger new attempt on next RPC") + } else { + // If successful, verify error cache and backoff state were cleared + impl.mu.RLock() + clearedErr := impl.cachedError + clearedErrTime := impl.cachedErrorTime + retryAttempt := impl.retryAttempt + nextRetryTime := impl.nextRetryTime + impl.mu.RUnlock() + + if clearedErr != nil { + t.Errorf("after successful retry, cached error should be cleared, got: %v", clearedErr) + } + if !clearedErrTime.IsZero() { + t.Error("after successful retry, cached error time should be cleared") + } + if retryAttempt != 0 { + t.Errorf("after successful retry, retry attempt should be reset, got: %d", retryAttempt) + } + if !nextRetryTime.IsZero() { + t.Error("after successful retry, next retry time should be cleared") + } + } +} + +// Tests that invalid JWT tokens are handled with UNAUTHENTICATED status. +func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { + tempDir := t.TempDir() + tokenFile := filepath.Join(tempDir, "token") + + // Write invalid JWT (missing exp field) + invalidJWT := createTestJWT(t, "", time.Time{}) // No expiration + if err := os.WriteFile(tokenFile, []byte(invalidJWT), 0600); err != nil { + t.Fatalf("Failed to write token file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + _, err = creds.GetRequestMetadata(ctx) + if err == nil { + t.Fatal("Expected UNAUTHENTICATED from invalid JWT") + } + if status.Code(err) != codes.Unauthenticated { + t.Errorf("GetRequestMetadata() = %v, want UNAUTHENTICATED for invalid JWT", status.Code(err)) + } +} + +// Tests that RPCs are queued during file operations and all receive the same result. +func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { + tempDir := t.TempDir() + tokenFile := filepath.Join(tempDir, "token") + + // Start with no token file to force file read during first RPC + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // Launch multiple concurrent RPCs before creating the token file + const numConcurrentRPCs = 5 + results := make(chan error, numConcurrentRPCs) + + for range numConcurrentRPCs { + go func() { + _, err := creds.GetRequestMetadata(ctx) + results <- err + }() + } + + // Collect all results - they should all be the same error (UNAVAILABLE) + var errors []error + for range numConcurrentRPCs { + err := <-results + errors = append(errors, err) + } + + // All RPCs should fail with the same error (file not found) + for i, err := range errors { + if err == nil { + t.Errorf("RPC %d should have failed with UNAVAILABLE", i) + continue + } + if status.Code(err) != codes.Unavailable { + t.Errorf("RPC %d = %v, want UNAVAILABLE", i, status.Code(err)) + } + if i > 0 && err.Error() != errors[0].Error() { + t.Errorf("RPC %d error should match first RPC error for proper queueing", i) + } + } + + // Verify error was cached after concurrent RPCs + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + finalCachedErr := impl.cachedError + impl.mu.RUnlock() + + if finalCachedErr == nil { + t.Error("error should be cached after failed concurrent RPCs") + } + if finalCachedErr.Error() != errors[0].Error() { + t.Error("cached error should match the errors returned to RPCs") + } +} + +// Tests that no background retries occur when channel is idle. +func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) { + tempDir := t.TempDir() + tokenFilepath := filepath.Join(tempDir, "token") + + newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) + if err := os.WriteFile(tokenFilepath, []byte(newToken), 0600); err != nil { + t.Fatalf("Failed to write updated token file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFilepath) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + impl := creds.(*jwtTokenFileCallCreds) + + // Verify state unchanged - no background file reads attempted + impl.mu.RLock() + token := impl.cachedToken + cachedErr := impl.cachedError + impl.mu.RUnlock() + + time.Sleep(100 * time.Millisecond) + + if token != "" { + t.Errorf("after idle period, cached token = %q, want empty (no background reads)", token) + } + if cachedErr != nil { + t.Errorf("after idle period, cached error = %v, want nil (no background reads)", cachedErr) + } +} diff --git a/internal/envconfig/xds.go b/internal/envconfig/xds.go index e87551552ad7..6420558c0b7a 100644 --- a/internal/envconfig/xds.go +++ b/internal/envconfig/xds.go @@ -68,4 +68,9 @@ var ( // trust. For more details, see: // https://github.com/grpc/proposal/blob/master/A87-mtls-spiffe-support.md XDSSPIFFEEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false) + + // XDSBootstrapCallCredsEnabled controls if JWT call credentials can be used + // in xDS bootstrap configuration. For more details, see: + // https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md + XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false) ) diff --git a/internal/xds/bootstrap/bootstrap.go b/internal/xds/bootstrap/bootstrap.go index f409e4bd77b2..46dbf6bc98bc 100644 --- a/internal/xds/bootstrap/bootstrap.go +++ b/internal/xds/bootstrap/bootstrap.go @@ -31,6 +31,7 @@ import ( "strings" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" @@ -64,11 +65,26 @@ type ChannelCreds struct { Config json.RawMessage `json:"config,omitempty"` } +// CallCreds contains the call credentials configuration for individual RPCs. +// This type implements RFC A97 call credentials structure. +type CallCreds struct { + // Type contains a unique name identifying the call credentials type. + // Currently only "jwt_token_file" is supported. + Type string `json:"type,omitempty"` + // Config contains the JSON configuration associated with the call credentials. + Config json.RawMessage `json:"config,omitempty"` +} + // Equal reports whether cc and other are considered equal. func (cc ChannelCreds) Equal(other ChannelCreds) bool { return cc.Type == other.Type && bytes.Equal(cc.Config, other.Config) } +// Equal reports whether cc and other are considered equal. +func (cc CallCreds) Equal(other CallCreds) bool { + return cc.Type == other.Type && bytes.Equal(cc.Config, other.Config) +} + // String returns a string representation of the credentials. It contains the // type and the config (if non-nil) separated by a "-". func (cc ChannelCreds) String() string { @@ -172,13 +188,15 @@ type ServerConfig struct { serverURI string channelCreds []ChannelCreds serverFeatures []string + callCreds []CallCreds // As part of unmarshalling the JSON config into this struct, we ensure that // the credentials config is valid by building an instance of the specified // credentials and store it here for easy access. - selectedCreds ChannelCreds - credsDialOption grpc.DialOption - extraDialOptions []grpc.DialOption + selectedCreds ChannelCreds + credsDialOption grpc.DialOption + extraDialOptions []grpc.DialOption + selectedCallCreds []credentials.PerRPCCredentials // Built call credentials cleanups []func() } @@ -200,6 +218,17 @@ func (sc *ServerConfig) ServerFeatures() []string { return sc.serverFeatures } +// CallCreds returns the call credentials configuration for this server. +func (sc *ServerConfig) CallCreds() []CallCreds { + return sc.callCreds +} + +// SelectedCallCreds returns the built call credentials that are ready to use. +// These are the credentials that were successfully built from the call_creds configuration. +func (sc *ServerConfig) SelectedCallCreds() []credentials.PerRPCCredentials { + return sc.selectedCallCreds +} + // ServerFeaturesIgnoreResourceDeletion returns true if this server supports a // feature where the xDS client can ignore resource deletions from this server, // as described in gRFC A53. @@ -233,6 +262,28 @@ func (sc *ServerConfig) DialOptions() []grpc.DialOption { return dopts } +// DialOptionsWithCallCredsForTransport returns dial options including call credentials +// only if they are compatible with the specified transport credentials type. +// Call credentials that require transport security will be skipped for insecure transports. +func (sc *ServerConfig) DialOptionsWithCallCredsForTransport(transportCredsType string, transportCreds credentials.TransportCredentials) []grpc.DialOption { + dopts := sc.DialOptions() + + // Check if transport is insecure + isInsecureTransport := transportCredsType == "insecure" || + (transportCreds != nil && transportCreds.Info().SecurityProtocol == "insecure") + + // Add call credentials only if compatible with transport security + for _, callCred := range sc.selectedCallCreds { + // Skip call credentials that require transport security on insecure transports + if isInsecureTransport && callCred.RequireTransportSecurity() { + continue + } + dopts = append(dopts, grpc.WithPerRPCCredentials(callCred)) + } + + return dopts +} + // Cleanups returns a collection of functions to be called when the xDS client // for this server is closed. Allows cleaning up resources created specifically // for this server. @@ -251,6 +302,8 @@ func (sc *ServerConfig) Equal(other *ServerConfig) bool { return false case !slices.EqualFunc(sc.channelCreds, other.channelCreds, func(a, b ChannelCreds) bool { return a.Equal(b) }): return false + case !slices.EqualFunc(sc.callCreds, other.callCreds, func(a, b CallCreds) bool { return a.Equal(b) }): + return false case !slices.Equal(sc.serverFeatures, other.serverFeatures): return false case !sc.selectedCreds.Equal(other.selectedCreds): @@ -273,6 +326,7 @@ type serverConfigJSON struct { ServerURI string `json:"server_uri,omitempty"` ChannelCreds []ChannelCreds `json:"channel_creds,omitempty"` ServerFeatures []string `json:"server_features,omitempty"` + CallCreds []CallCreds `json:"call_creds,omitempty"` } // MarshalJSON returns marshaled JSON bytes corresponding to this server config. @@ -281,6 +335,7 @@ func (sc *ServerConfig) MarshalJSON() ([]byte, error) { ServerURI: sc.serverURI, ChannelCreds: sc.channelCreds, ServerFeatures: sc.serverFeatures, + CallCreds: sc.callCreds, } return json.Marshal(server) } @@ -301,6 +356,7 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.serverURI = server.ServerURI sc.channelCreds = server.ChannelCreds sc.serverFeatures = server.ServerFeatures + sc.callCreds = server.CallCreds for _, cc := range server.ChannelCreds { // We stop at the first credential type that we support. @@ -320,6 +376,27 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.cleanups = append(sc.cleanups, cancel) break } + + // Process call credentials - unlike channel creds, we use ALL supported types + // Call credentials are optional per RFC A97 + for _, callCred := range server.CallCreds { + c := bootstrap.GetCredentials(callCred.Type) + if c == nil { + // Skip unsupported call credential types (don't fail bootstrap) + continue + } + bundle, cancel, err := c.Build(callCred.Config) + if err != nil { + // Call credential validation failed - this should fail bootstrap + return fmt.Errorf("failed to build call credentials from bootstrap for %q: %v", callCred.Type, err) + } + // Extract the PerRPCCredentials from the bundle. Sanity check for nil just in case + if callCredentials := bundle.PerRPCCredentials(); callCredentials != nil { + sc.selectedCallCreds = append(sc.selectedCallCreds, callCredentials) + } + sc.cleanups = append(sc.cleanups, cancel) + } + if sc.serverURI == "" { return fmt.Errorf("xds: `server_uri` field in server config cannot be empty: %s", string(data)) } @@ -341,6 +418,9 @@ type ServerConfigTestingOptions struct { ChannelCreds []ChannelCreds // ServerFeatures represents the list of features supported by this server. ServerFeatures []string + // CallCreds contains a list of call credentials to use for individual RPCs + // to this server. Optional. + CallCreds []CallCreds } // ServerConfigForTesting creates a new ServerConfig from the passed in options, @@ -356,6 +436,7 @@ func ServerConfigForTesting(opts ServerConfigTestingOptions) (*ServerConfig, err ServerURI: opts.URI, ChannelCreds: cc, ServerFeatures: opts.ServerFeatures, + CallCreds: opts.CallCreds, } scJSON, err := json.Marshal(scInternal) if err != nil { diff --git a/internal/xds/bootstrap/bootstrap_test.go b/internal/xds/bootstrap/bootstrap_test.go index d057197804d6..93e90144fd28 100644 --- a/internal/xds/bootstrap/bootstrap_test.go +++ b/internal/xds/bootstrap/bootstrap_test.go @@ -19,15 +19,21 @@ package bootstrap import ( + "context" "encoding/json" "errors" "fmt" + "net" "os" + "strings" "testing" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials/jwt" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" @@ -196,6 +202,74 @@ var ( "server_features" : ["ignore_resource_deletion", "xds_v3"] }] }`, + // example data seeded from + // https://github.com/istio/istio/blob/master/pkg/istio-agent/testdata/grpc-bootstrap.json + "istioStyleWithJWTCallCreds": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "insecure" } + ], + "call_creds": [ + { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } + ], + "server_features" : ["xds_v3"] + }] + }`, + "istioStyleWithoutCallCreds": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "insecure" } + ], + "server_features" : ["xds_v3"] + }] + }`, + "istioStyleWithTLSAndJWT": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "tls", "config": {} } + ], + "call_creds": [ + { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } + ], + "server_features" : ["xds_v3"] + }] + }`, } metadata = &structpb.Struct{ Fields: map[string]*structpb.Value{ @@ -276,6 +350,82 @@ var ( node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", } + + istioNodeMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "GENERATOR": { + Kind: &structpb.Value_StringValue{StringValue: "grpc"}, + }, + "INSTANCE_IPS": { + Kind: &structpb.Value_StringValue{StringValue: "127.0.0.1"}, + }, + "ISTIO_VERSION": { + Kind: &structpb.Value_StringValue{StringValue: "1.26.2"}, + }, + "WORKLOAD_IDENTITY_SOCKET_FILE": { + Kind: &structpb.Value_StringValue{StringValue: "socket"}, + }, + }, + } + jwtCallCreds, _ = jwt.NewTokenFileCallCredentials("/var/run/secrets/tokens/istio-token") + selectedJWTCallCreds = []credentials.PerRPCCredentials{jwtCallCreds} + configWithIstioJWTCallCreds = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: []CallCreds{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, + serverFeatures: []string{"xds_v3"}, + selectedCreds: ChannelCreds{Type: "insecure"}, + selectedCallCreds: selectedJWTCallCreds, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } + + configWithIstioStyleNoCallCreds = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + serverFeatures: []string{"xds_v3"}, + selectedCreds: ChannelCreds{Type: "insecure"}, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } + + configWithIstioStyleWithTLSAndJWT = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "tls", Config: json.RawMessage("{}")}}, + callCreds: []CallCreds{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, + serverFeatures: []string{"xds_v3"}, + selectedCreds: ChannelCreds{Type: "tls", Config: json.RawMessage("{}")}, + selectedCallCreds: selectedJWTCallCreds, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } ) func fileReadFromFileMap(bootstrapFileMap map[string]string, name string) ([]byte, error) { @@ -425,6 +575,35 @@ func (s) TestGetConfiguration_Success(t *testing.T) { {"goodBootstrap", configWithGoogleDefaultCredsAndV3}, {"multipleXDSServers", configWithMultipleServers}, {"serverSupportsIgnoreResourceDeletion", configWithGoogleDefaultCredsAndIgnoreResourceDeletion}, + {"istioStyleWithoutCallCreds", configWithIstioStyleNoCallCreds}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testGetConfigurationWithFileNameEnv(t, test.name, false, test.wantConfig) + testGetConfigurationWithFileContentEnv(t, test.name, false, test.wantConfig) + }) + } +} + +// Tests Istio-style bootstrap configurations with JWT call credentials +func (s) TestGetConfiguration_IstioStyleWithCallCreds(t *testing.T) { + // Enable JWT call credentials feature + original := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = original + }() + + cancel := setupBootstrapOverride(v3BootstrapFileMap) + defer cancel() + + tests := []struct { + name string + wantConfig *Config + }{ + {"istioStyleWithJWTCallCreds", configWithIstioJWTCallCreds}, + {"istioStyleWithTLSAndJWT", configWithIstioStyleWithTLSAndJWT}, } for _, test := range tests { @@ -1018,12 +1197,203 @@ func (s) TestDefaultBundles(t *testing.T) { } } -type s struct { - grpctest.Tester +func (s) TestCallCreds_Equal(t *testing.T) { + tests := []struct { + name string + cc1 CallCreds + cc2 CallCreds + expect bool + }{ + { + name: "identical configs", + cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + expect: true, + }, + { + name: "different types", + cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCreds{Type: "other_type", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + expect: false, + }, + { + name: "different configs", + cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/different/path"}`)}, + expect: false, + }, + { + name: "nil vs non-nil configs", + cc1: CallCreds{Type: "jwt_token_file", Config: nil}, + cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + expect: false, + }, + { + name: "both nil configs", + cc1: CallCreds{Type: "jwt_token_file", Config: nil}, + cc2: CallCreds{Type: "jwt_token_file", Config: nil}, + expect: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.cc1.Equal(test.cc2) + if result != test.expect { + t.Errorf("CallCreds.Equal() = %v, want %v", result, test.expect) + } + }) + } } -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) +func (s) TestServerConfig_UnmarshalJSON_WithCallCreds(t *testing.T) { + original := envconfig.XDSBootstrapCallCredsEnabled + defer func() { envconfig.XDSBootstrapCallCredsEnabled = original }() + envconfig.XDSBootstrapCallCredsEnabled = true // Enable call creds in bootstrap + tests := []struct { + name string + json string + wantCallCreds []CallCreds + wantErr bool + errContains string + }{ + { + name: "valid call_creds with jwt_token_file", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/path/to/token.jwt"} + } + ] + }`, + wantCallCreds: []CallCreds{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file": "/path/to/token.jwt"}`), + }}, + }, + { + name: "multiple call_creds types", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + {"type": "jwt_token_file", "config": {"jwt_token_file": "/token1.jwt"}}, + {"type": "unsupported_type", "config": {}} + ] + }`, + wantCallCreds: []CallCreds{ + {Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/token1.jwt"}`)}, + {Type: "unsupported_type", Config: json.RawMessage(`{}`)}, + }, + }, + { + name: "empty call_creds array", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [] + }`, + wantCallCreds: []CallCreds{}, + }, + { + name: "missing call_creds field", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}] + }`, + wantCallCreds: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var sc ServerConfig + err := sc.UnmarshalJSON([]byte(test.json)) + + if test.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + if test.errContains != "" && !strings.Contains(err.Error(), test.errContains) { + t.Errorf("Error %v should contain %q", err, test.errContains) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if diff := cmp.Diff(test.wantCallCreds, sc.CallCreds()); diff != "" { + t.Errorf("CallCreds mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func (s) TestServerConfig_Equal_WithCallCreds(t *testing.T) { + callCreds := []CallCreds{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file": "/test/token.jwt"}`), + }} + sc1 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: callCreds, + serverFeatures: []string{"feature1"}, + } + sc2 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: callCreds, + serverFeatures: []string{"feature1"}, + } + sc3 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: []CallCreds{{Type: "different"}}, + serverFeatures: []string{"feature1"}, + } + + if !sc1.Equal(sc2) { + t.Error("Equal ServerConfigs with same call creds should be equal") + } + if sc1.Equal(sc3) { + t.Error("ServerConfigs with different call creds should not be equal") + } +} + +func (s) TestServerConfig_MarshalJSON_WithCallCreds(t *testing.T) { + original := envconfig.XDSBootstrapCallCredsEnabled + defer func() { envconfig.XDSBootstrapCallCredsEnabled = original }() + envconfig.XDSBootstrapCallCredsEnabled = true // Enable call creds in bootstrap + sc := &ServerConfig{ + serverURI: "test-server:443", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: []CallCreds{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file":"/test/token.jwt"}`), + }}, + serverFeatures: []string{"test_feature"}, + } + + data, err := sc.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON failed: %v", err) + } + + // confirm Marshal/Unmarshal symmetry + var unmarshaled ServerConfig + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if diff := cmp.Diff(sc.CallCreds(), unmarshaled.CallCreds()); diff != "" { + t.Errorf("Marshal/Unmarshal call credentials produces differences:\n%s", diff) + } } func newStructProtoFromMap(t *testing.T, input map[string]any) *structpb.Struct { @@ -1269,3 +1639,231 @@ func (s) TestGetConfiguration_FallbackDisabled(t *testing.T) { testGetConfigurationWithFileContentEnv(t, "multipleXDSServers", false, wantConfig) }) } + +func (s) TestBootstrap_SelectedCredsAndCallCreds(t *testing.T) { + // Enable JWT call credentials + original := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = original + }() + + tokenFile := "/token.jwt" + tests := []struct { + name string + bootstrapConfig string + expectCallCreds int + expectTransportType string + }{ + { + name: "JWT call creds with TLS channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }`, + expectCallCreds: 1, + expectTransportType: "tls", + }, + { + name: "JWT call creds with multiple channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}, {"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + }, + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }`, + expectCallCreds: 2, + expectTransportType: "tls", // the first channel creds is selected + }, + { + name: "JWT call creds with insecure channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }`, + expectCallCreds: 1, + expectTransportType: "insecure", + }, + { + name: "No call creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}] + }`, + expectCallCreds: 0, + expectTransportType: "insecure", + }, + { + name: "No call creds multiple channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}, {"type": "tls", "config": {}}] + }`, + expectCallCreds: 0, + expectTransportType: "insecure", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var sc ServerConfig + err := sc.UnmarshalJSON([]byte(test.bootstrapConfig)) + if err != nil { + t.Fatalf("Failed to unmarshal bootstrap config: %v", err) + } + + // Verify call credentials processing + callCreds := sc.CallCreds() + selectedCallCreds := sc.SelectedCallCreds() + + if len(callCreds) != test.expectCallCreds { + t.Errorf("Call creds count = %d, want %d", len(callCreds), test.expectCallCreds) + } + if len(selectedCallCreds) != test.expectCallCreds { + t.Errorf("Selected call creds count = %d, want %d", len(selectedCallCreds), test.expectCallCreds) + } + + // Verify transport credentials are properly selected + if sc.SelectedCreds().Type != test.expectTransportType { + t.Errorf("Selected transport creds type = %q, want %q", + sc.SelectedCreds().Type, test.expectTransportType) + } + }) + } +} + +func (s) TestDialOptionsWithCallCredsForTransport(t *testing.T) { + // Create test JWT credentials that require transport security + testJWTCreds := &testPerRPCCreds{requireSecurity: true} + testInsecureCreds := &testPerRPCCreds{requireSecurity: false} + + sc := &ServerConfig{ + selectedCallCreds: []credentials.PerRPCCredentials{ + testJWTCreds, + testInsecureCreds, + }, + extraDialOptions: []grpc.DialOption{ + grpc.WithUserAgent("test-agent"), // Test extra option + }, + } + + tests := []struct { + name string + transportType string + transportCreds credentials.TransportCredentials + expectJWTCreds bool + expectOtherCreds bool + }{ + { + name: "insecure transport by type", + transportType: "insecure", + transportCreds: nil, + expectJWTCreds: false, // JWT requires security + expectOtherCreds: true, // Non-security creds allowed + }, + { + name: "insecure transport by protocol", + transportType: "custom", + transportCreds: insecure.NewCredentials(), + expectJWTCreds: false, // JWT requires security + expectOtherCreds: true, // Non-security creds allowed + }, + { + name: "secure transport", + transportType: "tls", + transportCreds: &testTransportCreds{securityProtocol: "tls"}, + expectJWTCreds: true, // JWT allowed on secure transport + expectOtherCreds: true, // All creds allowed + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := sc.DialOptionsWithCallCredsForTransport(test.transportType, test.transportCreds) + + // Count dial options (should include extra options + applicable call creds) + expectedCount := 2 // extraDialOptions + always include non-security creds + if test.expectJWTCreds { + expectedCount++ + } + + if len(opts) != expectedCount { + t.Errorf("DialOptions count = %d, want %d", len(opts), expectedCount) + } + }) + } +} + +type testPerRPCCreds struct { + requireSecurity bool +} + +func (c *testPerRPCCreds) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + return map[string]string{"test": "metadata"}, nil +} + +func (c *testPerRPCCreds) RequireTransportSecurity() bool { + return c.requireSecurity +} + +type testTransportCreds struct { + securityProtocol string +} + +func (c *testTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &testAuthInfo{}, nil +} + +func (c *testTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &testAuthInfo{}, nil +} + +func (c *testTransportCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: c.securityProtocol} +} + +func (c *testTransportCreds) Clone() credentials.TransportCredentials { + return &testTransportCreds{securityProtocol: c.securityProtocol} +} + +func (c *testTransportCreds) OverrideServerName(string) error { + return nil +} + +type testAuthInfo struct{} + +func (a *testAuthInfo) AuthType() string { + return "test" +} + +func (a *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{} +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} diff --git a/internal/xds/bootstrap/jwtcreds/bundle.go b/internal/xds/bootstrap/jwtcreds/bundle.go new file mode 100644 index 000000000000..2b2b2103e908 --- /dev/null +++ b/internal/xds/bootstrap/jwtcreds/bundle.go @@ -0,0 +1,81 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package jwtcreds implements JWT Call Credentials in xDS Bootstrap File. +// See gRFC A97: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md +package jwtcreds + +import ( + "encoding/json" + "fmt" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/jwt" +) + +// bundle is an implementation of credentials.Bundle which implements JWT +// Call Credentials in xDS Bootstrap File per RFC A97. +// This bundle only provides call credentials, not transport credentials. +type bundle struct { + transportCreds credentials.TransportCredentials // Always nil for JWT call creds + callCreds credentials.PerRPCCredentials +} + +// NewBundle returns a credentials.Bundle which implements JWT Call Credentials +// in xDS Bootstrap File per RFC A97. This implementation focuses on call credentials +// only and expects the config to match RFC A97 structure. +// See gRFC A97: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md +func NewBundle(configJSON json.RawMessage) (credentials.Bundle, func(), error) { + var cfg struct { + JWTTokenFile string `json:"jwt_token_file"` + } + + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal JWT call credentials config: %v", err) + } + + if cfg.JWTTokenFile == "" { + return nil, nil, fmt.Errorf("jwt_token_file is required in JWT call credentials config") + } + + // Create JWT call credentials + callCreds, err := jwt.NewTokenFileCallCredentials(cfg.JWTTokenFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to create JWT call credentials: %v", err) + } + + bundle := &bundle{ + transportCreds: nil, // JWT call creds don't provide transport security + callCreds: callCreds, + } + + return bundle, func() {}, nil +} + +func (b *bundle) TransportCredentials() credentials.TransportCredentials { + // Transport credentials should be configured separately via channel_creds + return nil +} + +func (b *bundle) PerRPCCredentials() credentials.PerRPCCredentials { + return b.callCreds +} + +func (b *bundle) NewWithMode(_ string) (credentials.Bundle, error) { + return nil, fmt.Errorf("JWT call credentials bundle does not support mode switching") +} diff --git a/internal/xds/bootstrap/jwtcreds/bundle_test.go b/internal/xds/bootstrap/jwtcreds/bundle_test.go new file mode 100644 index 000000000000..74f49a710246 --- /dev/null +++ b/internal/xds/bootstrap/jwtcreds/bundle_test.go @@ -0,0 +1,214 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwtcreds + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" +) + +func TestNewBundle(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + + tests := []struct { + name string + config string + wantErr bool + wantErrContains string + }{ + { + name: "valid RFC A97 config with jwt_token_file", + config: `{ + "jwt_token_file": "` + tokenFile + `" + }`, + wantErr: false, + }, + { + name: "empty config", + config: `""`, + wantErr: true, + wantErrContains: "unmarshal", + }, + { + name: "empty config", + config: `{}`, + wantErr: true, + wantErrContains: "jwt_token_file is required", + }, + { + name: "empty path", + config: `{ + "jwt_token_file": "" + }`, + wantErr: true, + wantErrContains: "jwt_token_file is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bundle, cleanup, err := NewBundle(json.RawMessage(tt.config)) + + if tt.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Errorf("Error %v should contain %q", err, tt.wantErrContains) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if bundle == nil { + t.Fatal("Expected non-nil bundle") + } + + if cleanup == nil { + t.Error("Expected non-nil cleanup function") + } else { + defer cleanup() + } + + // JWT bundle only deals with PerRPCCredentials, not TransportCredentials + if bundle.TransportCredentials() != nil { + t.Error("Expected nil transport credentials for JWT call creds bundle") + } + + if bundle.PerRPCCredentials() == nil { + t.Error("Expected non-nil per-RPC credentials for valid JWT config") + } + + // Test that call credentials work + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + metadata, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata failed: %v", err) + } + + if len(metadata) == 0 { + t.Error("Expected metadata to be returned") + } + + authHeader, ok := metadata["authorization"] + if !ok { + t.Error("Expected authorization header in metadata") + } + + if !strings.HasPrefix(authHeader, "Bearer ") { + t.Errorf("Authorization header should start with 'Bearer ', got %q", authHeader) + } + }) + } +} + +func TestBundle_NewWithMode(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + config := `{"jwt_token_file": "` + tokenFile + `"}` + bundle, cleanup, err := NewBundle(json.RawMessage(config)) + if err != nil { + t.Fatalf("NewBundle failed: %v", err) + } + defer cleanup() + + _, err = bundle.NewWithMode("test_mode") + if err == nil { + t.Error("Expected error from NewWithMode, got nil") + } + if !strings.Contains(err.Error(), "does not support mode switching") { + t.Errorf("Error should mention mode switching, got: %v", err) + } +} + +func TestBundle_Cleanup(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + config := `{"jwt_token_file": "` + tokenFile + `"}` + _, cleanup, err := NewBundle(json.RawMessage(config)) + if err != nil { + t.Fatalf("NewBundle failed: %v", err) + } + + if cleanup == nil { + t.Fatal("Expected non-nil cleanup function") + } + + // Cleanup should not panic + cleanup() + + // Multiple cleanup calls should be safe + cleanup() +} + +// testAuthInfo implements credentials.AuthInfo for testing +type testAuthInfo struct { + secLevel credentials.SecurityLevel +} + +func (t *testAuthInfo) AuthType() string { + return "test" +} + +func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} +} + +// createTestJWT creates a test JWT token for testing +func createTestJWT(t *testing.T) string { + t.Helper() + + // Create a valid JWT with proper base64 encoding for testing + // Header: {"typ":"JWT","alg":"HS256"} + header := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9" + + // Claims: {"aud":"https://example.com","exp":future_timestamp} + claims := "eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tIiwiZXhwIjoyMDAwMDAwMDAwfQ" + + // Fake signature for testing + signature := "fake_signature_for_testing" + + return header + "." + claims + "." + signature +} + +func writeTempFile(t *testing.T, content string) string { + t.Helper() + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, "tempfile") + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + return filePath +} diff --git a/xds/bootstrap/bootstrap.go b/xds/bootstrap/bootstrap.go index ef55ff0c02db..b1a5e831b2a6 100644 --- a/xds/bootstrap/bootstrap.go +++ b/xds/bootstrap/bootstrap.go @@ -29,6 +29,7 @@ import ( "encoding/json" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/envconfig" ) // registry is a map from credential type name to Credential builder. @@ -58,6 +59,9 @@ func RegisterCredentials(c Credentials) { // GetCredentials returns the credentials associated with a given name. // If no credentials are registered with the name, nil will be returned. func GetCredentials(name string) Credentials { + if name == "jwt_token_file" && !envconfig.XDSBootstrapCallCredsEnabled { + return nil + } if c, ok := registry[name]; ok { return c } diff --git a/xds/bootstrap/bootstrap_test.go b/xds/bootstrap/bootstrap_test.go index d1f7a1b64ee5..935976975513 100644 --- a/xds/bootstrap/bootstrap_test.go +++ b/xds/bootstrap/bootstrap_test.go @@ -22,6 +22,7 @@ import ( "testing" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/envconfig" ) const testCredsBuilderName = "test_creds" @@ -64,12 +65,14 @@ func TestRegisterNew(t *testing.T) { func TestCredsBuilders(t *testing.T) { tests := []struct { - typename string - builder Credentials + typename string + builder Credentials + minimumRequiredConfig json.RawMessage }{ - {"google_default", &googleDefaultCredsBuilder{}}, - {"insecure", &insecureCredsBuilder{}}, - {"tls", &tlsCredsBuilder{}}, + {"google_default", &googleDefaultCredsBuilder{}, nil}, + {"insecure", &insecureCredsBuilder{}, nil}, + {"tls", &tlsCredsBuilder{}, nil}, + {"jwt_token_file", &jwtCallCredsBuilder{}, json.RawMessage(`{"jwt_token_file":"/path/to/token.jwt"}`)}, } for _, test := range tests { @@ -78,10 +81,13 @@ func TestCredsBuilders(t *testing.T) { t.Errorf("%T.Name = %v, want %v", test.builder, got, want) } - _, stop, err := test.builder.Build(nil) + bundle, stop, err := test.builder.Build(test.minimumRequiredConfig) if err != nil { t.Fatalf("%T.Build failed: %v", test.builder, err) } + if bundle == nil { + t.Errorf("%T.Build returned nil bundle, expected non-nil", test.builder) + } stop() }) } @@ -100,3 +106,27 @@ func TestTlsCredsBuilder(t *testing.T) { stop() } } + +func TestJwtCallCredentials_BuildDisabledIfFeatureNotEnabled(t *testing.T) { + builder := GetCredentials("jwt_call_creds") + if builder != nil { + t.Fatal("Expected nil Credentials for jwt_call_creds when the feature is disabled.") + } + + // Enable JWT call credentials + original := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = original + }() + + // Test that GetCredentials returns the JWT builder + builder = GetCredentials("jwt_token_file") + if builder == nil { + t.Fatal("GetCredentials(\"jwt_token_file\") returned nil") + } + + if got, want := builder.Name(), "jwt_token_file"; got != want { + t.Errorf("Retrieved builder name = %q, want %q", got, want) + } +} diff --git a/xds/bootstrap/credentials.go b/xds/bootstrap/credentials.go index 578e1278970d..38018972f383 100644 --- a/xds/bootstrap/credentials.go +++ b/xds/bootstrap/credentials.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/google" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/xds/bootstrap/jwtcreds" "google.golang.org/grpc/internal/xds/bootstrap/tlscreds" ) @@ -31,6 +32,7 @@ func init() { RegisterCredentials(&insecureCredsBuilder{}) RegisterCredentials(&googleDefaultCredsBuilder{}) RegisterCredentials(&tlsCredsBuilder{}) + RegisterCredentials(&jwtCallCredsBuilder{}) } // insecureCredsBuilder implements the `Credentials` interface defined in @@ -68,3 +70,15 @@ func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func (d *googleDefaultCredsBuilder) Name() string { return "google_default" } + +// jwtCallCredsBuilder implements the `Credentials` interface defined in +// package `xds/bootstrap` and encapsulates JWT call credentials. +type jwtCallCredsBuilder struct{} + +func (j *jwtCallCredsBuilder) Build(configJSON json.RawMessage) (credentials.Bundle, func(), error) { + return jwtcreds.NewBundle(configJSON) +} + +func (j *jwtCallCredsBuilder) Name() string { + return "jwt_token_file" +} diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 967182740719..80bf8d0e8183 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -229,7 +229,9 @@ func populateGRPCTransportConfigsFromServerConfig(sc *bootstrap.ServerConfig, gr grpcTransportConfigs[cc.Type] = grpctransport.Config{ Credentials: bundle, GRPCNewClient: func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - opts = append(opts, sc.DialOptions()...) + // Only add call credentials that are compatible with this transport type + // Call credentials requiring transport security are skipped for insecure transports + opts = append(opts, sc.DialOptionsWithCallCredsForTransport(cc.Type, bundle.TransportCredentials())...) return grpc.NewClient(target, opts...) }, } diff --git a/xds/internal/xdsclient/clientimpl_test.go b/xds/internal/xdsclient/clientimpl_test.go index fbfc24a074ec..c7884e8ebff6 100644 --- a/xds/internal/xdsclient/clientimpl_test.go +++ b/xds/internal/xdsclient/clientimpl_test.go @@ -19,8 +19,10 @@ package xdsclient import ( + "context" "encoding/json" "fmt" + "net" "reflect" "sync" "testing" @@ -28,7 +30,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/internal/xds/bootstrap" "google.golang.org/grpc/xds/internal/clients" @@ -259,3 +263,90 @@ func (s) TestBuildXDSClientConfig_Success(t *testing.T) { }) } } + +func TestServerConfigCallCredsIntegration(t *testing.T) { + // Enable JWT call credentials + originalJWTEnabled := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = originalJWTEnabled + }() + + tokenFile := "/token.jwt" + // Test server config with both channel and call credentials + serverConfigJSON := `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }` + + var sc bootstrap.ServerConfig + if err := sc.UnmarshalJSON([]byte(serverConfigJSON)); err != nil { + t.Fatalf("Failed to unmarshal server config: %v", err) + } + + // Verify call credentials are processed + callCreds := sc.CallCreds() + if len(callCreds) != 1 { + t.Errorf("Expected 1 call credential, got %d", len(callCreds)) + } + + selectedCallCreds := sc.SelectedCallCreds() + if len(selectedCallCreds) != 1 { + t.Errorf("Expected 1 selected call credential, got %d", len(selectedCallCreds)) + } + + // Test dial options for secure transport (should include JWT) + secureOpts := sc.DialOptionsWithCallCredsForTransport("tls", &mockTransportCreds{protocol: "tls"}) + if len(secureOpts) != 1 { + t.Errorf("Expected dial options for secure transport. Got: %#v", secureOpts) + } + + // Test dial options for insecure transport (should exclude JWT) + insecureOpts := sc.DialOptionsWithCallCredsForTransport("insecure", &mockTransportCreds{protocol: "insecure"}) + + // JWT should be filtered out for insecure transport + if len(insecureOpts) >= len(secureOpts) { + t.Error("Expected fewer dial options for insecure transport (JWT should be filtered)") + } +} + +// Mock transport credentials for testing +type mockTransportCreds struct { + protocol string +} + +func (m *mockTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &mockAuthInfo{}, nil +} + +func (m *mockTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &mockAuthInfo{}, nil +} + +func (m *mockTransportCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: m.protocol} +} + +func (m *mockTransportCreds) Clone() credentials.TransportCredentials { + return &mockTransportCreds{protocol: m.protocol} +} + +func (m *mockTransportCreds) OverrideServerName(string) error { + return nil +} + +type mockAuthInfo struct{} + +func (m *mockAuthInfo) AuthType() string { + return "mock" +} + +func (m *mockAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{} +} From 3268ea5aab9a16e1f547ca512372dba30d9dc901 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Sun, 6 Jul 2025 18:50:28 +0100 Subject: [PATCH 02/35] remove example --- credentials/jwt/example_test.go | 58 --------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 credentials/jwt/example_test.go diff --git a/credentials/jwt/example_test.go b/credentials/jwt/example_test.go deleted file mode 100644 index 14716f8e35ac..000000000000 --- a/credentials/jwt/example_test.go +++ /dev/null @@ -1,58 +0,0 @@ -/* - * - * Copyright 2025 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package jwt_test - -import ( - "context" - "log" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/jwt" -) - -// ExampleNewTokenFileCallCredentials demonstrates how to create and use JWT -// token file call credentials for authentication. -func ExampleNewTokenFileCallCredentials() { - // Create JWT call credentials that read tokens from a file - creds, err := jwt.NewTokenFileCallCredentials( - "/path/to/jwt.token", // Path to JWT token file - ) - if err != nil { - log.Fatalf("Failed to create JWT credentials: %v", err) - } - - // Use the credentials when creating a gRPC connection - conn, err := grpc.NewClient( - "service.example.com:443", - grpc.WithPerRPCCredentials(creds), - // ... other dial options - ) - if err != nil { - log.Fatalf("Failed to connect: %v", err) - } - defer conn.Close() - - // Use the connection for RPC calls - // The JWT token will be automatically included in the authorization header - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - // ... make RPC calls using ctx and conn - _ = ctx -} From b18a1f561fa36d056aec1ff0e2a1dd604c74e45c Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Sun, 6 Jul 2025 19:23:25 +0100 Subject: [PATCH 03/35] refactor test creation --- credentials/jwt/jwt_token_file.go | 4 +- credentials/jwt/jwt_token_file_test.go | 54 ++++++++------------------ 2 files changed, 18 insertions(+), 40 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index 62b63e963f3f..9d78e7bebc98 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -80,10 +80,10 @@ func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCreden // GetRequestMetadata gets the current request metadata, refreshing tokens // if required. This implementation follows the PerRPCCredentials interface. // The tokens will get automatically refreshed if they are about to expire or if -// they haven't been loaded successfully yet. In the latter case, a backoff is -// applied before retrying. +// they haven't been loaded successfully yet. // If it's not possible to extract a token from the file, UNAVAILABLE is returned. // If the token is extracted but invalid, then UNAUTHENTICATED is returned. +// If errors are encoutered, a backoff is applied before retrying. func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { ri, _ := credentials.RequestInfoFromContext(ctx) if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index c611cc838d28..6405ef9c0bf1 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -144,10 +144,7 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenFile := filepath.Join(tempDir, "token") - if err := os.WriteFile(tokenFile, []byte(tt.tokenContent), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", tt.tokenContent) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -189,18 +186,9 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { } func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { - tempDir, err := os.MkdirTemp("", "jwt_test") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tempDir) - tokenFile := filepath.Join(tempDir, "token") token := createTestJWT(t, "", time.Now().Add(time.Hour)) - - if err := os.WriteFile(tokenFile, []byte(token), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", token) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -357,15 +345,10 @@ func createTestJWT(t *testing.T, audience string, expiration time.Time) string { // Tests that cached token expiration is set to 30 seconds before actual token expiration. func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - // Create token that expires in 2 hours tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour) token := createTestJWT(t, "", tokenExp) - if err := os.WriteFile(tokenFile, []byte(token), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", token) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -398,16 +381,11 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin // Tests that pre-emptive refresh is triggered within 1 minute of expiration. func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - // Create token that expires in 80 seconds (=> cache expires in ~50s) // This ensures pre-emptive refresh triggers since 50s < the 1 minute check tokenExp := time.Now().Add(80 * time.Second) expiringToken := createTestJWT(t, "", tokenExp) - if err := os.WriteFile(tokenFile, []byte(expiringToken), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", expiringToken) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -658,14 +636,9 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { // Tests that invalid JWT tokens are handled with UNAUTHENTICATED status. func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - // Write invalid JWT (missing exp field) invalidJWT := createTestJWT(t, "", time.Time{}) // No expiration - if err := os.WriteFile(tokenFile, []byte(invalidJWT), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", invalidJWT) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -752,13 +725,8 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { // Tests that no background retries occur when channel is idle. func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) { - tempDir := t.TempDir() - tokenFilepath := filepath.Join(tempDir, "token") - newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) - if err := os.WriteFile(tokenFilepath, []byte(newToken), 0600); err != nil { - t.Fatalf("Failed to write updated token file: %v", err) - } + tokenFilepath := writeTempFile(t, "token", newToken) creds, err := NewTokenFileCallCredentials(tokenFilepath) if err != nil { @@ -782,3 +750,13 @@ func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) { t.Errorf("after idle period, cached error = %v, want nil (no background reads)", cachedErr) } } + +func writeTempFile(t *testing.T, name, content string) string { + t.Helper() + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, name) + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + return filePath +} From eb391affff0a878dca66742625dc1dd624afa7ee Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Sat, 26 Jul 2025 11:00:33 +0100 Subject: [PATCH 04/35] refactor token string padding --- credentials/jwt/jwt_token_file.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index 9d78e7bebc98..e780472ca4e0 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -235,12 +235,10 @@ func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, erro return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) } - // Decode the payload (second part) payload := parts[1] - // Add padding if necessary for base64 decoding - for len(payload)%4 != 0 { - payload += "=" + if m := len(payload) % 4; m != 0 { + payload += strings.Repeat("=", 4-m) } payloadBytes, err := base64.URLEncoding.DecodeString(payload) From d43893ab374e6de5103a922d8f4a14b2425fed38 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 08:20:56 +0100 Subject: [PATCH 05/35] remove example; mark as experimental --- credentials/jwt/doc.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/credentials/jwt/doc.go b/credentials/jwt/doc.go index a3d561ee12a1..f74d3446afb4 100644 --- a/credentials/jwt/doc.go +++ b/credentials/jwt/doc.go @@ -22,20 +22,7 @@ // clients to authenticate using JWT tokens read from files. While originally // designed for xDS environments, these credentials are general-purpose. // -// # Usage -// -// The credentials can be used directly: -// -// import "google.golang.org/grpc/credentials/jwt" -// -// creds, err := jwt.NewTokenFileCallCredentials("/path/to/jwt.token") -// if err != nil { -// log.Fatal(err) -// } -// -// conn, err := grpc.NewClient("example.com:443", grpc.WithPerRPCCredentials(creds)) -// -// Or configured via xDS bootstrap file; see grpc/xds/bootstrap for details. +// The credentials can be used directly in gRPC clients or configured via xDS. // // # Token Requirements // @@ -48,10 +35,16 @@ // // - Tokens are cached until expiration to avoid excessive file I/O // - Transport security is required (RequireTransportSecurity returns true) -// - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or UNAUTHENTICATED errors +// - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or +// UNAUTHENTICATED errors // - These errors are cached and retried with exponential backoff. // // This implementation is originally intended for use in service mesh // environments like Istio where JWT tokens are provisioned and rotated by the // infrastructure. +// +// # Experimental +// +// Notice: All APIs in this package are experimental and may be removed in a +// later release. package jwt From 167b86ec46a760680cb7a9a4f1aca17eb8615279 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 08:33:24 +0100 Subject: [PATCH 06/35] reorganise struct attributes --- credentials/jwt/jwt_token_file.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index e780472ca4e0..9a24b4f890c4 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -44,19 +44,17 @@ type jwtClaims struct { // tokens from a file. // This implementation follows the A97 JWT Call Credentials specification. type jwtTokenFileCallCreds struct { - tokenFilePath string + tokenFilePath string + backoffStrategy backoff.Strategy // Backoff strategy when error occurs // Cached token data mu sync.RWMutex cachedToken string cachedExpiration time.Time // Slightly reduced expiration time compared to the actual exp - - // Error caching with backoff - cachedError error // Cached error from last failed attempt - cachedErrorTime time.Time // When the error was cached - backoffStrategy backoff.Strategy // Backoff strategy when error occurs - retryAttempt int // Current retry attempt number - nextRetryTime time.Time // When next retry is allowed + cachedError error // Cached error from last failed attempt + cachedErrorTime time.Time // When the error was cached + retryAttempt int // Current retry attempt number + nextRetryTime time.Time // When next retry is allowed // Pre-emptive refresh mutex refreshMu sync.Mutex @@ -64,8 +62,6 @@ type jwtTokenFileCallCreds struct { // NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens // from the specified file path. -// -// tokenFilePath is the filepath to the JWT token file. func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) { if tokenFilePath == "" { return nil, fmt.Errorf("tokenFilePath cannot be empty") From 439d28c4da484536cb919c400c26c164d784bce7 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 08:52:35 +0100 Subject: [PATCH 07/35] rename methods with Locked suffix --- credentials/jwt/jwt_token_file.go | 34 +++++++++++++------------- credentials/jwt/jwt_token_file_test.go | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index 9a24b4f890c4..fdcaf7801692 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -108,9 +108,9 @@ func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool { func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { c.mu.RLock() - if c.isTokenValid() { + if c.isTokenValidLocked() { token := c.cachedToken - shouldRefresh := c.needsPreemptiveRefresh() + shouldRefresh := c.needsPreemptiveRefreshLocked() c.mu.RUnlock() if shouldRefresh { @@ -133,22 +133,22 @@ func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { return c.refreshTokenSync(ctx, false) } -// isTokenValid checks if the cached token is still valid. +// isTokenValidLocked checks if the cached token is still valid. // Caller must hold c.mu.RLock(). -func (c *jwtTokenFileCallCreds) isTokenValid() bool { +func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { if c.cachedToken == "" { return false } return c.cachedExpiration.After(time.Now()) } -// needsPreemptiveRefresh checks if a pre-emptive refresh should be triggered. +// needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be triggered. // Returns true if the cached token is valid but expires within 1 minute. // We only trigger pre-emptive refresh for valid tokens - if the token is invalid // or expired, the next RPC will handle synchronous refresh instead. // Caller must hold c.mu.RLock(). -func (c *jwtTokenFileCallCreds) needsPreemptiveRefresh() bool { - return c.isTokenValid() && time.Until(c.cachedExpiration) < time.Minute +func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { + return c.isTokenValidLocked() && time.Until(c.cachedExpiration) < time.Minute } // triggerPreemptiveRefresh starts a background refresh if needed. @@ -161,7 +161,7 @@ func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { // Re-check if refresh is still needed under mutex c.mu.RLock() - stillNeeded := c.needsPreemptiveRefresh() + stillNeeded := c.needsPreemptiveRefreshLocked() c.mu.RUnlock() if !stillNeeded { @@ -188,21 +188,21 @@ func (c *jwtTokenFileCallCreds) refreshTokenSync(_ context.Context, preemptiveRe defer c.mu.Unlock() // Double-check under write lock but skip if preemptive refresh is requested - if !preemptiveRefresh && c.isTokenValid() { + if !preemptiveRefresh && c.isTokenValidLocked() { return c.cachedToken, nil } tokenBytes, err := os.ReadFile(c.tokenFilePath) if err != nil { err = status.Errorf(codes.Unavailable, "failed to read token file %q: %v", c.tokenFilePath, err) - c.setErrorWithBackoff(err) + c.setErrorWithBackoffLocked(err) return "", err } token := strings.TrimSpace(string(tokenBytes)) if token == "" { err := status.Errorf(codes.Unavailable, "token file %q is empty", c.tokenFilePath) - c.setErrorWithBackoff(err) + c.setErrorWithBackoffLocked(err) return "", err } @@ -210,12 +210,12 @@ func (c *jwtTokenFileCallCreds) refreshTokenSync(_ context.Context, preemptiveRe exp, err := c.extractExpiration(token) if err != nil { err = status.Errorf(codes.Unauthenticated, "failed to parse JWT from token file %q: %v", c.tokenFilePath, err) - c.setErrorWithBackoff(err) + c.setErrorWithBackoffLocked(err) return "", err } // Success - clear any cached error and backoff state, update token cache - c.clearErrorAndBackoff() + c.clearErrorAndBackoffLocked() c.cachedToken = token // Per RFC A97: consider token invalid if it expires within the next 30 // seconds to accommodate for clock skew and server processing time. @@ -261,9 +261,9 @@ func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, erro return expTime, nil } -// setErrorWithBackoff caches an error and calculates the next retry time using exponential backoff. +// setErrorWithBackoffLocked caches an error and calculates the next retry time using exponential backoff. // Caller must hold c.mu write lock. -func (c *jwtTokenFileCallCreds) setErrorWithBackoff(err error) { +func (c *jwtTokenFileCallCreds) setErrorWithBackoffLocked(err error) { c.cachedError = err c.cachedErrorTime = time.Now() c.retryAttempt++ @@ -271,9 +271,9 @@ func (c *jwtTokenFileCallCreds) setErrorWithBackoff(err error) { c.nextRetryTime = time.Now().Add(backoffDelay) } -// clearErrorAndBackoff clears the cached error and resets backoff state. +// clearErrorAndBackoffLocked clears the cached error and resets backoff state. // Caller must hold c.mu write lock. -func (c *jwtTokenFileCallCreds) clearErrorAndBackoff() { +func (c *jwtTokenFileCallCreds) clearErrorAndBackoffLocked() { c.cachedError = nil c.cachedErrorTime = time.Time{} c.retryAttempt = 0 diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index 6405ef9c0bf1..6c7b10e41c33 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -409,7 +409,7 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { impl.mu.RLock() cacheExp := impl.cachedExpiration tokenCached := impl.cachedToken != "" - shouldTriggerRefresh := impl.needsPreemptiveRefresh() + shouldTriggerRefresh := impl.needsPreemptiveRefreshLocked() impl.mu.RUnlock() if !tokenCached { From b36d4b656b9df4e19d573989d574d1596d5329f5 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 09:07:51 +0100 Subject: [PATCH 08/35] remove context param from refreshTokenSync --- credentials/jwt/jwt_token_file.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index fdcaf7801692..a7c3f7232a0f 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -130,7 +130,7 @@ func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { // Token is expired or missing or the retry backoff period has expired. So // refresh synchronously. // NOTE: refreshTokenSync itself acquires the write lock - return c.refreshTokenSync(ctx, false) + return c.refreshTokenSync(false) } // isTokenValidLocked checks if the cached token is still valid. @@ -168,11 +168,8 @@ func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { return // Another goroutine already refreshed or token expired } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - // Force refresh to read new token even if current one is still valid - _, _ = c.refreshTokenSync(ctx, true) + _, _ = c.refreshTokenSync(true) }() } @@ -183,7 +180,7 @@ func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { // the cached token is still valid. If preemptiveRefresh is false, skips file read // when cached token is still valid, optimizing concurrent synchronous refresh calls // where one RPC may have already updated the cache while another was waiting on the lock. -func (c *jwtTokenFileCallCreds) refreshTokenSync(_ context.Context, preemptiveRefresh bool) (string, error) { +func (c *jwtTokenFileCallCreds) refreshTokenSync(preemptiveRefresh bool) (string, error) { c.mu.Lock() defer c.mu.Unlock() From 26e04513f8da5dc441bd0890aaae888abf7c0274 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 09:37:24 +0100 Subject: [PATCH 09/35] reformat comments; remove redundant cachedErrorTime field --- credentials/jwt/jwt_token_file.go | 42 ++++++++++++++------------ credentials/jwt/jwt_token_file_test.go | 23 ++++++-------- 2 files changed, 32 insertions(+), 33 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index a7c3f7232a0f..a32f9bf2b22f 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -50,9 +50,8 @@ type jwtTokenFileCallCreds struct { // Cached token data mu sync.RWMutex cachedToken string - cachedExpiration time.Time // Slightly reduced expiration time compared to the actual exp - cachedError error // Cached error from last failed attempt - cachedErrorTime time.Time // When the error was cached + cachedExpiration time.Time // Slightly less than actual expiration time + cachedError error // Error from last failed attempt retryAttempt int // Current retry attempt number nextRetryTime time.Time // When next retry is allowed @@ -73,11 +72,12 @@ func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCreden }, nil } -// GetRequestMetadata gets the current request metadata, refreshing tokens -// if required. This implementation follows the PerRPCCredentials interface. -// The tokens will get automatically refreshed if they are about to expire or if +// GetRequestMetadata gets the current request metadata, refreshing tokens if +// required. This implementation follows the PerRPCCredentials interface. The +// tokens will get automatically refreshed if they are about to expire or if // they haven't been loaded successfully yet. -// If it's not possible to extract a token from the file, UNAVAILABLE is returned. +// If it's not possible to extract a token from the file, UNAVAILABLE is +// returned. // If the token is extracted but invalid, then UNAUTHENTICATED is returned. // If errors are encoutered, a backoff is applied before retrying. func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { @@ -119,7 +119,8 @@ func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { return token, nil } - // if still within backoff period, return cached error to avoid repeated file reads + // if still within backoff period, return cached error to avoid repeated + // file reads if c.cachedError != nil && time.Now().Before(c.nextRetryTime) { err := c.cachedError c.mu.RUnlock() @@ -142,10 +143,11 @@ func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { return c.cachedExpiration.After(time.Now()) } -// needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be triggered. +// needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be +// triggered. // Returns true if the cached token is valid but expires within 1 minute. -// We only trigger pre-emptive refresh for valid tokens - if the token is invalid -// or expired, the next RPC will handle synchronous refresh instead. +// We only trigger pre-emptive refresh for valid tokens - if the token is +// invalid or expired, the next RPC will handle synchronous refresh instead. // Caller must hold c.mu.RLock(). func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { return c.isTokenValidLocked() && time.Until(c.cachedExpiration) < time.Minute @@ -174,12 +176,13 @@ func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { } // refreshTokenSync reads a new token from the file and updates the cache. If -// preemptiveRefresh is true, bypasses the validity check of the currently cached -// token and always reads from file. -// This is used for pre-emptive refresh to ensure new tokens are loaded even when -// the cached token is still valid. If preemptiveRefresh is false, skips file read -// when cached token is still valid, optimizing concurrent synchronous refresh calls -// where one RPC may have already updated the cache while another was waiting on the lock. +// preemptiveRefresh is true, bypasses the validity check of the currently +// cached token and always reads from file. +// This is used for pre-emptive refresh to ensure new tokens are loaded even +// when the cached token is still valid. If preemptiveRefresh is false, skips +// file read when cached token is still valid, optimizing concurrent synchronous +// refresh calls where one RPC may have already updated the cache while another +// was waiting on the lock. func (c *jwtTokenFileCallCreds) refreshTokenSync(preemptiveRefresh bool) (string, error) { c.mu.Lock() defer c.mu.Unlock() @@ -258,11 +261,11 @@ func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, erro return expTime, nil } -// setErrorWithBackoffLocked caches an error and calculates the next retry time using exponential backoff. +// setErrorWithBackoffLocked caches an error and calculates the next retry time +// using exponential backoff. // Caller must hold c.mu write lock. func (c *jwtTokenFileCallCreds) setErrorWithBackoffLocked(err error) { c.cachedError = err - c.cachedErrorTime = time.Now() c.retryAttempt++ backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1) c.nextRetryTime = time.Now().Add(backoffDelay) @@ -272,7 +275,6 @@ func (c *jwtTokenFileCallCreds) setErrorWithBackoffLocked(err error) { // Caller must hold c.mu write lock. func (c *jwtTokenFileCallCreds) clearErrorAndBackoffLocked() { c.cachedError = nil - c.cachedErrorTime = time.Time{} c.retryAttempt = 0 c.nextRetryTime = time.Time{} } diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index 6c7b10e41c33..df5bb0017c1d 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -343,7 +343,8 @@ func createTestJWT(t *testing.T, audience string, expiration time.Time) string { return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature) } -// Tests that cached token expiration is set to 30 seconds before actual token expiration. +// Tests that cached token expiration is set to 30 seconds before actual token +// expiration. func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { // Create token that expires in 2 hours tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour) @@ -421,7 +422,8 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { t.Errorf("cache expires in %v, should be < 1 minute to trigger pre-emptive refresh", timeUntilExp) } - // Create new token file with different expiration while refresh is happening + // Create new token file with different expiration while refresh is + // happening newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { t.Fatalf("Failed to write updated token file: %v", err) @@ -429,7 +431,8 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { // Get token again - should trigger a refresh given that the first one was // cached but expiring soon - // However, the function should have returned right away with the current cached token + // However, the function should have returned right away with the current + // cached token metadata2, err := creds.GetRequestMetadata(ctx) if err != nil { t.Fatalf("Second GetRequestMetadata() failed: %v", err) @@ -467,10 +470,12 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { // Tests that backoff behavior handles file read errors correctly. func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { // This test has the following flow: - // First call to GetRequestMetadata() fails with UNAVAILABLE due to a missing file. + // First call to GetRequestMetadata() fails with UNAVAILABLE due to a + // missing file. // Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff. // Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry. - // Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff even though file exists. + // Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff + // even though file exists. // Fifth call to GetRequestMetadata() succeeds after creating the file. tempDir := t.TempDir() nonExistentFile := filepath.Join(tempDir, "nonexistent") @@ -499,7 +504,6 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { impl := creds.(*jwtTokenFileCallCreds) impl.mu.RLock() cachedErr := impl.cachedError - cachedErrTime := impl.cachedErrorTime retryAttempt := impl.retryAttempt nextRetryTime := impl.nextRetryTime impl.mu.RUnlock() @@ -507,9 +511,6 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { if cachedErr == nil { t.Error("error should be cached internally after failed file read") } - if cachedErrTime.IsZero() { - t.Error("error cache time should be set") - } if retryAttempt != 1 { t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt) } @@ -614,7 +615,6 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { // If successful, verify error cache and backoff state were cleared impl.mu.RLock() clearedErr := impl.cachedError - clearedErrTime := impl.cachedErrorTime retryAttempt := impl.retryAttempt nextRetryTime := impl.nextRetryTime impl.mu.RUnlock() @@ -622,9 +622,6 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { if clearedErr != nil { t.Errorf("after successful retry, cached error should be cleared, got: %v", clearedErr) } - if !clearedErrTime.IsZero() { - t.Error("after successful retry, cached error time should be cleared") - } if retryAttempt != 0 { t.Errorf("after successful retry, retry attempt should be reset, got: %d", retryAttempt) } From da2de8c3b3eb4b6aebd8097421cda3041d26f96f Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 10:00:59 +0100 Subject: [PATCH 10/35] add defaultTestTimeout const --- credentials/jwt/jwt_token_file_test.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index df5bb0017c1d..70ab2c035ddf 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -35,6 +35,8 @@ import ( "google.golang.org/grpc/status" ) +const defaultTestTimeout = 5 * time.Second + type s struct { grpctest.Tester } @@ -151,7 +153,7 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: tt.authInfo, @@ -186,7 +188,6 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { } func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { - token := createTestJWT(t, "", time.Now().Add(time.Hour)) tokenFile := writeTempFile(t, "token", token) @@ -195,7 +196,7 @@ func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, @@ -271,7 +272,7 @@ func (s) TestTokenFileCallCreds_FileErrors(t *testing.T) { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, @@ -356,7 +357,7 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, @@ -393,7 +394,7 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, @@ -485,7 +486,7 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, @@ -642,7 +643,7 @@ func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, @@ -668,7 +669,7 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, From 51ce34c7cf8155e6333b188fa8c81f914236072f Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 10:03:28 +0100 Subject: [PATCH 11/35] refactor test to use wantErr string only --- credentials/jwt/jwt_token_file_test.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index 70ab2c035ddf..2f25e6228012 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -47,33 +47,31 @@ func Test(t *testing.T) { func (s) TestNewTokenFileCallCredentials(t *testing.T) { tests := []struct { - name string - tokenFilePath string - wantErr bool - wantErrContains string + name string + tokenFilePath string + wantErr string }{ { name: "valid parameters", tokenFilePath: "/path/to/token", - wantErr: false, + wantErr: "", }, { - name: "empty token file path", - tokenFilePath: "", - wantErr: true, - wantErrContains: "tokenFilePath cannot be empty", + name: "empty token file path", + tokenFilePath: "", + wantErr: "tokenFilePath cannot be empty", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { creds, err := NewTokenFileCallCredentials(tt.tokenFilePath) - if tt.wantErr { + if tt.wantErr != "" { if err == nil { t.Fatalf("NewTokenFileCallCredentials() expected error, got nil") } - if !strings.Contains(err.Error(), tt.wantErrContains) { - t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErrContains) + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErr) } return } From f87f1f25c48d6c800dbf1a4676a84b6f537e6bd3 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 11:04:14 +0100 Subject: [PATCH 12/35] fix punctuation --- credentials/jwt/jwt_token_file.go | 31 ++++++------ credentials/jwt/jwt_token_file_test.go | 70 +++++++++++++------------- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index a32f9bf2b22f..0f796a9b0770 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -45,7 +45,7 @@ type jwtClaims struct { // This implementation follows the A97 JWT Call Credentials specification. type jwtTokenFileCallCreds struct { tokenFilePath string - backoffStrategy backoff.Strategy // Backoff strategy when error occurs + backoffStrategy backoff.Strategy // Strategy when error occurs // Cached token data mu sync.RWMutex @@ -86,7 +86,7 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str return nil, fmt.Errorf("unable to transfer JWT token file PerRPCCredentials: %v", err) } - // this may be delayed if the token needs to be refreshed from file + // This may be delayed if the token needs to be refreshed from file. token, err := c.getToken(ctx) if err != nil { return nil, err @@ -119,8 +119,8 @@ func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { return token, nil } - // if still within backoff period, return cached error to avoid repeated - // file reads + // If still within backoff period, return cached error to avoid repeated + // file reads. if c.cachedError != nil && time.Now().Before(c.nextRetryTime) { err := c.cachedError c.mu.RUnlock() @@ -128,9 +128,9 @@ func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { } c.mu.RUnlock() - // Token is expired or missing or the retry backoff period has expired. So - // refresh synchronously. - // NOTE: refreshTokenSync itself acquires the write lock + // Token is expired or missing or the retry backoff period has expired. + // So we should refresh synchronously. + // NOTE: refreshTokenSync itself acquires the write lock. return c.refreshTokenSync(false) } @@ -161,16 +161,16 @@ func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { c.refreshMu.Lock() defer c.refreshMu.Unlock() - // Re-check if refresh is still needed under mutex + // Re-check if refresh is still needed under mutex. c.mu.RLock() stillNeeded := c.needsPreemptiveRefreshLocked() c.mu.RUnlock() if !stillNeeded { - return // Another goroutine already refreshed or token expired + return // Another goroutine already refreshed or token expired. } - // Force refresh to read new token even if current one is still valid + // Force refresh to read new token even if current one is still valid. _, _ = c.refreshTokenSync(true) }() } @@ -187,7 +187,8 @@ func (c *jwtTokenFileCallCreds) refreshTokenSync(preemptiveRefresh bool) (string c.mu.Lock() defer c.mu.Unlock() - // Double-check under write lock but skip if preemptive refresh is requested + // Double-check under write lock but skip if preemptive refresh is + // requested. if !preemptiveRefresh && c.isTokenValidLocked() { return c.cachedToken, nil } @@ -206,7 +207,7 @@ func (c *jwtTokenFileCallCreds) refreshTokenSync(preemptiveRefresh bool) (string return "", err } - // Parse JWT to extract expiration + // Parse JWT to extract expiration. exp, err := c.extractExpiration(token) if err != nil { err = status.Errorf(codes.Unauthenticated, "failed to parse JWT from token file %q: %v", c.tokenFilePath, err) @@ -214,7 +215,7 @@ func (c *jwtTokenFileCallCreds) refreshTokenSync(preemptiveRefresh bool) (string return "", err } - // Success - clear any cached error and backoff state, update token cache + // Success - clear any cached error and backoff state, update token cache. c.clearErrorAndBackoffLocked() c.cachedToken = token // Per RFC A97: consider token invalid if it expires within the next 30 @@ -232,7 +233,7 @@ func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, erro } payload := parts[1] - // Add padding if necessary for base64 decoding + // Add padding if necessary for base64 decoding. if m := len(payload) % 4; m != 0 { payload += strings.Repeat("=", 4-m) } @@ -253,7 +254,7 @@ func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, erro expTime := time.Unix(claims.Exp, 0) - // Check if token is already expired + // Check if token is already expired. if expTime.Before(time.Now()) { return time.Time{}, fmt.Errorf("JWT token is expired") } diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index 2f25e6228012..932737b7ec2a 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -200,19 +200,19 @@ func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, }) - // First call should read from file + // First call should read from file. metadata1, err := creds.GetRequestMetadata(ctx) if err != nil { t.Fatalf("First GetRequestMetadata() failed: %v", err) } - // Update the file with a different token + // Update the file with a different token. newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { t.Fatalf("Failed to update token file: %v", err) } - // Second call should return cached token (not the updated one) + // Second call should return cached token (not the updated one). metadata2, err := creds.GetRequestMetadata(ctx) if err != nil { t.Fatalf("Second GetRequestMetadata() failed: %v", err) @@ -301,7 +301,8 @@ func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} } -// createTestJWT creates a test JWT token with the specified audience and expiration. +// createTestJWT creates a test JWT token with the specified audience and +// expiration. func createTestJWT(t *testing.T, audience string, expiration time.Time) string { t.Helper() @@ -345,7 +346,7 @@ func createTestJWT(t *testing.T, audience string, expiration time.Time) string { // Tests that cached token expiration is set to 30 seconds before actual token // expiration. func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { - // Create token that expires in 2 hours + // Create token that expires in 2 hours. tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour) token := createTestJWT(t, "", tokenExp) tokenFile := writeTempFile(t, "token", token) @@ -361,13 +362,13 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, }) - // Get token to trigger caching + // Get token to trigger caching. _, err = creds.GetRequestMetadata(ctx) if err != nil { t.Fatalf("GetRequestMetadata() failed: %v", err) } - // Verify cached expiration is 30 seconds before actual token expiration + // Verify cached expiration is 30 seconds before actual token expiration. impl := creds.(*jwtTokenFileCallCreds) impl.mu.RLock() cachedExp := impl.cachedExpiration @@ -382,7 +383,7 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin // Tests that pre-emptive refresh is triggered within 1 minute of expiration. func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { // Create token that expires in 80 seconds (=> cache expires in ~50s) - // This ensures pre-emptive refresh triggers since 50s < the 1 minute check + // This ensures pre-emptive refresh triggers since 50s < the 1 minute check. tokenExp := time.Now().Add(80 * time.Second) expiringToken := createTestJWT(t, "", tokenExp) tokenFile := writeTempFile(t, "token", expiringToken) @@ -398,13 +399,13 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, }) - // Get token - should trigger pre-emptive refresh + // Get token - should trigger pre-emptive refresh. metadata1, err := creds.GetRequestMetadata(ctx) if err != nil { t.Fatalf("GetRequestMetadata() failed: %v", err) } - // Verify token was cached and check if refresh should be triggered + // Verify token was cached and check if refresh should be triggered. impl := creds.(*jwtTokenFileCallCreds) impl.mu.RLock() cacheExp := impl.cachedExpiration @@ -422,16 +423,16 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { } // Create new token file with different expiration while refresh is - // happening + // happening. newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { t.Fatalf("Failed to write updated token file: %v", err) } // Get token again - should trigger a refresh given that the first one was - // cached but expiring soon + // cached but expiring soon. // However, the function should have returned right away with the current - // cached token + // cached token. metadata2, err := creds.GetRequestMetadata(ctx) if err != nil { t.Fatalf("Second GetRequestMetadata() failed: %v", err) @@ -439,13 +440,13 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { time.Sleep(50 * time.Millisecond) - // now should get the new token + // Now should get the new token. metadata3, err := creds.GetRequestMetadata(ctx) if err != nil { t.Fatalf("Second GetRequestMetadata() failed: %v", err) } - // If pre-emptive refresh worked, we should get the new token + // If pre-emptive refresh worked, we should get the new token. expectedAuth1 := "Bearer " + expiringToken expectedAuth2 := "Bearer " + expiringToken expectedAuth3 := "Bearer " + newToken @@ -490,7 +491,7 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, }) - // First call should fail with UNAVAILABLE + // First call should fail with UNAVAILABLE. _, err1 := creds.GetRequestMetadata(ctx) if err1 == nil { t.Fatal("Expected error from nonexistent file") @@ -499,7 +500,7 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err1)) } - // Verify error is cached internally + // Verify error is cached internally. impl := creds.(*jwtTokenFileCallCreds) impl.mu.RLock() cachedErr := impl.cachedError @@ -517,7 +518,7 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Error("Next retry time should be set to future time") } - // Second call should still return cached error + // Second call should still return cached error. _, err2 := creds.GetRequestMetadata(ctx) if err2 == nil { t.Fatal("Expected cached error") @@ -541,12 +542,12 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Error("retry attempt should not change due to backoff") } - // fast-forward the backoff retry time to allow next retry attempt + // Fast-forward the backoff retry time to allow next retry attempt. impl.mu.Lock() impl.nextRetryTime = time.Now().Add(-1 * time.Minute) impl.mu.Unlock() - // Third call should retry but still fail with UNAVAILABLE + // Third call should retry but still fail with UNAVAILABLE. _, err3 := creds.GetRequestMetadata(ctx) if err3 == nil { t.Fatal("Expected cached error") @@ -570,13 +571,13 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Error("retry attempt should not change due to backoff") } - // Create valid token file + // Create valid token file. validToken := createTestJWT(t, "", time.Now().Add(time.Hour)) if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil { t.Fatalf("Failed to create valid token file: %v", err) } - // Forth call should still fail even though the file now exists + // Fourth call should still fail even though the file now exists. _, err4 := creds.GetRequestMetadata(ctx) if err4 == nil { t.Fatal("Expected cached error") @@ -600,18 +601,18 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Error("retry attempt should not change due to backoff") } - // fast-forward the backoff retry time to allow next retry attempt + // Fast-forward the backoff retry time to allow next retry attempt. impl.mu.Lock() impl.nextRetryTime = time.Now().Add(-1 * time.Minute) impl.mu.Unlock() // Fifth call should succeed since the file now exists - // and the backoff has expired + // and the backoff has expired. _, err5 := creds.GetRequestMetadata(ctx) if err5 != nil { t.Errorf("after creating valid token file, GetRequestMetadata() should eventually succeed, but got: %v", err5) t.Error("backoff should expire and trigger new attempt on next RPC") } else { - // If successful, verify error cache and backoff state were cleared + // If successful, verify error cache and backoff state were cleared. impl.mu.RLock() clearedErr := impl.cachedError retryAttempt := impl.retryAttempt @@ -632,8 +633,8 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { // Tests that invalid JWT tokens are handled with UNAUTHENTICATED status. func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { - // Write invalid JWT (missing exp field) - invalidJWT := createTestJWT(t, "", time.Time{}) // No expiration + // Write invalid JWT (missing exp field). + invalidJWT := createTestJWT(t, "", time.Time{}) tokenFile := writeTempFile(t, "token", invalidJWT) creds, err := NewTokenFileCallCredentials(tokenFile) @@ -656,12 +657,13 @@ func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { } } -// Tests that RPCs are queued during file operations and all receive the same result. +// Tests that RPCs are queued during file operations and all receive the same +// result. func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { tempDir := t.TempDir() tokenFile := filepath.Join(tempDir, "token") - // Start with no token file to force file read during first RPC + // Start with no token file to force file read during first RPC. creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) @@ -673,7 +675,7 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, }) - // Launch multiple concurrent RPCs before creating the token file + // Launch multiple concurrent RPCs before creating the token file. const numConcurrentRPCs = 5 results := make(chan error, numConcurrentRPCs) @@ -684,14 +686,14 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { }() } - // Collect all results - they should all be the same error (UNAVAILABLE) + // Collect all results - they should all be the same error (UNAVAILABLE). var errors []error for range numConcurrentRPCs { err := <-results errors = append(errors, err) } - // All RPCs should fail with the same error (file not found) + // All RPCs should fail with the same error (file not found). for i, err := range errors { if err == nil { t.Errorf("RPC %d should have failed with UNAVAILABLE", i) @@ -705,7 +707,7 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { } } - // Verify error was cached after concurrent RPCs + // Verify error was cached after concurrent RPCs. impl := creds.(*jwtTokenFileCallCreds) impl.mu.RLock() finalCachedErr := impl.cachedError @@ -731,7 +733,7 @@ func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) { impl := creds.(*jwtTokenFileCallCreds) - // Verify state unchanged - no background file reads attempted + // Verify state unchanged - no background file reads attempted. impl.mu.RLock() token := impl.cachedToken cachedErr := impl.cachedError From 15dd0573f859c72fc614cc72fc0b49250ff0a136 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 29 Jul 2025 11:15:33 +0100 Subject: [PATCH 13/35] less prosaic subtest names --- credentials/jwt/jwt_token_file_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index 932737b7ec2a..b67acf383019 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -52,12 +52,12 @@ func (s) TestNewTokenFileCallCredentials(t *testing.T) { wantErr string }{ { - name: "valid parameters", + name: "some filepath", tokenFilePath: "/path/to/token", wantErr: "", }, { - name: "empty token file path", + name: "empty filepath", tokenFilePath: "", wantErr: "tokenFilePath cannot be empty", }, @@ -113,14 +113,14 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { wantMetadata map[string]string }{ { - name: "valid token without expiration errors", + name: "valid token without expiration", tokenContent: createTestJWT(t, "", time.Time{}), authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, wantErr: true, wantErrContains: "JWT token has no expiration claim", }, { - name: "valid token with future expiration succeeds", + name: "valid token with future expiration", tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)), authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, wantErr: false, @@ -134,7 +134,7 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { wantErrContains: "unable to transfer JWT token file PerRPCCredentials", }, { - name: "expired token errors", + name: "expired token", tokenContent: createTestJWT(t, "", now.Add(-time.Hour)), authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, wantErr: true, From 54cbbcb7dc245a5a3d322c2d710bf204e451e2d0 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Wed, 30 Jul 2025 14:29:49 +0100 Subject: [PATCH 14/35] remove unit test --- credentials/jwt/jwt_token_file_test.go | 37 ++++---------------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index b67acf383019..afad2b152602 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -382,7 +382,7 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin // Tests that pre-emptive refresh is triggered within 1 minute of expiration. func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { - // Create token that expires in 80 seconds (=> cache expires in ~50s) + // Create token that expires in 80 seconds (=> cache expires in ~50s). // This ensures pre-emptive refresh triggers since 50s < the 1 minute check. tokenExp := time.Now().Add(80 * time.Second) expiringToken := createTestJWT(t, "", tokenExp) @@ -463,20 +463,21 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { t.Errorf("Second call should return the original token: got %q, want %q", actualAuth2, expectedAuth2) } if actualAuth3 != expectedAuth3 { - t.Errorf("Third call should return the original token: got %q, want %q", actualAuth3, expectedAuth3) + t.Errorf("Third call should return the new token: got %q, want %q", actualAuth3, expectedAuth3) } } // Tests that backoff behavior handles file read errors correctly. func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { - // This test has the following flow: + // This test has the following expectations: // First call to GetRequestMetadata() fails with UNAVAILABLE due to a // missing file. // Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff. // Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry. // Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff // even though file exists. - // Fifth call to GetRequestMetadata() succeeds after creating the file. + // Fifth call to GetRequestMetadata() succeeds after reading the file and + // backoff has expired. tempDir := t.TempDir() nonExistentFile := filepath.Join(tempDir, "nonexistent") @@ -721,34 +722,6 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { } } -// Tests that no background retries occur when channel is idle. -func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) { - newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) - tokenFilepath := writeTempFile(t, "token", newToken) - - creds, err := NewTokenFileCallCredentials(tokenFilepath) - if err != nil { - t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) - } - - impl := creds.(*jwtTokenFileCallCreds) - - // Verify state unchanged - no background file reads attempted. - impl.mu.RLock() - token := impl.cachedToken - cachedErr := impl.cachedError - impl.mu.RUnlock() - - time.Sleep(100 * time.Millisecond) - - if token != "" { - t.Errorf("after idle period, cached token = %q, want empty (no background reads)", token) - } - if cachedErr != nil { - t.Errorf("after idle period, cached error = %v, want nil (no background reads)", cachedErr) - } -} - func writeTempFile(t *testing.T, name, content string) string { t.Helper() tempDir := t.TempDir() From 9c5035d5e4b03062bd2aece960c2f27a715f1c48 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 31 Jul 2025 14:30:10 +0100 Subject: [PATCH 15/35] rename preemptiveRefresh to forceRefresh --- credentials/jwt/jwt_token_file.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index 0f796a9b0770..3000110cd87d 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -176,20 +176,20 @@ func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { } // refreshTokenSync reads a new token from the file and updates the cache. If -// preemptiveRefresh is true, bypasses the validity check of the currently +// forceRefresh is true, bypasses the validity check of the currently // cached token and always reads from file. // This is used for pre-emptive refresh to ensure new tokens are loaded even -// when the cached token is still valid. If preemptiveRefresh is false, skips +// when the cached token is still valid. If forceRefresh is false, skips // file read when cached token is still valid, optimizing concurrent synchronous // refresh calls where one RPC may have already updated the cache while another // was waiting on the lock. -func (c *jwtTokenFileCallCreds) refreshTokenSync(preemptiveRefresh bool) (string, error) { +func (c *jwtTokenFileCallCreds) refreshTokenSync(forceRefresh bool) (string, error) { c.mu.Lock() defer c.mu.Unlock() // Double-check under write lock but skip if preemptive refresh is // requested. - if !preemptiveRefresh && c.isTokenValidLocked() { + if !forceRefresh && c.isTokenValidLocked() { return c.cachedToken, nil } From ec915dc35973b78ccbd7df2e5b3137548e5f5c2c Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 31 Jul 2025 14:32:58 +0100 Subject: [PATCH 16/35] remove unused context param --- credentials/jwt/jwt_token_file.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index 3000110cd87d..c42dbccd114d 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -87,7 +87,7 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str } // This may be delayed if the token needs to be refreshed from file. - token, err := c.getToken(ctx) + token, err := c.getToken() if err != nil { return nil, err } @@ -105,7 +105,7 @@ func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool { // getToken returns a valid JWT token, reading from file if necessary. // Implements pre-emptive refresh and caches errors with backoff. -func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) { +func (c *jwtTokenFileCallCreds) getToken() (string, error) { c.mu.RLock() if c.isTokenValidLocked() { From 1d95fa2e050aaa866e52663696e52775a0a06e5a Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 15:34:26 +0100 Subject: [PATCH 17/35] rename files --- ...n_file.go => jwt_token_file_call_creds.go} | 0 ...t.go => jwt_token_file_call_creds_test.go} | 225 +++++------------- 2 files changed, 62 insertions(+), 163 deletions(-) rename credentials/jwt/{jwt_token_file.go => jwt_token_file_call_creds.go} (100%) rename credentials/jwt/{jwt_token_file_test.go => jwt_token_file_call_creds_test.go} (84%) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file_call_creds.go similarity index 100% rename from credentials/jwt/jwt_token_file.go rename to credentials/jwt/jwt_token_file_call_creds.go diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go similarity index 84% rename from credentials/jwt/jwt_token_file_test.go rename to credentials/jwt/jwt_token_file_call_creds_test.go index afad2b152602..1eaacc5b3ab7 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -112,13 +112,6 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { wantErrContains string wantMetadata map[string]string }{ - { - name: "valid token without expiration", - tokenContent: createTestJWT(t, "", time.Time{}), - authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, - wantErr: true, - wantErrContains: "JWT token has no expiration claim", - }, { name: "valid token with future expiration", tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)), @@ -128,18 +121,11 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { }, { name: "insufficient security level", - tokenContent: createTestJWT(t, "", time.Time{}), + tokenContent: createTestJWT(t, "", now.Add(time.Hour)), authInfo: &testAuthInfo{secLevel: credentials.NoSecurity}, wantErr: true, wantErrContains: "unable to transfer JWT token file PerRPCCredentials", }, - { - name: "expired token", - tokenContent: createTestJWT(t, "", now.Add(-time.Hour)), - authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, - wantErr: true, - wantErrContains: "JWT token is expired", - }, } for _, tt := range tests { @@ -223,70 +209,6 @@ func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { } } -func (s) TestTokenFileCallCreds_FileErrors(t *testing.T) { - tests := []struct { - name string - setupFile func(string) error - wantErrContains string - }{ - { - name: "nonexistent file", - setupFile: func(_ string) error { - return nil // Don't create the file - }, - wantErrContains: "failed to read token file", - }, - { - name: "empty file", - setupFile: func(path string) error { - return os.WriteFile(path, []byte(""), 0600) - }, - wantErrContains: "token file", - }, - { - name: "file with whitespace only", - setupFile: func(path string) error { - return os.WriteFile(path, []byte(" \n\t "), 0600) - }, - wantErrContains: "token file", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tempDir, err := os.MkdirTemp("", "jwt_test") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tempDir) - - tokenFile := filepath.Join(tempDir, "token") - if err := tt.setupFile(tokenFile); err != nil { - t.Fatalf("Failed to setup test file: %v", err) - } - - creds, err := NewTokenFileCallCredentials(tokenFile) - if err != nil { - t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ - AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, - }) - - _, err = creds.GetRequestMetadata(ctx) - if err == nil { - t.Fatal("GetRequestMetadata() expected error, got nil") - } - - if !strings.Contains(err.Error(), tt.wantErrContains) { - t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains) - } - }) - } -} // testAuthInfo implements credentials.AuthInfo for testing. type testAuthInfo struct { @@ -301,47 +223,6 @@ func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} } -// createTestJWT creates a test JWT token with the specified audience and -// expiration. -func createTestJWT(t *testing.T, audience string, expiration time.Time) string { - t.Helper() - - header := map[string]any{ - "typ": "JWT", - "alg": "HS256", - } - - claims := map[string]any{} - if audience != "" { - claims["aud"] = audience - } - if !expiration.IsZero() { - claims["exp"] = expiration.Unix() - } - - headerBytes, err := json.Marshal(header) - if err != nil { - t.Fatalf("Failed to marshal header: %v", err) - } - - claimsBytes, err := json.Marshal(claims) - if err != nil { - t.Fatalf("Failed to marshal claims: %v", err) - } - - headerB64 := base64.URLEncoding.EncodeToString(headerBytes) - claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes) - - // Remove padding for URL-safe base64 - headerB64 = strings.TrimRight(headerB64, "=") - claimsB64 = strings.TrimRight(claimsB64, "=") - - // For testing, we'll use a fake signature - signature := base64.URLEncoding.EncodeToString([]byte("fake_signature")) - signature = strings.TrimRight(signature, "=") - - return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature) -} // Tests that cached token expiration is set to 30 seconds before actual token // expiration. @@ -370,9 +251,9 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin // Verify cached expiration is 30 seconds before actual token expiration. impl := creds.(*jwtTokenFileCallCreds) - impl.mu.RLock() - cachedExp := impl.cachedExpiration - impl.mu.RUnlock() + impl.mu.Lock() + cachedExp := impl.cachedExpiry + impl.mu.Unlock() expectedExp := tokenExp.Add(-30 * time.Second) if !cachedExp.Equal(expectedExp) { @@ -407,11 +288,11 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { // Verify token was cached and check if refresh should be triggered. impl := creds.(*jwtTokenFileCallCreds) - impl.mu.RLock() - cacheExp := impl.cachedExpiration + impl.mu.Lock() + cacheExp := impl.cachedExpiry tokenCached := impl.cachedToken != "" shouldTriggerRefresh := impl.needsPreemptiveRefreshLocked() - impl.mu.RUnlock() + impl.mu.Unlock() if !tokenCached { t.Error("token should be cached after successful GetRequestMetadata") @@ -503,11 +384,11 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { // Verify error is cached internally. impl := creds.(*jwtTokenFileCallCreds) - impl.mu.RLock() + impl.mu.Lock() cachedErr := impl.cachedError retryAttempt := impl.retryAttempt nextRetryTime := impl.nextRetryTime - impl.mu.RUnlock() + impl.mu.Unlock() if cachedErr == nil { t.Error("error should be cached internally after failed file read") @@ -531,10 +412,10 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Errorf("cached error = %q, want %q", err2.Error(), err1.Error()) } - impl.mu.RLock() + impl.mu.Lock() retryAttempt2 := impl.retryAttempt nextRetryTime2 := impl.nextRetryTime - impl.mu.RUnlock() + impl.mu.Unlock() if !nextRetryTime2.Equal(nextRetryTime) { t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, nextRetryTime) @@ -560,10 +441,10 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Errorf("cached error = %q, want %q", err3.Error(), err1.Error()) } - impl.mu.RLock() + impl.mu.Lock() retryAttempt3 := impl.retryAttempt nextRetryTime3 := impl.nextRetryTime - impl.mu.RUnlock() + impl.mu.Unlock() if !nextRetryTime3.After(nextRetryTime2) { t.Error("nextRetryTime should not change due to backoff") @@ -590,10 +471,10 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Errorf("cached error = %q, want %q", err4.Error(), err3.Error()) } - impl.mu.RLock() + impl.mu.Lock() retryAttempt4 := impl.retryAttempt nextRetryTime4 := impl.nextRetryTime - impl.mu.RUnlock() + impl.mu.Unlock() if !nextRetryTime4.Equal(nextRetryTime3) { t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, nextRetryTime3) @@ -614,11 +495,11 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { t.Error("backoff should expire and trigger new attempt on next RPC") } else { // If successful, verify error cache and backoff state were cleared. - impl.mu.RLock() + impl.mu.Lock() clearedErr := impl.cachedError retryAttempt := impl.retryAttempt nextRetryTime := impl.nextRetryTime - impl.mu.RUnlock() + impl.mu.Unlock() if clearedErr != nil { t.Errorf("after successful retry, cached error should be cleared, got: %v", clearedErr) @@ -632,31 +513,6 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { } } -// Tests that invalid JWT tokens are handled with UNAUTHENTICATED status. -func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { - // Write invalid JWT (missing exp field). - invalidJWT := createTestJWT(t, "", time.Time{}) - tokenFile := writeTempFile(t, "token", invalidJWT) - - creds, err := NewTokenFileCallCredentials(tokenFile) - if err != nil { - t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ - AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, - }) - - _, err = creds.GetRequestMetadata(ctx) - if err == nil { - t.Fatal("Expected UNAUTHENTICATED from invalid JWT") - } - if status.Code(err) != codes.Unauthenticated { - t.Errorf("GetRequestMetadata() = %v, want UNAUTHENTICATED for invalid JWT", status.Code(err)) - } -} // Tests that RPCs are queued during file operations and all receive the same // result. @@ -710,9 +566,9 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { // Verify error was cached after concurrent RPCs. impl := creds.(*jwtTokenFileCallCreds) - impl.mu.RLock() + impl.mu.Lock() finalCachedErr := impl.cachedError - impl.mu.RUnlock() + impl.mu.Unlock() if finalCachedErr == nil { t.Error("error should be cached after failed concurrent RPCs") @@ -722,6 +578,48 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { } } +// createTestJWT creates a test JWT token with the specified audience and +// expiration. +func createTestJWT(t *testing.T, audience string, expiration time.Time) string { + t.Helper() + + header := map[string]any{ + "typ": "JWT", + "alg": "HS256", + } + + claims := map[string]any{} + if audience != "" { + claims["aud"] = audience + } + if !expiration.IsZero() { + claims["exp"] = expiration.Unix() + } + + headerBytes, err := json.Marshal(header) + if err != nil { + t.Fatalf("Failed to marshal header: %v", err) + } + + claimsBytes, err := json.Marshal(claims) + if err != nil { + t.Fatalf("Failed to marshal claims: %v", err) + } + + headerB64 := base64.URLEncoding.EncodeToString(headerBytes) + claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes) + + // Remove padding for URL-safe base64 + headerB64 = strings.TrimRight(headerB64, "=") + claimsB64 = strings.TrimRight(claimsB64, "=") + + // For testing, we'll use a fake signature + signature := base64.URLEncoding.EncodeToString([]byte("fake_signature")) + signature = strings.TrimRight(signature, "=") + + return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature) +} + func writeTempFile(t *testing.T, name, content string) string { t.Helper() tempDir := t.TempDir() @@ -731,3 +629,4 @@ func writeTempFile(t *testing.T, name, content string) string { } return filePath } + From a797ed983c2ccd8deb61a094bb376c15693c0a14 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 15:34:48 +0100 Subject: [PATCH 18/35] use cond variable --- credentials/jwt/jwt_file_reader.go | 103 +++++++ credentials/jwt/jwt_file_reader_test.go | 180 ++++++++++++ credentials/jwt/jwt_token_file_call_creds.go | 263 ++++++------------ .../jwt/jwt_token_file_call_creds_test.go | 68 ----- 4 files changed, 375 insertions(+), 239 deletions(-) create mode 100644 credentials/jwt/jwt_file_reader.go create mode 100644 credentials/jwt/jwt_file_reader_test.go diff --git a/credentials/jwt/jwt_file_reader.go b/credentials/jwt/jwt_file_reader.go new file mode 100644 index 000000000000..005662eea364 --- /dev/null +++ b/credentials/jwt/jwt_file_reader.go @@ -0,0 +1,103 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwt + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + "time" +) + +// jwtClaims represents the JWT claims structure for extracting expiration time. +type jwtClaims struct { + Exp int64 `json:"exp"` +} + +// jWTFileReader handles reading and parsing JWT tokens from files. +type jWTFileReader struct { + tokenFilePath string +} + +// newJWTFileReader creates a new JWTFileReader for the specified file path. +func newJWTFileReader(tokenFilePath string) *jWTFileReader { + return &jWTFileReader{ + tokenFilePath: tokenFilePath, + } +} + +// ReadToken reads and parses a JWT token from the configured file. +// Returns the token string, expiration time, and any error encountered. +func (r *jWTFileReader) ReadToken() (string, time.Time, error) { + tokenBytes, err := os.ReadFile(r.tokenFilePath) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to read token file %q: %v", r.tokenFilePath, err) + } + + token := strings.TrimSpace(string(tokenBytes)) + if token == "" { + return "", time.Time{}, fmt.Errorf("token file %q is empty", r.tokenFilePath) + } + + exp, err := r.extractExpiration(token) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to parse JWT from token file %q: %v", r.tokenFilePath, err) + } + + return token, exp, nil +} + +// extractExpiration parses the JWT token to extract the expiration time. +func (r *jWTFileReader) extractExpiration(token string) (time.Time, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) + } + + payload := parts[1] + // Add padding if necessary for base64 decoding. + if m := len(payload) % 4; m != 0 { + payload += strings.Repeat("=", 4-m) + } + + payloadBytes, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + return time.Time{}, fmt.Errorf("failed to decode JWT payload: %v", err) + } + + var claims jwtClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return time.Time{}, fmt.Errorf("failed to unmarshal JWT claims: %v", err) + } + + if claims.Exp == 0 { + return time.Time{}, fmt.Errorf("JWT token has no expiration claim") + } + + expTime := time.Unix(claims.Exp, 0) + + // Check if token is already expired. + if expTime.Before(time.Now()) { + return time.Time{}, fmt.Errorf("JWT token is expired") + } + + return expTime, nil +} diff --git a/credentials/jwt/jwt_file_reader_test.go b/credentials/jwt/jwt_file_reader_test.go new file mode 100644 index 000000000000..b57a93fa54ef --- /dev/null +++ b/credentials/jwt/jwt_file_reader_test.go @@ -0,0 +1,180 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwt + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestJWTFileReader_ReadToken_FileErrors(t *testing.T) { + tests := []struct { + name string + setupFile func(string) error + wantErrContains string + }{ + { + name: "nonexistent file", + setupFile: func(_ string) error { + return nil // Don't create the file + }, + wantErrContains: "failed to read token file", + }, + { + name: "empty file", + setupFile: func(path string) error { + return os.WriteFile(path, []byte(""), 0600) + }, + wantErrContains: "token file", + }, + { + name: "file with whitespace only", + setupFile: func(path string) error { + return os.WriteFile(path, []byte(" \n\t "), 0600) + }, + wantErrContains: "token file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + tokenFile := filepath.Join(tempDir, "token") + if err := tt.setupFile(tokenFile); err != nil { + t.Fatalf("Failed to setup test file: %v", err) + } + + reader := newJWTFileReader(tokenFile) + _, _, err := reader.ReadToken() + if err == nil { + t.Fatal("ReadToken() expected error, got nil") + } + + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Fatalf("ReadToken() error = %v, want error containing %q", err, tt.wantErrContains) + } + }) + } +} + +func TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { + now := time.Now().Truncate(time.Second) + tests := []struct { + name string + tokenContent string + wantErrContains string + }{ + { + name: "valid token without expiration", + tokenContent: createTestJWT(t, "", time.Time{}), + wantErrContains: "JWT token has no expiration claim", + }, + { + name: "expired token", + tokenContent: createTestJWT(t, "", now.Add(-time.Hour)), + wantErrContains: "JWT token is expired", + }, + { + name: "malformed JWT - not enough parts", + tokenContent: "invalid.jwt", + wantErrContains: "invalid JWT format: expected 3 parts, got 2", + }, + { + name: "malformed JWT - invalid base64", + tokenContent: "header.invalid_base64!@#.signature", + wantErrContains: "failed to decode JWT payload", + }, + { + name: "malformed JWT - invalid JSON", + tokenContent: createInvalidJSONJWT(t), + wantErrContains: "failed to unmarshal JWT claims", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenFile := writeTempFile(t, "token", tt.tokenContent) + + reader := newJWTFileReader(tokenFile) + _, _, err := reader.ReadToken() + if err == nil { + t.Fatal("ReadToken() expected error, got nil") + } + + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Fatalf("ReadToken() error = %v, want error containing %q", err, tt.wantErrContains) + } + }) + } +} + +func TestJWTFileReader_ReadToken_ValidToken(t *testing.T) { + now := time.Now().Truncate(time.Second) + tokenExp := now.Add(time.Hour) + token := createTestJWT(t, "https://example.com", tokenExp) + tokenFile := writeTempFile(t, "token", token) + + reader := newJWTFileReader(tokenFile) + readToken, expiry, err := reader.ReadToken() + if err != nil { + t.Fatalf("ReadToken() unexpected error: %v", err) + } + + if readToken != token { + t.Errorf("ReadToken() token = %q, want %q", readToken, token) + } + + if !expiry.Equal(tokenExp) { + t.Errorf("ReadToken() expiry = %v, want %v", expiry, tokenExp) + } +} + +// createInvalidJSONJWT creates a JWT with invalid JSON in the payload. +func createInvalidJSONJWT(t *testing.T) string { + t.Helper() + + header := map[string]any{ + "typ": "JWT", + "alg": "HS256", + } + + headerBytes, err := json.Marshal(header) + if err != nil { + t.Fatalf("Failed to marshal header: %v", err) + } + + headerB64 := base64.URLEncoding.EncodeToString(headerBytes) + headerB64 = strings.TrimRight(headerB64, "=") + + // Create invalid JSON payload + invalidJSON := "invalid json content" + payloadB64 := base64.URLEncoding.EncodeToString([]byte(invalidJSON)) + payloadB64 = strings.TrimRight(payloadB64, "=") + + signature := base64.URLEncoding.EncodeToString([]byte("fake_signature")) + signature = strings.TrimRight(signature, "=") + + return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signature) +} diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index c42dbccd114d..90479310690b 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -21,10 +21,7 @@ package jwt import ( "context" - "encoding/base64" - "encoding/json" "fmt" - "os" "strings" "sync" "time" @@ -35,28 +32,23 @@ import ( "google.golang.org/grpc/status" ) -// jwtClaims represents the JWT claims structure for extracting expiration time. -type jwtClaims struct { - Exp int64 `json:"exp"` -} - // jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads // tokens from a file. // This implementation follows the A97 JWT Call Credentials specification. type jwtTokenFileCallCreds struct { - tokenFilePath string - backoffStrategy backoff.Strategy // Strategy when error occurs - - // Cached token data - mu sync.RWMutex - cachedToken string - cachedExpiration time.Time // Slightly less than actual expiration time - cachedError error // Error from last failed attempt - retryAttempt int // Current retry attempt number - nextRetryTime time.Time // When next retry is allowed - - // Pre-emptive refresh mutex - refreshMu sync.Mutex + fileReader *jWTFileReader + backoffStrategy backoff.Strategy + + // The below state is protected by mu. The cond field is initialised with + // &mu and used to coordinate an async token refresh. + mu sync.Mutex + cond *sync.Cond + cachedToken string + cachedExpiry time.Time // Slightly less than actual expiration time + cachedError error // Error from last failed attempt + retryAttempt int // Current retry attempt number + nextRetryTime time.Time // When next retry is allowed + pendingRefresh bool // Whether a refresh is currently in progress } // NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens @@ -66,10 +58,13 @@ func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCreden return nil, fmt.Errorf("tokenFilePath cannot be empty") } - return &jwtTokenFileCallCreds{ - tokenFilePath: tokenFilePath, + creds := &jwtTokenFileCallCreds{ + fileReader: newJWTFileReader(tokenFilePath), backoffStrategy: backoff.DefaultExponential, - }, nil + } + creds.cond = sync.NewCond(&creds.mu) + + return creds, nil } // GetRequestMetadata gets the current request metadata, refreshing tokens if @@ -86,14 +81,49 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str return nil, fmt.Errorf("unable to transfer JWT token file PerRPCCredentials: %v", err) } - // This may be delayed if the token needs to be refreshed from file. - token, err := c.getToken() - if err != nil { - return nil, err + c.mu.Lock() + defer c.mu.Unlock() + + if c.isTokenValidLocked() { + if c.needsPreemptiveRefreshLocked() { + // Start refresh if not pending (handling the prior RPC may have + // just spwaned a goroutine). + if !c.pendingRefresh { + c.pendingRefresh = true + go c.refreshToken() + } + } + return map[string]string{ + "authorization": "Bearer " + c.cachedToken, + }, nil + } + + // If in backoff state, just return the cached error. + if c.cachedError != nil && time.Now().Before(c.nextRetryTime) { + return nil, c.cachedError + } + + // At this point, the token is either invalid or expired and we are no + // longer backing off. So refresh it. + + if !c.pendingRefresh { + c.pendingRefresh = true + go c.refreshToken() + } + // Wait for refresh to complete. + // NOTE: cond is initialised with &mu, so it gets released while waiting. + for c.pendingRefresh { + c.cond.Wait() + } + + // Refresh completed, re-check the state. + + if c.cachedError != nil { + return nil, c.cachedError } return map[string]string{ - "authorization": "Bearer " + token, + "authorization": "Bearer " + c.cachedToken, }, nil } @@ -103,44 +133,13 @@ func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool { return true } -// getToken returns a valid JWT token, reading from file if necessary. -// Implements pre-emptive refresh and caches errors with backoff. -func (c *jwtTokenFileCallCreds) getToken() (string, error) { - c.mu.RLock() - - if c.isTokenValidLocked() { - token := c.cachedToken - shouldRefresh := c.needsPreemptiveRefreshLocked() - c.mu.RUnlock() - - if shouldRefresh { - c.triggerPreemptiveRefresh() - } - return token, nil - } - - // If still within backoff period, return cached error to avoid repeated - // file reads. - if c.cachedError != nil && time.Now().Before(c.nextRetryTime) { - err := c.cachedError - c.mu.RUnlock() - return "", err - } - - c.mu.RUnlock() - // Token is expired or missing or the retry backoff period has expired. - // So we should refresh synchronously. - // NOTE: refreshTokenSync itself acquires the write lock. - return c.refreshTokenSync(false) -} - // isTokenValidLocked checks if the cached token is still valid. -// Caller must hold c.mu.RLock(). +// Caller must hold c.mu lock. func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { if c.cachedToken == "" { return false } - return c.cachedExpiration.After(time.Now()) + return c.cachedExpiry.After(time.Now()) } // needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be @@ -148,132 +147,54 @@ func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { // Returns true if the cached token is valid but expires within 1 minute. // We only trigger pre-emptive refresh for valid tokens - if the token is // invalid or expired, the next RPC will handle synchronous refresh instead. -// Caller must hold c.mu.RLock(). +// Caller must hold c.mu lock. func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { - return c.isTokenValidLocked() && time.Until(c.cachedExpiration) < time.Minute + return c.isTokenValidLocked() && time.Until(c.cachedExpiry) < time.Minute } -// triggerPreemptiveRefresh starts a background refresh if needed. -// Multiple concurrent calls are safe - only one refresh will run at a time. -// The refresh runs in a separate goroutine and does not block the caller. -func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { - go func() { - c.refreshMu.Lock() - defer c.refreshMu.Unlock() +// refreshToken reads the token from file. +// The file read is done without holding the mutex to avoid blocking RPCs in +// GetRequestMetadata while waiting for the file read. +// Updates the cache and broadcasts to waiting goroutines when complete. +func (c *jwtTokenFileCallCreds) refreshToken() { + // Deliberately not locking c.mu here + token, expiry, err := c.fileReader.ReadToken() - // Re-check if refresh is still needed under mutex. - c.mu.RLock() - stillNeeded := c.needsPreemptiveRefreshLocked() - c.mu.RUnlock() - - if !stillNeeded { - return // Another goroutine already refreshed or token expired. - } - - // Force refresh to read new token even if current one is still valid. - _, _ = c.refreshTokenSync(true) - }() -} - -// refreshTokenSync reads a new token from the file and updates the cache. If -// forceRefresh is true, bypasses the validity check of the currently -// cached token and always reads from file. -// This is used for pre-emptive refresh to ensure new tokens are loaded even -// when the cached token is still valid. If forceRefresh is false, skips -// file read when cached token is still valid, optimizing concurrent synchronous -// refresh calls where one RPC may have already updated the cache while another -// was waiting on the lock. -func (c *jwtTokenFileCallCreds) refreshTokenSync(forceRefresh bool) (string, error) { c.mu.Lock() defer c.mu.Unlock() - - // Double-check under write lock but skip if preemptive refresh is - // requested. - if !forceRefresh && c.isTokenValidLocked() { - return c.cachedToken, nil - } - - tokenBytes, err := os.ReadFile(c.tokenFilePath) if err != nil { - err = status.Errorf(codes.Unavailable, "failed to read token file %q: %v", c.tokenFilePath, err) - c.setErrorWithBackoffLocked(err) - return "", err - } - - token := strings.TrimSpace(string(tokenBytes)) - if token == "" { - err := status.Errorf(codes.Unavailable, "token file %q is empty", c.tokenFilePath) - c.setErrorWithBackoffLocked(err) - return "", err - } - - // Parse JWT to extract expiration. - exp, err := c.extractExpiration(token) - if err != nil { - err = status.Errorf(codes.Unauthenticated, "failed to parse JWT from token file %q: %v", c.tokenFilePath, err) - c.setErrorWithBackoffLocked(err) - return "", err - } - - // Success - clear any cached error and backoff state, update token cache. - c.clearErrorAndBackoffLocked() - c.cachedToken = token - // Per RFC A97: consider token invalid if it expires within the next 30 - // seconds to accommodate for clock skew and server processing time. - c.cachedExpiration = exp.Add(-30 * time.Second) - - return token, nil -} - -// extractExpiration parses the JWT token to extract the expiration time. -func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) - } - - payload := parts[1] - // Add padding if necessary for base64 decoding. - if m := len(payload) % 4; m != 0 { - payload += strings.Repeat("=", 4-m) - } - - payloadBytes, err := base64.URLEncoding.DecodeString(payload) - if err != nil { - return time.Time{}, fmt.Errorf("failed to decode JWT payload: %v", err) - } - - var claims jwtClaims - if err := json.Unmarshal(payloadBytes, &claims); err != nil { - return time.Time{}, fmt.Errorf("failed to unmarshal JWT claims: %v", err) - } - - if claims.Exp == 0 { - return time.Time{}, fmt.Errorf("JWT token has no expiration claim") - } - - expTime := time.Unix(claims.Exp, 0) - - // Check if token is already expired. - if expTime.Before(time.Now()) { - return time.Time{}, fmt.Errorf("JWT token is expired") - } - - return expTime, nil + // Convert to gRPC status codes + if strings.Contains(err.Error(), "failed to read token file") || strings.Contains(err.Error(), "token file") && strings.Contains(err.Error(), "is empty") { + c.cachedError = status.Errorf(codes.Unavailable, "%v", err) + } else { + c.cachedError = status.Errorf(codes.Unauthenticated, "%v", err) + } + c.setErrorWithBackoffLocked() + } else { + // Success - clear any cached error and update token cache + c.clearErrorAndBackoffLocked() + c.cachedToken = token + // Per RFC A97: consider token invalid if it expires within the next 30 + // seconds to accommodate for clock skew and server processing time. + c.cachedExpiry = expiry.Add(-30 * time.Second) + } + + // Reset pending refresh and broadcast to waiting goroutines + c.pendingRefresh = false + c.cond.Broadcast() } // setErrorWithBackoffLocked caches an error and calculates the next retry time // using exponential backoff. -// Caller must hold c.mu write lock. -func (c *jwtTokenFileCallCreds) setErrorWithBackoffLocked(err error) { - c.cachedError = err +// Caller must hold c.mu lock. +func (c *jwtTokenFileCallCreds) setErrorWithBackoffLocked() { c.retryAttempt++ backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1) c.nextRetryTime = time.Now().Add(backoffDelay) } // clearErrorAndBackoffLocked clears the cached error and resets backoff state. -// Caller must hold c.mu write lock. +// Caller must hold c.mu lock. func (c *jwtTokenFileCallCreds) clearErrorAndBackoffLocked() { c.cachedError = nil c.retryAttempt = 0 diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index 1eaacc5b3ab7..83a4fd03dc5d 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -209,7 +209,6 @@ func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { } } - // testAuthInfo implements credentials.AuthInfo for testing. type testAuthInfo struct { secLevel credentials.SecurityLevel @@ -223,7 +222,6 @@ func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} } - // Tests that cached token expiration is set to 30 seconds before actual token // expiration. func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { @@ -513,71 +511,6 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { } } - -// Tests that RPCs are queued during file operations and all receive the same -// result. -func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - - // Start with no token file to force file read during first RPC. - creds, err := NewTokenFileCallCredentials(tokenFile) - if err != nil { - t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ - AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, - }) - - // Launch multiple concurrent RPCs before creating the token file. - const numConcurrentRPCs = 5 - results := make(chan error, numConcurrentRPCs) - - for range numConcurrentRPCs { - go func() { - _, err := creds.GetRequestMetadata(ctx) - results <- err - }() - } - - // Collect all results - they should all be the same error (UNAVAILABLE). - var errors []error - for range numConcurrentRPCs { - err := <-results - errors = append(errors, err) - } - - // All RPCs should fail with the same error (file not found). - for i, err := range errors { - if err == nil { - t.Errorf("RPC %d should have failed with UNAVAILABLE", i) - continue - } - if status.Code(err) != codes.Unavailable { - t.Errorf("RPC %d = %v, want UNAVAILABLE", i, status.Code(err)) - } - if i > 0 && err.Error() != errors[0].Error() { - t.Errorf("RPC %d error should match first RPC error for proper queueing", i) - } - } - - // Verify error was cached after concurrent RPCs. - impl := creds.(*jwtTokenFileCallCreds) - impl.mu.Lock() - finalCachedErr := impl.cachedError - impl.mu.Unlock() - - if finalCachedErr == nil { - t.Error("error should be cached after failed concurrent RPCs") - } - if finalCachedErr.Error() != errors[0].Error() { - t.Error("cached error should match the errors returned to RPCs") - } -} - // createTestJWT creates a test JWT token with the specified audience and // expiration. func createTestJWT(t *testing.T, audience string, expiration time.Time) string { @@ -629,4 +562,3 @@ func writeTempFile(t *testing.T, name, content string) string { } return filePath } - From fd388d189e9ae10525b4e195b010eee7b0ea6543 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 18:10:02 +0100 Subject: [PATCH 19/35] refactor to no longer need cond --- credentials/jwt/jwt_token_file_call_creds.go | 63 +++++++------------- 1 file changed, 21 insertions(+), 42 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index 90479310690b..3fcd34a1091a 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -39,10 +39,8 @@ type jwtTokenFileCallCreds struct { fileReader *jWTFileReader backoffStrategy backoff.Strategy - // The below state is protected by mu. The cond field is initialised with - // &mu and used to coordinate an async token refresh. + // cached data protected by mu mu sync.Mutex - cond *sync.Cond cachedToken string cachedExpiry time.Time // Slightly less than actual expiration time cachedError error // Error from last failed attempt @@ -62,7 +60,6 @@ func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCreden fileReader: newJWTFileReader(tokenFilePath), backoffStrategy: backoff.DefaultExponential, } - creds.cond = sync.NewCond(&creds.mu) return creds, nil } @@ -105,23 +102,12 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str // At this point, the token is either invalid or expired and we are no // longer backing off. So refresh it. - - if !c.pendingRefresh { - c.pendingRefresh = true - go c.refreshToken() - } - // Wait for refresh to complete. - // NOTE: cond is initialised with &mu, so it gets released while waiting. - for c.pendingRefresh { - c.cond.Wait() - } - - // Refresh completed, re-check the state. + token, expiry, err := c.fileReader.ReadToken() + c.setCacheLocked(token, expiry, err) if c.cachedError != nil { return nil, c.cachedError } - return map[string]string{ "authorization": "Bearer " + c.cachedToken, }, nil @@ -153,8 +139,6 @@ func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { } // refreshToken reads the token from file. -// The file read is done without holding the mutex to avoid blocking RPCs in -// GetRequestMetadata while waiting for the file read. // Updates the cache and broadcasts to waiting goroutines when complete. func (c *jwtTokenFileCallCreds) refreshToken() { // Deliberately not locking c.mu here @@ -162,6 +146,17 @@ func (c *jwtTokenFileCallCreds) refreshToken() { c.mu.Lock() defer c.mu.Unlock() + c.setCacheLocked(token, expiry, err) + + // Reset pending refresh and broadcast to waiting goroutines + c.pendingRefresh = false +} + +// setCacheLocked updates the cached token, expiry, and error state. +// If an error is provided, it determines whether to set it as an UNAVAILABLE +// or UNAUTHENTICATED error based on the error type. +// Caller must hold c.mu lock. +func (c *jwtTokenFileCallCreds) setCacheLocked(token string, expiry time.Time, err error) { if err != nil { // Convert to gRPC status codes if strings.Contains(err.Error(), "failed to read token file") || strings.Contains(err.Error(), "token file") && strings.Contains(err.Error(), "is empty") { @@ -169,34 +164,18 @@ func (c *jwtTokenFileCallCreds) refreshToken() { } else { c.cachedError = status.Errorf(codes.Unauthenticated, "%v", err) } - c.setErrorWithBackoffLocked() + c.retryAttempt++ + backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1) + c.nextRetryTime = time.Now().Add(backoffDelay) } else { // Success - clear any cached error and update token cache - c.clearErrorAndBackoffLocked() + c.cachedError = nil + c.retryAttempt = 0 + c.nextRetryTime = time.Time{} + c.cachedToken = token // Per RFC A97: consider token invalid if it expires within the next 30 // seconds to accommodate for clock skew and server processing time. c.cachedExpiry = expiry.Add(-30 * time.Second) } - - // Reset pending refresh and broadcast to waiting goroutines - c.pendingRefresh = false - c.cond.Broadcast() -} - -// setErrorWithBackoffLocked caches an error and calculates the next retry time -// using exponential backoff. -// Caller must hold c.mu lock. -func (c *jwtTokenFileCallCreds) setErrorWithBackoffLocked() { - c.retryAttempt++ - backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1) - c.nextRetryTime = time.Now().Add(backoffDelay) -} - -// clearErrorAndBackoffLocked clears the cached error and resets backoff state. -// Caller must hold c.mu lock. -func (c *jwtTokenFileCallCreds) clearErrorAndBackoffLocked() { - c.cachedError = nil - c.retryAttempt = 0 - c.nextRetryTime = time.Time{} } From 790a2d9e8cdc817aa873a0030e4d531e2ee9bcfe Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 18:23:39 +0100 Subject: [PATCH 20/35] fix docstring comment --- credentials/jwt/doc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/credentials/jwt/doc.go b/credentials/jwt/doc.go index f74d3446afb4..bba687f2151e 100644 --- a/credentials/jwt/doc.go +++ b/credentials/jwt/doc.go @@ -36,8 +36,8 @@ // - Tokens are cached until expiration to avoid excessive file I/O // - Transport security is required (RequireTransportSecurity returns true) // - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or -// UNAUTHENTICATED errors -// - These errors are cached and retried with exponential backoff. +// UNAUTHENTICATED errors. The errors are cached and retried with exponential +// backoff. // // This implementation is originally intended for use in service mesh // environments like Istio where JWT tokens are provisioned and rotated by the From 67131908f34b9c0b78dc14a7603a426501a973f6 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 18:24:08 +0100 Subject: [PATCH 21/35] cache authorization header instead of token --- credentials/jwt/jwt_token_file_call_creds.go | 26 ++++++++++--------- .../jwt/jwt_token_file_call_creds_test.go | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index 3fcd34a1091a..81e1165bd421 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -32,6 +32,8 @@ import ( "google.golang.org/grpc/status" ) +const preemptiveRefreshThreshold = time.Minute + // jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads // tokens from a file. // This implementation follows the A97 JWT Call Credentials specification. @@ -40,13 +42,13 @@ type jwtTokenFileCallCreds struct { backoffStrategy backoff.Strategy // cached data protected by mu - mu sync.Mutex - cachedToken string - cachedExpiry time.Time // Slightly less than actual expiration time - cachedError error // Error from last failed attempt - retryAttempt int // Current retry attempt number - nextRetryTime time.Time // When next retry is allowed - pendingRefresh bool // Whether a refresh is currently in progress + mu sync.Mutex + cachedAuthHeader string // "Bearer " + token + cachedExpiry time.Time // Slightly less than actual expiration time + cachedError error // Error from last failed attempt + retryAttempt int // Current retry attempt number + nextRetryTime time.Time // When next retry is allowed + pendingRefresh bool // Whether a refresh is currently in progress } // NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens @@ -91,7 +93,7 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str } } return map[string]string{ - "authorization": "Bearer " + c.cachedToken, + "authorization": c.cachedAuthHeader, }, nil } @@ -109,7 +111,7 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str return nil, c.cachedError } return map[string]string{ - "authorization": "Bearer " + c.cachedToken, + "authorization": c.cachedAuthHeader, }, nil } @@ -122,7 +124,7 @@ func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool { // isTokenValidLocked checks if the cached token is still valid. // Caller must hold c.mu lock. func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { - if c.cachedToken == "" { + if c.cachedAuthHeader == "" { return false } return c.cachedExpiry.After(time.Now()) @@ -135,7 +137,7 @@ func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { // invalid or expired, the next RPC will handle synchronous refresh instead. // Caller must hold c.mu lock. func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { - return c.isTokenValidLocked() && time.Until(c.cachedExpiry) < time.Minute + return c.isTokenValidLocked() && time.Until(c.cachedExpiry) < preemptiveRefreshThreshold } // refreshToken reads the token from file. @@ -173,7 +175,7 @@ func (c *jwtTokenFileCallCreds) setCacheLocked(token string, expiry time.Time, e c.retryAttempt = 0 c.nextRetryTime = time.Time{} - c.cachedToken = token + c.cachedAuthHeader = "Bearer " + token // Per RFC A97: consider token invalid if it expires within the next 30 // seconds to accommodate for clock skew and server processing time. c.cachedExpiry = expiry.Add(-30 * time.Second) diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index 83a4fd03dc5d..c09718d702da 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -288,7 +288,7 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { impl := creds.(*jwtTokenFileCallCreds) impl.mu.Lock() cacheExp := impl.cachedExpiry - tokenCached := impl.cachedToken != "" + tokenCached := impl.cachedAuthHeader != "" shouldTriggerRefresh := impl.needsPreemptiveRefreshLocked() impl.mu.Unlock() From 3f563eb09cbb7b04c249c5b1f2078197ecadb939 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 18:25:11 +0100 Subject: [PATCH 22/35] remove internal/ and xds/ changes --- credentials/jwt/jwt_token_file_call_creds.go | 10 +- internal/envconfig/xds.go | 5 - internal/xds/bootstrap/bootstrap.go | 87 +-- internal/xds/bootstrap/bootstrap_test.go | 606 +------------------ xds/bootstrap/bootstrap.go | 4 - xds/bootstrap/bootstrap_test.go | 42 +- xds/bootstrap/credentials.go | 14 - xds/internal/xdsclient/clientimpl.go | 4 +- xds/internal/xdsclient/clientimpl_test.go | 91 --- 9 files changed, 19 insertions(+), 844 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index 81e1165bd421..3e691b016cc8 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -86,7 +86,7 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str if c.isTokenValidLocked() { if c.needsPreemptiveRefreshLocked() { // Start refresh if not pending (handling the prior RPC may have - // just spwaned a goroutine). + // just spawned a goroutine). if !c.pendingRefresh { c.pendingRefresh = true go c.refreshToken() @@ -105,7 +105,7 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str // At this point, the token is either invalid or expired and we are no // longer backing off. So refresh it. token, expiry, err := c.fileReader.ReadToken() - c.setCacheLocked(token, expiry, err) + c.updateCacheLocked(token, expiry, err) if c.cachedError != nil { return nil, c.cachedError @@ -148,17 +148,17 @@ func (c *jwtTokenFileCallCreds) refreshToken() { c.mu.Lock() defer c.mu.Unlock() - c.setCacheLocked(token, expiry, err) + c.updateCacheLocked(token, expiry, err) // Reset pending refresh and broadcast to waiting goroutines c.pendingRefresh = false } -// setCacheLocked updates the cached token, expiry, and error state. +// updateCacheLocked updates the cached token, expiry, and error state. // If an error is provided, it determines whether to set it as an UNAVAILABLE // or UNAUTHENTICATED error based on the error type. // Caller must hold c.mu lock. -func (c *jwtTokenFileCallCreds) setCacheLocked(token string, expiry time.Time, err error) { +func (c *jwtTokenFileCallCreds) updateCacheLocked(token string, expiry time.Time, err error) { if err != nil { // Convert to gRPC status codes if strings.Contains(err.Error(), "failed to read token file") || strings.Contains(err.Error(), "token file") && strings.Contains(err.Error(), "is empty") { diff --git a/internal/envconfig/xds.go b/internal/envconfig/xds.go index 6420558c0b7a..e87551552ad7 100644 --- a/internal/envconfig/xds.go +++ b/internal/envconfig/xds.go @@ -68,9 +68,4 @@ var ( // trust. For more details, see: // https://github.com/grpc/proposal/blob/master/A87-mtls-spiffe-support.md XDSSPIFFEEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false) - - // XDSBootstrapCallCredsEnabled controls if JWT call credentials can be used - // in xDS bootstrap configuration. For more details, see: - // https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md - XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false) ) diff --git a/internal/xds/bootstrap/bootstrap.go b/internal/xds/bootstrap/bootstrap.go index 46dbf6bc98bc..f409e4bd77b2 100644 --- a/internal/xds/bootstrap/bootstrap.go +++ b/internal/xds/bootstrap/bootstrap.go @@ -31,7 +31,6 @@ import ( "strings" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" @@ -65,26 +64,11 @@ type ChannelCreds struct { Config json.RawMessage `json:"config,omitempty"` } -// CallCreds contains the call credentials configuration for individual RPCs. -// This type implements RFC A97 call credentials structure. -type CallCreds struct { - // Type contains a unique name identifying the call credentials type. - // Currently only "jwt_token_file" is supported. - Type string `json:"type,omitempty"` - // Config contains the JSON configuration associated with the call credentials. - Config json.RawMessage `json:"config,omitempty"` -} - // Equal reports whether cc and other are considered equal. func (cc ChannelCreds) Equal(other ChannelCreds) bool { return cc.Type == other.Type && bytes.Equal(cc.Config, other.Config) } -// Equal reports whether cc and other are considered equal. -func (cc CallCreds) Equal(other CallCreds) bool { - return cc.Type == other.Type && bytes.Equal(cc.Config, other.Config) -} - // String returns a string representation of the credentials. It contains the // type and the config (if non-nil) separated by a "-". func (cc ChannelCreds) String() string { @@ -188,15 +172,13 @@ type ServerConfig struct { serverURI string channelCreds []ChannelCreds serverFeatures []string - callCreds []CallCreds // As part of unmarshalling the JSON config into this struct, we ensure that // the credentials config is valid by building an instance of the specified // credentials and store it here for easy access. - selectedCreds ChannelCreds - credsDialOption grpc.DialOption - extraDialOptions []grpc.DialOption - selectedCallCreds []credentials.PerRPCCredentials // Built call credentials + selectedCreds ChannelCreds + credsDialOption grpc.DialOption + extraDialOptions []grpc.DialOption cleanups []func() } @@ -218,17 +200,6 @@ func (sc *ServerConfig) ServerFeatures() []string { return sc.serverFeatures } -// CallCreds returns the call credentials configuration for this server. -func (sc *ServerConfig) CallCreds() []CallCreds { - return sc.callCreds -} - -// SelectedCallCreds returns the built call credentials that are ready to use. -// These are the credentials that were successfully built from the call_creds configuration. -func (sc *ServerConfig) SelectedCallCreds() []credentials.PerRPCCredentials { - return sc.selectedCallCreds -} - // ServerFeaturesIgnoreResourceDeletion returns true if this server supports a // feature where the xDS client can ignore resource deletions from this server, // as described in gRFC A53. @@ -262,28 +233,6 @@ func (sc *ServerConfig) DialOptions() []grpc.DialOption { return dopts } -// DialOptionsWithCallCredsForTransport returns dial options including call credentials -// only if they are compatible with the specified transport credentials type. -// Call credentials that require transport security will be skipped for insecure transports. -func (sc *ServerConfig) DialOptionsWithCallCredsForTransport(transportCredsType string, transportCreds credentials.TransportCredentials) []grpc.DialOption { - dopts := sc.DialOptions() - - // Check if transport is insecure - isInsecureTransport := transportCredsType == "insecure" || - (transportCreds != nil && transportCreds.Info().SecurityProtocol == "insecure") - - // Add call credentials only if compatible with transport security - for _, callCred := range sc.selectedCallCreds { - // Skip call credentials that require transport security on insecure transports - if isInsecureTransport && callCred.RequireTransportSecurity() { - continue - } - dopts = append(dopts, grpc.WithPerRPCCredentials(callCred)) - } - - return dopts -} - // Cleanups returns a collection of functions to be called when the xDS client // for this server is closed. Allows cleaning up resources created specifically // for this server. @@ -302,8 +251,6 @@ func (sc *ServerConfig) Equal(other *ServerConfig) bool { return false case !slices.EqualFunc(sc.channelCreds, other.channelCreds, func(a, b ChannelCreds) bool { return a.Equal(b) }): return false - case !slices.EqualFunc(sc.callCreds, other.callCreds, func(a, b CallCreds) bool { return a.Equal(b) }): - return false case !slices.Equal(sc.serverFeatures, other.serverFeatures): return false case !sc.selectedCreds.Equal(other.selectedCreds): @@ -326,7 +273,6 @@ type serverConfigJSON struct { ServerURI string `json:"server_uri,omitempty"` ChannelCreds []ChannelCreds `json:"channel_creds,omitempty"` ServerFeatures []string `json:"server_features,omitempty"` - CallCreds []CallCreds `json:"call_creds,omitempty"` } // MarshalJSON returns marshaled JSON bytes corresponding to this server config. @@ -335,7 +281,6 @@ func (sc *ServerConfig) MarshalJSON() ([]byte, error) { ServerURI: sc.serverURI, ChannelCreds: sc.channelCreds, ServerFeatures: sc.serverFeatures, - CallCreds: sc.callCreds, } return json.Marshal(server) } @@ -356,7 +301,6 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.serverURI = server.ServerURI sc.channelCreds = server.ChannelCreds sc.serverFeatures = server.ServerFeatures - sc.callCreds = server.CallCreds for _, cc := range server.ChannelCreds { // We stop at the first credential type that we support. @@ -376,27 +320,6 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.cleanups = append(sc.cleanups, cancel) break } - - // Process call credentials - unlike channel creds, we use ALL supported types - // Call credentials are optional per RFC A97 - for _, callCred := range server.CallCreds { - c := bootstrap.GetCredentials(callCred.Type) - if c == nil { - // Skip unsupported call credential types (don't fail bootstrap) - continue - } - bundle, cancel, err := c.Build(callCred.Config) - if err != nil { - // Call credential validation failed - this should fail bootstrap - return fmt.Errorf("failed to build call credentials from bootstrap for %q: %v", callCred.Type, err) - } - // Extract the PerRPCCredentials from the bundle. Sanity check for nil just in case - if callCredentials := bundle.PerRPCCredentials(); callCredentials != nil { - sc.selectedCallCreds = append(sc.selectedCallCreds, callCredentials) - } - sc.cleanups = append(sc.cleanups, cancel) - } - if sc.serverURI == "" { return fmt.Errorf("xds: `server_uri` field in server config cannot be empty: %s", string(data)) } @@ -418,9 +341,6 @@ type ServerConfigTestingOptions struct { ChannelCreds []ChannelCreds // ServerFeatures represents the list of features supported by this server. ServerFeatures []string - // CallCreds contains a list of call credentials to use for individual RPCs - // to this server. Optional. - CallCreds []CallCreds } // ServerConfigForTesting creates a new ServerConfig from the passed in options, @@ -436,7 +356,6 @@ func ServerConfigForTesting(opts ServerConfigTestingOptions) (*ServerConfig, err ServerURI: opts.URI, ChannelCreds: cc, ServerFeatures: opts.ServerFeatures, - CallCreds: opts.CallCreds, } scJSON, err := json.Marshal(scInternal) if err != nil { diff --git a/internal/xds/bootstrap/bootstrap_test.go b/internal/xds/bootstrap/bootstrap_test.go index 93e90144fd28..d057197804d6 100644 --- a/internal/xds/bootstrap/bootstrap_test.go +++ b/internal/xds/bootstrap/bootstrap_test.go @@ -19,21 +19,15 @@ package bootstrap import ( - "context" "encoding/json" "errors" "fmt" - "net" "os" - "strings" "testing" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/credentials/jwt" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" @@ -202,74 +196,6 @@ var ( "server_features" : ["ignore_resource_deletion", "xds_v3"] }] }`, - // example data seeded from - // https://github.com/istio/istio/blob/master/pkg/istio-agent/testdata/grpc-bootstrap.json - "istioStyleWithJWTCallCreds": ` - { - "node": { - "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", - "metadata": { - "GENERATOR": "grpc", - "INSTANCE_IPS": "127.0.0.1", - "ISTIO_VERSION": "1.26.2", - "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" - }, - "locality": {} - }, - "xds_servers" : [{ - "server_uri": "unix:///etc/istio/XDS", - "channel_creds": [ - { "type": "insecure" } - ], - "call_creds": [ - { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } - ], - "server_features" : ["xds_v3"] - }] - }`, - "istioStyleWithoutCallCreds": ` - { - "node": { - "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", - "metadata": { - "GENERATOR": "grpc", - "INSTANCE_IPS": "127.0.0.1", - "ISTIO_VERSION": "1.26.2", - "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" - }, - "locality": {} - }, - "xds_servers" : [{ - "server_uri": "unix:///etc/istio/XDS", - "channel_creds": [ - { "type": "insecure" } - ], - "server_features" : ["xds_v3"] - }] - }`, - "istioStyleWithTLSAndJWT": ` - { - "node": { - "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", - "metadata": { - "GENERATOR": "grpc", - "INSTANCE_IPS": "127.0.0.1", - "ISTIO_VERSION": "1.26.2", - "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" - }, - "locality": {} - }, - "xds_servers" : [{ - "server_uri": "unix:///etc/istio/XDS", - "channel_creds": [ - { "type": "tls", "config": {} } - ], - "call_creds": [ - { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } - ], - "server_features" : ["xds_v3"] - }] - }`, } metadata = &structpb.Struct{ Fields: map[string]*structpb.Value{ @@ -350,82 +276,6 @@ var ( node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", } - - istioNodeMetadata = &structpb.Struct{ - Fields: map[string]*structpb.Value{ - "GENERATOR": { - Kind: &structpb.Value_StringValue{StringValue: "grpc"}, - }, - "INSTANCE_IPS": { - Kind: &structpb.Value_StringValue{StringValue: "127.0.0.1"}, - }, - "ISTIO_VERSION": { - Kind: &structpb.Value_StringValue{StringValue: "1.26.2"}, - }, - "WORKLOAD_IDENTITY_SOCKET_FILE": { - Kind: &structpb.Value_StringValue{StringValue: "socket"}, - }, - }, - } - jwtCallCreds, _ = jwt.NewTokenFileCallCredentials("/var/run/secrets/tokens/istio-token") - selectedJWTCallCreds = []credentials.PerRPCCredentials{jwtCallCreds} - configWithIstioJWTCallCreds = &Config{ - xDSServers: []*ServerConfig{{ - serverURI: "unix:///etc/istio/XDS", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - callCreds: []CallCreds{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "insecure"}, - selectedCallCreds: selectedJWTCallCreds, - }}, - node: node{ - ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", - Metadata: istioNodeMetadata, - userAgentName: gRPCUserAgentName, - userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, - clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, - }, - certProviderConfigs: map[string]*certprovider.BuildableConfig{}, - clientDefaultListenerResourceNameTemplate: "%s", - } - - configWithIstioStyleNoCallCreds = &Config{ - xDSServers: []*ServerConfig{{ - serverURI: "unix:///etc/istio/XDS", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "insecure"}, - }}, - node: node{ - ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", - Metadata: istioNodeMetadata, - userAgentName: gRPCUserAgentName, - userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, - clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, - }, - certProviderConfigs: map[string]*certprovider.BuildableConfig{}, - clientDefaultListenerResourceNameTemplate: "%s", - } - - configWithIstioStyleWithTLSAndJWT = &Config{ - xDSServers: []*ServerConfig{{ - serverURI: "unix:///etc/istio/XDS", - channelCreds: []ChannelCreds{{Type: "tls", Config: json.RawMessage("{}")}}, - callCreds: []CallCreds{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "tls", Config: json.RawMessage("{}")}, - selectedCallCreds: selectedJWTCallCreds, - }}, - node: node{ - ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", - Metadata: istioNodeMetadata, - userAgentName: gRPCUserAgentName, - userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, - clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, - }, - certProviderConfigs: map[string]*certprovider.BuildableConfig{}, - clientDefaultListenerResourceNameTemplate: "%s", - } ) func fileReadFromFileMap(bootstrapFileMap map[string]string, name string) ([]byte, error) { @@ -575,35 +425,6 @@ func (s) TestGetConfiguration_Success(t *testing.T) { {"goodBootstrap", configWithGoogleDefaultCredsAndV3}, {"multipleXDSServers", configWithMultipleServers}, {"serverSupportsIgnoreResourceDeletion", configWithGoogleDefaultCredsAndIgnoreResourceDeletion}, - {"istioStyleWithoutCallCreds", configWithIstioStyleNoCallCreds}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - testGetConfigurationWithFileNameEnv(t, test.name, false, test.wantConfig) - testGetConfigurationWithFileContentEnv(t, test.name, false, test.wantConfig) - }) - } -} - -// Tests Istio-style bootstrap configurations with JWT call credentials -func (s) TestGetConfiguration_IstioStyleWithCallCreds(t *testing.T) { - // Enable JWT call credentials feature - original := envconfig.XDSBootstrapCallCredsEnabled - envconfig.XDSBootstrapCallCredsEnabled = true - defer func() { - envconfig.XDSBootstrapCallCredsEnabled = original - }() - - cancel := setupBootstrapOverride(v3BootstrapFileMap) - defer cancel() - - tests := []struct { - name string - wantConfig *Config - }{ - {"istioStyleWithJWTCallCreds", configWithIstioJWTCallCreds}, - {"istioStyleWithTLSAndJWT", configWithIstioStyleWithTLSAndJWT}, } for _, test := range tests { @@ -1197,203 +1018,12 @@ func (s) TestDefaultBundles(t *testing.T) { } } -func (s) TestCallCreds_Equal(t *testing.T) { - tests := []struct { - name string - cc1 CallCreds - cc2 CallCreds - expect bool - }{ - { - name: "identical configs", - cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, - cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, - expect: true, - }, - { - name: "different types", - cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, - cc2: CallCreds{Type: "other_type", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, - expect: false, - }, - { - name: "different configs", - cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, - cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/different/path"}`)}, - expect: false, - }, - { - name: "nil vs non-nil configs", - cc1: CallCreds{Type: "jwt_token_file", Config: nil}, - cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, - expect: false, - }, - { - name: "both nil configs", - cc1: CallCreds{Type: "jwt_token_file", Config: nil}, - cc2: CallCreds{Type: "jwt_token_file", Config: nil}, - expect: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - result := test.cc1.Equal(test.cc2) - if result != test.expect { - t.Errorf("CallCreds.Equal() = %v, want %v", result, test.expect) - } - }) - } -} - -func (s) TestServerConfig_UnmarshalJSON_WithCallCreds(t *testing.T) { - original := envconfig.XDSBootstrapCallCredsEnabled - defer func() { envconfig.XDSBootstrapCallCredsEnabled = original }() - envconfig.XDSBootstrapCallCredsEnabled = true // Enable call creds in bootstrap - tests := []struct { - name string - json string - wantCallCreds []CallCreds - wantErr bool - errContains string - }{ - { - name: "valid call_creds with jwt_token_file", - json: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "insecure"}], - "call_creds": [ - { - "type": "jwt_token_file", - "config": {"jwt_token_file": "/path/to/token.jwt"} - } - ] - }`, - wantCallCreds: []CallCreds{{ - Type: "jwt_token_file", - Config: json.RawMessage(`{"jwt_token_file": "/path/to/token.jwt"}`), - }}, - }, - { - name: "multiple call_creds types", - json: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "insecure"}], - "call_creds": [ - {"type": "jwt_token_file", "config": {"jwt_token_file": "/token1.jwt"}}, - {"type": "unsupported_type", "config": {}} - ] - }`, - wantCallCreds: []CallCreds{ - {Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/token1.jwt"}`)}, - {Type: "unsupported_type", Config: json.RawMessage(`{}`)}, - }, - }, - { - name: "empty call_creds array", - json: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "insecure"}], - "call_creds": [] - }`, - wantCallCreds: []CallCreds{}, - }, - { - name: "missing call_creds field", - json: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "insecure"}] - }`, - wantCallCreds: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var sc ServerConfig - err := sc.UnmarshalJSON([]byte(test.json)) - - if test.wantErr { - if err == nil { - t.Fatal("Expected error, got nil") - } - if test.errContains != "" && !strings.Contains(err.Error(), test.errContains) { - t.Errorf("Error %v should contain %q", err, test.errContains) - } - return - } - - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if diff := cmp.Diff(test.wantCallCreds, sc.CallCreds()); diff != "" { - t.Errorf("CallCreds mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func (s) TestServerConfig_Equal_WithCallCreds(t *testing.T) { - callCreds := []CallCreds{{ - Type: "jwt_token_file", - Config: json.RawMessage(`{"jwt_token_file": "/test/token.jwt"}`), - }} - sc1 := &ServerConfig{ - serverURI: "server1", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - callCreds: callCreds, - serverFeatures: []string{"feature1"}, - } - sc2 := &ServerConfig{ - serverURI: "server1", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - callCreds: callCreds, - serverFeatures: []string{"feature1"}, - } - sc3 := &ServerConfig{ - serverURI: "server1", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - callCreds: []CallCreds{{Type: "different"}}, - serverFeatures: []string{"feature1"}, - } - - if !sc1.Equal(sc2) { - t.Error("Equal ServerConfigs with same call creds should be equal") - } - if sc1.Equal(sc3) { - t.Error("ServerConfigs with different call creds should not be equal") - } +type s struct { + grpctest.Tester } -func (s) TestServerConfig_MarshalJSON_WithCallCreds(t *testing.T) { - original := envconfig.XDSBootstrapCallCredsEnabled - defer func() { envconfig.XDSBootstrapCallCredsEnabled = original }() - envconfig.XDSBootstrapCallCredsEnabled = true // Enable call creds in bootstrap - sc := &ServerConfig{ - serverURI: "test-server:443", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - callCreds: []CallCreds{{ - Type: "jwt_token_file", - Config: json.RawMessage(`{"jwt_token_file":"/test/token.jwt"}`), - }}, - serverFeatures: []string{"test_feature"}, - } - - data, err := sc.MarshalJSON() - if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) - } - - // confirm Marshal/Unmarshal symmetry - var unmarshaled ServerConfig - if err := json.Unmarshal(data, &unmarshaled); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - if diff := cmp.Diff(sc.CallCreds(), unmarshaled.CallCreds()); diff != "" { - t.Errorf("Marshal/Unmarshal call credentials produces differences:\n%s", diff) - } +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) } func newStructProtoFromMap(t *testing.T, input map[string]any) *structpb.Struct { @@ -1639,231 +1269,3 @@ func (s) TestGetConfiguration_FallbackDisabled(t *testing.T) { testGetConfigurationWithFileContentEnv(t, "multipleXDSServers", false, wantConfig) }) } - -func (s) TestBootstrap_SelectedCredsAndCallCreds(t *testing.T) { - // Enable JWT call credentials - original := envconfig.XDSBootstrapCallCredsEnabled - envconfig.XDSBootstrapCallCredsEnabled = true - defer func() { - envconfig.XDSBootstrapCallCredsEnabled = original - }() - - tokenFile := "/token.jwt" - tests := []struct { - name string - bootstrapConfig string - expectCallCreds int - expectTransportType string - }{ - { - name: "JWT call creds with TLS channel creds", - bootstrapConfig: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "tls", "config": {}}], - "call_creds": [ - { - "type": "jwt_token_file", - "config": {"jwt_token_file": "` + tokenFile + `"} - } - ] - }`, - expectCallCreds: 1, - expectTransportType: "tls", - }, - { - name: "JWT call creds with multiple channel creds", - bootstrapConfig: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "tls", "config": {}}, {"type": "insecure"}], - "call_creds": [ - { - "type": "jwt_token_file", - "config": {"jwt_token_file": "` + tokenFile + `"} - }, - { - "type": "jwt_token_file", - "config": {"jwt_token_file": "` + tokenFile + `"} - } - ] - }`, - expectCallCreds: 2, - expectTransportType: "tls", // the first channel creds is selected - }, - { - name: "JWT call creds with insecure channel creds", - bootstrapConfig: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "insecure"}], - "call_creds": [ - { - "type": "jwt_token_file", - "config": {"jwt_token_file": "` + tokenFile + `"} - } - ] - }`, - expectCallCreds: 1, - expectTransportType: "insecure", - }, - { - name: "No call creds", - bootstrapConfig: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "insecure"}] - }`, - expectCallCreds: 0, - expectTransportType: "insecure", - }, - { - name: "No call creds multiple channel creds", - bootstrapConfig: `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "insecure"}, {"type": "tls", "config": {}}] - }`, - expectCallCreds: 0, - expectTransportType: "insecure", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var sc ServerConfig - err := sc.UnmarshalJSON([]byte(test.bootstrapConfig)) - if err != nil { - t.Fatalf("Failed to unmarshal bootstrap config: %v", err) - } - - // Verify call credentials processing - callCreds := sc.CallCreds() - selectedCallCreds := sc.SelectedCallCreds() - - if len(callCreds) != test.expectCallCreds { - t.Errorf("Call creds count = %d, want %d", len(callCreds), test.expectCallCreds) - } - if len(selectedCallCreds) != test.expectCallCreds { - t.Errorf("Selected call creds count = %d, want %d", len(selectedCallCreds), test.expectCallCreds) - } - - // Verify transport credentials are properly selected - if sc.SelectedCreds().Type != test.expectTransportType { - t.Errorf("Selected transport creds type = %q, want %q", - sc.SelectedCreds().Type, test.expectTransportType) - } - }) - } -} - -func (s) TestDialOptionsWithCallCredsForTransport(t *testing.T) { - // Create test JWT credentials that require transport security - testJWTCreds := &testPerRPCCreds{requireSecurity: true} - testInsecureCreds := &testPerRPCCreds{requireSecurity: false} - - sc := &ServerConfig{ - selectedCallCreds: []credentials.PerRPCCredentials{ - testJWTCreds, - testInsecureCreds, - }, - extraDialOptions: []grpc.DialOption{ - grpc.WithUserAgent("test-agent"), // Test extra option - }, - } - - tests := []struct { - name string - transportType string - transportCreds credentials.TransportCredentials - expectJWTCreds bool - expectOtherCreds bool - }{ - { - name: "insecure transport by type", - transportType: "insecure", - transportCreds: nil, - expectJWTCreds: false, // JWT requires security - expectOtherCreds: true, // Non-security creds allowed - }, - { - name: "insecure transport by protocol", - transportType: "custom", - transportCreds: insecure.NewCredentials(), - expectJWTCreds: false, // JWT requires security - expectOtherCreds: true, // Non-security creds allowed - }, - { - name: "secure transport", - transportType: "tls", - transportCreds: &testTransportCreds{securityProtocol: "tls"}, - expectJWTCreds: true, // JWT allowed on secure transport - expectOtherCreds: true, // All creds allowed - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := sc.DialOptionsWithCallCredsForTransport(test.transportType, test.transportCreds) - - // Count dial options (should include extra options + applicable call creds) - expectedCount := 2 // extraDialOptions + always include non-security creds - if test.expectJWTCreds { - expectedCount++ - } - - if len(opts) != expectedCount { - t.Errorf("DialOptions count = %d, want %d", len(opts), expectedCount) - } - }) - } -} - -type testPerRPCCreds struct { - requireSecurity bool -} - -func (c *testPerRPCCreds) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { - return map[string]string{"test": "metadata"}, nil -} - -func (c *testPerRPCCreds) RequireTransportSecurity() bool { - return c.requireSecurity -} - -type testTransportCreds struct { - securityProtocol string -} - -func (c *testTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, &testAuthInfo{}, nil -} - -func (c *testTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, &testAuthInfo{}, nil -} - -func (c *testTransportCreds) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{SecurityProtocol: c.securityProtocol} -} - -func (c *testTransportCreds) Clone() credentials.TransportCredentials { - return &testTransportCreds{securityProtocol: c.securityProtocol} -} - -func (c *testTransportCreds) OverrideServerName(string) error { - return nil -} - -type testAuthInfo struct{} - -func (a *testAuthInfo) AuthType() string { - return "test" -} - -func (a *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { - return credentials.CommonAuthInfo{} -} - -type s struct { - grpctest.Tester -} - -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) -} diff --git a/xds/bootstrap/bootstrap.go b/xds/bootstrap/bootstrap.go index b1a5e831b2a6..ef55ff0c02db 100644 --- a/xds/bootstrap/bootstrap.go +++ b/xds/bootstrap/bootstrap.go @@ -29,7 +29,6 @@ import ( "encoding/json" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal/envconfig" ) // registry is a map from credential type name to Credential builder. @@ -59,9 +58,6 @@ func RegisterCredentials(c Credentials) { // GetCredentials returns the credentials associated with a given name. // If no credentials are registered with the name, nil will be returned. func GetCredentials(name string) Credentials { - if name == "jwt_token_file" && !envconfig.XDSBootstrapCallCredsEnabled { - return nil - } if c, ok := registry[name]; ok { return c } diff --git a/xds/bootstrap/bootstrap_test.go b/xds/bootstrap/bootstrap_test.go index 935976975513..d1f7a1b64ee5 100644 --- a/xds/bootstrap/bootstrap_test.go +++ b/xds/bootstrap/bootstrap_test.go @@ -22,7 +22,6 @@ import ( "testing" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal/envconfig" ) const testCredsBuilderName = "test_creds" @@ -65,14 +64,12 @@ func TestRegisterNew(t *testing.T) { func TestCredsBuilders(t *testing.T) { tests := []struct { - typename string - builder Credentials - minimumRequiredConfig json.RawMessage + typename string + builder Credentials }{ - {"google_default", &googleDefaultCredsBuilder{}, nil}, - {"insecure", &insecureCredsBuilder{}, nil}, - {"tls", &tlsCredsBuilder{}, nil}, - {"jwt_token_file", &jwtCallCredsBuilder{}, json.RawMessage(`{"jwt_token_file":"/path/to/token.jwt"}`)}, + {"google_default", &googleDefaultCredsBuilder{}}, + {"insecure", &insecureCredsBuilder{}}, + {"tls", &tlsCredsBuilder{}}, } for _, test := range tests { @@ -81,13 +78,10 @@ func TestCredsBuilders(t *testing.T) { t.Errorf("%T.Name = %v, want %v", test.builder, got, want) } - bundle, stop, err := test.builder.Build(test.minimumRequiredConfig) + _, stop, err := test.builder.Build(nil) if err != nil { t.Fatalf("%T.Build failed: %v", test.builder, err) } - if bundle == nil { - t.Errorf("%T.Build returned nil bundle, expected non-nil", test.builder) - } stop() }) } @@ -106,27 +100,3 @@ func TestTlsCredsBuilder(t *testing.T) { stop() } } - -func TestJwtCallCredentials_BuildDisabledIfFeatureNotEnabled(t *testing.T) { - builder := GetCredentials("jwt_call_creds") - if builder != nil { - t.Fatal("Expected nil Credentials for jwt_call_creds when the feature is disabled.") - } - - // Enable JWT call credentials - original := envconfig.XDSBootstrapCallCredsEnabled - envconfig.XDSBootstrapCallCredsEnabled = true - defer func() { - envconfig.XDSBootstrapCallCredsEnabled = original - }() - - // Test that GetCredentials returns the JWT builder - builder = GetCredentials("jwt_token_file") - if builder == nil { - t.Fatal("GetCredentials(\"jwt_token_file\") returned nil") - } - - if got, want := builder.Name(), "jwt_token_file"; got != want { - t.Errorf("Retrieved builder name = %q, want %q", got, want) - } -} diff --git a/xds/bootstrap/credentials.go b/xds/bootstrap/credentials.go index 38018972f383..578e1278970d 100644 --- a/xds/bootstrap/credentials.go +++ b/xds/bootstrap/credentials.go @@ -24,7 +24,6 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/google" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/internal/xds/bootstrap/jwtcreds" "google.golang.org/grpc/internal/xds/bootstrap/tlscreds" ) @@ -32,7 +31,6 @@ func init() { RegisterCredentials(&insecureCredsBuilder{}) RegisterCredentials(&googleDefaultCredsBuilder{}) RegisterCredentials(&tlsCredsBuilder{}) - RegisterCredentials(&jwtCallCredsBuilder{}) } // insecureCredsBuilder implements the `Credentials` interface defined in @@ -70,15 +68,3 @@ func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func (d *googleDefaultCredsBuilder) Name() string { return "google_default" } - -// jwtCallCredsBuilder implements the `Credentials` interface defined in -// package `xds/bootstrap` and encapsulates JWT call credentials. -type jwtCallCredsBuilder struct{} - -func (j *jwtCallCredsBuilder) Build(configJSON json.RawMessage) (credentials.Bundle, func(), error) { - return jwtcreds.NewBundle(configJSON) -} - -func (j *jwtCallCredsBuilder) Name() string { - return "jwt_token_file" -} diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 80bf8d0e8183..967182740719 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -229,9 +229,7 @@ func populateGRPCTransportConfigsFromServerConfig(sc *bootstrap.ServerConfig, gr grpcTransportConfigs[cc.Type] = grpctransport.Config{ Credentials: bundle, GRPCNewClient: func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - // Only add call credentials that are compatible with this transport type - // Call credentials requiring transport security are skipped for insecure transports - opts = append(opts, sc.DialOptionsWithCallCredsForTransport(cc.Type, bundle.TransportCredentials())...) + opts = append(opts, sc.DialOptions()...) return grpc.NewClient(target, opts...) }, } diff --git a/xds/internal/xdsclient/clientimpl_test.go b/xds/internal/xdsclient/clientimpl_test.go index c7884e8ebff6..fbfc24a074ec 100644 --- a/xds/internal/xdsclient/clientimpl_test.go +++ b/xds/internal/xdsclient/clientimpl_test.go @@ -19,10 +19,8 @@ package xdsclient import ( - "context" "encoding/json" "fmt" - "net" "reflect" "sync" "testing" @@ -30,9 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/internal/xds/bootstrap" "google.golang.org/grpc/xds/internal/clients" @@ -263,90 +259,3 @@ func (s) TestBuildXDSClientConfig_Success(t *testing.T) { }) } } - -func TestServerConfigCallCredsIntegration(t *testing.T) { - // Enable JWT call credentials - originalJWTEnabled := envconfig.XDSBootstrapCallCredsEnabled - envconfig.XDSBootstrapCallCredsEnabled = true - defer func() { - envconfig.XDSBootstrapCallCredsEnabled = originalJWTEnabled - }() - - tokenFile := "/token.jwt" - // Test server config with both channel and call credentials - serverConfigJSON := `{ - "server_uri": "xds-server:443", - "channel_creds": [{"type": "tls", "config": {}}], - "call_creds": [ - { - "type": "jwt_token_file", - "config": {"jwt_token_file": "` + tokenFile + `"} - } - ] - }` - - var sc bootstrap.ServerConfig - if err := sc.UnmarshalJSON([]byte(serverConfigJSON)); err != nil { - t.Fatalf("Failed to unmarshal server config: %v", err) - } - - // Verify call credentials are processed - callCreds := sc.CallCreds() - if len(callCreds) != 1 { - t.Errorf("Expected 1 call credential, got %d", len(callCreds)) - } - - selectedCallCreds := sc.SelectedCallCreds() - if len(selectedCallCreds) != 1 { - t.Errorf("Expected 1 selected call credential, got %d", len(selectedCallCreds)) - } - - // Test dial options for secure transport (should include JWT) - secureOpts := sc.DialOptionsWithCallCredsForTransport("tls", &mockTransportCreds{protocol: "tls"}) - if len(secureOpts) != 1 { - t.Errorf("Expected dial options for secure transport. Got: %#v", secureOpts) - } - - // Test dial options for insecure transport (should exclude JWT) - insecureOpts := sc.DialOptionsWithCallCredsForTransport("insecure", &mockTransportCreds{protocol: "insecure"}) - - // JWT should be filtered out for insecure transport - if len(insecureOpts) >= len(secureOpts) { - t.Error("Expected fewer dial options for insecure transport (JWT should be filtered)") - } -} - -// Mock transport credentials for testing -type mockTransportCreds struct { - protocol string -} - -func (m *mockTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, &mockAuthInfo{}, nil -} - -func (m *mockTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, &mockAuthInfo{}, nil -} - -func (m *mockTransportCreds) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{SecurityProtocol: m.protocol} -} - -func (m *mockTransportCreds) Clone() credentials.TransportCredentials { - return &mockTransportCreds{protocol: m.protocol} -} - -func (m *mockTransportCreds) OverrideServerName(string) error { - return nil -} - -type mockAuthInfo struct{} - -func (m *mockAuthInfo) AuthType() string { - return "mock" -} - -func (m *mockAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { - return credentials.CommonAuthInfo{} -} From a38573b48a74f82668546c52242d5a8789184598 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 18:35:29 +0100 Subject: [PATCH 23/35] remove xds/bootstrap --- internal/xds/bootstrap/jwtcreds/bundle.go | 81 ------- .../xds/bootstrap/jwtcreds/bundle_test.go | 214 ------------------ 2 files changed, 295 deletions(-) delete mode 100644 internal/xds/bootstrap/jwtcreds/bundle.go delete mode 100644 internal/xds/bootstrap/jwtcreds/bundle_test.go diff --git a/internal/xds/bootstrap/jwtcreds/bundle.go b/internal/xds/bootstrap/jwtcreds/bundle.go deleted file mode 100644 index 2b2b2103e908..000000000000 --- a/internal/xds/bootstrap/jwtcreds/bundle.go +++ /dev/null @@ -1,81 +0,0 @@ -/* - * - * Copyright 2025 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// Package jwtcreds implements JWT Call Credentials in xDS Bootstrap File. -// See gRFC A97: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md -package jwtcreds - -import ( - "encoding/json" - "fmt" - - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/jwt" -) - -// bundle is an implementation of credentials.Bundle which implements JWT -// Call Credentials in xDS Bootstrap File per RFC A97. -// This bundle only provides call credentials, not transport credentials. -type bundle struct { - transportCreds credentials.TransportCredentials // Always nil for JWT call creds - callCreds credentials.PerRPCCredentials -} - -// NewBundle returns a credentials.Bundle which implements JWT Call Credentials -// in xDS Bootstrap File per RFC A97. This implementation focuses on call credentials -// only and expects the config to match RFC A97 structure. -// See gRFC A97: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md -func NewBundle(configJSON json.RawMessage) (credentials.Bundle, func(), error) { - var cfg struct { - JWTTokenFile string `json:"jwt_token_file"` - } - - if err := json.Unmarshal(configJSON, &cfg); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal JWT call credentials config: %v", err) - } - - if cfg.JWTTokenFile == "" { - return nil, nil, fmt.Errorf("jwt_token_file is required in JWT call credentials config") - } - - // Create JWT call credentials - callCreds, err := jwt.NewTokenFileCallCredentials(cfg.JWTTokenFile) - if err != nil { - return nil, nil, fmt.Errorf("failed to create JWT call credentials: %v", err) - } - - bundle := &bundle{ - transportCreds: nil, // JWT call creds don't provide transport security - callCreds: callCreds, - } - - return bundle, func() {}, nil -} - -func (b *bundle) TransportCredentials() credentials.TransportCredentials { - // Transport credentials should be configured separately via channel_creds - return nil -} - -func (b *bundle) PerRPCCredentials() credentials.PerRPCCredentials { - return b.callCreds -} - -func (b *bundle) NewWithMode(_ string) (credentials.Bundle, error) { - return nil, fmt.Errorf("JWT call credentials bundle does not support mode switching") -} diff --git a/internal/xds/bootstrap/jwtcreds/bundle_test.go b/internal/xds/bootstrap/jwtcreds/bundle_test.go deleted file mode 100644 index 74f49a710246..000000000000 --- a/internal/xds/bootstrap/jwtcreds/bundle_test.go +++ /dev/null @@ -1,214 +0,0 @@ -/* - * - * Copyright 2025 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package jwtcreds - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "google.golang.org/grpc/credentials" -) - -func TestNewBundle(t *testing.T) { - token := createTestJWT(t) - tokenFile := writeTempFile(t, token) - - tests := []struct { - name string - config string - wantErr bool - wantErrContains string - }{ - { - name: "valid RFC A97 config with jwt_token_file", - config: `{ - "jwt_token_file": "` + tokenFile + `" - }`, - wantErr: false, - }, - { - name: "empty config", - config: `""`, - wantErr: true, - wantErrContains: "unmarshal", - }, - { - name: "empty config", - config: `{}`, - wantErr: true, - wantErrContains: "jwt_token_file is required", - }, - { - name: "empty path", - config: `{ - "jwt_token_file": "" - }`, - wantErr: true, - wantErrContains: "jwt_token_file is required", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bundle, cleanup, err := NewBundle(json.RawMessage(tt.config)) - - if tt.wantErr { - if err == nil { - t.Fatal("Expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErrContains) { - t.Errorf("Error %v should contain %q", err, tt.wantErrContains) - } - return - } - - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if bundle == nil { - t.Fatal("Expected non-nil bundle") - } - - if cleanup == nil { - t.Error("Expected non-nil cleanup function") - } else { - defer cleanup() - } - - // JWT bundle only deals with PerRPCCredentials, not TransportCredentials - if bundle.TransportCredentials() != nil { - t.Error("Expected nil transport credentials for JWT call creds bundle") - } - - if bundle.PerRPCCredentials() == nil { - t.Error("Expected non-nil per-RPC credentials for valid JWT config") - } - - // Test that call credentials work - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ - AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, - }) - - metadata, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx) - if err != nil { - t.Fatalf("GetRequestMetadata failed: %v", err) - } - - if len(metadata) == 0 { - t.Error("Expected metadata to be returned") - } - - authHeader, ok := metadata["authorization"] - if !ok { - t.Error("Expected authorization header in metadata") - } - - if !strings.HasPrefix(authHeader, "Bearer ") { - t.Errorf("Authorization header should start with 'Bearer ', got %q", authHeader) - } - }) - } -} - -func TestBundle_NewWithMode(t *testing.T) { - token := createTestJWT(t) - tokenFile := writeTempFile(t, token) - config := `{"jwt_token_file": "` + tokenFile + `"}` - bundle, cleanup, err := NewBundle(json.RawMessage(config)) - if err != nil { - t.Fatalf("NewBundle failed: %v", err) - } - defer cleanup() - - _, err = bundle.NewWithMode("test_mode") - if err == nil { - t.Error("Expected error from NewWithMode, got nil") - } - if !strings.Contains(err.Error(), "does not support mode switching") { - t.Errorf("Error should mention mode switching, got: %v", err) - } -} - -func TestBundle_Cleanup(t *testing.T) { - token := createTestJWT(t) - tokenFile := writeTempFile(t, token) - config := `{"jwt_token_file": "` + tokenFile + `"}` - _, cleanup, err := NewBundle(json.RawMessage(config)) - if err != nil { - t.Fatalf("NewBundle failed: %v", err) - } - - if cleanup == nil { - t.Fatal("Expected non-nil cleanup function") - } - - // Cleanup should not panic - cleanup() - - // Multiple cleanup calls should be safe - cleanup() -} - -// testAuthInfo implements credentials.AuthInfo for testing -type testAuthInfo struct { - secLevel credentials.SecurityLevel -} - -func (t *testAuthInfo) AuthType() string { - return "test" -} - -func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { - return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} -} - -// createTestJWT creates a test JWT token for testing -func createTestJWT(t *testing.T) string { - t.Helper() - - // Create a valid JWT with proper base64 encoding for testing - // Header: {"typ":"JWT","alg":"HS256"} - header := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9" - - // Claims: {"aud":"https://example.com","exp":future_timestamp} - claims := "eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tIiwiZXhwIjoyMDAwMDAwMDAwfQ" - - // Fake signature for testing - signature := "fake_signature_for_testing" - - return header + "." + claims + "." + signature -} - -func writeTempFile(t *testing.T, content string) string { - t.Helper() - tempDir := t.TempDir() - filePath := filepath.Join(tempDir, "tempfile") - if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { - t.Fatalf("Failed to write temp file: %v", err) - } - return filePath -} From 12fedd5964e195de1166d7ce42a05347c498fc9e Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Thu, 21 Aug 2025 18:38:12 +0100 Subject: [PATCH 24/35] fix comment docstrings --- credentials/jwt/jwt_token_file_call_creds.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index 3e691b016cc8..a7c847584334 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -140,8 +140,7 @@ func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { return c.isTokenValidLocked() && time.Until(c.cachedExpiry) < preemptiveRefreshThreshold } -// refreshToken reads the token from file. -// Updates the cache and broadcasts to waiting goroutines when complete. +// refreshToken reads the token from file and updates the cached data. func (c *jwtTokenFileCallCreds) refreshToken() { // Deliberately not locking c.mu here token, expiry, err := c.fileReader.ReadToken() @@ -150,7 +149,6 @@ func (c *jwtTokenFileCallCreds) refreshToken() { defer c.mu.Unlock() c.updateCacheLocked(token, expiry, err) - // Reset pending refresh and broadcast to waiting goroutines c.pendingRefresh = false } From 52445c7f3830601972ddd4e1edcbba1b3276dffb Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 11:54:39 +0100 Subject: [PATCH 25/35] remove newJWTFileReader --- credentials/jwt/jwt_file_reader.go | 7 ------- credentials/jwt/jwt_file_reader_test.go | 6 +++--- credentials/jwt/jwt_token_file_call_creds.go | 2 +- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/credentials/jwt/jwt_file_reader.go b/credentials/jwt/jwt_file_reader.go index 005662eea364..cfde36a53e31 100644 --- a/credentials/jwt/jwt_file_reader.go +++ b/credentials/jwt/jwt_file_reader.go @@ -37,13 +37,6 @@ type jWTFileReader struct { tokenFilePath string } -// newJWTFileReader creates a new JWTFileReader for the specified file path. -func newJWTFileReader(tokenFilePath string) *jWTFileReader { - return &jWTFileReader{ - tokenFilePath: tokenFilePath, - } -} - // ReadToken reads and parses a JWT token from the configured file. // Returns the token string, expiration time, and any error encountered. func (r *jWTFileReader) ReadToken() (string, time.Time, error) { diff --git a/credentials/jwt/jwt_file_reader_test.go b/credentials/jwt/jwt_file_reader_test.go index b57a93fa54ef..fab0815504cc 100644 --- a/credentials/jwt/jwt_file_reader_test.go +++ b/credentials/jwt/jwt_file_reader_test.go @@ -66,7 +66,7 @@ func TestJWTFileReader_ReadToken_FileErrors(t *testing.T) { t.Fatalf("Failed to setup test file: %v", err) } - reader := newJWTFileReader(tokenFile) + reader := jWTFileReader{tokenFilePath: tokenFile} _, _, err := reader.ReadToken() if err == nil { t.Fatal("ReadToken() expected error, got nil") @@ -117,7 +117,7 @@ func TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tokenFile := writeTempFile(t, "token", tt.tokenContent) - reader := newJWTFileReader(tokenFile) + reader := jWTFileReader{tokenFilePath: tokenFile} _, _, err := reader.ReadToken() if err == nil { t.Fatal("ReadToken() expected error, got nil") @@ -136,7 +136,7 @@ func TestJWTFileReader_ReadToken_ValidToken(t *testing.T) { token := createTestJWT(t, "https://example.com", tokenExp) tokenFile := writeTempFile(t, "token", token) - reader := newJWTFileReader(tokenFile) + reader := jWTFileReader{tokenFilePath: tokenFile} readToken, expiry, err := reader.ReadToken() if err != nil { t.Fatalf("ReadToken() unexpected error: %v", err) diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index a7c847584334..271fd03f0649 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -59,7 +59,7 @@ func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCreden } creds := &jwtTokenFileCallCreds{ - fileReader: newJWTFileReader(tokenFilePath), + fileReader: &jWTFileReader{tokenFilePath: tokenFilePath}, backoffStrategy: backoff.DefaultExponential, } From 16780166f225ca4e17017e3a90c74192320bd16b Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 11:58:15 +0100 Subject: [PATCH 26/35] make ReadToken private method --- credentials/jwt/jwt_file_reader.go | 4 ++-- credentials/jwt/jwt_file_reader_test.go | 6 +++--- credentials/jwt/jwt_token_file_call_creds.go | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/credentials/jwt/jwt_file_reader.go b/credentials/jwt/jwt_file_reader.go index cfde36a53e31..12bffc939e42 100644 --- a/credentials/jwt/jwt_file_reader.go +++ b/credentials/jwt/jwt_file_reader.go @@ -37,9 +37,9 @@ type jWTFileReader struct { tokenFilePath string } -// ReadToken reads and parses a JWT token from the configured file. +// readToken reads and parses a JWT token from the configured file. // Returns the token string, expiration time, and any error encountered. -func (r *jWTFileReader) ReadToken() (string, time.Time, error) { +func (r *jWTFileReader) readToken() (string, time.Time, error) { tokenBytes, err := os.ReadFile(r.tokenFilePath) if err != nil { return "", time.Time{}, fmt.Errorf("failed to read token file %q: %v", r.tokenFilePath, err) diff --git a/credentials/jwt/jwt_file_reader_test.go b/credentials/jwt/jwt_file_reader_test.go index fab0815504cc..a9f4bd629a70 100644 --- a/credentials/jwt/jwt_file_reader_test.go +++ b/credentials/jwt/jwt_file_reader_test.go @@ -67,7 +67,7 @@ func TestJWTFileReader_ReadToken_FileErrors(t *testing.T) { } reader := jWTFileReader{tokenFilePath: tokenFile} - _, _, err := reader.ReadToken() + _, _, err := reader.readToken() if err == nil { t.Fatal("ReadToken() expected error, got nil") } @@ -118,7 +118,7 @@ func TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { tokenFile := writeTempFile(t, "token", tt.tokenContent) reader := jWTFileReader{tokenFilePath: tokenFile} - _, _, err := reader.ReadToken() + _, _, err := reader.readToken() if err == nil { t.Fatal("ReadToken() expected error, got nil") } @@ -137,7 +137,7 @@ func TestJWTFileReader_ReadToken_ValidToken(t *testing.T) { tokenFile := writeTempFile(t, "token", token) reader := jWTFileReader{tokenFilePath: tokenFile} - readToken, expiry, err := reader.ReadToken() + readToken, expiry, err := reader.readToken() if err != nil { t.Fatalf("ReadToken() unexpected error: %v", err) } diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index 271fd03f0649..422e21485674 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -104,7 +104,7 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str // At this point, the token is either invalid or expired and we are no // longer backing off. So refresh it. - token, expiry, err := c.fileReader.ReadToken() + token, expiry, err := c.fileReader.readToken() c.updateCacheLocked(token, expiry, err) if c.cachedError != nil { @@ -143,7 +143,7 @@ func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { // refreshToken reads the token from file and updates the cached data. func (c *jwtTokenFileCallCreds) refreshToken() { // Deliberately not locking c.mu here - token, expiry, err := c.fileReader.ReadToken() + token, expiry, err := c.fileReader.readToken() c.mu.Lock() defer c.mu.Unlock() From 1be843b97f044c8ad1757c8f128cc19eaafe28cc Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 14:30:21 +0100 Subject: [PATCH 27/35] use subtests --- credentials/jwt/jwt_file_reader_test.go | 12 +++++++++--- credentials/jwt/jwt_token_file_call_creds_test.go | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/credentials/jwt/jwt_file_reader_test.go b/credentials/jwt/jwt_file_reader_test.go index a9f4bd629a70..d5303b2f5063 100644 --- a/credentials/jwt/jwt_file_reader_test.go +++ b/credentials/jwt/jwt_file_reader_test.go @@ -27,9 +27,15 @@ import ( "strings" "testing" "time" + + "google.golang.org/grpc/internal/grpctest" ) -func TestJWTFileReader_ReadToken_FileErrors(t *testing.T) { +func TestJWTFileReader(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestJWTFileReader_ReadToken_FileErrors(t *testing.T) { tests := []struct { name string setupFile func(string) error @@ -79,7 +85,7 @@ func TestJWTFileReader_ReadToken_FileErrors(t *testing.T) { } } -func TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { +func (s) TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { now := time.Now().Truncate(time.Second) tests := []struct { name string @@ -130,7 +136,7 @@ func TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { } } -func TestJWTFileReader_ReadToken_ValidToken(t *testing.T) { +func (s) TestJWTFileReader_ReadToken_ValidToken(t *testing.T) { now := time.Now().Truncate(time.Second) tokenExp := now.Add(time.Hour) token := createTestJWT(t, "https://example.com", tokenExp) diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index c09718d702da..5328a15eaaf3 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -41,7 +41,7 @@ type s struct { grpctest.Tester } -func Test(t *testing.T) { +func TestTokenFileCallCreds(t *testing.T) { grpctest.RunSubTests(t, s{}) } From 8ac32968046ec4ddd0035ee4b91f8ee48efc3b2d Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 15:19:52 +0100 Subject: [PATCH 28/35] use writeTempFile --- credentials/jwt/jwt_file_reader_test.go | 42 +++++++++++-------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/credentials/jwt/jwt_file_reader_test.go b/credentials/jwt/jwt_file_reader_test.go index d5303b2f5063..5991b1c7f9f7 100644 --- a/credentials/jwt/jwt_file_reader_test.go +++ b/credentials/jwt/jwt_file_reader_test.go @@ -22,8 +22,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "os" - "path/filepath" "strings" "testing" "time" @@ -38,38 +36,37 @@ func TestJWTFileReader(t *testing.T) { func (s) TestJWTFileReader_ReadToken_FileErrors(t *testing.T) { tests := []struct { name string - setupFile func(string) error + create bool + contents string wantErrContains string }{ { - name: "nonexistent file", - setupFile: func(_ string) error { - return nil // Don't create the file - }, + name: "nonexistent file", + create: false, + contents: "", wantErrContains: "failed to read token file", }, { - name: "empty file", - setupFile: func(path string) error { - return os.WriteFile(path, []byte(""), 0600) - }, + name: "empty file", + create: true, + contents: "", wantErrContains: "token file", }, { - name: "file with whitespace only", - setupFile: func(path string) error { - return os.WriteFile(path, []byte(" \n\t "), 0600) - }, + name: "file with whitespace only", + create: true, + contents: " \n\t ", wantErrContains: "token file", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - if err := tt.setupFile(tokenFile); err != nil { - t.Fatalf("Failed to setup test file: %v", err) + var tokenFile string + if !tt.create { + tokenFile = "/does-not-exixt" + } else { + tokenFile = writeTempFile(t, "token", tt.contents) } reader := jWTFileReader{tokenFilePath: tokenFile} @@ -124,12 +121,9 @@ func (s) TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { tokenFile := writeTempFile(t, "token", tt.tokenContent) reader := jWTFileReader{tokenFilePath: tokenFile} - _, _, err := reader.readToken() - if err == nil { + if _, _, err := reader.readToken(); err == nil { t.Fatal("ReadToken() expected error, got nil") - } - - if !strings.Contains(err.Error(), tt.wantErrContains) { + } else if !strings.Contains(err.Error(), tt.wantErrContains) { t.Fatalf("ReadToken() error = %v, want error containing %q", err, tt.wantErrContains) } }) From b0bdc7092e442d9d497f3d8c7f623a6429aa9309 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 15:23:39 +0100 Subject: [PATCH 29/35] add comment about RPC queue behaviour --- credentials/jwt/jwt_token_file_call_creds.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index 422e21485674..5c5b8363b3ea 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -103,7 +103,11 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str } // At this point, the token is either invalid or expired and we are no - // longer backing off. So refresh it. + // longer backing off from any encountered errors. So refresh it. + // NB: We are holding the lock while reading the token from file. This will + // cause other RPCs to block until the read completes (sucecssfully or not) + // and the cache is updated. Subsequent RPCs will end up using the cache. + // This is per A97. token, expiry, err := c.fileReader.readToken() c.updateCacheLocked(token, expiry, err) From e4f955c17462b608ce2a907cd840e7aaf2621a37 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 15:41:37 +0100 Subject: [PATCH 30/35] remove needsPreemptiveRefreshLocked method --- credentials/jwt/jwt_token_file_call_creds.go | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds.go b/credentials/jwt/jwt_token_file_call_creds.go index 5c5b8363b3ea..a5fdad5bf77b 100644 --- a/credentials/jwt/jwt_token_file_call_creds.go +++ b/credentials/jwt/jwt_token_file_call_creds.go @@ -84,7 +84,8 @@ func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...str defer c.mu.Unlock() if c.isTokenValidLocked() { - if c.needsPreemptiveRefreshLocked() { + needsPreemptiveRefresh := time.Until(c.cachedExpiry) < preemptiveRefreshThreshold + if needsPreemptiveRefresh { // Start refresh if not pending (handling the prior RPC may have // just spawned a goroutine). if !c.pendingRefresh { @@ -134,16 +135,6 @@ func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { return c.cachedExpiry.After(time.Now()) } -// needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be -// triggered. -// Returns true if the cached token is valid but expires within 1 minute. -// We only trigger pre-emptive refresh for valid tokens - if the token is -// invalid or expired, the next RPC will handle synchronous refresh instead. -// Caller must hold c.mu lock. -func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { - return c.isTokenValidLocked() && time.Until(c.cachedExpiry) < preemptiveRefreshThreshold -} - // refreshToken reads the token from file and updates the cached data. func (c *jwtTokenFileCallCreds) refreshToken() { // Deliberately not locking c.mu here From f78178c130adac2f6079608f24d81534e899e15f Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 15:44:25 +0100 Subject: [PATCH 31/35] split NewTokenFileCallCredentials tests --- .../jwt/jwt_token_file_call_creds_test.go | 53 ++++++------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index 5328a15eaaf3..c8e17fbd002e 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -45,43 +45,24 @@ func TestTokenFileCallCreds(t *testing.T) { grpctest.RunSubTests(t, s{}) } -func (s) TestNewTokenFileCallCredentials(t *testing.T) { - tests := []struct { - name string - tokenFilePath string - wantErr string - }{ - { - name: "some filepath", - tokenFilePath: "/path/to/token", - wantErr: "", - }, - { - name: "empty filepath", - tokenFilePath: "", - wantErr: "tokenFilePath cannot be empty", - }, +func (s) TestNewTokenFileCallCredentialsValidFilepath(t *testing.T) { + creds, err := NewTokenFileCallCredentials("/path/to/token") + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err) + } + if creds == nil { + t.Fatal("NewTokenFileCallCredentials() returned nil credentials") } +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - creds, err := NewTokenFileCallCredentials(tt.tokenFilePath) - if tt.wantErr != "" { - if err == nil { - t.Fatalf("NewTokenFileCallCredentials() expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err) - } - if creds == nil { - t.Fatal("NewTokenFileCallCredentials() returned nil credentials") - } - }) +func (s) TestNewTokenFileCallCredentialsMissingFilepath(t *testing.T) { + _, err := NewTokenFileCallCredentials("") + if err == nil { + t.Fatalf("NewTokenFileCallCredentials() expected error, got nil") + } + expectedErr := "tokenFilePath cannot be empty" + if !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, expectedErr) } } @@ -289,7 +270,7 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { impl.mu.Lock() cacheExp := impl.cachedExpiry tokenCached := impl.cachedAuthHeader != "" - shouldTriggerRefresh := impl.needsPreemptiveRefreshLocked() + shouldTriggerRefresh := time.Until(cacheExp) < preemptiveRefreshThreshold impl.mu.Unlock() if !tokenCached { From bbeb759953d80919abbcec4d54e145bab68c46da Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 15:49:39 +0100 Subject: [PATCH 32/35] remove leftover os.MkdirTemp --- credentials/jwt/jwt_token_file_call_creds_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index c8e17fbd002e..c67d238a9474 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -78,12 +78,6 @@ func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) { } func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { - tempDir, err := os.MkdirTemp("", "jwt_test") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tempDir) - now := time.Now().Truncate(time.Second) tests := []struct { name string From bba5d344e5bc392aeaa737b49f473d89291586ec Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 16:30:53 +0100 Subject: [PATCH 33/35] remove audience parameter and do not set it at all for test tokens --- credentials/jwt/jwt_file_reader_test.go | 6 ++-- .../jwt/jwt_token_file_call_creds_test.go | 35 ++++++++----------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/credentials/jwt/jwt_file_reader_test.go b/credentials/jwt/jwt_file_reader_test.go index 5991b1c7f9f7..e2a1425dd884 100644 --- a/credentials/jwt/jwt_file_reader_test.go +++ b/credentials/jwt/jwt_file_reader_test.go @@ -91,12 +91,12 @@ func (s) TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { }{ { name: "valid token without expiration", - tokenContent: createTestJWT(t, "", time.Time{}), + tokenContent: createTestJWT(t, time.Time{}), wantErrContains: "JWT token has no expiration claim", }, { name: "expired token", - tokenContent: createTestJWT(t, "", now.Add(-time.Hour)), + tokenContent: createTestJWT(t, now.Add(-time.Hour)), wantErrContains: "JWT token is expired", }, { @@ -133,7 +133,7 @@ func (s) TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) { func (s) TestJWTFileReader_ReadToken_ValidToken(t *testing.T) { now := time.Now().Truncate(time.Second) tokenExp := now.Add(time.Hour) - token := createTestJWT(t, "https://example.com", tokenExp) + token := createTestJWT(t, tokenExp) tokenFile := writeTempFile(t, "token", token) reader := jWTFileReader{tokenFilePath: tokenFile} diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index c67d238a9474..2166e81edb8f 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -89,14 +89,14 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { }{ { name: "valid token with future expiration", - tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)), + tokenContent: createTestJWT(t, now.Add(time.Hour)), authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, wantErr: false, - wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, "https://example.com", now.Add(time.Hour))}, + wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, now.Add(time.Hour))}, }, { name: "insufficient security level", - tokenContent: createTestJWT(t, "", now.Add(time.Hour)), + tokenContent: createTestJWT(t, now.Add(time.Hour)), authInfo: &testAuthInfo{secLevel: credentials.NoSecurity}, wantErr: true, wantErrContains: "unable to transfer JWT token file PerRPCCredentials", @@ -147,7 +147,7 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { } func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { - token := createTestJWT(t, "", time.Now().Add(time.Hour)) + token := createTestJWT(t, time.Now().Add(time.Hour)) tokenFile := writeTempFile(t, "token", token) creds, err := NewTokenFileCallCredentials(tokenFile) @@ -168,7 +168,7 @@ func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { } // Update the file with a different token. - newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) + newToken := createTestJWT(t, time.Now().Add(2*time.Hour)) if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { t.Fatalf("Failed to update token file: %v", err) } @@ -202,7 +202,7 @@ func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { // Create token that expires in 2 hours. tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour) - token := createTestJWT(t, "", tokenExp) + token := createTestJWT(t, tokenExp) tokenFile := writeTempFile(t, "token", token) creds, err := NewTokenFileCallCredentials(tokenFile) @@ -217,8 +217,7 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin }) // Get token to trigger caching. - _, err = creds.GetRequestMetadata(ctx) - if err != nil { + if _, err = creds.GetRequestMetadata(ctx); err != nil { t.Fatalf("GetRequestMetadata() failed: %v", err) } @@ -239,7 +238,7 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { // Create token that expires in 80 seconds (=> cache expires in ~50s). // This ensures pre-emptive refresh triggers since 50s < the 1 minute check. tokenExp := time.Now().Add(80 * time.Second) - expiringToken := createTestJWT(t, "", tokenExp) + expiringToken := createTestJWT(t, tokenExp) tokenFile := writeTempFile(t, "token", expiringToken) creds, err := NewTokenFileCallCredentials(tokenFile) @@ -278,7 +277,7 @@ func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { // Create new token file with different expiration while refresh is // happening. - newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) + newToken := createTestJWT(t, time.Now().Add(2*time.Hour)) if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { t.Fatalf("Failed to write updated token file: %v", err) } @@ -427,7 +426,7 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { } // Create valid token file. - validToken := createTestJWT(t, "", time.Now().Add(time.Hour)) + validToken := createTestJWT(t, time.Now().Add(time.Hour)) if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil { t.Fatalf("Failed to create valid token file: %v", err) } @@ -488,22 +487,18 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { // createTestJWT creates a test JWT token with the specified audience and // expiration. -func createTestJWT(t *testing.T, audience string, expiration time.Time) string { +func createTestJWT(t *testing.T, expiration time.Time) string { t.Helper() - header := map[string]any{ - "typ": "JWT", - "alg": "HS256", - } - claims := map[string]any{} - if audience != "" { - claims["aud"] = audience - } if !expiration.IsZero() { claims["exp"] = expiration.Unix() } + header := map[string]any{ + "typ": "JWT", + "alg": "HS256", + } headerBytes, err := json.Marshal(header) if err != nil { t.Fatalf("Failed to marshal header: %v", err) From 607868b4433aa8cf8921b5a363a4d4f86c45b24e Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 18:42:04 +0100 Subject: [PATCH 34/35] test for grpc codes in TestTokenFileCallCreds_GetRequestMetadata --- .../jwt/jwt_token_file_call_creds_test.go | 59 +++++++++---------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index 2166e81edb8f..8204e168465e 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -29,6 +29,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpctest" @@ -80,26 +81,36 @@ func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) { func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { now := time.Now().Truncate(time.Second) tests := []struct { - name string - tokenContent string - authInfo credentials.AuthInfo - wantErr bool - wantErrContains string - wantMetadata map[string]string + name string + tokenContent string + authInfo credentials.AuthInfo + grpcCode codes.Code + wantMetadata map[string]string }{ { name: "valid token with future expiration", tokenContent: createTestJWT(t, now.Add(time.Hour)), authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, - wantErr: false, + grpcCode: codes.OK, wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, now.Add(time.Hour))}, }, { - name: "insufficient security level", - tokenContent: createTestJWT(t, now.Add(time.Hour)), - authInfo: &testAuthInfo{secLevel: credentials.NoSecurity}, - wantErr: true, - wantErrContains: "unable to transfer JWT token file PerRPCCredentials", + name: "insufficient security level", + tokenContent: createTestJWT(t, now.Add(time.Hour)), + authInfo: &testAuthInfo{secLevel: credentials.NoSecurity}, + grpcCode: codes.Unknown, + }, + { + name: "unreachable token file", + tokenContent: "", + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + grpcCode: codes.Unavailable, + }, + { + name: "malformed JWT token", + tokenContent: "bad contents", + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + grpcCode: codes.Unauthenticated, }, } @@ -119,28 +130,12 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { }) metadata, err := creds.GetRequestMetadata(ctx) - if tt.wantErr { - if err == nil { - t.Fatalf("GetRequestMetadata() expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErrContains) { - t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains) - } - return - } - - if err != nil { - t.Fatalf("GetRequestMetadata() unexpected error: %v", err) - } - - if len(metadata) != len(tt.wantMetadata) { - t.Fatalf("GetRequestMetadata() returned %d metadata entries, want %d", len(metadata), len(tt.wantMetadata)) + if status.Code(err) != tt.grpcCode { + t.Fatalf("GetRequestMetadata() = %v, want %v", status.Code(err), tt.grpcCode) } - for k, v := range tt.wantMetadata { - if metadata[k] != v { - t.Errorf("GetRequestMetadata() metadata[%q] = %q, want %q", k, metadata[k], v) - } + if diff := cmp.Diff(tt.wantMetadata, metadata); diff != "" { + t.Errorf("GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff) } }) } From bc2d3279ec5d2721545a26cc79300fff7d899de0 Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Tue, 26 Aug 2025 18:47:08 +0100 Subject: [PATCH 35/35] use cmp.Diff in TestTokenFileCallCreds_TokenCaching --- credentials/jwt/jwt_token_file_call_creds_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/credentials/jwt/jwt_token_file_call_creds_test.go b/credentials/jwt/jwt_token_file_call_creds_test.go index 8204e168465e..1501db13ba53 100644 --- a/credentials/jwt/jwt_token_file_call_creds_test.go +++ b/credentials/jwt/jwt_token_file_call_creds_test.go @@ -161,6 +161,10 @@ func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { if err != nil { t.Fatalf("First GetRequestMetadata() failed: %v", err) } + wantMetadata := map[string]string{"authorization": "Bearer " + token} + if diff := cmp.Diff(wantMetadata, metadata1); diff != "" { + t.Errorf("First GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff) + } // Update the file with a different token. newToken := createTestJWT(t, time.Now().Add(2*time.Hour)) @@ -174,8 +178,8 @@ func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { t.Fatalf("Second GetRequestMetadata() failed: %v", err) } - if metadata1["authorization"] != metadata2["authorization"] { - t.Error("Expected cached token to be returned, but got different token") + if diff := cmp.Diff(metadata1, metadata2); diff != "" { + t.Errorf("Second GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff) } }