Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • master
  • v1.0.0
  • v1.0.1
  • v1.1.0
  • v1.10.0
  • v1.10.1
  • v1.10.2
  • v1.11.0
  • v1.12.0
  • v1.12.1
  • v1.12.2
  • v1.12.3
  • v1.12.4
  • v1.12.5
  • v1.12.6
  • v1.12.7
  • v1.12.8
  • v1.13.0
  • v1.13.1
  • v1.13.2
  • v1.14.0
  • v1.15.0
  • v1.15.1
  • v1.15.10
  • v1.15.11
  • v1.15.12
  • v1.15.13
  • v1.15.14
  • v1.15.15
  • v1.15.16
  • v1.15.17
  • v1.15.2
  • v1.15.3
  • v1.15.4
  • v1.15.5
  • v1.15.6
  • v1.15.7
  • v1.15.8
  • v1.15.9
  • v1.16.0
  • v1.16.1
  • v1.17.0
  • v1.18.0
  • v1.18.1
  • v1.18.2
  • v1.19.0
  • v1.19.1
  • v1.19.2
  • v1.19.3
  • v1.19.4
  • v1.2.0
  • v1.20.0
  • v1.20.1
  • v1.20.2
  • v1.20.3
  • v1.21.0
  • v1.21.1
  • v1.22.0
  • v1.23.0
  • v1.23.1
  • v1.23.2
  • v1.3.0
  • v1.3.1
  • v1.3.2
  • v1.4.0
  • v1.5.0
  • v1.5.1
  • v1.6.0
  • v1.6.1
  • v1.7.0
  • v1.7.1
  • v1.7.2
  • v1.7.3
  • v1.8.0
  • v1.8.1
  • v1.9.0
76 results

Target

Select target project
  • oss/libraries/go/services/job-queues
1 result
Select Git revision
  • master
  • v1.0.0
  • v1.0.1
  • v1.1.0
  • v1.10.0
  • v1.10.1
  • v1.10.2
  • v1.11.0
  • v1.12.0
  • v1.12.1
  • v1.12.2
  • v1.12.3
  • v1.12.4
  • v1.12.5
  • v1.12.6
  • v1.12.7
  • v1.12.8
  • v1.13.0
  • v1.13.1
  • v1.13.2
  • v1.14.0
  • v1.15.0
  • v1.15.1
  • v1.15.10
  • v1.15.11
  • v1.15.12
  • v1.15.13
  • v1.15.14
  • v1.15.15
  • v1.15.16
  • v1.15.17
  • v1.15.2
  • v1.15.3
  • v1.15.4
  • v1.15.5
  • v1.15.6
  • v1.15.7
  • v1.15.8
  • v1.15.9
  • v1.16.0
  • v1.16.1
  • v1.17.0
  • v1.18.0
  • v1.18.1
  • v1.18.2
  • v1.19.0
  • v1.19.1
  • v1.19.2
  • v1.19.3
  • v1.19.4
  • v1.2.0
  • v1.20.0
  • v1.20.1
  • v1.20.2
  • v1.20.3
  • v1.21.0
  • v1.21.1
  • v1.22.0
  • v1.23.0
  • v1.23.1
  • v1.23.2
  • v1.3.0
  • v1.3.1
  • v1.3.2
  • v1.4.0
  • v1.5.0
  • v1.5.1
  • v1.6.0
  • v1.6.1
  • v1.7.0
  • v1.7.1
  • v1.7.2
  • v1.7.3
  • v1.8.0
  • v1.8.1
  • v1.9.0
76 results
Show changes
Showing
with 413 additions and 238 deletions
...@@ -16,9 +16,9 @@ import ( ...@@ -16,9 +16,9 @@ import (
) )
func TestErrorsSetLogger(t *testing.T) { func TestErrorsSetLogger(t *testing.T) {
previous := errLog previous := defaultLogger
defer func() { defer func() {
errLog = previous defaultLogger = previous
}() }()
// set up logger // set up logger
...@@ -28,7 +28,7 @@ func TestErrorsSetLogger(t *testing.T) { ...@@ -28,7 +28,7 @@ func TestErrorsSetLogger(t *testing.T) {
// print // print
SetLogger(logger) SetLogger(logger)
errLog.Print("test") defaultLogger.Print("test")
// check result // check result
if actual := buffer.String(); actual != expected { if actual := buffer.String(); actual != expected {
......
...@@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeBit: case fieldTypeBit:
return "BIT" return "BIT"
case fieldTypeBLOB: case fieldTypeBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "TEXT" return "TEXT"
} }
return "BLOB" return "BLOB"
...@@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeGeometry: case fieldTypeGeometry:
return "GEOMETRY" return "GEOMETRY"
case fieldTypeInt24: case fieldTypeInt24:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED MEDIUMINT"
}
return "MEDIUMINT" return "MEDIUMINT"
case fieldTypeJSON: case fieldTypeJSON:
return "JSON" return "JSON"
...@@ -46,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -46,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "INT" return "INT"
case fieldTypeLongBLOB: case fieldTypeLongBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "LONGTEXT" return "LONGTEXT"
} }
return "LONGBLOB" return "LONGBLOB"
...@@ -56,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -56,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "BIGINT" return "BIGINT"
case fieldTypeMediumBLOB: case fieldTypeMediumBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "MEDIUMTEXT" return "MEDIUMTEXT"
} }
return "MEDIUMBLOB" return "MEDIUMBLOB"
...@@ -74,7 +77,12 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -74,7 +77,12 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "SMALLINT" return "SMALLINT"
case fieldTypeString: case fieldTypeString:
if mf.charSet == collations[binaryCollation] { if mf.flags&flagEnum != 0 {
return "ENUM"
} else if mf.flags&flagSet != 0 {
return "SET"
}
if mf.charSet == binaryCollationID {
return "BINARY" return "BINARY"
} }
return "CHAR" return "CHAR"
...@@ -88,17 +96,17 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -88,17 +96,17 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "TINYINT" return "TINYINT"
case fieldTypeTinyBLOB: case fieldTypeTinyBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "TINYTEXT" return "TINYTEXT"
} }
return "TINYBLOB" return "TINYBLOB"
case fieldTypeVarChar: case fieldTypeVarChar:
if mf.charSet == collations[binaryCollation] { if mf.charSet == binaryCollationID {
return "VARBINARY" return "VARBINARY"
} }
return "VARCHAR" return "VARCHAR"
case fieldTypeVarString: case fieldTypeVarString:
if mf.charSet == collations[binaryCollation] { if mf.charSet == binaryCollationID {
return "VARBINARY" return "VARBINARY"
} }
return "VARCHAR" return "VARCHAR"
...@@ -123,7 +131,9 @@ var ( ...@@ -123,7 +131,9 @@ var (
scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0)) scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0)) scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) scanTypeString = reflect.TypeOf("")
scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeUnknown = reflect.TypeOf(new(interface{})) scanTypeUnknown = reflect.TypeOf(new(interface{}))
) )
...@@ -187,12 +197,18 @@ func (mf *mysqlField) scanType() reflect.Type { ...@@ -187,12 +197,18 @@ func (mf *mysqlField) scanType() reflect.Type {
} }
return scanTypeNullFloat return scanTypeNullFloat
case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB,
fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
if mf.charSet == binaryCollationID {
return scanTypeBytes
}
fallthrough
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime:
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, if mf.flags&flagNotNULL != 0 {
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, return scanTypeString
fieldTypeTime: }
return scanTypeRawBytes return scanTypeNullString
case fieldTypeDate, fieldTypeNewDate, case fieldTypeDate, fieldTypeNewDate,
fieldTypeTimestamp, fieldTypeDateTime: fieldTypeTimestamp, fieldTypeDateTime:
......
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build gofuzz
// +build gofuzz
package mysql
import (
"database/sql"
)
func Fuzz(data []byte) int {
db, err := sql.Open("mysql", string(data))
if err != nil {
return 0
}
db.Close()
return 1
}
module github.com/go-sql-driver/mysql module github.com/go-sql-driver/mysql
go 1.13 go 1.18
require filippo.io/edwards25519 v1.1.0
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
...@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) { ...@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {
const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
func (mc *mysqlConn) handleInFileRequest(name string) (err error) { func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader var rdr io.Reader
var data []byte var data []byte
packetSize := defaultPacketSize packetSize := defaultPacketSize
...@@ -116,10 +116,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -116,10 +116,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
defer deferredClose(&err, cl) defer deferredClose(&err, cl)
} }
} else { } else {
err = fmt.Errorf("Reader '%s' is <nil>", name) err = fmt.Errorf("reader '%s' is <nil>", name)
} }
} else { } else {
err = fmt.Errorf("Reader '%s' is not registered", name) err = fmt.Errorf("reader '%s' is not registered", name)
} }
} else { // File } else { // File
name = strings.Trim(name, `"`) name = strings.Trim(name, `"`)
...@@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
for err == nil { for err == nil {
n, err = rdr.Read(data[4:]) n, err = rdr.Read(data[4:])
if n > 0 { if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
return ioErr return ioErr
} }
} }
...@@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if data == nil { if data == nil {
data = make([]byte, 4) data = make([]byte, 4)
} }
if ioErr := mc.writePacket(data[:4]); ioErr != nil { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr return ioErr
} }
...@@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
return mc.readResultOK() return mc.readResultOK()
} }
mc.readPacket() mc.conn().readPacket()
return err return err
} }
...@@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) { ...@@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
} }
nt.Valid = false nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value) return fmt.Errorf("can't convert %T to time.Time", value)
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
......
...@@ -14,10 +14,10 @@ import ( ...@@ -14,10 +14,10 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
"strconv"
"time" "time"
) )
...@@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
errLog.Print(err) mc.cfg.Logger.Print(err)
mc.Close() mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
...@@ -44,6 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -44,6 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// check packet sync [8 bit] // check packet sync [8 bit]
if data[3] != mc.sequence { if data[3] != mc.sequence {
mc.Close()
if data[3] > mc.sequence { if data[3] > mc.sequence {
return nil, ErrPktSyncMul return nil, ErrPktSyncMul
} }
...@@ -56,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -56,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if pktLen == 0 { if pktLen == 0 {
// there was no previous packet // there was no previous packet
if prevData == nil { if prevData == nil {
errLog.Print(ErrMalformPkt) mc.cfg.Logger.Print(ErrMalformPkt)
mc.Close() mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
...@@ -70,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -70,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
errLog.Print(err) mc.cfg.Logger.Print(err)
mc.Close() mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
...@@ -97,34 +98,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { ...@@ -97,34 +98,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
return ErrPktTooLarge return ErrPktTooLarge
} }
// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.reset {
mc.reset = false
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
if mc.cfg.CheckConnLiveness {
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
}
}
if err != nil {
errLog.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
}
for { for {
var size int var size int
if pktLen >= maxPacketSize { if pktLen >= maxPacketSize {
...@@ -161,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { ...@@ -161,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
// Handle error // Handle error
if err == nil { // n != len(data) if err == nil { // n != len(data)
mc.cleanup() mc.cleanup()
errLog.Print(ErrMalformPkt) mc.cfg.Logger.Print(ErrMalformPkt)
} else { } else {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return cerr return cerr
...@@ -171,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { ...@@ -171,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
return errBadConnNoWrite return errBadConnNoWrite
} }
mc.cleanup() mc.cleanup()
errLog.Print(err) mc.cfg.Logger.Print(err)
} }
return ErrInvalidConn return ErrInvalidConn
} }
...@@ -239,7 +212,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro ...@@ -239,7 +212,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
// reserved (all [00]) [10 bytes] // reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10 pos += 1 + 2 + 2 + 1 + 10
// second part of the password cipher [mininum 13 bytes], // second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8) // where len=MAX(13, length of auth-plugin-data - 8)
// //
// The web documentation is ambiguous about the length. However, // The web documentation is ambiguous about the length. However,
...@@ -285,6 +258,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -285,6 +258,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles | clientLocalFiles |
clientPluginAuth | clientPluginAuth |
clientMultiResults | clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag mc.flags&clientLongFlag
if mc.cfg.ClientFoundRows { if mc.cfg.ClientFoundRows {
...@@ -318,11 +292,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -318,11 +292,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1 pktLen += n + 1
} }
// encode length of the connection attributes
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
// Calculate packet length and get buffer with that size // Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -338,14 +318,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -338,14 +318,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[10] = 0x00 data[10] = 0x00
data[11] = 0x00 data[11] = 0x00
// Charset [1 byte] // Collation ID [1 byte]
cname := mc.cfg.Collation
if cname == "" {
cname = defaultCollation
}
var found bool var found bool
data[12], found = collations[mc.cfg.Collation] data[12], found = collations[cname]
if !found { if !found {
// Note possibility for false negatives: // Note possibility for false negatives:
// could be triggered although the collation is valid if the // could be triggered although the collation is valid if the
// collations map does not contain entries the server supports. // collations map does not contain entries the server supports.
return errors.New("unknown collation") return fmt.Errorf("unknown collation: %q", cname)
} }
// Filler [23 bytes] (all 0x00) // Filler [23 bytes] (all 0x00)
...@@ -394,6 +378,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -394,6 +378,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00 data[pos] = 0x00
pos++ pos++
// Connection Attributes
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
// Send Auth packet // Send Auth packet
return mc.writePacket(data[:pos]) return mc.writePacket(data[:pos])
} }
...@@ -404,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { ...@@ -404,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
data, err := mc.buf.takeSmallBuffer(pktLen) data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -424,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { ...@@ -424,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
data, err := mc.buf.takeSmallBuffer(4 + 1) data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -443,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { ...@@ -443,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
data, err := mc.buf.takeBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -464,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { ...@@ -464,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -495,7 +483,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { ...@@ -495,7 +483,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
switch data[0] { switch data[0] {
case iOK: case iOK:
return nil, "", mc.handleOkPacket(data) // resultUnchanged, since auth happens before any queries or
// commands have been executed.
return nil, "", mc.resultUnchanged().handleOkPacket(data)
case iAuthMoreData: case iAuthMoreData:
return data[1:], "", err return data[1:], "", err
...@@ -518,9 +508,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { ...@@ -518,9 +508,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
} }
} }
// Returns error if Packet is not an 'Result OK'-Packet // Returns error if Packet is not a 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error { func (mc *okHandler) readResultOK() error {
data, err := mc.readPacket() data, err := mc.conn().readPacket()
if err != nil { if err != nil {
return err return err
} }
...@@ -528,13 +518,17 @@ func (mc *mysqlConn) readResultOK() error { ...@@ -528,13 +518,17 @@ func (mc *mysqlConn) readResultOK() error {
if data[0] == iOK { if data[0] == iOK {
return mc.handleOkPacket(data) return mc.handleOkPacket(data)
} }
return mc.handleErrorPacket(data) return mc.conn().handleErrorPacket(data)
} }
// Result Set Header Packet // Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket() // handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
data, err := mc.conn().readPacket()
if err == nil { if err == nil {
switch data[0] { switch data[0] {
...@@ -542,20 +536,17 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { ...@@ -542,20 +536,17 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
return 0, mc.handleOkPacket(data) return 0, mc.handleOkPacket(data)
case iERR: case iERR:
return 0, mc.handleErrorPacket(data) return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile: case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:])) return 0, mc.handleInFileRequest(string(data[1:]))
} }
// column count // column count
num, _, n := readLengthEncodedInteger(data) num, _, _ := readLengthEncodedInteger(data)
if n-len(data) == 0 { // ignore remaining data in the packet. see #1478.
return int(num), nil return int(num), nil
} }
return 0, ErrMalformPkt
}
return 0, err return 0, err
} }
...@@ -607,18 +598,61 @@ func readStatus(b []byte) statusFlag { ...@@ -607,18 +598,61 @@ func readStatus(b []byte) statusFlag {
return statusFlag(b[0]) | statusFlag(b[1])<<8 return statusFlag(b[0]) | statusFlag(b[1])<<8
} }
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
// need to be cleared first (e.g. during authentication, or while additional
// resultsets are being fetched.)
func (mc *mysqlConn) resultUnchanged() *okHandler {
return (*okHandler)(mc)
}
// okHandler represents the state of the connection when mysqlConn.result has
// been prepared for processing of OK packets.
//
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
// callpaths must either:
//
// 1. first clear it using clearResult(), or
// 2. confirm that they don't need to (by calling resultUnchanged()).
//
// Both return an instance of type *okHandler.
type okHandler mysqlConn
// Exposes the underlying type's methods.
func (mc *okHandler) conn() *mysqlConn {
return (*mysqlConn)(mc)
}
// clearResult clears the connection's stored affectedRows and insertIds
// fields.
//
// It returns a handler that can process OK responses.
func (mc *mysqlConn) clearResult() *okHandler {
mc.result = mysqlResult{}
return (*okHandler)(mc)
}
// Ok Packet // Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error { func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int var n, m int
var affectedRows, insertId uint64
// 0x00 [1 byte] // 0x00 [1 byte]
// Affected rows [Length Coded Binary] // Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary] // Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) insertId, _, m = readLengthEncodedInteger(data[1+n:])
// Update for the current statement result (only used by
// readResultSetHeaderPacket).
if len(mc.result.affectedRows) > 0 {
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.result.insertIds) > 0 {
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}
// server_status [2 bytes] // server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2]) mc.status = readStatus(data[1+n+m : 1+n+m+2])
...@@ -769,7 +803,8 @@ func (rows *textRows) readRow(dest []driver.Value) error { ...@@ -769,7 +803,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {
for i := range dest { for i := range dest {
// Read bytes and convert to string // Read bytes and convert to string
dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) var buf []byte
buf, isNull, n, err = readLengthEncodedString(data[pos:])
pos += n pos += n
if err != nil { if err != nil {
...@@ -781,19 +816,40 @@ func (rows *textRows) readRow(dest []driver.Value) error { ...@@ -781,19 +816,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
continue continue
} }
if !mc.parseTime {
continue
}
// Parse time field
switch rows.rs.columns[i].fieldType { switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp, case fieldTypeTimestamp,
fieldTypeDateTime, fieldTypeDateTime,
fieldTypeDate, fieldTypeDate,
fieldTypeNewDate: fieldTypeNewDate:
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { if mc.parseTime {
return err dest[i], err = parseDateTime(buf, mc.cfg.Loc)
} else {
dest[i] = buf
} }
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
} else {
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
}
case fieldTypeFloat:
var d float64
d, err = strconv.ParseFloat(string(buf), 32)
dest[i] = float32(d)
case fieldTypeDouble:
dest[i], err = strconv.ParseFloat(string(buf), 64)
default:
dest[i] = buf
}
if err != nil {
return err
} }
} }
...@@ -938,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -938,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
} }
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -1116,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -1116,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() { if v.IsZero() {
b = append(b, "0000-00-00"...) b = append(b, "0000-00-00"...)
} else { } else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc)) b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil { if err != nil {
return err return err
} }
...@@ -1137,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -1137,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if valuesCap != cap(paramValues) { if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...) data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil { if err = mc.buf.store(data); err != nil {
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
} }
...@@ -1149,7 +1205,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -1149,7 +1205,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
return mc.writePacket(data) return mc.writePacket(data)
} }
func (mc *mysqlConn) discardResults() error { // For each remaining resultset in the stream, discards its rows and updates
// mc.affectedRows and mc.insertIds.
func (mc *okHandler) discardResults() error {
for mc.status&statusMoreResultsExists != 0 { for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket() resLen, err := mc.readResultSetHeaderPacket()
if err != nil { if err != nil {
...@@ -1157,11 +1215,11 @@ func (mc *mysqlConn) discardResults() error { ...@@ -1157,11 +1215,11 @@ func (mc *mysqlConn) discardResults() error {
} }
if resLen > 0 { if resLen > 0 {
// columns // columns
if err := mc.readUntilEOF(); err != nil { if err := mc.conn().readUntilEOF(); err != nil {
return err return err
} }
// rows // rows
if err := mc.readUntilEOF(); err != nil { if err := mc.conn().readUntilEOF(); err != nil {
return err return err
} }
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.