diff --git a/issue_7_test.go b/issue_7_test.go index 4b6bf622a08af8d0e73ded83ba6b437a0cd3e361..5033b289852cbad971d25b165c2e8dc20615e4b3 100644 --- a/issue_7_test.go +++ b/issue_7_test.go @@ -25,7 +25,7 @@ func TestSetValue(t *testing.T) { t.Error(err) } - assert.Equal(t, v, PathValue("oldValue")) + assert.Equal(t, PathValue("oldValue"), v) nv := PathValue("newValue") err = SetValue[*Outer](obj, "InnerField.Field", nv) @@ -75,7 +75,7 @@ func TestPathRewrite(t *testing.T) { v, err := GetValue[*Issue7Config](&obj, "Server.Routing.0.P") assert.Nil(t, err) - assert.Equal(t, v, PathValue("./test")) + assert.Equal(t, PathValue("./test"), v) nv := PathValue("newValue") err = SetValue[*Issue7Config](&obj, "Server.Routing.0.P", nv) @@ -90,3 +90,37 @@ func TestPathRewrite(t *testing.T) { assert.Equal(t, v, nv) } + +// Test data structs +type Issue7TestStruct2 struct { + D map[string]PathValue +} + +func TestPathRewrite2(t *testing.T) { + // Test case 2 + ts2 := Issue7TestStruct2{ + D: map[string]PathValue{ + "key1": "yyy", + "key2": "zzz", + }, + } + + v, err := GetValue[Issue7TestStruct2](ts2, "D.key1") + assert.Nil(t, err) + assert.Equal(t, PathValue("yyy"), v) + + v, err = GetValue[Issue7TestStruct2](ts2, "D.key2") + assert.Nil(t, err) + assert.Equal(t, PathValue("zzz"), v) + + err = SetValue[*Issue7TestStruct2](&ts2, "D.key1", "xxx") + assert.Nil(t, err) + + v, err = GetValue[Issue7TestStruct2](ts2, "D.key1") + assert.Nil(t, err) + assert.Equal(t, PathValue("xxx"), v) + + v, err = GetValue[Issue7TestStruct2](ts2, "D.key2") + assert.Nil(t, err) + assert.Equal(t, PathValue("zzz"), v) +} diff --git a/set.go b/set.go index 8004d051a926312633de027306a3a0f6f733a28d..92b650947c7d46f0ff6bd197ac59662db76dd130 100644 --- a/set.go +++ b/set.go @@ -29,17 +29,17 @@ func deepCopy(src, dst interface{}) error { func SetValue[D any](obj D, keyWithDots string, newValue any) error { keySlice := strings.Split(keyWithDots, ".") - v := reflect.ValueOf(obj) + reflectionOfObject := reflect.ValueOf(obj) for keyIndex, key := range keySlice[0 : len(keySlice)-1] { - if v.Kind() == reflect.Map { + if reflectionOfObject.Kind() == reflect.Map { - if v.IsNil() { + if reflectionOfObject.IsNil() { return newInvalidPathError(keyWithDots) } - currentValue := v.MapIndex(reflect.ValueOf(key)).Interface() + currentValue := reflectionOfObject.MapIndex(reflect.ValueOf(key)).Interface() newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface() if err := deepCopy(currentValue, newValueCopy); err != nil { return err @@ -56,213 +56,275 @@ func SetValue[D any](obj D, keyWithDots string, newValue any) error { return err } - v.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem()) + reflectionOfObject.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem()) return nil } - if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Interface { - v = v.Elem().Elem() + if reflectionOfObject.Kind() == reflect.Ptr && reflectionOfObject.Elem().Kind() == reflect.Interface { + reflectionOfObject = reflectionOfObject.Elem().Elem() } - for v.Kind() != reflect.Ptr { - if v.Kind() == reflect.Invalid { + for reflectionOfObject.Kind() != reflect.Ptr { + if reflectionOfObject.Kind() == reflect.Invalid { return newInvalidPathError(keyWithDots) } - if v.CanAddr() { - v = v.Addr() + if reflectionOfObject.CanAddr() { + reflectionOfObject = reflectionOfObject.Addr() } else { return newCannotSetError(keyWithDots) } } - if v.Kind() != reflect.Ptr { - return newUnsupportedTypePathError(keyWithDots, v.Type()) + if reflectionOfObject.Kind() != reflect.Ptr { + return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } - switch v.Elem().Kind() { + switch reflectionOfObject.Elem().Kind() { case reflect.Struct: - v = v.Elem().FieldByName(key) + reflectionOfObject = reflectionOfObject.Elem().FieldByName(key) case reflect.Slice: - // index is a number and get v from slice with index + // index is a number and get reflectionOfObject from slice with index index, err := strconv.Atoi(key) if err != nil { return newInvalidPathError(keyWithDots) } - if index >= v.Elem().Len() { + if index >= reflectionOfObject.Elem().Len() { return newInvalidPathError(keyWithDots) } - v = v.Elem().Index(index) + reflectionOfObject = reflectionOfObject.Elem().Index(index) default: - return newUnsupportedTypePathError(keyWithDots, v.Type()) + return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } } - if v.Kind() == reflect.Invalid { + if reflectionOfObject.Kind() == reflect.Invalid { return newInvalidPathError(keyWithDots) } - for v.Kind() == reflect.Ptr { - v = v.Elem() + for reflectionOfObject.Kind() == reflect.Ptr { + reflectionOfObject = reflectionOfObject.Elem() } // non-supporter type at the top of the path - switch v.Kind() { + switch reflectionOfObject.Kind() { case reflect.Struct: - v = v.FieldByName(keySlice[len(keySlice)-1]) - if !v.IsValid() { + reflectionOfObject = reflectionOfObject.FieldByName(keySlice[len(keySlice)-1]) + if !reflectionOfObject.IsValid() { return newInvalidPathError(keyWithDots) } - if !v.CanSet() { + if !reflectionOfObject.CanSet() { return newCannotSetError(keyWithDots) } case reflect.Map: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + + key := keySlice[len(keySlice)-1] + m := reflectionOfObject + + keyVal := reflect.ValueOf(key) + newVal := reflect.ValueOf(newValue) + + if !keyVal.Type().ConvertibleTo(m.Type().Key()) { + return fmt.Errorf("key type mismatch") + } + + if !newVal.Type().ConvertibleTo(m.Type().Elem()) { + return fmt.Errorf("value type mismatch") + } + + keyValConverted := keyVal.Convert(m.Type().Key()) + newValConverted := newVal.Convert(m.Type().Elem()) + m.SetMapIndex(keyValConverted, newValConverted) + return nil + + //currentValue := reflectionOfObject.MapIndex(reflect.ValueOf(key)).Interface() + //newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface() + //if err := deepCopy(currentValue, newValueCopy); err != nil { + // return err + //} + //newValueCopyPtr := &newValueCopy + //newValueCopyReflect := reflect.ValueOf(newValueCopyPtr).Elem() + //if !newValueCopyReflect.CanAddr() { + // return newCannotSetError("Wert ist nicht adressierbar") + //} + ////newKey := strings.Join(keySlice[keyIndex+1:], ".") + ////err := SetValue(newValueCopyPtr, newKey, newValue) + ////if err != nil { + //// return err + ////} + // + //reflectionOfObject.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem()) + //return nil + + //if reflectionOfObject.IsNil() { + // return newInvalidPathError(keyWithDots) + //} + // + //index := keySlice[len(keySlice)-1] + //reflectedIndex := reflect.ValueOf(index) + // + //if !reflectedIndex.Type().AssignableTo(reflectionOfObject.Type().Key()) { + // return newInvalidPathError(keyWithDots) + //} + // + //currentValue := reflectionOfObject.MapIndex(reflectedIndex).Interface() + //newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface() + //if err := deepCopy(currentValue, newValueCopy); err != nil { + // return err + //} + // + //if !reflect.ValueOf(newValueCopy).Elem().Type().AssignableTo(reflectionOfObject.Type().Elem()) { + // return newInvalidPathError(keyWithDots) + //} + // + //newValueCopyX := reflect.ValueOf(newValueCopy).Elem() + //reflectionOfObject.SetMapIndex(reflectedIndex, newValueCopyX) + case reflect.Slice: - // index is a number and get v from slice with index + // index is a number and get reflectionOfObject from slice with index index, err := strconv.Atoi(keySlice[len(keySlice)-1]) if err != nil { return newInvalidPathError(keyWithDots) } // index out of range - if index >= v.Len() { + if index >= reflectionOfObject.Len() { return newInvalidPathError(keyWithDots) } - v = v.Index(index) + reflectionOfObject = reflectionOfObject.Index(index) case reflect.Array: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Ptr: if newValue == nil { - v.Set(reflect.Zero(v.Type())) + reflectionOfObject.Set(reflect.Zero(reflectionOfObject.Type())) } else { - v.Set(reflect.ValueOf(&newValue)) + reflectionOfObject.Set(reflect.ValueOf(&newValue)) } return nil case reflect.Interface: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Chan: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Func: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.UnsafePointer: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Uintptr: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Complex64: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Complex128: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Invalid: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) default: - return newUnsupportedTypeAtTopOfPathError(keyWithDots, v.Type()) + return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) } newValueType := reflect.TypeOf(newValue) if newValueType == nil { - return newUnsupportedTypePathError(keyWithDots, v.Type()) + return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } newValueKind := reflect.TypeOf(newValue).Kind() - switch v.Kind() { + switch reflectionOfObject.Kind() { case reflect.String: - if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { - if v.Elem().CanSet() && v.Elem().Kind() == reflect.String { + if reflectionOfObject.Kind() == reflect.Ptr || reflectionOfObject.Kind() == reflect.Interface { + if reflectionOfObject.Elem().CanSet() && reflectionOfObject.Elem().Kind() == reflect.String { if newValueKind == reflect.String { - v.Elem().SetString(newValue.(string)) + reflectionOfObject.Elem().SetString(newValue.(string)) } else { - v.Elem().SetString(fmt.Sprintf("%v", newValue)) + reflectionOfObject.Elem().SetString(fmt.Sprintf("%v", newValue)) } } } else if newValueKind == reflect.String { if reflect.TypeOf(newValue).ConvertibleTo(reflect.TypeOf("")) { newValueString := reflect.ValueOf(newValue).Convert(reflect.TypeOf("")).Interface().(string) - v.SetString(newValueString) + reflectionOfObject.SetString(newValueString) } else { - return newUnsupportedTypePathError(keyWithDots, v.Type()) + return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } } else { - v.SetString(fmt.Sprintf("%v", newValue)) + reflectionOfObject.SetString(fmt.Sprintf("%v", newValue)) } case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: if newValueKind == reflect.Int { - v.SetInt(int64(newValue.(int))) + reflectionOfObject.SetInt(int64(newValue.(int))) } else { s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64) if err != nil { return err } - v.SetInt(s) + reflectionOfObject.SetInt(s) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if newValueKind == reflect.Int { - v.SetUint(uint64(newValue.(int))) + reflectionOfObject.SetUint(uint64(newValue.(int))) } else { s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64) if err != nil { return err } - v.SetUint(uint64(s)) + reflectionOfObject.SetUint(uint64(s)) } case reflect.Bool: if newValueKind == reflect.Bool { - v.SetBool(newValue.(bool)) + reflectionOfObject.SetBool(newValue.(bool)) } else { b, err := strconv.ParseBool(fmt.Sprintf("%v", newValue)) if err != nil { return err } - v.SetBool(b) + reflectionOfObject.SetBool(b) } case reflect.Float64, reflect.Float32: if newValueKind == reflect.Float64 { - v.SetFloat(newValue.(float64)) + reflectionOfObject.SetFloat(newValue.(float64)) } else { s, err := strconv.ParseFloat(fmt.Sprintf("%v", newValue), 64) if err != nil { return err } - v.SetFloat(s) + reflectionOfObject.SetFloat(s) } case reflect.Slice, reflect.Array: if newValueKind == reflect.Ptr { newValue = reflect.ValueOf(newValue).Elem().Interface() - v.Set(reflect.ValueOf(newValue)) + reflectionOfObject.Set(reflect.ValueOf(newValue)) } else if newValueKind == reflect.Slice { - v.Set(reflect.ValueOf(newValue)) + reflectionOfObject.Set(reflect.ValueOf(newValue)) } else { - return newUnsupportedTypePathError(keyWithDots, v.Type()) + return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } default: - return newInvalidTypeForPathError(keyWithDots, v.Type().String(), newValueKind.String()) + return newInvalidTypeForPathError(keyWithDots, reflectionOfObject.Type().String(), newValueKind.String()) } return nil