Skip to content

Commit 991022a

Browse files
authored
improve the DialCloud function (#159)
DialCloud function now takes almost any endpoint to the cloud, be it grpc endpoint or the graphql endpoint.
1 parent adce802 commit 991022a

File tree

2 files changed

+75
-8
lines changed

2 files changed

+75
-8
lines changed

client.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package dgo
1919
import (
2020
"context"
2121
"crypto/x509"
22+
"errors"
2223
"fmt"
2324
"math/rand"
2425
"net/url"
@@ -82,22 +83,37 @@ func NewDgraphClient(clients ...api.DgraphClient) *Dgraph {
8283
// defer conn.Close()
8384
// dgraphClient := dgo.NewDgraphClient(api.NewDgraphClient(conn))
8485
func DialCloud(endpoint, key string) (*grpc.ClientConn, error) {
85-
u, err := url.Parse(endpoint)
86-
if err != nil {
87-
return nil, err
86+
var grpcHost string
87+
switch {
88+
case strings.Contains(endpoint, ".grpc.") && strings.Contains(endpoint, ":"+cloudPort):
89+
// if we already have the grpc URL with the port, we don't need to do anything
90+
grpcHost = endpoint
91+
case strings.Contains(endpoint, ".grpc.") && !strings.Contains(endpoint, ":"+cloudPort):
92+
// if we have the grpc URL without the port, just add the port
93+
grpcHost = endpoint + ":" + cloudPort
94+
default:
95+
// otherwise, parse the non-grpc URL and add ".grpc." along with port to it.
96+
if !strings.HasPrefix(endpoint, "http") {
97+
endpoint = "https://" + endpoint
98+
}
99+
u, err := url.Parse(endpoint)
100+
if err != nil {
101+
return nil, err
102+
}
103+
urlParts := strings.SplitN(u.Host, ".", 2)
104+
if len(urlParts) < 2 {
105+
return nil, errors.New("invalid URL to Dgraph Cloud")
106+
}
107+
grpcHost = urlParts[0] + ".grpc." + urlParts[1] + ":" + cloudPort
88108
}
89109

90-
urlParts := strings.SplitN(u.Host, ".", 2)
91-
92-
host := urlParts[0] + ".grpc." + urlParts[1] + ":" + cloudPort
93110
pool, err := x509.SystemCertPool()
94111
if err != nil {
95112
return nil, err
96113
}
97-
98114
creds := credentials.NewClientTLSFromCert(pool, "")
99115
return grpc.Dial(
100-
host,
116+
grpcHost,
101117
grpc.WithTransportCredentials(creds),
102118
grpc.WithPerRPCCredentials(&authCreds{key}),
103119
)

cloud_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (C) 2023 Dgraph Labs, Inc. and Contributors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package dgo_test
18+
19+
import (
20+
"testing"
21+
22+
"github.com/stretchr/testify/require"
23+
24+
"github.com/dgraph-io/dgo/v210"
25+
)
26+
27+
func TestDialCLoud(t *testing.T) {
28+
cases := []struct {
29+
endpoint string
30+
err string
31+
}{
32+
{endpoint: "godly.grpc.region.aws.cloud.dgraph.io"},
33+
{endpoint: "godly.grpc.region.aws.cloud.dgraph.io:443"},
34+
{endpoint: "https://godly.region.aws.cloud.dgraph.io/graphql"},
35+
{endpoint: "godly.region.aws.cloud.dgraph.io"},
36+
{endpoint: "https://godly.region.aws.cloud.dgraph.io"},
37+
{endpoint: "random:url", err: "invalid port"},
38+
{endpoint: "google", err: "invalid URL"},
39+
}
40+
41+
for _, tc := range cases {
42+
t.Run(tc.endpoint, func(t *testing.T) {
43+
_, err := dgo.DialCloud(tc.endpoint, "abc123")
44+
if tc.err == "" {
45+
require.NoError(t, err)
46+
} else {
47+
require.Contains(t, err.Error(), tc.err)
48+
}
49+
})
50+
}
51+
}

0 commit comments

Comments
 (0)