Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
Loading items

Target

Select target project
  • oss/libraries/go/services/job-queues
1 result
Select Git revision
Loading items
Show changes
Showing
with 413 additions and 238 deletions
...@@ -16,9 +16,9 @@ import ( ...@@ -16,9 +16,9 @@ import (
) )
func TestErrorsSetLogger(t *testing.T) { func TestErrorsSetLogger(t *testing.T) {
previous := errLog previous := defaultLogger
defer func() { defer func() {
errLog = previous defaultLogger = previous
}() }()
// set up logger // set up logger
...@@ -28,7 +28,7 @@ func TestErrorsSetLogger(t *testing.T) { ...@@ -28,7 +28,7 @@ func TestErrorsSetLogger(t *testing.T) {
// print // print
SetLogger(logger) SetLogger(logger)
errLog.Print("test") defaultLogger.Print("test")
// check result // check result
if actual := buffer.String(); actual != expected { if actual := buffer.String(); actual != expected {
......
...@@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeBit: case fieldTypeBit:
return "BIT" return "BIT"
case fieldTypeBLOB: case fieldTypeBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "TEXT" return "TEXT"
} }
return "BLOB" return "BLOB"
...@@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeGeometry: case fieldTypeGeometry:
return "GEOMETRY" return "GEOMETRY"
case fieldTypeInt24: case fieldTypeInt24:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED MEDIUMINT"
}
return "MEDIUMINT" return "MEDIUMINT"
case fieldTypeJSON: case fieldTypeJSON:
return "JSON" return "JSON"
...@@ -46,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -46,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "INT" return "INT"
case fieldTypeLongBLOB: case fieldTypeLongBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "LONGTEXT" return "LONGTEXT"
} }
return "LONGBLOB" return "LONGBLOB"
...@@ -56,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -56,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "BIGINT" return "BIGINT"
case fieldTypeMediumBLOB: case fieldTypeMediumBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "MEDIUMTEXT" return "MEDIUMTEXT"
} }
return "MEDIUMBLOB" return "MEDIUMBLOB"
...@@ -74,7 +77,12 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -74,7 +77,12 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "SMALLINT" return "SMALLINT"
case fieldTypeString: case fieldTypeString:
if mf.charSet == collations[binaryCollation] { if mf.flags&flagEnum != 0 {
return "ENUM"
} else if mf.flags&flagSet != 0 {
return "SET"
}
if mf.charSet == binaryCollationID {
return "BINARY" return "BINARY"
} }
return "CHAR" return "CHAR"
...@@ -88,17 +96,17 @@ func (mf *mysqlField) typeDatabaseName() string { ...@@ -88,17 +96,17 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "TINYINT" return "TINYINT"
case fieldTypeTinyBLOB: case fieldTypeTinyBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "TINYTEXT" return "TINYTEXT"
} }
return "TINYBLOB" return "TINYBLOB"
case fieldTypeVarChar: case fieldTypeVarChar:
if mf.charSet == collations[binaryCollation] { if mf.charSet == binaryCollationID {
return "VARBINARY" return "VARBINARY"
} }
return "VARCHAR" return "VARCHAR"
case fieldTypeVarString: case fieldTypeVarString:
if mf.charSet == collations[binaryCollation] { if mf.charSet == binaryCollationID {
return "VARBINARY" return "VARBINARY"
} }
return "VARCHAR" return "VARCHAR"
...@@ -123,7 +131,9 @@ var ( ...@@ -123,7 +131,9 @@ var (
scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0)) scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0)) scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) scanTypeString = reflect.TypeOf("")
scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeUnknown = reflect.TypeOf(new(interface{})) scanTypeUnknown = reflect.TypeOf(new(interface{}))
) )
...@@ -187,12 +197,18 @@ func (mf *mysqlField) scanType() reflect.Type { ...@@ -187,12 +197,18 @@ func (mf *mysqlField) scanType() reflect.Type {
} }
return scanTypeNullFloat return scanTypeNullFloat
case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB,
fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
if mf.charSet == binaryCollationID {
return scanTypeBytes
}
fallthrough
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime:
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, if mf.flags&flagNotNULL != 0 {
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, return scanTypeString
fieldTypeTime: }
return scanTypeRawBytes return scanTypeNullString
case fieldTypeDate, fieldTypeNewDate, case fieldTypeDate, fieldTypeNewDate,
fieldTypeTimestamp, fieldTypeDateTime: fieldTypeTimestamp, fieldTypeDateTime:
......
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build gofuzz
// +build gofuzz
package mysql
import (
"database/sql"
)
func Fuzz(data []byte) int {
db, err := sql.Open("mysql", string(data))
if err != nil {
return 0
}
db.Close()
return 1
}
module github.com/go-sql-driver/mysql module github.com/go-sql-driver/mysql
go 1.13 go 1.18
require filippo.io/edwards25519 v1.1.0
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
...@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) { ...@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {
const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
func (mc *mysqlConn) handleInFileRequest(name string) (err error) { func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader var rdr io.Reader
var data []byte var data []byte
packetSize := defaultPacketSize packetSize := defaultPacketSize
...@@ -116,10 +116,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -116,10 +116,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
defer deferredClose(&err, cl) defer deferredClose(&err, cl)
} }
} else { } else {
err = fmt.Errorf("Reader '%s' is <nil>", name) err = fmt.Errorf("reader '%s' is <nil>", name)
} }
} else { } else {
err = fmt.Errorf("Reader '%s' is not registered", name) err = fmt.Errorf("reader '%s' is not registered", name)
} }
} else { // File } else { // File
name = strings.Trim(name, `"`) name = strings.Trim(name, `"`)
...@@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
for err == nil { for err == nil {
n, err = rdr.Read(data[4:]) n, err = rdr.Read(data[4:])
if n > 0 { if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
return ioErr return ioErr
} }
} }
...@@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if data == nil { if data == nil {
data = make([]byte, 4) data = make([]byte, 4)
} }
if ioErr := mc.writePacket(data[:4]); ioErr != nil { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr return ioErr
} }
...@@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { ...@@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
return mc.readResultOK() return mc.readResultOK()
} }
mc.readPacket() mc.conn().readPacket()
return err return err
} }
...@@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) { ...@@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
} }
nt.Valid = false nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value) return fmt.Errorf("can't convert %T to time.Time", value)
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
......
...@@ -14,10 +14,10 @@ import ( ...@@ -14,10 +14,10 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
"strconv"
"time" "time"
) )
...@@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
errLog.Print(err) mc.cfg.Logger.Print(err)
mc.Close() mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
...@@ -44,6 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -44,6 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// check packet sync [8 bit] // check packet sync [8 bit]
if data[3] != mc.sequence { if data[3] != mc.sequence {
mc.Close()
if data[3] > mc.sequence { if data[3] > mc.sequence {
return nil, ErrPktSyncMul return nil, ErrPktSyncMul
} }
...@@ -56,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -56,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if pktLen == 0 { if pktLen == 0 {
// there was no previous packet // there was no previous packet
if prevData == nil { if prevData == nil {
errLog.Print(ErrMalformPkt) mc.cfg.Logger.Print(ErrMalformPkt)
mc.Close() mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
...@@ -70,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { ...@@ -70,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
errLog.Print(err) mc.cfg.Logger.Print(err)
mc.Close() mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
...@@ -97,34 +98,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { ...@@ -97,34 +98,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
return ErrPktTooLarge return ErrPktTooLarge
} }
// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.reset {
mc.reset = false
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
if mc.cfg.CheckConnLiveness {
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
}
}
if err != nil {
errLog.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
}
for { for {
var size int var size int
if pktLen >= maxPacketSize { if pktLen >= maxPacketSize {
...@@ -161,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { ...@@ -161,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
// Handle error // Handle error
if err == nil { // n != len(data) if err == nil { // n != len(data)
mc.cleanup() mc.cleanup()
errLog.Print(ErrMalformPkt) mc.cfg.Logger.Print(ErrMalformPkt)
} else { } else {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return cerr return cerr
...@@ -171,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { ...@@ -171,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
return errBadConnNoWrite return errBadConnNoWrite
} }
mc.cleanup() mc.cleanup()
errLog.Print(err) mc.cfg.Logger.Print(err)
} }
return ErrInvalidConn return ErrInvalidConn
} }
...@@ -239,7 +212,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro ...@@ -239,7 +212,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
// reserved (all [00]) [10 bytes] // reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10 pos += 1 + 2 + 2 + 1 + 10
// second part of the password cipher [mininum 13 bytes], // second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8) // where len=MAX(13, length of auth-plugin-data - 8)
// //
// The web documentation is ambiguous about the length. However, // The web documentation is ambiguous about the length. However,
...@@ -285,6 +258,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -285,6 +258,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles | clientLocalFiles |
clientPluginAuth | clientPluginAuth |
clientMultiResults | clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag mc.flags&clientLongFlag
if mc.cfg.ClientFoundRows { if mc.cfg.ClientFoundRows {
...@@ -318,11 +292,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -318,11 +292,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1 pktLen += n + 1
} }
// encode length of the connection attributes
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
// Calculate packet length and get buffer with that size // Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -338,14 +318,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -338,14 +318,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[10] = 0x00 data[10] = 0x00
data[11] = 0x00 data[11] = 0x00
// Charset [1 byte] // Collation ID [1 byte]
cname := mc.cfg.Collation
if cname == "" {
cname = defaultCollation
}
var found bool var found bool
data[12], found = collations[mc.cfg.Collation] data[12], found = collations[cname]
if !found { if !found {
// Note possibility for false negatives: // Note possibility for false negatives:
// could be triggered although the collation is valid if the // could be triggered although the collation is valid if the
// collations map does not contain entries the server supports. // collations map does not contain entries the server supports.
return errors.New("unknown collation") return fmt.Errorf("unknown collation: %q", cname)
} }
// Filler [23 bytes] (all 0x00) // Filler [23 bytes] (all 0x00)
...@@ -394,6 +378,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string ...@@ -394,6 +378,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00 data[pos] = 0x00
pos++ pos++
// Connection Attributes
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
// Send Auth packet // Send Auth packet
return mc.writePacket(data[:pos]) return mc.writePacket(data[:pos])
} }
...@@ -404,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { ...@@ -404,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
data, err := mc.buf.takeSmallBuffer(pktLen) data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -424,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { ...@@ -424,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
data, err := mc.buf.takeSmallBuffer(4 + 1) data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -443,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { ...@@ -443,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
data, err := mc.buf.takeBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -464,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { ...@@ -464,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -495,7 +483,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { ...@@ -495,7 +483,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
switch data[0] { switch data[0] {
case iOK: case iOK:
return nil, "", mc.handleOkPacket(data) // resultUnchanged, since auth happens before any queries or
// commands have been executed.
return nil, "", mc.resultUnchanged().handleOkPacket(data)
case iAuthMoreData: case iAuthMoreData:
return data[1:], "", err return data[1:], "", err
...@@ -518,9 +508,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { ...@@ -518,9 +508,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
} }
} }
// Returns error if Packet is not an 'Result OK'-Packet // Returns error if Packet is not a 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error { func (mc *okHandler) readResultOK() error {
data, err := mc.readPacket() data, err := mc.conn().readPacket()
if err != nil { if err != nil {
return err return err
} }
...@@ -528,13 +518,17 @@ func (mc *mysqlConn) readResultOK() error { ...@@ -528,13 +518,17 @@ func (mc *mysqlConn) readResultOK() error {
if data[0] == iOK { if data[0] == iOK {
return mc.handleOkPacket(data) return mc.handleOkPacket(data)
} }
return mc.handleErrorPacket(data) return mc.conn().handleErrorPacket(data)
} }
// Result Set Header Packet // Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket() // handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
data, err := mc.conn().readPacket()
if err == nil { if err == nil {
switch data[0] { switch data[0] {
...@@ -542,20 +536,17 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { ...@@ -542,20 +536,17 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
return 0, mc.handleOkPacket(data) return 0, mc.handleOkPacket(data)
case iERR: case iERR:
return 0, mc.handleErrorPacket(data) return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile: case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:])) return 0, mc.handleInFileRequest(string(data[1:]))
} }
// column count // column count
num, _, n := readLengthEncodedInteger(data) num, _, _ := readLengthEncodedInteger(data)
if n-len(data) == 0 { // ignore remaining data in the packet. see #1478.
return int(num), nil return int(num), nil
} }
return 0, ErrMalformPkt
}
return 0, err return 0, err
} }
...@@ -607,18 +598,61 @@ func readStatus(b []byte) statusFlag { ...@@ -607,18 +598,61 @@ func readStatus(b []byte) statusFlag {
return statusFlag(b[0]) | statusFlag(b[1])<<8 return statusFlag(b[0]) | statusFlag(b[1])<<8
} }
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
// need to be cleared first (e.g. during authentication, or while additional
// resultsets are being fetched.)
func (mc *mysqlConn) resultUnchanged() *okHandler {
return (*okHandler)(mc)
}
// okHandler represents the state of the connection when mysqlConn.result has
// been prepared for processing of OK packets.
//
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
// callpaths must either:
//
// 1. first clear it using clearResult(), or
// 2. confirm that they don't need to (by calling resultUnchanged()).
//
// Both return an instance of type *okHandler.
type okHandler mysqlConn
// Exposes the underlying type's methods.
func (mc *okHandler) conn() *mysqlConn {
return (*mysqlConn)(mc)
}
// clearResult clears the connection's stored affectedRows and insertIds
// fields.
//
// It returns a handler that can process OK responses.
func (mc *mysqlConn) clearResult() *okHandler {
mc.result = mysqlResult{}
return (*okHandler)(mc)
}
// Ok Packet // Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error { func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int var n, m int
var affectedRows, insertId uint64
// 0x00 [1 byte] // 0x00 [1 byte]
// Affected rows [Length Coded Binary] // Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary] // Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) insertId, _, m = readLengthEncodedInteger(data[1+n:])
// Update for the current statement result (only used by
// readResultSetHeaderPacket).
if len(mc.result.affectedRows) > 0 {
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.result.insertIds) > 0 {
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}
// server_status [2 bytes] // server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2]) mc.status = readStatus(data[1+n+m : 1+n+m+2])
...@@ -769,7 +803,8 @@ func (rows *textRows) readRow(dest []driver.Value) error { ...@@ -769,7 +803,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {
for i := range dest { for i := range dest {
// Read bytes and convert to string // Read bytes and convert to string
dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) var buf []byte
buf, isNull, n, err = readLengthEncodedString(data[pos:])
pos += n pos += n
if err != nil { if err != nil {
...@@ -781,19 +816,40 @@ func (rows *textRows) readRow(dest []driver.Value) error { ...@@ -781,19 +816,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
continue continue
} }
if !mc.parseTime {
continue
}
// Parse time field
switch rows.rs.columns[i].fieldType { switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp, case fieldTypeTimestamp,
fieldTypeDateTime, fieldTypeDateTime,
fieldTypeDate, fieldTypeDate,
fieldTypeNewDate: fieldTypeNewDate:
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { if mc.parseTime {
return err dest[i], err = parseDateTime(buf, mc.cfg.Loc)
} else {
dest[i] = buf
} }
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
} else {
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
}
case fieldTypeFloat:
var d float64
d, err = strconv.ParseFloat(string(buf), 32)
dest[i] = float32(d)
case fieldTypeDouble:
dest[i], err = strconv.ParseFloat(string(buf), 64)
default:
dest[i] = buf
}
if err != nil {
return err
} }
} }
...@@ -938,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -938,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
} }
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection // cannot take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
...@@ -1116,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -1116,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() { if v.IsZero() {
b = append(b, "0000-00-00"...) b = append(b, "0000-00-00"...)
} else { } else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc)) b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil { if err != nil {
return err return err
} }
...@@ -1137,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -1137,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if valuesCap != cap(paramValues) { if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...) data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil { if err = mc.buf.store(data); err != nil {
errLog.Print(err) mc.cfg.Logger.Print(err)
return errBadConnNoWrite return errBadConnNoWrite
} }
} }
...@@ -1149,7 +1205,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ...@@ -1149,7 +1205,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
return mc.writePacket(data) return mc.writePacket(data)
} }
func (mc *mysqlConn) discardResults() error { // For each remaining resultset in the stream, discards its rows and updates
// mc.affectedRows and mc.insertIds.
func (mc *okHandler) discardResults() error {
for mc.status&statusMoreResultsExists != 0 { for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket() resLen, err := mc.readResultSetHeaderPacket()
if err != nil { if err != nil {
...@@ -1157,11 +1215,11 @@ func (mc *mysqlConn) discardResults() error { ...@@ -1157,11 +1215,11 @@ func (mc *mysqlConn) discardResults() error {
} }
if resLen > 0 { if resLen > 0 {
// columns // columns
if err := mc.readUntilEOF(); err != nil { if err := mc.conn().readUntilEOF(); err != nil {
return err return err
} }
// rows // rows
if err := mc.readUntilEOF(); err != nil { if err := mc.conn().readUntilEOF(); err != nil {
return err return err
} }
} }
......
...@@ -96,9 +96,11 @@ var _ net.Conn = new(mockConn) ...@@ -96,9 +96,11 @@ var _ net.Conn = new(mockConn)
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn) conn := new(mockConn)
connector := newConnector(NewConfig())
mc := &mysqlConn{ mc := &mysqlConn{
buf: newBuffer(conn), buf: newBuffer(conn),
cfg: NewConfig(), cfg: connector.cfg,
connector: connector,
netConn: conn, netConn: conn,
closech: make(chan struct{}), closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket, maxAllowedPacket: defaultMaxAllowedPacket,
...@@ -128,30 +130,34 @@ func TestReadPacketSingleByte(t *testing.T) { ...@@ -128,30 +130,34 @@ func TestReadPacketSingleByte(t *testing.T) {
} }
func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn) for _, testCase := range []struct {
mc := &mysqlConn{ ClientSequenceID byte
buf: newBuffer(conn), ServerSequenceID byte
} ExpectedErr error
}{
// too low sequence id {
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} ClientSequenceID: 1,
conn.maxReads = 1 ServerSequenceID: 0,
mc.sequence = 1 ExpectedErr: ErrPktSync,
},
{
ClientSequenceID: 0,
ServerSequenceID: 0x42,
ExpectedErr: ErrPktSyncMul,
},
} {
conn, mc := newRWMockConn(testCase.ClientSequenceID)
conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff}
_, err := mc.readPacket() _, err := mc.readPacket()
if err != ErrPktSync { if err != testCase.ExpectedErr {
t.Errorf("expected ErrPktSync, got %v", err) t.Errorf("expected %v, got %v", testCase.ExpectedErr, err)
} }
// reset // connection should not be returned to the pool in this state
conn.reads = 0 if mc.IsValid() {
mc.sequence = 0 t.Errorf("expected IsValid() to be false")
mc.buf = newBuffer(conn) }
// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
_, err = mc.readPacket()
if err != ErrPktSyncMul {
t.Errorf("expected ErrPktSyncMul, got %v", err)
} }
} }
...@@ -179,7 +185,7 @@ func TestReadPacketSplit(t *testing.T) { ...@@ -179,7 +185,7 @@ func TestReadPacketSplit(t *testing.T) {
data[4] = 0x11 data[4] = 0x11
data[maxPacketSize+3] = 0x22 data[maxPacketSize+3] = 0x22
// 2nd packet has payload length 0 and squence id 1 // 2nd packet has payload length 0 and sequence id 1
// 00 00 00 01 // 00 00 00 01
data[pkt2ofs+3] = 0x01 data[pkt2ofs+3] = 0x01
...@@ -211,7 +217,7 @@ func TestReadPacketSplit(t *testing.T) { ...@@ -211,7 +217,7 @@ func TestReadPacketSplit(t *testing.T) {
data[pkt2ofs+4] = 0x33 data[pkt2ofs+4] = 0x33
data[pkt2ofs+maxPacketSize+3] = 0x44 data[pkt2ofs+maxPacketSize+3] = 0x44
// 3rd packet has payload length 0 and squence id 2 // 3rd packet has payload length 0 and sequence id 2
// 00 00 00 02 // 00 00 00 02
data[pkt3ofs+3] = 0x02 data[pkt3ofs+3] = 0x02
...@@ -265,6 +271,7 @@ func TestReadPacketFail(t *testing.T) { ...@@ -265,6 +271,7 @@ func TestReadPacketFail(t *testing.T) {
mc := &mysqlConn{ mc := &mysqlConn{
buf: newBuffer(conn), buf: newBuffer(conn),
closech: make(chan struct{}), closech: make(chan struct{}),
cfg: NewConfig(),
} }
// illegal empty (stand-alone) packet // illegal empty (stand-alone) packet
......
...@@ -8,15 +8,43 @@ ...@@ -8,15 +8,43 @@
package mysql package mysql
import "database/sql/driver"
// Result exposes data not available through *connection.Result.
//
// This is accessible by executing statements using sql.Conn.Raw() and
// downcasting the returned result:
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result
// AllRowsAffected returns a slice containing the affected rows for each
// executed statement.
AllRowsAffected() []int64
// AllLastInsertIds returns a slice containing the last inserted ID for each
// executed statement.
AllLastInsertIds() []int64
}
type mysqlResult struct { type mysqlResult struct {
affectedRows int64 // One entry in both slices is created for every executed statement result.
insertId int64 affectedRows []int64
insertIds []int64
} }
func (res *mysqlResult) LastInsertId() (int64, error) { func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil return res.insertIds[len(res.insertIds)-1], nil
} }
func (res *mysqlResult) RowsAffected() (int64, error) { func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil return res.affectedRows[len(res.affectedRows)-1], nil
}
func (res *mysqlResult) AllLastInsertIds() []int64 {
return append([]int64{}, res.insertIds...) // defensive copy
}
func (res *mysqlResult) AllRowsAffected() []int64 {
return append([]int64{}, res.affectedRows...) // defensive copy
} }
...@@ -123,7 +123,8 @@ func (rows *mysqlRows) Close() (err error) { ...@@ -123,7 +123,8 @@ func (rows *mysqlRows) Close() (err error) {
err = mc.readUntilEOF() err = mc.readUntilEOF()
} }
if err == nil { if err == nil {
if err = mc.discardResults(); err != nil { handleOk := mc.clearResult()
if err = handleOk.discardResults(); err != nil {
return err return err
} }
} }
...@@ -160,7 +161,15 @@ func (rows *mysqlRows) nextResultSet() (int, error) { ...@@ -160,7 +161,15 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
return 0, io.EOF return 0, io.EOF
} }
rows.rs = resultSet{} rows.rs = resultSet{}
return rows.mc.readResultSetHeaderPacket() // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to
// nextResultSet.
resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
if err != nil {
// Clean up about multi-results flag
rows.rs.done = true
rows.mc.status = rows.mc.status & (^statusMoreResultsExists)
}
return resLen, err
} }
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
......
...@@ -51,7 +51,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { ...@@ -51,7 +51,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.Load() { if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn) stmt.mc.cfg.Logger.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command
...@@ -61,12 +61,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { ...@@ -61,12 +61,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
} }
mc := stmt.mc mc := stmt.mc
handleOk := stmt.mc.clearResult()
mc.affectedRows = 0
mc.insertId = 0
// Read Result // Read Result
resLen, err := mc.readResultSetHeaderPacket() resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -83,14 +81,12 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { ...@@ -83,14 +81,12 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
} }
} }
if err := mc.discardResults(); err != nil { if err := handleOk.discardResults(); err != nil {
return nil, err return nil, err
} }
return &mysqlResult{ copied := mc.result
affectedRows: int64(mc.affectedRows), return &copied, nil
insertId: int64(mc.insertId),
}, nil
} }
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
...@@ -99,7 +95,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { ...@@ -99,7 +95,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.Load() { if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn) stmt.mc.cfg.Logger.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command
...@@ -111,7 +107,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { ...@@ -111,7 +107,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
mc := stmt.mc mc := stmt.mc
// Read Result // Read Result
resLen, err := mc.readResultSetHeaderPacket() handleOk := stmt.mc.clearResult()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -36,7 +36,7 @@ var ( ...@@ -36,7 +36,7 @@ var (
// registering it. // registering it.
// //
// rootCertPool := x509.NewCertPool() // rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem") // pem, err := os.ReadFile("/path/ca-cert.pem")
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
...@@ -265,7 +265,11 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va ...@@ -265,7 +265,11 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
return nil, fmt.Errorf("invalid DATETIME packet length %d", num) return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
} }
func appendDateTime(buf []byte, t time.Time) ([]byte, error) { func appendDateTime(buf []byte, t time.Time, timeTruncate time.Duration) ([]byte, error) {
if timeTruncate > 0 {
t = t.Truncate(timeTruncate)
}
year, month, day := t.Date() year, month, day := t.Date()
hour, min, sec := t.Clock() hour, min, sec := t.Clock()
nsec := t.Nanosecond() nsec := t.Nanosecond()
...@@ -616,6 +620,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { ...@@ -616,6 +620,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
} }
func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer. // If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte { func reserveBuffer(buf []byte, appendSize int) []byte {
......
...@@ -239,6 +239,8 @@ func TestAppendDateTime(t *testing.T) { ...@@ -239,6 +239,8 @@ func TestAppendDateTime(t *testing.T) {
tests := []struct { tests := []struct {
t time.Time t time.Time
str string str string
timeTruncate time.Duration
expectedErr bool
}{ }{
{ {
t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC),
...@@ -276,32 +278,73 @@ func TestAppendDateTime(t *testing.T) { ...@@ -276,32 +278,73 @@ func TestAppendDateTime(t *testing.T) {
t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC),
str: "0001-01-01", str: "0001-01-01",
}, },
} // Truncated time
for _, v := range tests { {
buf := make([]byte, 0, 32) t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC),
buf, _ = appendDateTime(buf, v.t) str: "1234-05-06",
if str := string(buf); str != v.str { timeTruncate: time.Second,
t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str) },
} {
} t: time.Date(4567, 12, 31, 12, 0, 0, 0, time.UTC),
str: "4567-12-31 12:00:00",
timeTruncate: time.Minute,
},
{
t: time.Date(2020, 5, 30, 12, 34, 0, 0, time.UTC),
str: "2020-05-30 12:34:00",
timeTruncate: 0,
},
{
t: time.Date(2020, 5, 30, 12, 34, 56, 0, time.UTC),
str: "2020-05-30 12:34:56",
timeTruncate: time.Second,
},
{
t: time.Date(2020, 5, 30, 22, 33, 44, 123000000, time.UTC),
str: "2020-05-30 22:33:44",
timeTruncate: time.Second,
},
{
t: time.Date(2020, 5, 30, 22, 33, 44, 123456000, time.UTC),
str: "2020-05-30 22:33:44.123",
timeTruncate: time.Millisecond,
},
{
t: time.Date(2020, 5, 30, 22, 33, 44, 123456789, time.UTC),
str: "2020-05-30 22:33:44",
timeTruncate: time.Second,
},
{
t: time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC),
str: "9999-12-31 23:59:59.999999999",
timeTruncate: 0,
},
{
t: time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC),
str: "0001-01-01",
timeTruncate: 365 * 24 * time.Hour,
},
// year out of range // year out of range
{ {
v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) t: time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC),
expectedErr: true,
},
{
t: time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC),
expectedErr: true,
},
}
for _, v := range tests {
buf := make([]byte, 0, 32) buf := make([]byte, 0, 32)
_, err := appendDateTime(buf, v) buf, err := appendDateTime(buf, v.t, v.timeTruncate)
if err == nil { if err != nil {
t.Error("want an error") if !v.expectedErr {
return t.Errorf("appendDateTime(%v) returned an errror: %v", v.t, err)
} }
continue
} }
{ if str := string(buf); str != v.str {
v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC) t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str)
buf := make([]byte, 0, 32)
_, err := appendDateTime(buf, v)
if err == nil {
t.Error("want an error")
return
} }
} }
} }
......
...@@ -35,7 +35,7 @@ type Manager struct { ...@@ -35,7 +35,7 @@ type Manager struct {
//logger Logger //logger Logger
database *gorm.DB database *gorm.DB
dbSaver *DBSaver jobSyncer *JobSyncer
mu sync.Mutex mu sync.Mutex
} }
...@@ -152,9 +152,9 @@ func (m *Manager) DeleteJob(id JobID) error { ...@@ -152,9 +152,9 @@ func (m *Manager) DeleteJob(id JobID) error {
return err return err
} }
if m.dbSaver != nil { if m.jobSyncer != nil {
err := m.dbSaver.DeleteJob(job) err := m.jobSyncer.DeleteJob(job)
if err != nil { if err != nil {
return err return err
} }
...@@ -186,9 +186,9 @@ func (m *Manager) ResetJobLogs(id JobID) error { ...@@ -186,9 +186,9 @@ func (m *Manager) ResetJobLogs(id JobID) error {
return ErrJobNotActive return ErrJobNotActive
} }
if m.dbSaver != nil { if m.jobSyncer != nil {
err := m.dbSaver.ResetLogs(m.activeJobs[id]) err := m.jobSyncer.ResetLogs(m.activeJobs[id])
if err != nil { if err != nil {
return err return err
} }
...@@ -206,9 +206,9 @@ func (m *Manager) ResetJobStats(id JobID) error { ...@@ -206,9 +206,9 @@ func (m *Manager) ResetJobStats(id JobID) error {
return ErrJobNotActive return ErrJobNotActive
} }
if m.dbSaver != nil { if m.jobSyncer != nil {
err := m.dbSaver.ResetStats(m.activeJobs[id]) err := m.jobSyncer.ResetStats(m.activeJobs[id])
if err != nil { if err != nil {
return err return err
} }
...@@ -290,12 +290,10 @@ func (m *Manager) SetDB(db *gorm.DB) *Manager { ...@@ -290,12 +290,10 @@ func (m *Manager) SetDB(db *gorm.DB) *Manager {
defer m.mu.Unlock() defer m.mu.Unlock()
m.database = db m.database = db
if m.dbSaver != nil { if m.jobSyncer != nil {
return m return m
} }
m.jobSyncer = NewJobSyncer(m)
m.dbSaver = NewDBSaver()
m.dbSaver.SetManager(m)
return m return m
} }
...@@ -391,7 +389,6 @@ func (m *Manager) RemoveWorker(worker Worker) error { ...@@ -391,7 +389,6 @@ func (m *Manager) RemoveWorker(worker Worker) error {
// Start starts the manager // Start starts the manager
func (m *Manager) Start() error { func (m *Manager) Start() error {
var err error
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
...@@ -400,21 +397,31 @@ func (m *Manager) Start() error { ...@@ -400,21 +397,31 @@ func (m *Manager) Start() error {
return ErrManagerAlreadyRunning return ErrManagerAlreadyRunning
} }
if m.dbSaver != nil { if m.jobSyncer != nil {
p := StartDBSaver(m.dbSaver) p := CreateAndStartJobSyncer(m)
ready := make(chan struct{}) ready := make(chan struct{})
Then[bool, bool](p, func(value bool) (bool, error) { var jobSyncerErr error
Then[*JobSyncer, *JobSyncer](p, func(value *JobSyncer) (*JobSyncer, error) {
close(ready) close(ready)
m.mu.Lock()
m.jobSyncer = value
m.mu.Unlock()
return value, nil return value, nil
}, func(e error) error { }, func(e error) error {
close(ready) close(ready)
Error("Error while starting db saver", "error", err) Error("Error while starting db saver", "error", e)
jobSyncerErr = e
return nil return nil
}) })
<-ready <-ready
if jobSyncerErr != nil {
return jobSyncerErr
}
} }
if len(m.workerMap) == 0 { if len(m.workerMap) == 0 {
...@@ -449,7 +456,7 @@ func (m *Manager) Start() error { ...@@ -449,7 +456,7 @@ func (m *Manager) Start() error {
go m.handleJobEvents() go m.handleJobEvents()
err = m.checkAndSetRunningState() err := m.checkAndSetRunningState()
if err != nil { if err != nil {
wrappedErr = fmt.Errorf("%w\n%s", wrappedErr, err.Error()) wrappedErr = fmt.Errorf("%w\n%s", wrappedErr, err.Error())
...@@ -481,7 +488,7 @@ func (m *Manager) Stop() error { ...@@ -481,7 +488,7 @@ func (m *Manager) Stop() error {
for _, worker := range m.workerMap { for _, worker := range m.workerMap {
err := worker.Stop() err := worker.Stop()
if err != nil && err != ErrWorkerAlreadyStopped { if err != nil && !errors.Is(err, ErrWorkerAlreadyStopped) {
if wrappedErr == nil { if wrappedErr == nil {
wrappedErr = fmt.Errorf("Error: ") wrappedErr = fmt.Errorf("Error: ")
} }
...@@ -500,8 +507,16 @@ func (m *Manager) Stop() error { ...@@ -500,8 +507,16 @@ func (m *Manager) Stop() error {
m.cronInstance.Stop() m.cronInstance.Stop()
} }
if m.dbSaver != nil { if m.jobSyncer != nil {
m.dbSaver.Stop() err = m.jobSyncer.Stop()
if err != nil {
if wrappedErr == nil {
wrappedErr = fmt.Errorf("Error: ")
}
wrappedErr = fmt.Errorf("%w\n%s", wrappedErr, err.Error())
}
} }
return wrappedErr return wrappedErr
...@@ -564,11 +579,8 @@ func (m *Manager) ScheduleJob(job GenericJob, scheduler Scheduler) error { ...@@ -564,11 +579,8 @@ func (m *Manager) ScheduleJob(job GenericJob, scheduler Scheduler) error {
m.activeJobs[job.GetID()] = job m.activeJobs[job.GetID()] = job
if m.dbSaver != nil { if m.jobSyncer != nil {
err := m.dbSaver.SaveJob(job) m.jobSyncer.AddJob(job)
if err != nil {
return err
}
} }
return nil return nil
...@@ -621,7 +633,7 @@ func (m *Manager) handleJobEvents() { ...@@ -621,7 +633,7 @@ func (m *Manager) handleJobEvents() {
job := event.Data.(GenericJob) job := event.Data.(GenericJob)
err := m.queue.Enqueue(job) err := m.queue.Enqueue(job)
if err != nil && err != ErrJobAlreadyExists { if err != nil && !errors.Is(err, ErrJobAlreadyExists) {
Error("Error while queueing job", "error", err) Error("Error while queueing job", "error", err)
} }
......
...@@ -64,6 +64,10 @@ func (m *MockGenericJob) ResetStats() { ...@@ -64,6 +64,10 @@ func (m *MockGenericJob) ResetStats() {
} }
func (m *MockGenericJob) GetStats() JobStats {
return JobStats{}
}
func (m *MockGenericJob) GetMaxRetries() uint { func (m *MockGenericJob) GetMaxRetries() uint {
return 0 return 0
} }
...@@ -232,7 +236,7 @@ func TestManager_CancelJob(t *testing.T) { ...@@ -232,7 +236,7 @@ func TestManager_CancelJob(t *testing.T) {
func TestManagerEventHandling(t *testing.T) { func TestManagerEventHandling(t *testing.T) {
mgr := NewManager() mgr := NewManager()
worker := NewLocalWorker(1) worker := NewLocalWorker(10)
err := mgr.AddWorker(worker) err := mgr.AddWorker(worker)
assert.Nil(t, err) assert.Nil(t, err)
......
...@@ -68,6 +68,8 @@ func (c *CounterRunnable) GetType() string { ...@@ -68,6 +68,8 @@ func (c *CounterRunnable) GetType() string {
} }
func (c *CounterRunnable) GetPersistence() RunnableImport { func (c *CounterRunnable) GetPersistence() RunnableImport {
c.mu.Lock()
defer c.mu.Unlock()
data := JSONMap{ data := JSONMap{
"count": c.Count, "count": c.Count,
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
package jobqueue package jobqueue
import "context" import (
"context"
"sync"
)
func NewDummyRunnableFromMap(data map[string]any) (*DummyRunnable, error) { func NewDummyRunnableFromMap(data map[string]any) (*DummyRunnable, error) {
return &DummyRunnable{}, nil return &DummyRunnable{}, nil
...@@ -13,7 +16,9 @@ func NewDummyRunnableFromMap(data map[string]any) (*DummyRunnable, error) { ...@@ -13,7 +16,9 @@ func NewDummyRunnableFromMap(data map[string]any) (*DummyRunnable, error) {
type DummyResult struct { type DummyResult struct {
} }
type DummyRunnable struct{} type DummyRunnable struct {
mu sync.Mutex
}
func (d *DummyRunnable) Run(_ context.Context) (RunResult[DummyResult], error) { func (d *DummyRunnable) Run(_ context.Context) (RunResult[DummyResult], error) {
return RunResult[DummyResult]{ return RunResult[DummyResult]{
...@@ -26,6 +31,8 @@ func (d *DummyRunnable) GetType() string { ...@@ -26,6 +31,8 @@ func (d *DummyRunnable) GetType() string {
} }
func (c *DummyRunnable) GetPersistence() RunnableImport { func (c *DummyRunnable) GetPersistence() RunnableImport {
c.mu.Lock()
defer c.mu.Unlock()
data := JSONMap{} data := JSONMap{}
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
) )
func NewFileOperationRunnableFromMap(data map[string]interface{}) (*FileOperationRunnable, error) { func NewFileOperationRunnableFromMap(data map[string]interface{}) (*FileOperationRunnable, error) {
...@@ -60,6 +61,7 @@ type FileOperationRunnable struct { ...@@ -60,6 +61,7 @@ type FileOperationRunnable struct {
Operation string // z.B. "read", "write", "delete" Operation string // z.B. "read", "write", "delete"
FilePath string FilePath string
Content string // Optional, je nach Operation Content string // Optional, je nach Operation
mu sync.Mutex
} }
func (f *FileOperationRunnable) Run(_ context.Context) (RunResult[FileOperationResult], error) { func (f *FileOperationRunnable) Run(_ context.Context) (RunResult[FileOperationResult], error) {
...@@ -142,6 +144,8 @@ func (f *FileOperationRunnable) GetType() string { ...@@ -142,6 +144,8 @@ func (f *FileOperationRunnable) GetType() string {
} }
func (f *FileOperationRunnable) GetPersistence() RunnableImport { func (f *FileOperationRunnable) GetPersistence() RunnableImport {
f.mu.Lock()
defer f.mu.Unlock()
data := JSONMap{ data := JSONMap{
"operation": f.Operation, "operation": f.Operation,
......