package jobqueue

import (
	"context"
	"github.com/shirou/gopsutil/v3/cpu"
	"math"
	"runtime"
	"sync"
	"sync/atomic"
	"time"
)

var mainResourceStats *ResourceStats

func StartResourceMonitoring(interval time.Duration) error {
	mainResourceStats = NewResourceStats()
	return mainResourceStats.StartMonitoring(interval)
}

func StopResourceMonitoring() {
	if mainResourceStats != nil {
		mainResourceStats.StopMonitoring()
	}
}

func resetResourceStatsForTesting() {
	if mainResourceStats != nil {
		StopResourceMonitoring()
	}
}

func GetCpuUsage() float64 {

	if mainResourceStats != nil {
		return mainResourceStats.GetCpuUsage()
	}
	return 0
}

func GetMemoryUsage() uint64 {
	if mainResourceStats != nil {
		return mainResourceStats.GetMemoryUsage()
	}
	return 0
}

type ResourceStats struct {
	cpuUsage    uint64
	memoryUsage uint64
	context     context.Context
	cancel      context.CancelFunc
	mu          sync.Mutex
}

func NewResourceStats() *ResourceStats {
	return &ResourceStats{}
}

func (stats *ResourceStats) getMemoryUsage() uint64 {
	var m runtime.MemStats
	runtime.ReadMemStats(&m)
	return m.Alloc
}

func (stats *ResourceStats) getCPUPercentage() (float64, error) {
	percentages, err := cpu.Percent(100*time.Millisecond, false)
	if err != nil {
		return 0, err
	}

	if len(percentages) == 0 {
		return 0, ErrCPUPercentage
	}

	return percentages[0], nil
}

func (stats *ResourceStats) assignResourceStats() {
	mem := stats.getMemoryUsage()
	cpuP, err := stats.getCPUPercentage()
	if err != nil {
		return
	}
	cpuPBits := math.Float64bits(cpuP)
	atomic.StoreUint64(&stats.cpuUsage, cpuPBits)
	atomic.StoreUint64(&stats.memoryUsage, mem)
}

func (stats *ResourceStats) MonitorResources(interval time.Duration) {
	stats.mu.Lock()
	ctx := stats.context
	stats.mu.Unlock()

	if ctx == nil {
		return
	}

	ticker := time.NewTicker(interval)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			stats.assignResourceStats()
		case <-ctx.Done():
			return
		}
	}
}

func (stats *ResourceStats) StartMonitoring(interval time.Duration) error {
	stats.mu.Lock()
	defer stats.mu.Unlock()

	if stats.context != nil && stats.context.Err() == nil {
		return nil
	}

	ctx, cancel := context.WithCancel(context.Background())
	stats.context = ctx
	stats.cancel = cancel

	if interval == 0 {
		return ErrIntervalIsZero
	}

	stats.assignResourceStats()
	go stats.MonitorResources(interval)

	return nil
}

func (stats *ResourceStats) StopMonitoring() {
	stats.mu.Lock()
	defer stats.mu.Unlock()

	if stats.cancel != nil {
		ctx := stats.context // save for later
		stats.context = nil  // set to nil first
		stats.cancel()       // then cancel
		ctx.Done()           // ensure channel is closed if needed
		stats.cancel = nil
	}
}

func (stats *ResourceStats) GetCpuUsage() float64 {
	bits := atomic.LoadUint64(&stats.cpuUsage)
	return math.Float64frombits(bits)
}

func (stats *ResourceStats) GetMemoryUsage() uint64 {
	return atomic.LoadUint64(&stats.memoryUsage)
}