Skip to content

feat: Cache awss3/awssqs Client #735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions aws/awscognito/awscognito.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package awscognito
import (
"context"
"fmt"
"sync/atomic"

"github.com/aws/aws-sdk-go-v2/aws"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
Expand All @@ -12,22 +13,28 @@ import (
"github.com/88labs/go-utils/aws/ctxawslocal"
)

var cognitoidentityClientAtomic atomic.Pointer[cognitoidentity.Client]

// GetCredentialsForIdentity
// aws-sdk-go v2 GetCredentialsForIdentity
//
// Mocks: Using ctxawslocal.WithContext, you can make requests for local mocks.
func GetCredentialsForIdentity(ctx context.Context, region awsconfig.Region, identityId string, logins map[string]string) (*cognitoidentity.GetCredentialsForIdentityOutput, error) {
localProfile, _ := getLocalEndpoint(ctx)
// Cognito Client
awsCfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region.String()))
if err != nil {
return nil, fmt.Errorf("unable to load SDK config, %w", err)
}
client := cognitoidentity.NewFromConfig(awsCfg)
if err != nil {
return nil, err
func GetCredentialsForIdentity(
ctx context.Context, region awsconfig.Region, identityId string, logins map[string]string,
) (*cognitoidentity.GetCredentialsForIdentityOutput, error) {
var client *cognitoidentity.Client
if v := cognitoidentityClientAtomic.Load(); v != nil {
client = v
} else {
// Cognito Client
awsCfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region.String()))
if err != nil {
return nil, fmt.Errorf("unable to load SDK config, %w", err)
}
client = cognitoidentity.NewFromConfig(awsCfg)
cognitoidentityClientAtomic.Store(client)
}

localProfile, _ := getLocalEndpoint(ctx)
res, err := client.GetCredentialsForIdentity(
ctx,
&cognitoidentity.GetCredentialsForIdentityInput{
Expand Down
121 changes: 60 additions & 61 deletions aws/awsdynamo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package awsdynamo

import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -17,74 +16,74 @@ import (
"github.com/88labs/go-utils/aws/ctxawslocal"
)

var dynamoDBClient *dynamodb.Client
var once sync.Once
var dynamoDBClientAtomic atomic.Pointer[dynamodb.Client]

func GetClient(ctx context.Context, region awsconfig.Region, limitAttempts int, limitBackOffDelay time.Duration) (*dynamodb.Client, error) {
if localProfile, ok := getLocalEndpoint(ctx); ok {
return getClientLocal(ctx, *localProfile)
func GetClient(
ctx context.Context, region awsconfig.Region, limitAttempts int, limitBackOffDelay time.Duration,
) (*dynamodb.Client, error) {
if v := dynamoDBClientAtomic.Load(); v != nil {
return v, nil
}
var responseError error
once.Do(func() {
// S3 Client
awsCfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region.String()),
awsConfig.WithRetryer(func() aws.Retryer {
r := retry.AddWithMaxAttempts(retry.NewStandard(), limitAttempts)
r = retry.AddWithMaxBackoffDelay(r, limitBackOffDelay)
r = retry.AddWithErrorCodes(r,
string(types.BatchStatementErrorCodeEnumItemCollectionSizeLimitExceeded),
string(types.BatchStatementErrorCodeEnumRequestLimitExceeded),
string(types.BatchStatementErrorCodeEnumProvisionedThroughputExceeded),
string(types.BatchStatementErrorCodeEnumInternalServerError),
string(types.BatchStatementErrorCodeEnumThrottlingError),
)
return r
}),
)
if localProfile, ok := getLocalEndpoint(ctx); ok {
c, err := getClientLocal(ctx, *localProfile)
if err != nil {
responseError = fmt.Errorf("unable to load SDK config, %w", err)
} else {
responseError = nil
return nil, err
}
dynamoDBClient = dynamodb.NewFromConfig(awsCfg)
})
return dynamoDBClient, responseError
dynamoDBClientAtomic.Store(c)
return c, nil
}
// S3 Client
awsCfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region.String()),
awsConfig.WithRetryer(func() aws.Retryer {
r := retry.AddWithMaxAttempts(retry.NewStandard(), limitAttempts)
r = retry.AddWithMaxBackoffDelay(r, limitBackOffDelay)
r = retry.AddWithErrorCodes(r,
string(types.BatchStatementErrorCodeEnumItemCollectionSizeLimitExceeded),
string(types.BatchStatementErrorCodeEnumRequestLimitExceeded),
string(types.BatchStatementErrorCodeEnumProvisionedThroughputExceeded),
string(types.BatchStatementErrorCodeEnumInternalServerError),
string(types.BatchStatementErrorCodeEnumThrottlingError),
)
return r
}),
)
if err != nil {
return nil, err
}
c := dynamodb.NewFromConfig(awsCfg)
dynamoDBClientAtomic.Store(c)
return c, nil
}

func getClientLocal(ctx context.Context, localProfile LocalProfile) (*dynamodb.Client, error) {
var responseError error
once.Do(func() {
// https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/endpoints/
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
if service == dynamodb.ServiceID {
return aws.Endpoint{
PartitionID: "aws",
URL: localProfile.Endpoint,
SigningRegion: region,
HostnameImmutable: true,
}, nil
}
// returning EndpointNotFoundError will allow the service to fallback to it's default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})
awsCfg, err := awsConfig.LoadDefaultConfig(ctx,
awsConfig.WithEndpointResolverWithOptions(customResolver),
awsConfig.WithCredentialsProvider(credentials.StaticCredentialsProvider{
Value: aws.Credentials{
AccessKeyID: localProfile.AccessKey,
SecretAccessKey: localProfile.SecretAccessKey,
},
}),
)
if err != nil {
responseError = fmt.Errorf("unable to load SDK config, %w", err)
return
} else {
responseError = nil
// https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/endpoints/
customResolver := aws.EndpointResolverWithOptionsFunc(func(
service, region string, options ...interface{},
) (aws.Endpoint, error) {
if service == dynamodb.ServiceID {
return aws.Endpoint{
PartitionID: "aws",
URL: localProfile.Endpoint,
SigningRegion: region,
HostnameImmutable: true,
}, nil
}
dynamoDBClient = dynamodb.NewFromConfig(awsCfg)
// returning EndpointNotFoundError will allow the service to fallback to it's default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})
return dynamoDBClient, responseError
awsCfg, err := awsConfig.LoadDefaultConfig(ctx,
awsConfig.WithEndpointResolverWithOptions(customResolver),
awsConfig.WithCredentialsProvider(credentials.StaticCredentialsProvider{
Value: aws.Credentials{
AccessKeyID: localProfile.AccessKey,
SecretAccessKey: localProfile.SecretAccessKey,
},
}),
)
if err != nil {
return nil, err
}
return dynamodb.NewFromConfig(awsCfg), nil
}

type LocalProfile struct {
Expand Down
25 changes: 13 additions & 12 deletions aws/awss3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"encoding/gob"
"fmt"
"net"
"sync"
"sync/atomic"

"github.com/aws/aws-sdk-go-v2/aws"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
Expand All @@ -21,24 +21,23 @@ import (

var (
// GlobalDialer Global http dialer settings for awss3 library
GlobalDialer *s3dialer.ConfGlobalDialer

customMu sync.Mutex
customEndpointClient *s3.Client
GlobalDialer *s3dialer.ConfGlobalDialer
s3ClientAtomic atomic.Pointer[s3.Client]
)

// GetClient
// Get s3 client for aws-sdk-go v2.
// Using ctxawslocal.WithContext, you can make requests for local mocks
func GetClient(ctx context.Context, region awsconfig.Region) (*s3.Client, error) {
if v := s3ClientAtomic.Load(); v != nil {
return v, nil
}
if localProfile, ok := getLocalEndpoint(ctx); ok {
customMu.Lock()
defer customMu.Unlock()
var err error
if customEndpointClient != nil {
return customEndpointClient, err
customEndpointClient, err := getClientLocal(ctx, *localProfile)
if err != nil {
return nil, err
}
customEndpointClient, err = getClientLocal(ctx, *localProfile)
s3ClientAtomic.Store(customEndpointClient)
return customEndpointClient, err
}
awsHttpClient := awshttp.NewBuildableClient()
Expand All @@ -64,7 +63,9 @@ func GetClient(ctx context.Context, region awsconfig.Region) (*s3.Client, error)
if err != nil {
return nil, fmt.Errorf("unable to load SDK config, %w", err)
}
return s3.NewFromConfig(awsCfg), nil
c := s3.NewFromConfig(awsCfg)
s3ClientAtomic.Store(c)
return c, nil
}

func getClientLocal(ctx context.Context, localProfile LocalProfile) (*s3.Client, error) {
Expand Down
21 changes: 18 additions & 3 deletions aws/awssqs/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package awssqs
import (
"context"
"fmt"
"sync/atomic"

"github.com/aws/aws-sdk-go-v2/aws"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
Expand All @@ -13,24 +14,38 @@ import (
"github.com/88labs/go-utils/aws/ctxawslocal"
)

var sqsClientAtomic atomic.Pointer[sqs.Client]

// GetClient
// Get s3 client for aws-sdk-go v2.
// Using ctxawslocal.WithContext, you can make requests for local mocks
func GetClient(ctx context.Context, region awsconfig.Region) (*sqs.Client, error) {
if v := sqsClientAtomic.Load(); v != nil {
return v, nil
}
if localProfile, ok := getLocalEndpoint(ctx); ok {
return getClientLocal(ctx, *localProfile)
c, err := getClientLocal(ctx, *localProfile)
if err != nil {
return nil, err
}
sqsClientAtomic.Store(c)
return c, nil
}
// SQS Client
awsCfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region.String()))
if err != nil {
return nil, fmt.Errorf("unable to load SDK config, %w", err)
}
return sqs.NewFromConfig(awsCfg), nil
c := sqs.NewFromConfig(awsCfg)
sqsClientAtomic.Store(c)
return c, nil
}

func getClientLocal(ctx context.Context, localProfile LocalProfile) (*sqs.Client, error) {
// https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/endpoints/
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
customResolver := aws.EndpointResolverWithOptionsFunc(func(
service, region string, options ...interface{},
) (aws.Endpoint, error) {
if service == sqs.ServiceID {
return aws.Endpoint{
PartitionID: "aws",
Expand Down
Loading