// Copyright 2023 schukai GmbH
// SPDX-License-Identifier: AGPL-3.0

package jobqueue

import (
	"context"
	"fmt"
	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
	"io"
	"os"
	"sync"
)

func NewSFTPRunnableFromMap(data map[string]interface{}) (*SFTPRunnable, error) {
	host, ok := data["host"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid Host: %v", ErrInvalidData, data["host"])
	}

	port, ok := data["port"].(int)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid Port: %v", ErrInvalidData, data["port"])
	}

	user, ok := data["user"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid User: %v", ErrInvalidData, data["user"])
	}

	insecure, ok := data["insecure"].(bool)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid Insecure: %v", ErrInvalidData, data["insecure"])
	}

	credential, ok := data["credential"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid Credential: %v", ErrInvalidData, data["credential"])
	}

	credentialType, ok := data["credentialtype"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid CredentialType: %v", ErrInvalidData, data["credentialtype"])
	}

	hostKey, ok := data["hostkey"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid HostKey: %v", ErrInvalidData, data["hostkey"])
	}

	srcDir, ok := data["srcdir"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid SrcDir: %v", ErrInvalidData, data["srcdir"])
	}

	dstDir, ok := data["dstdir"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid DstDir: %v", ErrInvalidData, data["dstdir"])
	}

	transferDirection, ok := data["transferdirection"].(string)
	if !ok {
		return nil, fmt.Errorf("%w: Invalid TransferDirection: %v", ErrInvalidData, data["TransferDirection"])
	}

	return &SFTPRunnable{
		Host:              host,
		Port:              port,
		User:              user,
		Insecure:          insecure,
		Credential:        credential,
		CredentialType:    credentialType,
		HostKey:           hostKey,
		SrcDir:            srcDir,
		DstDir:            dstDir,
		TransferDirection: Direction(transferDirection),
	}, nil
}

// SFTPResult is a result of a sftp
type SFTPResult struct {
	FilesCopied []string
}

func (s *SFTPResult) GetResult() string {
	return fmt.Sprintf("FilesCopied: %v", s.FilesCopied)
}

func (s *SFTPResult) GetError() (string, int) {
	return "", 0
}

const (
	CredentialTypePassword = "password"
	CredentialTypeKey      = "key"
)

type Direction string

const (
	LocalToRemote Direction = "LocalToRemote"
	RemoteToLocal Direction = "RemoteToLocal"
)

type SFTPRunnable struct {
	Host              string
	Port              int
	User              string
	Insecure          bool
	Credential        string
	CredentialType    string
	HostKey           string
	SrcDir            string
	DstDir            string
	TransferDirection Direction
	mu                sync.Mutex
}

func (s *SFTPRunnable) Run(_ context.Context) (RunResult[SFTPResult], error) {

	var authMethod ssh.AuthMethod

	// Auth
	switch s.CredentialType {
	case CredentialTypePassword:
		authMethod = ssh.Password(s.Credential)
	case CredentialTypeKey:
		key, err := ssh.ParsePrivateKey([]byte(s.Credential))
		if err != nil {
			return RunResult[SFTPResult]{Status: ResultStatusFailed,
				Data: SFTPResult{},
			}, err
		}
		authMethod = ssh.PublicKeys(key)
	default:
		return RunResult[SFTPResult]{Status: ResultStatusFailed,
			Data: SFTPResult{}}, ErrUnsupportedCredentialType
	}

	var hkCallback ssh.HostKeyCallback

	if s.HostKey != "" {

		hostkeyBytes := []byte(s.HostKey)
		hostKey, err := ssh.ParsePublicKey(hostkeyBytes)
		if err != nil {
			return RunResult[SFTPResult]{Status: ResultStatusFailed,
				Data: SFTPResult{},
			}, err
		}

		hkCallback = ssh.FixedHostKey(hostKey)
	} else {
		if s.Insecure {
			// #nosec
			hkCallback = ssh.InsecureIgnoreHostKey()
		} else {
			hkCallback = ssh.FixedHostKey(nil)
		}
	}

	config := &ssh.ClientConfig{
		User: s.User,
		Auth: []ssh.AuthMethod{
			authMethod,
		},
		HostKeyCallback: hkCallback,
	}

	client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port), config)
	if err != nil {
		return RunResult[SFTPResult]{Status: ResultStatusFailed,
			Data: SFTPResult{},
		}, err
	}
	defer client.Close()

	sftpClient, err := sftp.NewClient(client)
	if err != nil {
		return RunResult[SFTPResult]{Status: ResultStatusFailed,
			Data: SFTPResult{},
		}, err
	}
	defer sftpClient.Close()

	var filesCopied []string

	switch s.TransferDirection {
	case LocalToRemote:
		filesCopied, err = s.copyLocalToRemote(sftpClient)
	case RemoteToLocal:
		filesCopied, err = s.copyRemoteToLocal(sftpClient)
	default:
		return RunResult[SFTPResult]{Status: ResultStatusFailed,
			Data: SFTPResult{},
		}, ErrUnsupportedTransferDirection
	}

	if err != nil {
		return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
	}

	if err != nil {
		return RunResult[SFTPResult]{Status: ResultStatusFailed}, err
	}

	return RunResult[SFTPResult]{Status: ResultStatusSuccess, Data: SFTPResult{FilesCopied: filesCopied}}, nil
}

func copyFile(src io.Reader, dst io.Writer) error {
	_, err := io.Copy(dst, src)
	return err
}

func (s *SFTPRunnable) copyLocalToRemote(sftpClient *sftp.Client) ([]string, error) {

	var filesCopied []string

	// create destination directory
	err := sftpClient.MkdirAll(s.DstDir)
	if err != nil {
		return nil, err
	}

	// copy files
	files, err := os.ReadDir(s.SrcDir)
	if err != nil {
		return nil, err
	}

	for _, file := range files {
		if file.IsDir() {
			continue
		}

		srcFile, err := os.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name()))
		if err != nil {
			return nil, err
		}
		dstFile, err := sftpClient.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
		if err != nil {
			_ = srcFile.Close()
			return nil, err
		}
		err = copyFile(srcFile, dstFile)
		_ = srcFile.Close()
		_ = dstFile.Close()
		if err != nil {
			return nil, err
		}

		filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
	}

	return filesCopied, nil
}

func (s *SFTPRunnable) copyRemoteToLocal(sftpClient *sftp.Client) ([]string, error) {

	var filesCopied []string

	// create destination directory
	err := os.MkdirAll(s.DstDir, 0700)
	if err != nil {
		return nil, err
	}

	// copy files
	files, err := sftpClient.ReadDir(s.SrcDir)
	if err != nil {
		return nil, err
	}

	for _, file := range files {
		if file.IsDir() {
			continue
		}

		srcFile, err := sftpClient.Open(fmt.Sprintf("%s/%s", s.SrcDir, file.Name()))
		if err != nil {
			return nil, err
		}
		dstFile, err := os.Create(fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
		if err != nil {
			_ = srcFile.Close()
			return nil, err
		}
		err = copyFile(srcFile, dstFile)
		_ = srcFile.Close()
		_ = dstFile.Close()
		if err != nil {
			return nil, err
		}

		filesCopied = append(filesCopied, fmt.Sprintf("%s/%s", s.DstDir, file.Name()))
	}

	return filesCopied, nil
}

func (s *SFTPRunnable) GetType() string {
	return "sftp"
}

func (s *SFTPRunnable) GetPersistence() RunnableImport {
	s.mu.Lock()
	defer s.mu.Unlock()

	data := JSONMap{
		"host":               s.Host,
		"port":               s.Port,
		"user":               s.User,
		"insecure":           s.Insecure,
		"credential":         s.Credential,
		"credential_type":    s.CredentialType,
		"hostkey":            s.HostKey,
		"src_dir":            s.SrcDir,
		"dst_dir":            s.DstDir,
		"transfer_direction": s.TransferDirection,
	}

	return RunnableImport{
		Type: s.GetType(),
		Data: data,
	}
}