// Copyright 2023 schukai GmbH // SPDX-License-Identifier: AGPL-3.0 package jobqueue import ( "context" "fmt" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "io" "os" "sync" ) func NewSFTPRunnableFromMap(data map[string]interface{}) (*SFTPRunnable, error) { host, ok := data["host"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid Host: %v", ErrInvalidData, data["host"]) } port, ok := data["port"].(int) if !ok { return nil, fmt.Errorf("%w: Invalid Port: %v", ErrInvalidData, data["port"]) } user, ok := data["user"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid User: %v", ErrInvalidData, data["user"]) } insecure, ok := data["insecure"].(bool) if !ok { return nil, fmt.Errorf("%w: Invalid Insecure: %v", ErrInvalidData, data["insecure"]) } credential, ok := data["credential"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid Credential: %v", ErrInvalidData, data["credential"]) } credentialType, ok := data["credentialtype"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid CredentialType: %v", ErrInvalidData, data["credentialtype"]) } hostKey, ok := data["hostkey"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid HostKey: %v", ErrInvalidData, data["hostkey"]) } srcDir, ok := data["srcdir"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid SrcDir: %v", ErrInvalidData, data["srcdir"]) } dstDir, ok := data["dstdir"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid DstDir: %v", ErrInvalidData, data["dstdir"]) } transferDirection, ok := data["transferdirection"].(string) if !ok { return nil, fmt.Errorf("%w: Invalid TransferDirection: %v", ErrInvalidData, data["TransferDirection"]) } return &SFTPRunnable{ Host: host, Port: port, User: user, Insecure: insecure, Credential: credential, CredentialType: credentialType, HostKey: hostKey, SrcDir: srcDir, DstDir: dstDir, TransferDirection: Direction(transferDirection), }, nil } // SFTPResult is a result of a sftp type SFTPResult struct { FilesCopied []string } func (s *SFTPResult) GetResult() string { return fmt.Sprintf("FilesCopied: %v", s.FilesCopied) } func (s *SFTPResult) GetError() (string, int) { return "", 0 } const ( CredentialTypePassword = "password" CredentialTypeKey = "key" ) type Direction string const ( LocalToRemote Direction = "LocalToRemote" RemoteToLocal Direction = "RemoteToLocal" ) type SFTPRunnable struct { Host string Port int User string Insecure bool Credential string CredentialType string HostKey string SrcDir string DstDir string TransferDirection Direction mu sync.Mutex } func (s *SFTPRunnable) Run(_ context.Context) (RunResult[SFTPResult], error) { var authMethod ssh.AuthMethod // Auth switch s.CredentialType { case CredentialTypePassword: authMethod = ssh.Password(s.Credential) case CredentialTypeKey: key, err := ssh.ParsePrivateKey([]byte(s.Credential)) if err != nil { return RunResult[SFTPResult]{Status: ResultStatusFailed, Data: SFTPResult{}, }, err } authMethod = ssh.PublicKeys(key) default: return RunResult[SFTPResult]{Status: ResultStatusFailed, Data: SFTPResult{}}, ErrUnsupportedCredentialType } var hkCallback ssh.HostKeyCallback if s.HostKey != "" { hostkeyBytes := []byte(s.HostKey) hostKey, err := ssh.ParsePublicKey(hostkeyBytes) if err != nil { return RunResult[SFTPResult]{Status: ResultStatusFailed, Data: SFTPResult{}, }, err } hkCallback = ssh.FixedHostKey(hostKey) } else { if s.Insecure { // #nosec hkCallback = ssh.InsecureIgnoreHostKey() } else { hkCallback = ssh.FixedHostKey(nil) } } config := &ssh.ClientConfig{ User: s.User, Auth: []ssh.AuthMethod{ authMethod, }, HostKeyCallback: hkCallback, } client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port), config) if err != nil { return RunResult[SFTPResult]{Status: ResultStatusFailed, Data: SFTPResult{}, }, err } defer client.Close() sftpClient, err := sftp.NewClient(client) if err != nil { return RunResult[SFTPResult]{Status: ResultStatusFailed, Data: SFTPResult{}, }, err } defer sftpClient.Close() var filesCopied []string switch s.TransferDirection { case LocalToRemote: filesCopied, err = s.copyLocalToRemote(sftpClient) case RemoteToLocal: filesCopied, err = s.copyRemoteToLocal(sftpClient) default: return RunResult[SFTPResult]{Status: ResultStatusFailed, Data: SFTPResult{}, }, ErrUnsupportedTransferDirection } if err != nil { return RunResult[SFTPResult]{Status: ResultStatusFailed}, err } if err != nil { return RunResult[SFTPResult]{Status: ResultStatusFailed}, err } return RunResult[SFTPResult]{Status: ResultStatusSuccess, Data: SFTPResult{FilesCopied: filesCopied}}, nil } func copyFile(src io.Reader, dst io.Writer) error { _, err := io.Copy(dst, src) return err } func (s *SFTPRunnable) copyLocalToRemote(sftpClient *sftp.Client) ([]string, error) { var filesCopied []string // create destination directory err := sftpClient.MkdirAll(s.DstDir) if err != nil { return nil, err } // copy files files, err := os.ReadDir(s.SrcDir) if err != nil { return nil, err } for _, file := range files { if file.IsDir() { continue } srcFile, err := os.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name())) if err != nil { return nil, err } dstFile, err := sftpClient.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name())) if err != nil { _ = srcFile.Close() return nil, err } err = copyFile(srcFile, dstFile) _ = srcFile.Close() _ = dstFile.Close() if err != nil { return nil, err } filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name())) } return filesCopied, nil } func (s *SFTPRunnable) copyRemoteToLocal(sftpClient *sftp.Client) ([]string, error) { var filesCopied []string // create destination directory err := os.MkdirAll(s.DstDir, 0700) if err != nil { return nil, err } // copy files files, err := sftpClient.ReadDir(s.SrcDir) if err != nil { return nil, err } for _, file := range files { if file.IsDir() { continue } srcFile, err := sftpClient.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name())) if err != nil { return nil, err } dstFile, err := os.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name())) if err != nil { _ = srcFile.Close() return nil, err } err = copyFile(srcFile, dstFile) _ = srcFile.Close() _ = dstFile.Close() if err != nil { return nil, err } filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name())) } return filesCopied, nil } func (s *SFTPRunnable) GetType() string { return "sftp" } func (s *SFTPRunnable) GetPersistence() RunnableImport { s.mu.Lock() defer s.mu.Unlock() data := JSONMap{ "host": s.Host, "port": s.Port, "user": s.User, "insecure": s.Insecure, "credential": s.Credential, "credential_type": s.CredentialType, "hostkey": s.HostKey, "src_dir": s.SrcDir, "dst_dir": s.DstDir, "transfer_direction": s.TransferDirection, } return RunnableImport{ Type: s.GetType(), Data: data, } }