diff --git a/Snowflake.Client.Tests/Models/TestConfiguration.cs b/Snowflake.Client.Tests/Models/TestConfiguration.cs
index 1fda5df..f8c052e 100644
--- a/Snowflake.Client.Tests/Models/TestConfiguration.cs
+++ b/Snowflake.Client.Tests/Models/TestConfiguration.cs
@@ -3,5 +3,10 @@
public class TestConfiguration
{
public SnowflakeConnectionInfo Connection { get; set; }
+ public string AdClientId { get; set; }
+ public string AdClientSecret { get; set; }
+ public string AdServicePrincipalObjectId { get; set; }
+ public string AdTenantId { get; set; }
+ public string AdScope { get; set; }
}
}
diff --git a/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj b/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj
index e1ce1f6..c7f4ed3 100644
--- a/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj
+++ b/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj
@@ -7,6 +7,8 @@
+
+
diff --git a/Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs b/Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs
new file mode 100644
index 0000000..2100e91
--- /dev/null
+++ b/Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs
@@ -0,0 +1,34 @@
+using System;
+using NUnit.Framework;
+using Snowflake.Client.Tests.Models;
+using Snowflake.Client.Model;
+using System.IO;
+using System.Text.Json;
+
+namespace Snowflake.Client.Tests.IntegrationTests
+{
+ [TestFixture]
+ public class AzureAdAuthInfoTests
+ {
+ protected readonly AzureAdAuthInfo _azureAdAuthInfo;
+
+ public AzureAdAuthInfoTests()
+ {
+ var configJson = File.ReadAllText("testconfig.json");
+ var testParameters = JsonSerializer.Deserialize(configJson, new JsonSerializerOptions() { PropertyNameCaseInsensitive = true });
+ var connectionInfo = testParameters.Connection;
+
+ _azureAdAuthInfo = new AzureAdAuthInfo(
+ testParameters.AdClientId,
+ testParameters.AdClientSecret,
+ testParameters.AdServicePrincipalObjectId,
+ testParameters.AdTenantId,
+ testParameters.AdScope,
+ connectionInfo.Region,
+ connectionInfo.Account,
+ connectionInfo.User,
+ connectionInfo.Host,
+ connectionInfo.Role);
+ }
+ }
+}
diff --git a/Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs b/Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs
new file mode 100644
index 0000000..55d3050
--- /dev/null
+++ b/Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs
@@ -0,0 +1,32 @@
+using Microsoft.Identity.Client;
+using Moq;
+using NUnit.Framework;
+using Snowflake.Client;
+using Snowflake.Client.Model;
+using Snowflake.Client.Tests.IntegrationTests;
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Snowflake.Client.Tests
+{
+ public class AzureAdTokenProviderTests : AzureAdAuthInfoTests
+ {
+ [Test]
+ public async Task GetAzureAdAccessTokenAsync_ReturnsAccessToken()
+ {
+ var expectedAccessToken = "accessToken";
+ var mockTokenProvider = new Mock();
+
+ mockTokenProvider
+ .Setup(provider => provider.GetAzureAdAccessTokenAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(expectedAccessToken);
+
+ // Act
+ string actualAccessToken = await mockTokenProvider.Object.GetAzureAdAccessTokenAsync(_azureAdAuthInfo);
+
+ // Assert
+ Assert.AreEqual(expectedAccessToken, actualAccessToken);
+ }
+ }
+}
diff --git a/Snowflake.Client/AzureAdTokenProvider.cs b/Snowflake.Client/AzureAdTokenProvider.cs
new file mode 100644
index 0000000..20b7196
--- /dev/null
+++ b/Snowflake.Client/AzureAdTokenProvider.cs
@@ -0,0 +1,41 @@
+using Microsoft.Identity.Client;
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Snowflake.Client.Model;
+
+namespace Snowflake.Client
+{
+ public class AzureAdTokenProvider : IAzureAdTokenProvider
+ {
+ public async Task GetAzureAdAccessTokenAsync(AzureAdAuthInfo authInfo, CancellationToken ct = default)
+ {
+ try
+ {
+ if (authInfo.ClientId == null || authInfo.ClientSecret == null || authInfo.ServicePrincipalObjectId == null || authInfo.TenantId == null || authInfo.Scope == null)
+ {
+ throw new SnowflakeException("Error: One or more required environment variables are missing.", 400);
+ }
+
+ return await GetAccessTokenAsync(authInfo.ClientId, authInfo.ClientSecret, authInfo.ServicePrincipalObjectId, authInfo.TenantId, authInfo.Scope);
+ }
+ catch (Exception ex)
+ {
+ throw new SnowflakeException($"Failed getting the Azure Token. Message: {ex.Message}", ex);
+ }
+ }
+
+ private async Task GetAccessTokenAsync(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope)
+ {
+ IConfidentialClientApplication app = ConfidentialClientApplicationBuilder.Create(clientId)
+ .WithClientSecret(clientSecret)
+ .WithAuthority(new Uri($"https://login.microsoftonline.com/{tenantId}/"))
+ .Build();
+
+ var scopes = new[] { scope };
+
+ AuthenticationResult result = await app.AcquireTokenForClient(scopes).ExecuteAsync();
+ return result.AccessToken;
+ }
+ }
+}
\ No newline at end of file
diff --git a/Snowflake.Client/IAzureAdTokenProvider.cs b/Snowflake.Client/IAzureAdTokenProvider.cs
new file mode 100644
index 0000000..d5a4386
--- /dev/null
+++ b/Snowflake.Client/IAzureAdTokenProvider.cs
@@ -0,0 +1,11 @@
+using System.Threading;
+using System.Threading.Tasks;
+using Snowflake.Client.Model;
+
+namespace Snowflake.Client
+{
+ public interface IAzureAdTokenProvider
+ {
+ Task GetAzureAdAccessTokenAsync(AzureAdAuthInfo authInfo, CancellationToken ct = default);
+ }
+}
\ No newline at end of file
diff --git a/Snowflake.Client/Model/AuthInfo.cs b/Snowflake.Client/Model/AuthInfo.cs
index d0e0ba2..fec3d73 100644
--- a/Snowflake.Client/Model/AuthInfo.cs
+++ b/Snowflake.Client/Model/AuthInfo.cs
@@ -3,7 +3,7 @@
///
/// Snowflake Authentication information.
///
- public class AuthInfo
+ public class AuthInfo : IAuthInfo
{
///
/// Your Snowflake account name
diff --git a/Snowflake.Client/Model/AzureAdAuthInfo.cs b/Snowflake.Client/Model/AzureAdAuthInfo.cs
new file mode 100644
index 0000000..45eb805
--- /dev/null
+++ b/Snowflake.Client/Model/AzureAdAuthInfo.cs
@@ -0,0 +1,29 @@
+namespace Snowflake.Client.Model
+{
+ public class AzureAdAuthInfo : AuthInfo
+ {
+ public string ClientId { get; set; }
+ public string ClientSecret { get; set; }
+ public string ServicePrincipalObjectId { get; set; }
+ public string TenantId { get; set; }
+ public string Scope { get; set; }
+ public string Host {get; set; }
+ public string Role {get; set; }
+
+
+ public AzureAdAuthInfo(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope, string region, string account, string user, string host, string role)
+ : base(user, account, region)
+ {
+ ClientId = clientId;
+ ClientSecret = clientSecret;
+ ServicePrincipalObjectId = servicePrincipalObjectId;
+ TenantId = tenantId;
+ Scope = scope;
+ Region = region;
+ Account = account;
+ User = user;
+ Host = host;
+ Role = role;
+ }
+ }
+}
\ No newline at end of file
diff --git a/Snowflake.Client/Model/IAuthInfo.cs b/Snowflake.Client/Model/IAuthInfo.cs
new file mode 100644
index 0000000..8976247
--- /dev/null
+++ b/Snowflake.Client/Model/IAuthInfo.cs
@@ -0,0 +1,11 @@
+namespace Snowflake.Client.Model
+{
+ public interface IAuthInfo
+ {
+ string Account { get; set; }
+ string User { get; set; }
+ string Region { get; set; }
+
+ string ToString();
+ }
+}
\ No newline at end of file
diff --git a/Snowflake.Client/RequestBuilder.cs b/Snowflake.Client/RequestBuilder.cs
index 04fff78..96bccbb 100644
--- a/Snowflake.Client/RequestBuilder.cs
+++ b/Snowflake.Client/RequestBuilder.cs
@@ -51,19 +51,27 @@ internal void ClearSessionTokens()
_masterToken = null;
}
- internal HttpRequestMessage BuildLoginRequest(AuthInfo authInfo, SessionInfo sessionInfo)
+ internal HttpRequestMessage BuildLoginRequest(AuthInfo authInfo, SessionInfo sessionInfo, String azureAdAccessToken = null)
{
var requestUri = BuildLoginUrl(sessionInfo);
+ var data = new LoginRequestData();
+
+ if (authInfo is AzureAdAuthInfo azureAdAuthInfo) {
+ data = new LoginRequestData() {
+ Authenticator = "OAUTH",
+ Token = azureAdAccessToken,
+ };
+ } else {
+ data = new LoginRequestData() {
+ Password = authInfo.Password,
+ };
+ }
- var data = new LoginRequestData()
- {
- LoginName = authInfo.User,
- Password = authInfo.Password,
- AccountName = authInfo.Account,
- ClientAppId = _clientInfo.DriverName,
- ClientAppVersion = _clientInfo.DriverVersion,
- ClientEnvironment = _clientInfo.Environment
- };
+ data.LoginName = authInfo.User;
+ data.AccountName = authInfo.Account;
+ data.ClientAppId = _clientInfo.DriverName;
+ data.ClientAppVersion = _clientInfo.DriverVersion;
+ data.ClientEnvironment = _clientInfo.Environment;
var requestBody = new LoginRequest() { Data = data };
var jsonBody = JsonSerializer.Serialize(requestBody, _jsonSerializerOptions);
diff --git a/Snowflake.Client/Snowflake.Client.csproj b/Snowflake.Client/Snowflake.Client.csproj
index 1b235d2..4ceab60 100644
--- a/Snowflake.Client/Snowflake.Client.csproj
+++ b/Snowflake.Client/Snowflake.Client.csproj
@@ -32,4 +32,8 @@ Provides straightforward and efficient way to execute SQL queries in Snowflake a
+
+
+
+
diff --git a/Snowflake.Client/SnowflakeClient.cs b/Snowflake.Client/SnowflakeClient.cs
index b1b6272..616dc00 100644
--- a/Snowflake.Client/SnowflakeClient.cs
+++ b/Snowflake.Client/SnowflakeClient.cs
@@ -8,6 +8,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Identity.Client;
namespace Snowflake.Client
{
@@ -23,11 +24,41 @@ public class SnowflakeClient : ISnowflakeClient
///
public SnowflakeClientSettings Settings => _clientSettings;
+ ///
+ /// Azure AD Token Provider
+ ///
+ private readonly AzureAdTokenProvider _azureAdTokenProvider;
+
private SnowflakeSession _snowflakeSession;
private readonly RestClient _restClient;
private readonly RequestBuilder _requestBuilder;
private readonly SnowflakeClientSettings _clientSettings;
+ ///
+ /// Creates new Snowflake client.
+ ///
+ /// Client ID
+ /// Client Secret
+ /// Service Principal Object ID
+ /// Tenant ID
+ /// Scope
+ /// Region: "us-east-1", etc. Required for all except for US West Oregon (us-west-2).
+ /// Account
+ /// Username
+ /// Host
+ /// Role
+ public SnowflakeClient(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope, string region, string account, string user, string host, string role)
+ : this(new AzureAdAuthInfo(clientId, clientSecret, servicePrincipalObjectId, tenantId, scope, region, account, user, host, role), urlInfo: new UrlInfo
+ {
+ Host = host,
+ },
+ sessionInfo: new SessionInfo
+ {
+ Role = role,
+ })
+ {
+ }
+
///
/// Creates new Snowflake client.
///
@@ -52,6 +83,11 @@ public SnowflakeClient(AuthInfo authInfo, SessionInfo sessionInfo = null, UrlInf
{
}
+ public SnowflakeClient(AzureAdAuthInfo authInfo, SessionInfo sessionInfo = null, UrlInfo urlInfo = null, JsonSerializerOptions jsonMapperOptions = null)
+ : this(new SnowflakeClientSettings(authInfo, sessionInfo, urlInfo, jsonMapperOptions))
+ {
+ }
+
///
/// Creates new Snowflake client.
///
@@ -63,6 +99,7 @@ public SnowflakeClient(SnowflakeClientSettings settings)
_clientSettings = settings;
_restClient = new RestClient();
_requestBuilder = new RequestBuilder(settings.UrlInfo);
+ _azureAdTokenProvider = new AzureAdTokenProvider();
SnowflakeDataMapper.Configure(settings.JsonMapperOptions);
ChunksDownloader.Configure(settings.ChunksDownloaderOptions);
@@ -104,10 +141,24 @@ public async Task InitNewSessionAsync(CancellationToken ct = default)
return true;
}
+ ///
+ /// Authenticates user and returns new Snowflake session.
+ ///
+ /// New Snowflake session
private async Task AuthenticateAsync(AuthInfo authInfo, SessionInfo sessionInfo, CancellationToken ct)
{
var loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo);
+ if(authInfo is AzureAdAuthInfo azureAdAuthInfo)
+ {
+ var azureAdAccessToken = await _azureAdTokenProvider.GetAzureAdAccessTokenAsync(azureAdAuthInfo, ct).ConfigureAwait(false);
+ loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo, azureAdAccessToken);
+ }
+ else
+ {
+ loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo);
+ }
+
var response = await _restClient.SendAsync(loginRequest, ct).ConfigureAwait(false);
if (!response.Success)