Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bug/285 - allow prepared statements to be prepared multiple times #309

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion sqlmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,16 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
if _, ok = next.(*ExpectedPrepare); !ok {
continue
} else {
next.Lock()
}
}

if c.ordered {
if expected, ok = next.(*ExpectedPrepare); ok {
fulfilled--
break
}

Expand All @@ -311,6 +316,7 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
if pr, ok := next.(*ExpectedPrepare); ok {
if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
expected = pr
fulfilled--
break
}
}
Expand All @@ -334,6 +340,13 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
}

func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
for _, e := range c.expected {
if ep, ok := e.(*ExpectedPrepare); ok {
if ep.expectSQL == expectedSQL {
return ep
}
}
}
e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
c.expected = append(c.expected, e)
return e
Expand Down
127 changes: 127 additions & 0 deletions sqlmock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,50 @@ func TestUnorderedPreparedQueryExecutions(t *testing.T) {
}
}

func TestParallelPreparedQueryExecutions(t *testing.T) {
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
mock.MatchExpectationsInOrder(false)

mock.ExpectPrepare("INSERT INTO authors \\((.+)\\) VALUES \\((.+)\\)").
ExpectExec().
WithArgs(1, "Jane Doe").
WillReturnResult(NewResult(1, 1))

mock.ExpectPrepare("INSERT INTO authors \\((.+)\\) VALUES \\((.+)\\)").
ExpectExec().
WithArgs(0, "John Doe").
WillReturnResult(NewResult(0, 1))

t.Run("Parallel1", func(t *testing.T) {
t.Parallel()

stmt, err := db.Prepare("INSERT INTO authors (id, name) VALUES (?, ?)")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
} else {
_, err = stmt.Exec(0, "John Doe")
}
})

t.Run("Parallel2", func(t *testing.T) {
t.Parallel()

stmt, err := db.Prepare("INSERT INTO authors (id, name) VALUES (?, ?)")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
} else {
_, err = stmt.Exec(1, "Jane Doe")
}
})

t.Cleanup(func() {
db.Close()
})
}

func TestUnexpectedOperations(t *testing.T) {
t.Parallel()
db, mock, err := New()
Expand Down Expand Up @@ -632,6 +676,89 @@ func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) {
// note this line is important for unordered expectation matching
mock.MatchExpectationsInOrder(false)

data := []interface{}{
1,
"John Doe",
2,
"Jane Doe",
}
rows := NewRows([]string{"id", "name"})
rows.AddRow(data[0], data[1])
rows.AddRow(data[2], data[3])

mock.ExpectExec("DROP TABLE IF EXISTS author").WillReturnResult(NewResult(0, 0))
mock.ExpectExec("TRUNCATE TABLE").WillReturnResult(NewResult(0, 0))

mock.ExpectExec("CREATE TABLE IF NOT EXISTS author").WillReturnResult(NewResult(0, 0))

mock.ExpectQuery("SELECT").WillReturnRows(rows).WithArgs()

mock.ExpectPrepare("INSERT INTO").
ExpectExec().
WithArgs(
data[0],
data[1],
data[2],
data[3],
).
WillReturnResult(NewResult(0, 2))

var wg sync.WaitGroup
queries := []func() error{
func() error {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS author (a varchar(255)")
return err
},
func() error {
_, err := db.Exec("TRUNCATE TABLE author")
return err
},
func() error {
stmt, err := db.Prepare("INSERT INTO author (id,name) VALUES (?,?),(?,?)")
if err != nil {
return err
}
_, err = stmt.Exec(1, "John Doe", 2, "Jane Doe")
return err
},
func() error {
_, err := db.Query("SELECT * FROM author")
return err
},
func() error {
_, err := db.Exec("DROP TABLE IF EXISTS author")
return err
},
}

wg.Add(len(queries))
for _, f := range queries {
go func(f func() error) {
if err := f(); err != nil {
t.Errorf("error was not expected: %s", err)
}
wg.Done()
}(f)
}

wg.Wait()

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestGoroutineExecutionMultiTypes(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()

// note this line is important for unordered expectation matching
mock.MatchExpectationsInOrder(false)

result := NewResult(1, 1)

mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)
Expand Down