This commit is contained in:
root
2019-04-22 02:59:20 +00:00
commit beccf3fe43
25440 changed files with 4054998 additions and 0 deletions

View File

@ -0,0 +1,35 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"auth.go",
"conn.go",
"req.go",
"resp.go",
"stmt.go",
],
importmap = "go-common/vendor/github.com/siddontang/go-mysql/client",
importpath = "github.com/siddontang/go-mysql/client",
visibility = ["//visibility:public"],
deps = [
"//vendor/github.com/juju/errors:go_default_library",
"//vendor/github.com/siddontang/go-mysql/mysql:go_default_library",
"//vendor/github.com/siddontang/go-mysql/packet:go_default_library",
"//vendor/github.com/siddontang/go/hack:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

174
vendor/github.com/siddontang/go-mysql/client/auth.go generated vendored Normal file
View File

@ -0,0 +1,174 @@
package client
import (
"bytes"
"crypto/tls"
"encoding/binary"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/packet"
)
func (c *Conn) readInitialHandshake() error {
data, err := c.ReadPacket()
if err != nil {
return errors.Trace(err)
}
if data[0] == ERR_HEADER {
return errors.New("read initial handshake error")
}
if data[0] < MinProtocolVersion {
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
}
//skip mysql version
//mysql version end with 0x00
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
//connection id length is 4
c.connectionID = uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
pos += 4
c.salt = []byte{}
c.salt = append(c.salt, data[pos:pos+8]...)
//skip filter
pos += 8 + 1
//capability lower 2 bytes
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
pos += 2
if len(data) > pos {
//skip server charset
//c.charset = data[pos]
pos += 1
c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
pos += 2
//skip auth data len or [00]
//skip reserved (all [00])
pos += 10 + 1
// The documentation is ambiguous about the length.
// The official Python library uses the fixed length 12
// mysql-proxy also use 12
// which is not documented but seems to work.
c.salt = append(c.salt, data[pos:pos+12]...)
}
return nil
}
func (c *Conn) writeAuthHandshake() error {
// Adjust client capability flags based on server support
capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG
// To enable TLS / SSL
if c.TLSConfig != nil {
capability |= CLIENT_PLUGIN_AUTH
capability |= CLIENT_SSL
}
capability &= c.capability
//packet length
//capbility 4
//max-packet size 4
//charset 1
//reserved all[0] 23
length := 4 + 4 + 1 + 23
//username
length += len(c.user) + 1
//we only support secure connection
auth := CalcPassword(c.salt, []byte(c.password))
length += 1 + len(auth)
if len(c.db) > 0 {
capability |= CLIENT_CONNECT_WITH_DB
length += len(c.db) + 1
}
// mysql_native_password + null-terminated
length += 21 + 1
c.capability = capability
data := make([]byte, length+4)
//capability [32 bit]
data[4] = byte(capability)
data[5] = byte(capability >> 8)
data[6] = byte(capability >> 16)
data[7] = byte(capability >> 24)
//MaxPacketSize [32 bit] (none)
//data[8] = 0x00
//data[9] = 0x00
//data[10] = 0x00
//data[11] = 0x00
//Charset [1 byte]
//use default collation id 33 here, is utf-8
data[12] = byte(DEFAULT_COLLATION_ID)
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if c.TLSConfig != nil {
// Send TLS / SSL request packet
if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil {
return err
}
// Switch to TLS
tlsConn := tls.Client(c.Conn.Conn, c.TLSConfig)
if err := tlsConn.Handshake(); err != nil {
return err
}
currentSequence := c.Sequence
c.Conn = packet.NewConn(tlsConn)
c.Sequence = currentSequence
}
//Filler [23 bytes] (all 0x00)
pos := 13 + 23
//User [null terminated string]
if len(c.user) > 0 {
pos += copy(data[pos:], c.user)
}
data[pos] = 0x00
pos++
// auth [length encoded integer]
data[pos] = byte(len(auth))
pos += 1 + copy(data[pos+1:], auth)
// db [null terminated string]
if len(c.db) > 0 {
pos += copy(data[pos:], c.db)
data[pos] = 0x00
pos++
}
// Assume native client during response
pos += copy(data[pos:], "mysql_native_password")
data[pos] = 0x00
return c.WritePacket(data)
}

254
vendor/github.com/siddontang/go-mysql/client/conn.go generated vendored Normal file
View File

@ -0,0 +1,254 @@
package client
import (
"crypto/tls"
"fmt"
"net"
"strings"
"time"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/packet"
)
type Conn struct {
*packet.Conn
user string
password string
db string
TLSConfig *tls.Config
capability uint32
status uint16
charset string
salt []byte
connectionID uint32
}
func getNetProto(addr string) string {
proto := "tcp"
if strings.Contains(addr, "/") {
proto = "unix"
}
return proto
}
// Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
// Accepts a series of configuration functions as a variadic argument.
func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
proto := getNetProto(addr)
c := new(Conn)
var err error
conn, err := net.DialTimeout(proto, addr, 10*time.Second)
if err != nil {
return nil, errors.Trace(err)
}
c.Conn = packet.NewConn(conn)
c.user = user
c.password = password
c.db = dbName
//use default charset here, utf-8
c.charset = DEFAULT_CHARSET
// Apply configuration functions.
for i := range options {
options[i](c)
}
if err = c.handshake(); err != nil {
return nil, errors.Trace(err)
}
return c, nil
}
func (c *Conn) handshake() error {
var err error
if err = c.readInitialHandshake(); err != nil {
c.Close()
return errors.Trace(err)
}
if err := c.writeAuthHandshake(); err != nil {
c.Close()
return errors.Trace(err)
}
if _, err := c.readOK(); err != nil {
c.Close()
return errors.Trace(err)
}
return nil
}
func (c *Conn) Close() error {
return c.Conn.Close()
}
func (c *Conn) Ping() error {
if err := c.writeCommand(COM_PING); err != nil {
return errors.Trace(err)
}
if _, err := c.readOK(); err != nil {
return errors.Trace(err)
}
return nil
}
func (c *Conn) UseDB(dbName string) error {
if c.db == dbName {
return nil
}
if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil {
return errors.Trace(err)
}
if _, err := c.readOK(); err != nil {
return errors.Trace(err)
}
c.db = dbName
return nil
}
func (c *Conn) GetDB() string {
return c.db
}
func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
if len(args) == 0 {
return c.exec(command)
} else {
if s, err := c.Prepare(command); err != nil {
return nil, errors.Trace(err)
} else {
var r *Result
r, err = s.Execute(args...)
s.Close()
return r, err
}
}
}
func (c *Conn) Begin() error {
_, err := c.exec("BEGIN")
return errors.Trace(err)
}
func (c *Conn) Commit() error {
_, err := c.exec("COMMIT")
return errors.Trace(err)
}
func (c *Conn) Rollback() error {
_, err := c.exec("ROLLBACK")
return errors.Trace(err)
}
func (c *Conn) SetCharset(charset string) error {
if c.charset == charset {
return nil
}
if _, err := c.exec(fmt.Sprintf("SET NAMES %s", charset)); err != nil {
return errors.Trace(err)
} else {
c.charset = charset
return nil
}
}
func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
}
data, err := c.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
fs := make([]*Field, 0, 4)
var f *Field
if data[0] == ERR_HEADER {
return nil, c.handleErrorPacket(data)
} else {
for {
if data, err = c.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
// EOF Packet
if c.isEOFPacket(data) {
return fs, nil
}
if f, err = FieldData(data).Parse(); err != nil {
return nil, errors.Trace(err)
}
fs = append(fs, f)
}
}
return nil, fmt.Errorf("field list error")
}
func (c *Conn) SetAutoCommit() error {
if !c.IsAutoCommit() {
if _, err := c.exec("SET AUTOCOMMIT = 1"); err != nil {
return errors.Trace(err)
}
}
return nil
}
func (c *Conn) IsAutoCommit() bool {
return c.status&SERVER_STATUS_AUTOCOMMIT > 0
}
func (c *Conn) IsInTransaction() bool {
return c.status&SERVER_STATUS_IN_TRANS > 0
}
func (c *Conn) GetCharset() string {
return c.charset
}
func (c *Conn) GetConnectionID() uint32 {
return c.connectionID
}
func (c *Conn) HandleOKPacket(data []byte) *Result {
r, _ := c.handleOKPacket(data)
return r
}
func (c *Conn) HandleErrorPacket(data []byte) error {
return c.handleErrorPacket(data)
}
func (c *Conn) ReadOKPacket() (*Result, error) {
return c.readOK()
}
func (c *Conn) exec(query string) (*Result, error) {
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
return nil, errors.Trace(err)
}
return c.readResult(false)
}

72
vendor/github.com/siddontang/go-mysql/client/req.go generated vendored Normal file
View File

@ -0,0 +1,72 @@
package client
func (c *Conn) writeCommand(command byte) error {
c.ResetSequence()
return c.WritePacket([]byte{
0x01, //1 bytes long
0x00,
0x00,
0x00, //sequence
command,
})
}
func (c *Conn) writeCommandBuf(command byte, arg []byte) error {
c.ResetSequence()
length := len(arg) + 1
data := make([]byte, length+4)
data[4] = command
copy(data[5:], arg)
return c.WritePacket(data)
}
func (c *Conn) writeCommandStr(command byte, arg string) error {
c.ResetSequence()
length := len(arg) + 1
data := make([]byte, length+4)
data[4] = command
copy(data[5:], arg)
return c.WritePacket(data)
}
func (c *Conn) writeCommandUint32(command byte, arg uint32) error {
c.ResetSequence()
return c.WritePacket([]byte{
0x05, //5 bytes long
0x00,
0x00,
0x00, //sequence
command,
byte(arg),
byte(arg >> 8),
byte(arg >> 16),
byte(arg >> 24),
})
}
func (c *Conn) writeCommandStrStr(command byte, arg1 string, arg2 string) error {
c.ResetSequence()
data := make([]byte, 4, 6+len(arg1)+len(arg2))
data = append(data, command)
data = append(data, arg1...)
data = append(data, 0)
data = append(data, arg2...)
return c.WritePacket(data)
}

218
vendor/github.com/siddontang/go-mysql/client/resp.go generated vendored Normal file
View File

@ -0,0 +1,218 @@
package client
import (
"encoding/binary"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go/hack"
)
func (c *Conn) readUntilEOF() (err error) {
var data []byte
for {
data, err = c.ReadPacket()
if err != nil {
return
}
// EOF Packet
if c.isEOFPacket(data) {
return
}
}
return
}
func (c *Conn) isEOFPacket(data []byte) bool {
return data[0] == EOF_HEADER && len(data) <= 5
}
func (c *Conn) handleOKPacket(data []byte) (*Result, error) {
var n int
var pos int = 1
r := new(Result)
r.AffectedRows, _, n = LengthEncodedInt(data[pos:])
pos += n
r.InsertId, _, n = LengthEncodedInt(data[pos:])
pos += n
if c.capability&CLIENT_PROTOCOL_41 > 0 {
r.Status = binary.LittleEndian.Uint16(data[pos:])
c.status = r.Status
pos += 2
//todo:strict_mode, check warnings as error
//Warnings := binary.LittleEndian.Uint16(data[pos:])
//pos += 2
} else if c.capability&CLIENT_TRANSACTIONS > 0 {
r.Status = binary.LittleEndian.Uint16(data[pos:])
c.status = r.Status
pos += 2
}
//new ok package will check CLIENT_SESSION_TRACK too, but I don't support it now.
//skip info
return r, nil
}
func (c *Conn) handleErrorPacket(data []byte) error {
e := new(MyError)
var pos int = 1
e.Code = binary.LittleEndian.Uint16(data[pos:])
pos += 2
if c.capability&CLIENT_PROTOCOL_41 > 0 {
//skip '#'
pos++
e.State = hack.String(data[pos : pos+5])
pos += 5
}
e.Message = hack.String(data[pos:])
return e
}
func (c *Conn) readOK() (*Result, error) {
data, err := c.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
if data[0] == OK_HEADER {
return c.handleOKPacket(data)
} else if data[0] == ERR_HEADER {
return nil, c.handleErrorPacket(data)
} else {
return nil, errors.New("invalid ok packet")
}
}
func (c *Conn) readResult(binary bool) (*Result, error) {
data, err := c.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
if data[0] == OK_HEADER {
return c.handleOKPacket(data)
} else if data[0] == ERR_HEADER {
return nil, c.handleErrorPacket(data)
} else if data[0] == LocalInFile_HEADER {
return nil, ErrMalformPacket
}
return c.readResultset(data, binary)
}
func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
result := &Result{
Status: 0,
InsertId: 0,
AffectedRows: 0,
Resultset: &Resultset{},
}
// column count
count, _, n := LengthEncodedInt(data)
if n-len(data) != 0 {
return nil, ErrMalformPacket
}
result.Fields = make([]*Field, count)
result.FieldNames = make(map[string]int, count)
if err := c.readResultColumns(result); err != nil {
return nil, errors.Trace(err)
}
if err := c.readResultRows(result, binary); err != nil {
return nil, errors.Trace(err)
}
return result, nil
}
func (c *Conn) readResultColumns(result *Result) (err error) {
var i int = 0
var data []byte
for {
data, err = c.ReadPacket()
if err != nil {
return
}
// EOF Packet
if c.isEOFPacket(data) {
if c.capability&CLIENT_PROTOCOL_41 > 0 {
//result.Warnings = binary.LittleEndian.Uint16(data[1:])
//todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}
if i != len(result.Fields) {
err = ErrMalformPacket
}
return
}
result.Fields[i], err = FieldData(data).Parse()
if err != nil {
return
}
result.FieldNames[hack.String(result.Fields[i].Name)] = i
i++
}
}
func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) {
var data []byte
for {
data, err = c.ReadPacket()
if err != nil {
return
}
// EOF Packet
if c.isEOFPacket(data) {
if c.capability&CLIENT_PROTOCOL_41 > 0 {
//result.Warnings = binary.LittleEndian.Uint16(data[1:])
//todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}
break
}
result.RowDatas = append(result.RowDatas, data)
}
result.Values = make([][]interface{}, len(result.RowDatas))
for i := range result.Values {
result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary)
if err != nil {
return errors.Trace(err)
}
}
return nil
}

215
vendor/github.com/siddontang/go-mysql/client/stmt.go generated vendored Normal file
View File

@ -0,0 +1,215 @@
package client
import (
"encoding/binary"
"fmt"
"math"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
)
type Stmt struct {
conn *Conn
id uint32
query string
params int
columns int
}
func (s *Stmt) ParamNum() int {
return s.params
}
func (s *Stmt) ColumnNum() int {
return s.columns
}
func (s *Stmt) Execute(args ...interface{}) (*Result, error) {
if err := s.write(args...); err != nil {
return nil, errors.Trace(err)
}
return s.conn.readResult(true)
}
func (s *Stmt) Close() error {
if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil {
return errors.Trace(err)
}
return nil
}
func (s *Stmt) write(args ...interface{}) error {
paramsNum := s.params
if len(args) != paramsNum {
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
}
paramTypes := make([]byte, paramsNum<<1)
paramValues := make([][]byte, paramsNum)
//NULL-bitmap, length: (num-params+7)
nullBitmap := make([]byte, (paramsNum+7)>>3)
var length int = int(1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1))
var newParamBoundFlag byte = 0
for i := range args {
if args[i] == nil {
nullBitmap[i/8] |= (1 << (uint(i) % 8))
paramTypes[i<<1] = MYSQL_TYPE_NULL
continue
}
newParamBoundFlag = 1
switch v := args[i].(type) {
case int8:
paramTypes[i<<1] = MYSQL_TYPE_TINY
paramValues[i] = []byte{byte(v)}
case int16:
paramTypes[i<<1] = MYSQL_TYPE_SHORT
paramValues[i] = Uint16ToBytes(uint16(v))
case int32:
paramTypes[i<<1] = MYSQL_TYPE_LONG
paramValues[i] = Uint32ToBytes(uint32(v))
case int:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramValues[i] = Uint64ToBytes(uint64(v))
case int64:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramValues[i] = Uint64ToBytes(uint64(v))
case uint8:
paramTypes[i<<1] = MYSQL_TYPE_TINY
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = []byte{v}
case uint16:
paramTypes[i<<1] = MYSQL_TYPE_SHORT
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint16ToBytes(uint16(v))
case uint32:
paramTypes[i<<1] = MYSQL_TYPE_LONG
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint32ToBytes(uint32(v))
case uint:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint64ToBytes(uint64(v))
case uint64:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint64ToBytes(uint64(v))
case bool:
paramTypes[i<<1] = MYSQL_TYPE_TINY
if v {
paramValues[i] = []byte{1}
} else {
paramValues[i] = []byte{0}
}
case float32:
paramTypes[i<<1] = MYSQL_TYPE_FLOAT
paramValues[i] = Uint32ToBytes(math.Float32bits(v))
case float64:
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
paramValues[i] = Uint64ToBytes(math.Float64bits(v))
case string:
paramTypes[i<<1] = MYSQL_TYPE_STRING
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
case []byte:
paramTypes[i<<1] = MYSQL_TYPE_STRING
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
default:
return fmt.Errorf("invalid argument type %T", args[i])
}
length += len(paramValues[i])
}
data := make([]byte, 4, 4+length)
data = append(data, COM_STMT_EXECUTE)
data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24))
//flag: CURSOR_TYPE_NO_CURSOR
data = append(data, 0x00)
//iteration-count, always 1
data = append(data, 1, 0, 0, 0)
if s.params > 0 {
data = append(data, nullBitmap...)
//new-params-bound-flag
data = append(data, newParamBoundFlag)
if newParamBoundFlag == 1 {
//type of each parameter, length: num-params * 2
data = append(data, paramTypes...)
//value of each parameter
for _, v := range paramValues {
data = append(data, v...)
}
}
}
s.conn.ResetSequence()
return s.conn.WritePacket(data)
}
func (c *Conn) Prepare(query string) (*Stmt, error) {
if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil {
return nil, errors.Trace(err)
}
data, err := c.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
if data[0] == ERR_HEADER {
return nil, c.handleErrorPacket(data)
} else if data[0] != OK_HEADER {
return nil, ErrMalformPacket
}
s := new(Stmt)
s.conn = c
pos := 1
//for statement id
s.id = binary.LittleEndian.Uint32(data[pos:])
pos += 4
//number columns
s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
pos += 2
//number params
s.params = int(binary.LittleEndian.Uint16(data[pos:]))
pos += 2
//warnings
//warnings = binary.LittleEndian.Uint16(data[pos:])
if s.params > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
}
}
if s.columns > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
}
}
return s, nil
}