// Copyright 2023 schukai GmbH // SPDX-License-Identifier: AGPL-3.0 package jobqueue import ( "context" "fmt" "gorm.io/driver/mysql" "gorm.io/gorm" "sync" ) func NewDBRunnableFromMap(data map[string]interface{}) (*DBRunnable, error) { t, ok := data["type"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid Type: %v", ErrInvalidData, data["type"]) } dsn, ok := data["dsn"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid DSN: %v", ErrInvalidData, data["dsn"]) } query, ok := data["query"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid Query: %v", ErrInvalidData, data["query"]) } return &DBRunnable{ Type: t, DSN: dsn, Query: query, }, nil } // DBResult is a result of a db query type DBResult struct { RowsAffected int } func (d *DBResult) GetResult() string { return fmt.Sprintf("RowsAffected: %d", d.RowsAffected) } func (d *DBResult) GetError() (string, int) { return "", 0 } type DBRunnable struct { Type string DSN string Query string mu sync.Mutex } type dbKey struct{} // DBRunnableWithDB returns a new context with the provided gorm.DB injected. func DBRunnableWithDB(ctx context.Context, db *gorm.DB) context.Context { return context.WithValue(ctx, dbKey{}, db) } // GetDBFromContext tries to retrieve a *gorm.DB from the context. If it exists, returns the db and true, otherwise returns nil and false. func getDBFromContext(ctx context.Context) (*gorm.DB, bool) { db, ok := ctx.Value(dbKey{}).(*gorm.DB) return db, ok } func (d *DBRunnable) Run(ctx context.Context) (RunResult[DBResult], error) { var db *gorm.DB var ok bool var err error if db, ok = getDBFromContext(ctx); !ok { // No *gorm.DB in context, create a new connection switch d.Type { case "mysql": db, err = gorm.Open(mysql.Open(d.DSN), &gorm.Config{}) default: return RunResult[DBResult]{Status: ResultStatusFailed, Data: DBResult{}, }, ErrUnsupportedDatabaseType } if err != nil { return RunResult[DBResult]{Status: ResultStatusFailed, Data: DBResult{}, }, err } } result := db.Exec(d.Query) if result.Error != nil { return RunResult[DBResult]{Status: ResultStatusFailed, Data: DBResult{}, }, result.Error } return RunResult[DBResult]{ Status: ResultStatusSuccess, Data: DBResult{ RowsAffected: int(result.RowsAffected), }, }, nil } func (d *DBRunnable) GetType() string { return "db" } func (d *DBRunnable) GetPersistence() RunnableImport { d.mu.Lock() defer d.mu.Unlock() data := JSONMap{ "type": d.Type, "dsn": d.DSN, "query": d.Query, } return RunnableImport{ Type: d.GetType(), Data: data, } }