diff --git a/database-1_test.go b/database-1_test.go index 4c57ee35df07116869043f2fc2e66e86bd8cefe2..db3401655e7a7474b5016afcadf432badae0ba29 100644 --- a/database-1_test.go +++ b/database-1_test.go @@ -23,15 +23,17 @@ func TestWriteToDB1(t *testing.T) { gormDB.Logger = gormDB.Logger.LogMode(4) manager := &Manager{database: gormDB} - saver := NewDBSaver().SetManager(manager) // Starte den DBSaver - p := StartDBSaver(saver) + promise := CreateAndStartJobSyncer(manager) ready := make(chan struct{}) - Then[bool, bool](p, func(value bool) (bool, error) { + var saver *JobSyncer + + Then[*JobSyncer, *JobSyncer](promise, func(value *JobSyncer) (*JobSyncer, error) { close(ready) + saver = value return value, nil }, func(e error) error { close(ready) @@ -51,15 +53,13 @@ func TestWriteToDB1(t *testing.T) { job.scheduler = scheduler - err = saver.SaveJob(job) - assert.Nil(t, err) - - err = saver.SaveJob(job) - assert.Nil(t, err) + saver.AddJob(job) + saver.AddJob(job) time.Sleep(1 * time.Second) - saver.Wait() + err = saver.Stop() + assert.Nil(t, err) // check if stats are in database var stats JobPersistence @@ -69,7 +69,4 @@ func TestWriteToDB1(t *testing.T) { assert.Equal(t, job.GetID(), stats.ID) assert.Equal(t, job.GetID(), stats.ID) - err = saver.Stop() - assert.Nil(t, err) - } diff --git a/database-2_test.go b/database-2_test.go index 5e5c6245773ab2ce62006829d92228df521ec75b..8673f8a114824cfb3262a5fa01d42036ebea0123 100644 --- a/database-2_test.go +++ b/database-2_test.go @@ -23,15 +23,18 @@ func TestWriteToDB2(t *testing.T) { gormDB.Logger = gormDB.Logger.LogMode(4) manager := &Manager{database: gormDB} - saver := NewDBSaver().SetManager(manager) + // Starte den DBSaver - p := StartDBSaver(saver) + p := CreateAndStartJobSyncer(manager) ready := make(chan struct{}) - Then[bool, bool](p, func(value bool) (bool, error) { + //var saver *JobSyncer + + Then[*JobSyncer, *JobSyncer](p, func(value *JobSyncer) (*JobSyncer, error) { close(ready) + // saver = value return value, nil }, func(e error) error { close(ready) @@ -89,15 +92,14 @@ func TestWriteToDB2(t *testing.T) { time.Sleep(1 * time.Second) - if mgr.dbSaver == nil { - t.Error("mgr.dbSaver == nil") + if mgr.jobSyncer == nil { + t.Error("mgr.JobSyncer == nil") return } time.Sleep(1 * time.Second) - err = mgr.dbSaver.SaveJob(job) - assert.Nil(t, err) + mgr.jobSyncer.AddJob(job) runtime.Gosched() time.Sleep(1 * time.Second) @@ -114,8 +116,7 @@ func TestWriteToDB2(t *testing.T) { runtime.Gosched() time.Sleep(1 * time.Second) - err = mgr.dbSaver.SaveJob(job) - assert.Nil(t, err) + mgr.jobSyncer.AddJob(job) time.Sleep(2 * time.Second) err = mgr.CancelJobSchedule("job1") @@ -142,6 +143,6 @@ func TestWriteToDB2(t *testing.T) { time.Sleep(1 * time.Second) - err = mgr.dbSaver.SaveJob(job) + mgr.jobSyncer.AddJob(job) } diff --git a/database-4_test.go b/database-4_test.go index 18970b5ecc9b8269c7eecf22ab486354d9f66d96..2cc20f2b515904459595ad78819027ffe78eefc3 100644 --- a/database-4_test.go +++ b/database-4_test.go @@ -39,7 +39,7 @@ func TestWriteToDB4(t *testing.T) { err = manager.ScheduleJob(job, scheduler) assert.Nil(t, err) - time.Sleep(200 * time.Millisecond) + time.Sleep(500 * time.Millisecond) // check is stats are the values above var tmpJob JobPersistence @@ -49,15 +49,15 @@ func TestWriteToDB4(t *testing.T) { // Validate the fields assert.Equal(t, JobID("job3"), tmpJob.ID) - assert.Equal(t, 21, tmpJob.Stats.RunCount) // +1 because of the first run - assert.Equal(t, 31, tmpJob.Stats.SuccessCount) // +1 because of the first run + assert.Equal(t, 21, tmpJob.Stats.RunCount) + assert.Equal(t, 31, tmpJob.Stats.SuccessCount) assert.Equal(t, 40, tmpJob.Stats.ErrorCount) // reset stats err = manager.ResetJobStats(job.GetID()) assert.Nil(t, err) - time.Sleep(2 * time.Second) + time.Sleep(500 * time.Millisecond) var tmpJob2 JobPersistence // check is stats are the values above diff --git a/database-5_test.go b/database-5_test.go index 33bce67b1bf306ee665dbcb07b0b5780988d9e0c..399a61217eb0797f9508f0bff2479831f0644b30 100644 --- a/database-5_test.go +++ b/database-5_test.go @@ -35,14 +35,15 @@ func TestWriteToDB5(t *testing.T) { sameIDJob := NewJob[CounterResult]("jobSameID", runner) - // Trying to save a job with the same ID should return an error - err = mgr.dbSaver.SaveJob(sameIDJob) - assert.Nil(t, err) + // Trying to save a job with the same ID should do nothing + mgr.mu.Lock() + mgr.jobSyncer.AddJob(sameIDJob) + mgr.mu.Unlock() err = mgr.CancelJobSchedule("jobSameID") assert.Nil(t, err) - err = mgr.dbSaver.Stop() + err = mgr.jobSyncer.Stop() assert.Nil(t, err) diff --git a/database-6_test.go b/database-6_test.go index 2dd30284a1e174fad63559ffce1f477b93a218e8..fee7927b2628f0a449caacfa16ced150f22582db 100644 --- a/database-6_test.go +++ b/database-6_test.go @@ -7,7 +7,7 @@ package jobqueue import ( "fmt" "github.com/stretchr/testify/assert" - "gorm.io/driver/sqlite" + "gorm.io/driver/mysql" "gorm.io/gorm" "testing" "time" @@ -15,10 +15,13 @@ import ( func TestWriteToDB6(t *testing.T) { - db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) - if err != nil { - t.Fatalf("a error occurred while opening the database: %v", err) - } + // it is necessary to have a running mysql server + // docker rm -f mysql-test && \ + // docker run --name mysql-test -e MYSQL_ROOT_PASSWORD=my-secret-pw -e MYSQL_DATABASE=testdb -p 3306:3306 -d mysql:latest && \ + // docker logs -f mysql-test + + dsn := "root:my-secret-pw@tcp(localhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local" + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) manager := NewManager() manager.SetDB(db) @@ -34,7 +37,7 @@ func TestWriteToDB6(t *testing.T) { jobIDs := make([]JobID, numJobs) for i := 0; i < numJobs; i++ { - jobID := JobID(fmt.Sprintf("burstJob%d", i)) + jobID := JobID(fmt.Sprintf("burstJobA%d", i)) jobIDs[i] = jobID runner := &CounterRunnable{} @@ -44,11 +47,14 @@ func TestWriteToDB6(t *testing.T) { err = mgr.ScheduleJob(job, scheduler) assert.Nil(t, err) - err = mgr.dbSaver.SaveJob(job) - assert.Nil(t, err) + mgr.jobSyncer.AddJob(job) + time.Sleep(10 * time.Millisecond) + } - time.Sleep(10 * time.Second) + mgr.jobSyncer.Stop() + + time.Sleep(2 * time.Second) for _, jobID := range jobIDs { var tmpJob JobPersistence diff --git a/database-7_test.go b/database-7_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2bb2b1352b101b740f902aa1069d8a30d6ca90b7 --- /dev/null +++ b/database-7_test.go @@ -0,0 +1,33 @@ +package jobqueue + +import ( + "github.com/stretchr/testify/assert" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "testing" +) + +func TestCreateOrUpdateJob(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + assert.NoError(t, err) + + // Migrate the schema + assert.NoError(t, db.AutoMigrate(&JobPersistence{}, &JobLog{}, &JobStats{})) + + runner := &CounterRunnable{} + job := NewJob[CounterResult]("job1", runner) + + assert.NoError(t, createOrUpdateJob(job, db)) + + var jobPersistence JobPersistence + assert.NoError(t, db.First(&jobPersistence, "id = ?", job.GetID()).Error) + assert.Equal(t, job.GetID(), jobPersistence.ID) + + assert.Equal(t, "", jobPersistence.Description) + assert.Equal(t, Priority(1), jobPersistence.Priority) + job.description = "Updated description" + assert.NoError(t, createOrUpdateJob(job, db)) + assert.NoError(t, db.First(&jobPersistence, "id = ?", job.GetID()).Error) + assert.Equal(t, "Updated description", jobPersistence.Description) + +} diff --git a/database.go b/database.go index 040885740a94a3e3cfb803119ca8a076367d6a37..edd611c26784170258756441e015a643d106da3d 100644 --- a/database.go +++ b/database.go @@ -1,169 +1,85 @@ -// Copyright 2023 schukai GmbH -// SPDX-License-Identifier: AGPL-3.0 - package jobqueue import ( - "context" "errors" "gorm.io/gorm" - "math/rand" - "sync" - "time" ) -type DBSaverStatus int +func (s *JobSyncer) DeleteJob(job GenericJob) error { + s.mu.Lock() + defer s.mu.Unlock() -const ( - DBSaverStatusStopped = iota - DBSaverStatusRunning -) + if s.manager == nil || s.manager.database == nil { + return ErrNoDatabaseConnection + } -type DBSaver struct { - saveChannel chan GenericJob - stopChan chan struct{} - migrateFlag bool - manager *Manager - status DBSaverStatus - mu sync.Mutex - jobSaveProgress sync.WaitGroup -} + return s.manager.database.Transaction(func(tx *gorm.DB) error { + permJob := job.GetPersistence() -type RunnerData string -type SchedulerData string + if err := tx.Where("job_id = ?", permJob.GetID()).Delete(&JobLog{}).Error; err != nil { + return err + } -// NewDBSaver creates a new DBSaver -func NewDBSaver() *DBSaver { - return &DBSaver{ - saveChannel: make(chan GenericJob, 1000), - stopChan: make(chan struct{}), - } -} + if err := tx.Where("job_id = ?", permJob.GetID()).Delete(&JobStats{}).Error; err != nil { + return err + } -// SetManager sets the manager of the DBSaver -func (s *DBSaver) SetManager(manager *Manager) *DBSaver { - s.mu.Lock() - defer s.mu.Unlock() + if err := tx.Delete(&permJob).Error; err != nil { + return err + } - s.manager = manager - return s + return nil + }) } -func (s *DBSaver) setStatus(status DBSaverStatus) *DBSaver { +func (s *JobSyncer) ResetLogs(job GenericJob) error { s.mu.Lock() defer s.mu.Unlock() - s.status = status - return s -} - -// isStatus returns true if the DBSaver has the given status -// the lock is not needed here, because it is only used in the Start() method -func (s *DBSaver) isStatus(status DBSaverStatus, lock bool) bool { - if lock { - s.mu.Lock() - defer s.mu.Unlock() + if s.manager == nil || s.manager.database == nil { + return ErrNoDatabaseConnection } - return s.status == status + return s.manager.database.Transaction(func(tx *gorm.DB) error { + permJob := job.GetPersistence() + if err := tx.Unscoped().Where("job_id = ?", permJob.GetID()).Delete(&JobLog{}).Error; err != nil { + return err + } + return nil + }) } -func StartDBSaver[P *Promise[bool]](s *DBSaver) *Promise[bool] { +func (s *JobSyncer) ResetStats(job GenericJob) error { s.mu.Lock() defer s.mu.Unlock() - return NewPromise[bool](func(resolve func(bool), reject func(error)) { - - if s.manager == nil || s.manager.database == nil { - reject(ErrNoDatabaseConnection) - return - } - - if s.isStatus(DBSaverStatusRunning, false) { - resolve(true) - return - } - - db := s.manager.database - if !s.migrateFlag { - err := db.AutoMigrate(&JobPersistence{}, &JobLog{}, &JobStats{}) - if err != nil { - reject(err) - return - } - s.migrateFlag = true - } + if s.manager == nil || s.manager.database == nil { + return ErrNoDatabaseConnection + } - ready := make(chan struct{}) - go runSaver(s, db, ready) + job.ResetStats() + stats := job.GetStats() + return s.manager.database.Transaction(func(tx *gorm.DB) error { + return tx.Model(&JobStats{}).Where("job_id = ?", job.GetID()).Select("*").Omit("deleted_at", "created_at", "job_id").Updates(stats).Error - <-ready - resolve(true) }) } -// -//// Start starts the DBSaver -//func (s *DBSaver) Start() error { -// s.mu.Lock() -// defer s.mu.Unlock() -// -// if s.manager == nil || s.manager.database == nil { -// return ErrNoDatabaseConnection -// } -// -// if s.isStatus(DBSaverStatusRunning, false) { -// return nil -// } -// -// db := s.manager.database -// -// if !s.migrateFlag { -// err := db.AutoMigrate(&JobPersistence{}, &JobLog{}, &JobStats{}) -// if err != nil { -// return err -// } -// s.migrateFlag = true -// } -// -// go runSaver(s, db) -// return nil -//} - -func runSaver(s *DBSaver, db *gorm.DB, ready chan struct{}) { - s.setStatus(DBSaverStatusRunning) - - defer func() { - // this runs after the function returns - // and needs to be protected by the lock - // of the setStatus method - //s.status = DBSaverStatusStopped - s.setStatus(DBSaverStatusStopped) - }() - - close(ready) - for { - - select { - case job := <-s.saveChannel: - s.jobSaveProgress.Add(1) - - err := CreateOrUpdateJob(job, db) +func (s *JobSyncer) CreateOrUpdateJob(job GenericJob) error { - if err != nil { - Error("Error while saving job", "error", err) - } - - s.jobSaveProgress.Done() + s.mu.Lock() + defer s.mu.Unlock() - case <-s.stopChan: - return - } + if s.manager == nil || s.manager.database == nil { + return ErrNoDatabaseConnection } + + return createOrUpdateJob(job, s.manager.database) + } -func CreateOrUpdateJob(job GenericJob, db *gorm.DB) error { +func createOrUpdateJob(job GenericJob, db *gorm.DB) error { return db.Transaction(func(tx *gorm.DB) error { @@ -220,18 +136,8 @@ func CreateOrUpdateJob(job GenericJob, db *gorm.DB) error { return tx.Error } - tx.Model(&permJob.Stats).Where("job_id = ?", permJob.GetID()). - Select( - []string{ - "run_count", - "success_count", - "error_count", - "time_metrics_avg_run_time", - "time_metrics_max_run_time", - "time_metrics_min_run_time", - "time_metrics_total_run_time", - }, - ). + tx.Model(&permJob.Stats). + Select("*").Omit("deleted_at", "created_at", "job_id"). UpdateColumns(permJob.Stats) if tx.Error != nil { @@ -250,185 +156,3 @@ func CreateOrUpdateJob(job GenericJob, db *gorm.DB) error { }) } - -func (s *DBSaver) Wait() { - s.jobSaveProgress.Wait() -} - -// Stop stops the DBSaver -func (s *DBSaver) Stop() error { - s.mu.Lock() - defer s.mu.Unlock() - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - go func() { - s.stopChan <- struct{}{} - s.jobSaveProgress.Wait() - cancel() - }() - - <-ctx.Done() - - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - Error("DBSaver did not stop in time") - return ctx.Err() - } - - s.status = DBSaverStatusStopped - return nil -} - -func exponentialBackoff(retry int) time.Duration { - waitTime := 100 * time.Millisecond - for i := 0; i < retry; i++ { - waitTime *= 2 - waitTime += time.Duration(rand.Int63n(int64(waitTime))) // #nosec G404 - } - return waitTime -} - -// SaveJob saves a job to the database -func (s *DBSaver) SaveJob(job GenericJob) error { - s.jobSaveProgress.Add(1) - s.mu.Lock() - defer func() { - if r := recover(); r != nil { - Error("Error while saving job", "error", r) - } - s.mu.Unlock() - s.jobSaveProgress.Done() - }() - - if s.saveChannel == nil { - return ErrDBSaverNotInitialized - } - if s.status != DBSaverStatusRunning { - return ErrDBSaverNotRunning - } - - maxRetries := 5 - - for retries := maxRetries; retries > 0; retries-- { - select { - case s.saveChannel <- job: - return nil - default: - Error("DBSaver channel is full, dropping safe for job with ID", "job_id", job.GetID()) - backoff := exponentialBackoff(maxRetries - retries) - Trace("DBSaver channel is full, retrying in", "backoff", backoff) - time.Sleep(backoff) - } - } - - return errors.New("failed to save job after multiple attempts") - -} - -func checkRunningSaver(s *DBSaver) (*gorm.DB, error) { - if s.manager == nil { - return nil, ErrNoManager - } - - if s.manager.database == nil { - return nil, ErrNoDatabaseConnection - } - - if !s.isStatus(DBSaverStatusRunning, false) { - return nil, ErrDBSaverNotRunning - } - - return s.manager.database, nil - -} - -// DeleteJob deletes a job from the database -func (s *DBSaver) DeleteJob(job GenericJob) error { - s.mu.Lock() - defer s.mu.Unlock() - var db *gorm.DB - var err error - - if db, err = checkRunningSaver(s); err != nil { - return err - } - - s.jobSaveProgress.Add(1) - return db.Transaction(func(tx *gorm.DB) error { - defer s.jobSaveProgress.Done() - permJob := job.GetPersistence() - - dbErr := tx.Where("job_id = ?", permJob.GetID()).Delete(&JobLog{}).Error - if dbErr != nil { - return dbErr - } - - dbErr = tx.Where("job_id = ?", permJob.GetID()).Delete(&JobStats{}).Error - if dbErr != nil { - return dbErr - } - - dbErr = tx.Delete(&permJob).Error - if dbErr != nil { - return dbErr - } - - return nil - }) - -} - -func (s *DBSaver) ResetLogs(job GenericJob) error { - s.mu.Lock() - defer s.mu.Unlock() - var db *gorm.DB - var err error - - if db, err = checkRunningSaver(s); err != nil { - return err - } - - s.jobSaveProgress.Add(1) - return db.Transaction(func(tx *gorm.DB) error { - defer s.jobSaveProgress.Done() - permJob := job.GetPersistence() - - // unscoped because we want to delete the logs finally - dbErr := tx.Unscoped().Where("job_id = ?", permJob.GetID()).Delete(&JobLog{}).Error - if dbErr != nil { - return dbErr - } - - return nil - }) -} - -func (s *DBSaver) ResetStats(job GenericJob) error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.saveChannel == nil { - return ErrDBSaverNotInitialized - } - - if s.status != DBSaverStatusRunning { - return ErrDBSaverNotRunning - } - - defer func() { - if r := recover(); r != nil { - Error("Error while saving job", "error", r) - } - }() - - job.ResetStats() - - select { - case s.saveChannel <- job: - default: - Error("DBSaver channel is full, dropping job with ID", "job_id", job.GetID()) - } - - return nil -} diff --git a/database_test.go b/database_test.go index b9f3b73e2d0262a1b79dc3d696ca5cd60ad88d33..f2a8269ccf523045c8c7114dc41c3044b9470139 100644 --- a/database_test.go +++ b/database_test.go @@ -1,148 +1,135 @@ -//go:build !runOnTask - -// Copyright 2023 schukai GmbH -// SPDX-License-Identifier: AGPL-3.0 package jobqueue import ( + "github.com/stretchr/testify/assert" + "gorm.io/driver/sqlite" "gorm.io/gorm" "testing" "time" - - "github.com/stretchr/testify/assert" - "gorm.io/driver/sqlite" ) -// -//func startTestMySQLDockerImageAndContainer(t *testing.T, port string, ctx context.Context) error { -// t.Helper() -// -// cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) -// if err != nil { -// return err -// } -// -// imageName := "mysql:8" -// -// reader, err := cli.ImagePull(ctx, imageName, types.ImagePullOptions{}) -// if err != nil { -// return err -// } -// -// // if debug image pull, comment out the following lines -// //_, _ = io.Copy(os.Stdout, reader) -// _ = reader -// -// hostConfig := &container.HostConfig{ -// PortBindings: nat.PortMap{ -// "3306/tcp": []nat.PortBinding{ -// { -// HostIP: DOCKER_TEST_HOST_IP, -// HostPort: port, -// }, -// }, -// }, -// } -// -// resp, err := cli.ContainerCreate(ctx, &container.Config{ -// Image: imageName, -// Env: []string{ -// "MYSQL_ROOT_PASSWORD=secret", -// "MYSQL_USER=user", -// "MYSQL_PASSWORD=secret", -// "MYSQL_DATABASE=test", -// }, -// }, hostConfig, nil, nil, "") -// -// if err != nil { -// return err -// } -// -// if err := cli.ContainerStart(ctx, resp.ID, types.ContainerStartOptions{}); err != nil { -// return err -// } -// -// go func() { -// <-ctx.Done() -// -// timeout := 0 -// stopOptions := container.StopOptions{ -// Timeout: &timeout, -// Signal: "SIGKILL", -// } -// newCtx, _ := context.WithTimeout(context.Background(), 60*time.Second) -// if err := cli.ContainerStop(newCtx, resp.ID, stopOptions); err != nil { -// t.Errorf("ContainerStop returned error: %v", err) -// } -// if err := cli.ContainerRemove(newCtx, resp.ID, types.ContainerRemoveOptions{ -// Force: true, -// }); err != nil { -// t.Errorf("ContainerRemove returned error: %v", err) -// } -// -// }() -// -// statusCh, errCh := cli.ContainerWait(ctx, resp.ID, container.WaitConditionNotRunning) -// select { -// case err := <-errCh: -// if err != nil { -// // empty error means container exited normally (see container_wait.go) -// if err.Error() == "" { -// return nil -// } -// -// return err -// } -// case <-statusCh: -// -// } -// -// return nil -//} - -func TestSaveJobWithSQLite(t *testing.T) { - - gormDB, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) - if err != nil { - t.Fatalf("a error occurred while opening the database: %v", err) - } +func TestDeleteJob(t *testing.T) { + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) - manager := &Manager{database: gormDB} - saver := NewDBSaver().SetManager(manager) + db.Logger = db.Logger.LogMode(4) - // Starte den DBSaver - p := StartDBSaver(saver) + assert.Nil(t, err) + err = db.AutoMigrate(&JobPersistence{}, &JobLog{}, &JobStats{}) + assert.Nil(t, err) - ready := make(chan struct{}) + manager := NewManager() + manager.SetDB(db) + jobSyncer := NewJobSyncer(manager) - Then[bool, bool](p, func(value bool) (bool, error) { - close(ready) - return value, nil - }, func(e error) error { - close(ready) - Error("Error while starting db saver", "error", err) - return nil - }) + // Erstelle einen Job zum Löschen + runner := &CounterRunnable{} + job := NewJob[CounterResult]("testJobID", runner) + err = createOrUpdateJob(job, db) + assert.Nil(t, err) - <-ready + var count int64 + db.Model(&JobPersistence{}).Where("id = ?", "testJobID").Count(&count) + assert.Equal(t, int64(1), count) - jobID := JobID("testJob") - job := NewJob[CounterResult](jobID, &CounterRunnable{}) + // Lösche den Job + err = jobSyncer.DeleteJob(job) + assert.Nil(t, err) - err = saver.SaveJob(job) - assert.NoError(t, err) + // Überprüfe, ob der Job gelöscht wurde + db.Model(&JobPersistence{}).Where("id = ?", "testJobID").Count(&count) + assert.Equal(t, int64(0), count) +} +func TestResetLogs(t *testing.T) { + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + assert.Nil(t, err) + + db.Logger = db.Logger.LogMode(4) + + // Automatische Migration für benötigte Strukturen + err = db.AutoMigrate(&JobPersistence{}, &JobLog{}, &JobStats{}) + assert.Nil(t, err) + + manager := NewManager() + manager.SetDB(db) + jobSyncer := NewJobSyncer(manager) + + // Erstelle einen Job und füge einige Logs hinzu + runner := &CounterRunnable{} + job := NewJob[CounterResult]("testJobID", runner) + err = createOrUpdateJob(job, db) + assert.Nil(t, err) + + // Füge Logs zum Job hinzu + for i := 0; i < 5; i++ { + log := JobLog{JobID: job.GetID(), Result: "Test Message"} + err = db.Create(&log).Error + assert.Nil(t, err) + } - time.Sleep(100 * time.Millisecond) + var logCount int64 + db.Model(&JobLog{}).Where("job_id = ?", job.GetID()).Count(&logCount) + assert.Equal(t, int64(5), logCount) - saver.Stop() + // Setze die Logs zurück + err = jobSyncer.ResetLogs(job) + assert.Nil(t, err) - var count int64 - gormDB.Model(&JobPersistence{}).Count(&count) - assert.Equal(t, int64(1), count, "It should be 1 job in the database") + // Überprüfe, ob die Logs gelöscht wurden - // get job from database - var jobFromDB JobPersistence - gormDB.First(&jobFromDB, "id = ?", jobID) - assert.Equal(t, jobID, jobFromDB.ID, "JobID should be the same") + db.Model(&JobLog{}).Where("job_id = ?", job.GetID()).Count(&logCount) + assert.Equal(t, int64(0), logCount) +} +func TestResetStats(t *testing.T) { + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + assert.Nil(t, err) + + db.Logger = db.Logger.LogMode(4) + + // Automatische Migration für benötigte Strukturen + err = db.AutoMigrate(&JobPersistence{}, &JobLog{}, &JobStats{}) + assert.Nil(t, err) + + manager := NewManager() + manager.SetDB(db) + jobSyncer := NewJobSyncer(manager) + + // Erstelle einen Job und setze einige Statistiken + runner := &CounterRunnable{} + job := NewJob[CounterResult]("testJobID", runner) + err = createOrUpdateJob(job, db) + assert.Nil(t, err) + + // Aktualisiere die Job-Statistiken + jobStats := &JobStats{ + JobID: job.GetID(), + RunCount: 5, + SuccessCount: 3, + ErrorCount: 2, + TimeMetrics: TimeMetrics{ + AvgRunTime: 10 * time.Second, + MaxRunTime: 15 * time.Second, + MinRunTime: 5 * time.Second, + TotalRunTime: 50 * time.Second, + }, + } + err = db.Save(jobStats).Error + assert.Nil(t, err) + + // Setze die Statistiken zurück + err = jobSyncer.ResetStats(job) + assert.Nil(t, err) + + // Überprüfe, ob die Statistiken zurückgesetzt wurden + var resetStats JobStats + err = db.First(&resetStats, "job_id = ?", job.GetID()).Error + assert.Nil(t, err) + + assert.Equal(t, int(0), resetStats.RunCount) + assert.Equal(t, int(0), resetStats.SuccessCount) + assert.Equal(t, int(0), resetStats.ErrorCount) + assert.Equal(t, time.Duration(0), resetStats.TimeMetrics.AvgRunTime) + assert.Equal(t, time.Duration(0), resetStats.TimeMetrics.MaxRunTime) + assert.Equal(t, time.Duration(0), resetStats.TimeMetrics.MinRunTime) + assert.Equal(t, time.Duration(0), resetStats.TimeMetrics.TotalRunTime) } diff --git a/devenv.lock b/devenv.lock index 760ce636be6b188e97f061413cd3de47443e99f7..6b47afb1fd265375d548fa1e63cb09538c404a1d 100644 --- a/devenv.lock +++ b/devenv.lock @@ -74,11 +74,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1710162809, - "narHash": "sha256-i2R2bcnQp+85de67yjgZVvJhd6rRnJbSYNpGmB6Leb8=", + "lastModified": 1710695816, + "narHash": "sha256-3Eh7fhEID17pv9ZxrPwCLfqXnYP006RKzSs0JptsN84=", "owner": "nixos", "repo": "nixpkgs", - "rev": "ddcd7598b2184008c97e6c9c6a21c5f37590b8d2", + "rev": "614b4613980a522ba49f0d194531beddbb7220d3", "type": "github" }, "original": { @@ -106,11 +106,11 @@ }, "nixpkgs_2": { "locked": { - "lastModified": 1710162809, - "narHash": "sha256-i2R2bcnQp+85de67yjgZVvJhd6rRnJbSYNpGmB6Leb8=", + "lastModified": 1710695816, + "narHash": "sha256-3Eh7fhEID17pv9ZxrPwCLfqXnYP006RKzSs0JptsN84=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "ddcd7598b2184008c97e6c9c6a21c5f37590b8d2", + "rev": "614b4613980a522ba49f0d194531beddbb7220d3", "type": "github" }, "original": { diff --git a/errors.go b/errors.go index 94e08d647da31650a0acad04d1d266aa1890cba1..eafc6bbbb96dedc934bc0a6fd9bfa203d596dd4c 100644 --- a/errors.go +++ b/errors.go @@ -50,4 +50,6 @@ var ( ErrInvalidTime = fmt.Errorf("invalid time") ErrSchedulerMisconfiguration = fmt.Errorf("scheduler misconfiguration") ErrInvalidDuration = fmt.Errorf("invalid duration") + ErrJobSyncerAlreadyRunning = fmt.Errorf("JobSyncer is already running") + ErrJobSyncerNotRunning = fmt.Errorf("JobSyncer is not running") ) diff --git a/event-bus.go b/event-bus.go index bf9f491454c43beccf484d226ec5c460bf380821..4038de7755732476d0b6666456ea0b6140c59a98 100644 --- a/event-bus.go +++ b/event-bus.go @@ -83,6 +83,11 @@ func (eb *EventBus) Unsubscribe(name EventName, ch chan interface{}) { for i := range channels { if channels[i] == ch { eb.subscribers[name] = append(channels[:i], channels[i+1:]...) + + if len(eb.subscribers[name]) == 0 { + delete(eb.subscribers, name) + } + break } } diff --git a/go.mod b/go.mod index 07b3053dd0bb0ad4413d1e11591ecaf1ab4ef727..6baeff1dcb6d38f7dd8f3aeb4ed8ac21802b41d5 100644 --- a/go.mod +++ b/go.mod @@ -16,8 +16,8 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/crypto v0.21.0 gopkg.in/yaml.v3 v3.0.1 - gorm.io/driver/mysql v1.5.4 - gorm.io/gorm v1.25.7 + gorm.io/driver/mysql v1.5.5 + gorm.io/gorm v1.25.8 ) require ( diff --git a/go.sum b/go.sum index 0e4b77cef74ebf07b51e167498a695aef09417c7..03a350f93722601728cec9f6e0ac5faec79e81d1 100644 --- a/go.sum +++ b/go.sum @@ -208,6 +208,8 @@ gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs= gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8= gorm.io/driver/mysql v1.5.4 h1:igQmHfKcbaTVyAIHNhhB888vvxh8EdQ2uSUT0LPcBso= gorm.io/driver/mysql v1.5.4/go.mod h1:9rYxJph/u9SWkWc9yY4XJ1F/+xO0S/ChOmbk3+Z5Tvs= +gorm.io/driver/mysql v1.5.5 h1:WxklwX6FozMs1gk9yVadxGfjGiJjrBKPvIIvYZOMyws= +gorm.io/driver/mysql v1.5.5/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= @@ -216,5 +218,7 @@ gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.8 h1:WAGEZ/aEcznN4D03laj8DKnehe1e9gYQAjW8xyPRdeo= +gorm.io/gorm v1.25.8/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/job-generic.go b/job-generic.go index eaaaf87b71c1f9356f263b8468f456e1571a3551..4c3c9fa92208af30f5d654f9577ae95933edad92 100644 --- a/job-generic.go +++ b/job-generic.go @@ -38,4 +38,6 @@ type GenericJob interface { IsPaused() bool ResetStats() + + GetStats() JobStats } diff --git a/job-queues.iml b/job-queues.iml new file mode 100644 index 0000000000000000000000000000000000000000..789c0e8b4a221b9a6b43b7164dd4a46ca78d5822 --- /dev/null +++ b/job-queues.iml @@ -0,0 +1,12 @@ +<?xml version="1.0" encoding="UTF-8"?> +<module type="JAVA_MODULE" version="4"> + <component name="NewModuleRootManager" inherit-compiler-output="true"> + <exclude-output /> + <content url="file://$MODULE_DIR$"> + <sourceFolder url="file://$MODULE_DIR$/.devenv/state/go/pkg/mod/github.com/google/addlicense@v1.1.1/testdata/expected" isTestSource="false" /> + <sourceFolder url="file://$MODULE_DIR$/.devenv/state/go/pkg/mod/github.com/google/addlicense@v1.1.1/testdata/initial" isTestSource="false" /> + </content> + <orderEntry type="inheritedJdk" /> + <orderEntry type="sourceFolder" forTests="false" /> + </component> +</module> \ No newline at end of file diff --git a/job-syncer.go b/job-syncer.go new file mode 100644 index 0000000000000000000000000000000000000000..53b4a1a22e80384b07a394b6afee609671e46bec --- /dev/null +++ b/job-syncer.go @@ -0,0 +1,169 @@ +// Copyright 2023 schukai GmbH +// SPDX-License-Identifier: AGPL-3.0 + +package jobqueue + +import ( + "sync" +) + +type JobSyncer struct { + jobQueue []GenericJob + queueLock sync.Mutex + notifyChannel chan struct{} + stopChan chan struct{} + jobSaveProgress sync.WaitGroup + status Status + manager *Manager + mu sync.Mutex + migrateFlag bool +} + +type Status int + +const ( + JobSyncerStatusStopped Status = iota + JobSyncerStatusRunning +) + +func CreateAndStartJobSyncer[P *Promise[*JobSyncer]](manager *Manager) *Promise[*JobSyncer] { + + s := NewJobSyncer(manager) + + s.mu.Lock() + defer s.mu.Unlock() + + return NewPromise[*JobSyncer](func(resolve func(*JobSyncer), reject func(error)) { + + if s.manager == nil || s.manager.database == nil { + reject(ErrNoDatabaseConnection) + return + } + + if s.status == JobSyncerStatusRunning { + resolve(s) + return + } + + db := s.manager.database + if !s.migrateFlag { + err := db.AutoMigrate(&JobPersistence{}, &JobLog{}, &JobStats{}) + if err != nil { + reject(err) + return + } + s.migrateFlag = true + } + + err := s.Start() + if err != nil { + reject(err) + return + } + resolve(s) + }) +} + +func NewJobSyncer(manager *Manager) *JobSyncer { + return &JobSyncer{ + jobQueue: make([]GenericJob, 0), + manager: manager, + } +} + +func (js *JobSyncer) Start() error { + js.mu.Lock() + defer js.mu.Unlock() + + if js.status == JobSyncerStatusRunning { + return ErrJobSyncerAlreadyRunning + } + + js.notifyChannel = make(chan struct{}, 1) // Buffer to avoid blocking + js.stopChan = make(chan struct{}) + js.status = JobSyncerStatusRunning + + go js.runWorker() + return nil +} + +func (js *JobSyncer) runWorker() { + for { + select { + case <-js.notifyChannel: + js.processJobs() + case <-js.stopChan: + js.cleanup() + return + } + } +} + +func (js *JobSyncer) processJobs() { + for { + js.queueLock.Lock() + if len(js.jobQueue) == 0 { + js.queueLock.Unlock() + return + } + job := js.jobQueue[0] + js.jobQueue = js.jobQueue[1:] + js.queueLock.Unlock() + + js.jobSaveProgress.Add(1) + js.processJob(job) + js.jobSaveProgress.Done() + } +} + +func (js *JobSyncer) AddJob(job GenericJob) { + js.queueLock.Lock() + + // check if job is already in queue + for _, j := range js.jobQueue { + if j.GetID() == job.GetID() { + js.queueLock.Unlock() + return + } + } + + js.jobQueue = append(js.jobQueue, job) + js.queueLock.Unlock() + + js.jobSaveProgress.Add(1) + + // Non-blocking notify + select { + case js.notifyChannel <- struct{}{}: + default: + } +} + +func (js *JobSyncer) processJob(job GenericJob) { + defer js.jobSaveProgress.Done() + err := createOrUpdateJob(job, js.manager.database) + Error("Error while creating or updating job", "error", err) +} + +func (js *JobSyncer) Stop() error { + js.mu.Lock() + if js.status != JobSyncerStatusRunning { + js.mu.Unlock() + return ErrJobSyncerNotRunning + } + js.status = JobSyncerStatusStopped + js.mu.Unlock() + + close(js.stopChan) + js.jobSaveProgress.Wait() + + return nil +} + +func (js *JobSyncer) cleanup() { + js.mu.Lock() + defer js.mu.Unlock() + + js.jobSaveProgress.Wait() + js.status = JobSyncerStatusStopped +} diff --git a/job-syncer_test.go b/job-syncer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c0453f3886592959c62228e50b4eb901a9bbe4d7 --- /dev/null +++ b/job-syncer_test.go @@ -0,0 +1,62 @@ +//go:build !runOnTask + +// Copyright 2023 schukai GmbH +// SPDX-License-Identifier: AGPL-3.0 +package jobqueue + +import ( + "gorm.io/gorm" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "gorm.io/driver/sqlite" +) + +func TestSaveJobWithSQLite(t *testing.T) { + + gormDB, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("a error occurred while opening the database: %v", err) + } + + manager := &Manager{database: gormDB} + //saver := NewJobSyncer(manager) + + // Starte den DBSaver + p := CreateAndStartJobSyncer(manager) + + ready := make(chan struct{}) + var saver *JobSyncer + + Then[*JobSyncer, *JobSyncer](p, func(value *JobSyncer) (*JobSyncer, error) { + saver = value + close(ready) + return value, nil + }, func(e error) error { + close(ready) + Error("Error while starting db saver", "error", err) + return nil + }) + + <-ready + + jobID := JobID("testJob") + job := NewJob[CounterResult](jobID, &CounterRunnable{}) + + saver.AddJob(job) + + time.Sleep(100 * time.Millisecond) + + saver.Stop() + + var count int64 + gormDB.Model(&JobPersistence{}).Count(&count) + assert.Equal(t, int64(1), count, "It should be 1 job in the database") + + // get job from database + var jobFromDB JobPersistence + gormDB.First(&jobFromDB, "id = ?", jobID) + assert.Equal(t, jobID, jobFromDB.ID, "JobID should be the same") + +} diff --git a/licenses/filippo.io/edwards25519/LICENSE b/licenses/filippo.io/edwards25519/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6a66aea5eafe0ca6a688840c47219556c552488e --- /dev/null +++ b/licenses/filippo.io/edwards25519/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/github.com/go-sql-driver/mysql/.github/workflows/codeql.yml b/licenses/github.com/go-sql-driver/mysql/.github/workflows/codeql.yml index d9d29a8b7e4fe092012c9e6847b48324a206a838..83a3d6ee8b0ccb23d032cddfddbe39759696d696 100644 --- a/licenses/github.com/go-sql-driver/mysql/.github/workflows/codeql.yml +++ b/licenses/github.com/go-sql-driver/mysql/.github/workflows/codeql.yml @@ -24,18 +24,18 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} queries: +security-and-quality - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 with: category: "/language:${{ matrix.language }}" diff --git a/licenses/github.com/go-sql-driver/mysql/.github/workflows/test.yml b/licenses/github.com/go-sql-driver/mysql/.github/workflows/test.yml index d45ed0fa945d9837f960afeb334a526252a242fd..f5a115802013e057eddfd8a95dadab7ff195eba7 100644 --- a/licenses/github.com/go-sql-driver/mysql/.github/workflows/test.yml +++ b/licenses/github.com/go-sql-driver/mysql/.github/workflows/test.yml @@ -11,6 +11,14 @@ env: MYSQL_TEST_CONCURRENT: 1 jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dominikh/staticcheck-action@v1.3.0 + with: + version: "2023.1.6" + list: runs-on: ubuntu-latest outputs: @@ -23,17 +31,14 @@ jobs: import os go = [ # Keep the most recent production release at the top - '1.20', + '1.21', # Older production releases + '1.20', '1.19', '1.18', - '1.17', - '1.16', - '1.15', - '1.14', - '1.13', ] mysql = [ + '8.1', '8.0', '5.7', '5.6', @@ -47,7 +52,7 @@ jobs: includes = [] # Go versions compatibility check for v in go[1:]: - includes.append({'os': 'ubuntu-latest', 'go': v, 'mysql': mysql[0]}) + includes.append({'os': 'ubuntu-latest', 'go': v, 'mysql': mysql[0]}) matrix = { # OS vs MySQL versions @@ -68,11 +73,11 @@ jobs: fail-fast: false matrix: ${{ fromJSON(needs.list.outputs.matrix) }} steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - - uses: shogo82148/actions-setup-mysql@v1.15.0 + - uses: shogo82148/actions-setup-mysql@v1 with: mysql-version: ${{ matrix.mysql }} user: ${{ env.MYSQL_TEST_USER }} @@ -84,13 +89,14 @@ jobs: ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 + performance_schema=on - name: setup database run: | mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;' - name: test run: | - go test -v '-covermode=count' '-coverprofile=coverage.out' + go test -v '-race' '-covermode=atomic' '-coverprofile=coverage.out' -parallel 10 - name: Send coverage uses: shogo82148/actions-goveralls@v1 diff --git a/licenses/github.com/go-sql-driver/mysql/AUTHORS b/licenses/github.com/go-sql-driver/mysql/AUTHORS index fb1478c3bc67d53611fd903e60eb601d2dc28aa3..4021b96cc0c51b4dec36ad26000531964910011c 100644 --- a/licenses/github.com/go-sql-driver/mysql/AUTHORS +++ b/licenses/github.com/go-sql-driver/mysql/AUTHORS @@ -13,6 +13,7 @@ Aaron Hopkins <go-sql-driver at die.net> Achille Roussel <achille.roussel at gmail.com> +Aidan <aidan.liu at pingcap.com> Alex Snast <alexsn at fb.com> Alexey Palazhchenko <alexey.palazhchenko at gmail.com> Andrew Reid <andrew.reid at tixtrack.com> @@ -20,12 +21,14 @@ Animesh Ray <mail.rayanimesh at gmail.com> Arne Hormann <arnehormann at gmail.com> Ariel Mashraki <ariel at mashraki.co.il> Asta Xie <xiemengjun at gmail.com> +Brian Hendriks <brian at dolthub.com> Bulat Gaifullin <gaifullinbf at gmail.com> Caine Jette <jette at alum.mit.edu> Carlos Nieto <jose.carlos at menteslibres.net> Chris Kirkland <chriskirkland at github.com> Chris Moos <chris at tech9computers.com> Craig Wilson <craiggwilson at gmail.com> +Daemonxiao <735462752 at qq.com> Daniel Montoya <dsmontoyam at gmail.com> Daniel Nichter <nil at codenode.com> Daniël van Eeden <git at myname.nl> @@ -33,9 +36,11 @@ Dave Protasowski <dprotaso at gmail.com> DisposaBoy <disposaboy at dby.me> Egor Smolyakov <egorsmkv at gmail.com> Erwan Martin <hello at erwan.io> +Evan Elias <evan at skeema.net> Evan Shaw <evan at vendhq.com> Frederick Mayle <frederickmayle at gmail.com> Gustavo Kristic <gkristic at gmail.com> +Gusted <postmaster at gusted.xyz> Hajime Nakagami <nakagami at gmail.com> Hanno Braun <mail at hannobraun.com> Henri Yandell <flamefew at gmail.com> @@ -47,8 +52,11 @@ INADA Naoki <songofacandy at gmail.com> Jacek Szwec <szwec.jacek at gmail.com> James Harr <james.harr at gmail.com> Janek Vedock <janekvedock at comcast.net> +Jason Ng <oblitorum at gmail.com> +Jean-Yves Pellé <jy at pelle.link> Jeff Hodges <jeff at somethingsimilar.com> Jeffrey Charles <jeffreycharles at gmail.com> +Jennifer Purevsuren <jennifer at dolthub.com> Jerome Meyer <jxmeyer at gmail.com> Jiajia Zhong <zhong2plus at gmail.com> Jian Zhen <zhenjl at gmail.com> @@ -74,9 +82,11 @@ Maciej Zimnoch <maciej.zimnoch at codilime.com> Michael Woolnough <michael.woolnough at gmail.com> Nathanial Murphy <nathanial.murphy at gmail.com> Nicola Peduzzi <thenikso at gmail.com> +Oliver Bone <owbone at github.com> Olivier Mengué <dolmen at cpan.org> oscarzhao <oscarzhaosl at gmail.com> Paul Bonser <misterpib at gmail.com> +Paulius Lozys <pauliuslozys at gmail.com> Peter Schultz <peter.schultz at classmarkets.com> Phil Porada <philporada at gmail.com> Rebecca Chin <rchin at pivotal.io> @@ -95,6 +105,7 @@ Stan Putrya <root.vagner at gmail.com> Stanley Gunawan <gunawan.stanley at gmail.com> Steven Hartland <steven.hartland at multiplay.co.uk> Tan Jinhua <312841925 at qq.com> +Tetsuro Aoki <t.aoki1130 at gmail.com> Thomas Wodarek <wodarekwebpage at gmail.com> Tim Ruffles <timruffles at gmail.com> Tom Jenkinson <tom at tjenkinson.me> @@ -104,6 +115,7 @@ Xiangyu Hu <xiangyu.hu at outlook.com> Xiaobing Jiang <s7v7nislands at gmail.com> Xiuming Chen <cc at cxm.cc> Xuehong Chan <chanxuehong at gmail.com> +Zhang Xiang <angwerzx at 126.com> Zhenye Xie <xiezhenye at gmail.com> Zhixin Wen <john.wenzhixin at gmail.com> Ziheng Lyu <zihenglv at gmail.com> @@ -113,14 +125,18 @@ Ziheng Lyu <zihenglv at gmail.com> Barracuda Networks, Inc. Counting Ltd. DigitalOcean Inc. +Dolthub Inc. dyves labs AG Facebook Inc. GitHub Inc. Google Inc. InfoSum Ltd. Keybase Inc. +Microsoft Corp. Multiplay Ltd. Percona LLC +PingCAP Inc. Pivotal Inc. +Shattered Silicon Ltd. Stripe Inc. Zendesk Inc. diff --git a/licenses/github.com/go-sql-driver/mysql/CHANGELOG.md b/licenses/github.com/go-sql-driver/mysql/CHANGELOG.md index 5166e4adb57addae33a435bdbb6979796f87e1e9..213215c8d55b853cc2ca240ace7f2e23500cee08 100644 --- a/licenses/github.com/go-sql-driver/mysql/CHANGELOG.md +++ b/licenses/github.com/go-sql-driver/mysql/CHANGELOG.md @@ -162,7 +162,7 @@ New Features: - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) - Support for returning table alias on Columns() (#289, #359, #382) - - Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490) + - Placeholder interpolation, can be activated with the DSN parameter `interpolateParams=true` (#309, #318, #490) - Support for uint64 parameters with high bit set (#332, #345) - Cleartext authentication plugin support (#327) - Exported ParseDSN function and the Config struct (#403, #419, #429) @@ -206,7 +206,7 @@ Changes: - Also exported the MySQLWarning type - mysqlConn.Close returns the first error encountered instead of ignoring all errors - writePacket() automatically writes the packet size to the header - - readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets + - readPacket() uses an iterative approach instead of the recursive approach to merge split packets New Features: @@ -254,7 +254,7 @@ Bugfixes: - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification - Convert to DB timezone when inserting `time.Time` - - Splitted packets (more than 16MB) are now merged correctly + - Split packets (more than 16MB) are now merged correctly - Fixed false positive `io.EOF` errors when the data was fully read - Avoid panics on reuse of closed connections - Fixed empty string producing false nil values diff --git a/licenses/github.com/go-sql-driver/mysql/README.md b/licenses/github.com/go-sql-driver/mysql/README.md index 3b5d229aae97744c4d92993f0bf0bfa3c1c365d7..9d0d806ef2c12db93aac8a837079267e36ae149e 100644 --- a/licenses/github.com/go-sql-driver/mysql/README.md +++ b/licenses/github.com/go-sql-driver/mysql/README.md @@ -40,15 +40,23 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Optional placeholder interpolation ## Requirements - * Go 1.13 or higher. We aim to support the 3 latest versions of Go. - * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) + +* Go 1.19 or higher. We aim to support the 3 latest versions of Go. +* MySQL (5.7+) and MariaDB (10.3+) are supported. +* [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP. + * Do not ask questions about TiDB in our issue tracker or forum. + * [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang) + * [Forum](https://ask.pingcap.com/) +* go-mysql would work with Percona Server, Google CloudSQL or Sphinx (2.2.3+). + * Maintainers won't support them. Do not expect issues are investigated and resolved by maintainers. + * Investigate issues yourself and please send a pull request to fix it. --------------------------------------- ## Installation Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: ```bash -$ go get -u github.com/go-sql-driver/mysql +go get -u github.com/go-sql-driver/mysql ``` Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. @@ -114,6 +122,12 @@ This has the same effect as an empty DSN string: ``` +`dbname` is escaped by [PathEscape()](https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes: + +``` +/dbname%2Fwithslash +``` + Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. #### Password @@ -121,7 +135,7 @@ Passwords can consist of any character. Escaping is **not** necessary. #### Protocol See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. -In general you should use an Unix domain socket if available and TCP otherwise for best performance. +In general you should use a Unix domain socket if available and TCP otherwise for best performance. #### Address For TCP and UDP networks, addresses have the form `host[:port]`. @@ -145,7 +159,7 @@ Default: false ``` `allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. -[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) +[*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local) ##### `allowCleartextPasswords` @@ -194,10 +208,9 @@ Valid Values: <name> Default: none ``` -Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). +Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset fails. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). -Usage of the `charset` parameter is discouraged because it issues additional queries to the server. -Unless you need the fallback behavior, please use `collation` instead. +See also [Unicode Support](#unicode-support). ##### `checkConnLiveness` @@ -226,6 +239,7 @@ The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You s Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). +See also [Unicode Support](#unicode-support). ##### `clientFoundRows` @@ -279,6 +293,15 @@ Note that this sets the location for time.Time values but does not change MySQL' Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. +##### `timeTruncate` + +``` +Type: duration +Default: 0 +``` + +[Truncate time values](https://pkg.go.dev/time#Duration.Truncate) to the specified duration. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + ##### `maxAllowedPacket` ``` Type: decimal number @@ -295,9 +318,25 @@ Valid Values: true, false Default: false ``` -Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. +Allow multiple statements in one query. This can be used to bach multiple queries. Use [Rows.NextResultSet()](https://pkg.go.dev/database/sql#Rows.NextResultSet) to get result of the second and subsequent queries. + +When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly. + +It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example: -When `multiStatements` is used, `?` parameters must only be used in the first statement. +```go +conn, _ := db.Conn(ctx) +conn.Raw(func(conn interface{}) error { + ex := conn.(driver.Execer) + res, err := ex.Exec(` + UPDATE point SET x = 1 WHERE y = 2; + UPDATE point SET x = 2 WHERE y = 3; + `, nil) + // Both slices have 2 elements. + log.Print(res.(mysql.Result).AllRowsAffected()) + log.Print(res.(mysql.Result).AllLastInsertIds()) +}) +``` ##### `parseTime` @@ -393,6 +432,15 @@ Default: 0 I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. +##### `connectionAttributes` + +``` +Type: comma-delimited string of user-defined "key:value" pairs +Valid Values: (<name1>:<value1>,<name2>:<value2>,...) +Default: none +``` + +[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time. ##### System Variables @@ -465,7 +513,7 @@ user:password@/ The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. ## `ColumnType` Support -This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `BIGINT`. +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `MEDIUMINT`, `BIGINT`. ## `context.Context` Support Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. @@ -478,7 +526,7 @@ For this feature you need direct access to the package. Therefore you must chang import "github.com/go-sql-driver/mysql" ``` -Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). +Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. @@ -496,9 +544,11 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v ### Unicode support Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. -Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. +Other charsets / collations can be set using the [`charset`](#charset) or [`collation`](#collation) DSN parameter. -Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. +- When only the `charset` is specified, the `SET NAMES <charset>` query is sent and the server's default collation is used. +- When both the `charset` and `collation` are specified, the `SET NAMES <charset> COLLATE <collation>` query is sent. +- When only the `collation` is specified, the collation is specified in the protocol handshake and the `SET NAMES` query is not sent. This can save one roundtrip, but note that the server may ignore the specified collation silently and use the server's default charset/collation instead. See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support. diff --git a/licenses/github.com/go-sql-driver/mysql/auth.go b/licenses/github.com/go-sql-driver/mysql/auth.go index 1ff203e57bb029c352a2d190180993dfffc71a1d..658259b248d042d7029cb56e3f466621a235199c 100644 --- a/licenses/github.com/go-sql-driver/mysql/auth.go +++ b/licenses/github.com/go-sql-driver/mysql/auth.go @@ -13,10 +13,13 @@ import ( "crypto/rsa" "crypto/sha1" "crypto/sha256" + "crypto/sha512" "crypto/x509" "encoding/pem" "fmt" "sync" + + "filippo.io/edwards25519" ) // server pub keys registry @@ -33,7 +36,7 @@ var ( // Note: The provided rsa.PublicKey instance is exclusively owned by the driver // after registering it and may not be modified. // -// data, err := ioutil.ReadFile("mykey.pem") +// data, err := os.ReadFile("mykey.pem") // if err != nil { // log.Fatal(err) // } @@ -225,6 +228,44 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) } +// authEd25519 does ed25519 authentication used by MariaDB. +func authEd25519(scramble []byte, password string) ([]byte, error) { + // Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c + // Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207 + h := sha512.Sum512([]byte(password)) + + s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) + if err != nil { + return nil, err + } + A := (&edwards25519.Point{}).ScalarBaseMult(s) + + mh := sha512.New() + mh.Write(h[32:]) + mh.Write(scramble) + messageDigest := mh.Sum(nil) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) + if err != nil { + return nil, err + } + + R := (&edwards25519.Point{}).ScalarBaseMult(r) + + kh := sha512.New() + kh.Write(R.Bytes()) + kh.Write(A.Bytes()) + kh.Write(scramble) + hramDigest := kh.Sum(nil) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) + if err != nil { + return nil, err + } + + S := k.MultiplyAdd(k, s, r) + + return append(R.Bytes(), S.Bytes()...), nil +} + func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) if err != nil { @@ -290,8 +331,14 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) return enc, err + case "client_ed25519": + if len(authData) != 32 { + return nil, ErrMalformPkt + } + return authEd25519(authData, mc.cfg.Passwd) + default: - errLog.Print("unknown auth plugin:", plugin) + mc.cfg.Logger.Print("unknown auth plugin:", plugin) return nil, ErrUnknownPlugin } } @@ -338,7 +385,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { switch plugin { - // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + // https://dev.mysql.com/blog-archive/preparing-your-community-connector-for-mysql-8-part-2-sha256/ case "caching_sha2_password": switch len(authData) { case 0: @@ -346,7 +393,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case 1: switch authData[0] { case cachingSha2PasswordFastAuthSuccess: - if err = mc.readResultOK(); err == nil { + if err = mc.resultUnchanged().readResultOK(); err == nil { return nil // auth successful } @@ -376,13 +423,13 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } if data[0] != iAuthMoreData { - return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication") + return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication") } // parse public key block, rest := pem.Decode(data[1:]) if block == nil { - return fmt.Errorf("No Pem data found, data: %s", rest) + return fmt.Errorf("no pem data found, data: %s", rest) } pkix, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { @@ -397,7 +444,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return err } } - return mc.readResultOK() + return mc.resultUnchanged().readResultOK() default: return ErrMalformPkt @@ -426,7 +473,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { if err != nil { return err } - return mc.readResultOK() + return mc.resultUnchanged().readResultOK() } default: diff --git a/licenses/github.com/go-sql-driver/mysql/auth_test.go b/licenses/github.com/go-sql-driver/mysql/auth_test.go index 3ce0ea6e0c8bfc46f2a415134162c29885eef050..8caed1fff74b47f40368ef99c8ad378c25b2097e 100644 --- a/licenses/github.com/go-sql-driver/mysql/auth_test.go +++ b/licenses/github.com/go-sql-driver/mysql/auth_test.go @@ -1328,3 +1328,54 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { t.Errorf("got unexpected data: %v", conn.written) } } + +// Derived from https://github.com/MariaDB/server/blob/6b2287fff23fbdc362499501c562f01d0d2db52e/plugin/auth_ed25519/ed25519-t.c +func TestEd25519Auth(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "foobar" + + authData := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + plugin := "client_ed25519" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{ + 232, 61, 201, 63, 67, 63, 51, 53, 86, 73, 238, 35, 170, 117, 146, + 214, 26, 17, 35, 9, 8, 132, 245, 141, 48, 99, 66, 58, 36, 228, 48, + 84, 115, 254, 187, 168, 88, 162, 249, 57, 35, 85, 79, 238, 167, 106, + 68, 117, 56, 135, 171, 47, 20, 14, 133, 79, 15, 229, 124, 160, 176, + 100, 138, 14, + } + if writtenAuthRespLen != 64 { + t.Fatalf("expected 64 bytes from client, got %d", writtenAuthRespLen) + } + if !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("auth response did not match expected value:\n%v\n%v", writtenAuthResp, expectedAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} diff --git a/licenses/github.com/go-sql-driver/mysql/benchmark_test.go b/licenses/github.com/go-sql-driver/mysql/benchmark_test.go index 97ed781f8f10e6d65ff9517f2c03143b59011e85..a4ecc0a63e2eb3143c930ba16714bc0752c28fef 100644 --- a/licenses/github.com/go-sql-driver/mysql/benchmark_test.go +++ b/licenses/github.com/go-sql-driver/mysql/benchmark_test.go @@ -48,7 +48,7 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { func initDB(b *testing.B, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -105,7 +105,7 @@ func BenchmarkExec(b *testing.B) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) db.SetMaxIdleConns(concurrencyLevel) defer db.Close() @@ -151,7 +151,7 @@ func BenchmarkRoundtripTxt(b *testing.B) { sampleString := string(sample) b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() b.StartTimer() var result string @@ -184,7 +184,7 @@ func BenchmarkRoundtripBin(b *testing.B) { sample, min, max := initRoundtripBenchmarks() b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() stmt := tb.checkStmt(db.Prepare("SELECT ?")) defer stmt.Close() @@ -372,3 +372,59 @@ func BenchmarkQueryRawBytes(b *testing.B) { }) } } + +// BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. +func BenchmarkReceiveMassiveRows(b *testing.B) { + // Setup -- prepare 10000 rows. + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") + defer db.Close() + + sval := strings.Repeat("x", 50) + stmt, err := db.Prepare(`INSERT INTO foo (id, val) VALUES (?, ?)` + strings.Repeat(",(?,?)", 99)) + if err != nil { + b.Errorf("failed to prepare query: %v", err) + return + } + for i := 0; i < 10000; i += 100 { + args := make([]any, 200) + for j := 0; j < 100; j++ { + args[j*2] = i + j + args[j*2+1] = sval + } + _, err := stmt.Exec(args...) + if err != nil { + b.Error(err) + return + } + } + stmt.Close() + + // Use b.Run() to skip expensive setup. + b.Run("query", func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + rows, err := db.Query(`SELECT id, val FROM foo`) + if err != nil { + b.Errorf("failed to select: %v", err) + return + } + for rows.Next() { + var i int + var s sql.RawBytes + err = rows.Scan(&i, &s) + if err != nil { + b.Errorf("failed to scan: %v", err) + _ = rows.Close() + return + } + } + if err = rows.Err(); err != nil { + b.Errorf("failed to read rows: %v", err) + } + _ = rows.Close() + } + }) +} diff --git a/licenses/github.com/go-sql-driver/mysql/collations.go b/licenses/github.com/go-sql-driver/mysql/collations.go index 295bfbe52af369ceb5cb5d00ad377534ff50b365..1cdf97b67e87beea4dac0f8f78980a4b20891908 100644 --- a/licenses/github.com/go-sql-driver/mysql/collations.go +++ b/licenses/github.com/go-sql-driver/mysql/collations.go @@ -9,7 +9,7 @@ package mysql const defaultCollation = "utf8mb4_general_ci" -const binaryCollation = "binary" +const binaryCollationID = 63 // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: diff --git a/licenses/github.com/go-sql-driver/mysql/conncheck_test.go b/licenses/github.com/go-sql-driver/mysql/conncheck_test.go index f7e025680d285a57651ff9e4582d267717ec178c..6b60cb7d6d760cc4a288117edf55656ac3753577 100644 --- a/licenses/github.com/go-sql-driver/mysql/conncheck_test.go +++ b/licenses/github.com/go-sql-driver/mysql/conncheck_test.go @@ -17,7 +17,7 @@ import ( ) func TestStaleConnectionChecks(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { dbt.mustExec("SET @@SESSION.wait_timeout = 2") if err := dbt.db.Ping(); err != nil { diff --git a/licenses/github.com/go-sql-driver/mysql/connection.go b/licenses/github.com/go-sql-driver/mysql/connection.go index 947a883e304daf6500ec525d1987fb9e8998ca33..c170114feb7c6378e2a041481e4008d0ee77ecb0 100644 --- a/licenses/github.com/go-sql-driver/mysql/connection.go +++ b/licenses/github.com/go-sql-driver/mysql/connection.go @@ -23,10 +23,10 @@ import ( type mysqlConn struct { buf buffer netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - affectedRows uint64 - insertId uint64 + rawConn net.Conn // underlying connection when netConn is TLS connection. + result mysqlResult // managed by clearResult() and handleOkPacket(). cfg *Config + connector *connector maxAllowedPacket int maxWriteSize int writeTimeout time.Duration @@ -34,7 +34,6 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool - reset bool // set when the Go SQL package calls ResetSession // for context support (Go 1.8+) watching bool @@ -48,14 +47,19 @@ type mysqlConn struct { // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder + for param, val := range mc.cfg.Params { switch param { // Charset: character_set_connection, character_set_client, character_set_results case "charset": charsets := strings.Split(val, ",") - for i := range charsets { + for _, cs := range charsets { // ignore errors here - a charset may not exist - err = mc.exec("SET NAMES " + charsets[i]) + if mc.cfg.Collation != "" { + err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) + } else { + err = mc.exec("SET NAMES " + cs) + } if err == nil { break } @@ -68,7 +72,7 @@ func (mc *mysqlConn) handleParams() (err error) { default: if cmdSet.Len() == 0 { // Heuristic: 29 chars for each other key=value to reduce reallocations - cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1)) cmdSet.WriteString("SET ") } else { cmdSet.WriteString(", ") @@ -105,7 +109,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } var q string @@ -147,8 +151,9 @@ func (mc *mysqlConn) cleanup() { return } if err := mc.netConn.Close(); err != nil { - errLog.Print(err) + mc.cfg.Logger.Print(err) } + mc.clearResult() } func (mc *mysqlConn) error() error { @@ -163,14 +168,14 @@ func (mc *mysqlConn) error() error { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. - errLog.Print(err) + mc.cfg.Logger.Print(err) return nil, driver.ErrBadConn } @@ -204,7 +209,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf, err := mc.buf.takeCompleteBuffer() if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return "", ErrInvalidConn } buf = buf[:0] @@ -246,7 +251,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf = append(buf, "'0000-00-00'"...) } else { buf = append(buf, '\'') - buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) if err != nil { return "", err } @@ -296,7 +301,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -310,28 +315,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } query = prepared } - mc.affectedRows = 0 - mc.insertId = 0 err := mc.exec(query) if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, err + copied := mc.result + return &copied, err } return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { + handleOk := mc.clearResult() // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { return mc.markBadConn(err) } // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } @@ -348,7 +350,7 @@ func (mc *mysqlConn) exec(query string) error { } } - return mc.discardResults() + return handleOk.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { @@ -356,8 +358,10 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro } func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + handleOk := mc.clearResult() + if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -376,7 +380,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) if err == nil { // Read Result var resLen int - resLen, err = mc.readResultSetHeaderPacket() + resLen, err = handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -404,12 +408,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command + handleOk := mc.clearResult() if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { return nil, err } // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -451,7 +456,7 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return driver.ErrBadConn } @@ -460,11 +465,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { } defer mc.finish() + handleOk := mc.clearResult() if err = mc.writeCommandPacket(comPing); err != nil { return mc.markBadConn(err) } - return mc.readResultOK() + return handleOk.readResultOK() } // BeginTx implements driver.ConnBeginTx interface @@ -639,7 +645,31 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { if mc.closed.Load() { return driver.ErrBadConn } - mc.reset = true + + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.cfg.CheckConnLiveness { + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + var err error + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) + } + if err == nil { + err = connCheck(conn) + } + if err != nil { + mc.cfg.Logger.Print("closing bad idle connection: ", err) + return driver.ErrBadConn + } + } + return nil } diff --git a/licenses/github.com/go-sql-driver/mysql/connection_test.go b/licenses/github.com/go-sql-driver/mysql/connection_test.go index b6764a2f61edb83d58cde835bdd611290d450b66..98c985ae1217e46e165bece253c8b5a73daa9cca 100644 --- a/licenses/github.com/go-sql-driver/mysql/connection_test.go +++ b/licenses/github.com/go-sql-driver/mysql/connection_test.go @@ -179,6 +179,7 @@ func TestPingErrInvalidConn(t *testing.T) { buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), + cfg: NewConfig(), } err := ms.Ping(context.Background()) diff --git a/licenses/github.com/go-sql-driver/mysql/connector.go b/licenses/github.com/go-sql-driver/mysql/connector.go index d567b4e4fc0ffe719d3e6d9e2a07178d474be04d..a0ee62839c6e24896ea34a0e1464d038db888e05 100644 --- a/licenses/github.com/go-sql-driver/mysql/connector.go +++ b/licenses/github.com/go-sql-driver/mysql/connector.go @@ -12,10 +12,53 @@ import ( "context" "database/sql/driver" "net" + "os" + "strconv" + "strings" ) type connector struct { - cfg *Config // immutable private copy. + cfg *Config // immutable private copy. + encodedAttributes string // Encoded connection attributes. +} + +func encodeConnectionAttributes(cfg *Config) string { + connAttrsBuf := make([]byte, 0) + + // default connection attributes + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) + serverHost, _, _ := net.SplitHostPort(cfg.Addr) + if serverHost != "" { + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost) + } + + // user-defined connection attributes + for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") { + k, v, found := strings.Cut(connAttr, ":") + if !found { + continue + } + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, k) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v) + } + + return string(connAttrsBuf) +} + +func newConnector(cfg *Config) *connector { + encodedAttributes := encodeConnectionAttributes(cfg) + return &connector{ + cfg: cfg, + encodedAttributes: encodedAttributes, + } } // Connect implements driver.Connector interface. @@ -23,12 +66,23 @@ type connector struct { func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { var err error + // Invoke beforeConnect if present, with a copy of the configuration + cfg := c.cfg + if c.cfg.beforeConnect != nil { + cfg = c.cfg.Clone() + err = c.cfg.beforeConnect(ctx, cfg) + if err != nil { + return nil, err + } + } + // New mysqlConn mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), - cfg: c.cfg, + cfg: cfg, + connector: c, } mc.parseTime = mc.cfg.ParseTime @@ -56,10 +110,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { if err := tc.SetKeepAlive(true); err != nil { - // Don't send COM_QUIT before handshake. - mc.netConn.Close() - mc.netConn = nil - return nil, err + c.cfg.Logger.Print(err) } } @@ -92,7 +143,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { authResp, err := mc.auth(authData, plugin) if err != nil { // try the default auth plugin, if using the requested plugin failed - errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin authResp, err = mc.auth(authData, plugin) if err != nil { diff --git a/licenses/github.com/go-sql-driver/mysql/connector_test.go b/licenses/github.com/go-sql-driver/mysql/connector_test.go index 976903c5b5ae3e9aa74acb5137b52226eb92f67b..82d8c5989dc90094469059f84f29db2eb45d5f0f 100644 --- a/licenses/github.com/go-sql-driver/mysql/connector_test.go +++ b/licenses/github.com/go-sql-driver/mysql/connector_test.go @@ -8,11 +8,11 @@ import ( ) func TestConnectorReturnsTimeout(t *testing.T) { - connector := &connector{&Config{ + connector := newConnector(&Config{ Net: "tcp", Addr: "1.1.1.1:1234", Timeout: 10 * time.Millisecond, - }} + }) _, err := connector.Connect(context.Background()) if err == nil { diff --git a/licenses/github.com/go-sql-driver/mysql/const.go b/licenses/github.com/go-sql-driver/mysql/const.go index 64e2bced6f8485a22ddc44d919b9f7f3186eb49f..22526e0317f76e37fe5f3917213a9095ecafae7a 100644 --- a/licenses/github.com/go-sql-driver/mysql/const.go +++ b/licenses/github.com/go-sql-driver/mysql/const.go @@ -8,12 +8,25 @@ package mysql +import "runtime" + const ( defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" + + // Connection attributes + // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available + connAttrClientName = "_client_name" + connAttrClientNameValue = "Go-MySQL-Driver" + connAttrOS = "_os" + connAttrOSValue = runtime.GOOS + connAttrPlatform = "_platform" + connAttrPlatformValue = runtime.GOARCH + connAttrPid = "_pid" + connAttrServerHost = "_server_host" ) // MySQL constants documentation: diff --git a/licenses/github.com/go-sql-driver/mysql/driver.go b/licenses/github.com/go-sql-driver/mysql/driver.go index ad7aec215c6c3f860f70d72695af4decea52624d..105316b8164963854cc4f99328578259e80c1d6b 100644 --- a/licenses/github.com/go-sql-driver/mysql/driver.go +++ b/licenses/github.com/go-sql-driver/mysql/driver.go @@ -55,6 +55,15 @@ func RegisterDialContext(net string, dial DialContextFunc) { dials[net] = dial } +// DeregisterDialContext removes the custom dial function registered with the given net. +func DeregisterDialContext(net string) { + dialsLock.Lock() + defer dialsLock.Unlock() + if dials != nil { + delete(dials, net) + } +} + // RegisterDial registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. // addr is passed as a parameter to the dial function. @@ -74,14 +83,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c := &connector{ - cfg: cfg, - } + c := newConnector(cfg) return c.Connect(context.Background()) } +// This variable can be replaced with -ldflags like below: +// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom" +var driverName = "mysql" + func init() { - sql.Register("mysql", &MySQLDriver{}) + if driverName != "" { + sql.Register(driverName, &MySQLDriver{}) + } } // NewConnector returns new driver.Connector. @@ -92,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) { if err := cfg.normalize(); err != nil { return nil, err } - return &connector{cfg: cfg}, nil + return newConnector(cfg), nil } // OpenConnector implements driver.DriverContext. @@ -101,7 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { if err != nil { return nil, err } - return &connector{ - cfg: cfg, - }, nil + return newConnector(cfg), nil } diff --git a/licenses/github.com/go-sql-driver/mysql/driver_test.go b/licenses/github.com/go-sql-driver/mysql/driver_test.go index a1c77672869e824b3ba6f5680b4e05cdbf1901aa..001957244077487b2cf8454c5b37e6f3d2a051cf 100644 --- a/licenses/github.com/go-sql-driver/mysql/driver_test.go +++ b/licenses/github.com/go-sql-driver/mysql/driver_test.go @@ -11,13 +11,13 @@ package mysql import ( "bytes" "context" + "crypto/rand" "crypto/tls" "database/sql" "database/sql/driver" "encoding/json" "fmt" "io" - "io/ioutil" "log" "math" "net" @@ -25,6 +25,7 @@ import ( "os" "reflect" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -32,6 +33,16 @@ import ( "time" ) +// This variable can be replaced with -ldflags like below: +// go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom" +var driverNameTest string + +func init() { + if driverNameTest == "" { + driverNameTest = driverName + } +} + // Ensure that all the driver interfaces are implemented var ( _ driver.Rows = &binaryRows{} @@ -83,7 +94,7 @@ func init() { } type DBTest struct { - *testing.T + testing.TB db *sql.DB } @@ -112,12 +123,14 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT dsn += "&multiStatements=true" var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { - db, err = sql.Open("mysql", dsn) + db, err = sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } defer db.Close() } + // Previous test may be skipped without dropping the test table + db.Exec("DROP TABLE IF EXISTS test") dbt := &DBTest{t, db} for _, test := range tests { @@ -131,52 +144,103 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { t.Skipf("MySQL server not running on %s", netAddr) } - db, err := sql.Open("mysql", dsn) + db, err := sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } defer db.Close() - db.Exec("DROP TABLE IF EXISTS test") + cleanup := func() { + db.Exec("DROP TABLE IF EXISTS test") + } dsn2 := dsn + "&interpolateParams=true" var db2 *sql.DB if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { - db2, err = sql.Open("mysql", dsn2) + db2, err = sql.Open(driverNameTest, dsn2) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } defer db2.Close() } - dsn3 := dsn + "&multiStatements=true" - var db3 *sql.DB - if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { - db3, err = sql.Open("mysql", dsn3) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + for _, test := range tests { + test := test + t.Run("default", func(t *testing.T) { + dbt := &DBTest{t, db} + t.Cleanup(cleanup) + test(dbt) + }) + if db2 != nil { + t.Run("interpolateParams", func(t *testing.T) { + dbt2 := &DBTest{t, db2} + t.Cleanup(cleanup) + test(dbt2) + }) } - defer db3.Close() } +} - dbt := &DBTest{t, db} - dbt2 := &DBTest{t, db2} - dbt3 := &DBTest{t, db3} - for _, test := range tests { - test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") - if db2 != nil { - test(dbt2) - dbt2.db.Exec("DROP TABLE IF EXISTS test") +// runTestsParallel runs the tests in parallel with a separate database connection for each test. +func runTestsParallel(t *testing.T, dsn string, tests ...func(dbt *DBTest, tableName string)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + newTableName := func(t *testing.T) string { + t.Helper() + var buf [8]byte + if _, err := rand.Read(buf[:]); err != nil { + t.Fatal(err) } - if db3 != nil { - test(dbt3) - dbt3.db.Exec("DROP TABLE IF EXISTS test") + return fmt.Sprintf("test_%x", buf[:]) + } + + t.Parallel() + for _, test := range tests { + test := test + + t.Run("default", func(t *testing.T) { + t.Parallel() + + tableName := newTableName(t) + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + t.Cleanup(func() { + db.Exec("DROP TABLE IF EXISTS " + tableName) + db.Close() + }) + + dbt := &DBTest{t, db} + test(dbt, tableName) + }) + + dsn2 := dsn + "&interpolateParams=true" + if _, err := ParseDSN(dsn2); err == errInvalidDSNUnsafeCollation { + t.Run("interpolateParams", func(t *testing.T) { + t.Parallel() + + tableName := newTableName(t) + db, err := sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + t.Cleanup(func() { + db.Exec("DROP TABLE IF EXISTS " + tableName) + db.Close() + }) + + dbt := &DBTest{t, db} + test(dbt, tableName) + }) } } } func (dbt *DBTest) fail(method, query string, err error) { + dbt.Helper() if len(query) > 300 { query = "[query too large to print]" } @@ -184,6 +248,7 @@ func (dbt *DBTest) fail(method, query string, err error) { } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + dbt.Helper() res, err := dbt.db.Exec(query, args...) if err != nil { dbt.fail("exec", query, err) @@ -192,6 +257,7 @@ func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) } func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + dbt.Helper() rows, err := dbt.db.Query(query, args...) if err != nil { dbt.fail("query", query, err) @@ -211,7 +277,7 @@ func maybeSkip(t *testing.T, err error, skipErrno uint16) { } func TestEmptyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { // just a comment, no query rows := dbt.mustQuery("--") defer rows.Close() @@ -223,20 +289,20 @@ func TestEmptyQuery(t *testing.T) { } func TestCRUD(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { // Create Table - dbt.mustExec("CREATE TABLE test (value BOOL)") + dbt.mustExec("CREATE TABLE " + tbl + " (value BOOL)") // Test for unexpected data var out bool - rows := dbt.mustQuery("SELECT * FROM test") + rows := dbt.mustQuery("SELECT * FROM " + tbl) if rows.Next() { dbt.Error("unexpected data in empty table") } rows.Close() // Create Data - res := dbt.mustExec("INSERT INTO test VALUES (1)") + res := dbt.mustExec("INSERT INTO " + tbl + " VALUES (1)") count, err := res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -254,7 +320,7 @@ func TestCRUD(t *testing.T) { } // Read - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if true != out { @@ -270,7 +336,7 @@ func TestCRUD(t *testing.T) { rows.Close() // Update - res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + res = dbt.mustExec("UPDATE "+tbl+" SET value = ? WHERE value = ?", false, true) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -280,7 +346,7 @@ func TestCRUD(t *testing.T) { } // Check Update - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if false != out { @@ -296,7 +362,7 @@ func TestCRUD(t *testing.T) { rows.Close() // Delete - res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + res = dbt.mustExec("DELETE FROM "+tbl+" WHERE value = ?", false) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -306,7 +372,7 @@ func TestCRUD(t *testing.T) { } // Check for unexpected rows - res = dbt.mustExec("DELETE FROM test") + res = dbt.mustExec("DELETE FROM " + tbl) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -317,6 +383,51 @@ func TestCRUD(t *testing.T) { }) } +// TestNumbers test that selecting numeric columns. +// Both of textRows and binaryRows should return same type and value. +func TestNumbersToAny(t *testing.T) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " + + "i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE, iu32 INT UNSIGNED)") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5, 4294967295)") + + // Use binaryRows for interpolateParams=false and textRows for interpolateParams=true. + rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64, iu32 FROM "+tbl+" WHERE id=?", 1) + if !rows.Next() { + dbt.Fatal("no data") + } + var b, i8, i16, i32, i64, f32, f64, iu32 any + err := rows.Scan(&b, &i8, &i16, &i32, &i64, &f32, &f64, &iu32) + if err != nil { + dbt.Fatal(err) + } + if b.(int64) != 1 { + dbt.Errorf("b != 1") + } + if i8.(int64) != 127 { + dbt.Errorf("i8 != 127") + } + if i16.(int64) != 32767 { + dbt.Errorf("i16 != 32767") + } + if i32.(int64) != 2147483647 { + dbt.Errorf("i32 != 2147483647") + } + if i64.(int64) != 9223372036854775807 { + dbt.Errorf("i64 != 9223372036854775807") + } + if f32.(float32) != 1.25 { + dbt.Errorf("f32 != 1.25") + } + if f64.(float64) != 2.5 { + dbt.Errorf("f64 != 2.5") + } + if iu32.(int64) != 4294967295 { + dbt.Errorf("iu32 != 4294967295") + } + }) +} + func TestMultiQuery(t *testing.T) { runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { // Create Table @@ -347,8 +458,8 @@ func TestMultiQuery(t *testing.T) { rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") if rows.Next() { rows.Scan(&out) - if 5 != out { - dbt.Errorf("5 != %d", out) + if out != 5 { + dbt.Errorf("expected 5, got %d", out) } if rows.Next() { @@ -363,7 +474,7 @@ func TestMultiQuery(t *testing.T) { } func TestInt(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} in := int64(42) var out int64 @@ -371,11 +482,11 @@ func TestInt(t *testing.T) { // SIGNED for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -386,16 +497,16 @@ func TestInt(t *testing.T) { } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } // UNSIGNED ZEROFILL for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + " ZEROFILL)") - dbt.mustExec("INSERT INTO test VALUES (?)", in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -406,21 +517,21 @@ func TestInt(t *testing.T) { } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestFloat32(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [2]string{"FLOAT", "DOUBLE"} in := float32(42.23) var out float32 var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")") + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -430,21 +541,21 @@ func TestFloat32(t *testing.T) { dbt.Errorf("%s: no data", v) } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestFloat64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [2]string{"FLOAT", "DOUBLE"} var expected float64 = 42.23 var out float64 var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (42.23)") - rows = dbt.mustQuery("SELECT value FROM test") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if expected != out { @@ -454,21 +565,21 @@ func TestFloat64(t *testing.T) { dbt.Errorf("%s: no data", v) } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestFloat64Placeholder(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [2]string{"FLOAT", "DOUBLE"} var expected float64 = 42.23 var out float64 var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") - rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + dbt.mustExec("CREATE TABLE " + tbl + " (id int, value " + v + ")") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, 42.23)") + rows = dbt.mustQuery("SELECT value FROM "+tbl+" WHERE id = ?", 1) if rows.Next() { rows.Scan(&out) if expected != out { @@ -478,24 +589,24 @@ func TestFloat64Placeholder(t *testing.T) { dbt.Errorf("%s: no data", v) } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestString(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" var out string var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ") CHARACTER SET utf8") - dbt.mustExec("INSERT INTO test VALUES (?)", in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -506,11 +617,11 @@ func TestString(t *testing.T) { } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } // BLOB - dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + dbt.mustExec("CREATE TABLE " + tbl + " (id int, value BLOB) CHARACTER SET utf8") id := 2 in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + @@ -521,9 +632,9 @@ func TestString(t *testing.T) { "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." - dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?, ?)", id, in) - err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + err := dbt.db.QueryRow("SELECT value FROM "+tbl+" WHERE id = ?", id).Scan(&out) if err != nil { dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) } else if out != in { @@ -533,7 +644,7 @@ func TestString(t *testing.T) { } func TestRawBytes(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { v1 := []byte("aaa") v2 := []byte("bbb") rows := dbt.mustQuery("SELECT ?, ?", v1, v2) @@ -562,7 +673,7 @@ func TestRawBytes(t *testing.T) { } func TestRawMessage(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { v1 := json.RawMessage("{}") v2 := json.RawMessage("[]") rows := dbt.mustQuery("SELECT ?, ?", v1, v2) @@ -593,14 +704,14 @@ func (tv testValuer) Value() (driver.Value, error) { } func TestValuer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { in := testValuer{"a_value"} var out string var rows *sql.Rows - dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + dbt.mustExec("CREATE TABLE " + tbl + " (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in.value != out { @@ -610,8 +721,6 @@ func TestValuer(t *testing.T) { dbt.Errorf("Valuer: no data") } rows.Close() - - dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -628,15 +737,15 @@ func (tv testValuerWithValidation) Value() (driver.Value, error) { } func TestValuerWithValidation(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { in := testValuerWithValidation{"a_value"} var out string var rows *sql.Rows - dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") - dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + dbt.mustExec("CREATE TABLE " + tbl + " (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM testValuer") + rows = dbt.mustQuery("SELECT value FROM " + tbl) defer rows.Close() if rows.Next() { @@ -648,19 +757,17 @@ func TestValuerWithValidation(t *testing.T) { dbt.Errorf("Valuer: no data") } - if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + if _, err := dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", testValuerWithValidation{""}); err == nil { dbt.Errorf("Failed to check valuer error") } - if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + if _, err := dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", nil); err != nil { dbt.Errorf("Failed to check nil") } - if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + if _, err := dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", map[string]bool{}); err == nil { dbt.Errorf("Failed to check not valuer") } - - dbt.mustExec("DROP TABLE IF EXISTS testValuer") }) } @@ -894,7 +1001,7 @@ func TestTimestampMicros(t *testing.T) { f0 := format[:19] f1 := format[:21] f6 := format[:26] - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { // check if microseconds are supported. // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width // and not precision. @@ -909,7 +1016,7 @@ func TestTimestampMicros(t *testing.T) { return } _, err := dbt.db.Exec(` - CREATE TABLE test ( + CREATE TABLE ` + tbl + ` ( value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' @@ -918,10 +1025,10 @@ func TestTimestampMicros(t *testing.T) { if err != nil { dbt.Error(err) } - defer dbt.mustExec("DROP TABLE IF EXISTS test") - dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) + defer dbt.mustExec("DROP TABLE IF EXISTS " + tbl) + dbt.mustExec("INSERT INTO "+tbl+" SET value0=?, value1=?, value6=?", f0, f1, f6) var res0, res1, res6 string - rows := dbt.mustQuery("SELECT * FROM test") + rows := dbt.mustQuery("SELECT * FROM " + tbl) defer rows.Close() if !rows.Next() { dbt.Errorf("test contained no selectable values") @@ -943,7 +1050,7 @@ func TestTimestampMicros(t *testing.T) { } func TestNULL(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { nullStmt, err := dbt.db.Prepare("SELECT NULL") if err != nil { dbt.Fatal(err) @@ -1075,12 +1182,12 @@ func TestNULL(t *testing.T) { } // Insert NULL - dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + dbt.mustExec("CREATE TABLE " + tbl + " (dummmy1 int, value int, dummy2 int)") - dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?, ?, ?)", 1, nil, 2) var out interface{} - rows := dbt.mustQuery("SELECT * FROM test") + rows := dbt.mustQuery("SELECT * FROM " + tbl) defer rows.Close() if rows.Next() { rows.Scan(&out) @@ -1104,7 +1211,7 @@ func TestUint64(t *testing.T) { shigh = int64(uhigh) stop = ^shigh ) - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) if err != nil { dbt.Fatal(err) @@ -1168,7 +1275,7 @@ func TestLongData(t *testing.T) { dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) } if rows.Next() { - dbt.Error("LONGBLOB: unexpexted row") + dbt.Error("LONGBLOB: unexpected row") } } else { dbt.Fatalf("LONGBLOB: no data") @@ -1187,7 +1294,7 @@ func TestLongData(t *testing.T) { dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) } if rows.Next() { - dbt.Error("LONGBLOB: unexpexted row") + dbt.Error("LONGBLOB: unexpected row") } } else { if err = rows.Err(); err != nil { @@ -1245,7 +1352,7 @@ func TestLoadData(t *testing.T) { dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") // Local File - file, err := ioutil.TempFile("", "gotest") + file, err := os.CreateTemp("", "gotest") defer os.Remove(file.Name()) if err != nil { dbt.Fatal(err) @@ -1263,7 +1370,7 @@ func TestLoadData(t *testing.T) { dbt.Fatalf("unexpected row count: got %d, want 0", count) } - // Then fille File with data and try to load it + // Then fill File with data and try to load it file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") file.Close() dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) @@ -1294,18 +1401,18 @@ func TestLoadData(t *testing.T) { _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") if err == nil { dbt.Fatal("load non-existent Reader didn't fail") - } else if err.Error() != "Reader 'doesnotexist' is not registered" { + } else if err.Error() != "reader 'doesnotexist' is not registered" { dbt.Fatal(err.Error()) } }) } -func TestFoundRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") +func TestFoundRows1(t *testing.T) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO " + tbl + " (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + res := dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 0") count, err := res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1313,7 +1420,7 @@ func TestFoundRows(t *testing.T) { if count != 2 { dbt.Fatalf("Expected 2 affected rows, got %d", count) } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + res = dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 1") count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1322,11 +1429,14 @@ func TestFoundRows(t *testing.T) { dbt.Fatalf("Expected 2 affected rows, got %d", count) } }) - runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") +} + +func TestFoundRows2(t *testing.T) { + runTestsParallel(t, dsn+"&clientFoundRows=true", func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO " + tbl + " (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + res := dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 0") count, err := res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1334,7 +1444,7 @@ func TestFoundRows(t *testing.T) { if count != 2 { dbt.Fatalf("Expected 2 matched rows, got %d", count) } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + res = dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 1") count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1402,6 +1512,7 @@ func TestReuseClosedConnection(t *testing.T) { if err != nil { t.Fatalf("error preparing statement: %s", err.Error()) } + //lint:ignore SA1019 this is a test _, err = stmt.Exec(nil) if err != nil { t.Fatalf("error executing statement: %s", err.Error()) @@ -1416,6 +1527,7 @@ func TestReuseClosedConnection(t *testing.T) { t.Errorf("panic after reusing a closed connection: %v", err) } }() + //lint:ignore SA1019 this is a test _, err = stmt.Exec(nil) if err != nil && err != driver.ErrBadConn { t.Errorf("unexpected error '%s', expected '%s'", @@ -1458,7 +1570,7 @@ func TestCharset(t *testing.T) { } func TestFailingCharset(t *testing.T) { - runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + runTestsParallel(t, dsn+"&charset=none", func(dbt *DBTest, _ string) { // run query to really establish connection... _, err := dbt.db.Exec("SELECT 1") if err == nil { @@ -1507,7 +1619,7 @@ func TestCollation(t *testing.T) { } func TestColumnsWithAlias(t *testing.T) { - runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + runTestsParallel(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest, _ string) { rows := dbt.mustQuery("SELECT 1 AS A") defer rows.Close() cols, _ := rows.Columns() @@ -1531,7 +1643,7 @@ func TestColumnsWithAlias(t *testing.T) { } func TestRawBytesResultExceedsBuffer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { // defaultBufSize from buffer.go expected := strings.Repeat("abc", defaultBufSize) @@ -1590,7 +1702,7 @@ func TestTimezoneConversion(t *testing.T) { // Special cases func TestRowsClose(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { rows, err := dbt.db.Query("SELECT 1") if err != nil { dbt.Fatal(err) @@ -1615,7 +1727,7 @@ func TestRowsClose(t *testing.T) { // dangling statements // http://code.google.com/p/go/issues/detail?id=3865 func TestCloseStmtBeforeRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { stmt, err := dbt.db.Prepare("SELECT 1") if err != nil { dbt.Fatal(err) @@ -1656,7 +1768,7 @@ func TestCloseStmtBeforeRows(t *testing.T) { // It is valid to have multiple Rows for the same Stmt // http://code.google.com/p/go/issues/detail?id=3734 func TestStmtMultiRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") if err != nil { dbt.Fatal(err) @@ -1807,13 +1919,13 @@ func TestConcurrent(t *testing.T) { } runTests(t, dsn, func(dbt *DBTest) { - var version string - if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil { - dbt.Fatalf("%s", err.Error()) - } - if strings.Contains(strings.ToLower(version), "mariadb") { - t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`) - } + // var version string + // if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil { + // dbt.Fatal(err) + // } + // if strings.Contains(strings.ToLower(version), "mariadb") { + // t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`) + // } var max int err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) @@ -1840,7 +1952,6 @@ func TestConcurrent(t *testing.T) { defer wg.Done() tx, err := dbt.db.Begin() - atomic.AddInt32(&remaining, -1) if err != nil { if err.Error() != "Error 1040: Too many connections" { @@ -1850,7 +1961,7 @@ func TestConcurrent(t *testing.T) { } // keep the connection busy until all connections are open - for remaining > 0 { + for atomic.AddInt32(&remaining, -1) > 0 { if _, err = tx.Exec("DO 1"); err != nil { fatalf("error on conn %d: %s", id, err.Error()) return @@ -1867,7 +1978,7 @@ func TestConcurrent(t *testing.T) { }(i) } - // wait until all conections are open + // wait until all connections are open wg.Wait() if fatalError != "" { @@ -1883,7 +1994,7 @@ func testDialError(t *testing.T, dialErr error, expectErr error) { return nil, dialErr }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1916,13 +2027,13 @@ func TestCustomDial(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - // our custom dial function which justs wraps net.Dial here + // our custom dial function which just wraps net.Dial here RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { var d net.Dialer return d.DialContext(ctx, prot, addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1933,6 +2044,40 @@ func TestCustomDial(t *testing.T) { } } +func TestBeforeConnect(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // dbname is set in the BeforeConnect handle + cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_")) + if err != nil { + t.Fatalf("error parsing DSN: %v", err) + } + + cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error { + c.DBName = dbname + return nil + })) + + connector, err := NewConnector(cfg) + if err != nil { + t.Fatalf("error creating connector: %v", err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var connectedDb string + err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb) + if err != nil { + t.Fatalf("error executing query: %v", err) + } + if connectedDb != dbname { + t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb) + } +} + func TestSQLInjection(t *testing.T) { createTest := func(arg string) func(dbt *DBTest) { return func(dbt *DBTest) { @@ -1995,7 +2140,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) { func TestUnixSocketAuthFail(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // Save the current logger so we can restore it. - oldLogger := errLog + oldLogger := defaultLogger // Set a new logger so we can capture its output. buffer := bytes.NewBuffer(make([]byte, 0, 64)) @@ -2020,7 +2165,7 @@ func TestUnixSocketAuthFail(t *testing.T) { } t.Logf("socket: %s", socket) badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) - db, err := sql.Open("mysql", badDSN) + db, err := sql.Open(driverNameTest, badDSN) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -2155,11 +2300,51 @@ func TestRejectReadOnly(t *testing.T) { } func TestPing(t *testing.T) { + ctx := context.Background() runTests(t, dsn, func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { dbt.fail("Ping", "Ping", err) } }) + + runTests(t, dsn, func(dbt *DBTest) { + conn, err := dbt.db.Conn(ctx) + if err != nil { + dbt.fail("db", "Conn", err) + } + + // Check that affectedRows and insertIds are cleared after each call. + conn.Raw(func(conn interface{}) error { + c := conn.(*mysqlConn) + + // Issue a query that sets affectedRows and insertIds. + q, err := c.Query(`SELECT 1`, nil) + if err != nil { + dbt.fail("Conn", "Query", err) + } + if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) { + dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want) + } + if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) { + dbt.Fatalf("bad insertIds: got %v, want=%v", got, want) + } + q.Close() + + // Verify that Ping() clears both fields. + for i := 0; i < 2; i++ { + if err := c.Ping(ctx); err != nil { + dbt.fail("Pinger", "Ping", err) + } + if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) { + t.Errorf("bad affectedRows: got %v, want=%v", got, want) + } + if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) { + t.Errorf("bad affectedRows: got %v, want=%v", got, want) + } + } + return nil + }) + }) } // See Issue #799 @@ -2169,7 +2354,7 @@ func TestEmptyPassword(t *testing.T) { } dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) - db, err := sql.Open("mysql", dsn) + db, err := sql.Open(driverNameTest, dsn) if err == nil { defer db.Close() err = db.Ping() @@ -2379,10 +2564,47 @@ func TestMultiResultSetNoSelect(t *testing.T) { }) } +func TestExecMultipleResults(t *testing.T) { + ctx := context.Background() + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + CREATE TABLE test ( + id INT NOT NULL AUTO_INCREMENT, + value VARCHAR(255), + PRIMARY KEY (id) + )`) + conn, err := dbt.db.Conn(ctx) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + conn.Raw(func(conn interface{}) error { + //lint:ignore SA1019 this is a test + ex := conn.(driver.Execer) + res, err := ex.Exec(` + INSERT INTO test (value) VALUES ('a'), ('b'); + INSERT INTO test (value) VALUES ('c'), ('d'), ('e'); + `, nil) + if err != nil { + t.Fatalf("insert statements failed: %v", err) + } + mres := res.(Result) + if got, want := mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want) { + t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want) + } + // For INSERTs containing multiple rows, LAST_INSERT_ID() returns the + // first inserted ID, not the last. + if got, want := mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want) { + t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want) + } + return nil + }) + }) +} + // tests if rows are set in a proper state if some results were ignored before // calling rows.NextResultSet. func TestSkipResults(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { rows := dbt.mustQuery("SELECT 1, 2") defer rows.Close() @@ -2400,8 +2622,44 @@ func TestSkipResults(t *testing.T) { }) } +func TestQueryMultipleResults(t *testing.T) { + ctx := context.Background() + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + CREATE TABLE test ( + id INT NOT NULL AUTO_INCREMENT, + value VARCHAR(255), + PRIMARY KEY (id) + )`) + conn, err := dbt.db.Conn(ctx) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + conn.Raw(func(conn interface{}) error { + //lint:ignore SA1019 this is a test + qr := conn.(driver.Queryer) + c := conn.(*mysqlConn) + + // Demonstrate that repeated queries reset the affectedRows + for i := 0; i < 2; i++ { + _, err := qr.Query(` + INSERT INTO test (value) VALUES ('a'), ('b'); + INSERT INTO test (value) VALUES ('c'), ('d'), ('e'); + `, nil) + if err != nil { + t.Fatalf("insert statements failed: %v", err) + } + if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) { + t.Errorf("bad affectedRows: got %v, want=%v", got, want) + } + } + return nil + }) + }) +} + func TestPingContext(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { ctx, cancel := context.WithCancel(context.Background()) cancel() if err := dbt.db.PingContext(ctx); err != context.Canceled { @@ -2411,8 +2669,8 @@ func TestPingContext(t *testing.T) { } func TestContextCancelExec(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) // Delay execution for just a bit until db.ExecContext has begun. @@ -2420,7 +2678,7 @@ func TestContextCancelExec(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2432,7 +2690,7 @@ func TestContextCancelExec(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2440,14 +2698,14 @@ func TestContextCancelExec(t *testing.T) { } // Context is already canceled, so error should come before execution. - if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (1)"); err == nil { dbt.Error("expected error") } else if err.Error() != "context canceled" { dbt.Fatalf("unexpected error: %s", err) } // The second insert query will fail, so the table has no changes. - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { @@ -2457,8 +2715,8 @@ func TestContextCancelExec(t *testing.T) { } func TestContextCancelQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) // Delay execution for just a bit until db.ExecContext has begun. @@ -2466,7 +2724,7 @@ func TestContextCancelQuery(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2478,7 +2736,7 @@ func TestContextCancelQuery(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2486,12 +2744,12 @@ func TestContextCancelQuery(t *testing.T) { } // Context is already canceled, so error should come before execution. - if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO "+tbl+" VALUES (1)"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } // The second insert query will fail, so the table has no changes. - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { @@ -2501,12 +2759,12 @@ func TestContextCancelQuery(t *testing.T) { } func TestContextCancelQueryRow(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") - dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (1), (2), (3)") ctx, cancel := context.WithCancel(context.Background()) - rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM "+tbl) if err != nil { dbt.Fatalf("%s", err.Error()) } @@ -2534,7 +2792,7 @@ func TestContextCancelQueryRow(t *testing.T) { } func TestContextCancelPrepare(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { ctx, cancel := context.WithCancel(context.Background()) cancel() if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { @@ -2544,10 +2802,10 @@ func TestContextCancelPrepare(t *testing.T) { } func TestContextCancelStmtExec(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))") if err != nil { dbt.Fatalf("unexpected error: %v", err) } @@ -2569,7 +2827,7 @@ func TestContextCancelStmtExec(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2579,10 +2837,10 @@ func TestContextCancelStmtExec(t *testing.T) { } func TestContextCancelStmtQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))") if err != nil { dbt.Fatalf("unexpected error: %v", err) } @@ -2604,7 +2862,7 @@ func TestContextCancelStmtQuery(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2618,8 +2876,8 @@ func TestContextCancelBegin(t *testing.T) { t.Skip(`FIXME: it sometime fails with "expected driver.ErrBadConn, got sql: connection is already closed" on windows and macOS`) } - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) conn, err := dbt.db.Conn(ctx) if err != nil { @@ -2636,7 +2894,7 @@ func TestContextCancelBegin(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + if _, err := tx.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2674,8 +2932,8 @@ func TestContextCancelBegin(t *testing.T) { } func TestContextBeginIsolationLevel(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2693,13 +2951,13 @@ func TestContextBeginIsolationLevel(t *testing.T) { dbt.Fatal(err) } - _, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)") + _, err = tx1.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (1)") if err != nil { dbt.Fatal(err) } var v int - row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tbl) if err := row.Scan(&v); err != nil { dbt.Fatal(err) } @@ -2713,7 +2971,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { dbt.Fatal(err) } - row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tbl) if err := row.Scan(&v); err != nil { dbt.Fatal(err) } @@ -2726,8 +2984,8 @@ func TestContextBeginIsolationLevel(t *testing.T) { } func TestContextBeginReadOnly(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2742,14 +3000,14 @@ func TestContextBeginReadOnly(t *testing.T) { } // INSERT queries fail in a READ ONLY transaction. - _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") + _, err = tx.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (1)") if _, ok := err.(*MySQLError); !ok { dbt.Errorf("expected MySQLError, got %v", err) } // SELECT queries can be executed. var v int - row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tbl) if err := row.Scan(&v); err != nil { dbt.Fatal(err) } @@ -2778,13 +3036,18 @@ func TestRowsColumnTypes(t *testing.T) { nd1 := sql.NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} nd2 := sql.NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} ndNULL := sql.NullTime{Time: time.Time{}, Valid: false} - rbNULL := sql.RawBytes(nil) - rb0 := sql.RawBytes("0") - rb42 := sql.RawBytes("42") - rbTest := sql.RawBytes("Test") - rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 - rbx0 := sql.RawBytes("\x00") - rbx42 := sql.RawBytes("\x42") + bNULL := []byte(nil) + nsNULL := sql.NullString{String: "", Valid: false} + // Helper function to build NullString from string literal. + ns := func(s string) sql.NullString { return sql.NullString{String: s, Valid: true} } + ns0 := ns("0") + b0 := []byte("0") + b42 := []byte("42") + nsTest := ns("Test") + bTest := []byte("Test") + b0pad4 := []byte("0\x00\x00\x00") // BINARY right-pads values with 0x00 + bx0 := []byte("\x00") + bx42 := []byte("\x42") var columns = []struct { name string @@ -2797,7 +3060,7 @@ func TestRowsColumnTypes(t *testing.T) { valuesIn [3]string valuesOut [3]interface{} }{ - {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, + {"bit8null", "BIT(8)", "BIT", scanTypeBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{bx0, bNULL, bx42}}, {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, @@ -2811,35 +3074,38 @@ func TestRowsColumnTypes(t *testing.T) { {"tinyuint", "TINYINT UNSIGNED NOT NULL", "UNSIGNED TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, {"smalluint", "SMALLINT UNSIGNED NOT NULL", "UNSIGNED SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, {"biguint", "BIGINT UNSIGNED NOT NULL", "UNSIGNED BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"mediumuint", "MEDIUMINT UNSIGNED NOT NULL", "UNSIGNED MEDIUMINT", scanTypeUint32, false, 0, 0, [3]string{"0", "16777215", "42"}, [3]interface{}{uint32(0), uint32(16777215), uint32(42)}}, {"uint13", "INT(13) UNSIGNED NOT NULL", "UNSIGNED INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, - {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, - {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, - {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, - {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, - {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, - {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, - {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, - {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeString, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.000000", "13.370000", "1234.123456"}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeNullString, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.000000"), nsNULL, ns("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeString, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.0000", "13.3700", "1234.1235"}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeNullString, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.0000"), nsNULL, ns("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeString, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{"0", "13", "-12345"}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeNullString, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{ns0, nsNULL, ns("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0pad4, bNULL, bTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}}, + {"textnull", "TEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, + {"enum", "ENUM('', 'v1', 'v2')", "ENUM", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v2'"}, [3]interface{}{ns(""), ns("v1"), ns("v2")}}, + {"set", "set('', 'v1', 'v2')", "SET", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v1,v2'"}, [3]interface{}{ns(""), ns("v1"), ns("v1,v2")}}, } schema := "" @@ -2945,7 +3211,10 @@ func TestRowsColumnTypes(t *testing.T) { continue } } - + // Avoid panic caused by nil scantype. + if t.Failed() { + return + } values := make([]interface{}, len(tt)) for i := range values { values[i] = reflect.New(types[i]).Interface() @@ -2956,14 +3225,10 @@ func TestRowsColumnTypes(t *testing.T) { if err != nil { t.Fatalf("failed to scan values in %v", err) } - for j := range values { - value := reflect.ValueOf(values[j]).Elem().Interface() + for j, value := range values { + value := reflect.ValueOf(value).Elem().Interface() if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { - if columns[j].scanType == scanTypeRawBytes { - t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) - } else { - t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) - } + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) } } i++ @@ -2979,9 +3244,9 @@ func TestRowsColumnTypes(t *testing.T) { } func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") - dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (value VARCHAR(255))") + dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", (*testValuer)(nil)) // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() }) } @@ -3015,27 +3280,28 @@ func TestRawBytesAreNotModified(t *testing.T) { rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) if err != nil { - t.Fatal(err) + dbt.Fatal(err) } + defer rows.Close() var b int var raw sql.RawBytes - for rows.Next() { - if err := rows.Scan(&b, &raw); err != nil { - t.Fatal(err) - } + if !rows.Next() { + dbt.Fatal("expected at least one row") + } + if err := rows.Scan(&b, &raw); err != nil { + dbt.Fatal(err) + } - before := string(raw) - // Ensure cancelling the query does not corrupt the contents of `raw` - cancel() - time.Sleep(time.Microsecond * 100) - after := string(raw) + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) - if before != after { - t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) - } + if before != after { + dbt.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) } - rows.Close() }() } }) @@ -3058,7 +3324,7 @@ func TestConnectorObeysDialTimeouts(t *testing.T) { return d.DialContext(ctx, prot, addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -3209,3 +3475,105 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) { t.Errorf("connection not closed") } } + +func TestConnectionAttributes(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + defaultAttrs := []string{ + connAttrClientName, + connAttrOS, + connAttrPlatform, + connAttrPid, + connAttrServerHost, + } + host, _, _ := net.SplitHostPort(addr) + defaultAttrValues := []string{ + connAttrClientNameValue, + connAttrOSValue, + connAttrPlatformValue, + strconv.Itoa(os.Getpid()), + host, + } + + customAttrs := []string{"attr1", "fo/o"} + customAttrValues := []string{"value1", "bo/o"} + + customAttrStrs := make([]string, len(customAttrs)) + for i := range customAttrs { + customAttrStrs[i] = fmt.Sprintf("%s:%s", customAttrs[i], customAttrValues[i]) + } + dsn += "&connectionAttributes=" + url.QueryEscape(strings.Join(customAttrStrs, ",")) + + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open(driverNameTest, dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + + queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" + rows := dbt.mustQuery(queryString) + defer rows.Close() + + rowsMap := make(map[string]string) + for rows.Next() { + var attrName, attrValue string + rows.Scan(&attrName, &attrValue) + rowsMap[attrName] = attrValue + } + + connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...) + expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...) + for i := range connAttrs { + if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] { + dbt.Errorf("expected %q, got %q", expectedAttrValues[i], gotValue) + } + } +} + +func TestErrorInMultiResult(t *testing.T) { + // https://github.com/go-sql-driver/mysql/issues/1361 + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + query := ` +CREATE PROCEDURE test_proc1() +BEGIN + SELECT 1,2; + SELECT 3,4; + SIGNAL SQLSTATE '10000' SET MESSAGE_TEXT = "some error", MYSQL_ERRNO = 10000; +END; +` + runCallCommand(dbt, query, "test_proc1") +} + +func runCallCommand(dbt *DBTest, query, name string) { + dbt.mustExec(fmt.Sprintf("DROP PROCEDURE IF EXISTS %s", name)) + dbt.mustExec(query) + defer dbt.mustExec("DROP PROCEDURE " + name) + rows, err := dbt.db.Query(fmt.Sprintf("CALL %s", name)) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + } + for rows.NextResultSet() { + for rows.Next() { + } + } +} diff --git a/licenses/github.com/go-sql-driver/mysql/dsn.go b/licenses/github.com/go-sql-driver/mysql/dsn.go index 4b71aaab0bf271bea2fbe0a79633bf42752814a4..65f5a0242fb5fbf15369b63c76bc007d9e0a29f5 100644 --- a/licenses/github.com/go-sql-driver/mysql/dsn.go +++ b/licenses/github.com/go-sql-driver/mysql/dsn.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "context" "crypto/rsa" "crypto/tls" "errors" @@ -34,22 +35,27 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + // non boolean fields + + User string // Username + Passwd string // Password (requires User) + Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") + Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") + DBName string // Database name + Params map[string]string // Connection parameters + ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + TLSConfig string // TLS configuration name + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger + + // boolean fields AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -63,17 +69,57 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections + + // unexported fields. new options should be come here + + beforeConnect func(context.Context, *Config) error // Invoked before a connection is established + pubKey *rsa.PublicKey // Server public key + timeTruncate time.Duration // Truncate time.Time values to the specified duration } +// Functional Options Pattern +// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis +type Option func(*Config) error + // NewConfig creates a new Config and sets default values. func NewConfig() *Config { - return &Config{ - Collation: defaultCollation, + cfg := &Config{ Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, + Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, } + + return cfg +} + +// Apply applies the given options to the Config object. +func (c *Config) Apply(opts ...Option) error { + for _, opt := range opts { + err := opt(c) + if err != nil { + return err + } + } + return nil +} + +// TimeTruncate sets the time duration to truncate time.Time values in +// query parameters. +func TimeTruncate(d time.Duration) Option { + return func(cfg *Config) error { + cfg.timeTruncate = d + return nil + } +} + +// BeforeConnect sets the function to be invoked before a connection is established. +func BeforeConnect(fn func(context.Context, *Config) error) Option { + return func(cfg *Config) error { + cfg.beforeConnect = fn + return nil + } } func (cfg *Config) Clone() *Config { @@ -97,7 +143,7 @@ func (cfg *Config) Clone() *Config { } func (cfg *Config) normalize() error { - if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + if cfg.InterpolateParams && cfg.Collation != "" && unsafeCollations[cfg.Collation] { return errInvalidDSNUnsafeCollation } @@ -153,6 +199,10 @@ func (cfg *Config) normalize() error { } } + if cfg.Logger == nil { + cfg.Logger = defaultLogger + } + return nil } @@ -171,6 +221,8 @@ func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) { // FormatDSN formats the given Config into a DSN string which can be passed to // the driver. +// +// Note: use [NewConnector] and [database/sql.OpenDB] to open a connection from a [*Config]. func (cfg *Config) FormatDSN() string { var buf bytes.Buffer @@ -196,7 +248,7 @@ func (cfg *Config) FormatDSN() string { // /dbname buf.WriteByte('/') - buf.WriteString(cfg.DBName) + buf.WriteString(url.PathEscape(cfg.DBName)) // [?param1=value1&...¶mN=valueN] hasParam := false @@ -230,7 +282,7 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") } - if col := cfg.Collation; col != defaultCollation && len(col) > 0 { + if col := cfg.Collation; col != "" { writeDSNParam(&buf, &hasParam, "collation", col) } @@ -254,6 +306,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "parseTime", "true") } + if cfg.timeTruncate > 0 { + writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String()) + } + if cfg.ReadTimeout > 0 { writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String()) } @@ -358,7 +414,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) { break } } - cfg.DBName = dsn[i+1 : j] + + dbname := dsn[i+1 : j] + if cfg.DBName, err = url.PathUnescape(dbname); err != nil { + return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err) + } break } @@ -378,13 +438,13 @@ func ParseDSN(dsn string) (cfg *Config, err error) { // Values must be url.QueryEscape'ed func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { - param := strings.SplitN(v, "=", 2) - if len(param) != 2 { + key, value, found := strings.Cut(v, "=") + if !found { continue } // cfg params - switch value := param[1]; param[0] { + switch key { // Disable INFILE allowlist / enable all files case "allowAllFiles": var isBool bool @@ -490,6 +550,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // time.Time truncation + case "timeTruncate": + cfg.timeTruncate, err = time.ParseDuration(value) + if err != nil { + return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err) + } + // I/O read Timeout case "readTimeout": cfg.ReadTimeout, err = time.ParseDuration(value) @@ -554,13 +621,22 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + + // Connection attributes + case "connectionAttributes": + connectionAttributes, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid connectionAttributes value: %v", err) + } + cfg.ConnectionAttributes = connectionAttributes + default: // lazy init if cfg.Params == nil { cfg.Params = make(map[string]string) } - if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[key], err = url.QueryUnescape(value); err != nil { return } } diff --git a/licenses/github.com/go-sql-driver/mysql/dsn_fuzz_test.go b/licenses/github.com/go-sql-driver/mysql/dsn_fuzz_test.go new file mode 100644 index 0000000000000000000000000000000000000000..04c56ad4500080c2404a9f997c2d452f3e8e3c06 --- /dev/null +++ b/licenses/github.com/go-sql-driver/mysql/dsn_fuzz_test.go @@ -0,0 +1,47 @@ +//go:build go1.18 +// +build go1.18 + +package mysql + +import ( + "net" + "testing" +) + +func FuzzFormatDSN(f *testing.F) { + for _, test := range testDSNs { // See dsn_test.go + f.Add(test.in) + } + + f.Fuzz(func(t *testing.T, dsn1 string) { + // Do not waste resources + if len(dsn1) > 1000 { + t.Skip("ignore: too long") + } + + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Skipf("invalid DSN: %v", err) + } + + dsn2 := cfg1.FormatDSN() + if dsn2 == dsn1 { + return + } + + // Skip known cases of bad config that are not strictly checked by ParseDSN + if _, _, err := net.SplitHostPort(cfg1.Addr); err != nil { + t.Skipf("invalid addr %q: %v", cfg1.Addr, err) + } + + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Fatalf("%q rewritten as %q: %v", dsn1, dsn2, err) + } + + dsn3 := cfg2.FormatDSN() + if dsn3 != dsn2 { + t.Errorf("%q rewritten as %q", dsn2, dsn3) + } + }) +} diff --git a/licenses/github.com/go-sql-driver/mysql/dsn_test.go b/licenses/github.com/go-sql-driver/mysql/dsn_test.go index 41a6a29fa6dad4f080eb93879e0e8f56684a9e24..dd8cd935c1138ae43baa8ca2a07d1cb456be69c8 100644 --- a/licenses/github.com/go-sql-driver/mysql/dsn_test.go +++ b/licenses/github.com/go-sql-driver/mysql/dsn_test.go @@ -22,71 +22,80 @@ var testDSNs = []struct { out *Config }{{ "username:password@protocol(address)/dbname?param=value", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, }, { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, Logger: defaultLogger, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, }, { "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToPlaintext=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: 0, Logger: defaultLogger, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", - &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "/dbname%2Fwithslash", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname/withslash", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "@/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:p@/ssword@/", - &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "unix/?arg=%2Fsome%2Fpath.ext", - &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(127.0.0.1)/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(de:ad:be:ef::ca:fe)/dbname", - &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, timeTruncate: time.Hour}, }, } func TestDSNParser(t *testing.T) { for i, tst := range testDSNs { - cfg, err := ParseDSN(tst.in) - if err != nil { - t.Error(err.Error()) - } + t.Run(tst.in, func(t *testing.T) { + cfg, err := ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + return + } - // pointer not static - cfg.TLS = nil + // pointer not static + cfg.TLS = nil - if !reflect.DeepEqual(cfg, tst.out) { - t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) - } + if !reflect.DeepEqual(cfg, tst.out) { + t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) + } + }) } } @@ -113,27 +122,39 @@ func TestDSNParserInvalid(t *testing.T) { func TestDSNReformat(t *testing.T) { for i, tst := range testDSNs { - dsn1 := tst.in - cfg1, err := ParseDSN(dsn1) - if err != nil { - t.Error(err.Error()) - continue - } - cfg1.TLS = nil // pointer not static - res1 := fmt.Sprintf("%+v", cfg1) - - dsn2 := cfg1.FormatDSN() - cfg2, err := ParseDSN(dsn2) - if err != nil { - t.Error(err.Error()) - continue - } - cfg2.TLS = nil // pointer not static - res2 := fmt.Sprintf("%+v", cfg2) + t.Run(tst.in, func(t *testing.T) { + dsn1 := tst.in + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Error(err.Error()) + return + } + cfg1.TLS = nil // pointer not static + res1 := fmt.Sprintf("%+v", cfg1) - if res1 != res2 { - t.Errorf("%d. %q does not match %q", i, res2, res1) - } + dsn2 := cfg1.FormatDSN() + if dsn2 != dsn1 { + // Just log + t.Logf("%d. %q reformatted as %q", i, dsn1, dsn2) + } + + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Error(err.Error()) + return + } + cfg2.TLS = nil // pointer not static + res2 := fmt.Sprintf("%+v", cfg2) + + if res1 != res2 { + t.Errorf("%d. %q does not match %q", i, res2, res1) + } + + dsn3 := cfg2.FormatDSN() + if dsn3 != dsn2 { + t.Errorf("%d. %q does not match %q", i, dsn2, dsn3) + } + }) } } diff --git a/licenses/github.com/go-sql-driver/mysql/errors.go b/licenses/github.com/go-sql-driver/mysql/errors.go index ff9a8f088c395457f1155ff64f7db2a81d0838f0..a9a3060c982123d799970a58f658ea8c71b29cdb 100644 --- a/licenses/github.com/go-sql-driver/mysql/errors.go +++ b/licenses/github.com/go-sql-driver/mysql/errors.go @@ -21,7 +21,7 @@ var ( ErrMalformPkt = errors.New("malformed packet") ErrNoTLS = errors.New("TLS requested but server does not support TLS") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") - ErrNativePassword = errors.New("this user requires mysql native password authentication.") + ErrNativePassword = errors.New("this user requires mysql native password authentication") ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") ErrUnknownPlugin = errors.New("this authentication plugin is not supported") ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") @@ -37,20 +37,26 @@ var ( errBadConnNoWrite = errors.New("bad connection") ) -var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) +var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) // Logger is used to log critical error messages. type Logger interface { Print(v ...interface{}) } -// SetLogger is used to set the logger for critical errors. +// NopLogger is a nop implementation of the Logger interface. +type NopLogger struct{} + +// Print implements Logger interface. +func (nl *NopLogger) Print(_ ...interface{}) {} + +// SetLogger is used to set the default logger for critical errors. // The initial logger is os.Stderr. func SetLogger(logger Logger) error { if logger == nil { return errors.New("logger is nil") } - errLog = logger + defaultLogger = logger return nil } diff --git a/licenses/github.com/go-sql-driver/mysql/errors_test.go b/licenses/github.com/go-sql-driver/mysql/errors_test.go index 43213f98e670f421e6d07632d2f1845ac190b086..53d634454aa30689058404891907a3e2ae1e2aca 100644 --- a/licenses/github.com/go-sql-driver/mysql/errors_test.go +++ b/licenses/github.com/go-sql-driver/mysql/errors_test.go @@ -16,9 +16,9 @@ import ( ) func TestErrorsSetLogger(t *testing.T) { - previous := errLog + previous := defaultLogger defer func() { - errLog = previous + defaultLogger = previous }() // set up logger @@ -28,7 +28,7 @@ func TestErrorsSetLogger(t *testing.T) { // print SetLogger(logger) - errLog.Print("test") + defaultLogger.Print("test") // check result if actual := buffer.String(); actual != expected { diff --git a/licenses/github.com/go-sql-driver/mysql/fields.go b/licenses/github.com/go-sql-driver/mysql/fields.go index e0654a83d9989358cefd4563ffaf22a09b35f060..2a397b2456c7b409d5de7712105685dd5068cfde 100644 --- a/licenses/github.com/go-sql-driver/mysql/fields.go +++ b/licenses/github.com/go-sql-driver/mysql/fields.go @@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeBit: return "BIT" case fieldTypeBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "TEXT" } return "BLOB" @@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeGeometry: return "GEOMETRY" case fieldTypeInt24: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED MEDIUMINT" + } return "MEDIUMINT" case fieldTypeJSON: return "JSON" @@ -46,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string { } return "INT" case fieldTypeLongBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "LONGTEXT" } return "LONGBLOB" @@ -56,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string { } return "BIGINT" case fieldTypeMediumBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "MEDIUMTEXT" } return "MEDIUMBLOB" @@ -74,7 +77,12 @@ func (mf *mysqlField) typeDatabaseName() string { } return "SMALLINT" case fieldTypeString: - if mf.charSet == collations[binaryCollation] { + if mf.flags&flagEnum != 0 { + return "ENUM" + } else if mf.flags&flagSet != 0 { + return "SET" + } + if mf.charSet == binaryCollationID { return "BINARY" } return "CHAR" @@ -88,17 +96,17 @@ func (mf *mysqlField) typeDatabaseName() string { } return "TINYINT" case fieldTypeTinyBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "TINYTEXT" } return "TINYBLOB" case fieldTypeVarChar: - if mf.charSet == collations[binaryCollation] { + if mf.charSet == binaryCollationID { return "VARBINARY" } return "VARCHAR" case fieldTypeVarString: - if mf.charSet == collations[binaryCollation] { + if mf.charSet == binaryCollationID { return "VARBINARY" } return "VARCHAR" @@ -110,21 +118,23 @@ func (mf *mysqlField) typeDatabaseName() string { } var ( - scanTypeFloat32 = reflect.TypeOf(float32(0)) - scanTypeFloat64 = reflect.TypeOf(float64(0)) - scanTypeInt8 = reflect.TypeOf(int8(0)) - scanTypeInt16 = reflect.TypeOf(int16(0)) - scanTypeInt32 = reflect.TypeOf(int32(0)) - scanTypeInt64 = reflect.TypeOf(int64(0)) - scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) - scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) - scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) - scanTypeUint8 = reflect.TypeOf(uint8(0)) - scanTypeUint16 = reflect.TypeOf(uint16(0)) - scanTypeUint32 = reflect.TypeOf(uint32(0)) - scanTypeUint64 = reflect.TypeOf(uint64(0)) - scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) - scanTypeUnknown = reflect.TypeOf(new(interface{})) + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeString = reflect.TypeOf("") + scanTypeNullString = reflect.TypeOf(sql.NullString{}) + scanTypeBytes = reflect.TypeOf([]byte{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) ) type mysqlField struct { @@ -187,12 +197,18 @@ func (mf *mysqlField) scanType() reflect.Type { } return scanTypeNullFloat + case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, + fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry: + if mf.charSet == binaryCollationID { + return scanTypeBytes + } + fallthrough case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, - fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, - fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, - fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, - fieldTypeTime: - return scanTypeRawBytes + fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime: + if mf.flags&flagNotNULL != 0 { + return scanTypeString + } + return scanTypeNullString case fieldTypeDate, fieldTypeNewDate, fieldTypeTimestamp, fieldTypeDateTime: diff --git a/licenses/github.com/go-sql-driver/mysql/fuzz.go b/licenses/github.com/go-sql-driver/mysql/fuzz.go deleted file mode 100644 index 3a4ec25a9e42aa8c485582f2e66d29493e3ad7f7..0000000000000000000000000000000000000000 --- a/licenses/github.com/go-sql-driver/mysql/fuzz.go +++ /dev/null @@ -1,25 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. -// -// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build gofuzz -// +build gofuzz - -package mysql - -import ( - "database/sql" -) - -func Fuzz(data []byte) int { - db, err := sql.Open("mysql", string(data)) - if err != nil { - return 0 - } - db.Close() - return 1 -} diff --git a/licenses/github.com/go-sql-driver/mysql/go.mod b/licenses/github.com/go-sql-driver/mysql/go.mod index 2511104786fd325e626ee358841f1e79810420eb..4629714c0ced4cf635d88b182ce7d8fcf3faf83f 100644 --- a/licenses/github.com/go-sql-driver/mysql/go.mod +++ b/licenses/github.com/go-sql-driver/mysql/go.mod @@ -1,3 +1,5 @@ module github.com/go-sql-driver/mysql -go 1.13 +go 1.18 + +require filippo.io/edwards25519 v1.1.0 diff --git a/licenses/github.com/go-sql-driver/mysql/go.sum b/licenses/github.com/go-sql-driver/mysql/go.sum new file mode 100644 index 0000000000000000000000000000000000000000..359ca94b4be4cc4313a261940c87f02fc8bb6f62 --- /dev/null +++ b/licenses/github.com/go-sql-driver/mysql/go.sum @@ -0,0 +1,2 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= diff --git a/licenses/github.com/go-sql-driver/mysql/infile.go b/licenses/github.com/go-sql-driver/mysql/infile.go index 3279dcffd7e310a83e10f1cf7a22e09b8f585463..0c8af9f110e7747b197b0f806ee980e6ceea5eab 100644 --- a/licenses/github.com/go-sql-driver/mysql/infile.go +++ b/licenses/github.com/go-sql-driver/mysql/infile.go @@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) { const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP -func (mc *mysqlConn) handleInFileRequest(name string) (err error) { +func (mc *okHandler) handleInFileRequest(name string) (err error) { var rdr io.Reader var data []byte packetSize := defaultPacketSize @@ -116,10 +116,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { defer deferredClose(&err, cl) } } else { - err = fmt.Errorf("Reader '%s' is <nil>", name) + err = fmt.Errorf("reader '%s' is <nil>", name) } } else { - err = fmt.Errorf("Reader '%s' is not registered", name) + err = fmt.Errorf("reader '%s' is not registered", name) } } else { // File name = strings.Trim(name, `"`) @@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil { return ioErr } } @@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { if data == nil { data = make([]byte, 4) } - if ioErr := mc.writePacket(data[:4]); ioErr != nil { + if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } @@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { return mc.readResultOK() } - mc.readPacket() + mc.conn().readPacket() return err } diff --git a/licenses/github.com/go-sql-driver/mysql/nulltime.go b/licenses/github.com/go-sql-driver/mysql/nulltime.go index 36c8a42c57b561acf1c44f56ec22939b9107c3a5..7d381d5c28f2522f3b097c4e562b3ae17773b17a 100644 --- a/licenses/github.com/go-sql-driver/mysql/nulltime.go +++ b/licenses/github.com/go-sql-driver/mysql/nulltime.go @@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) { } nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) + return fmt.Errorf("can't convert %T to time.Time", value) } // Value implements the driver Valuer interface. diff --git a/licenses/github.com/go-sql-driver/mysql/packets.go b/licenses/github.com/go-sql-driver/mysql/packets.go index ee05c95a8704706c3b3ac47c7dbd45b4ac5b7f41..3d6e5308cec4d71dafa75a2f0ff16e44eefd6e46 100644 --- a/licenses/github.com/go-sql-driver/mysql/packets.go +++ b/licenses/github.com/go-sql-driver/mysql/packets.go @@ -14,10 +14,10 @@ import ( "database/sql/driver" "encoding/binary" "encoding/json" - "errors" "fmt" "io" "math" + "strconv" "time" ) @@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - errLog.Print(err) + mc.cfg.Logger.Print(err) mc.Close() return nil, ErrInvalidConn } @@ -44,6 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { + mc.Close() if data[3] > mc.sequence { return nil, ErrPktSyncMul } @@ -56,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if pktLen == 0 { // there was no previous packet if prevData == nil { - errLog.Print(ErrMalformPkt) + mc.cfg.Logger.Print(ErrMalformPkt) mc.Close() return nil, ErrInvalidConn } @@ -70,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - errLog.Print(err) + mc.cfg.Logger.Print(err) mc.Close() return nil, ErrInvalidConn } @@ -97,34 +98,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { return ErrPktTooLarge } - // Perform a stale connection check. We only perform this check for - // the first query on a connection that has been checked out of the - // connection pool: a fresh connection from the pool is more likely - // to be stale, and it has not performed any previous writes that - // could cause data corruption, so it's safe to return ErrBadConn - // if the check fails. - if mc.reset { - mc.reset = false - conn := mc.netConn - if mc.rawConn != nil { - conn = mc.rawConn - } - var err error - if mc.cfg.CheckConnLiveness { - if mc.cfg.ReadTimeout != 0 { - err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) - } - if err == nil { - err = connCheck(conn) - } - } - if err != nil { - errLog.Print("closing bad idle connection: ", err) - mc.Close() - return driver.ErrBadConn - } - } - for { var size int if pktLen >= maxPacketSize { @@ -161,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) mc.cleanup() - errLog.Print(ErrMalformPkt) + mc.cfg.Logger.Print(ErrMalformPkt) } else { if cerr := mc.canceled.Value(); cerr != nil { return cerr @@ -171,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { return errBadConnNoWrite } mc.cleanup() - errLog.Print(err) + mc.cfg.Logger.Print(err) } return ErrInvalidConn } @@ -239,7 +212,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // reserved (all [00]) [10 bytes] pos += 1 + 2 + 2 + 1 + 10 - // second part of the password cipher [mininum 13 bytes], + // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) // // The web documentation is ambiguous about the length. However, @@ -285,6 +258,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | + clientConnectAttrs | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -318,11 +292,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += n + 1 } + // encode length of the connection attributes + var connAttrsLEIBuf [9]byte + connAttrsLen := len(mc.connector.encodedAttributes) + connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) + pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) + // Calculate packet length and get buffer with that size - data, err := mc.buf.takeSmallBuffer(pktLen + 4) + data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -338,14 +318,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[10] = 0x00 data[11] = 0x00 - // Charset [1 byte] + // Collation ID [1 byte] + cname := mc.cfg.Collation + if cname == "" { + cname = defaultCollation + } var found bool - data[12], found = collations[mc.cfg.Collation] + data[12], found = collations[cname] if !found { // Note possibility for false negatives: // could be triggered although the collation is valid if the // collations map does not contain entries the server supports. - return errors.New("unknown collation") + return fmt.Errorf("unknown collation: %q", cname) } // Filler [23 bytes] (all 0x00) @@ -394,6 +378,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[pos] = 0x00 pos++ + // Connection Attributes + pos += copy(data[pos:], connAttrsLEI) + pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + // Send Auth packet return mc.writePacket(data[:pos]) } @@ -404,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { data, err := mc.buf.takeSmallBuffer(pktLen) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -424,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -443,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -464,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -495,7 +483,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { switch data[0] { case iOK: - return nil, "", mc.handleOkPacket(data) + // resultUnchanged, since auth happens before any queries or + // commands have been executed. + return nil, "", mc.resultUnchanged().handleOkPacket(data) case iAuthMoreData: return data[1:], "", err @@ -518,9 +508,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { } } -// Returns error if Packet is not an 'Result OK'-Packet -func (mc *mysqlConn) readResultOK() error { - data, err := mc.readPacket() +// Returns error if Packet is not a 'Result OK'-Packet +func (mc *okHandler) readResultOK() error { + data, err := mc.conn().readPacket() if err != nil { return err } @@ -528,13 +518,17 @@ func (mc *mysqlConn) readResultOK() error { if data[0] == iOK { return mc.handleOkPacket(data) } - return mc.handleErrorPacket(data) + return mc.conn().handleErrorPacket(data) } // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset -func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { - data, err := mc.readPacket() +func (mc *okHandler) readResultSetHeaderPacket() (int, error) { + // handleOkPacket replaces both values; other cases leave the values unchanged. + mc.result.affectedRows = append(mc.result.affectedRows, 0) + mc.result.insertIds = append(mc.result.insertIds, 0) + + data, err := mc.conn().readPacket() if err == nil { switch data[0] { @@ -542,19 +536,16 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { return 0, mc.handleOkPacket(data) case iERR: - return 0, mc.handleErrorPacket(data) + return 0, mc.conn().handleErrorPacket(data) case iLocalInFile: return 0, mc.handleInFileRequest(string(data[1:])) } // column count - num, _, n := readLengthEncodedInteger(data) - if n-len(data) == 0 { - return int(num), nil - } - - return 0, ErrMalformPkt + num, _, _ := readLengthEncodedInteger(data) + // ignore remaining data in the packet. see #1478. + return int(num), nil } return 0, err } @@ -607,18 +598,61 @@ func readStatus(b []byte) statusFlag { return statusFlag(b[0]) | statusFlag(b[1])<<8 } +// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't +// need to be cleared first (e.g. during authentication, or while additional +// resultsets are being fetched.) +func (mc *mysqlConn) resultUnchanged() *okHandler { + return (*okHandler)(mc) +} + +// okHandler represents the state of the connection when mysqlConn.result has +// been prepared for processing of OK packets. +// +// To correctly populate mysqlConn.result (updated by handleOkPacket()), all +// callpaths must either: +// +// 1. first clear it using clearResult(), or +// 2. confirm that they don't need to (by calling resultUnchanged()). +// +// Both return an instance of type *okHandler. +type okHandler mysqlConn + +// Exposes the underlying type's methods. +func (mc *okHandler) conn() *mysqlConn { + return (*mysqlConn)(mc) +} + +// clearResult clears the connection's stored affectedRows and insertIds +// fields. +// +// It returns a handler that can process OK responses. +func (mc *mysqlConn) clearResult() *okHandler { + mc.result = mysqlResult{} + return (*okHandler)(mc) +} + // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet -func (mc *mysqlConn) handleOkPacket(data []byte) error { +func (mc *okHandler) handleOkPacket(data []byte) error { var n, m int + var affectedRows, insertId uint64 // 0x00 [1 byte] // Affected rows [Length Coded Binary] - mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + affectedRows, _, n = readLengthEncodedInteger(data[1:]) // Insert id [Length Coded Binary] - mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + insertId, _, m = readLengthEncodedInteger(data[1+n:]) + + // Update for the current statement result (only used by + // readResultSetHeaderPacket). + if len(mc.result.affectedRows) > 0 { + mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows) + } + if len(mc.result.insertIds) > 0 { + mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId) + } // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) @@ -769,7 +803,8 @@ func (rows *textRows) readRow(dest []driver.Value) error { for i := range dest { // Read bytes and convert to string - dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + var buf []byte + buf, isNull, n, err = readLengthEncodedString(data[pos:]) pos += n if err != nil { @@ -781,19 +816,40 @@ func (rows *textRows) readRow(dest []driver.Value) error { continue } - if !mc.parseTime { - continue - } - - // Parse time field switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: - if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { - return err + if mc.parseTime { + dest[i], err = parseDateTime(buf, mc.cfg.Loc) + } else { + dest[i] = buf } + + case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong: + dest[i], err = strconv.ParseInt(string(buf), 10, 64) + + case fieldTypeLongLong: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i], err = strconv.ParseUint(string(buf), 10, 64) + } else { + dest[i], err = strconv.ParseInt(string(buf), 10, 64) + } + + case fieldTypeFloat: + var d float64 + d, err = strconv.ParseFloat(string(buf), 32) + dest[i] = float32(d) + + case fieldTypeDouble: + dest[i], err = strconv.ParseFloat(string(buf), 64) + + default: + dest[i] = buf + } + if err != nil { + return err } } @@ -938,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -1116,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b, err = appendDateTime(b, v.In(mc.cfg.Loc)) + b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) if err != nil { return err } @@ -1137,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } } @@ -1149,7 +1205,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return mc.writePacket(data) } -func (mc *mysqlConn) discardResults() error { +// For each remaining resultset in the stream, discards its rows and updates +// mc.affectedRows and mc.insertIds. +func (mc *okHandler) discardResults() error { for mc.status&statusMoreResultsExists != 0 { resLen, err := mc.readResultSetHeaderPacket() if err != nil { @@ -1157,11 +1215,11 @@ func (mc *mysqlConn) discardResults() error { } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.conn().readUntilEOF(); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.conn().readUntilEOF(); err != nil { return err } } diff --git a/licenses/github.com/go-sql-driver/mysql/packets_test.go b/licenses/github.com/go-sql-driver/mysql/packets_test.go index b61e4dbf777bec389c8e38adb6b768b360a14c50..fa4683eab3a5ab0ba63c5d322e7c16a2b765de36 100644 --- a/licenses/github.com/go-sql-driver/mysql/packets_test.go +++ b/licenses/github.com/go-sql-driver/mysql/packets_test.go @@ -96,9 +96,11 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) + connector := newConnector(NewConfig()) mc := &mysqlConn{ buf: newBuffer(conn), - cfg: NewConfig(), + cfg: connector.cfg, + connector: connector, netConn: conn, closech: make(chan struct{}), maxAllowedPacket: defaultMaxAllowedPacket, @@ -128,30 +130,34 @@ func TestReadPacketSingleByte(t *testing.T) { } func TestReadPacketWrongSequenceID(t *testing.T) { - conn := new(mockConn) - mc := &mysqlConn{ - buf: newBuffer(conn), - } - - // too low sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} - conn.maxReads = 1 - mc.sequence = 1 - _, err := mc.readPacket() - if err != ErrPktSync { - t.Errorf("expected ErrPktSync, got %v", err) - } - - // reset - conn.reads = 0 - mc.sequence = 0 - mc.buf = newBuffer(conn) - - // too high sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} - _, err = mc.readPacket() - if err != ErrPktSyncMul { - t.Errorf("expected ErrPktSyncMul, got %v", err) + for _, testCase := range []struct { + ClientSequenceID byte + ServerSequenceID byte + ExpectedErr error + }{ + { + ClientSequenceID: 1, + ServerSequenceID: 0, + ExpectedErr: ErrPktSync, + }, + { + ClientSequenceID: 0, + ServerSequenceID: 0x42, + ExpectedErr: ErrPktSyncMul, + }, + } { + conn, mc := newRWMockConn(testCase.ClientSequenceID) + + conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} + _, err := mc.readPacket() + if err != testCase.ExpectedErr { + t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) + } + + // connection should not be returned to the pool in this state + if mc.IsValid() { + t.Errorf("expected IsValid() to be false") + } } } @@ -179,7 +185,7 @@ func TestReadPacketSplit(t *testing.T) { data[4] = 0x11 data[maxPacketSize+3] = 0x22 - // 2nd packet has payload length 0 and squence id 1 + // 2nd packet has payload length 0 and sequence id 1 // 00 00 00 01 data[pkt2ofs+3] = 0x01 @@ -211,7 +217,7 @@ func TestReadPacketSplit(t *testing.T) { data[pkt2ofs+4] = 0x33 data[pkt2ofs+maxPacketSize+3] = 0x44 - // 3rd packet has payload length 0 and squence id 2 + // 3rd packet has payload length 0 and sequence id 2 // 00 00 00 02 data[pkt3ofs+3] = 0x02 @@ -265,6 +271,7 @@ func TestReadPacketFail(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), closech: make(chan struct{}), + cfg: NewConfig(), } // illegal empty (stand-alone) packet diff --git a/licenses/github.com/go-sql-driver/mysql/result.go b/licenses/github.com/go-sql-driver/mysql/result.go index c6438d0347db5c570524ba792975883631cce8e7..d516314683719c8601e541d3036aff0b4118e144 100644 --- a/licenses/github.com/go-sql-driver/mysql/result.go +++ b/licenses/github.com/go-sql-driver/mysql/result.go @@ -8,15 +8,43 @@ package mysql +import "database/sql/driver" + +// Result exposes data not available through *connection.Result. +// +// This is accessible by executing statements using sql.Conn.Raw() and +// downcasting the returned result: +// +// res, err := rawConn.Exec(...) +// res.(mysql.Result).AllRowsAffected() +type Result interface { + driver.Result + // AllRowsAffected returns a slice containing the affected rows for each + // executed statement. + AllRowsAffected() []int64 + // AllLastInsertIds returns a slice containing the last inserted ID for each + // executed statement. + AllLastInsertIds() []int64 +} + type mysqlResult struct { - affectedRows int64 - insertId int64 + // One entry in both slices is created for every executed statement result. + affectedRows []int64 + insertIds []int64 } func (res *mysqlResult) LastInsertId() (int64, error) { - return res.insertId, nil + return res.insertIds[len(res.insertIds)-1], nil } func (res *mysqlResult) RowsAffected() (int64, error) { - return res.affectedRows, nil + return res.affectedRows[len(res.affectedRows)-1], nil +} + +func (res *mysqlResult) AllLastInsertIds() []int64 { + return append([]int64{}, res.insertIds...) // defensive copy +} + +func (res *mysqlResult) AllRowsAffected() []int64 { + return append([]int64{}, res.affectedRows...) // defensive copy } diff --git a/licenses/github.com/go-sql-driver/mysql/rows.go b/licenses/github.com/go-sql-driver/mysql/rows.go index 888bdb5f0ada819c2bbf32d681d335a13d32e591..81fa6062cd6884525443a6397b7259c0c67a759a 100644 --- a/licenses/github.com/go-sql-driver/mysql/rows.go +++ b/licenses/github.com/go-sql-driver/mysql/rows.go @@ -123,7 +123,8 @@ func (rows *mysqlRows) Close() (err error) { err = mc.readUntilEOF() } if err == nil { - if err = mc.discardResults(); err != nil { + handleOk := mc.clearResult() + if err = handleOk.discardResults(); err != nil { return err } } @@ -160,7 +161,15 @@ func (rows *mysqlRows) nextResultSet() (int, error) { return 0, io.EOF } rows.rs = resultSet{} - return rows.mc.readResultSetHeaderPacket() + // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to + // nextResultSet. + resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() + if err != nil { + // Clean up about multi-results flag + rows.rs.done = true + rows.mc.status = rows.mc.status & (^statusMoreResultsExists) + } + return resLen, err } func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { diff --git a/licenses/github.com/go-sql-driver/mysql/statement.go b/licenses/github.com/go-sql-driver/mysql/statement.go index 10ece8bd6a166c01299340a3d21b7641b464785e..31e7799c4b6bc433ba4664e1882958657335cba0 100644 --- a/licenses/github.com/go-sql-driver/mysql/statement.go +++ b/licenses/github.com/go-sql-driver/mysql/statement.go @@ -51,7 +51,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.Load() { - errLog.Print(ErrInvalidConn) + stmt.mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -61,12 +61,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } mc := stmt.mc - - mc.affectedRows = 0 - mc.insertId = 0 + handleOk := stmt.mc.clearResult() // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -83,14 +81,12 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } } - if err := mc.discardResults(); err != nil { + if err := handleOk.discardResults(); err != nil { return nil, err } - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + copied := mc.result + return &copied, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { @@ -99,7 +95,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.Load() { - errLog.Print(ErrInvalidConn) + stmt.mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -111,7 +107,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { mc := stmt.mc // Read Result - resLen, err := mc.readResultSetHeaderPacket() + handleOk := stmt.mc.clearResult() + resLen, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } diff --git a/licenses/github.com/go-sql-driver/mysql/utils.go b/licenses/github.com/go-sql-driver/mysql/utils.go index 15dbd8d16ab98bc3ffd117c38a8f193bc72db33f..cda24fe744ede31ea41d6672ee809de753d30d47 100644 --- a/licenses/github.com/go-sql-driver/mysql/utils.go +++ b/licenses/github.com/go-sql-driver/mysql/utils.go @@ -36,7 +36,7 @@ var ( // registering it. // // rootCertPool := x509.NewCertPool() -// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// pem, err := os.ReadFile("/path/ca-cert.pem") // if err != nil { // log.Fatal(err) // } @@ -265,7 +265,11 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } -func appendDateTime(buf []byte, t time.Time) ([]byte, error) { +func appendDateTime(buf []byte, t time.Time, timeTruncate time.Duration) ([]byte, error) { + if timeTruncate > 0 { + t = t.Truncate(timeTruncate) + } + year, month, day := t.Date() hour, min, sec := t.Clock() nsec := t.Nanosecond() @@ -616,6 +620,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } +func appendLengthEncodedString(b []byte, s string) []byte { + b = appendLengthEncodedInteger(b, uint64(len(s))) + return append(b, s...) +} + // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. // If cap(buf) is not enough, reallocate new buffer. func reserveBuffer(buf []byte, appendSize int) []byte { diff --git a/licenses/github.com/go-sql-driver/mysql/utils_test.go b/licenses/github.com/go-sql-driver/mysql/utils_test.go index 4e5fc3cb748bf3772991bd005f1c3a32bcd39ea3..80aebddfff27a35b6fb1c4d67f400d95ef663852 100644 --- a/licenses/github.com/go-sql-driver/mysql/utils_test.go +++ b/licenses/github.com/go-sql-driver/mysql/utils_test.go @@ -237,8 +237,10 @@ func TestIsolationLevelMapping(t *testing.T) { func TestAppendDateTime(t *testing.T) { tests := []struct { - t time.Time - str string + t time.Time + str string + timeTruncate time.Duration + expectedErr bool }{ { t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), @@ -276,34 +278,75 @@ func TestAppendDateTime(t *testing.T) { t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), str: "0001-01-01", }, + // Truncated time + { + t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), + str: "1234-05-06", + timeTruncate: time.Second, + }, + { + t: time.Date(4567, 12, 31, 12, 0, 0, 0, time.UTC), + str: "4567-12-31 12:00:00", + timeTruncate: time.Minute, + }, + { + t: time.Date(2020, 5, 30, 12, 34, 0, 0, time.UTC), + str: "2020-05-30 12:34:00", + timeTruncate: 0, + }, + { + t: time.Date(2020, 5, 30, 12, 34, 56, 0, time.UTC), + str: "2020-05-30 12:34:56", + timeTruncate: time.Second, + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123000000, time.UTC), + str: "2020-05-30 22:33:44", + timeTruncate: time.Second, + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456000, time.UTC), + str: "2020-05-30 22:33:44.123", + timeTruncate: time.Millisecond, + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456789, time.UTC), + str: "2020-05-30 22:33:44", + timeTruncate: time.Second, + }, + { + t: time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC), + str: "9999-12-31 23:59:59.999999999", + timeTruncate: 0, + }, + { + t: time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC), + str: "0001-01-01", + timeTruncate: 365 * 24 * time.Hour, + }, + // year out of range + { + t: time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), + expectedErr: true, + }, + { + t: time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC), + expectedErr: true, + }, } for _, v := range tests { buf := make([]byte, 0, 32) - buf, _ = appendDateTime(buf, v.t) + buf, err := appendDateTime(buf, v.t, v.timeTruncate) + if err != nil { + if !v.expectedErr { + t.Errorf("appendDateTime(%v) returned an errror: %v", v.t, err) + } + continue + } if str := string(buf); str != v.str { t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str) } } - - // year out of range - { - v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) - buf := make([]byte, 0, 32) - _, err := appendDateTime(buf, v) - if err == nil { - t.Error("want an error") - return - } - } - { - v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC) - buf := make([]byte, 0, 32) - _, err := appendDateTime(buf, v) - if err == nil { - t.Error("want an error") - return - } - } } func TestParseDateTime(t *testing.T) { diff --git a/licenses/go.uber.org/zap/LICENSE.txt b/licenses/go.uber.org/zap/LICENSE similarity index 100% rename from licenses/go.uber.org/zap/LICENSE.txt rename to licenses/go.uber.org/zap/LICENSE diff --git a/manager.go b/manager.go index e30f6e9e41b6155780aa9b5dce4ccf0e81a71bea..70307a30564383328c3ce23bf6b38881aae74152 100644 --- a/manager.go +++ b/manager.go @@ -34,8 +34,8 @@ type Manager struct { cronInstance *cron.Cron //logger Logger - database *gorm.DB - dbSaver *DBSaver + database *gorm.DB + jobSyncer *JobSyncer mu sync.Mutex } @@ -152,9 +152,9 @@ func (m *Manager) DeleteJob(id JobID) error { return err } - if m.dbSaver != nil { + if m.jobSyncer != nil { - err := m.dbSaver.DeleteJob(job) + err := m.jobSyncer.DeleteJob(job) if err != nil { return err } @@ -186,9 +186,9 @@ func (m *Manager) ResetJobLogs(id JobID) error { return ErrJobNotActive } - if m.dbSaver != nil { + if m.jobSyncer != nil { - err := m.dbSaver.ResetLogs(m.activeJobs[id]) + err := m.jobSyncer.ResetLogs(m.activeJobs[id]) if err != nil { return err } @@ -206,9 +206,9 @@ func (m *Manager) ResetJobStats(id JobID) error { return ErrJobNotActive } - if m.dbSaver != nil { + if m.jobSyncer != nil { - err := m.dbSaver.ResetStats(m.activeJobs[id]) + err := m.jobSyncer.ResetStats(m.activeJobs[id]) if err != nil { return err } @@ -290,12 +290,10 @@ func (m *Manager) SetDB(db *gorm.DB) *Manager { defer m.mu.Unlock() m.database = db - if m.dbSaver != nil { + if m.jobSyncer != nil { return m } - - m.dbSaver = NewDBSaver() - m.dbSaver.SetManager(m) + m.jobSyncer = NewJobSyncer(m) return m } @@ -391,7 +389,7 @@ func (m *Manager) RemoveWorker(worker Worker) error { // Start starts the manager func (m *Manager) Start() error { - var err error + //var err error m.mu.Lock() defer m.mu.Unlock() @@ -400,21 +398,31 @@ func (m *Manager) Start() error { return ErrManagerAlreadyRunning } - if m.dbSaver != nil { - p := StartDBSaver(m.dbSaver) + if m.jobSyncer != nil { + p := CreateAndStartJobSyncer(m) ready := make(chan struct{}) - Then[bool, bool](p, func(value bool) (bool, error) { + var jobSyncerErr error + + Then[*JobSyncer, *JobSyncer](p, func(value *JobSyncer) (*JobSyncer, error) { close(ready) + m.mu.Lock() + m.jobSyncer = value + m.mu.Unlock() return value, nil }, func(e error) error { close(ready) - Error("Error while starting db saver", "error", err) + Error("Error while starting db saver", "error", e) + jobSyncerErr = e return nil }) <-ready + + if jobSyncerErr != nil { + return jobSyncerErr + } } if len(m.workerMap) == 0 { @@ -449,7 +457,7 @@ func (m *Manager) Start() error { go m.handleJobEvents() - err = m.checkAndSetRunningState() + err := m.checkAndSetRunningState() if err != nil { wrappedErr = fmt.Errorf("%w\n%s", wrappedErr, err.Error()) @@ -481,7 +489,7 @@ func (m *Manager) Stop() error { for _, worker := range m.workerMap { err := worker.Stop() - if err != nil && err != ErrWorkerAlreadyStopped { + if err != nil && !errors.Is(err, ErrWorkerAlreadyStopped) { if wrappedErr == nil { wrappedErr = fmt.Errorf("Error: ") } @@ -500,15 +508,15 @@ func (m *Manager) Stop() error { m.cronInstance.Stop() } - if m.dbSaver != nil { - err = m.dbSaver.Stop() + if m.jobSyncer != nil { + err = m.jobSyncer.Stop() if err != nil { if wrappedErr == nil { wrappedErr = fmt.Errorf("Error: ") } wrappedErr = fmt.Errorf("%w\n%s", wrappedErr, err.Error()) - + } } @@ -572,11 +580,8 @@ func (m *Manager) ScheduleJob(job GenericJob, scheduler Scheduler) error { m.activeJobs[job.GetID()] = job - if m.dbSaver != nil { - err := m.dbSaver.SaveJob(job) - if err != nil { - return err - } + if m.jobSyncer != nil { + m.jobSyncer.AddJob(job) } return nil @@ -629,7 +634,7 @@ func (m *Manager) handleJobEvents() { job := event.Data.(GenericJob) err := m.queue.Enqueue(job) - if err != nil && err != ErrJobAlreadyExists { + if err != nil && !errors.Is(err, ErrJobAlreadyExists) { Error("Error while queueing job", "error", err) } diff --git a/manager_test.go b/manager_test.go index 22824848b4ae791e572518c99022faa25394e954..58356733ac0e8e6ee75cb757b1fa43584c92d7c3 100644 --- a/manager_test.go +++ b/manager_test.go @@ -64,6 +64,10 @@ func (m *MockGenericJob) ResetStats() { } +func (m *MockGenericJob) GetStats() JobStats { + return JobStats{} +} + func (m *MockGenericJob) GetMaxRetries() uint { return 0 } @@ -232,7 +236,7 @@ func TestManager_CancelJob(t *testing.T) { func TestManagerEventHandling(t *testing.T) { mgr := NewManager() - worker := NewLocalWorker(1) + worker := NewLocalWorker(10) err := mgr.AddWorker(worker) assert.Nil(t, err) diff --git a/schedule-interval.go b/schedule-interval.go index 4766a0e2856f5ea544cf450999ed5ed0adad7bb4..0a2c0d2b3bba7719ee245eebc5b047a1535dbe30 100644 --- a/schedule-interval.go +++ b/schedule-interval.go @@ -37,7 +37,6 @@ func (s *IntervalScheduler) Schedule(job GenericJob, eventBus *EventBus) error { for { select { case <-ticker.C: - if !job.IsPaused() { eventBus.Publish(QueueJob, job) } diff --git a/worker.go b/worker.go index d2fbb0e8dfcb565e3a5ec4cf72e9f770dca5c801..836a474f2a68bf24d4868317409cabbd0946570d 100644 --- a/worker.go +++ b/worker.go @@ -172,7 +172,7 @@ func (w *LocalWorker) run(jobChannel chan GenericJob, stopChan chan bool, cancel workerThreadID := w.ID.String() + "-" + fmt.Sprintf("%p", &w) Info("Worker thread with id started", "worker", w.ID, "thread_id", workerThreadID) - + stopFlag := false w.wg.Done() @@ -181,6 +181,10 @@ func (w *LocalWorker) run(jobChannel chan GenericJob, stopChan chan bool, cancel select { case job := <-jobChannel: + if stopFlag { + break + } + w.statisticMu.Lock() w.statistic.JobsAssigned++ w.statistic.ActiveThreads++ @@ -233,19 +237,22 @@ func (w *LocalWorker) run(jobChannel chan GenericJob, stopChan chan bool, cancel cancel() - if w.manager != nil && w.manager.dbSaver != nil { - err = w.manager.dbSaver.SaveJob(job) - if err != nil { - Error("Error while saving job", "job_id", job.GetID(), "error", err) - - } + if w.manager != nil { + go func() { + w.manager.mu.Lock() + if w.manager.jobSyncer != nil { + w.manager.jobSyncer.AddJob(job) + } + w.manager.mu.Unlock() + }() } - + w.statisticMu.Lock() w.statistic.ActiveThreads-- w.statisticMu.Unlock() case <-stopChan: + Info("Stopping worker thread", "worker", w.ID, "thread_id", workerThreadID) stopFlag = true break } diff --git a/worker_test.go b/worker_test.go index b5ab55b0ab52aaa9409af4f4be426b015c22f706..84bdb4c3bd32a71554abc160bcdd2440a4f87f10 100644 --- a/worker_test.go +++ b/worker_test.go @@ -39,6 +39,10 @@ func (j DummyJob) ResetStats() { } +func (j DummyJob) GetStats() JobStats { + return JobStats{} +} + func (j DummyJob) GetMaxRetries() uint { return 0 }