Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions pkg/cdc/sinker_v2_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"io"
"net"
"strconv"
"sync"
"testing"
"time"

Expand All @@ -29,6 +35,175 @@ import (
"github.com/stretchr/testify/require"
)

type fakeMySQLServer struct {
listener net.Listener
queries chan string
errs chan error
wg sync.WaitGroup
}

func startFakeMySQLServer(t *testing.T) *fakeMySQLServer {
t.Helper()

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

server := &fakeMySQLServer{
listener: listener,
queries: make(chan string, 4),
errs: make(chan error, 1),
}
server.wg.Add(1)
go server.serve()

t.Cleanup(func() {
_ = listener.Close()
server.wg.Wait()
})

return server
}

func (s *fakeMySQLServer) addr(t *testing.T) (string, int) {
t.Helper()

host, portStr, err := net.SplitHostPort(s.listener.Addr().String())
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
return host, port
}

func (s *fakeMySQLServer) serve() {
defer s.wg.Done()

conn, err := s.listener.Accept()
if err != nil {
if !errorsIsNetClosed(err) {
s.reportErr(err)
}
return
}
defer conn.Close()

if err := writeMySQLPacket(conn, 0, mysqlHandshakePayload()); err != nil {
s.reportErr(err)
return
}
if _, _, err := readMySQLPacket(conn); err != nil {
s.reportErr(err)
return
}
if err := writeMySQLOK(conn, 2); err != nil {
s.reportErr(err)
return
}

for {
_, payload, err := readMySQLPacket(conn)
if err != nil {
if err != io.EOF && !errorsIsNetClosed(err) {
s.reportErr(err)
}
return
}
if len(payload) == 0 {
continue
}

switch payload[0] {
case 0x01: // COM_QUIT
return
case 0x03: // COM_QUERY
s.queries <- string(payload[1:])
if err := writeMySQLOK(conn, 1); err != nil {
s.reportErr(err)
return
}
case 0x0e: // COM_PING
if err := writeMySQLOK(conn, 1); err != nil {
s.reportErr(err)
return
}
default:
s.reportErr(io.ErrUnexpectedEOF)
return
}
}
}

func (s *fakeMySQLServer) reportErr(err error) {
select {
case s.errs <- err:
default:
}
}

func errorsIsNetClosed(err error) bool {
return errors.Is(err, net.ErrClosed) || err.Error() == "use of closed network connection"
}

func mysqlHandshakePayload() []byte {
const (
clientLongPassword uint32 = 1 << 0
clientLongFlag uint32 = 1 << 2
clientProtocol41 uint32 = 1 << 9
clientTransactions uint32 = 1 << 13
clientSecureConn uint32 = 1 << 15
clientMultiStatements uint32 = 1 << 16
clientPluginAuth uint32 = 1 << 19
)

caps := clientLongPassword | clientLongFlag | clientProtocol41 |
clientTransactions | clientSecureConn | clientMultiStatements | clientPluginAuth
authData := []byte("12345678abcdefghijklmnop")

payload := []byte{0x0a}
payload = append(payload, []byte("5.7.0-cdc-test")...)
payload = append(payload, 0x00)
payload = binary.LittleEndian.AppendUint32(payload, 1)
payload = append(payload, authData[:8]...)
payload = append(payload, 0x00)
payload = binary.LittleEndian.AppendUint16(payload, uint16(caps))
payload = append(payload, 0x21)
payload = binary.LittleEndian.AppendUint16(payload, 0x0002)
payload = binary.LittleEndian.AppendUint16(payload, uint16(caps>>16))
payload = append(payload, 21)
payload = append(payload, make([]byte, 10)...)
payload = append(payload, authData[8:21]...)
payload = append(payload, 0x00)
payload = append(payload, []byte("mysql_native_password")...)
payload = append(payload, 0x00)
return payload
}

func readMySQLPacket(conn net.Conn) (byte, []byte, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return 0, nil, err
}

length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16
payload := make([]byte, length)
if _, err := io.ReadFull(conn, payload); err != nil {
return 0, nil, err
}
return header[3], payload, nil
}

func writeMySQLPacket(conn net.Conn, sequence byte, payload []byte) error {
header := []byte{byte(len(payload)), byte(len(payload) >> 8), byte(len(payload) >> 16), sequence}
if _, err := conn.Write(header); err != nil {
return err
}
_, err := conn.Write(payload)
return err
}

func writeMySQLOK(conn net.Conn, sequence byte) error {
return writeMySQLPacket(conn, sequence, []byte{0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00})
}

func TestExecutor_BeginTx(t *testing.T) {
t.Run("SuccessfulBegin", func(t *testing.T) {
db, mock, err := sqlmock.New()
Expand Down Expand Up @@ -461,6 +636,36 @@ func TestExecutor_ExecSQL(t *testing.T) {
})
}

func TestExecutor_ExecSQLAfterTryConnUsesReuseQueryBuf(t *testing.T) {
server := startFakeMySQLServer(t)
host, port := server.addr(t)

cfg, err := makeMysqlConfig("user", "password", host, port, "5s")
require.NoError(t, err)
cfg.MaxAllowedPacket = 64 << 20

db, err := tryConn(cfg)
require.NoError(t, err)

executor := &Executor{conn: db}
defer func() {
require.NoError(t, executor.Close())
}()

sqlBuf := append(make([]byte, v2SQLBufReserved), []byte("CREATE DATABASE cdc_regression")...)
err = executor.ExecSQL(context.Background(), nil, sqlBuf, false)
require.NoError(t, err)

select {
case query := <-server.queries:
require.Equal(t, "CREATE DATABASE cdc_regression", query)
case err := <-server.errs:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timed out waiting for COM_QUERY")
}
}

func TestExecutor_Close(t *testing.T) {
t.Run("CloseWithActiveTransaction", func(t *testing.T) {
db, mock, err := sqlmock.New()
Expand Down
84 changes: 55 additions & 29 deletions pkg/cdc/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"fmt"
"math"
"math/rand"
"net"
"net/url"
"slices"
"strconv"
"strings"
Expand All @@ -44,6 +46,7 @@ import (
"github.com/matrixorigin/matrixone/pkg/txn/client"
v2 "github.com/matrixorigin/matrixone/pkg/util/metric/v2"
"github.com/matrixorigin/matrixone/pkg/vm/engine"
"github.com/matrixorigin/mysql"
)

// escapeSQLString escapes special characters in SQL string literals to prevent SQL injection.
Expand Down Expand Up @@ -557,10 +560,12 @@ func floatArrayToString[T float32 | float64](arr []T) string {

var OpenDbConn = func(user, password string, ip string, port int, timeout string) (db *sql.DB, err error) {
logutil.Info("cdc.util.open_db_conn", zap.String("timeout", timeout))
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?readTimeout=%s&timeout=%s&writeTimeout=%s&multiStatements=true",
user, password, ip, port, timeout, timeout, timeout)
cfg, err := makeMysqlConfig(user, password, ip, port, timeout)
if err != nil {
return nil, err
}
for i := 0; i < 3; i++ {
if db, err = tryConn(dsn); err == nil {
if db, err = tryConn(cfg); err == nil {
// TODO check table existence
return
}
Expand All @@ -571,23 +576,40 @@ var OpenDbConn = func(user, password string, ip string, port int, timeout string
return
}

var openDb = sql.Open
func makeMysqlConfig(user, password string, ip string, port int, timeout string) (*mysql.Config, error) {
timeoutDuration, err := time.ParseDuration(timeout)
if err != nil {
return nil, err
}
cfg := mysql.NewConfig()
cfg.User = user
cfg.Passwd = password
cfg.Net = "tcp"
cfg.Addr = net.JoinHostPort(ip, strconv.Itoa(port))
cfg.Timeout = timeoutDuration
cfg.ReadTimeout = timeoutDuration
cfg.WriteTimeout = timeoutDuration
cfg.MultiStatements = true
return cfg, nil
}

var tryConn = func(dsn string) (*sql.DB, error) {
db, err := openDb("mysql-mo", dsn)
var openDbWithConnector = sql.OpenDB

var tryConn = func(cfg *mysql.Config) (*sql.DB, error) {
connector, err := mysql.NewConnector(cfg)
if err != nil {
return nil, err
}
db := openDbWithConnector(connector)
db.SetConnMaxLifetime(time.Minute * 3)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
time.Sleep(time.Millisecond * 100)

//ping opens the connection
err = db.Ping()
if err != nil {
return nil, err
} else {
db.SetConnMaxLifetime(time.Minute * 3)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
time.Sleep(time.Millisecond * 100)

//ping opens the connection
err = db.Ping()
if err != nil {
return nil, err
}
}
return db, err
}
Expand Down Expand Up @@ -820,31 +842,35 @@ func ExtractUriInfo(

// compositedUriInfo uri according to the format: mysql://root:111@127.0.0.1:6001
// if valid, return true and extracted info
// !!!NOTE!!!
// user and password does not have the special character ( ':' '@' )
func compositedUriInfo(uri string, uriPrefix string) (bool, UriInfo) {
if !uriHasPrefix(uri, uriPrefix) {
return false, UriInfo{}
}
//locate user password
rest := uri[len(uriPrefix):]
seps := strings.Split(rest, "@")
if len(seps) != 2 || len(seps[0]) == 0 || len(seps[1]) == 0 {
atIdx := strings.LastIndex(rest, "@")
if atIdx <= 0 || atIdx == len(rest)-1 {
return false, UriInfo{}
}
seps2 := strings.Split(seps[0], ":")
if len(seps2) < 2 {
userInfo := rest[:atIdx]
hostInfo := rest[atIdx+1:]

colonIdx := strings.LastIndex(userInfo, ":")
if colonIdx <= 0 || colonIdx == len(userInfo)-1 {
return false, UriInfo{}
}
userName, err := url.PathUnescape(userInfo[:colonIdx])
if err != nil || userName == "" {
return false, UriInfo{}
}
userName := strings.Join(seps2[0:len(seps2)-1], ":")
password := seps2[len(seps2)-1]
passwordStart := len(uriPrefix) + len(userName) + 1
passwordEnd := passwordStart + len(password)
if passwordEnd > len(uri) || password != uri[passwordStart:passwordEnd] {
passwordStart := len(uriPrefix) + colonIdx + 1
passwordEnd := len(uriPrefix) + atIdx
password, err := url.PathUnescape(uri[passwordStart:passwordEnd])
if err != nil {
return false, UriInfo{}
}

sep3 := strings.Split(seps[1], ":")
sep3 := strings.Split(hostInfo, ":")
if len(sep3) != 2 || len(sep3[0]) == 0 || len(sep3[1]) == 0 {
return false, UriInfo{}
}
Expand Down
Loading
Loading