Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit b19c55a

Browse files
committed
sql/plan: speed up InnerJoin iterator
pkg: gopkg.in/src-d/go-mysql-server.v0/sql/plan BenchmarkInnerJoin/inner_join-4 50000 34188 ns/op 11415 B/op 169 allocs/op BenchmarkInnerJoin/cross_join_with_filter-4 30000 41608 ns/op 12119 B/op 181 allocs/op PASS ok gopkg.in/src-d/go-mysql-server.v0/sql/plan 3.795s Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
1 parent ac59802 commit b19c55a

File tree

2 files changed

+153
-9
lines changed

2 files changed

+153
-9
lines changed

sql/plan/innerjoin.go

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plan
22

33
import (
4+
"io"
45
"reflect"
56

67
opentracing "github.com/opentracing/opentracing-go"
@@ -60,15 +61,12 @@ func (j *InnerJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) {
6061
return nil, err
6162
}
6263

63-
return sql.NewSpanIter(span, NewFilterIter(
64-
ctx,
65-
j.Cond,
66-
&crossJoinIterator{
67-
l: l,
68-
rp: j.Right,
69-
s: ctx,
70-
},
71-
)), nil
64+
return sql.NewSpanIter(span, &innerJoinIter{
65+
l: l,
66+
rp: j.Right,
67+
ctx: ctx,
68+
cond: j.Cond,
69+
}), nil
7270
}
7371

7472
// TransformUp implements the Transformable interface.
@@ -127,3 +125,74 @@ func (j *InnerJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, err
127125

128126
return NewInnerJoin(j.Left, j.Right, cond), nil
129127
}
128+
129+
type innerJoinIter struct {
130+
l sql.RowIter
131+
rp rowIterProvider
132+
r sql.RowIter
133+
ctx *sql.Context
134+
cond sql.Expression
135+
136+
leftRow sql.Row
137+
}
138+
139+
func (i *innerJoinIter) Next() (sql.Row, error) {
140+
for {
141+
if i.leftRow == nil {
142+
r, err := i.l.Next()
143+
if err != nil {
144+
return nil, err
145+
}
146+
147+
i.leftRow = r
148+
}
149+
150+
if i.r == nil {
151+
iter, err := i.rp.RowIter(i.ctx)
152+
if err != nil {
153+
return nil, err
154+
}
155+
156+
i.r = iter
157+
}
158+
159+
rightRow, err := i.r.Next()
160+
if err == io.EOF {
161+
i.r = nil
162+
i.leftRow = nil
163+
continue
164+
}
165+
166+
if err != nil {
167+
return nil, err
168+
}
169+
170+
var row sql.Row
171+
row = append(row, i.leftRow...)
172+
row = append(row, rightRow...)
173+
174+
v, err := i.cond.Eval(i.ctx, row)
175+
if err != nil {
176+
return nil, err
177+
}
178+
179+
if v == true {
180+
return row, nil
181+
}
182+
}
183+
}
184+
185+
func (i *innerJoinIter) Close() error {
186+
if err := i.l.Close(); err != nil {
187+
if i.r != nil {
188+
_ = i.r.Close()
189+
}
190+
return err
191+
}
192+
193+
if i.r != nil {
194+
return i.r.Close()
195+
}
196+
197+
return nil
198+
}

sql/plan/innerjoin_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plan
22

33
import (
4+
"fmt"
45
"testing"
56

67
"github.com/stretchr/testify/require"
@@ -57,3 +58,77 @@ func TestInnerJoinEmpty(t *testing.T) {
5758

5859
assertRows(t, iter, 0)
5960
}
61+
62+
func BenchmarkInnerJoin(b *testing.B) {
63+
t1 := mem.NewTable("foo", sql.Schema{
64+
{Name: "a", Source: "foo", Type: sql.Int64},
65+
{Name: "b", Source: "foo", Type: sql.Text},
66+
})
67+
68+
t2 := mem.NewTable("bar", sql.Schema{
69+
{Name: "a", Source: "bar", Type: sql.Int64},
70+
{Name: "b", Source: "bar", Type: sql.Text},
71+
})
72+
73+
for i := 0; i < 5; i++ {
74+
t1.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t1_%d", i)))
75+
t2.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t2_%d", i)))
76+
}
77+
78+
n1 := NewInnerJoin(
79+
NewResolvedTable(t1),
80+
NewResolvedTable(t2),
81+
expression.NewEquals(
82+
expression.NewGetField(0, sql.Int64, "a", false),
83+
expression.NewGetField(2, sql.Int64, "a", false),
84+
),
85+
)
86+
87+
n2 := NewFilter(
88+
expression.NewEquals(
89+
expression.NewGetField(0, sql.Int64, "a", false),
90+
expression.NewGetField(2, sql.Int64, "a", false),
91+
),
92+
NewCrossJoin(
93+
NewResolvedTable(t1),
94+
NewResolvedTable(t2),
95+
),
96+
)
97+
98+
expected := []sql.Row{
99+
{int64(0), "t1_0", int64(0), "t2_0"},
100+
{int64(1), "t1_1", int64(1), "t2_1"},
101+
{int64(2), "t1_2", int64(2), "t2_2"},
102+
{int64(3), "t1_3", int64(3), "t2_3"},
103+
{int64(4), "t1_4", int64(4), "t2_4"},
104+
}
105+
106+
ctx := sql.NewEmptyContext()
107+
b.Run("inner join", func(b *testing.B) {
108+
require := require.New(b)
109+
110+
for i := 0; i < b.N; i++ {
111+
iter, err := n1.RowIter(ctx)
112+
require.NoError(err)
113+
114+
rows, err := sql.RowIterToRows(iter)
115+
require.NoError(err)
116+
117+
require.Equal(expected, rows)
118+
}
119+
})
120+
121+
b.Run("cross join with filter", func(b *testing.B) {
122+
require := require.New(b)
123+
124+
for i := 0; i < b.N; i++ {
125+
iter, err := n2.RowIter(ctx)
126+
require.NoError(err)
127+
128+
rows, err := sql.RowIterToRows(iter)
129+
require.NoError(err)
130+
131+
require.Equal(expected, rows)
132+
}
133+
})
134+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy