Skip to content
Snippets Groups Projects
Select Git revision
  • 3cb0b5977f840ce4cd0e1274b3545a08621d9800
  • master default protected
  • v1.23.2
  • v1.23.1
  • v1.23.0
  • v1.22.0
  • v1.21.1
  • v1.21.0
  • v1.20.3
  • v1.20.2
  • v1.20.1
  • v1.20.0
  • v1.19.4
  • v1.19.3
  • v1.19.2
  • v1.19.1
  • v1.19.0
  • v1.18.2
  • v1.18.1
  • v1.18.0
  • v1.17.0
  • v1.16.1
22 results

runnable-sftp.go

Blame
  • runnable-sftp.go 6.80 KiB
    package jobqueue
    
    import (
    	"context"
    	"fmt"
    	"github.com/pkg/sftp"
    	"golang.org/x/crypto/ssh"
    	"io"
    	"os"
    )
    
    func NewSFTPRunnableFromMap(data map[string]interface{}) (*SFTPRunnable, error) {
    	host, ok := data["host"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid Host: %v", ErrInvalidData, data["host"])
    	}
    
    	port, ok := data["port"].(int)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid Port: %v", ErrInvalidData, data["port"])
    	}
    
    	user, ok := data["user"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid User: %v", ErrInvalidData, data["user"])
    	}
    
    	insecure, ok := data["insecure"].(bool)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid Insecure: %v", ErrInvalidData, data["insecure"])
    	}
    
    	credential, ok := data["credential"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid Credential: %v", ErrInvalidData, data["credential"])
    	}
    
    	credentialType, ok := data["credentialtype"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid CredentialType: %v", ErrInvalidData, data["credentialtype"])
    	}
    
    	hostKey, ok := data["hostkey"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid HostKey: %v", ErrInvalidData, data["hostkey"])
    	}
    
    	srcDir, ok := data["srcdir"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid SrcDir: %v", ErrInvalidData, data["srcdir"])
    	}
    
    	dstDir, ok := data["dstdir"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid DstDir: %v", ErrInvalidData, data["dstdir"])
    	}
    
    	transferDirection, ok := data["transferdirection"].(string)
    	if !ok {
    		return nil, fmt.Errorf("%w: Invalid TransferDirection: %v", ErrInvalidData, data["TransferDirection"])
    	}
    
    	return &SFTPRunnable{
    		Host:              host,
    		Port:              port,
    		User:              user,
    		Insecure:          insecure,
    		Credential:        credential,
    		CredentialType:    credentialType,
    		HostKey:           hostKey,
    		SrcDir:            srcDir,
    		DstDir:            dstDir,
    		TransferDirection: Direction(transferDirection),
    	}, nil
    }
    
    // SFTPResult is a result of a sftp
    type SFTPResult struct {
    	FilesCopied []string
    }
    
    const (
    	CredentialTypePassword = "password"
    	CredentialTypeKey      = "key"
    )
    
    type Direction string
    
    const (
    	LocalToRemote Direction = "LocalToRemote"
    	RemoteToLocal Direction = "RemoteToLocal"
    )
    
    type SFTPRunnable struct {
    	Host              string
    	Port              int
    	User              string
    	Insecure          bool
    	Credential        string
    	CredentialType    string
    	HostKey           string
    	SrcDir            string
    	DstDir            string
    	TransferDirection Direction
    }
    
    func (s *SFTPRunnable) Run(_ context.Context) (RunResult[SFTPResult], error) {
    
    	var authMethod ssh.AuthMethod
    
    	// Auth
    	switch s.CredentialType {
    	case CredentialTypePassword:
    		authMethod = ssh.Password(s.Credential)
    	case CredentialTypeKey:
    		key, err := ssh.ParsePrivateKey([]byte(s.Credential))
    		if err != nil {
    			return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
    		}
    		authMethod = ssh.PublicKeys(key)
    	default:
    		return RunResult[SFTPResult]{Status: ResultStatusFailed}, ErrUnsupportedCredentialType
    	}
    
    	var hkCallback ssh.HostKeyCallback
    
    	if s.HostKey != "" {
    
    		hostkeyBytes := []byte(s.HostKey)
    		hostKey, err := ssh.ParsePublicKey(hostkeyBytes)
    		if err != nil {
    			return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
    		}
    
    		hkCallback = ssh.FixedHostKey(hostKey)
    	} else {
    		if s.Insecure {
    			// #nosec
    			hkCallback = ssh.InsecureIgnoreHostKey()
    		} else {
    			hkCallback = ssh.FixedHostKey(nil)
    		}
    	}
    
    	config := &ssh.ClientConfig{
    		User: s.User,
    		Auth: []ssh.AuthMethod{
    			authMethod,
    		},
    		HostKeyCallback: hkCallback,
    	}
    
    	client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port), config)
    	if err != nil {
    		return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
    	}
    	defer client.Close()
    
    	sftpClient, err := sftp.NewClient(client)
    	if err != nil {
    		return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
    	}
    	defer sftpClient.Close()
    
    	var filesCopied []string
    
    	switch s.TransferDirection {
    	case LocalToRemote:
    		filesCopied, err = s.copyLocalToRemote(sftpClient)
    	case RemoteToLocal:
    		filesCopied, err = s.copyRemoteToLocal(sftpClient)
    	default:
    		return RunResult[SFTPResult]{Status: ResultStatusFailed}, ErrUnsupportedTransferDirection
    	}
    
    	if err != nil {
    		return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
    	}
    
    	if err != nil {
    		return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
    	}
    
    	return RunResult[SFTPResult]{Status: ResultStatusSuccess, Data: SFTPResult{FilesCopied: filesCopied}}, nil
    }
    
    func copyFile(src io.Reader, dst io.Writer) error {
    	_, err := io.Copy(dst, src)
    	return err
    }
    
    func (s *SFTPRunnable) copyLocalToRemote(sftpClient *sftp.Client) ([]string, error) {
    
    	var filesCopied []string
    
    	// create destination directory
    	err := sftpClient.MkdirAll(s.DstDir)
    	if err != nil {
    		return nil, err
    	}
    
    	// copy files
    	files, err := os.ReadDir(s.SrcDir)
    	if err != nil {
    		return nil, err
    	}
    
    	for _, file := range files {
    		if file.IsDir() {
    			continue
    		}
    
    		srcFile, err := os.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name()))
    		if err != nil {
    			return nil, err
    		}
    		dstFile, err := sftpClient.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
    		if err != nil {
    			_ = srcFile.Close()
    			return nil, err
    		}
    		err = copyFile(srcFile, dstFile)
    		_ = srcFile.Close()
    		_ = dstFile.Close()
    		if err != nil {
    			return nil, err
    		}
    
    		filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
    	}
    
    	return filesCopied, nil
    }
    
    func (s *SFTPRunnable) copyRemoteToLocal(sftpClient *sftp.Client) ([]string, error) {
    
    	var filesCopied []string
    
    	// create destination directory
    	err := os.MkdirAll(s.DstDir, 0700)
    	if err != nil {
    		return nil, err
    	}
    
    	// copy files
    	files, err := sftpClient.ReadDir(s.SrcDir)
    	if err != nil {
    		return nil, err
    	}
    
    	for _, file := range files {
    		if file.IsDir() {
    			continue
    		}
    
    		srcFile, err := sftpClient.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name()))
    		if err != nil {
    			return nil, err
    		}
    		dstFile, err := os.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
    		if err != nil {
    			_ = srcFile.Close()
    			return nil, err
    		}
    		err = copyFile(srcFile, dstFile)
    		_ = srcFile.Close()
    		_ = dstFile.Close()
    		if err != nil {
    			return nil, err
    		}
    
    		filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
    	}
    
    	return filesCopied, nil
    }
    
    func (s *SFTPRunnable) GetType() string {
    	return "sftp"
    }
    
    func (c *SFTPRunnable) GetPersistence() RunnableImport {
    
    	data := JSONMap{
    		"host":               c.Host,
    		"port":               c.Port,
    		"user":               c.User,
    		"insecure":           c.Insecure,
    		"credential":         c.Credential,
    		"credential_type":    c.CredentialType,
    		"hostkey":            c.HostKey,
    		"src_dir":            c.SrcDir,
    		"dst_dir":            c.DstDir,
    		"transfer_direction": c.TransferDirection,
    	}
    
    	return RunnableImport{
    		Type: c.GetType(),
    		Data: data,
    	}
    }