// Copyright 2023 schukai GmbH
// SPDX-License-Identifier: AGPL-3.0

package jobqueue

import (
	"context"
	"errors"
	"fmt"
	"github.com/google/uuid"
	"sync"
	"time"
)

type WorkerStatus int

const (
	WorkerStatusStopped = iota
	WorkerStatusRunning
)

type WorkerID string

func (id WorkerID) String() string {
	return string(id)
}

// Worker is a worker
type Worker interface {
	Start() error
	Stop() error
	Status() WorkerStatus
	AssignJob(job GenericJob) error

	GetID() WorkerID

	SetManager(manager *Manager)
}

type Statistic struct {
	TotalThreads       int
	ActiveThreads      int
	JobsAssigned       int
	JobsCompleted      int
	FailedJobs         int
	TotalExecutionTime time.Duration
}

func (s *Statistic) AverageExecutionTime() time.Duration {
	if s.JobsCompleted == 0 {
		return 0
	}
	return s.TotalExecutionTime / time.Duration(s.JobsCompleted)
}

func (s *Statistic) UtilizationRate() float64 {
	if s.TotalThreads == 0 {
		return 0
	}
	return float64(s.ActiveThreads) / float64(s.TotalThreads) * 100
}

// GenericWorker is a generic worker
type GenericWorker struct {
	ID     WorkerID
	status WorkerStatus
}

// LocalWorker is a worker that runs jobs locally
type LocalWorker struct {
	GenericWorker
	jobChannels []chan GenericJob
	stopChans   []chan bool
	cancelChans []chan bool
	maxJobs     int
	mu          sync.Mutex
	statisticMu sync.Mutex
	wg          sync.WaitGroup
	manager     *Manager
	statistic   Statistic
}

// GetID returns the ID of the worker
func (w *GenericWorker) GetID() WorkerID {
	return w.ID
}

// NewLocalWorker creates a new local worker
func NewLocalWorker(maxJobs int) *LocalWorker {
	w := &LocalWorker{maxJobs: maxJobs, statistic: Statistic{TotalThreads: maxJobs}}
	w.jobChannels = make([]chan GenericJob, maxJobs)
	w.stopChans = make([]chan bool, maxJobs)
	w.cancelChans = make([]chan bool, maxJobs)
	w.ID = WorkerID(uuid.New().String())
	return w
}

// Start starts the worker
func (w *LocalWorker) Start() error {
	w.mu.Lock()
	defer w.mu.Unlock()

	if w.status == WorkerStatusRunning {
		return ErrWorkerAlreadyRunning
	}

	for i := 0; i < w.maxJobs; i++ {
		w.wg.Add(1)
		w.jobChannels[i] = make(chan GenericJob)
		w.stopChans[i] = make(chan bool)
		w.cancelChans[i] = make(chan bool)
		go w.run(w.jobChannels[i], w.stopChans[i], w.cancelChans[i])
	}

	time.Sleep(200 * time.Millisecond) // wait go routine until select
	w.wg.Wait()
	w.status = WorkerStatusRunning

	Info("Worker started", "worker", w.ID)

	return nil
}

// UpdateStatisticExtended updates the worker's statistics with job execution details
func (w *LocalWorker) UpdateStatisticExtended(jobDuration time.Duration, jobFailed bool) {
	w.statisticMu.Lock()
	defer w.statisticMu.Unlock()

	if jobFailed {
		w.statistic.FailedJobs++
	} else {
		w.statistic.TotalExecutionTime += jobDuration
		w.statistic.JobsCompleted++
	}

}

// GetStatistic returns the current statistics of the worker
func (w *LocalWorker) GetStatistic() Statistic {
	w.statisticMu.Lock()
	defer w.statisticMu.Unlock()
	return w.statistic
}

func (w *LocalWorker) SetManager(manager *Manager) {
	w.mu.Lock()
	defer w.mu.Unlock()
	w.manager = manager
}

// Stop stops the worker
func (w *LocalWorker) Stop() error {
	w.mu.Lock()
	defer w.mu.Unlock()

	if w.status == WorkerStatusStopped {
		return ErrWorkerNotRunning
	}

	w.status = WorkerStatusStopped
	for _, stopChan := range w.stopChans {
		stopChan <- true
	}

	Info("Worker stopped", "worker", w.ID)

	return nil
}

func (w *LocalWorker) run(jobChannel chan GenericJob, stopChan chan bool, cancelChan chan bool) {

	workerThreadID := w.ID.String() + "-" + fmt.Sprintf("%p", &w)

	Info("Worker thread with id started", "worker", w.ID, "thread_id", workerThreadID)

	stopFlag := false
	w.wg.Done()

	for {

		select {
		case job := <-jobChannel:

			if stopFlag {
				break
			}

			w.statisticMu.Lock()
			w.statistic.JobsAssigned++
			w.statistic.ActiveThreads++
			w.statisticMu.Unlock()

			ctx, cancel := context.WithCancel(context.Background())
			retries := job.GetMaxRetries()
			retryDelay := job.GetRetryDelay()

			startTime := time.Now()

			if retries == 0 {
				retries = 1
			}

			var err error
			for retries > 0 {

				var cancel context.CancelFunc

				timeout := job.GetTimeout()
				if timeout > 0 {
					ctx, cancel = context.WithTimeout(ctx, timeout)
				}

				Info("Executing job on worker thread", "worker", w.ID, "thread_id", workerThreadID, "job_id", job.GetID())

				_, err = job.Execute(ctx)
				jobFailed := false
				if err != nil {
					jobFailed = true
				}

				w.UpdateStatisticExtended(time.Since(startTime), jobFailed)

				if cancel != nil {
					cancel()
				}

				if err == nil || errors.Is(ctx.Err(), context.Canceled) {
					break
				}

				if retryDelay > 0 {
					time.Sleep(retryDelay)
				}

				retries--
			}

			cancel()

			go func() {
				w.mu.Lock()
				defer w.mu.Unlock()
				if w.manager != nil {
					w.manager.Sync(job)
				}

			}()

			w.statisticMu.Lock()
			w.statistic.ActiveThreads--
			w.statisticMu.Unlock()

		case <-stopChan:
			Info("Stopping worker thread", "worker", w.ID, "thread_id", workerThreadID)
			stopFlag = true
			break
		}

		if stopFlag {
			break
		}
	}

	Info("Worker thread with id stopped", "worker", w.ID, "thread_id", workerThreadID)

}

// AssignJob assigns a job to the worker
func (w *LocalWorker) AssignJob(job GenericJob) error {
	w.mu.Lock()
	defer w.mu.Unlock()

	if w.status != WorkerStatusRunning {
		return ErrWorkerNotRunning
	}

	for _, ch := range w.jobChannels {
		select {
		case ch <- job:
			return nil
		default:
			continue
		}
	}

	return ErrMaxJobsReached
}

// Status returns the status of the worker
func (w *LocalWorker) Status() WorkerStatus {
	w.mu.Lock()
	defer w.mu.Unlock()
	return w.status
}