Something went wrong on our end
Select Git revision
runnable-sftp.go
-
Volker Schukai authoredVolker Schukai authored
runnable-sftp.go 6.80 KiB
package jobqueue
import (
"context"
"fmt"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"io"
"os"
)
func NewSFTPRunnableFromMap(data map[string]interface{}) (*SFTPRunnable, error) {
host, ok := data["host"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid Host: %v", ErrInvalidData, data["host"])
}
port, ok := data["port"].(int)
if !ok {
return nil, fmt.Errorf("%w: Invalid Port: %v", ErrInvalidData, data["port"])
}
user, ok := data["user"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid User: %v", ErrInvalidData, data["user"])
}
insecure, ok := data["insecure"].(bool)
if !ok {
return nil, fmt.Errorf("%w: Invalid Insecure: %v", ErrInvalidData, data["insecure"])
}
credential, ok := data["credential"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid Credential: %v", ErrInvalidData, data["credential"])
}
credentialType, ok := data["credentialtype"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid CredentialType: %v", ErrInvalidData, data["credentialtype"])
}
hostKey, ok := data["hostkey"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid HostKey: %v", ErrInvalidData, data["hostkey"])
}
srcDir, ok := data["srcdir"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid SrcDir: %v", ErrInvalidData, data["srcdir"])
}
dstDir, ok := data["dstdir"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid DstDir: %v", ErrInvalidData, data["dstdir"])
}
transferDirection, ok := data["transferdirection"].(string)
if !ok {
return nil, fmt.Errorf("%w: Invalid TransferDirection: %v", ErrInvalidData, data["TransferDirection"])
}
return &SFTPRunnable{
Host: host,
Port: port,
User: user,
Insecure: insecure,
Credential: credential,
CredentialType: credentialType,
HostKey: hostKey,
SrcDir: srcDir,
DstDir: dstDir,
TransferDirection: Direction(transferDirection),
}, nil
}
// SFTPResult is a result of a sftp
type SFTPResult struct {
FilesCopied []string
}
const (
CredentialTypePassword = "password"
CredentialTypeKey = "key"
)
type Direction string
const (
LocalToRemote Direction = "LocalToRemote"
RemoteToLocal Direction = "RemoteToLocal"
)
type SFTPRunnable struct {
Host string
Port int
User string
Insecure bool
Credential string
CredentialType string
HostKey string
SrcDir string
DstDir string
TransferDirection Direction
}
func (s *SFTPRunnable) Run(_ context.Context) (RunResult[SFTPResult], error) {
var authMethod ssh.AuthMethod
// Auth
switch s.CredentialType {
case CredentialTypePassword:
authMethod = ssh.Password(s.Credential)
case CredentialTypeKey:
key, err := ssh.ParsePrivateKey([]byte(s.Credential))
if err != nil {
return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
}
authMethod = ssh.PublicKeys(key)
default:
return RunResult[SFTPResult]{Status: ResultStatusFailed}, ErrUnsupportedCredentialType
}
var hkCallback ssh.HostKeyCallback
if s.HostKey != "" {
hostkeyBytes := []byte(s.HostKey)
hostKey, err := ssh.ParsePublicKey(hostkeyBytes)
if err != nil {
return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
}
hkCallback = ssh.FixedHostKey(hostKey)
} else {
if s.Insecure {
// #nosec
hkCallback = ssh.InsecureIgnoreHostKey()
} else {
hkCallback = ssh.FixedHostKey(nil)
}
}
config := &ssh.ClientConfig{
User: s.User,
Auth: []ssh.AuthMethod{
authMethod,
},
HostKeyCallback: hkCallback,
}
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port), config)
if err != nil {
return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
}
defer client.Close()
sftpClient, err := sftp.NewClient(client)
if err != nil {
return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
}
defer sftpClient.Close()
var filesCopied []string
switch s.TransferDirection {
case LocalToRemote:
filesCopied, err = s.copyLocalToRemote(sftpClient)
case RemoteToLocal:
filesCopied, err = s.copyRemoteToLocal(sftpClient)
default:
return RunResult[SFTPResult]{Status: ResultStatusFailed}, ErrUnsupportedTransferDirection
}
if err != nil {
return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
}
if err != nil {
return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
}
return RunResult[SFTPResult]{Status: ResultStatusSuccess, Data: SFTPResult{FilesCopied: filesCopied}}, nil
}
func copyFile(src io.Reader, dst io.Writer) error {
_, err := io.Copy(dst, src)
return err
}
func (s *SFTPRunnable) copyLocalToRemote(sftpClient *sftp.Client) ([]string, error) {
var filesCopied []string
// create destination directory
err := sftpClient.MkdirAll(s.DstDir)
if err != nil {
return nil, err
}
// copy files
files, err := os.ReadDir(s.SrcDir)
if err != nil {
return nil, err
}
for _, file := range files {
if file.IsDir() {
continue
}
srcFile, err := os.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name()))
if err != nil {
return nil, err
}
dstFile, err := sftpClient.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
if err != nil {
_ = srcFile.Close()
return nil, err
}
err = copyFile(srcFile, dstFile)
_ = srcFile.Close()
_ = dstFile.Close()
if err != nil {
return nil, err
}
filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
}
return filesCopied, nil
}
func (s *SFTPRunnable) copyRemoteToLocal(sftpClient *sftp.Client) ([]string, error) {
var filesCopied []string
// create destination directory
err := os.MkdirAll(s.DstDir, 0700)
if err != nil {
return nil, err
}
// copy files
files, err := sftpClient.ReadDir(s.SrcDir)
if err != nil {
return nil, err
}
for _, file := range files {
if file.IsDir() {
continue
}
srcFile, err := sftpClient.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name()))
if err != nil {
return nil, err
}
dstFile, err := os.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
if err != nil {
_ = srcFile.Close()
return nil, err
}
err = copyFile(srcFile, dstFile)
_ = srcFile.Close()
_ = dstFile.Close()
if err != nil {
return nil, err
}
filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
}
return filesCopied, nil
}
func (s *SFTPRunnable) GetType() string {
return "sftp"
}
func (c *SFTPRunnable) GetPersistence() RunnableImport {
data := JSONMap{
"host": c.Host,
"port": c.Port,
"user": c.User,
"insecure": c.Insecure,
"credential": c.Credential,
"credential_type": c.CredentialType,
"hostkey": c.HostKey,
"src_dir": c.SrcDir,
"dst_dir": c.DstDir,
"transfer_direction": c.TransferDirection,
}
return RunnableImport{
Type: c.GetType(),
Data: data,
}
}