Skip to content
Snippets Groups Projects
runnable-gorm.go 2.59 KiB
// Copyright 2023 schukai GmbH
// SPDX-License-Identifier: AGPL-3.0

package jobqueue

import (
	"context"
	"fmt"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"sync"
)

func NewDBRunnableFromMap(data map[string]interface{}) (*DBRunnable, error) {
	t, ok := data["type"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid Type: %v", ErrInvalidData, data["type"])
	}

	dsn, ok := data["dsn"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid DSN: %v", ErrInvalidData, data["dsn"])
	}

	query, ok := data["query"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid Query: %v", ErrInvalidData, data["query"])
	}

	return &DBRunnable{
		Type:  t,
		DSN:   dsn,
		Query: query,
	}, nil
}

// DBResult is a result of a db query
type DBResult struct {
	RowsAffected int
}

func (d *DBResult) GetResult() string {
	return fmt.Sprintf("RowsAffected: %d", d.RowsAffected)
}

func (d *DBResult) GetError() (string, int) {
	return "", 0
}

type DBRunnable struct {
	Type  string
	DSN   string
	Query string
	mu    sync.Mutex
}

type dbKey struct{}

// DBRunnableWithDB returns a new context with the provided gorm.DB injected.
func DBRunnableWithDB(ctx context.Context, db *gorm.DB) context.Context {
	return context.WithValue(ctx, dbKey{}, db)
}

// GetDBFromContext tries to retrieve a *gorm.DB from the context. If it exists, returns the db and true, otherwise returns nil and false.
func getDBFromContext(ctx context.Context) (*gorm.DB, bool) {
	db, ok := ctx.Value(dbKey{}).(*gorm.DB)
	return db, ok
}

func (d *DBRunnable) Run(ctx context.Context) (RunResult[DBResult], error) {
	var db *gorm.DB
	var ok bool
	var err error

	if db, ok = getDBFromContext(ctx); !ok {
		// No *gorm.DB in context, create a new connection
		switch d.Type {
		case "mysql":
			db, err = gorm.Open(mysql.Open(d.DSN), &gorm.Config{})
		default:
			return RunResult[DBResult]{Status: ResultStatusFailed,
				Data: DBResult{},
			}, ErrUnsupportedDatabaseType
		}

		if err != nil {
			return RunResult[DBResult]{Status: ResultStatusFailed,
					Data: DBResult{},
				},
				err
		}
	}

	result := db.Exec(d.Query)
	if result.Error != nil {
		return RunResult[DBResult]{Status: ResultStatusFailed,
			Data: DBResult{},
		}, result.Error
	}

	return RunResult[DBResult]{
		Status: ResultStatusSuccess,
		Data: DBResult{
			RowsAffected: int(result.RowsAffected),
		},
	}, nil
}

func (d *DBRunnable) GetType() string {
	return "db"
}

func (d *DBRunnable) GetPersistence() RunnableImport {
	d.mu.Lock()
	defer d.mu.Unlock()

	data := JSONMap{
		"type":  d.Type,
		"dsn":   d.DSN,
		"query": d.Query,
	}

	return RunnableImport{
		Type: d.GetType(),
		Data: data,
	}
}