Files
go-utils/fieldmaskutil/fieldmaskutil.go
2023-11-08 15:16:20 +08:00

235 lines
6.3 KiB
Go

package fieldmaskutil
import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
// Filter keeps the msg fields that are listed in the paths and clears all the rest.
//
// This is a handy wrapper for NestedMask.Filter method.
// If the same paths are used to process multiple proto messages use NestedMask.Filter method directly.
func Filter(msg proto.Message, paths []string) {
NestedMaskFromPaths(paths).Filter(msg)
}
// Prune clears all the fields listed in paths from the given msg.
//
// This is a handy wrapper for NestedMask.Prune method.
// If the same paths are used to process multiple proto messages use NestedMask.Filter method directly.
func Prune(msg proto.Message, paths []string) {
NestedMaskFromPaths(paths).Prune(msg)
}
// Overwrite overwrites all the fields listed in paths in the dest msg using values from src msg.
//
// This is a handy wrapper for NestedMask.Overwrite method.
// If the same paths are used to process multiple proto messages use NestedMask.Overwrite method directly.
func Overwrite(src, dest proto.Message, paths []string) {
NestedMaskFromPaths(paths).Overwrite(src, dest)
}
// NestedMask represents a field mask as a recursive map.
type NestedMask map[string]NestedMask
// NestedMaskFromPaths creates an instance of NestedMask for the given paths.
func NestedMaskFromPaths(paths []string) NestedMask {
mask := make(NestedMask)
for _, path := range paths {
curr := mask
var letters []rune
for _, letter := range path {
if letter == '.' {
if len(letters) == 0 {
continue
}
key := string(letters)
c, ok := curr[key]
if !ok {
c = make(NestedMask)
curr[key] = c
}
curr = c
letters = nil
continue
}
letters = append(letters, letter)
}
if len(letters) != 0 {
key := string(letters)
if _, ok := curr[key]; !ok {
curr[key] = make(NestedMask)
}
}
}
return mask
}
// Filter keeps the msg fields that are listed in the paths and clears all the rest.
//
// If the mask is empty then all the fields are kept.
// Paths are assumed to be valid and normalized otherwise the function may panic.
// See google.golang.org/protobuf/types/known/fieldmaskpb for details.
func (mask NestedMask) Filter(msg proto.Message) {
if len(mask) == 0 {
return
}
rft := msg.ProtoReflect()
rft.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
m, ok := mask[string(fd.Name())]
if !ok {
rft.Clear(fd)
return true
}
if len(m) == 0 {
return true
}
if fd.IsMap() {
xmap := rft.Get(fd).Map()
xmap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
if mi, ok := m[mk.String()]; ok {
if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
mi.Filter(i.Interface())
}
} else {
xmap.Clear(mk)
}
return true
})
} else if fd.IsList() {
list := rft.Get(fd).List()
for i := 0; i < list.Len(); i++ {
m.Filter(list.Get(i).Message().Interface())
}
} else if fd.Kind() == protoreflect.MessageKind {
m.Filter(rft.Get(fd).Message().Interface())
}
return true
})
}
// Prune clears all the fields listed in paths from the given msg.
//
// All other fields are kept untouched. If the mask is empty no fields are cleared.
// This operation is the opposite of NestedMask.Filter.
// Paths are assumed to be valid and normalized otherwise the function may panic.
// See google.golang.org/protobuf/types/known/fieldmaskpb for details.
func (mask NestedMask) Prune(msg proto.Message) {
if len(mask) == 0 {
return
}
rft := msg.ProtoReflect()
rft.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
m, ok := mask[string(fd.Name())]
if ok {
if len(m) == 0 {
rft.Clear(fd)
return true
}
if fd.IsMap() {
xmap := rft.Get(fd).Map()
xmap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
if mi, ok := m[mk.String()]; ok {
if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
mi.Prune(i.Interface())
} else {
xmap.Clear(mk)
}
}
return true
})
} else if fd.IsList() {
list := rft.Get(fd).List()
for i := 0; i < list.Len(); i++ {
m.Prune(list.Get(i).Message().Interface())
}
} else if fd.Kind() == protoreflect.MessageKind {
m.Prune(rft.Get(fd).Message().Interface())
}
}
return true
})
}
// Overwrite overwrites all the fields listed in paths in the dest msg using values from src msg.
//
// All other fields are kept untouched. If the mask is empty, no fields are overwritten.
// Supports scalars, messages, repeated fields, and maps.
// If the parent of the field is nil message, the parent is initiated before overwriting the field
// If the field in src is empty value, the field in dest is cleared.
// Paths are assumed to be valid and normalized otherwise the function may panic.
func (mask NestedMask) Overwrite(src, dest proto.Message) {
mask.overwrite(src.ProtoReflect(), dest.ProtoReflect())
}
func (mask NestedMask) overwrite(src, dest protoreflect.Message) {
for k, v := range mask {
srcFD := src.Descriptor().Fields().ByName(protoreflect.Name(k))
destFD := dest.Descriptor().Fields().ByName(protoreflect.Name(k))
if srcFD == nil || destFD == nil {
continue
}
// Leaf mask -> copy value from src to dest
if len(v) == 0 {
if srcFD.Kind() == destFD.Kind() { // TODO: Full type equality check
val := src.Get(srcFD)
if isValid(srcFD, val) {
dest.Set(destFD, val)
} else {
dest.Clear(destFD)
}
}
} else if srcFD.Kind() == protoreflect.MessageKind {
// If dest field is nil
if !dest.Get(destFD).Message().IsValid() {
dest.Set(destFD, protoreflect.ValueOf(dest.Get(destFD).Message().New()))
}
v.overwrite(src.Get(srcFD).Message(), dest.Get(destFD).Message())
}
}
}
func isValid(fd protoreflect.FieldDescriptor, val protoreflect.Value) bool {
if fd.IsMap() {
return val.Map().IsValid()
} else if fd.IsList() {
return val.List().IsValid()
} else if fd.Message() != nil {
return val.Message().IsValid()
}
return true
}
func NilValuePaths(msg proto.Message, paths []string) []string {
if len(paths) == 0 {
return nil
}
var out []string
rft := msg.ProtoReflect()
for _, v := range paths {
fd := rft.Descriptor().Fields().ByName(protoreflect.Name(v))
if fd == nil {
continue
}
if !rft.Has(fd) {
out = append(out, v)
}
}
return out
}