From 8345a09bc12ea8334640894933ea6f660f5e9cf2 Mon Sep 17 00:00:00 2001
From: Volker Schukai <volker.schukai@schukai.com>
Date: Sun, 18 Dec 2022 15:53:04 +0100
Subject: [PATCH] fix: optimize getter

---
 get.go      | 61 ++++++++++++++++++++++-------------------------------
 get_test.go | 19 +++++++++++++++++
 2 files changed, 44 insertions(+), 36 deletions(-)

diff --git a/get.go b/get.go
index a49d985..79789c4 100644
--- a/get.go
+++ b/get.go
@@ -5,6 +5,7 @@ package pathfinder
 
 import (
 	"reflect"
+	"strconv"
 	"strings"
 )
 
@@ -13,28 +14,35 @@ func GetValue[D any](obj D, keyWithDots string) (any, error) {
 	keySlice := strings.Split(keyWithDots, ".")
 	v := reflect.ValueOf(obj)
 
-	for _, key := range keySlice[0 : len(keySlice)-1] {
-		for v.Kind() == reflect.Ptr {
-			if v.Kind() == reflect.Invalid {
-				return nil, newInvalidPathError(keyWithDots)
-			}
+	for _, key := range keySlice[0:len(keySlice)] {
+
+		switch v.Kind() {
+		case reflect.Ptr, reflect.Slice, reflect.Array, reflect.Interface:
 			v = v.Elem()
 		}
 
-		if v.Kind() == reflect.Map {
-			switch v.Type().Key().Kind() {
-			case reflect.String:
-				v = v.MapIndex(reflect.ValueOf(key)).Elem()
-		
-				continue
-			default:
-				return nil, newUnsupportedTypePathError(keyWithDots, v.Type())
+		switch v.Kind() {
+		case reflect.Map:
+			v = v.MapIndex(reflect.ValueOf(key))
+			if !v.IsValid() {
+				return nil, newInvalidPathError(keyWithDots)
+			}
+
+		case reflect.Slice, reflect.Array:
+			index, err := strconv.Atoi(key)
+			if err != nil {
+				return nil, newInvalidPathError(keyWithDots)
 			}
-		} else if v.Kind() != reflect.Struct {
-			return nil, newUnsupportedTypePathError(keyWithDots, v.Type())
+			v = v.Index(index)
+		case reflect.Struct:
+			v = v.FieldByName(key)
+			if !v.IsValid() {
+				return nil, newInvalidPathError(keyWithDots)
+			}
+		default:
+			return nil, newInvalidPathError(keyWithDots)
 		}
 
-		v = v.FieldByName(key)
 	}
 
 	if v.Kind() == reflect.Invalid {
@@ -44,26 +52,7 @@ func GetValue[D any](obj D, keyWithDots string) (any, error) {
 	for v.Kind() == reflect.Ptr {
 		v = v.Elem()
 	}
-
-	if v.Kind() == reflect.Map {
-		switch v.Type().Key().Kind() {
-		case reflect.String:
-			return v.MapIndex(reflect.ValueOf(keySlice[len(keySlice)-1])).Interface(), nil
-		default:
-			return nil, newUnsupportedTypePathError(keyWithDots, v.Type())
-		}
-	}
-
-	// non-supporter type at the top of the path
-	if v.Kind() != reflect.Struct {
-		return nil, newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
-	}
-
-	v = v.FieldByName(keySlice[len(keySlice)-1])
-	if !v.IsValid() {
-		return nil, newInvalidPathError(keyWithDots)
-	}
-
+	
 	return v.Interface(), nil
 
 }
diff --git a/get_test.go b/get_test.go
index d079fa3..2340260 100644
--- a/get_test.go
+++ b/get_test.go
@@ -8,6 +8,25 @@ import (
 	"testing"
 )
 
+func TestGetIndexFromArray(t *testing.T) {
+	m := map[string]any{
+		"A": "true",
+		"B": []string{
+			"1",
+			"2",
+			"3",
+		},
+	}
+
+	v, err := GetValue[map[string]any](m, "B.1")
+	if err != nil {
+		t.Error(err)
+	}
+
+	assert.Equal(t, "2", v)
+
+}
+
 func TestGetValueFrom(t *testing.T) {
 
 	m := map[string]string{
-- 
GitLab