From aa99843ec5a11ae8ac4af99f6ce2052ad410d716 Mon Sep 17 00:00:00 2001
From: Volker Schukai <volker.schukai@schukai.com>
Date: Sat, 16 Sep 2023 10:48:21 +0200
Subject: [PATCH] fix: map as last element causes panic #9

---
 issue_7_test.go |  38 +++++++++-
 set.go          | 188 ++++++++++++++++++++++++++++++++----------------
 2 files changed, 161 insertions(+), 65 deletions(-)

diff --git a/issue_7_test.go b/issue_7_test.go
index 4b6bf62..5033b28 100644
--- a/issue_7_test.go
+++ b/issue_7_test.go
@@ -25,7 +25,7 @@ func TestSetValue(t *testing.T) {
 		t.Error(err)
 	}
 
-	assert.Equal(t, v, PathValue("oldValue"))
+	assert.Equal(t, PathValue("oldValue"), v)
 
 	nv := PathValue("newValue")
 	err = SetValue[*Outer](obj, "InnerField.Field", nv)
@@ -75,7 +75,7 @@ func TestPathRewrite(t *testing.T) {
 	v, err := GetValue[*Issue7Config](&obj, "Server.Routing.0.P")
 	assert.Nil(t, err)
 
-	assert.Equal(t, v, PathValue("./test"))
+	assert.Equal(t, PathValue("./test"), v)
 
 	nv := PathValue("newValue")
 	err = SetValue[*Issue7Config](&obj, "Server.Routing.0.P", nv)
@@ -90,3 +90,37 @@ func TestPathRewrite(t *testing.T) {
 	assert.Equal(t, v, nv)
 
 }
+
+// Test data structs
+type Issue7TestStruct2 struct {
+	D map[string]PathValue
+}
+
+func TestPathRewrite2(t *testing.T) {
+	// Test case 2
+	ts2 := Issue7TestStruct2{
+		D: map[string]PathValue{
+			"key1": "yyy",
+			"key2": "zzz",
+		},
+	}
+
+	v, err := GetValue[Issue7TestStruct2](ts2, "D.key1")
+	assert.Nil(t, err)
+	assert.Equal(t, PathValue("yyy"), v)
+
+	v, err = GetValue[Issue7TestStruct2](ts2, "D.key2")
+	assert.Nil(t, err)
+	assert.Equal(t, PathValue("zzz"), v)
+
+	err = SetValue[*Issue7TestStruct2](&ts2, "D.key1", "xxx")
+	assert.Nil(t, err)
+
+	v, err = GetValue[Issue7TestStruct2](ts2, "D.key1")
+	assert.Nil(t, err)
+	assert.Equal(t, PathValue("xxx"), v)
+
+	v, err = GetValue[Issue7TestStruct2](ts2, "D.key2")
+	assert.Nil(t, err)
+	assert.Equal(t, PathValue("zzz"), v)
+}
diff --git a/set.go b/set.go
index 8004d05..92b6509 100644
--- a/set.go
+++ b/set.go
@@ -29,17 +29,17 @@ func deepCopy(src, dst interface{}) error {
 func SetValue[D any](obj D, keyWithDots string, newValue any) error {
 
 	keySlice := strings.Split(keyWithDots, ".")
-	v := reflect.ValueOf(obj)
+	reflectionOfObject := reflect.ValueOf(obj)
 
 	for keyIndex, key := range keySlice[0 : len(keySlice)-1] {
 
-		if v.Kind() == reflect.Map {
+		if reflectionOfObject.Kind() == reflect.Map {
 
-			if v.IsNil() {
+			if reflectionOfObject.IsNil() {
 				return newInvalidPathError(keyWithDots)
 			}
 
-			currentValue := v.MapIndex(reflect.ValueOf(key)).Interface()
+			currentValue := reflectionOfObject.MapIndex(reflect.ValueOf(key)).Interface()
 			newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface()
 			if err := deepCopy(currentValue, newValueCopy); err != nil {
 				return err
@@ -56,213 +56,275 @@ func SetValue[D any](obj D, keyWithDots string, newValue any) error {
 				return err
 			}
 
-			v.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem())
+			reflectionOfObject.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem())
 			return nil
 
 		}
 
-		if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Interface {
-			v = v.Elem().Elem()
+		if reflectionOfObject.Kind() == reflect.Ptr && reflectionOfObject.Elem().Kind() == reflect.Interface {
+			reflectionOfObject = reflectionOfObject.Elem().Elem()
 		}
 
-		for v.Kind() != reflect.Ptr {
-			if v.Kind() == reflect.Invalid {
+		for reflectionOfObject.Kind() != reflect.Ptr {
+			if reflectionOfObject.Kind() == reflect.Invalid {
 				return newInvalidPathError(keyWithDots)
 			}
 
-			if v.CanAddr() {
-				v = v.Addr()
+			if reflectionOfObject.CanAddr() {
+				reflectionOfObject = reflectionOfObject.Addr()
 			} else {
 				return newCannotSetError(keyWithDots)
 			}
 
 		}
 
-		if v.Kind() != reflect.Ptr {
-			return newUnsupportedTypePathError(keyWithDots, v.Type())
+		if reflectionOfObject.Kind() != reflect.Ptr {
+			return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
 		}
 
-		switch v.Elem().Kind() {
+		switch reflectionOfObject.Elem().Kind() {
 		case reflect.Struct:
-			v = v.Elem().FieldByName(key)
+			reflectionOfObject = reflectionOfObject.Elem().FieldByName(key)
 
 		case reflect.Slice:
-			// index is a number and get v from slice with index
+			// index is a number and get reflectionOfObject from slice with index
 			index, err := strconv.Atoi(key)
 			if err != nil {
 				return newInvalidPathError(keyWithDots)
 			}
 
-			if index >= v.Elem().Len() {
+			if index >= reflectionOfObject.Elem().Len() {
 				return newInvalidPathError(keyWithDots)
 			}
 
-			v = v.Elem().Index(index)
+			reflectionOfObject = reflectionOfObject.Elem().Index(index)
 		default:
-			return newUnsupportedTypePathError(keyWithDots, v.Type())
+			return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
 		}
 
 	}
 
-	if v.Kind() == reflect.Invalid {
+	if reflectionOfObject.Kind() == reflect.Invalid {
 		return newInvalidPathError(keyWithDots)
 	}
 
-	for v.Kind() == reflect.Ptr {
-		v = v.Elem()
+	for reflectionOfObject.Kind() == reflect.Ptr {
+		reflectionOfObject = reflectionOfObject.Elem()
 	}
 
 	// non-supporter type at the top of the path
-	switch v.Kind() {
+	switch reflectionOfObject.Kind() {
 	case reflect.Struct:
 
-		v = v.FieldByName(keySlice[len(keySlice)-1])
-		if !v.IsValid() {
+		reflectionOfObject = reflectionOfObject.FieldByName(keySlice[len(keySlice)-1])
+		if !reflectionOfObject.IsValid() {
 			return newInvalidPathError(keyWithDots)
 		}
 
-		if !v.CanSet() {
+		if !reflectionOfObject.CanSet() {
 			return newCannotSetError(keyWithDots)
 		}
 
 	case reflect.Map:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+
+		key := keySlice[len(keySlice)-1]
+		m := reflectionOfObject
+
+		keyVal := reflect.ValueOf(key)
+		newVal := reflect.ValueOf(newValue)
+
+		if !keyVal.Type().ConvertibleTo(m.Type().Key()) {
+			return fmt.Errorf("key type mismatch")
+		}
+
+		if !newVal.Type().ConvertibleTo(m.Type().Elem()) {
+			return fmt.Errorf("value type mismatch")
+		}
+
+		keyValConverted := keyVal.Convert(m.Type().Key())
+		newValConverted := newVal.Convert(m.Type().Elem())
+		m.SetMapIndex(keyValConverted, newValConverted)
+		return nil
+
+		//currentValue := reflectionOfObject.MapIndex(reflect.ValueOf(key)).Interface()
+		//newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface()
+		//if err := deepCopy(currentValue, newValueCopy); err != nil {
+		//	return err
+		//}
+		//newValueCopyPtr := &newValueCopy
+		//newValueCopyReflect := reflect.ValueOf(newValueCopyPtr).Elem()
+		//if !newValueCopyReflect.CanAddr() {
+		//	return newCannotSetError("Wert ist nicht adressierbar")
+		//}
+		////newKey := strings.Join(keySlice[keyIndex+1:], ".")
+		////err := SetValue(newValueCopyPtr, newKey, newValue)
+		////if err != nil {
+		////	return err
+		////}
+		//
+		//reflectionOfObject.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem())
+		//return nil
+
+		//if reflectionOfObject.IsNil() {
+		//	return newInvalidPathError(keyWithDots)
+		//}
+		//
+		//index := keySlice[len(keySlice)-1]
+		//reflectedIndex := reflect.ValueOf(index)
+		//
+		//if !reflectedIndex.Type().AssignableTo(reflectionOfObject.Type().Key()) {
+		//	return newInvalidPathError(keyWithDots)
+		//}
+		//
+		//currentValue := reflectionOfObject.MapIndex(reflectedIndex).Interface()
+		//newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface()
+		//if err := deepCopy(currentValue, newValueCopy); err != nil {
+		//	return err
+		//}
+		//
+		//if !reflect.ValueOf(newValueCopy).Elem().Type().AssignableTo(reflectionOfObject.Type().Elem()) {
+		//	return newInvalidPathError(keyWithDots)
+		//}
+		//
+		//newValueCopyX := reflect.ValueOf(newValueCopy).Elem()
+		//reflectionOfObject.SetMapIndex(reflectedIndex, newValueCopyX)
+
 	case reflect.Slice:
 
-		// index is a number and get v from slice with index
+		// index is a number and get reflectionOfObject from slice with index
 		index, err := strconv.Atoi(keySlice[len(keySlice)-1])
 		if err != nil {
 			return newInvalidPathError(keyWithDots)
 		}
 
 		// index out of range
-		if index >= v.Len() {
+		if index >= reflectionOfObject.Len() {
 			return newInvalidPathError(keyWithDots)
 		}
 
-		v = v.Index(index)
+		reflectionOfObject = reflectionOfObject.Index(index)
 
 	case reflect.Array:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.Ptr:
 		if newValue == nil {
-			v.Set(reflect.Zero(v.Type()))
+			reflectionOfObject.Set(reflect.Zero(reflectionOfObject.Type()))
 		} else {
-			v.Set(reflect.ValueOf(&newValue))
+			reflectionOfObject.Set(reflect.ValueOf(&newValue))
 		}
 		return nil
 	case reflect.Interface:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.Chan:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.Func:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.UnsafePointer:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.Uintptr:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.Complex64:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.Complex128:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	case reflect.Invalid:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	default:
-		return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
+		return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
 	}
 
 	newValueType := reflect.TypeOf(newValue)
 	if newValueType == nil {
-		return newUnsupportedTypePathError(keyWithDots, v.Type())
+		return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
 	}
 
 	newValueKind := reflect.TypeOf(newValue).Kind()
 
-	switch v.Kind() {
+	switch reflectionOfObject.Kind() {
 	case reflect.String:
-		if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
-			if v.Elem().CanSet() && v.Elem().Kind() == reflect.String {
+		if reflectionOfObject.Kind() == reflect.Ptr || reflectionOfObject.Kind() == reflect.Interface {
+			if reflectionOfObject.Elem().CanSet() && reflectionOfObject.Elem().Kind() == reflect.String {
 				if newValueKind == reflect.String {
-					v.Elem().SetString(newValue.(string))
+					reflectionOfObject.Elem().SetString(newValue.(string))
 				} else {
-					v.Elem().SetString(fmt.Sprintf("%v", newValue))
+					reflectionOfObject.Elem().SetString(fmt.Sprintf("%v", newValue))
 				}
 			}
 		} else if newValueKind == reflect.String {
 
 			if reflect.TypeOf(newValue).ConvertibleTo(reflect.TypeOf("")) {
 				newValueString := reflect.ValueOf(newValue).Convert(reflect.TypeOf("")).Interface().(string)
-				v.SetString(newValueString)
+				reflectionOfObject.SetString(newValueString)
 			} else {
-				return newUnsupportedTypePathError(keyWithDots, v.Type())
+				return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
 			}
 		} else {
-			v.SetString(fmt.Sprintf("%v", newValue))
+			reflectionOfObject.SetString(fmt.Sprintf("%v", newValue))
 		}
 
 	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
 
 		if newValueKind == reflect.Int {
-			v.SetInt(int64(newValue.(int)))
+			reflectionOfObject.SetInt(int64(newValue.(int)))
 		} else {
 			s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64)
 			if err != nil {
 				return err
 			}
-			v.SetInt(s)
+			reflectionOfObject.SetInt(s)
 		}
 
 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 
 		if newValueKind == reflect.Int {
-			v.SetUint(uint64(newValue.(int)))
+			reflectionOfObject.SetUint(uint64(newValue.(int)))
 		} else {
 			s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64)
 			if err != nil {
 				return err
 			}
-			v.SetUint(uint64(s))
+			reflectionOfObject.SetUint(uint64(s))
 		}
 
 	case reflect.Bool:
 
 		if newValueKind == reflect.Bool {
-			v.SetBool(newValue.(bool))
+			reflectionOfObject.SetBool(newValue.(bool))
 		} else {
 			b, err := strconv.ParseBool(fmt.Sprintf("%v", newValue))
 			if err != nil {
 				return err
 			}
 
-			v.SetBool(b)
+			reflectionOfObject.SetBool(b)
 		}
 
 	case reflect.Float64, reflect.Float32:
 
 		if newValueKind == reflect.Float64 {
-			v.SetFloat(newValue.(float64))
+			reflectionOfObject.SetFloat(newValue.(float64))
 		} else {
 			s, err := strconv.ParseFloat(fmt.Sprintf("%v", newValue), 64)
 			if err != nil {
 				return err
 			}
 
-			v.SetFloat(s)
+			reflectionOfObject.SetFloat(s)
 		}
 
 	case reflect.Slice, reflect.Array:
 
 		if newValueKind == reflect.Ptr {
 			newValue = reflect.ValueOf(newValue).Elem().Interface()
-			v.Set(reflect.ValueOf(newValue))
+			reflectionOfObject.Set(reflect.ValueOf(newValue))
 		} else if newValueKind == reflect.Slice {
-			v.Set(reflect.ValueOf(newValue))
+			reflectionOfObject.Set(reflect.ValueOf(newValue))
 		} else {
-			return newUnsupportedTypePathError(keyWithDots, v.Type())
+			return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
 		}
 
 	default:
-		return newInvalidTypeForPathError(keyWithDots, v.Type().String(), newValueKind.String())
+		return newInvalidTypeForPathError(keyWithDots, reflectionOfObject.Type().String(), newValueKind.String())
 	}
 
 	return nil
-- 
GitLab