diff --git a/sqlite3.go b/sqlite3.go index 3025a500..f7bfa363 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -399,6 +399,9 @@ type SQLiteRows struct { decltype []string ctx context.Context // no better alternative to pass context into Next() method closemu sync.Mutex + // semaphore to signal the goroutine used to interrupt queries when a + // cancellable context is passed to QueryContext + sema chan struct{} } type functionInfo struct { @@ -2050,36 +2053,37 @@ func isInterruptErr(err error) bool { // exec executes a query that doesn't return rows. Attempts to honor context timeout. func (s *SQLiteStmt) exec(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - if ctx.Done() == nil { + done := ctx.Done() + if done == nil { return s.execSync(args) } - - type result struct { - r driver.Result - err error + if err := ctx.Err(); err != nil { + return nil, err // Fast check if the channel is closed } - resultCh := make(chan result) - defer close(resultCh) + + sema := make(chan struct{}) go func() { - r, err := s.execSync(args) - resultCh <- result{r, err} - }() - var rv result - select { - case rv = <-resultCh: - case <-ctx.Done(): select { - case rv = <-resultCh: // no need to interrupt, operation completed in db - default: - // this is still racy and can be no-op if executed between sqlite3_* calls in execSync. + case <-done: C.sqlite3_interrupt(s.c.db) - rv = <-resultCh // wait for goroutine completed - if isInterruptErr(rv.err) { - return nil, ctx.Err() - } + // Wait until signaled. We need to ensure that this goroutine + // will not call interrupt after this method returns. + <-sema + case <-sema: } + }() + r, err := s.execSync(args) + // Signal the goroutine to exit. This send will only succeed at a point + // where it is impossible for the goroutine to call sqlite3_interrupt. + // + // This is necessary to ensure the goroutine does not interrupt an + // unrelated query if the context is cancelled after this method returns + // but before the goroutine exits (we don't wait for it to exit). + sema <- struct{}{} + if err != nil && isInterruptErr(err) { + return nil, ctx.Err() } - return rv.r, rv.err + return r, err } func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { @@ -2117,6 +2121,9 @@ func (rc *SQLiteRows) Close() error { return nil } rc.s = nil // remove reference to SQLiteStmt + if rc.sema != nil { + close(rc.sema) + } s.mu.Lock() if s.closed { s.mu.Unlock() @@ -2174,27 +2181,40 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { return io.EOF } - if rc.ctx.Done() == nil { + done := rc.ctx.Done() + if done == nil { return rc.nextSyncLocked(dest) } - resultCh := make(chan error) - defer close(resultCh) + if err := rc.ctx.Err(); err != nil { + return err // Fast check if the channel is closed + } + + if rc.sema == nil { + rc.sema = make(chan struct{}) + } go func() { - resultCh <- rc.nextSyncLocked(dest) - }() - select { - case err := <-resultCh: - return err - case <-rc.ctx.Done(): select { - case <-resultCh: // no need to interrupt - default: - // this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked. + case <-done: C.sqlite3_interrupt(rc.s.c.db) - <-resultCh // ensure goroutine completed + // Wait until signaled. We need to ensure that this goroutine + // will not call interrupt after this method returns. + <-rc.sema + case <-rc.sema: } - return rc.ctx.Err() + }() + + err := rc.nextSyncLocked(dest) + // Signal the goroutine to exit. This send will only succeed at a point + // where it is impossible for the goroutine to call sqlite3_interrupt. + // + // This is necessary to ensure the goroutine does not interrupt an + // unrelated query if the context is cancelled after this method returns + // but before the goroutine exits (we don't wait for it to exit). + rc.sema <- struct{}{} + if err != nil && isInterruptErr(err) { + err = rc.ctx.Err() } + return err } // nextSyncLocked moves cursor to next; must be called with locked mutex. diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index eec7479d..879d9fda 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -11,10 +11,12 @@ package sqlite3 import ( "context" "database/sql" + "errors" "fmt" "io/ioutil" "math/rand" "os" + "strings" "sync" "testing" "time" @@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) { } } +// Test that we can successfully interrupt a long running query when +// the context is canceled. The previous two QueryRowContext tests +// only test that we handle a previously cancelled context and thus +// do not call sqlite3_interrupt. +func TestQueryRowContextCancelInterrupt(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Test that we have the unixepoch function and if not skip the test. + if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil { + libVersion, libVersionNumber, sourceID := Version() + if strings.Contains(err.Error(), "no such function: unixepoch") { + t.Skip("Skipping the 'unixepoch' function is not implemented in "+ + "this version of sqlite3:", libVersion, libVersionNumber, sourceID) + } + t.Fatal(err) + } + + const createTableStmt = ` + CREATE TABLE timestamps ( + ts TIMESTAMP NOT NULL + );` + if _, err := db.Exec(createTableStmt); err != nil { + t.Fatal(err) + } + + stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + // Computationally expensive query that consumes many rows. This is needed + // to test cancellation because queries are not interrupted immediately. + // Instead, queries are only halted at certain checkpoints where the + // sqlite3.isInterrupted is checked and true. + queryStmt := ` + SELECT + SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1, + SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2, + SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3, + SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4 + FROM + timestamps + WHERE datetime(ts, 'unixepoch', 'localtime') + LIKE + ?;` + + query := func(t *testing.T, timeout time.Duration) (int, error) { + // Create a complicated pattern to match timestamps + const pattern = "%2%0%2%4%-%-%:%:%" + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + rows, err := db.QueryContext(ctx, queryStmt, pattern) + if err != nil { + return 0, err + } + var count int + for rows.Next() { + var n int64 + if err := rows.Scan(&n, &n, &n, &n); err != nil { + return count, err + } + count++ + } + return count, rows.Err() + } + + average := func(n int, fn func()) time.Duration { + start := time.Now() + for i := 0; i < n; i++ { + fn() + } + return time.Since(start) / time.Duration(n) + } + + createRows := func(n int) { + t.Logf("Creating %d rows", n) + if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil { + t.Fatal(err) + } + ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix() + rr := rand.New(rand.NewSource(1234)) + for i := 0; i < n; i++ { + if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil { + t.Fatal(err) + } + } + } + + const TargetRuntime = 200 * time.Millisecond + const N = 5_000 // Number of rows to insert at a time + + // Create enough rows that the query takes ~200ms to run. + start := time.Now() + createRows(N) + baseAvg := average(4, func() { + if _, err := query(t, time.Hour); err != nil { + t.Fatal(err) + } + }) + t.Log("Base average:", baseAvg) + rowCount := N * (int(TargetRuntime/baseAvg) + 1) + createRows(rowCount) + t.Log("Table setup time:", time.Since(start)) + + // Set the timeout to 1/10 of the average query time. + avg := average(2, func() { + n, err := query(t, time.Hour) + if err != nil { + t.Fatal(err) + } + if n == 0 { + t.Fatal("scanned zero rows") + } + }) + // Guard against the timeout being too short to reliably test. + if avg < TargetRuntime/2 { + t.Fatalf("Average query runtime should be around %s got: %s ", + TargetRuntime, avg) + } + timeout := (avg / 10).Round(100 * time.Microsecond) + t.Logf("Average: %s Timeout: %s", avg, timeout) + + for i := 0; i < 10; i++ { + tt := time.Now() + n, err := query(t, timeout) + if !errors.Is(err, context.DeadlineExceeded) { + fn := t.Errorf + if err != nil { + fn = t.Fatalf + } + fn("expected error %v got %v", context.DeadlineExceeded, err) + } + d := time.Since(tt) + t.Logf("%d: rows: %d duration: %s", i, n, d) + if d > timeout*4 { + t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d) + } + } +} + func TestExecCancel(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { diff --git a/sqlite3_test.go b/sqlite3_test.go index 94de7386..e3dcace3 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -2030,7 +2031,7 @@ func BenchmarkCustomFunctions(b *testing.B) { } func TestSuite(t *testing.T) { - initializeTestDB(t) + initializeTestDB(t, false) defer freeTestDB() for _, test := range tests { @@ -2039,7 +2040,7 @@ func TestSuite(t *testing.T) { } func BenchmarkSuite(b *testing.B) { - initializeTestDB(b) + initializeTestDB(b, true) defer freeTestDB() for _, benchmark := range benchmarks { @@ -2068,8 +2069,13 @@ type TestDB struct { var db *TestDB -func initializeTestDB(t testing.TB) { - tempFilename := TempFilename(t) +func initializeTestDB(t testing.TB, memory bool) { + var tempFilename string + if memory { + tempFilename = ":memory:" + } else { + tempFilename = TempFilename(t) + } d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { os.Remove(tempFilename) @@ -2084,9 +2090,11 @@ func freeTestDB() { if err != nil { panic(err) } - err = os.Remove(db.tempFilename) - if err != nil { - panic(err) + if db.tempFilename != "" && db.tempFilename != ":memory:" { + err := os.Remove(db.tempFilename) + if err != nil { + panic(err) + } } } @@ -2106,7 +2114,9 @@ var tests = []testing.InternalTest{ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExec", F: benchmarkExec}, + {Name: "BenchmarkExecContext", F: benchmarkExecContext}, {Name: "BenchmarkQuery", F: benchmarkQuery}, + {Name: "BenchmarkQueryContext", F: benchmarkQueryContext}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkRows", F: benchmarkRows}, @@ -2466,6 +2476,16 @@ func benchmarkExec(b *testing.B) { } } +func benchmarkExecContext(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select 1"); err != nil { + panic(err) + } + } +} + // benchmarkQuery is benchmark for query func benchmarkQuery(b *testing.B) { for i := 0; i < b.N; i++ { @@ -2480,6 +2500,65 @@ func benchmarkQuery(b *testing.B) { } } +// benchmarkQueryContext is benchmark for QueryContext +func benchmarkQueryContext(b *testing.B) { + const createTableStmt = ` + CREATE TABLE IF NOT EXISTS query_context( + id INTEGER PRIMARY KEY + ); + DELETE FROM query_context; + VACUUM;` + test := func(ctx context.Context, b *testing.B) { + if _, err := db.Exec(createTableStmt); err != nil { + b.Fatal(err) + } + for i := 0; i < 10; i++ { + _, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i)) + if err != nil { + db.Fatal(err) + } + } + stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`) + if err != nil { + b.Fatal(err) + } + b.Cleanup(func() { stmt.Close() }) + + var n int + for i := 0; i < b.N; i++ { + rows, err := stmt.QueryContext(ctx) + if err != nil { + b.Fatal(err) + } + for rows.Next() { + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + } + if err := rows.Err(); err != nil { + b.Fatal(err) + } + } + } + + // When the context does not have a Done channel we should use + // the fast path that directly handles the query instead of + // handling it in a goroutine. This benchmark also serves to + // highlight the performance impact of using a cancelable + // context. + b.Run("Background", func(b *testing.B) { + test(context.Background(), b) + }) + + // Benchmark a query with a context that can be canceled. This + // requires using a goroutine and is thus much slower. + b.Run("WithCancel", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + test(ctx, b) + }) +} + // benchmarkParams is benchmark for params func benchmarkParams(b *testing.B) { for i := 0; i < b.N; i++ {