Skip to content

Commit 81d3026

Browse files
committed
Added sql.DB wrapper to satisfy squirrel's hidden type assertions
Fixes #256, hopefully for good
1 parent 572ee7e commit 81d3026

File tree

2 files changed

+101
-39
lines changed

2 files changed

+101
-39
lines changed

store.go

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -64,62 +64,69 @@ func defaultLogger(message string, args ...interface{}) {
6464
log.Printf("%s, args: %v", message, args)
6565
}
6666

67-
// basicLogger is a database runner that logs all SQL statements executed.
68-
type basicLogger struct {
69-
logger LoggerFunc
70-
runner squirrel.BaseRunner
71-
}
72-
73-
// basicLogger is a database runner that logs all SQL statements executed.
67+
// runnerLogger is a database runner that logs all SQL statements executed.
7468
type proxyLogger struct {
75-
basicLogger
69+
squirrel.DBProxyContext
70+
logger LoggerFunc
7671
}
7772

78-
func (p *basicLogger) Exec(query string, args ...interface{}) (sql.Result, error) {
73+
func (p *proxyLogger) Exec(query string, args ...interface{}) (sql.Result, error) {
7974
p.logger(fmt.Sprintf("kallax: Exec: %s", query), args...)
80-
return p.runner.Exec(query, args...)
75+
return p.DBProxyContext.Exec(query, args...)
8176
}
8277

83-
func (p *basicLogger) Query(query string, args ...interface{}) (*sql.Rows, error) {
78+
func (p *proxyLogger) Query(query string, args ...interface{}) (*sql.Rows, error) {
8479
p.logger(fmt.Sprintf("kallax: Query: %s", query), args...)
85-
return p.runner.Query(query, args...)
80+
return p.DBProxyContext.Query(query, args...)
8681
}
8782

8883
func (p *proxyLogger) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
89-
p.basicLogger.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
90-
if queryRower, ok := p.basicLogger.runner.(squirrel.QueryRower); ok {
91-
return queryRower.QueryRow(query, args...)
92-
} else {
93-
panic("Called proxyLogger with a runner which doesn't implement QueryRower")
94-
}
84+
p.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
85+
return p.DBProxyContext.QueryRow(query, args...)
9586
}
9687

9788
func (p *proxyLogger) Prepare(query string) (*sql.Stmt, error) {
98-
// If chained runner is a proxy, run Prepare(). Otherwise, noop.
99-
if preparer, ok := p.basicLogger.runner.(squirrel.Preparer); ok {
100-
p.basicLogger.logger(fmt.Sprintf("kallax: Prepare: %s", query))
101-
return preparer.Prepare(query)
102-
} else {
103-
panic("Called proxyLogger with a runner which doesn't implement Preparer")
104-
}
89+
//If chained runner is a proxy, run Prepare(). Otherwise, noop.
90+
p.logger(fmt.Sprintf("kallax: Prepare: %s", query))
91+
return p.DBProxyContext.Prepare(query)
92+
}
93+
94+
// PrepareContext will not be logged
95+
96+
// dbRunner is a copypaste from squirrel.dbRunner, used to make sql.DB implement squirrel.QueryRower.
97+
// squirrel will silently fail and return nil if BaseRunner(s) supplied to RunWith don't implement QueryRower, so
98+
// it has been copied there to avoid that.
99+
// TODO: Delete this when squirrel dependency is dropped.
100+
type dbRunner struct {
101+
*sql.DB
102+
}
103+
104+
func (r *dbRunner) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
105+
return r.DB.QueryRow(query, args...)
106+
}
107+
108+
// txRunner does the analogous for sql.Tx
109+
type txRunner struct {
110+
*sql.Tx
111+
}
112+
113+
func (r *txRunner) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
114+
return r.Tx.QueryRow(query, args...)
105115
}
106116

107117
// Store is a structure capable of retrieving records from a concrete table in
108118
// the database.
109119
type Store struct {
110-
db interface {
111-
squirrel.BaseRunner
112-
squirrel.PreparerContext
113-
}
114-
runner squirrel.BaseRunner
120+
db squirrel.DBProxyContext
121+
runner squirrel.DBProxyContext
115122
useCacher bool
116123
logger LoggerFunc
117124
}
118125

119126
// NewStore returns a new Store instance.
120127
func NewStore(db *sql.DB) *Store {
121128
return (&Store{
122-
db: db,
129+
db: &dbRunner{db},
123130
useCacher: true,
124131
}).init()
125132
}
@@ -132,12 +139,8 @@ func (s *Store) init() *Store {
132139
s.runner = squirrel.NewStmtCacher(s.db)
133140
}
134141

135-
if s.logger != nil && !s.useCacher {
136-
// Use BasicLogger as wrapper
137-
s.runner = &basicLogger{s.logger, s.db}
138-
} else if s.logger != nil && s.useCacher {
139-
// We're using a proxy (cacher), so use proxyLogger instead
140-
s.runner = &proxyLogger{basicLogger{s.logger, s.runner}}
142+
if s.logger != nil {
143+
s.runner = &proxyLogger{logger: s.logger, DBProxyContext: s.runner}
141144
}
142145

143146
return s
@@ -469,7 +472,7 @@ func (s *Store) MustCount(q Query) int64 {
469472
func (s *Store) Transaction(callback func(*Store) error) error {
470473
var tx *sql.Tx
471474
var err error
472-
if db, ok := s.db.(*sql.DB); ok {
475+
if db, ok := s.db.(*dbRunner); ok {
473476
// db is *sql.DB, not *sql.Tx
474477
tx, err = db.Begin()
475478
if err != nil {
@@ -481,7 +484,7 @@ func (s *Store) Transaction(callback func(*Store) error) error {
481484
}
482485

483486
txStore := (&Store{
484-
db: tx,
487+
db: &txRunner{tx},
485488
logger: s.logger,
486489
useCacher: s.useCacher,
487490
}).init()

tests/store_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,65 @@ func (s *StoreSuite) TestFindOne() {
232232
}
233233
}
234234

235+
func (s *StoreSuite) TestDebug() {
236+
store := NewStoreWithConstructFixtureStore(s.db)
237+
238+
docInserted := NewStoreWithConstructFixture("bar")
239+
s.Nil(store.DisableCacher().Insert(docInserted))
240+
241+
query := NewStoreWithConstructFixtureQuery()
242+
243+
// Normal find
244+
docFound, err := store.FindOne(query)
245+
246+
s.resultOrError(docFound, err)
247+
if s.NotNil(docFound) {
248+
s.Equal(docInserted.Foo, docFound.Foo)
249+
}
250+
251+
// Debug
252+
docFound, err = store.Debug().FindOne(query)
253+
254+
s.resultOrError(docFound, err)
255+
if s.NotNil(docFound) {
256+
s.Equal(docInserted.Foo, docFound.Foo)
257+
}
258+
}
259+
260+
func (s *StoreSuite) TestDebugWithoutCacher() {
261+
store := NewStoreWithConstructFixtureStore(s.db)
262+
263+
docInserted := NewStoreWithConstructFixture("bar")
264+
s.Nil(store.DisableCacher().Insert(docInserted))
265+
266+
query := NewStoreWithConstructFixtureQuery()
267+
268+
// Normal find
269+
docFound, err := store.FindOne(query)
270+
271+
s.resultOrError(docFound, err)
272+
if s.NotNil(docFound) {
273+
s.Equal(docInserted.Foo, docFound.Foo)
274+
}
275+
276+
// No cacher -> debug
277+
noCacherDebugStore := store.DisableCacher().Debug()
278+
docFound, err = noCacherDebugStore.FindOne(query)
279+
280+
s.resultOrError(docFound, err)
281+
if s.NotNil(docFound) {
282+
s.Equal(docInserted.Foo, docFound.Foo)
283+
}
284+
285+
// Debug -> no cacher
286+
docFound, err = store.Debug().DisableCacher().FindOne(query)
287+
288+
s.resultOrError(docFound, err)
289+
if s.NotNil(docFound) {
290+
s.Equal(docInserted.Foo, docFound.Foo)
291+
}
292+
}
293+
235294
func (s *StoreSuite) TestFindAliasSlice() {
236295
store := NewStoreFixtureStore(s.db)
237296

0 commit comments

Comments
 (0)