package traceroute

import (
	"errors"
	"fmt"
	"golang.org/x/net/icmp"
	"golang.org/x/net/ipv4"
	"math/rand"
	"net"

	"time"
)

type Result struct {
	Hop     int
	Station string
	Latency time.Duration
	Err     error
}

// Results is a collection of hops
type Results struct {
	Hops []Result
}

func (s *Session) doHop(i int) Result {
	s.icmpEcho.Body.(*icmp.Echo).Seq = i

	r := Result{
		Hop:     i,
		Station: "*",
	}

	writeBuffer, err := s.icmpEcho.Marshal(nil)

	if err != nil {
		r.Err = err
		return r
	}

	if err := s.ipV4Sock.SetTTL(i); err != nil {
		r.Err = fmt.Errorf("socket: %w", err)
		return r
	}

	timeNow := time.Now()

	dst := s.Destination

	a := net.IPAddr{
		IP:   dst.ip,
		Zone: "",
	}

	if _, err := s.ipV4Sock.WriteTo(writeBuffer, nil, &a); err != nil {
		r.Err = err
		return r
	}

	if err := s.ipV4Sock.SetReadDeadline(time.Now().Add(s.Timeout)); err != nil {
		r.Err = err
		return r
	}

	readBytes, _, hopNode, err := s.ipV4Sock.ReadFrom(s.readBuffer)

	if hopNode != nil {
		r.Station = hopNode.String()
	}

	if err != nil {
		r.Err = err
		return r
	}

	icmpAnswer, err := icmp.ParseMessage(1, s.readBuffer[:readBytes])

	if err != nil {
		r.Err = err
		return r
	}

	latency := time.Since(timeNow)
	r.Latency = latency

	if icmpAnswer.Type == ipv4.ICMPTypeTimeExceeded {
		s.nextHop++
		return r
	}

	if icmpAnswer.Type == ipv4.ICMPTypeEchoReply {
		s.isFinale = true
		return r
	}

	r.Err = fmt.Errorf("unknown icmp answer: %d", icmpAnswer.Type.Protocol())

	return r
}

func (s *Session) TraceRouteV4() (*Results, error) {

	sock, err := net.ListenPacket("ip4:icmp", s.Source.ip.String())

	if err != nil {
		return nil, err
	}

	defer sock.Close()

	s.ipV4Sock = ipv4.NewPacketConn(sock)
	defer s.ipV4Sock.Close()

	if err := s.ipV4Sock.SetControlMessage(ipv4.FlagTTL|ipv4.FlagDst|ipv4.FlagInterface|ipv4.FlagSrc, true); err != nil {
		return nil, err
	}

	s.icmpEcho = icmp.Message{
		Type: ipv4.ICMPTypeEcho, Code: 0, Body: &icmp.Echo{ID: rand.Int(), Data: []byte("")},
	}

	s.readBuffer = make([]byte, 1500)

	results := Results{}

	for i := 1; i < s.MaxHops; i++ {
		r:=s.doHop(i)
		results.Hops = append(results.Hops, r)
		if s.CallBack!=nil {
			s.CallBack(r)
		}

		if s.isFinale {
			break
		}
	}

	return &results, nil
}

// currently not implemented
//func (s *Session) traceRouteV6() error {
//	return nil
//}

// TraceRoute measures the steps to the target host 
func (s *Session) TraceRoute() (*Results, error) {

	if s.Destination.isV4() {
		results, err := s.TraceRouteV4()
		if err != nil {
			return nil, err
		}
		return results, nil
	}

	// currently not implemented
	//if s.Destination.isV6() {
	//	if err := s.traceRouteV6(); err != nil {
	//		return nil, err
	//	}
	//}

	return nil, errors.New("could not traceroute " + s.Destination.String())

}