Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

*: fix query cancellation #565

Merged
merged 1 commit into from
Nov 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) {

if s[1] == "query" {
logrus.Infof("kill query: id %v", id)
h.e.Catalog.Kill(id)
h.e.Catalog.KillConnection(uint32(id))
} else {
logrus.Infof("kill connection: id %v, pid: %v", conn.ConnectionID, id)
h.mu.Lock()
Expand All @@ -189,7 +189,7 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) {
return false, errConnectionNotFound.New(conn.ConnectionID)
}

h.e.Catalog.KillConnection(id)
h.e.Catalog.KillConnection(uint32(id))
h.sm.CloseConn(c)
c.Close()
}
Expand Down
13 changes: 13 additions & 0 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ func TestHandlerKill(t *testing.T) {
require.Equal(conn1, handler.c[1])
require.Equal(conn2, handler.c[2])

assertNoConnProcesses(t, e, conn2.ConnectionID)

err = handler.ComQuery(conn2, "KILL 1", func(res *sqltypes.Result) error {
return nil
})
Expand All @@ -203,4 +205,15 @@ func TestHandlerKill(t *testing.T) {
require.Len(handler.sm.sessions, 0)
require.Len(handler.c, 1)
require.Equal(conn1, handler.c[1])
assertNoConnProcesses(t, e, conn2.ConnectionID)
}

func assertNoConnProcesses(t *testing.T, e *sqle.Engine, conn uint32) {
t.Helper()

for _, p := range e.Catalog.Processes() {
if p.Connection == conn {
t.Errorf("expecting no processes with connection id %d", conn)
}
}
}
2 changes: 0 additions & 2 deletions sql/parse/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package parse

import (
"bufio"
"fmt"
"io"
"strings"

Expand All @@ -11,7 +10,6 @@ import (
)

func parseLockTables(ctx *sql.Context, query string) (sql.Node, error) {
fmt.Println(query)
var r = bufio.NewReader(strings.NewReader(query))
var tables []*plan.TableLock
err := parseFuncs{
Expand Down
12 changes: 8 additions & 4 deletions sql/plan/exchange.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plan

import (
"context"
"fmt"
"io"
"sync"
Expand Down Expand Up @@ -103,7 +104,7 @@ func newExchangeRowIter(
ctx: ctx,
parallelism: parallelism,
rows: make(chan sql.Row, parallelism),
err: make(chan error),
err: make(chan error, 1),
started: false,
tree: tree,
partitions: iter,
Expand Down Expand Up @@ -149,6 +150,7 @@ func (it *exchangeRowIter) start() {
for {
select {
case <-it.ctx.Done():
it.err <- context.Canceled
it.closeTokens()
return
case <-it.quit:
Expand Down Expand Up @@ -186,6 +188,7 @@ func (it *exchangeRowIter) iterPartitions(ch chan<- sql.Partition) {
for {
select {
case <-it.ctx.Done():
it.err <- context.Canceled
return
case <-it.quit:
return
Expand Down Expand Up @@ -232,6 +235,7 @@ func (it *exchangeRowIter) iterPartition(p sql.Partition) {
for {
select {
case <-it.ctx.Done():
it.err <- context.Canceled
return
case <-it.quit:
return
Expand Down Expand Up @@ -259,14 +263,14 @@ func (it *exchangeRowIter) Next() (sql.Row, error) {
}

select {
case err := <-it.err:
_ = it.Close()
return nil, err
case row, ok := <-it.rows:
if !ok {
return nil, io.EOF
}
return row, nil
case err := <-it.err:
_ = it.Close()
return nil, err
}
}

Expand Down
34 changes: 34 additions & 0 deletions sql/plan/exchange_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plan

import (
"context"
"fmt"
"io"
"testing"
Expand Down Expand Up @@ -56,6 +57,39 @@ func TestExchange(t *testing.T) {
}
}

func TestExchangeCancelled(t *testing.T) {
children := NewProject(
[]sql.Expression{
expression.NewGetField(0, sql.Text, "partition", false),
expression.NewArithmetic(
expression.NewGetField(1, sql.Int64, "val", false),
expression.NewLiteral(int64(1), sql.Int64),
"+",
),
},
NewFilter(
expression.NewLessThan(
expression.NewGetField(1, sql.Int64, "val", false),
expression.NewLiteral(int64(4), sql.Int64),
),
&partitionable{nil, 3, 6},
),
)

exchange := NewExchange(3, children)
require := require.New(t)

c, cancel := context.WithCancel(context.Background())
ctx := sql.NewContext(c)
cancel()

iter, err := exchange.RowIter(ctx)
require.NoError(err)

_, err = iter.Next()
require.Equal(context.Canceled, err)
}

type partitionable struct {
sql.Node
partitions int
Expand Down
7 changes: 7 additions & 0 deletions sql/plan/resolved_table.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plan

import (
"context"
"io"

"gopkg.in/src-d/go-mysql-server.v0/sql"
Expand Down Expand Up @@ -62,6 +63,12 @@ type tableIter struct {
}

func (i *tableIter) Next() (sql.Row, error) {
select {
case <-i.ctx.Done():
return nil, context.Canceled
default:
}

if i.partition == nil {
partition, err := i.partitions.Next()
if err != nil {
Expand Down
17 changes: 17 additions & 0 deletions sql/plan/resolved_table_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plan

import (
"context"
"fmt"
"io"
"testing"
Expand Down Expand Up @@ -31,6 +32,22 @@ func TestResolvedTable(t *testing.T) {
}
}

func TestResolvedTableCancelled(t *testing.T) {
var require = require.New(t)

table := NewResolvedTable(newTableTest("test"))
require.NotNil(table)

ctx, cancel := context.WithCancel(context.Background())
cancel()

iter, err := table.RowIter(sql.NewContext(ctx))
require.NoError(err)

_, err = iter.Next()
require.Equal(context.Canceled, err)
}

func newTableTest(source string) sql.Table {
schema := []*sql.Column{
{Name: "col1", Type: sql.Int32, Source: source, Default: int32(0), Nullable: false},
Expand Down
14 changes: 3 additions & 11 deletions sql/processlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,14 @@ func (pl *ProcessList) Kill(pid uint64) {
pl.Done(pid)
}

// KillConnection kills all processes that have the same connection as the one
// of the process with the given process id. If the process does not exist, it
// will do nothing.
func (pl *ProcessList) KillConnection(pid uint64) {
// KillConnection kills all processes from the given connection.
func (pl *ProcessList) KillConnection(conn uint32) {
pl.mu.Lock()
defer pl.mu.Unlock()

proc, ok := pl.procs[pid]
if !ok {
return
}

conn := proc.Connection
for pid, proc := range pl.procs {
if proc.Connection == conn {
proc.Kill()
proc.Done()
delete(pl.procs, pid)
}
}
Expand Down
36 changes: 36 additions & 0 deletions sql/processlist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,39 @@ func sortByPid(slice []Process) {
return slice[i].Pid < slice[j].Pid
})
}

func TestKillConnection(t *testing.T) {
pl := NewProcessList()

s1 := NewSession("", "", "", 1)
s2 := NewSession("", "", "", 2)

var killed = make(map[uint64]bool)
for i := uint64(1); i <= 3; i++ {
// Odds get s1, evens get s2
s := s1
if i%2 == 0 {
s = s2
}

_, err := pl.AddProcess(
NewContext(context.Background(), WithPid(i), WithSession(s)),
QueryProcess,
"foo",
)
require.NoError(t, err)

i := i
pl.procs[i].Kill = func() {
killed[i] = true
}
}

pl.KillConnection(1)
require.Len(t, pl.procs, 1)

// Odds should have been killed
require.True(t, killed[1])
require.False(t, killed[2])
require.True(t, killed[3])
}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy