Files
go-utils/entgo/query/filter.go
2025-05-07 14:07:13 +08:00

674 lines
19 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package entgo
import (
"fmt"
"strings"
"encoding/json"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/tx7do/go-utils/stringcase"
)
type FilterOp int
const (
FilterNot DatePart = iota // 不等于
FilterIn // 检查值是否在列表中
FilterNotIn // 不在列表中
FilterGTE // 大于或等于传递的值
FilterGT // 大于传递值
FilterLTE // 小于或等于传递值
FilterLT // 小于传递值
FilterRange // 是否介于和给定的两个值之间
FilterIsNull // 是否为空
FilterNotIsNull // 是否不为空
FilterContains // 是否包含指定的子字符串
FilterInsensitiveContains // 不区分大小写,是否包含指定的子字符串
FilterStartsWith // 以值开头
FilterInsensitiveStartsWith // 不区分大小写,以值开头
FilterEndsWith // 以值结尾
FilterInsensitiveEndsWith // 不区分大小写,以值结尾
FilterExact // 精确匹配
FilterInsensitiveExact // 不区分大小写,精确匹配
FilterRegex // 正则表达式
FilterInsensitiveRegex // 不区分大小写,正则表达式
FilterSearch // 全文搜索
)
var ops = [...]string{
FilterNot: "not",
FilterIn: "in",
FilterNotIn: "not_in",
FilterGTE: "gte",
FilterGT: "gt",
FilterLTE: "lte",
FilterLT: "lt",
FilterRange: "range",
FilterIsNull: "isnull",
FilterNotIsNull: "not_isnull",
FilterContains: "contains",
FilterInsensitiveContains: "icontains",
FilterStartsWith: "startswith",
FilterInsensitiveStartsWith: "istartswith",
FilterEndsWith: "endswith",
FilterInsensitiveEndsWith: "iendswith",
FilterExact: "exact",
FilterInsensitiveExact: "iexact",
FilterRegex: "regex",
FilterInsensitiveRegex: "iregex",
FilterSearch: "search",
}
type DatePart int
const (
DatePartDate DatePart = iota // 日期
DatePartYear // 年
DatePartISOYear // ISO 8601 一年中的周数
DatePartQuarter // 季度
DatePartMonth // 月
DatePartWeek // ISO 8601 周编号 一年中的周数
DatePartWeekDay // 星期几
DatePartISOWeekDay // 星期几
DatePartDay // 日
DatePartTime // 小时:分钟:秒
DatePartHour // 小时
DatePartMinute // 分钟
DatePartSecond // 秒
DatePartMicrosecond // 微秒
)
var dateParts = [...]string{
DatePartDate: "date",
DatePartYear: "year",
DatePartISOYear: "iso_year",
DatePartQuarter: "quarter",
DatePartMonth: "month",
DatePartWeek: "week",
DatePartWeekDay: "week_day",
DatePartISOWeekDay: "iso_week_day",
DatePartDay: "day",
DatePartTime: "time",
DatePartHour: "hour",
DatePartMinute: "minute",
DatePartSecond: "second",
DatePartMicrosecond: "microsecond",
}
const (
QueryDelimiter = "__" // 分隔符
JsonFieldDelimiter = "." // JSONB字段分隔符
)
// splitQueryKey 分割查询键
func splitQueryKey(key string) []string {
return strings.Split(key, QueryDelimiter)
}
// splitJsonFieldKey 分割JSON字段键
func splitJsonFieldKey(key string) []string {
return strings.Split(key, JsonFieldDelimiter)
}
// isJsonFieldKey 是否为JSON字段键
func isJsonFieldKey(key string) bool {
return strings.Contains(key, JsonFieldDelimiter)
}
// hasOperations 是否有操作
func hasOperations(str string) bool {
str = strings.ToLower(str)
for _, item := range ops {
if str == item {
return true
}
}
return false
}
// hasDatePart 是否有日期部分
func hasDatePart(str string) bool {
str = strings.ToLower(str)
for _, item := range dateParts {
if str == item {
return true
}
}
return false
}
// BuildFilterSelector 构建过滤选择器
func BuildFilterSelector(andFilterJsonString, orFilterJsonString string) (error, []func(s *sql.Selector)) {
var err error
var queryConditions []func(s *sql.Selector)
var andSelector func(s *sql.Selector)
err, andSelector = QueryCommandToWhereConditions(andFilterJsonString, false)
if err != nil {
return err, nil
}
if andSelector != nil {
queryConditions = append(queryConditions, andSelector)
}
var orSelector func(s *sql.Selector)
err, orSelector = QueryCommandToWhereConditions(orFilterJsonString, true)
if err != nil {
return err, nil
}
if orSelector != nil {
queryConditions = append(queryConditions, orSelector)
}
return nil, queryConditions
}
// QueryCommandToWhereConditions 查询命令转换为选择条件
func QueryCommandToWhereConditions(strJson string, isOr bool) (error, func(s *sql.Selector)) {
if len(strJson) == 0 {
return nil, nil
}
codec := encoding.GetCodec("json")
queryMap := make(map[string]string)
var queryMapArray []map[string]string
if err1 := codec.Unmarshal([]byte(strJson), &queryMap); err1 != nil {
if err2 := codec.Unmarshal([]byte(strJson), &queryMapArray); err2 != nil {
return err2, nil
}
}
return nil, func(s *sql.Selector) {
var ps []*sql.Predicate
ps = append(ps, processQueryMap(s, queryMap)...)
for _, v := range queryMapArray {
ps = append(ps, processQueryMap(s, v)...)
}
if isOr {
s.Where(sql.Or(ps...))
} else {
s.Where(sql.And(ps...))
}
}
}
// processQueryMap 处理查询映射表
func processQueryMap(s *sql.Selector, queryMap map[string]string) []*sql.Predicate {
var ps []*sql.Predicate
for k, v := range queryMap {
keys := splitQueryKey(k)
if cond := makeFieldFilter(s, keys, v); cond != nil {
ps = append(ps, cond)
}
}
return ps
}
// makeFieldFilter 构建一个字段过滤器
func makeFieldFilter(s *sql.Selector, keys []string, value string) *sql.Predicate {
if len(keys) == 0 {
return nil
}
if len(value) == 0 {
return nil
}
field := keys[0]
if len(field) == 0 {
return nil
}
p := sql.P()
switch len(keys) {
case 1:
if isJsonFieldKey(field) {
jsonFields := splitJsonFieldKey(field)
if len(jsonFields) != 2 {
field = stringcase.ToSnakeCase(field)
return filterEqual(s, p, field, value)
}
//value = "'" + value + "'"
return filterJsonb(
s, p,
stringcase.ToSnakeCase(jsonFields[1]),
stringcase.ToSnakeCase(jsonFields[0]),
).
EQ("", value)
}
field = stringcase.ToSnakeCase(field)
return filterEqual(s, p, field, value)
case 2:
op := keys[1]
if len(op) == 0 {
return nil
}
if isJsonFieldKey(field) {
jsonFields := splitJsonFieldKey(field)
if len(jsonFields) == 2 {
field = filterJsonbField(s,
stringcase.ToSnakeCase(jsonFields[1]),
stringcase.ToSnakeCase(jsonFields[0]),
)
//value = "'" + value + "'"
}
} else {
field = stringcase.ToSnakeCase(field)
}
var cond *sql.Predicate
if hasOperations(op) {
return processOp(s, p, op, field, value)
} else if hasDatePart(op) {
cond = filterDatePart(s, p, op, field).EQ("", value)
} else {
cond = filterJsonb(s, p, op, field).EQ("", value)
}
return cond
case 3:
op1 := keys[1]
if len(op1) == 0 {
return nil
}
op2 := keys[2]
if len(op2) == 0 {
return nil
}
// 第二个参数要么是提取日期要么是json字段。
//var cond *sql.Predicate
if hasDatePart(op1) {
if isJsonFieldKey(field) {
jsonFields := splitJsonFieldKey(field)
if len(jsonFields) == 2 {
field = filterJsonbField(s, jsonFields[1], jsonFields[0])
//value = "'" + value + "'"
}
} else {
field = stringcase.ToSnakeCase(field)
}
str := filterDatePartField(s, op1, field)
if hasOperations(op2) {
return processOp(s, p, op2, str, value)
}
return nil
} else {
str := filterJsonbField(s, op1, field)
if hasOperations(op2) {
return processOp(s, p, op2, str, value)
} else if hasDatePart(op2) {
return filterDatePart(s, p, op2, str)
}
return nil
}
default:
return nil
}
}
func processOp(s *sql.Selector, p *sql.Predicate, op, field, value string) *sql.Predicate {
var cond *sql.Predicate
switch op {
case ops[FilterNot]:
cond = filterNot(s, p, field, value)
case ops[FilterIn]:
cond = filterIn(s, p, field, value)
case ops[FilterNotIn]:
cond = filterNotIn(s, p, field, value)
case ops[FilterGTE]:
cond = filterGTE(s, p, field, value)
case ops[FilterGT]:
cond = filterGT(s, p, field, value)
case ops[FilterLTE]:
cond = filterLTE(s, p, field, value)
case ops[FilterLT]:
cond = filterLT(s, p, field, value)
case ops[FilterRange]:
cond = filterRange(s, p, field, value)
case ops[FilterIsNull]:
cond = filterIsNull(s, p, field, value)
case ops[FilterNotIsNull]:
cond = filterIsNotNull(s, p, field, value)
case ops[FilterContains]:
cond = filterContains(s, p, field, value)
case ops[FilterInsensitiveContains]:
cond = filterInsensitiveContains(s, p, field, value)
case ops[FilterStartsWith]:
cond = filterStartsWith(s, p, field, value)
case ops[FilterInsensitiveStartsWith]:
cond = filterInsensitiveStartsWith(s, p, field, value)
case ops[FilterEndsWith]:
cond = filterEndsWith(s, p, field, value)
case ops[FilterInsensitiveEndsWith]:
cond = filterInsensitiveEndsWith(s, p, field, value)
case ops[FilterExact]:
cond = filterExact(s, p, field, value)
case ops[FilterInsensitiveExact]:
cond = filterInsensitiveExact(s, p, field, value)
case ops[FilterRegex]:
cond = filterRegex(s, p, field, value)
case ops[FilterInsensitiveRegex]:
cond = filterInsensitiveRegex(s, p, field, value)
case ops[FilterSearch]:
cond = filterSearch(s, p, field, value)
default:
return nil
}
return cond
}
// filterEqual = 相等操作
// SQL: WHERE "name" = "tom"
func filterEqual(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.EQ(s.C(field), value)
}
// filterNot NOT 不相等操作
// SQL: WHERE NOT ("name" = "tom")
// 或者: WHERE "name" <> "tom"
// 用NOT可以过滤出NULL而用<>、!=则不能。
func filterNot(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.Not().EQ(s.C(field), value)
}
// filterIn IN操作
// SQL: WHERE name IN ("tom", "jimmy")
func filterIn(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
var values []any
if err := json.Unmarshal([]byte(value), &values); err == nil {
return p.In(s.C(field), values...)
}
return nil
}
// filterNotIn NOT IN操作
// SQL: WHERE name NOT IN ("tom", "jimmy")`
func filterNotIn(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
var values []any
if err := json.Unmarshal([]byte(value), &values); err == nil {
return p.NotIn(s.C(field), values...)
}
return nil
}
// filterGTE GTE (Greater Than or Equal) 大于等于 >=操作
// SQL: WHERE "create_time" >= "2023-10-25"
func filterGTE(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.GTE(s.C(field), value)
}
// filterGT GT (Greater than) 大于 >操作
// SQL: WHERE "create_time" > "2023-10-25"
func filterGT(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.GT(s.C(field), value)
}
// filterLTE LTE (Less Than or Equal) 小于等于 <=操作
// SQL: WHERE "create_time" <= "2023-10-25"
func filterLTE(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.LTE(s.C(field), value)
}
// filterLT LT (Less than) 小于 <操作
// SQL: WHERE "create_time" < "2023-10-25"
func filterLT(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.LT(s.C(field), value)
}
// filterRange 在值域之中 BETWEEN操作
// SQL: WHERE "create_time" BETWEEN "2023-10-25" AND "2024-10-25"
// 或者: WHERE "create_time" >= "2023-10-25" AND "create_time" <= "2024-10-25"
func filterRange(s *sql.Selector, _ *sql.Predicate, field, value string) *sql.Predicate {
var values []any
if err := json.Unmarshal([]byte(value), &values); err == nil {
if len(values) != 2 {
return nil
}
return sql.And(
sql.GTE(s.C(field), values[0]),
sql.LTE(s.C(field), values[1]),
)
}
return nil
}
// filterIsNull 为空 IS NULL操作
// SQL: WHERE name IS NULL
func filterIsNull(s *sql.Selector, p *sql.Predicate, field, _ string) *sql.Predicate {
return p.IsNull(s.C(field))
}
// filterIsNotNull 不为空 IS NOT NULL操作
// SQL: WHERE name IS NOT NULL
func filterIsNotNull(s *sql.Selector, p *sql.Predicate, field, _ string) *sql.Predicate {
return p.Not().IsNull(s.C(field))
}
// filterContains LIKE 前后模糊查询
// SQL: WHERE name LIKE '%L%';
func filterContains(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.Contains(s.C(field), value)
}
// filterInsensitiveContains ILIKE 前后模糊查询
// SQL: WHERE name ILIKE '%L%';
func filterInsensitiveContains(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.ContainsFold(s.C(field), value)
}
// filterStartsWith LIKE 前缀+模糊查询
// SQL: WHERE name LIKE 'La%';
func filterStartsWith(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.HasPrefix(s.C(field), value)
}
// filterInsensitiveStartsWith ILIKE 前缀+模糊查询
// SQL: WHERE name ILIKE 'La%';
func filterInsensitiveStartsWith(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.EqualFold(s.C(field), value+"%")
}
// filterEndsWith LIKE 后缀+模糊查询
// SQL: WHERE name LIKE '%a';
func filterEndsWith(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.HasSuffix(s.C(field), value)
}
// filterInsensitiveEndsWith ILIKE 后缀+模糊查询
// SQL: WHERE name ILIKE '%a';
func filterInsensitiveEndsWith(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.EqualFold(s.C(field), "%"+value)
}
// filterExact LIKE 操作 精确比对
// SQL: WHERE name LIKE 'a';
func filterExact(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.Like(s.C(field), value)
}
// filterInsensitiveExact ILIKE 操作 不区分大小写,精确比对
// SQL: WHERE name ILIKE 'a';
func filterInsensitiveExact(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
return p.EqualFold(s.C(field), value)
}
// filterRegex 正则查找
// MySQL: WHERE title REGEXP BINARY '^(An?|The) +'
// Oracle: WHERE REGEXP_LIKE(title, '^(An?|The) +', 'c');
// PostgreSQL: WHERE title ~ '^(An?|The) +';
// SQLite: WHERE title REGEXP '^(An?|The) +';
func filterRegex(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
p.Append(func(b *sql.Builder) {
switch s.Builder.Dialect() {
case dialect.Postgres:
b.Ident(s.C(field)).WriteString(" ~ ")
b.Arg(value)
break
case dialect.MySQL:
b.Ident(s.C(field)).WriteString(" REGEXP BINARY ")
b.Arg(value)
break
case dialect.SQLite:
b.Ident(s.C(field)).WriteString(" REGEXP ")
b.Arg(value)
break
case dialect.Gremlin:
break
}
})
return p
}
// filterInsensitiveRegex 正则查找 不区分大小写
// MySQL: WHERE title REGEXP '^(an?|the) +'
// Oracle: WHERE REGEXP_LIKE(title, '^(an?|the) +', 'i');
// PostgreSQL: WHERE title ~* '^(an?|the) +';
// SQLite: WHERE title REGEXP '(?i)^(an?|the) +';
func filterInsensitiveRegex(s *sql.Selector, p *sql.Predicate, field, value string) *sql.Predicate {
p.Append(func(b *sql.Builder) {
switch s.Builder.Dialect() {
case dialect.Postgres:
b.Ident(s.C(field)).WriteString(" ~* ")
b.Arg(strings.ToLower(value))
break
case dialect.MySQL:
b.Ident(s.C(field)).WriteString(" REGEXP ")
b.Arg(strings.ToLower(value))
break
case dialect.SQLite:
b.Ident(s.C(field)).WriteString(" REGEXP ")
if !strings.HasPrefix(value, "(?i)") {
value = "(?i)" + value
}
b.Arg(strings.ToLower(value))
break
case dialect.Gremlin:
break
}
})
return p
}
// filterSearch 全文搜索
// SQL:
func filterSearch(s *sql.Selector, p *sql.Predicate, _, _ string) *sql.Predicate {
p.Append(func(b *sql.Builder) {
switch s.Builder.Dialect() {
}
})
return p
}
// filterDatePart 时间戳提取日期
// SQL: select extract(quarter from timestamp '2018-08-15 12:10:10');
func filterDatePart(s *sql.Selector, p *sql.Predicate, datePart, field string) *sql.Predicate {
p.Append(func(b *sql.Builder) {
switch s.Builder.Dialect() {
case dialect.Postgres:
str := fmt.Sprintf("EXTRACT('%s' FROM %s)", strings.ToUpper(datePart), s.C(field))
b.WriteString(str)
//b.Arg(strings.ToLower(value))
break
case dialect.MySQL:
str := fmt.Sprintf("%s(%s)", strings.ToUpper(datePart), s.C(field))
b.WriteString(str)
//b.Arg(strings.ToLower(value))
break
}
})
return p
}
// filterDatePartField 日期
func filterDatePartField(s *sql.Selector, datePart, field string) string {
p := sql.P()
switch s.Builder.Dialect() {
case dialect.Postgres:
str := fmt.Sprintf("EXTRACT('%s' FROM %s)", strings.ToUpper(datePart), s.C(field))
p.WriteString(str)
break
case dialect.MySQL:
str := fmt.Sprintf("%s(%s)", strings.ToUpper(datePart), s.C(field))
p.WriteString(str)
break
}
return p.String()
}
// filterJsonb 提取JSONB字段
// Postgresql: WHERE ("app_profile"."preferences" ->> 'daily_email') = 'true'
func filterJsonb(s *sql.Selector, p *sql.Predicate, jsonbField, field string) *sql.Predicate {
field = stringcase.ToSnakeCase(field)
p.Append(func(b *sql.Builder) {
switch s.Builder.Dialect() {
case dialect.Postgres:
b.Ident(s.C(field)).WriteString(" ->> ").WriteString("'" + jsonbField + "'")
//b.Arg(strings.ToLower(value))
break
case dialect.MySQL:
str := fmt.Sprintf("JSON_EXTRACT(%s, '$.%s')", s.C(field), jsonbField)
b.WriteString(str)
//b.Arg(strings.ToLower(value))
break
}
})
return p
}
// filterJsonbField JSONB字段
func filterJsonbField(s *sql.Selector, jsonbField, field string) string {
field = stringcase.ToSnakeCase(field)
p := sql.P()
switch s.Builder.Dialect() {
case dialect.Postgres:
p.Ident(s.C(field)).WriteString(" ->> ").WriteString("'" + jsonbField + "'")
//b.Arg(strings.ToLower(value))
break
case dialect.MySQL:
str := fmt.Sprintf("JSON_EXTRACT(%s, '$.%s')", s.C(field), jsonbField)
p.WriteString(str)
//b.Arg(strings.ToLower(value))
break
}
return p.String()
}