Skip to content
Snippets Groups Projects
Verified Commit aa99843e authored by Volker Schukai's avatar Volker Schukai :alien:
Browse files

fix: map as last element causes panic #9

parent 77af24d7
No related branches found
No related tags found
No related merge requests found
...@@ -25,7 +25,7 @@ func TestSetValue(t *testing.T) { ...@@ -25,7 +25,7 @@ func TestSetValue(t *testing.T) {
t.Error(err) t.Error(err)
} }
assert.Equal(t, v, PathValue("oldValue")) assert.Equal(t, PathValue("oldValue"), v)
nv := PathValue("newValue") nv := PathValue("newValue")
err = SetValue[*Outer](obj, "InnerField.Field", nv) err = SetValue[*Outer](obj, "InnerField.Field", nv)
...@@ -75,7 +75,7 @@ func TestPathRewrite(t *testing.T) { ...@@ -75,7 +75,7 @@ func TestPathRewrite(t *testing.T) {
v, err := GetValue[*Issue7Config](&obj, "Server.Routing.0.P") v, err := GetValue[*Issue7Config](&obj, "Server.Routing.0.P")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, v, PathValue("./test")) assert.Equal(t, PathValue("./test"), v)
nv := PathValue("newValue") nv := PathValue("newValue")
err = SetValue[*Issue7Config](&obj, "Server.Routing.0.P", nv) err = SetValue[*Issue7Config](&obj, "Server.Routing.0.P", nv)
...@@ -90,3 +90,37 @@ func TestPathRewrite(t *testing.T) { ...@@ -90,3 +90,37 @@ func TestPathRewrite(t *testing.T) {
assert.Equal(t, v, nv) assert.Equal(t, v, nv)
} }
// Test data structs
type Issue7TestStruct2 struct {
D map[string]PathValue
}
func TestPathRewrite2(t *testing.T) {
// Test case 2
ts2 := Issue7TestStruct2{
D: map[string]PathValue{
"key1": "yyy",
"key2": "zzz",
},
}
v, err := GetValue[Issue7TestStruct2](ts2, "D.key1")
assert.Nil(t, err)
assert.Equal(t, PathValue("yyy"), v)
v, err = GetValue[Issue7TestStruct2](ts2, "D.key2")
assert.Nil(t, err)
assert.Equal(t, PathValue("zzz"), v)
err = SetValue[*Issue7TestStruct2](&ts2, "D.key1", "xxx")
assert.Nil(t, err)
v, err = GetValue[Issue7TestStruct2](ts2, "D.key1")
assert.Nil(t, err)
assert.Equal(t, PathValue("xxx"), v)
v, err = GetValue[Issue7TestStruct2](ts2, "D.key2")
assert.Nil(t, err)
assert.Equal(t, PathValue("zzz"), v)
}
...@@ -29,17 +29,17 @@ func deepCopy(src, dst interface{}) error { ...@@ -29,17 +29,17 @@ func deepCopy(src, dst interface{}) error {
func SetValue[D any](obj D, keyWithDots string, newValue any) error { func SetValue[D any](obj D, keyWithDots string, newValue any) error {
keySlice := strings.Split(keyWithDots, ".") keySlice := strings.Split(keyWithDots, ".")
v := reflect.ValueOf(obj) reflectionOfObject := reflect.ValueOf(obj)
for keyIndex, key := range keySlice[0 : len(keySlice)-1] { for keyIndex, key := range keySlice[0 : len(keySlice)-1] {
if v.Kind() == reflect.Map { if reflectionOfObject.Kind() == reflect.Map {
if v.IsNil() { if reflectionOfObject.IsNil() {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
currentValue := v.MapIndex(reflect.ValueOf(key)).Interface() currentValue := reflectionOfObject.MapIndex(reflect.ValueOf(key)).Interface()
newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface() newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface()
if err := deepCopy(currentValue, newValueCopy); err != nil { if err := deepCopy(currentValue, newValueCopy); err != nil {
return err return err
...@@ -56,213 +56,275 @@ func SetValue[D any](obj D, keyWithDots string, newValue any) error { ...@@ -56,213 +56,275 @@ func SetValue[D any](obj D, keyWithDots string, newValue any) error {
return err return err
} }
v.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem()) reflectionOfObject.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem())
return nil return nil
} }
if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Interface { if reflectionOfObject.Kind() == reflect.Ptr && reflectionOfObject.Elem().Kind() == reflect.Interface {
v = v.Elem().Elem() reflectionOfObject = reflectionOfObject.Elem().Elem()
} }
for v.Kind() != reflect.Ptr { for reflectionOfObject.Kind() != reflect.Ptr {
if v.Kind() == reflect.Invalid { if reflectionOfObject.Kind() == reflect.Invalid {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
if v.CanAddr() { if reflectionOfObject.CanAddr() {
v = v.Addr() reflectionOfObject = reflectionOfObject.Addr()
} else { } else {
return newCannotSetError(keyWithDots) return newCannotSetError(keyWithDots)
} }
} }
if v.Kind() != reflect.Ptr { if reflectionOfObject.Kind() != reflect.Ptr {
return newUnsupportedTypePathError(keyWithDots, v.Type()) return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
} }
switch v.Elem().Kind() { switch reflectionOfObject.Elem().Kind() {
case reflect.Struct: case reflect.Struct:
v = v.Elem().FieldByName(key) reflectionOfObject = reflectionOfObject.Elem().FieldByName(key)
case reflect.Slice: case reflect.Slice:
// index is a number and get v from slice with index // index is a number and get reflectionOfObject from slice with index
index, err := strconv.Atoi(key) index, err := strconv.Atoi(key)
if err != nil { if err != nil {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
if index >= v.Elem().Len() { if index >= reflectionOfObject.Elem().Len() {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
v = v.Elem().Index(index) reflectionOfObject = reflectionOfObject.Elem().Index(index)
default: default:
return newUnsupportedTypePathError(keyWithDots, v.Type()) return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
} }
} }
if v.Kind() == reflect.Invalid { if reflectionOfObject.Kind() == reflect.Invalid {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
for v.Kind() == reflect.Ptr { for reflectionOfObject.Kind() == reflect.Ptr {
v = v.Elem() reflectionOfObject = reflectionOfObject.Elem()
} }
// non-supporter type at the top of the path // non-supporter type at the top of the path
switch v.Kind() { switch reflectionOfObject.Kind() {
case reflect.Struct: case reflect.Struct:
v = v.FieldByName(keySlice[len(keySlice)-1]) reflectionOfObject = reflectionOfObject.FieldByName(keySlice[len(keySlice)-1])
if !v.IsValid() { if !reflectionOfObject.IsValid() {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
if !v.CanSet() { if !reflectionOfObject.CanSet() {
return newCannotSetError(keyWithDots) return newCannotSetError(keyWithDots)
} }
case reflect.Map: case reflect.Map:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type())
key := keySlice[len(keySlice)-1]
m := reflectionOfObject
keyVal := reflect.ValueOf(key)
newVal := reflect.ValueOf(newValue)
if !keyVal.Type().ConvertibleTo(m.Type().Key()) {
return fmt.Errorf("key type mismatch")
}
if !newVal.Type().ConvertibleTo(m.Type().Elem()) {
return fmt.Errorf("value type mismatch")
}
keyValConverted := keyVal.Convert(m.Type().Key())
newValConverted := newVal.Convert(m.Type().Elem())
m.SetMapIndex(keyValConverted, newValConverted)
return nil
//currentValue := reflectionOfObject.MapIndex(reflect.ValueOf(key)).Interface()
//newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface()
//if err := deepCopy(currentValue, newValueCopy); err != nil {
// return err
//}
//newValueCopyPtr := &newValueCopy
//newValueCopyReflect := reflect.ValueOf(newValueCopyPtr).Elem()
//if !newValueCopyReflect.CanAddr() {
// return newCannotSetError("Wert ist nicht adressierbar")
//}
////newKey := strings.Join(keySlice[keyIndex+1:], ".")
////err := SetValue(newValueCopyPtr, newKey, newValue)
////if err != nil {
//// return err
////}
//
//reflectionOfObject.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem())
//return nil
//if reflectionOfObject.IsNil() {
// return newInvalidPathError(keyWithDots)
//}
//
//index := keySlice[len(keySlice)-1]
//reflectedIndex := reflect.ValueOf(index)
//
//if !reflectedIndex.Type().AssignableTo(reflectionOfObject.Type().Key()) {
// return newInvalidPathError(keyWithDots)
//}
//
//currentValue := reflectionOfObject.MapIndex(reflectedIndex).Interface()
//newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface()
//if err := deepCopy(currentValue, newValueCopy); err != nil {
// return err
//}
//
//if !reflect.ValueOf(newValueCopy).Elem().Type().AssignableTo(reflectionOfObject.Type().Elem()) {
// return newInvalidPathError(keyWithDots)
//}
//
//newValueCopyX := reflect.ValueOf(newValueCopy).Elem()
//reflectionOfObject.SetMapIndex(reflectedIndex, newValueCopyX)
case reflect.Slice: case reflect.Slice:
// index is a number and get v from slice with index // index is a number and get reflectionOfObject from slice with index
index, err := strconv.Atoi(keySlice[len(keySlice)-1]) index, err := strconv.Atoi(keySlice[len(keySlice)-1])
if err != nil { if err != nil {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
// index out of range // index out of range
if index >= v.Len() { if index >= reflectionOfObject.Len() {
return newInvalidPathError(keyWithDots) return newInvalidPathError(keyWithDots)
} }
v = v.Index(index) reflectionOfObject = reflectionOfObject.Index(index)
case reflect.Array: case reflect.Array:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.Ptr: case reflect.Ptr:
if newValue == nil { if newValue == nil {
v.Set(reflect.Zero(v.Type())) reflectionOfObject.Set(reflect.Zero(reflectionOfObject.Type()))
} else { } else {
v.Set(reflect.ValueOf(&newValue)) reflectionOfObject.Set(reflect.ValueOf(&newValue))
} }
return nil return nil
case reflect.Interface: case reflect.Interface:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.Chan: case reflect.Chan:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.Func: case reflect.Func:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.UnsafePointer: case reflect.UnsafePointer:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.Uintptr: case reflect.Uintptr:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.Complex64: case reflect.Complex64:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.Complex128: case reflect.Complex128:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
case reflect.Invalid: case reflect.Invalid:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
default: default:
return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type())
} }
newValueType := reflect.TypeOf(newValue) newValueType := reflect.TypeOf(newValue)
if newValueType == nil { if newValueType == nil {
return newUnsupportedTypePathError(keyWithDots, v.Type()) return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
} }
newValueKind := reflect.TypeOf(newValue).Kind() newValueKind := reflect.TypeOf(newValue).Kind()
switch v.Kind() { switch reflectionOfObject.Kind() {
case reflect.String: case reflect.String:
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { if reflectionOfObject.Kind() == reflect.Ptr || reflectionOfObject.Kind() == reflect.Interface {
if v.Elem().CanSet() && v.Elem().Kind() == reflect.String { if reflectionOfObject.Elem().CanSet() && reflectionOfObject.Elem().Kind() == reflect.String {
if newValueKind == reflect.String { if newValueKind == reflect.String {
v.Elem().SetString(newValue.(string)) reflectionOfObject.Elem().SetString(newValue.(string))
} else { } else {
v.Elem().SetString(fmt.Sprintf("%v", newValue)) reflectionOfObject.Elem().SetString(fmt.Sprintf("%v", newValue))
} }
} }
} else if newValueKind == reflect.String { } else if newValueKind == reflect.String {
if reflect.TypeOf(newValue).ConvertibleTo(reflect.TypeOf("")) { if reflect.TypeOf(newValue).ConvertibleTo(reflect.TypeOf("")) {
newValueString := reflect.ValueOf(newValue).Convert(reflect.TypeOf("")).Interface().(string) newValueString := reflect.ValueOf(newValue).Convert(reflect.TypeOf("")).Interface().(string)
v.SetString(newValueString) reflectionOfObject.SetString(newValueString)
} else { } else {
return newUnsupportedTypePathError(keyWithDots, v.Type()) return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
} }
} else { } else {
v.SetString(fmt.Sprintf("%v", newValue)) reflectionOfObject.SetString(fmt.Sprintf("%v", newValue))
} }
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
if newValueKind == reflect.Int { if newValueKind == reflect.Int {
v.SetInt(int64(newValue.(int))) reflectionOfObject.SetInt(int64(newValue.(int)))
} else { } else {
s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64) s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64)
if err != nil { if err != nil {
return err return err
} }
v.SetInt(s) reflectionOfObject.SetInt(s)
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if newValueKind == reflect.Int { if newValueKind == reflect.Int {
v.SetUint(uint64(newValue.(int))) reflectionOfObject.SetUint(uint64(newValue.(int)))
} else { } else {
s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64) s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64)
if err != nil { if err != nil {
return err return err
} }
v.SetUint(uint64(s)) reflectionOfObject.SetUint(uint64(s))
} }
case reflect.Bool: case reflect.Bool:
if newValueKind == reflect.Bool { if newValueKind == reflect.Bool {
v.SetBool(newValue.(bool)) reflectionOfObject.SetBool(newValue.(bool))
} else { } else {
b, err := strconv.ParseBool(fmt.Sprintf("%v", newValue)) b, err := strconv.ParseBool(fmt.Sprintf("%v", newValue))
if err != nil { if err != nil {
return err return err
} }
v.SetBool(b) reflectionOfObject.SetBool(b)
} }
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
if newValueKind == reflect.Float64 { if newValueKind == reflect.Float64 {
v.SetFloat(newValue.(float64)) reflectionOfObject.SetFloat(newValue.(float64))
} else { } else {
s, err := strconv.ParseFloat(fmt.Sprintf("%v", newValue), 64) s, err := strconv.ParseFloat(fmt.Sprintf("%v", newValue), 64)
if err != nil { if err != nil {
return err return err
} }
v.SetFloat(s) reflectionOfObject.SetFloat(s)
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if newValueKind == reflect.Ptr { if newValueKind == reflect.Ptr {
newValue = reflect.ValueOf(newValue).Elem().Interface() newValue = reflect.ValueOf(newValue).Elem().Interface()
v.Set(reflect.ValueOf(newValue)) reflectionOfObject.Set(reflect.ValueOf(newValue))
} else if newValueKind == reflect.Slice { } else if newValueKind == reflect.Slice {
v.Set(reflect.ValueOf(newValue)) reflectionOfObject.Set(reflect.ValueOf(newValue))
} else { } else {
return newUnsupportedTypePathError(keyWithDots, v.Type()) return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type())
} }
default: default:
return newInvalidTypeForPathError(keyWithDots, v.Type().String(), newValueKind.String()) return newInvalidTypeForPathError(keyWithDots, reflectionOfObject.Type().String(), newValueKind.String())
} }
return nil return nil
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment