feat: ent support update json field.

This commit is contained in:
tx7do
2024-12-14 13:07:01 +08:00
parent c9a0909d46
commit 801deb98cd
4 changed files with 118 additions and 47 deletions

View File

@@ -1,8 +1,16 @@
package entgo
import (
"fmt"
"strings"
"entgo.io/ent/dialect/sql"
"github.com/tx7do/go-utils/fieldmaskutil"
"github.com/tx7do/go-utils/stringcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
func BuildSetNullUpdate(u *sql.UpdateBuilder, fields []string) {
@@ -14,12 +22,72 @@ func BuildSetNullUpdate(u *sql.UpdateBuilder, fields []string) {
}
}
func BuildSetNullUpdater(fields []string) (error, func(u *sql.UpdateBuilder)) {
if len(fields) > 0 {
return nil, func(u *sql.UpdateBuilder) {
BuildSetNullUpdate(u, fields)
}
} else {
return nil, nil
// BuildSetNullUpdater 构建一个UpdateBuilder,用于清空字段的值
func BuildSetNullUpdater(fields []string) func(u *sql.UpdateBuilder) {
if len(fields) == 0 {
return nil
}
return func(u *sql.UpdateBuilder) {
BuildSetNullUpdate(u, fields)
}
}
// ExtractJsonFieldKeyValues 提取json字段的键值对
func ExtractJsonFieldKeyValues(msg proto.Message, paths []string) []string {
var keyValues []string
rft := msg.ProtoReflect()
for _, path := range paths {
fd := rft.Descriptor().Fields().ByName(protoreflect.Name(path))
if fd == nil {
continue
}
if !rft.Has(fd) {
continue
}
keyValues = append(keyValues, fmt.Sprintf("'%s'", stringcase.ToSnakeCase(path)))
v := rft.Get(fd)
switch v.Interface().(type) {
case int32, int64, uint32, uint64, float32, float64, bool:
keyValues = append(keyValues, fmt.Sprintf("%d", v.Interface()))
case string:
keyValues = append(keyValues, fmt.Sprintf("'%s'", v.Interface()))
}
}
return keyValues
}
// SetJsonNullFieldUpdateBuilder 设置json字段的空值
func SetJsonNullFieldUpdateBuilder(fieldName string, msg proto.Message, paths []string) func(u *sql.UpdateBuilder) {
nilPaths := fieldmaskutil.NilValuePaths(msg, paths)
if len(nilPaths) == 0 {
return nil
}
return func(u *sql.UpdateBuilder) {
u.Set(fieldName,
sql.Expr(
fmt.Sprintf("\"%s\" - '{%s}'::text[]", fieldName, strings.Join(nilPaths, ",")),
),
)
}
}
// SetJsonFieldValueUpdateBuilder 设置json字段的值
func SetJsonFieldValueUpdateBuilder(fieldName string, msg proto.Message, paths []string) func(u *sql.UpdateBuilder) {
keyValues := ExtractJsonFieldKeyValues(msg, paths)
if len(keyValues) == 0 {
return nil
}
return func(u *sql.UpdateBuilder) {
u.Set(fieldName,
sql.Expr(
fmt.Sprintf("\"%s\" || jsonb_build_object(%s)", fieldName, strings.Join(keyValues, ",")),
),
)
}
}