From 29a8782662a753557664b5b61b2ce596043cd861 Mon Sep 17 00:00:00 2001 From: Bobo Date: Sun, 29 Jun 2025 09:29:47 +0800 Subject: [PATCH] feat: database. --- database/clickhouse/README.md | 18 ++ database/clickhouse/batch.go | 198 +++++++++++++++++ database/clickhouse/client.go | 345 +++++++++++++++++++++++++++-- database/clickhouse/client_test.go | 256 +++++++++++++++++++++ database/clickhouse/errors.go | 80 +++++++ database/clickhouse/go.mod | 17 +- database/clickhouse/go.sum | 34 ++- database/clickhouse/query.go | 246 ++++++++++++++++++++ database/clickhouse/query_test.go | 120 ++++++++++ database/elasticsearch/README.md | 39 ---- database/elasticsearch/client.go | 10 +- database/elasticsearch/go.mod | 5 +- database/elasticsearch/go.sum | 6 +- database/influxdb/client.go | 36 ++- database/influxdb/go.mod | 2 +- database/influxdb/utils.go | 37 ++++ database/influxdb/utils_test.go | 77 +++++++ database/mongodb/consts.go | 19 +- database/mongodb/go.mod | 2 +- database/mongodb/query.go | 16 +- database/mongodb/query_test.go | 19 ++ tag.bat | 8 +- 22 files changed, 1481 insertions(+), 109 deletions(-) create mode 100644 database/clickhouse/README.md create mode 100644 database/clickhouse/batch.go create mode 100644 database/clickhouse/client_test.go create mode 100644 database/clickhouse/errors.go create mode 100644 database/clickhouse/query.go create mode 100644 database/clickhouse/query_test.go diff --git a/database/clickhouse/README.md b/database/clickhouse/README.md new file mode 100644 index 0000000..7d2bc21 --- /dev/null +++ b/database/clickhouse/README.md @@ -0,0 +1,18 @@ +# ClickHouse + +## Docker部署 + +```bash +docker pull bitnami/clickhouse:latest + +docker run -itd \ + --name clickhouse-server \ + --network=app-tier \ + -p 8123:8123 \ + -p 9000:9000 \ + -p 9004:9004 \ + -e ALLOW_EMPTY_PASSWORD=no \ + -e CLICKHOUSE_ADMIN_USER=default \ + -e CLICKHOUSE_ADMIN_PASSWORD=123456 \ + bitnami/clickhouse:latest +``` diff --git a/database/clickhouse/batch.go b/database/clickhouse/batch.go new file mode 100644 index 0000000..226f3bc --- /dev/null +++ b/database/clickhouse/batch.go @@ -0,0 +1,198 @@ +package clickhouse + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "sync" + + clickhouseV2 "github.com/ClickHouse/clickhouse-go/v2" + driverV2 "github.com/ClickHouse/clickhouse-go/v2/lib/driver" +) + +// BatchInserter 批量插入器 +type BatchInserter struct { + conn clickhouseV2.Conn + tableName string + columns []string + batchSize int + rows []interface{} + insertStmt string + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc +} + +// NewBatchInserter 创建新的批量插入器 +func NewBatchInserter( + ctx context.Context, + conn clickhouseV2.Conn, + tableName string, + batchSize int, + columns []string, +) (*BatchInserter, error) { + if batchSize <= 0 { + batchSize = 1000 // 默认批量大小 + } + + if len(columns) == 0 { + return nil, errors.New("必须指定列名") + } + + // 构建INSERT语句 + placeholders := make([]string, len(columns)) + for i := range placeholders { + placeholders[i] = "?" + } + + insertStmt := fmt.Sprintf( + "INSERT INTO %s (%s) VALUES (%s)", + tableName, + strings.Join(columns, ", "), + strings.Join(placeholders, ", "), + ) + + ctx, cancel := context.WithCancel(ctx) + + return &BatchInserter{ + conn: conn, + tableName: tableName, + columns: columns, + batchSize: batchSize, + rows: make([]interface{}, 0, batchSize), + insertStmt: insertStmt, + ctx: ctx, + cancel: cancel, + }, nil +} + +// Add 添加数据行 +func (bi *BatchInserter) Add(row interface{}) error { + bi.mu.Lock() + defer bi.mu.Unlock() + + // 检查上下文是否已取消 + if bi.ctx.Err() != nil { + return bi.ctx.Err() + } + + bi.rows = append(bi.rows, row) + + // 达到批量大小时自动提交 + if len(bi.rows) >= bi.batchSize { + return bi.flush() + } + + return nil +} + +// Flush 强制提交当前批次 +func (bi *BatchInserter) Flush() error { + bi.mu.Lock() + defer bi.mu.Unlock() + + return bi.flush() +} + +// Close 关闭插入器并提交剩余数据 +func (bi *BatchInserter) Close() error { + defer bi.cancel() + + bi.mu.Lock() + defer bi.mu.Unlock() + + return bi.flush() +} + +// flush 内部提交方法 +func (bi *BatchInserter) flush() error { + if len(bi.rows) == 0 { + return nil + } + + // 创建批量 + batch, err := bi.conn.PrepareBatch(bi.ctx, bi.insertStmt) + if err != nil { + return ErrBatchPrepareFailed + } + + // 添加所有行 + for _, row := range bi.rows { + // 使用反射获取字段值 + if err = appendStructToBatch(batch, row, bi.columns); err != nil { + return ErrBatchAppendFailed + } + } + + // 提交批量 + if err = batch.Send(); err != nil { + return ErrBatchSendFailed + } + + // 清空批次 + bi.rows = bi.rows[:0] + return nil +} + +// appendStructToBatch 使用反射将结构体字段添加到批次 +func appendStructToBatch(batch driverV2.Batch, obj interface{}, columns []string) error { + v := reflect.ValueOf(obj) + + // 如果是指针,获取指针指向的值 + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return errors.New("nil指针") + } + v = v.Elem() + } + + // 必须是结构体 + if v.Kind() != reflect.Struct { + return fmt.Errorf("期望结构体类型,得到 %v", v.Kind()) + } + + // 获取结构体类型 + t := v.Type() + + // 准备参数值 + values := make([]interface{}, len(columns)) + + // 映射列名到结构体字段 + for i, col := range columns { + // 查找匹配的字段 + found := false + for j := 0; j < v.NumField(); j++ { + field := t.Field(j) + + // 检查ch标签 + if tag := field.Tag.Get("ch"); tag == col { + values[i] = v.Field(j).Interface() + found = true + break + } + + // 检查json标签 + if tag := field.Tag.Get("json"); tag == col { + values[i] = v.Field(j).Interface() + found = true + break + } + + // 检查字段名 + if field.Name == col { + values[i] = v.Field(j).Interface() + found = true + break + } + } + + if !found { + return fmt.Errorf("未找到列 %s 对应的结构体字段", col) + } + } + + // 添加到批次 + return batch.Append(values...) +} diff --git a/database/clickhouse/client.go b/database/clickhouse/client.go index 61f0f04..ee09c06 100644 --- a/database/clickhouse/client.go +++ b/database/clickhouse/client.go @@ -1,9 +1,13 @@ package clickhouse import ( + "context" "crypto/tls" + "database/sql" + "net/url" - "github.com/ClickHouse/clickhouse-go/v2" + clickhouseV2 "github.com/ClickHouse/clickhouse-go/v2" + driverV2 "github.com/ClickHouse/clickhouse-go/v2/lib/driver" "github.com/go-kratos/kratos/v2/log" @@ -11,45 +15,340 @@ import ( "github.com/tx7do/kratos-bootstrap/utils" ) -func NewClickHouseClient(cfg *conf.Bootstrap, l *log.Helper) clickhouse.Conn { +type Creator func() any + +var compressionMap = map[string]clickhouseV2.CompressionMethod{ + "none": clickhouseV2.CompressionNone, + "zstd": clickhouseV2.CompressionZSTD, + "lz4": clickhouseV2.CompressionLZ4, + "lz4hc": clickhouseV2.CompressionLZ4HC, + "gzip": clickhouseV2.CompressionGZIP, + "deflate": clickhouseV2.CompressionDeflate, + "br": clickhouseV2.CompressionBrotli, +} + +type Client struct { + log *log.Helper + + conn clickhouseV2.Conn + db *sql.DB +} + +func NewClient(logger log.Logger, cfg *conf.Bootstrap) (*Client, error) { + c := &Client{ + log: log.NewHelper(log.With(logger, "module", "clickhouse-client")), + } + + if err := c.createClickHouseClient(cfg); err != nil { + return nil, err + } + + return c, nil +} + +// createClickHouseClient 创建ClickHouse客户端 +func (c *Client) createClickHouseClient(cfg *conf.Bootstrap) error { if cfg.Data == nil || cfg.Data.Clickhouse == nil { - l.Warn("ClickHouse config is nil") return nil } - options := &clickhouse.Options{ - Addr: []string{cfg.Data.Clickhouse.Address}, - Auth: clickhouse.Auth{ - Database: cfg.Data.Clickhouse.Database, - Username: cfg.Data.Clickhouse.Username, - Password: cfg.Data.Clickhouse.Password, - }, - Debug: cfg.Data.Clickhouse.Debug, - DialTimeout: cfg.Data.Clickhouse.DialTimeout.AsDuration(), - MaxOpenConns: int(cfg.Data.Clickhouse.MaxOpenConns), - MaxIdleConns: int(cfg.Data.Clickhouse.MaxIdleConns), - ConnMaxLifetime: cfg.Data.Clickhouse.ConnMaxLifeTime.AsDuration(), + opts := &clickhouseV2.Options{} + + if cfg.Data.Clickhouse.Dsn != nil { + tmp, err := clickhouseV2.ParseDSN(cfg.Data.Clickhouse.GetDsn()) + if err != nil { + c.log.Errorf("failed to parse clickhouse DSN: %v", err) + return ErrInvalidDSN + } + opts = tmp + } + + if cfg.Data.Clickhouse.Addresses != nil { + opts.Addr = cfg.Data.Clickhouse.GetAddresses() + } + + if cfg.Data.Clickhouse.Database != nil || + cfg.Data.Clickhouse.Username != nil || + cfg.Data.Clickhouse.Password != nil { + opts.Auth = clickhouseV2.Auth{} + + if cfg.Data.Clickhouse.Database != nil { + opts.Auth.Database = cfg.Data.Clickhouse.GetDatabase() + } + if cfg.Data.Clickhouse.Username != nil { + opts.Auth.Username = cfg.Data.Clickhouse.GetUsername() + } + if cfg.Data.Clickhouse.Password != nil { + opts.Auth.Password = cfg.Data.Clickhouse.GetPassword() + } + } + + if cfg.Data.Clickhouse.Debug != nil { + opts.Debug = cfg.Data.Clickhouse.GetDebug() + } + + if cfg.Data.Clickhouse.MaxOpenConns != nil { + opts.MaxOpenConns = int(cfg.Data.Clickhouse.GetMaxOpenConns()) + } + if cfg.Data.Clickhouse.MaxIdleConns != nil { + opts.MaxIdleConns = int(cfg.Data.Clickhouse.GetMaxIdleConns()) } - // 设置ssl if cfg.Data.Clickhouse.Tls != nil { var tlsCfg *tls.Config var err error - if tlsCfg, err = utils.LoadServerTlsConfig(cfg.Data.Clickhouse.Tls); err != nil { + if tlsCfg, err = utils.LoadServerTlsConfig(cfg.Server.Grpc.Tls); err != nil { panic(err) } if tlsCfg != nil { - options.TLS = tlsCfg + opts.TLS = tlsCfg } } - conn, err := clickhouse.Open(options) - if err != nil { - l.Fatalf("failed opening connection to clickhouse: %v", err) - return nil + if cfg.Data.Clickhouse.CompressionMethod != nil || cfg.Data.Clickhouse.CompressionLevel != nil { + opts.Compression = &clickhouseV2.Compression{} + + if cfg.Data.Clickhouse.GetCompressionMethod() != "" { + opts.Compression.Method = compressionMap[cfg.Data.Clickhouse.GetCompressionMethod()] + } + if cfg.Data.Clickhouse.CompressionLevel != nil { + opts.Compression.Level = int(cfg.Data.Clickhouse.GetCompressionLevel()) + } + } + if cfg.Data.Clickhouse.MaxCompressionBuffer != nil { + opts.MaxCompressionBuffer = int(cfg.Data.Clickhouse.GetMaxCompressionBuffer()) } - return conn + if cfg.Data.Clickhouse.DialTimeout != nil { + opts.DialTimeout = cfg.Data.Clickhouse.GetDialTimeout().AsDuration() + } + if cfg.Data.Clickhouse.ReadTimeout != nil { + opts.ReadTimeout = cfg.Data.Clickhouse.GetReadTimeout().AsDuration() + } + if cfg.Data.Clickhouse.ConnMaxLifetime != nil { + opts.ConnMaxLifetime = cfg.Data.Clickhouse.GetConnMaxLifetime().AsDuration() + } + + if cfg.Data.Clickhouse.HttpProxy != nil { + proxyURL, err := url.Parse(cfg.Data.Clickhouse.GetHttpProxy()) + if err != nil { + c.log.Errorf("failed to parse HTTP proxy URL: %v", err) + return ErrInvalidProxyURL + } + + opts.HTTPProxyURL = proxyURL + } + + if cfg.Data.Clickhouse.ConnectionOpenStrategy != nil { + strategy := clickhouseV2.ConnOpenInOrder + switch cfg.Data.Clickhouse.GetConnectionOpenStrategy() { + case "in_order": + strategy = clickhouseV2.ConnOpenInOrder + case "round_robin": + strategy = clickhouseV2.ConnOpenRoundRobin + case "random": + strategy = clickhouseV2.ConnOpenRandom + } + opts.ConnOpenStrategy = strategy + } + + if cfg.Data.Clickhouse.Scheme != nil { + switch cfg.Data.Clickhouse.GetScheme() { + case "http": + opts.Protocol = clickhouseV2.HTTP + case "https": + opts.Protocol = clickhouseV2.HTTP + default: + opts.Protocol = clickhouseV2.Native + } + } + + if cfg.Data.Clickhouse.BlockBufferSize != nil { + opts.BlockBufferSize = uint8(cfg.Data.Clickhouse.GetBlockBufferSize()) + } + + // 创建ClickHouse连接 + conn, err := clickhouseV2.Open(opts) + if err != nil { + c.log.Errorf("failed to create clickhouse client: %v", err) + return ErrConnectionFailed + } + + c.conn = conn + + return nil +} + +// Close 关闭ClickHouse客户端连接 +func (c *Client) Close() { + if c.conn == nil { + c.log.Warn("clickhouse client is already closed or not initialized") + return + } + + if err := c.conn.Close(); err != nil { + c.log.Errorf("failed to close clickhouse client: %v", err) + } else { + c.log.Info("clickhouse client closed successfully") + } +} + +// GetServerVersion 获取ClickHouse服务器版本 +func (c *Client) GetServerVersion() string { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return "" + } + + version, err := c.conn.ServerVersion() + if err != nil { + c.log.Errorf("failed to get server version: %v", err) + return "" + } else { + c.log.Infof("ClickHouse server version: %s", version) + return version.String() + } +} + +// CheckConnection 检查ClickHouse客户端连接是否正常 +func (c *Client) CheckConnection(ctx context.Context) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + + if err := c.conn.Ping(ctx); err != nil { + c.log.Errorf("ping failed: %v", err) + return ErrPingFailed + } + + c.log.Info("clickhouse client connection is healthy") + return nil +} + +// Query 执行查询并返回结果 +func (c *Client) Query(ctx context.Context, creator Creator, results *[]any, query string, args ...interface{}) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + if creator == nil { + c.log.Error("creator function cannot be nil") + return ErrCreatorFunctionNil + } + + rows, err := c.conn.Query(ctx, query, args...) + if err != nil { + c.log.Errorf("query failed: %v", err) + return ErrQueryExecutionFailed + } + defer func(rows driverV2.Rows) { + if err = rows.Close(); err != nil { + c.log.Errorf("failed to close rows: %v", err) + } + }(rows) + + for rows.Next() { + row := creator() + if err = rows.ScanStruct(row); err != nil { + c.log.Errorf("failed to scan row: %v", err) + return ErrRowScanFailed + } + *results = append(*results, row) + } + + // 检查是否有未处理的错误 + if rows.Err() != nil { + c.log.Errorf("Rows iteration error: %v", rows.Err()) + return ErrRowsIterationError + } + + return nil +} + +// QueryRow 执行查询并返回单行结果 +func (c *Client) QueryRow(ctx context.Context, dest any, query string, args ...interface{}) error { + row := c.conn.QueryRow(ctx, query, args...) + if row == nil { + c.log.Error("query row returned nil") + return ErrRowNotFound + } + + if err := row.ScanStruct(dest); err != nil { + c.log.Errorf("") + return ErrRowScanFailed + } + + return nil +} + +// Select 封装 SELECT 子句 +func (c *Client) Select(ctx context.Context, dest any, query string, args ...interface{}) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + + err := c.conn.Select(ctx, dest, query, args...) + if err != nil { + c.log.Errorf("select failed: %v", err) + return ErrQueryExecutionFailed + } + + return nil +} + +// Exec 执行非查询语句 +func (c *Client) Exec(ctx context.Context, query string, args ...interface{}) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + + if err := c.conn.Exec(ctx, query, args...); err != nil { + c.log.Errorf("exec failed: %v", err) + return ErrExecutionFailed + } + + return nil +} + +// AsyncInsert 异步插入数据 +func (c *Client) AsyncInsert(ctx context.Context, query string, wait bool, args ...interface{}) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + + if err := c.conn.AsyncInsert(ctx, query, wait, args...); err != nil { + c.log.Errorf("exec failed: %v", err) + return ErrAsyncInsertFailed + } + + return nil +} + +// BatchInsert 批量插入数据 +func (c *Client) BatchInsert(ctx context.Context, query string, data [][]interface{}) error { + batch, err := c.conn.PrepareBatch(ctx, query) + if err != nil { + c.log.Errorf("failed to prepare batch: %v", err) + return ErrBatchPrepareFailed + } + + for _, row := range data { + if err := batch.Append(row...); err != nil { + c.log.Errorf("failed to append data: %v", err) + return ErrBatchAppendFailed + } + } + + if err = batch.Send(); err != nil { + c.log.Errorf("failed to send batch: %v", err) + return ErrBatchSendFailed + } + + return nil } diff --git a/database/clickhouse/client_test.go b/database/clickhouse/client_test.go new file mode 100644 index 0000000..e90f6f4 --- /dev/null +++ b/database/clickhouse/client_test.go @@ -0,0 +1,256 @@ +package clickhouse + +import ( + "context" + "testing" + "time" + + "github.com/go-kratos/kratos/v2/log" + "github.com/stretchr/testify/assert" + conf "github.com/tx7do/kratos-bootstrap/api/gen/go/conf/v1" +) + +type Candle struct { + Symbol string `json:"symbol" ch:"symbol"` + Open float64 `json:"open" ch:"open"` + High float64 `json:"high" ch:"high"` + Low float64 `json:"low" ch:"low"` + Close float64 `json:"close" ch:"close"` + Volume float64 `json:"volume" ch:"volume"` + Timestamp time.Time `json:"timestamp" ch:"timestamp"` +} + +func createTestClient() *Client { + database := "finances" + username := "default" + password := "*Abcd123456" + cli, _ := NewClient( + log.DefaultLogger, + &conf.Bootstrap{ + Data: &conf.Data{ + Clickhouse: &conf.Data_ClickHouse{ + Addresses: []string{"localhost:9000"}, + Database: &database, + Username: &username, + Password: &password, + }, + }, + }, + ) + return cli +} + +func createCandlesTable(client *Client) { + // 创建表的 SQL 语句 + createTableQuery := ` + CREATE TABLE IF NOT EXISTS candles ( + timestamp DateTime, + symbol String, + open Float64, + high Float64, + low Float64, + close Float64, + volume Float64 + ) ENGINE = MergeTree() + ORDER BY timestamp + ` + err := client.Exec(context.Background(), createTableQuery) + if err != nil { + log.Errorf("Failed to create candles table: %v", err) + return + } +} + +func TestNewClient(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + // 测试 CheckConnection + err := client.CheckConnection(context.Background()) + assert.NoError(t, err, "CheckConnection 应该成功执行") + + // 测试 GetServerVersion + version := client.GetServerVersion() + assert.NotEmpty(t, version, "GetServerVersion 应该返回非空值") + + createCandlesTable(client) +} + +func TestAsyncInsert(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + // 测试异步插入 + err := client.AsyncInsert(context.Background(), "INSERT INTO test_table (id, name) VALUES (?, ?)", true, 1, "example") + assert.NoError(t, err, "AsyncInsert 应该成功执行") +} + +func TestBatchInsert(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + // 测试数据 + data := [][]interface{}{ + {1, "example1"}, + {2, "example2"}, + {3, "example3"}, + } + + // 测试批量插入 + err := client.BatchInsert(context.Background(), "INSERT INTO test_table (id, name) VALUES (?, ?)", data) + assert.NoError(t, err, "BatchInsert 应该成功执行") +} + +func TestInsertIntoCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 插入数据的 SQL 语句 + insertQuery := ` + INSERT INTO candles (timestamp, symbol, open, high, low, close, volume) + VALUES (?, ?, ?, ?, ?, ?, ?) + ` + + // 测试数据 + err := client.AsyncInsert(context.Background(), insertQuery, true, + "2023-10-01 12:00:00", "AAPL", 100.5, 105.0, 99.5, 102.0, 1500.0) + assert.NoError(t, err, "InsertIntoCandlesTable 应该成功执行") +} + +func TestQueryCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 查询数据的 SQL 语句 + query := ` + SELECT timestamp, symbol, open, high, low, close, volume + FROM candles + ` + + // 定义结果集 + var results []any + + // 执行查询 + err := client.Query(context.Background(), func() interface{} { return &Candle{} }, &results, query) + assert.NoError(t, err, "QueryCandlesTable 应该成功执行") + assert.NotEmpty(t, results, "QueryCandlesTable 应该返回结果") +} + +func TestSelectCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 查询数据的 SQL 语句 + query := ` + SELECT timestamp, symbol, open, high, low, close, volume + FROM candles + ` + + // 定义结果集 + var results []Candle + + // 执行查询 + err := client.Select(context.Background(), &results, query) + assert.NoError(t, err, "QueryCandlesTable 应该成功执行") + assert.NotEmpty(t, results, "QueryCandlesTable 应该返回结果") +} + +func TestQueryRow(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 插入测试数据 + insertQuery := ` + INSERT INTO candles (timestamp, symbol, open, high, low, close, volume) + VALUES (?, ?, ?, ?, ?, ?, ?) + ` + err := client.AsyncInsert(context.Background(), insertQuery, true, + "2023-10-01 12:00:00", "AAPL", 100.5, 105.0, 99.5, 102.0, 1500.0) + assert.NoError(t, err, "数据插入失败") + + // 查询单行数据 + query := ` + SELECT timestamp, symbol, open, high, low, close, volume + FROM candles + WHERE symbol = ? + ` + var result Candle + + err = client.QueryRow(context.Background(), &result, query, "AAPL") + assert.NoError(t, err, "QueryRow 应该成功执行") + assert.Equal(t, "AAPL", result.Symbol, "symbol 列值应该为 AAPL") + assert.Equal(t, 100.5, result.Open, "open 列值应该为 100.5") + assert.Equal(t, 1500.0, result.Volume, "volume 列值应该为 1500.0") +} + +func TestDropCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + // 删除表的 SQL 语句 + dropTableQuery := `DROP TABLE IF EXISTS candles` + + // 执行删除表操作 + err := client.Exec(context.Background(), dropTableQuery) + assert.NoError(t, err, "DropCandlesTable 应该成功执行") +} + +func TestAggregateCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 聚合查询的 SQL 语句 + query := ` + SELECT symbol, + MAX(high) AS max_high, + MIN(low) AS min_low, + AVG(close) AS avg_close, + SUM(volume) AS total_volume + FROM candles + GROUP BY symbol + ` + + // 定义结果集 + var results []struct { + Symbol string `ch:"symbol"` + MaxHigh float64 `ch:"max_high"` + MinLow float64 `ch:"min_low"` + AvgClose float64 `ch:"avg_close"` + TotalVolume float64 `ch:"total_volume"` + } + + // 执行查询 + err := client.Select(context.Background(), &results, query) + assert.NoError(t, err, "AggregateCandlesTable 应该成功执行") + assert.NotEmpty(t, results, "AggregateCandlesTable 应该返回结果") +} + +func TestBatchInsertCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 测试数据 + data := [][]interface{}{ + {"2023-10-01 12:00:00", "AAPL", 100.5, 105.0, 99.5, 102.0, 1500.0}, + {"2023-10-01 12:01:00", "GOOG", 200.5, 205.0, 199.5, 202.0, 2500.0}, + {"2023-10-01 12:02:00", "MSFT", 300.5, 305.0, 299.5, 302.0, 3500.0}, + } + + // 批量插入数据 + err := client.BatchInsert(context.Background(), ` + INSERT INTO candles (timestamp, symbol, open, high, low, close, volume) + VALUES (?, ?, ?, ?, ?, ?, ?)`, data) + assert.NoError(t, err, "BatchInsertCandlesTable 应该成功执行") +} diff --git a/database/clickhouse/errors.go b/database/clickhouse/errors.go new file mode 100644 index 0000000..ff2dd7c --- /dev/null +++ b/database/clickhouse/errors.go @@ -0,0 +1,80 @@ +package clickhouse + +import "github.com/go-kratos/kratos/v2/errors" + +var ( + // ErrInvalidColumnName is returned when an invalid column name is used. + ErrInvalidColumnName = errors.InternalServer("INVALID_COLUMN_NAME", "invalid column name") + + // ErrInvalidTableName is returned when an invalid table name is used. + ErrInvalidTableName = errors.InternalServer("INVALID_TABLE_NAME", "invalid table name") + + // ErrInvalidCondition is returned when an invalid condition is used in a query. + ErrInvalidCondition = errors.InternalServer("INVALID_CONDITION", "invalid condition in query") + + // ErrQueryExecutionFailed is returned when a query execution fails. + ErrQueryExecutionFailed = errors.InternalServer("QUERY_EXECUTION_FAILED", "query execution failed") + + // ErrExecutionFailed is returned when a general execution fails. + ErrExecutionFailed = errors.InternalServer("EXECUTION_FAILED", "execution failed") + + // ErrAsyncInsertFailed is returned when an asynchronous insert operation fails. + ErrAsyncInsertFailed = errors.InternalServer("ASYNC_INSERT_FAILED", "async insert operation failed") + + // ErrRowScanFailed is returned when scanning rows from a query result fails. + ErrRowScanFailed = errors.InternalServer("ROW_SCAN_FAILED", "row scan failed") + + // ErrRowsIterationError is returned when there is an error iterating over rows. + ErrRowsIterationError = errors.InternalServer("ROWS_ITERATION_ERROR", "rows iteration error") + + // ErrRowNotFound is returned when a specific row is not found in the result set. + ErrRowNotFound = errors.InternalServer("ROW_NOT_FOUND", "row not found") + + // ErrConnectionFailed is returned when the connection to ClickHouse fails. + ErrConnectionFailed = errors.InternalServer("CONNECTION_FAILED", "failed to connect to ClickHouse") + + // ErrDatabaseNotFound is returned when the specified database is not found. + ErrDatabaseNotFound = errors.InternalServer("DATABASE_NOT_FOUND", "specified database not found") + + // ErrTableNotFound is returned when the specified table is not found. + ErrTableNotFound = errors.InternalServer("TABLE_NOT_FOUND", "specified table not found") + + // ErrInsertFailed is returned when an insert operation fails. + ErrInsertFailed = errors.InternalServer("INSERT_FAILED", "insert operation failed") + + // ErrUpdateFailed is returned when an update operation fails. + ErrUpdateFailed = errors.InternalServer("UPDATE_FAILED", "update operation failed") + + // ErrDeleteFailed is returned when a delete operation fails. + ErrDeleteFailed = errors.InternalServer("DELETE_FAILED", "delete operation failed") + + // ErrTransactionFailed is returned when a transaction fails. + ErrTransactionFailed = errors.InternalServer("TRANSACTION_FAILED", "transaction failed") + + // ErrClientNotInitialized is returned when the ClickHouse client is not initialized. + ErrClientNotInitialized = errors.InternalServer("CLIENT_NOT_INITIALIZED", "clickhouse client not initialized") + + // ErrGetServerVersionFailed is returned when getting the server version fails. + ErrGetServerVersionFailed = errors.InternalServer("GET_SERVER_VERSION_FAILED", "failed to get server version") + + // ErrPingFailed is returned when a ping to the ClickHouse server fails. + ErrPingFailed = errors.InternalServer("PING_FAILED", "ping to ClickHouse server failed") + + // ErrCreatorFunctionNil is returned when the creator function is nil. + ErrCreatorFunctionNil = errors.InternalServer("CREATOR_FUNCTION_NIL", "creator function cannot be nil") + + // ErrBatchPrepareFailed is returned when a batch prepare operation fails. + ErrBatchPrepareFailed = errors.InternalServer("BATCH_PREPARE_FAILED", "batch prepare operation failed") + + // ErrBatchSendFailed is returned when a batch send operation fails. + ErrBatchSendFailed = errors.InternalServer("BATCH_SEND_FAILED", "batch send operation failed") + + // ErrBatchAppendFailed is returned when appending to a batch fails. + ErrBatchAppendFailed = errors.InternalServer("BATCH_APPEND_FAILED", "batch append operation failed") + + // ErrInvalidDSN is returned when the data source name (DSN) is invalid. + ErrInvalidDSN = errors.InternalServer("INVALID_DSN", "invalid data source name") + + // ErrInvalidProxyURL is returned when the proxy URL is invalid. + ErrInvalidProxyURL = errors.InternalServer("INVALID_PROXY_URL", "invalid proxy URL") +) diff --git a/database/clickhouse/go.mod b/database/clickhouse/go.mod index b5d4aed..00e6788 100644 --- a/database/clickhouse/go.mod +++ b/database/clickhouse/go.mod @@ -7,27 +7,32 @@ toolchain go1.23.3 replace github.com/tx7do/kratos-bootstrap/api => ../../api require ( - github.com/ClickHouse/clickhouse-go/v2 v2.35.0 + github.com/ClickHouse/clickhouse-go/v2 v2.37.2 github.com/go-kratos/kratos/v2 v2.8.4 - github.com/tx7do/kratos-bootstrap/api v0.0.21 + github.com/stretchr/testify v1.10.0 + github.com/tx7do/kratos-bootstrap/api v0.0.27 github.com/tx7do/kratos-bootstrap/utils v0.1.3 ) require ( - github.com/ClickHouse/ch-go v0.66.0 // indirect - github.com/andybalholm/brotli v1.1.1 // indirect + github.com/ClickHouse/ch-go v0.66.1 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.7.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/paulmach/orb v0.11.1 // indirect github.com/pierrec/lz4/v4 v4.1.22 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect - go.opentelemetry.io/otel v1.36.0 // indirect - go.opentelemetry.io/otel/trace v1.36.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect golang.org/x/sys v0.33.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/grpc v1.73.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/database/clickhouse/go.sum b/database/clickhouse/go.sum index 43e4ec0..f9d0a2e 100644 --- a/database/clickhouse/go.sum +++ b/database/clickhouse/go.sum @@ -1,9 +1,9 @@ -github.com/ClickHouse/ch-go v0.66.0 h1:hLslxxAVb2PHpbHr4n0d6aP8CEIpUYGMVT1Yj/Q5Img= -github.com/ClickHouse/ch-go v0.66.0/go.mod h1:noiHWyLMJAZ5wYuq3R/K0TcRhrNA8h7o1AqHX0klEhM= -github.com/ClickHouse/clickhouse-go/v2 v2.35.0 h1:ZMLZqxu+NiW55f4JS32kzyEbMb7CthGn3ziCcULOvSE= -github.com/ClickHouse/clickhouse-go/v2 v2.35.0/go.mod h1:O2FFT/rugdpGEW2VKyEGyMUWyQU0ahmenY9/emxLPxs= -github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= -github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/ClickHouse/ch-go v0.66.1 h1:LQHFslfVYZsISOY0dnOYOXGkOUvpv376CCm8g7W74A4= +github.com/ClickHouse/ch-go v0.66.1/go.mod h1:NEYcg3aOFv2EmTJfo4m2WF7sHB/YFbLUuIWv9iq76xY= +github.com/ClickHouse/clickhouse-go/v2 v2.37.2 h1:wRLNKoynvHQEN4znnVHNLaYnrqVc9sGJmGYg+GGCfto= +github.com/ClickHouse/clickhouse-go/v2 v2.37.2/go.mod h1:pH2zrBGp5Y438DMwAxXMm1neSXPPjSI7tD4MURVULw8= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -15,6 +15,8 @@ github.com/go-kratos/kratos/v2 v2.8.4 h1:eIJLE9Qq9WSoKx+Buy2uPyrahtF/lPh+Xf4MTpx github.com/go-kratos/kratos/v2 v2.8.4/go.mod h1:mq62W2101a5uYyRxe+7IdWubu7gZCGYqSNKwGFiiRcw= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -65,10 +67,10 @@ github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7Jul github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g= -go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= -go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= -go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= -go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -80,12 +82,14 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -99,6 +103,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -107,6 +113,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= diff --git a/database/clickhouse/query.go b/database/clickhouse/query.go new file mode 100644 index 0000000..eed53f7 --- /dev/null +++ b/database/clickhouse/query.go @@ -0,0 +1,246 @@ +package clickhouse + +import ( + "fmt" + "regexp" + "strings" + + "github.com/go-kratos/kratos/v2/log" +) + +type QueryBuilder struct { + table string + columns []string + distinct bool + conditions []string + orderBy []string + groupBy []string + having []string + joins []string + with []string + union []string + limit int + offset int + params []interface{} // 用于存储参数 + useIndex string // 索引提示 + cacheResult bool // 是否缓存查询结果 + debug bool // 是否启用调试 + log *log.Helper +} + +// NewQueryBuilder 创建一个新的 QueryBuilder 实例 +func NewQueryBuilder(table string, log *log.Helper) *QueryBuilder { + return &QueryBuilder{ + log: log, + table: table, + params: []interface{}{}, + } +} + +// EnableDebug 启用调试模式 +func (qb *QueryBuilder) EnableDebug() *QueryBuilder { + qb.debug = true + return qb +} + +// logDebug 打印调试信息 +func (qb *QueryBuilder) logDebug(message string) { + if qb.debug { + qb.log.Debug("[QueryBuilder Debug]:", message) + } +} + +// Select 设置查询的列 +func (qb *QueryBuilder) Select(columns ...string) *QueryBuilder { + for _, column := range columns { + if !isValidIdentifier(column) { + panic("Invalid column name") + } + } + + qb.columns = columns + return qb +} + +// Distinct 设置 DISTINCT 查询 +func (qb *QueryBuilder) Distinct() *QueryBuilder { + qb.distinct = true + return qb +} + +// Where 添加查询条件并支持参数化 +func (qb *QueryBuilder) Where(condition string, args ...interface{}) *QueryBuilder { + if !isValidCondition(condition) { + panic("Invalid condition") + } + + qb.conditions = append(qb.conditions, condition) + qb.params = append(qb.params, args...) + return qb +} + +// OrderBy 设置排序条件 +func (qb *QueryBuilder) OrderBy(order string) *QueryBuilder { + qb.orderBy = append(qb.orderBy, order) + return qb +} + +// GroupBy 设置分组条件 +func (qb *QueryBuilder) GroupBy(columns ...string) *QueryBuilder { + qb.groupBy = append(qb.groupBy, columns...) + return qb +} + +// Having 添加分组后的过滤条件并支持参数化 +func (qb *QueryBuilder) Having(condition string, args ...interface{}) *QueryBuilder { + qb.having = append(qb.having, condition) + qb.params = append(qb.params, args...) + return qb +} + +// Join 添加 JOIN 操作 +func (qb *QueryBuilder) Join(joinType, table, onCondition string) *QueryBuilder { + join := fmt.Sprintf("%s JOIN %s ON %s", joinType, table, onCondition) + qb.joins = append(qb.joins, join) + return qb +} + +// With 添加 WITH 子句 +func (qb *QueryBuilder) With(expression string) *QueryBuilder { + qb.with = append(qb.with, expression) + return qb +} + +// Union 添加 UNION 操作 +func (qb *QueryBuilder) Union(query string) *QueryBuilder { + qb.union = append(qb.union, query) + return qb +} + +// Limit 设置查询结果的限制数量 +func (qb *QueryBuilder) Limit(limit int) *QueryBuilder { + qb.limit = limit + return qb +} + +// Offset 设置查询结果的偏移量 +func (qb *QueryBuilder) Offset(offset int) *QueryBuilder { + qb.offset = offset + return qb +} + +// UseIndex 设置索引提示 +func (qb *QueryBuilder) UseIndex(index string) *QueryBuilder { + qb.useIndex = index + return qb +} + +// CacheResult 启用查询结果缓存 +func (qb *QueryBuilder) CacheResult() *QueryBuilder { + qb.cacheResult = true + return qb +} + +// ArrayJoin 添加 ARRAY JOIN 子句 +func (qb *QueryBuilder) ArrayJoin(expression string) *QueryBuilder { + qb.joins = append(qb.joins, fmt.Sprintf("ARRAY JOIN %s", expression)) + return qb +} + +// Final 添加 FINAL 修饰符 +func (qb *QueryBuilder) Final() *QueryBuilder { + qb.table = fmt.Sprintf("%s FINAL", qb.table) + return qb +} + +// Sample 添加 SAMPLE 子句 +func (qb *QueryBuilder) Sample(sampleRate float64) *QueryBuilder { + qb.table = fmt.Sprintf("%s SAMPLE %f", qb.table, sampleRate) + return qb +} + +// LimitBy 添加 LIMIT BY 子句 +func (qb *QueryBuilder) LimitBy(limit int, columns ...string) *QueryBuilder { + qb.limit = limit + qb.orderBy = append(qb.orderBy, fmt.Sprintf("LIMIT BY %d (%s)", limit, strings.Join(columns, ", "))) + return qb +} + +// PreWhere 添加 PREWHERE 子句 +func (qb *QueryBuilder) PreWhere(condition string, args ...interface{}) *QueryBuilder { + qb.conditions = append([]string{condition}, qb.conditions...) + qb.params = append(args, qb.params...) + return qb +} + +// Format 添加 FORMAT 子句 +func (qb *QueryBuilder) Format(format string) *QueryBuilder { + qb.union = append(qb.union, fmt.Sprintf("FORMAT %s", format)) + return qb +} + +// Build 构建最终的 SQL 查询 +func (qb *QueryBuilder) Build() (string, []interface{}) { + query := "" + + if qb.cacheResult { + query += "/* CACHE */ " + } + + query += "SELECT " + if qb.distinct { + query += "DISTINCT " + } + query += qb.buildColumns() + query += fmt.Sprintf(" FROM %s", qb.table) + + if qb.useIndex != "" { + query += fmt.Sprintf(" USE INDEX (%s)", qb.useIndex) + } + + if len(qb.conditions) > 0 { + query += fmt.Sprintf(" WHERE %s", strings.Join(qb.conditions, " AND ")) + } + + if len(qb.groupBy) > 0 { + query += fmt.Sprintf(" GROUP BY %s", strings.Join(qb.groupBy, ", ")) + } + + if len(qb.having) > 0 { + query += fmt.Sprintf(" HAVING %s", strings.Join(qb.having, " AND ")) + } + + if len(qb.orderBy) > 0 { + query += fmt.Sprintf(" ORDER BY %s", strings.Join(qb.orderBy, ", ")) + } + + if qb.limit > 0 { + query += fmt.Sprintf(" LIMIT %d", qb.limit) + } + + if qb.offset > 0 { + query += fmt.Sprintf(" OFFSET %d", qb.offset) + } + + return query, qb.params +} + +func (qb *QueryBuilder) buildColumns() string { + if len(qb.columns) == 0 { + return "*" + } + return strings.Join(qb.columns, ", ") +} + +// isValidIdentifier 验证表名或列名是否合法 +func isValidIdentifier(identifier string) bool { + // 仅允许字母、数字、下划线,且不能以数字开头 + matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, identifier) + return matched +} + +// isValidCondition 验证条件语句是否合法 +func isValidCondition(condition string) bool { + // 简单验证条件中是否包含危险字符 + return !strings.Contains(condition, ";") && !strings.Contains(condition, "--") +} diff --git a/database/clickhouse/query_test.go b/database/clickhouse/query_test.go new file mode 100644 index 0000000..3cc7181 --- /dev/null +++ b/database/clickhouse/query_test.go @@ -0,0 +1,120 @@ +package clickhouse + +import ( + "testing" + + "github.com/go-kratos/kratos/v2/log" + "github.com/stretchr/testify/assert" +) + +func TestQueryBuilder(t *testing.T) { + logger := log.NewHelper(log.DefaultLogger) + qb := NewQueryBuilder("test_table", logger) + + // 测试 Select 方法 + qb.Select("id", "name") + query, params := qb.Build() + assert.Contains(t, query, "SELECT id, name FROM test_table") + + // 测试 Distinct 方法 + qb.Distinct() + query, _ = qb.Build() + assert.Contains(t, query, "SELECT DISTINCT id, name FROM test_table") + + // 测试 Where 方法 + qb.Where("id > ?", 10).Where("name = ?", "example") + query, params = qb.Build() + assert.Contains(t, query, "WHERE id > ? AND name = ?") + assert.Equal(t, []interface{}{10, "example"}, params) + + // 测试 OrderBy 方法 + qb.OrderBy("name ASC") + query, _ = qb.Build() + assert.Contains(t, query, "ORDER BY name ASC") + + // 测试 GroupBy 方法 + qb.GroupBy("category") + query, _ = qb.Build() + assert.Contains(t, query, "GROUP BY category") + + // 测试 Having 方法 + qb.Having("COUNT(id) > ?", 5) + query, params = qb.Build() + assert.Contains(t, query, "HAVING COUNT(id) > ?") + assert.Equal(t, []interface{}{10, "example", 5}, params) + + // 测试 Join 方法 + qb.Join("INNER", "other_table", "test_table.id = other_table.id") + query, _ = qb.Build() + assert.Contains(t, query, "INNER JOIN other_table ON test_table.id = other_table.id") + + // 测试 With 方法 + qb.With("temp AS (SELECT id FROM another_table WHERE status = 'active')") + query, _ = qb.Build() + assert.Contains(t, query, "WITH temp AS (SELECT id FROM another_table WHERE status = 'active')") + + // 测试 Union 方法 + qb.Union("SELECT id FROM another_table") + query, _ = qb.Build() + assert.Contains(t, query, "UNION SELECT id FROM another_table") + + // 测试 Limit 和 Offset 方法 + qb.Limit(10).Offset(20) + query, _ = qb.Build() + assert.Contains(t, query, "LIMIT 10 OFFSET 20") + + // 测试 UseIndex 方法 + qb.UseIndex("idx_name") + query, _ = qb.Build() + assert.Contains(t, query, "USE INDEX (idx_name)") + + // 测试 CacheResult 方法 + qb.CacheResult() + query, _ = qb.Build() + assert.Contains(t, query, "/* CACHE */") + + // 测试 EnableDebug 方法 + qb.EnableDebug() + assert.True(t, qb.debug) + + // 测试 ArrayJoin 方法 + qb.ArrayJoin("array_column") + query, _ = qb.Build() + assert.Contains(t, query, "ARRAY JOIN array_column") + + // 测试 Final 方法 + qb.Final() + query, _ = qb.Build() + assert.Contains(t, query, "test_table FINAL") + + // 测试 Sample 方法 + qb.Sample(0.1) + query, _ = qb.Build() + assert.Contains(t, query, "test_table SAMPLE 0.100000") + + // 测试 LimitBy 方法 + qb.LimitBy(5, "name") + query, _ = qb.Build() + assert.Contains(t, query, "LIMIT BY 5 (name)") + + // 测试 PreWhere 方法 + qb.PreWhere("status = ?", "active") + query, params = qb.Build() + assert.Contains(t, query, "PREWHERE status = ?") + assert.Equal(t, []interface{}{"active"}, params) + + // 测试 Format 方法 + qb.Format("JSON") + query, _ = qb.Build() + assert.Contains(t, query, "FORMAT JSON") + + // 测试边界情况:空列名 + assert.Panics(t, func() { + qb.Select("") + }, "应该抛出异常:无效的列名") + + // 测试边界情况:无效条件 + assert.Panics(t, func() { + qb.Where("id = 1; DROP TABLE test_table") + }, "应该抛出异常:无效的条件") +} diff --git a/database/elasticsearch/README.md b/database/elasticsearch/README.md index 1758fae..7e2f319 100644 --- a/database/elasticsearch/README.md +++ b/database/elasticsearch/README.md @@ -14,42 +14,3 @@ - 动态映射(dynamic mapping) - 显式映射(explicit mapping) - 严格映射(strict mappings) - -## Docker部署 - -### 拉取镜像 - -```bash -docker pull bitnami/elasticsearch:latest -``` - -### 启动容器 - -```bash -docker run -itd \ - --name elasticsearch \ - -p 9200:9200 \ - -p 9300:9300 \ - -e ELASTICSEARCH_USERNAME=elastic \ - -e ELASTICSEARCH_PASSWORD=elastic \ - -e xpack.security.enabled=true \ - -e discovery.type=single-node \ - -e http.cors.enabled=true \ - -e http.cors.allow-origin=http://localhost:13580,http://127.0.0.1:13580 \ - -e http.cors.allow-headers=X-Requested-With,X-Auth-Token,Content-Type,Content-Length,Authorization \ - -e http.cors.allow-credentials=true \ - bitnami/elasticsearch:latest -``` - -安装管理工具: - -```bash -docker pull appbaseio/dejavu:latest - -docker run -itd \ - --name dejavu-test \ - -p 13580:1358 \ - appbaseio/dejavu:latest -``` - -访问管理工具: diff --git a/database/elasticsearch/client.go b/database/elasticsearch/client.go index d7ef87b..18f0306 100644 --- a/database/elasticsearch/client.go +++ b/database/elasticsearch/client.go @@ -6,8 +6,6 @@ import ( "encoding/json" "io" - "github.com/go-kratos/kratos/v2/encoding" - _ "github.com/go-kratos/kratos/v2/encoding/json" "github.com/go-kratos/kratos/v2/log" elasticsearchV9 "github.com/elastic/go-elasticsearch/v9" @@ -18,15 +16,13 @@ import ( ) type Client struct { - cli *elasticsearchV9.Client - log *log.Helper - codec encoding.Codec + cli *elasticsearchV9.Client + log *log.Helper } func NewClient(logger log.Logger, cfg *conf.Bootstrap) (*Client, error) { c := &Client{ - log: log.NewHelper(log.With(logger, "module", "elasticsearch-client")), - codec: encoding.GetCodec("json"), + log: log.NewHelper(log.With(logger, "module", "elasticsearch-client")), } if err := c.createESClient(cfg); err != nil { diff --git a/database/elasticsearch/go.mod b/database/elasticsearch/go.mod index cd5a959..dfce9f4 100644 --- a/database/elasticsearch/go.mod +++ b/database/elasticsearch/go.mod @@ -10,7 +10,7 @@ require ( github.com/elastic/go-elasticsearch/v9 v9.0.0 github.com/go-kratos/kratos/v2 v2.8.4 github.com/stretchr/testify v1.10.0 - github.com/tx7do/kratos-bootstrap/api v0.0.25 + github.com/tx7do/kratos-bootstrap/api v0.0.27 ) require ( @@ -19,12 +19,13 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/gnostic v0.7.0 // indirect - github.com/google/gnostic-models v0.6.9 // indirect + github.com/google/gnostic-models v0.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect + go.yaml.in/yaml/v3 v3.0.3 // indirect golang.org/x/net v0.40.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect diff --git a/database/elasticsearch/go.sum b/database/elasticsearch/go.sum index 7543713..ead5129 100644 --- a/database/elasticsearch/go.sum +++ b/database/elasticsearch/go.sum @@ -722,8 +722,8 @@ github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6 github.com/google/gnostic v0.7.0 h1:d7EpuFp8vVdML+y0JJJYiKeOLjKTdH/GvVkLOBWqJpw= github.com/google/gnostic v0.7.0/go.mod h1:IAcUyMl6vtC95f60EZ8oXyqTsOersP6HbwjeG7EyDPM= github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49/go.mod h1:BkkQ4L1KS1xMt2aWSPStnn55ChGC0DPOn2FQYj+f25M= -github.com/google/gnostic-models v0.6.9 h1:MU/8wDLif2qCXZmzncUQ/BOfxWfthHi63KqpoNbWqVw= -github.com/google/gnostic-models v0.6.9/go.mod h1:CiWsm0s6BSQd1hRn8/QmxqB6BesYcbSZxsz9b0KuDBw= +github.com/google/gnostic-models v0.7.0 h1:qwTtogB15McXDaNqTZdzPJRHvaVJlAl+HVQnLmJEJxo= +github.com/google/gnostic-models v0.7.0/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -891,6 +891,8 @@ go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mx go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= +go.yaml.in/yaml/v3 v3.0.3 h1:bXOww4E/J3f66rav3pX3m8w6jDE4knZjGOw8b5Y6iNE= +go.yaml.in/yaml/v3 v3.0.3/go.mod h1:tBHosrYAkRZjRAOREWbDnBXUf08JOwYq++0QNwQiWzI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/database/influxdb/client.go b/database/influxdb/client.go index 489e8ff..48afb58 100644 --- a/database/influxdb/client.go +++ b/database/influxdb/client.go @@ -3,11 +3,8 @@ package influxdb import ( "context" - "github.com/go-kratos/kratos/v2/encoding" - _ "github.com/go-kratos/kratos/v2/encoding/json" - "github.com/go-kratos/kratos/v2/log" - "github.com/InfluxCommunity/influxdb3-go/v2/influxdb3" + "github.com/go-kratos/kratos/v2/log" conf "github.com/tx7do/kratos-bootstrap/api/gen/go/conf/v1" ) @@ -15,14 +12,12 @@ import ( type Client struct { cli *influxdb3.Client - log *log.Helper - codec encoding.Codec + log *log.Helper } func NewClient(logger log.Logger, cfg *conf.Bootstrap) (*Client, error) { c := &Client{ - log: log.NewHelper(log.With(logger, "module", "influxdb-client")), - codec: encoding.GetCodec("json"), + log: log.NewHelper(log.With(logger, "module", "influxdb-client")), } if err := c.createInfluxdbClient(cfg); err != nil { @@ -87,6 +82,31 @@ func (c *Client) Query(ctx context.Context, query string) (*influxdb3.QueryItera return result, nil } +func (c *Client) QueryWithParams( + ctx context.Context, + table string, + filters map[string]interface{}, + operators map[string]string, + fields []string, +) (*influxdb3.QueryIterator, error) { + if c.cli == nil { + return nil, ErrInfluxDBClientNotInitialized + } + + query := BuildQueryWithParams(table, filters, operators, fields) + result, err := c.cli.Query( + ctx, + query, + influxdb3.WithQueryType(influxdb3.InfluxQL), + ) + if err != nil { + c.log.Errorf("failed to query data: %v", err) + return nil, ErrInfluxDBQueryFailed + } + + return result, nil +} + // Insert 插入数据 func (c *Client) Insert(ctx context.Context, point *influxdb3.Point) error { if c.cli == nil { diff --git a/database/influxdb/go.mod b/database/influxdb/go.mod index 7e37a26..690a81a 100644 --- a/database/influxdb/go.mod +++ b/database/influxdb/go.mod @@ -11,7 +11,7 @@ require ( github.com/go-kratos/kratos/v2 v2.8.4 github.com/stretchr/testify v1.10.0 github.com/tx7do/go-utils v1.1.29 - github.com/tx7do/kratos-bootstrap/api v0.0.25 + github.com/tx7do/kratos-bootstrap/api v0.0.27 google.golang.org/protobuf v1.36.6 ) diff --git a/database/influxdb/utils.go b/database/influxdb/utils.go index 137ffc1..809bde4 100644 --- a/database/influxdb/utils.go +++ b/database/influxdb/utils.go @@ -10,6 +10,43 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +func BuildQuery( + table string, + filters map[string]interface{}, + operators map[string]string, + fields []string, +) (string, []interface{}) { + var queryBuilder strings.Builder + args := make([]interface{}, 0) + + // 构建 SELECT 语句 + queryBuilder.WriteString("SELECT ") + if len(fields) > 0 { + queryBuilder.WriteString(strings.Join(fields, ", ")) + } else { + queryBuilder.WriteString("*") + } + queryBuilder.WriteString(fmt.Sprintf(" FROM %s", table)) + + // 构建 WHERE 条件 + if len(filters) > 0 { + queryBuilder.WriteString(" WHERE ") + var conditions []string + var operator string + for key, value := range filters { + operator = "=" // 默认操作符 + if op, exists := operators[key]; exists { + operator = op + } + conditions = append(conditions, fmt.Sprintf("%s %s ?", key, operator)) + args = append(args, value) + } + queryBuilder.WriteString(strings.Join(conditions, " AND ")) + } + + return queryBuilder.String(), args +} + func GetPointTag(point *influxdb3.Point, name string) *string { if point == nil { return nil diff --git a/database/influxdb/utils_test.go b/database/influxdb/utils_test.go index 8245a39..4af2633 100644 --- a/database/influxdb/utils_test.go +++ b/database/influxdb/utils_test.go @@ -1,9 +1,86 @@ package influxdb import ( + "reflect" "testing" ) +func TestBuildQuery(t *testing.T) { + tests := []struct { + name string + table string + filters map[string]interface{} + operators map[string]string + fields []string + expectedQuery string + expectedArgs []interface{} + }{ + { + name: "Basic query with filters and fields", + table: "candles", + filters: map[string]interface{}{"s": "AAPL", "o": 150.0}, + fields: []string{"s", "o", "h", "l", "c", "v"}, + expectedQuery: "SELECT s, o, h, l, c, v FROM candles WHERE s = ? AND o = ?", + expectedArgs: []interface{}{"AAPL", 150.0}, + }, + { + name: "Query with no filters", + table: "candles", + filters: map[string]interface{}{}, + fields: []string{"s", "o", "h"}, + expectedQuery: "SELECT s, o, h FROM candles", + expectedArgs: []interface{}{}, + }, + { + name: "Query with no fields", + table: "candles", + filters: map[string]interface{}{"s": "AAPL"}, + fields: []string{}, + expectedQuery: "SELECT * FROM candles WHERE s = ?", + expectedArgs: []interface{}{"AAPL"}, + }, + { + name: "Empty table name", + table: "", + filters: map[string]interface{}{"s": "AAPL"}, + fields: []string{"s", "o"}, + expectedQuery: "SELECT s, o FROM WHERE s = ?", + expectedArgs: []interface{}{"AAPL"}, + }, + { + name: "Special characters in filters", + table: "candles", + filters: map[string]interface{}{"name": "O'Reilly"}, + fields: []string{"name"}, + expectedQuery: "SELECT name FROM candles WHERE name = ?", + expectedArgs: []interface{}{"O'Reilly"}, + }, + { + name: "Query with interval filters", + table: "candles", + filters: map[string]interface{}{"time": "now() - interval '15 minutes'"}, + fields: []string{"*"}, + operators: map[string]string{"time": ">="}, + expectedQuery: "SELECT * FROM candles WHERE time >= ?", + expectedArgs: []interface{}{"now() - interval '15 minutes'"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query, args := BuildQuery(tt.table, tt.filters, tt.operators, tt.fields) + + if query != tt.expectedQuery { + t.Errorf("expected query %s, got %s", tt.expectedQuery, query) + } + + if !reflect.DeepEqual(args, tt.expectedArgs) { + t.Errorf("expected args %v, got %v", tt.expectedArgs, args) + } + }) + } +} + func TestBuildQueryWithParams(t *testing.T) { tests := []struct { name string diff --git a/database/mongodb/consts.go b/database/mongodb/consts.go index 41a78ef..aa9d108 100644 --- a/database/mongodb/consts.go +++ b/database/mongodb/consts.go @@ -50,8 +50,6 @@ const ( OperatorInc = "$inc" // 增加值 OperatorMul = "$mul" // 乘法 OperatorRename = "$rename" // 重命名字段 - OperatorMin = "$min" // 设置最小值 - OperatorMax = "$max" // 设置最大值 OperatorCurrentDate = "$currentDate" // 设置当前日期 OperatorAddToSet = "$addToSet" // 添加到集合 OperatorPop = "$pop" // 删除数组中的元素 @@ -81,6 +79,23 @@ const ( OperatorIndexStats = "$indexStats" // 索引统计 OperatorOut = "$out" // 输出 OperatorMerge = "$merge" // 合并 + OperatorSum = "$sum" // 求和 + OperatorAvg = "$avg" // 平均值 + OperatorMin = "$min" // 最小值 + OperatorMax = "$max" // 最大值 + OperatorFirst = "$first" // 第一个值 + OperatorLast = "$last" // 最后一个值 + OperatorStdDevPop = "$stdDevPop" // 总体标准差 + OperatorStdDevSamp = "$stdDevSamp" // 样本标准差 + + // 类型转换操作符 + + OperatorToLong = "$toLong" // 转换为 long 类型 + OperatorToDouble = "$toDouble" // 转换为 double 类型 + OperatorToDecimal = "$toDecimal" // 转换为 decimal 类型 + OperatorToString = "$toString" // 转换为 string 类型 + OperatorToDate = "$toDate" // 转换为 date 类型 + OperatorToInt = "$toInt" // 转换为 int 类型 // 地理空间操作符 diff --git a/database/mongodb/go.mod b/database/mongodb/go.mod index 40066c4..b5b7ae1 100644 --- a/database/mongodb/go.mod +++ b/database/mongodb/go.mod @@ -10,7 +10,7 @@ require ( github.com/go-kratos/kratos/v2 v2.8.4 github.com/stretchr/testify v1.10.0 github.com/tx7do/go-utils v1.1.29 - github.com/tx7do/kratos-bootstrap/api v0.0.26 + github.com/tx7do/kratos-bootstrap/api v0.0.27 go.mongodb.org/mongo-driver/v2 v2.2.2 google.golang.org/protobuf v1.36.6 ) diff --git a/database/mongodb/query.go b/database/mongodb/query.go index 7708fd4..c920620 100644 --- a/database/mongodb/query.go +++ b/database/mongodb/query.go @@ -6,8 +6,9 @@ import ( ) type QueryBuilder struct { - filter bsonV2.M - opts *optionsV2.FindOptions + filter bsonV2.M + opts *optionsV2.FindOptions + pipeline []bsonV2.M } func NewQuery() *QueryBuilder { @@ -211,6 +212,17 @@ func (qb *QueryBuilder) SetPage(page, size int64) *QueryBuilder { return qb } +// AddStage 添加聚合阶段到管道 +func (qb *QueryBuilder) AddStage(stage bsonV2.M) *QueryBuilder { + qb.pipeline = append(qb.pipeline, stage) + return qb +} + +// BuildPipeline 返回最终的聚合管道 +func (qb *QueryBuilder) BuildPipeline() []bsonV2.M { + return qb.pipeline +} + // Build 返回最终的过滤条件和查询选项 func (qb *QueryBuilder) Build() (bsonV2.M, *optionsV2.FindOptions) { return qb.filter, qb.opts diff --git a/database/mongodb/query_test.go b/database/mongodb/query_test.go index 5b4ab5c..8d036ce 100644 --- a/database/mongodb/query_test.go +++ b/database/mongodb/query_test.go @@ -217,3 +217,22 @@ func TestSetNearSphere(t *testing.T) { assert.Equal(t, expected, qb.filter[field]) } + +func TestQueryBuilderPipeline(t *testing.T) { + // 创建 QueryBuilder 实例 + qb := NewQuery() + + // 添加聚合阶段 + matchStage := bsonV2.M{OperatorMatch: bsonV2.M{"status": "active"}} + groupStage := bsonV2.M{OperatorGroup: bsonV2.M{"_id": "$category", "count": bsonV2.M{OperatorSum: 1}}} + sortStage := bsonV2.M{OperatorSortAgg: bsonV2.M{"count": -1}} + + qb.AddStage(matchStage).AddStage(groupStage).AddStage(sortStage) + + // 构建 Pipeline + pipeline := qb.BuildPipeline() + + // 验证 Pipeline + expectedPipeline := []bsonV2.M{matchStage, groupStage, sortStage} + assert.Equal(t, expectedPipeline, pipeline) +} diff --git a/tag.bat b/tag.bat index a656ef0..5df98ea 100644 --- a/tag.bat +++ b/tag.bat @@ -10,11 +10,11 @@ git tag tracer/v0.0.10 --force git tag database/ent/v0.0.10 --force git tag database/gorm/v0.0.10 --force -git tag database/mongodb/v0.0.11 --force -git tag database/influxdb/v0.0.11 --force +git tag database/mongodb/v0.0.12 --force +git tag database/influxdb/v0.0.12 --force +git tag database/clickhouse/v0.0.12 --force +git tag database/elasticsearch/v0.0.12 --force git tag database/cassandra/v0.0.10 --force -git tag database/clickhouse/v0.0.10 --force -git tag database/elasticsearch/v0.0.1 --force git tag registry/v0.1.0 --force git tag registry/consul/v0.1.0 --force