Skip to content

Commit f0858f9

Browse files
glebteterinJames Naylor
authored andcommitted
Support MSSQL batch statements (Resolves #652)
1 parent 2788339 commit f0858f9

File tree

3 files changed

+80
-15
lines changed

3 files changed

+80
-15
lines changed

database/sqlserver/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
| `encrypt` | | `disable` - Data send between client and server is not encrypted. `false` - Data sent between client and server is not encrypted beyond the login packet (Default). `true` - Data sent between client and server is encrypted. |
1818
| `app+name` || The application name (default is go-mssqldb). |
1919
| `useMsi` | | `true` - Use Azure MSI Authentication for connecting to Sql Server. Must be running from an Azure VM/an instance with MSI enabled. `false` - Use password authentication (Default). See [here for Azure MSI Auth details](https://docs.microsoft.com/en-us/azure/app-service/app-service-web-tutorial-connect-msi). NOTE: Since this cannot be tested locally, this is not officially supported.
20+
| `x-batch` | | Enable batch statements (default: false) |
2021

2122
See https://github.com/microsoft/go-mssqldb for full parameter list.
2223

database/sqlserver/sqlserver.go

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/golang-migrate/migrate/v4/database"
1717
"github.com/hashicorp/go-multierror"
1818
mssql "github.com/microsoft/go-mssqldb" // mssql support
19+
"github.com/microsoft/go-mssqldb/batch"
1920
)
2021

2122
func init() {
@@ -30,7 +31,7 @@ var (
3031
ErrNoDatabaseName = fmt.Errorf("no database name")
3132
ErrNoSchema = fmt.Errorf("no schema")
3233
ErrDatabaseDirty = fmt.Errorf("database is dirty")
33-
ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
34+
ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed")
3435
)
3536

3637
var lockErrorMap = map[int]string{
@@ -42,9 +43,10 @@ var lockErrorMap = map[int]string{
4243

4344
// Config for database
4445
type Config struct {
45-
MigrationsTable string
46-
DatabaseName string
47-
SchemaName string
46+
MigrationsTable string
47+
DatabaseName string
48+
SchemaName string
49+
BatchStatementEnabled bool
4850
}
4951

5052
// SQL Server connection
@@ -168,9 +170,18 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) {
168170

169171
migrationsTable := purl.Query().Get("x-migrations-table")
170172

173+
batchStatementEnabled := false
174+
if s := purl.Query().Get("x-batch"); len(s) > 0 {
175+
batchStatementEnabled, err = strconv.ParseBool(s)
176+
if err != nil {
177+
return nil, fmt.Errorf("unable to parse option x-batch: %w", err)
178+
}
179+
}
180+
171181
px, err := WithInstance(db, &Config{
172-
DatabaseName: purl.Path,
173-
MigrationsTable: migrationsTable,
182+
DatabaseName: purl.Path,
183+
MigrationsTable: migrationsTable,
184+
BatchStatementEnabled: batchStatementEnabled,
174185
})
175186

176187
if err != nil {
@@ -247,15 +258,23 @@ func (ss *SQLServer) Run(migration io.Reader) error {
247258

248259
// run migration
249260
query := string(migr[:])
250-
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
251-
if msErr, ok := err.(mssql.Error); ok {
252-
message := fmt.Sprintf("migration failed: %s", msErr.Message)
253-
if msErr.ProcName != "" {
254-
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
261+
scripts := []string{query}
262+
263+
if ss.config.BatchStatementEnabled {
264+
scripts = batch.Split(query, "go")
265+
}
266+
267+
for _, script := range scripts {
268+
if _, err := ss.conn.ExecContext(context.Background(), script); err != nil {
269+
if msErr, ok := err.(mssql.Error); ok {
270+
message := fmt.Sprintf("migration failed: %s", msErr.Message)
271+
if msErr.ProcName != "" {
272+
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
273+
}
274+
return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)}
255275
}
256-
return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
276+
return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)}
257277
}
258-
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
259278
}
260279

261280
return nil

database/sqlserver/sqlserver_test.go

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ var (
3737
}
3838
)
3939

40-
func msConnectionString(host, port string) string {
41-
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
40+
func msConnectionString(host, port string, options ...string) string {
41+
options = append(options, "database=master")
42+
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&"))
4243
}
4344

4445
func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string {
@@ -87,6 +88,7 @@ func Test(t *testing.T) {
8788
t.Run("test", test)
8889
t.Run("testMigrate", testMigrate)
8990
t.Run("testMultiStatement", testMultiStatement)
91+
t.Run("testBatchedStatement", testBatchedStatement)
9092
t.Run("testErrorParsing", testErrorParsing)
9193
t.Run("testLockWorks", testLockWorks)
9294
t.Run("testMsiTrue", testMsiTrue)
@@ -191,6 +193,49 @@ func testMultiStatement(t *testing.T) {
191193
})
192194
}
193195

196+
func testBatchedStatement(t *testing.T) {
197+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
198+
ip, port, err := c.Port(defaultPort)
199+
if err != nil {
200+
t.Fatal(err)
201+
}
202+
203+
addr := msConnectionString(ip, port, "x-batch=true")
204+
ms := &SQLServer{}
205+
d, err := ms.Open(addr)
206+
if err != nil {
207+
t.Fatal(err)
208+
}
209+
defer func() {
210+
if err := d.Close(); err != nil {
211+
t.Error(err)
212+
}
213+
}()
214+
if err := d.Run(strings.NewReader(`CREATE PROCEDURE uspA
215+
AS
216+
BEGIN
217+
SELECT 1;
218+
END;
219+
GO
220+
CREATE PROCEDURE uspB
221+
AS
222+
BEGIN
223+
SELECT 2;
224+
END`)); err != nil {
225+
t.Fatalf("expected err to be nil, got %v", err)
226+
}
227+
228+
// make sure second proc exists
229+
var exists int
230+
if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'P' and category = 0 and [NAME] = 'uspB'").Scan(&exists); err != nil {
231+
t.Fatal(err)
232+
}
233+
if exists != 1 {
234+
t.Fatalf("expected proc uspB to exist")
235+
}
236+
})
237+
}
238+
194239
func testErrorParsing(t *testing.T) {
195240
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
196241
SkipIfUnsupportedArch(t, c)

0 commit comments

Comments
 (0)