// Copyright 2022 schukai GmbH // SPDX-License-Identifier: AGPL-3.0 package pathfinder import ( "bytes" "encoding/gob" "fmt" "reflect" "strconv" "strings" ) func deepCopy(src, dst interface{}) error { var buf bytes.Buffer enc := gob.NewEncoder(&buf) dec := gob.NewDecoder(&buf) if err := enc.Encode(src); err != nil { return err } return dec.Decode(dst) } // SetValue sets the value of a field in a struct, given a path to the field. // The object must be a pointer to a struct, otherwise an error is returned. func SetValue[D any](obj D, keyWithDots string, newValue any) error { keySlice := strings.Split(keyWithDots, ".") reflectionOfObject := reflect.ValueOf(obj) for keyIndex, key := range keySlice[0 : len(keySlice)-1] { if reflectionOfObject.Kind() == reflect.Map { if reflectionOfObject.IsNil() { return newInvalidPathError(keyWithDots) } currentValue := reflectionOfObject.MapIndex(reflect.ValueOf(key)).Interface() newValueCopy := reflect.New(reflect.TypeOf(currentValue)).Interface() if err := deepCopy(currentValue, newValueCopy); err != nil { return err } newValueCopyPtr := &newValueCopy newValueCopyReflect := reflect.ValueOf(newValueCopyPtr).Elem() if !newValueCopyReflect.CanAddr() { return newCannotSetError("Wert ist nicht adressierbar") } newKey := strings.Join(keySlice[keyIndex+1:], ".") err := SetValue(newValueCopyPtr, newKey, newValue) if err != nil { return err } reflectionOfObject.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValueCopy).Elem()) return nil } if reflectionOfObject.Kind() == reflect.Ptr && reflectionOfObject.Elem().Kind() == reflect.Interface { reflectionOfObject = reflectionOfObject.Elem().Elem() } for reflectionOfObject.Kind() != reflect.Ptr { if reflectionOfObject.Kind() == reflect.Invalid { return newInvalidPathError(keyWithDots) } if reflectionOfObject.CanAddr() { reflectionOfObject = reflectionOfObject.Addr() } else { return newCannotSetError(keyWithDots) } } if reflectionOfObject.Kind() != reflect.Ptr { return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } switch reflectionOfObject.Elem().Kind() { case reflect.Struct: reflectionOfObject = reflectionOfObject.Elem().FieldByName(key) case reflect.Slice: // index is a number and get reflectionOfObject from slice with index index, err := strconv.Atoi(key) if err != nil { return newInvalidPathError(keyWithDots) } if index >= reflectionOfObject.Elem().Len() { return newInvalidPathError(keyWithDots) } reflectionOfObject = reflectionOfObject.Elem().Index(index) default: return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } } if reflectionOfObject.Kind() == reflect.Invalid { return newInvalidPathError(keyWithDots) } for reflectionOfObject.Kind() == reflect.Ptr { reflectionOfObject = reflectionOfObject.Elem() } // non-supporter type at the top of the path switch reflectionOfObject.Kind() { case reflect.Struct: reflectionOfObject = reflectionOfObject.FieldByName(keySlice[len(keySlice)-1]) if !reflectionOfObject.IsValid() { return newInvalidPathError(keyWithDots) } if !reflectionOfObject.CanSet() { return newCannotSetError(keyWithDots) } case reflect.Map: key := keySlice[len(keySlice)-1] m := reflectionOfObject keyVal := reflect.ValueOf(key) newVal := reflect.ValueOf(newValue) if !keyVal.Type().ConvertibleTo(m.Type().Key()) { return fmt.Errorf("key type mismatch") } if !newVal.Type().ConvertibleTo(m.Type().Elem()) { return fmt.Errorf("value type mismatch") } keyValConverted := keyVal.Convert(m.Type().Key()) newValConverted := newVal.Convert(m.Type().Elem()) m.SetMapIndex(keyValConverted, newValConverted) return nil case reflect.Slice: index, err := strconv.Atoi(keySlice[len(keySlice)-1]) if err != nil { return newInvalidPathError(keyWithDots) } if index >= reflectionOfObject.Len() { return newInvalidPathError(keyWithDots) } reflectionOfObject = reflectionOfObject.Index(index) case reflect.Array: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Ptr: if newValue == nil { reflectionOfObject.Set(reflect.Zero(reflectionOfObject.Type())) } else { reflectionOfObject.Set(reflect.ValueOf(&newValue)) } return nil case reflect.Interface: // check if reflectionOfObject is an interface to an struct pointer if reflectionOfObject.Elem().Kind() == reflect.Ptr && reflectionOfObject.Elem().Elem().Kind() == reflect.Struct { return SetValue(reflectionOfObject.Elem().Interface(), keySlice[len(keySlice)-1], newValue) } case reflect.Chan: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Func: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.UnsafePointer: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Uintptr: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Complex64: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Complex128: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) case reflect.Invalid: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) default: return newUnsupportedTypeAtTopOfPathError(keyWithDots, reflectionOfObject.Type()) } newValueType := reflect.TypeOf(newValue) if newValueType == nil { return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } newValueKind := reflect.TypeOf(newValue).Kind() switch reflectionOfObject.Kind() { case reflect.String: if reflectionOfObject.Kind() == reflect.Ptr || reflectionOfObject.Kind() == reflect.Interface { if reflectionOfObject.Elem().CanSet() && reflectionOfObject.Elem().Kind() == reflect.String { if newValueKind == reflect.String { reflectionOfObject.Elem().SetString(newValue.(string)) } else { reflectionOfObject.Elem().SetString(fmt.Sprintf("%v", newValue)) } } } else if newValueKind == reflect.String { if reflect.TypeOf(newValue).ConvertibleTo(reflect.TypeOf("")) { newValueString := reflect.ValueOf(newValue).Convert(reflect.TypeOf("")).Interface().(string) reflectionOfObject.SetString(newValueString) } else { return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } } else { reflectionOfObject.SetString(fmt.Sprintf("%v", newValue)) } case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: if newValueKind == reflect.Int { reflectionOfObject.SetInt(int64(newValue.(int))) } else { s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64) if err != nil { return err } reflectionOfObject.SetInt(s) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if newValueKind == reflect.Int { reflectionOfObject.SetUint(uint64(newValue.(int))) } else { s, err := strconv.ParseInt(fmt.Sprintf("%v", newValue), 10, 64) if err != nil { return err } reflectionOfObject.SetUint(uint64(s)) } case reflect.Bool: if newValueKind == reflect.Bool { reflectionOfObject.SetBool(newValue.(bool)) } else { b, err := strconv.ParseBool(fmt.Sprintf("%v", newValue)) if err != nil { return err } reflectionOfObject.SetBool(b) } case reflect.Float64, reflect.Float32: if newValueKind == reflect.Float64 { reflectionOfObject.SetFloat(newValue.(float64)) } else { s, err := strconv.ParseFloat(fmt.Sprintf("%v", newValue), 64) if err != nil { return err } reflectionOfObject.SetFloat(s) } case reflect.Slice, reflect.Array: if newValueKind == reflect.Ptr { newValue = reflect.ValueOf(newValue).Elem().Interface() reflectionOfObject.Set(reflect.ValueOf(newValue)) } else if newValueKind == reflect.Slice { reflectionOfObject.Set(reflect.ValueOf(newValue)) } else { return newUnsupportedTypePathError(keyWithDots, reflectionOfObject.Type()) } default: return newInvalidTypeForPathError(keyWithDots, reflectionOfObject.Type().String(), newValueKind.String()) } return nil }