Compare commits

...

2 Commits

Author SHA1 Message Date
tx7do
d65e7bb928 feat: ent update fields. 2023-11-08 18:15:48 +08:00
tx7do
78452b1abf feat: field mask util. 2023-11-08 15:16:20 +08:00
4 changed files with 106 additions and 29 deletions

25
entgo/update/update.go Normal file
View File

@@ -0,0 +1,25 @@
package update
import (
"entgo.io/ent/dialect/sql"
"github.com/tx7do/go-utils/stringcase"
)
func BuildSetNullUpdate(u *sql.UpdateBuilder, fields []string) {
if len(fields) > 0 {
for _, field := range fields {
field = stringcase.ToSnakeCase(field)
u.SetNull(field)
}
}
}
func BuildSetNullUpdater(fields []string) (error, func(u *sql.UpdateBuilder)) {
if len(fields) > 0 {
return nil, func(u *sql.UpdateBuilder) {
BuildSetNullUpdate(u, fields)
}
} else {
return nil, nil
}
}

View File

@@ -0,0 +1,28 @@
package update
import (
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"github.com/stretchr/testify/require"
"testing"
)
func TestBuildSetNullUpdate(t *testing.T) {
t.Run("MySQL_Set2", func(t *testing.T) {
s := sql.Dialect(dialect.MySQL).Update("users")
BuildSetNullUpdate(s, []string{"id", "username"})
query, args := s.Query()
require.Equal(t, "UPDATE `users` SET `id` = NULL, `username` = NULL", query)
require.Empty(t, args)
})
t.Run("PostgreSQL_Set2", func(t *testing.T) {
s := sql.Dialect(dialect.Postgres).Update("users")
BuildSetNullUpdate(s, []string{"id", "username"})
query, args := s.Query()
require.Equal(t, `UPDATE "users" SET "id" = NULL, "username" = NULL`, query)
require.Empty(t, args)
})
}

View File

@@ -80,35 +80,37 @@ func (mask NestedMask) Filter(msg proto.Message) {
rft := msg.ProtoReflect()
rft.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
m, ok := mask[string(fd.Name())]
if ok {
if len(m) == 0 {
return true
}
if fd.IsMap() {
xmap := rft.Get(fd).Map()
xmap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
if mi, ok := m[mk.String()]; ok {
if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
mi.Filter(i.Interface())
}
} else {
xmap.Clear(mk)
}
return true
})
} else if fd.IsList() {
list := rft.Get(fd).List()
for i := 0; i < list.Len(); i++ {
m.Filter(list.Get(i).Message().Interface())
}
} else if fd.Kind() == protoreflect.MessageKind {
m.Filter(rft.Get(fd).Message().Interface())
}
} else {
if !ok {
rft.Clear(fd)
return true
}
if len(m) == 0 {
return true
}
if fd.IsMap() {
xmap := rft.Get(fd).Map()
xmap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
if mi, ok := m[mk.String()]; ok {
if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
mi.Filter(i.Interface())
}
} else {
xmap.Clear(mk)
}
return true
})
} else if fd.IsList() {
list := rft.Get(fd).List()
for i := 0; i < list.Len(); i++ {
m.Filter(list.Get(i).Message().Interface())
}
} else if fd.Kind() == protoreflect.MessageKind {
m.Filter(rft.Get(fd).Message().Interface())
}
return true
})
}
@@ -208,3 +210,25 @@ func isValid(fd protoreflect.FieldDescriptor, val protoreflect.Value) bool {
}
return true
}
func NilValuePaths(msg proto.Message, paths []string) []string {
if len(paths) == 0 {
return nil
}
var out []string
rft := msg.ProtoReflect()
for _, v := range paths {
fd := rft.Descriptor().Fields().ByName(protoreflect.Name(v))
if fd == nil {
continue
}
if !rft.Has(fd) {
out = append(out, v)
}
}
return out
}

View File

@@ -1,6 +1,6 @@
git tag v1.1.7
git tag v1.1.8
git tag bank_card/v1.1.0
git tag entgo/v1.1.8
git tag entgo/v1.1.9
git tag geoip/v1.1.0
git push origin --tags