105 lines
2.5 KiB
Go
105 lines
2.5 KiB
Go
package entgo
|
||
|
||
import (
|
||
"fmt"
|
||
"strings"
|
||
|
||
"entgo.io/ent/dialect/sql"
|
||
|
||
"github.com/tx7do/go-utils/fieldmaskutil"
|
||
"github.com/tx7do/go-utils/stringcase"
|
||
|
||
"google.golang.org/protobuf/proto"
|
||
"google.golang.org/protobuf/reflect/protoreflect"
|
||
)
|
||
|
||
func BuildSetNullUpdate(u *sql.UpdateBuilder, fields []string) {
|
||
if len(fields) > 0 {
|
||
for _, field := range fields {
|
||
field = stringcase.ToSnakeCase(field)
|
||
u.SetNull(field)
|
||
}
|
||
}
|
||
}
|
||
|
||
// BuildSetNullUpdater 构建一个UpdateBuilder,用于清空字段的值
|
||
func BuildSetNullUpdater(fields []string) func(u *sql.UpdateBuilder) {
|
||
if len(fields) == 0 {
|
||
return nil
|
||
}
|
||
|
||
return func(u *sql.UpdateBuilder) {
|
||
BuildSetNullUpdate(u, fields)
|
||
}
|
||
}
|
||
|
||
// ExtractJsonFieldKeyValues 提取json字段的键值对
|
||
func ExtractJsonFieldKeyValues(msg proto.Message, paths []string, needToSnakeCase bool) []string {
|
||
var keyValues []string
|
||
rft := msg.ProtoReflect()
|
||
for _, path := range paths {
|
||
fd := rft.Descriptor().Fields().ByName(protoreflect.Name(path))
|
||
if fd == nil {
|
||
continue
|
||
}
|
||
if !rft.Has(fd) {
|
||
continue
|
||
}
|
||
|
||
var k string
|
||
if needToSnakeCase {
|
||
k = stringcase.ToSnakeCase(path)
|
||
} else {
|
||
k = path
|
||
}
|
||
|
||
keyValues = append(keyValues, fmt.Sprintf("'%s'", k))
|
||
|
||
v := rft.Get(fd)
|
||
switch v.Interface().(type) {
|
||
case bool:
|
||
keyValues = append(keyValues, fmt.Sprintf("%t", v.Interface()))
|
||
case int32, int64, uint32, uint64, float32, float64:
|
||
keyValues = append(keyValues, fmt.Sprintf("%d", v.Interface()))
|
||
case string:
|
||
keyValues = append(keyValues, fmt.Sprintf("'%s'", v.Interface()))
|
||
default:
|
||
keyValues = append(keyValues, fmt.Sprintf("%v", v.Interface()))
|
||
}
|
||
}
|
||
|
||
return keyValues
|
||
}
|
||
|
||
// SetJsonNullFieldUpdateBuilder 设置json字段的空值
|
||
func SetJsonNullFieldUpdateBuilder(fieldName string, msg proto.Message, paths []string) func(u *sql.UpdateBuilder) {
|
||
nilPaths := fieldmaskutil.NilValuePaths(msg, paths)
|
||
if len(nilPaths) == 0 {
|
||
return nil
|
||
}
|
||
|
||
return func(u *sql.UpdateBuilder) {
|
||
u.Set(fieldName,
|
||
sql.Expr(
|
||
fmt.Sprintf("\"%s\" - '{%s}'::text[]", fieldName, strings.Join(nilPaths, ",")),
|
||
),
|
||
)
|
||
}
|
||
}
|
||
|
||
// SetJsonFieldValueUpdateBuilder 设置json字段的值
|
||
func SetJsonFieldValueUpdateBuilder(fieldName string, msg proto.Message, paths []string, needToSnakeCase bool) func(u *sql.UpdateBuilder) {
|
||
keyValues := ExtractJsonFieldKeyValues(msg, paths, needToSnakeCase)
|
||
if len(keyValues) == 0 {
|
||
return nil
|
||
}
|
||
|
||
return func(u *sql.UpdateBuilder) {
|
||
u.Set(fieldName,
|
||
sql.Expr(
|
||
fmt.Sprintf("\"%s\" || jsonb_build_object(%s)", fieldName, strings.Join(keyValues, ",")),
|
||
),
|
||
)
|
||
}
|
||
}
|