init
This commit is contained in:
35
vendor/github.com/siddontang/go-mysql/client/BUILD.bazel
generated
vendored
Normal file
35
vendor/github.com/siddontang/go-mysql/client/BUILD.bazel
generated
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"auth.go",
|
||||
"conn.go",
|
||||
"req.go",
|
||||
"resp.go",
|
||||
"stmt.go",
|
||||
],
|
||||
importmap = "go-common/vendor/github.com/siddontang/go-mysql/client",
|
||||
importpath = "github.com/siddontang/go-mysql/client",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//vendor/github.com/juju/errors:go_default_library",
|
||||
"//vendor/github.com/siddontang/go-mysql/mysql:go_default_library",
|
||||
"//vendor/github.com/siddontang/go-mysql/packet:go_default_library",
|
||||
"//vendor/github.com/siddontang/go/hack:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "package-srcs",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all-srcs",
|
||||
srcs = [":package-srcs"],
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
174
vendor/github.com/siddontang/go-mysql/client/auth.go
generated
vendored
Normal file
174
vendor/github.com/siddontang/go-mysql/client/auth.go
generated
vendored
Normal file
@ -0,0 +1,174 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/juju/errors"
|
||||
. "github.com/siddontang/go-mysql/mysql"
|
||||
"github.com/siddontang/go-mysql/packet"
|
||||
)
|
||||
|
||||
func (c *Conn) readInitialHandshake() error {
|
||||
data, err := c.ReadPacket()
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
if data[0] == ERR_HEADER {
|
||||
return errors.New("read initial handshake error")
|
||||
}
|
||||
|
||||
if data[0] < MinProtocolVersion {
|
||||
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
|
||||
}
|
||||
|
||||
//skip mysql version
|
||||
//mysql version end with 0x00
|
||||
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
|
||||
|
||||
//connection id length is 4
|
||||
c.connectionID = uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
||||
pos += 4
|
||||
|
||||
c.salt = []byte{}
|
||||
c.salt = append(c.salt, data[pos:pos+8]...)
|
||||
|
||||
//skip filter
|
||||
pos += 8 + 1
|
||||
|
||||
//capability lower 2 bytes
|
||||
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||||
|
||||
pos += 2
|
||||
|
||||
if len(data) > pos {
|
||||
//skip server charset
|
||||
//c.charset = data[pos]
|
||||
pos += 1
|
||||
|
||||
c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
|
||||
pos += 2
|
||||
|
||||
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
|
||||
|
||||
pos += 2
|
||||
|
||||
//skip auth data len or [00]
|
||||
//skip reserved (all [00])
|
||||
pos += 10 + 1
|
||||
|
||||
// The documentation is ambiguous about the length.
|
||||
// The official Python library uses the fixed length 12
|
||||
// mysql-proxy also use 12
|
||||
// which is not documented but seems to work.
|
||||
c.salt = append(c.salt, data[pos:pos+12]...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) writeAuthHandshake() error {
|
||||
// Adjust client capability flags based on server support
|
||||
capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
|
||||
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG
|
||||
|
||||
// To enable TLS / SSL
|
||||
if c.TLSConfig != nil {
|
||||
capability |= CLIENT_PLUGIN_AUTH
|
||||
capability |= CLIENT_SSL
|
||||
}
|
||||
|
||||
capability &= c.capability
|
||||
|
||||
//packet length
|
||||
//capbility 4
|
||||
//max-packet size 4
|
||||
//charset 1
|
||||
//reserved all[0] 23
|
||||
length := 4 + 4 + 1 + 23
|
||||
|
||||
//username
|
||||
length += len(c.user) + 1
|
||||
|
||||
//we only support secure connection
|
||||
auth := CalcPassword(c.salt, []byte(c.password))
|
||||
|
||||
length += 1 + len(auth)
|
||||
|
||||
if len(c.db) > 0 {
|
||||
capability |= CLIENT_CONNECT_WITH_DB
|
||||
|
||||
length += len(c.db) + 1
|
||||
}
|
||||
|
||||
// mysql_native_password + null-terminated
|
||||
length += 21 + 1
|
||||
|
||||
c.capability = capability
|
||||
|
||||
data := make([]byte, length+4)
|
||||
|
||||
//capability [32 bit]
|
||||
data[4] = byte(capability)
|
||||
data[5] = byte(capability >> 8)
|
||||
data[6] = byte(capability >> 16)
|
||||
data[7] = byte(capability >> 24)
|
||||
|
||||
//MaxPacketSize [32 bit] (none)
|
||||
//data[8] = 0x00
|
||||
//data[9] = 0x00
|
||||
//data[10] = 0x00
|
||||
//data[11] = 0x00
|
||||
|
||||
//Charset [1 byte]
|
||||
//use default collation id 33 here, is utf-8
|
||||
data[12] = byte(DEFAULT_COLLATION_ID)
|
||||
|
||||
// SSL Connection Request Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
||||
if c.TLSConfig != nil {
|
||||
// Send TLS / SSL request packet
|
||||
if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Switch to TLS
|
||||
tlsConn := tls.Client(c.Conn.Conn, c.TLSConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
currentSequence := c.Sequence
|
||||
c.Conn = packet.NewConn(tlsConn)
|
||||
c.Sequence = currentSequence
|
||||
}
|
||||
|
||||
//Filler [23 bytes] (all 0x00)
|
||||
pos := 13 + 23
|
||||
|
||||
//User [null terminated string]
|
||||
if len(c.user) > 0 {
|
||||
pos += copy(data[pos:], c.user)
|
||||
}
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
|
||||
// auth [length encoded integer]
|
||||
data[pos] = byte(len(auth))
|
||||
pos += 1 + copy(data[pos+1:], auth)
|
||||
|
||||
// db [null terminated string]
|
||||
if len(c.db) > 0 {
|
||||
pos += copy(data[pos:], c.db)
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
}
|
||||
|
||||
// Assume native client during response
|
||||
pos += copy(data[pos:], "mysql_native_password")
|
||||
data[pos] = 0x00
|
||||
|
||||
return c.WritePacket(data)
|
||||
}
|
254
vendor/github.com/siddontang/go-mysql/client/conn.go
generated
vendored
Normal file
254
vendor/github.com/siddontang/go-mysql/client/conn.go
generated
vendored
Normal file
@ -0,0 +1,254 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juju/errors"
|
||||
. "github.com/siddontang/go-mysql/mysql"
|
||||
"github.com/siddontang/go-mysql/packet"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
*packet.Conn
|
||||
|
||||
user string
|
||||
password string
|
||||
db string
|
||||
TLSConfig *tls.Config
|
||||
|
||||
capability uint32
|
||||
|
||||
status uint16
|
||||
|
||||
charset string
|
||||
|
||||
salt []byte
|
||||
|
||||
connectionID uint32
|
||||
}
|
||||
|
||||
func getNetProto(addr string) string {
|
||||
proto := "tcp"
|
||||
if strings.Contains(addr, "/") {
|
||||
proto = "unix"
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
// Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
|
||||
// Accepts a series of configuration functions as a variadic argument.
|
||||
func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
|
||||
proto := getNetProto(addr)
|
||||
|
||||
c := new(Conn)
|
||||
|
||||
var err error
|
||||
conn, err := net.DialTimeout(proto, addr, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
c.Conn = packet.NewConn(conn)
|
||||
c.user = user
|
||||
c.password = password
|
||||
c.db = dbName
|
||||
|
||||
//use default charset here, utf-8
|
||||
c.charset = DEFAULT_CHARSET
|
||||
|
||||
// Apply configuration functions.
|
||||
for i := range options {
|
||||
options[i](c)
|
||||
}
|
||||
|
||||
if err = c.handshake(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) handshake() error {
|
||||
var err error
|
||||
if err = c.readInitialHandshake(); err != nil {
|
||||
c.Close()
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
if err := c.writeAuthHandshake(); err != nil {
|
||||
c.Close()
|
||||
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
if _, err := c.readOK(); err != nil {
|
||||
c.Close()
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) Ping() error {
|
||||
if err := c.writeCommand(COM_PING); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
if _, err := c.readOK(); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) UseDB(dbName string) error {
|
||||
if c.db == dbName {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
if _, err := c.readOK(); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
c.db = dbName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) GetDB() string {
|
||||
return c.db
|
||||
}
|
||||
|
||||
func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
|
||||
if len(args) == 0 {
|
||||
return c.exec(command)
|
||||
} else {
|
||||
if s, err := c.Prepare(command); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
} else {
|
||||
var r *Result
|
||||
r, err = s.Execute(args...)
|
||||
s.Close()
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Begin() error {
|
||||
_, err := c.exec("BEGIN")
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
func (c *Conn) Commit() error {
|
||||
_, err := c.exec("COMMIT")
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
func (c *Conn) Rollback() error {
|
||||
_, err := c.exec("ROLLBACK")
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
func (c *Conn) SetCharset(charset string) error {
|
||||
if c.charset == charset {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := c.exec(fmt.Sprintf("SET NAMES %s", charset)); err != nil {
|
||||
return errors.Trace(err)
|
||||
} else {
|
||||
c.charset = charset
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
|
||||
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
data, err := c.ReadPacket()
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
fs := make([]*Field, 0, 4)
|
||||
var f *Field
|
||||
if data[0] == ERR_HEADER {
|
||||
return nil, c.handleErrorPacket(data)
|
||||
} else {
|
||||
for {
|
||||
if data, err = c.ReadPacket(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// EOF Packet
|
||||
if c.isEOFPacket(data) {
|
||||
return fs, nil
|
||||
}
|
||||
|
||||
if f, err = FieldData(data).Parse(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
fs = append(fs, f)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("field list error")
|
||||
}
|
||||
|
||||
func (c *Conn) SetAutoCommit() error {
|
||||
if !c.IsAutoCommit() {
|
||||
if _, err := c.exec("SET AUTOCOMMIT = 1"); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) IsAutoCommit() bool {
|
||||
return c.status&SERVER_STATUS_AUTOCOMMIT > 0
|
||||
}
|
||||
|
||||
func (c *Conn) IsInTransaction() bool {
|
||||
return c.status&SERVER_STATUS_IN_TRANS > 0
|
||||
}
|
||||
|
||||
func (c *Conn) GetCharset() string {
|
||||
return c.charset
|
||||
}
|
||||
|
||||
func (c *Conn) GetConnectionID() uint32 {
|
||||
return c.connectionID
|
||||
}
|
||||
|
||||
func (c *Conn) HandleOKPacket(data []byte) *Result {
|
||||
r, _ := c.handleOKPacket(data)
|
||||
return r
|
||||
}
|
||||
|
||||
func (c *Conn) HandleErrorPacket(data []byte) error {
|
||||
return c.handleErrorPacket(data)
|
||||
}
|
||||
|
||||
func (c *Conn) ReadOKPacket() (*Result, error) {
|
||||
return c.readOK()
|
||||
}
|
||||
|
||||
func (c *Conn) exec(query string) (*Result, error) {
|
||||
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
return c.readResult(false)
|
||||
}
|
72
vendor/github.com/siddontang/go-mysql/client/req.go
generated
vendored
Normal file
72
vendor/github.com/siddontang/go-mysql/client/req.go
generated
vendored
Normal file
@ -0,0 +1,72 @@
|
||||
package client
|
||||
|
||||
func (c *Conn) writeCommand(command byte) error {
|
||||
c.ResetSequence()
|
||||
|
||||
return c.WritePacket([]byte{
|
||||
0x01, //1 bytes long
|
||||
0x00,
|
||||
0x00,
|
||||
0x00, //sequence
|
||||
command,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) writeCommandBuf(command byte, arg []byte) error {
|
||||
c.ResetSequence()
|
||||
|
||||
length := len(arg) + 1
|
||||
|
||||
data := make([]byte, length+4)
|
||||
|
||||
data[4] = command
|
||||
|
||||
copy(data[5:], arg)
|
||||
|
||||
return c.WritePacket(data)
|
||||
}
|
||||
|
||||
func (c *Conn) writeCommandStr(command byte, arg string) error {
|
||||
c.ResetSequence()
|
||||
|
||||
length := len(arg) + 1
|
||||
|
||||
data := make([]byte, length+4)
|
||||
|
||||
data[4] = command
|
||||
|
||||
copy(data[5:], arg)
|
||||
|
||||
return c.WritePacket(data)
|
||||
}
|
||||
|
||||
func (c *Conn) writeCommandUint32(command byte, arg uint32) error {
|
||||
c.ResetSequence()
|
||||
|
||||
return c.WritePacket([]byte{
|
||||
0x05, //5 bytes long
|
||||
0x00,
|
||||
0x00,
|
||||
0x00, //sequence
|
||||
|
||||
command,
|
||||
|
||||
byte(arg),
|
||||
byte(arg >> 8),
|
||||
byte(arg >> 16),
|
||||
byte(arg >> 24),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) writeCommandStrStr(command byte, arg1 string, arg2 string) error {
|
||||
c.ResetSequence()
|
||||
|
||||
data := make([]byte, 4, 6+len(arg1)+len(arg2))
|
||||
|
||||
data = append(data, command)
|
||||
data = append(data, arg1...)
|
||||
data = append(data, 0)
|
||||
data = append(data, arg2...)
|
||||
|
||||
return c.WritePacket(data)
|
||||
}
|
218
vendor/github.com/siddontang/go-mysql/client/resp.go
generated
vendored
Normal file
218
vendor/github.com/siddontang/go-mysql/client/resp.go
generated
vendored
Normal file
@ -0,0 +1,218 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/juju/errors"
|
||||
. "github.com/siddontang/go-mysql/mysql"
|
||||
"github.com/siddontang/go/hack"
|
||||
)
|
||||
|
||||
func (c *Conn) readUntilEOF() (err error) {
|
||||
var data []byte
|
||||
|
||||
for {
|
||||
data, err = c.ReadPacket()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// EOF Packet
|
||||
if c.isEOFPacket(data) {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) isEOFPacket(data []byte) bool {
|
||||
return data[0] == EOF_HEADER && len(data) <= 5
|
||||
}
|
||||
|
||||
func (c *Conn) handleOKPacket(data []byte) (*Result, error) {
|
||||
var n int
|
||||
var pos int = 1
|
||||
|
||||
r := new(Result)
|
||||
|
||||
r.AffectedRows, _, n = LengthEncodedInt(data[pos:])
|
||||
pos += n
|
||||
r.InsertId, _, n = LengthEncodedInt(data[pos:])
|
||||
pos += n
|
||||
|
||||
if c.capability&CLIENT_PROTOCOL_41 > 0 {
|
||||
r.Status = binary.LittleEndian.Uint16(data[pos:])
|
||||
c.status = r.Status
|
||||
pos += 2
|
||||
|
||||
//todo:strict_mode, check warnings as error
|
||||
//Warnings := binary.LittleEndian.Uint16(data[pos:])
|
||||
//pos += 2
|
||||
} else if c.capability&CLIENT_TRANSACTIONS > 0 {
|
||||
r.Status = binary.LittleEndian.Uint16(data[pos:])
|
||||
c.status = r.Status
|
||||
pos += 2
|
||||
}
|
||||
|
||||
//new ok package will check CLIENT_SESSION_TRACK too, but I don't support it now.
|
||||
|
||||
//skip info
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleErrorPacket(data []byte) error {
|
||||
e := new(MyError)
|
||||
|
||||
var pos int = 1
|
||||
|
||||
e.Code = binary.LittleEndian.Uint16(data[pos:])
|
||||
pos += 2
|
||||
|
||||
if c.capability&CLIENT_PROTOCOL_41 > 0 {
|
||||
//skip '#'
|
||||
pos++
|
||||
e.State = hack.String(data[pos : pos+5])
|
||||
pos += 5
|
||||
}
|
||||
|
||||
e.Message = hack.String(data[pos:])
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
func (c *Conn) readOK() (*Result, error) {
|
||||
data, err := c.ReadPacket()
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if data[0] == OK_HEADER {
|
||||
return c.handleOKPacket(data)
|
||||
} else if data[0] == ERR_HEADER {
|
||||
return nil, c.handleErrorPacket(data)
|
||||
} else {
|
||||
return nil, errors.New("invalid ok packet")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) readResult(binary bool) (*Result, error) {
|
||||
data, err := c.ReadPacket()
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if data[0] == OK_HEADER {
|
||||
return c.handleOKPacket(data)
|
||||
} else if data[0] == ERR_HEADER {
|
||||
return nil, c.handleErrorPacket(data)
|
||||
} else if data[0] == LocalInFile_HEADER {
|
||||
return nil, ErrMalformPacket
|
||||
}
|
||||
|
||||
return c.readResultset(data, binary)
|
||||
}
|
||||
|
||||
func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
|
||||
result := &Result{
|
||||
Status: 0,
|
||||
InsertId: 0,
|
||||
AffectedRows: 0,
|
||||
|
||||
Resultset: &Resultset{},
|
||||
}
|
||||
|
||||
// column count
|
||||
count, _, n := LengthEncodedInt(data)
|
||||
|
||||
if n-len(data) != 0 {
|
||||
return nil, ErrMalformPacket
|
||||
}
|
||||
|
||||
result.Fields = make([]*Field, count)
|
||||
result.FieldNames = make(map[string]int, count)
|
||||
|
||||
if err := c.readResultColumns(result); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if err := c.readResultRows(result, binary); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *Conn) readResultColumns(result *Result) (err error) {
|
||||
var i int = 0
|
||||
var data []byte
|
||||
|
||||
for {
|
||||
data, err = c.ReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// EOF Packet
|
||||
if c.isEOFPacket(data) {
|
||||
if c.capability&CLIENT_PROTOCOL_41 > 0 {
|
||||
//result.Warnings = binary.LittleEndian.Uint16(data[1:])
|
||||
//todo add strict_mode, warning will be treat as error
|
||||
result.Status = binary.LittleEndian.Uint16(data[3:])
|
||||
c.status = result.Status
|
||||
}
|
||||
|
||||
if i != len(result.Fields) {
|
||||
err = ErrMalformPacket
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result.Fields[i], err = FieldData(data).Parse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result.FieldNames[hack.String(result.Fields[i].Name)] = i
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) {
|
||||
var data []byte
|
||||
|
||||
for {
|
||||
data, err = c.ReadPacket()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// EOF Packet
|
||||
if c.isEOFPacket(data) {
|
||||
if c.capability&CLIENT_PROTOCOL_41 > 0 {
|
||||
//result.Warnings = binary.LittleEndian.Uint16(data[1:])
|
||||
//todo add strict_mode, warning will be treat as error
|
||||
result.Status = binary.LittleEndian.Uint16(data[3:])
|
||||
c.status = result.Status
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
result.RowDatas = append(result.RowDatas, data)
|
||||
}
|
||||
|
||||
result.Values = make([][]interface{}, len(result.RowDatas))
|
||||
|
||||
for i := range result.Values {
|
||||
result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary)
|
||||
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
215
vendor/github.com/siddontang/go-mysql/client/stmt.go
generated
vendored
Normal file
215
vendor/github.com/siddontang/go-mysql/client/stmt.go
generated
vendored
Normal file
@ -0,0 +1,215 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/juju/errors"
|
||||
. "github.com/siddontang/go-mysql/mysql"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
conn *Conn
|
||||
id uint32
|
||||
query string
|
||||
|
||||
params int
|
||||
columns int
|
||||
}
|
||||
|
||||
func (s *Stmt) ParamNum() int {
|
||||
return s.params
|
||||
}
|
||||
|
||||
func (s *Stmt) ColumnNum() int {
|
||||
return s.columns
|
||||
}
|
||||
|
||||
func (s *Stmt) Execute(args ...interface{}) (*Result, error) {
|
||||
if err := s.write(args...); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return s.conn.readResult(true)
|
||||
}
|
||||
|
||||
func (s *Stmt) Close() error {
|
||||
if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stmt) write(args ...interface{}) error {
|
||||
paramsNum := s.params
|
||||
|
||||
if len(args) != paramsNum {
|
||||
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
|
||||
}
|
||||
|
||||
paramTypes := make([]byte, paramsNum<<1)
|
||||
paramValues := make([][]byte, paramsNum)
|
||||
|
||||
//NULL-bitmap, length: (num-params+7)
|
||||
nullBitmap := make([]byte, (paramsNum+7)>>3)
|
||||
|
||||
var length int = int(1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1))
|
||||
|
||||
var newParamBoundFlag byte = 0
|
||||
|
||||
for i := range args {
|
||||
if args[i] == nil {
|
||||
nullBitmap[i/8] |= (1 << (uint(i) % 8))
|
||||
paramTypes[i<<1] = MYSQL_TYPE_NULL
|
||||
continue
|
||||
}
|
||||
|
||||
newParamBoundFlag = 1
|
||||
|
||||
switch v := args[i].(type) {
|
||||
case int8:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_TINY
|
||||
paramValues[i] = []byte{byte(v)}
|
||||
case int16:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_SHORT
|
||||
paramValues[i] = Uint16ToBytes(uint16(v))
|
||||
case int32:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_LONG
|
||||
paramValues[i] = Uint32ToBytes(uint32(v))
|
||||
case int:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
|
||||
paramValues[i] = Uint64ToBytes(uint64(v))
|
||||
case int64:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
|
||||
paramValues[i] = Uint64ToBytes(uint64(v))
|
||||
case uint8:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_TINY
|
||||
paramTypes[(i<<1)+1] = 0x80
|
||||
paramValues[i] = []byte{v}
|
||||
case uint16:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_SHORT
|
||||
paramTypes[(i<<1)+1] = 0x80
|
||||
paramValues[i] = Uint16ToBytes(uint16(v))
|
||||
case uint32:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_LONG
|
||||
paramTypes[(i<<1)+1] = 0x80
|
||||
paramValues[i] = Uint32ToBytes(uint32(v))
|
||||
case uint:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
|
||||
paramTypes[(i<<1)+1] = 0x80
|
||||
paramValues[i] = Uint64ToBytes(uint64(v))
|
||||
case uint64:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
|
||||
paramTypes[(i<<1)+1] = 0x80
|
||||
paramValues[i] = Uint64ToBytes(uint64(v))
|
||||
case bool:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_TINY
|
||||
if v {
|
||||
paramValues[i] = []byte{1}
|
||||
} else {
|
||||
paramValues[i] = []byte{0}
|
||||
|
||||
}
|
||||
case float32:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_FLOAT
|
||||
paramValues[i] = Uint32ToBytes(math.Float32bits(v))
|
||||
case float64:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
|
||||
paramValues[i] = Uint64ToBytes(math.Float64bits(v))
|
||||
case string:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_STRING
|
||||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
|
||||
case []byte:
|
||||
paramTypes[i<<1] = MYSQL_TYPE_STRING
|
||||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
|
||||
default:
|
||||
return fmt.Errorf("invalid argument type %T", args[i])
|
||||
}
|
||||
|
||||
length += len(paramValues[i])
|
||||
}
|
||||
|
||||
data := make([]byte, 4, 4+length)
|
||||
|
||||
data = append(data, COM_STMT_EXECUTE)
|
||||
data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24))
|
||||
|
||||
//flag: CURSOR_TYPE_NO_CURSOR
|
||||
data = append(data, 0x00)
|
||||
|
||||
//iteration-count, always 1
|
||||
data = append(data, 1, 0, 0, 0)
|
||||
|
||||
if s.params > 0 {
|
||||
data = append(data, nullBitmap...)
|
||||
|
||||
//new-params-bound-flag
|
||||
data = append(data, newParamBoundFlag)
|
||||
|
||||
if newParamBoundFlag == 1 {
|
||||
//type of each parameter, length: num-params * 2
|
||||
data = append(data, paramTypes...)
|
||||
|
||||
//value of each parameter
|
||||
for _, v := range paramValues {
|
||||
data = append(data, v...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.conn.ResetSequence()
|
||||
|
||||
return s.conn.WritePacket(data)
|
||||
}
|
||||
|
||||
func (c *Conn) Prepare(query string) (*Stmt, error) {
|
||||
if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
data, err := c.ReadPacket()
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if data[0] == ERR_HEADER {
|
||||
return nil, c.handleErrorPacket(data)
|
||||
} else if data[0] != OK_HEADER {
|
||||
return nil, ErrMalformPacket
|
||||
}
|
||||
|
||||
s := new(Stmt)
|
||||
s.conn = c
|
||||
|
||||
pos := 1
|
||||
|
||||
//for statement id
|
||||
s.id = binary.LittleEndian.Uint32(data[pos:])
|
||||
pos += 4
|
||||
|
||||
//number columns
|
||||
s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
|
||||
pos += 2
|
||||
|
||||
//number params
|
||||
s.params = int(binary.LittleEndian.Uint16(data[pos:]))
|
||||
pos += 2
|
||||
|
||||
//warnings
|
||||
//warnings = binary.LittleEndian.Uint16(data[pos:])
|
||||
|
||||
if s.params > 0 {
|
||||
if err := s.conn.readUntilEOF(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.columns > 0 {
|
||||
if err := s.conn.readUntilEOF(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
Reference in New Issue
Block a user