diff --git a/database/clickhouse/client.go b/database/clickhouse/client.go index c91e175..8edc595 100644 --- a/database/clickhouse/client.go +++ b/database/clickhouse/client.go @@ -8,13 +8,11 @@ import ( "net/url" "reflect" "strings" - "time" clickhouseV2 "github.com/ClickHouse/clickhouse-go/v2" driverV2 "github.com/ClickHouse/clickhouse-go/v2/lib/driver" "github.com/go-kratos/kratos/v2/log" - conf "github.com/tx7do/kratos-bootstrap/api/gen/go/conf/v1" "github.com/tx7do/kratos-bootstrap/utils" ) @@ -328,13 +326,14 @@ func (c *Client) prepareInsertData(data any) (string, string, []any, error) { val = val.Elem() typ := val.Type() - var columns []string - var placeholders []string - var values []any + columns := make([]string, 0, typ.NumField()) + placeholders := make([]string, 0, typ.NumField()) + values := make([]any, 0, typ.NumField()) + + values = structToValueArray(data) for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) - value := val.Field(i).Interface() // 优先获取 `cn` 标签,其次获取 `json` 标签,最后使用字段名 columnName := field.Tag.Get("cn") @@ -347,58 +346,6 @@ func (c *Client) prepareInsertData(data any) (string, string, []any, error) { columns = append(columns, columnName) placeholders = append(placeholders, "?") - - switch v := value.(type) { - case *sql.NullString: - if v.Valid { - values = append(values, v.String) - } else { - values = append(values, nil) - } - case *sql.NullInt64: - if v.Valid { - values = append(values, v.Int64) - } else { - values = append(values, nil) - } - case *sql.NullFloat64: - if v.Valid { - values = append(values, v.Float64) - } else { - values = append(values, nil) - } - case *sql.NullBool: - if v.Valid { - values = append(values, v.Bool) - } else { - values = append(values, nil) - } - - case *sql.NullTime: - if v != nil && v.Valid { - values = append(values, v.Time.Format("2006-01-02 15:04:05.000000000")) - } else { - values = append(values, nil) - } - - case *time.Time: - if v != nil { - values = append(values, v.Format("2006-01-02 15:04:05.000000000")) - } else { - values = append(values, nil) - } - - case time.Time: - // 处理 time.Time 类型 - if !v.IsZero() { - values = append(values, v.Format("2006-01-02 15:04:05.000000000")) - } else { - values = append(values, nil) // 如果时间为零值,插入 NULL - } - - default: - values = append(values, v) - } } return strings.Join(columns, ", "), strings.Join(placeholders, ", "), values, nil @@ -483,7 +430,33 @@ func (c *Client) InsertMany(ctx context.Context, tableName string, data []any) e } // AsyncInsert 异步插入数据 -func (c *Client) AsyncInsert(ctx context.Context, query string, wait bool, args ...any) error { +func (c *Client) AsyncInsert(ctx context.Context, tableName string, data any, wait bool) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + + // 准备插入数据 + columns, placeholders, values, err := c.prepareInsertData(data) + if err != nil { + c.log.Errorf("prepare insert data failed: %v", err) + return ErrPrepareInsertDataFailed + } + + // 构造 SQL 语句 + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, columns, placeholders) + + // 执行异步插入 + if err := c.asyncInsert(ctx, query, wait, values...); err != nil { + c.log.Errorf("async insert failed: %v", err) + return ErrAsyncInsertFailed + } + + return nil +} + +// asyncInsert 异步插入数据 +func (c *Client) asyncInsert(ctx context.Context, query string, wait bool, args ...any) error { if c.conn == nil { c.log.Error("clickhouse client is not initialized") return ErrClientNotInitialized @@ -497,8 +470,104 @@ func (c *Client) AsyncInsert(ctx context.Context, query string, wait bool, args return nil } +// AsyncInsertMany 批量异步插入数据 +func (c *Client) AsyncInsertMany(ctx context.Context, tableName string, data []any, wait bool) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + + if len(data) == 0 { + c.log.Error("data slice is empty") + return ErrInvalidColumnData + } + + // 准备插入数据的列名和占位符 + var columns string + var placeholders []string + var values []any + + for _, item := range data { + itemColumns, itemPlaceholders, itemValues, err := c.prepareInsertData(item) + if err != nil { + c.log.Errorf("prepare insert data failed: %v", err) + return ErrPrepareInsertDataFailed + } + + if columns == "" { + columns = itemColumns + } else if columns != itemColumns { + c.log.Error("data items have inconsistent columns") + return ErrInvalidColumnData + } + + placeholders = append(placeholders, fmt.Sprintf("(%s)", itemPlaceholders)) + values = append(values, itemValues...) + } + + // 构造 SQL 语句 + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", + tableName, + columns, + strings.Join(placeholders, ", "), + ) + + // 执行异步插入操作 + if err := c.asyncInsert(ctx, query, wait, values...); err != nil { + c.log.Errorf("batch insert failed: %v", err) + return err + } + + return nil +} + // BatchInsert 批量插入数据 -func (c *Client) BatchInsert(ctx context.Context, query string, data [][]any) error { +func (c *Client) BatchInsert(ctx context.Context, tableName string, data []any) error { + if c.conn == nil { + c.log.Error("clickhouse client is not initialized") + return ErrClientNotInitialized + } + + if len(data) == 0 { + c.log.Error("data slice is empty") + return ErrInvalidColumnData + } + + // 准备插入数据的列名和占位符 + var columns string + var values [][]any + + for _, item := range data { + itemColumns, _, itemValues, err := c.prepareInsertData(item) + if err != nil { + c.log.Errorf("prepare insert data failed: %v", err) + return ErrPrepareInsertDataFailed + } + + if columns == "" { + columns = itemColumns + } else if columns != itemColumns { + c.log.Error("data items have inconsistent columns") + return ErrInvalidColumnData + } + + values = append(values, itemValues) + } + + // 构造 SQL 语句 + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES", tableName, columns) + + // 调用 batchExec 方法执行批量插入 + if err := c.batchExec(ctx, query, values); err != nil { + c.log.Errorf("batch insert failed: %v", err) + return ErrBatchInsertFailed + } + + return nil +} + +// batchExec 执行批量操作 +func (c *Client) batchExec(ctx context.Context, query string, data [][]any) error { batch, err := c.conn.PrepareBatch(ctx, query) if err != nil { c.log.Errorf("failed to prepare batch: %v", err) @@ -520,8 +589,8 @@ func (c *Client) BatchInsert(ctx context.Context, query string, data [][]any) er return nil } -// BatchInsertStructs 批量插入结构体数据 -func (c *Client) BatchInsertStructs(ctx context.Context, query string, data []any) error { +// BatchStructs 批量插入结构体数据 +func (c *Client) BatchStructs(ctx context.Context, query string, data []any) error { if c.conn == nil { c.log.Error("clickhouse client is not initialized") return ErrClientNotInitialized diff --git a/database/clickhouse/client_test.go b/database/clickhouse/client_test.go index a6d1f1c..307c8f8 100644 --- a/database/clickhouse/client_test.go +++ b/database/clickhouse/client_test.go @@ -10,10 +10,6 @@ import ( conf "github.com/tx7do/kratos-bootstrap/api/gen/go/conf/v1" ) -func Ptr[T any](v T) *T { - return &v -} - type Candle struct { Timestamp *time.Time `json:"timestamp" ch:"timestamp"` Symbol *string `json:"symbol" ch:"symbol"` @@ -84,9 +80,8 @@ func TestInsertCandlesTable(t *testing.T) { createCandlesTable(client) // 测试数据 - now := time.Now() candle := &Candle{ - Timestamp: &now, + Timestamp: Ptr(time.Now()), Symbol: Ptr("AAPL"), Open: Ptr(100.5), High: Ptr(105.0), @@ -107,10 +102,9 @@ func TestInsertManyCandlesTable(t *testing.T) { createCandlesTable(client) // 测试数据 - now := time.Now() data := []any{ &Candle{ - Timestamp: &now, + Timestamp: Ptr(time.Now()), Symbol: Ptr("AAPL"), Open: Ptr(100.5), High: Ptr(105.0), @@ -119,7 +113,7 @@ func TestInsertManyCandlesTable(t *testing.T) { Volume: Ptr(1500.0), }, &Candle{ - Timestamp: &now, + Timestamp: Ptr(time.Now()), Symbol: Ptr("GOOG"), Open: Ptr(200.5), High: Ptr(205.0), @@ -134,7 +128,93 @@ func TestInsertManyCandlesTable(t *testing.T) { assert.NoError(t, err, "InsertManyCandlesTable 应该成功执行") } -func TestBatchInsertCandlesTable(t *testing.T) { +func TestAsyncInsertCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 测试数据 + candle := &Candle{ + Timestamp: Ptr(time.Now()), + Symbol: Ptr("BTC/USD"), + Open: Ptr(30000.0), + High: Ptr(31000.0), + Low: Ptr(29000.0), + Close: Ptr(30500.0), + Volume: Ptr(500.0), + } + + // 异步插入数据 + err := client.AsyncInsert(context.Background(), "candles", candle, true) + assert.NoError(t, err, "AsyncInsert 方法应该成功执行") + + // 验证插入结果 + query := ` + SELECT timestamp, symbol, open, high, low, close, volume + FROM candles + WHERE symbol = ? + ` + var result Candle + err = client.QueryRow(context.Background(), &result, query, "BTC/USD") + assert.NoError(t, err, "QueryRow 应该成功执行") + assert.Equal(t, "BTC/USD", *result.Symbol, "symbol 列值应该为 BTC/USD") + assert.Equal(t, 30500.0, *result.Close, "close 列值应该为 30500.0") + assert.Equal(t, 500.0, *result.Volume, "volume 列值应该为 500.0") +} + +func TestAsyncInsertManyCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 测试数据 + data := []any{ + &Candle{ + Timestamp: Ptr(time.Now()), + Symbol: Ptr("AAPL"), + Open: Ptr(100.5), + High: Ptr(105.0), + Low: Ptr(99.5), + Close: Ptr(102.0), + Volume: Ptr(1500.0), + }, + &Candle{ + Timestamp: Ptr(time.Now()), + Symbol: Ptr("GOOG"), + Open: Ptr(200.5), + High: Ptr(205.0), + Low: Ptr(199.5), + Close: Ptr(202.0), + Volume: Ptr(2500.0), + }, + &Candle{ + Timestamp: Ptr(time.Now()), + Symbol: Ptr("MSFT"), + Open: Ptr(300.5), + High: Ptr(305.0), + Low: Ptr(299.5), + Close: Ptr(302.0), + Volume: Ptr(3500.0), + }, + } + + // 批量插入数据 + err := client.AsyncInsertMany(context.Background(), "candles", data, true) + assert.NoError(t, err, "AsyncInsertMany 方法应该成功执行") + + // 验证插入结果 + query := ` + SELECT timestamp, symbol, open, high, low, close, volume + FROM candles + ` + var results []Candle + err = client.Select(context.Background(), &results, query) + assert.NoError(t, err, "查询数据应该成功执行") +} + +func TestInternalBatchExecCandlesTable(t *testing.T) { client := createTestClient() assert.NotNil(t, client) @@ -154,11 +234,62 @@ func TestBatchInsertCandlesTable(t *testing.T) { } // 批量插入数据 - err := client.BatchInsert(context.Background(), insertQuery, data) - assert.NoError(t, err, "BatchInsertCandlesTable 应该成功执行") + err := client.batchExec(context.Background(), insertQuery, data) + assert.NoError(t, err, "batchExec 应该成功执行") } -func TestBatchInsertStructsCandlesTable(t *testing.T) { +func TestBatchInsertCandlesTable(t *testing.T) { + client := createTestClient() + assert.NotNil(t, client) + + createCandlesTable(client) + + // 测试数据 + data := []any{ + &Candle{ + Timestamp: Ptr(time.Now()), + Symbol: Ptr("AAPL"), + Open: Ptr(100.5), + High: Ptr(105.0), + Low: Ptr(99.5), + Close: Ptr(102.0), + Volume: Ptr(1500.0), + }, + &Candle{ + Timestamp: Ptr(time.Now()), + Symbol: Ptr("GOOG"), + Open: Ptr(200.5), + High: Ptr(205.0), + Low: Ptr(199.5), + Close: Ptr(202.0), + Volume: Ptr(2500.0), + }, + &Candle{ + Timestamp: Ptr(time.Now()), + Symbol: Ptr("MSFT"), + Open: Ptr(300.5), + High: Ptr(305.0), + Low: Ptr(299.5), + Close: Ptr(302.0), + Volume: Ptr(3500.0), + }, + } + + // 批量插入数据 + err := client.BatchInsert(context.Background(), "candles", data) + assert.NoError(t, err, "BatchInsert 方法应该成功执行") + + // 验证插入结果 + query := ` + SELECT timestamp, symbol, open, high, low, close, volume + FROM candles + ` + var results []Candle + err = client.Select(context.Background(), &results, query) + assert.NoError(t, err, "查询数据应该成功执行") +} + +func TestBatchStructsCandlesTable(t *testing.T) { client := createTestClient() assert.NotNil(t, client) @@ -171,10 +302,9 @@ func TestBatchInsertStructsCandlesTable(t *testing.T) { ` // 测试数据 - now := time.Now() data := []any{ &Candle{ - Timestamp: &now, + Timestamp: Ptr(time.Now()), Symbol: Ptr("AAPL"), Open: Ptr(100.5), High: Ptr(105.0), @@ -183,7 +313,7 @@ func TestBatchInsertStructsCandlesTable(t *testing.T) { Volume: Ptr(1500.0), }, &Candle{ - Timestamp: &now, + Timestamp: Ptr(time.Now()), Symbol: Ptr("GOOG"), Open: Ptr(200.5), High: Ptr(205.0), @@ -194,11 +324,11 @@ func TestBatchInsertStructsCandlesTable(t *testing.T) { } // 批量插入数据 - err := client.BatchInsertStructs(context.Background(), insertQuery, data) - assert.NoError(t, err, "BatchInsertStructsCandlesTable 应该成功执行") + err := client.BatchStructs(context.Background(), insertQuery, data) + assert.NoError(t, err, "BatchStructsCandlesTable 应该成功执行") } -func TestAsyncInsertIntoCandlesTable(t *testing.T) { +func TestInternalAsyncInsertIntoCandlesTable(t *testing.T) { client := createTestClient() assert.NotNil(t, client) @@ -211,7 +341,7 @@ func TestAsyncInsertIntoCandlesTable(t *testing.T) { ` // 测试数据 - err := client.AsyncInsert(context.Background(), insertQuery, true, + err := client.asyncInsert(context.Background(), insertQuery, true, "2023-10-01 12:00:00", "AAPL", 100.5, 105.0, 99.5, 102.0, 1500.0) assert.NoError(t, err, "InsertIntoCandlesTable 应该成功执行") } @@ -300,7 +430,7 @@ func TestQueryRow(t *testing.T) { INSERT INTO candles (timestamp, symbol, open, high, low, close, volume) VALUES (?, ?, ?, ?, ?, ?, ?) ` - err := client.AsyncInsert(context.Background(), insertQuery, true, + err := client.asyncInsert(context.Background(), insertQuery, true, "2023-10-01 12:00:00", "AAPL", 100.5, 105.0, 99.5, 102.0, 1500.0) assert.NoError(t, err, "数据插入失败") diff --git a/database/clickhouse/errors.go b/database/clickhouse/errors.go index 3560277..2e24bd7 100644 --- a/database/clickhouse/errors.go +++ b/database/clickhouse/errors.go @@ -72,6 +72,9 @@ var ( // ErrBatchAppendFailed is returned when appending to a batch fails. ErrBatchAppendFailed = errors.InternalServer("BATCH_APPEND_FAILED", "batch append operation failed") + // ErrBatchInsertFailed is returned when a batch insert operation fails. + ErrBatchInsertFailed = errors.InternalServer("BATCH_INSERT_FAILED", "batch insert operation failed") + // ErrInvalidDSN is returned when the data source name (DSN) is invalid. ErrInvalidDSN = errors.InternalServer("INVALID_DSN", "invalid data source name") diff --git a/database/clickhouse/utils.go b/database/clickhouse/utils.go new file mode 100644 index 0000000..090d130 --- /dev/null +++ b/database/clickhouse/utils.go @@ -0,0 +1,118 @@ +package clickhouse + +import ( + "database/sql" + "reflect" + "time" + + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + timeFormat = "2006-01-02 15:04:05.000000000" +) + +func structToValueArray(input any) []any { + // 检查是否是指针类型,如果是则解引用 + val := reflect.ValueOf(input) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + // 确保输入是结构体 + if val.Kind() != reflect.Struct { + return nil + } + + var values []any + for i := 0; i < val.NumField(); i++ { + value := val.Field(i).Interface() + + switch v := value.(type) { + case *sql.NullString: + if v.Valid { + values = append(values, v.String) + } else { + values = append(values, nil) + } + case *sql.NullInt64: + if v.Valid { + values = append(values, v.Int64) + } else { + values = append(values, nil) + } + case *sql.NullFloat64: + if v.Valid { + values = append(values, v.Float64) + } else { + values = append(values, nil) + } + case *sql.NullBool: + if v.Valid { + values = append(values, v.Bool) + } else { + values = append(values, nil) + } + + case *sql.NullTime: + if v != nil && v.Valid { + values = append(values, v.Time.Format(timeFormat)) + } else { + values = append(values, nil) + } + + case *time.Time: + if v != nil { + values = append(values, v.Format(timeFormat)) + } else { + values = append(values, nil) + } + + case time.Time: + // 处理 time.Time 类型 + if !v.IsZero() { + values = append(values, v.Format(timeFormat)) + } else { + values = append(values, nil) // 如果时间为零值,插入 NULL + } + + case timestamppb.Timestamp: + // 处理 timestamppb.Timestamp 类型 + if !v.IsValid() { + values = append(values, v.AsTime().Format(timeFormat)) + } else { + values = append(values, nil) // 如果时间为零值,插入 NULL + } + + case *timestamppb.Timestamp: + // 处理 *timestamppb.Timestamp 类型 + if v != nil && v.IsValid() { + values = append(values, v.AsTime().Format(timeFormat)) + } else { + values = append(values, nil) // 如果时间为零值,插入 NULL + } + + case durationpb.Duration: + // 处理 timestamppb.Duration 类型 + if v.AsDuration() != 0 { + values = append(values, v.AsDuration().String()) + } else { + values = append(values, nil) // 如果时间为零值,插入 NULL + } + + case *durationpb.Duration: + // 处理 *timestamppb.Duration 类型 + if v != nil && v.AsDuration() != 0 { + values = append(values, v.AsDuration().String()) + } else { + values = append(values, nil) // 如果时间为零值,插入 NULL + } + + default: + values = append(values, v) + } + } + + return values +} diff --git a/database/clickhouse/utils_test.go b/database/clickhouse/utils_test.go new file mode 100644 index 0000000..c56abc8 --- /dev/null +++ b/database/clickhouse/utils_test.go @@ -0,0 +1,83 @@ +package clickhouse + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// Ptr returns a pointer to the provided value. +func Ptr[T any](v T) *T { + return &v +} + +func TestStructToValueArrayWithCandle(t *testing.T) { + now := time.Now() + + candle := Candle{ + Timestamp: Ptr(now), + Symbol: Ptr("AAPL"), + Open: Ptr(100.5), + High: Ptr(105.0), + Low: Ptr(99.5), + Close: Ptr(102.0), + Volume: Ptr(1500.0), + } + + values := structToValueArray(candle) + assert.NotNil(t, values, "Values should not be nil") + assert.Len(t, values, 7, "Expected 7 fields in the Candle struct") + assert.Equal(t, now.Format(timeFormat), values[0].(string), "Timestamp should match") + assert.Equal(t, *candle.Symbol, *(values[1].(*string)), "Symbol should match") + assert.Equal(t, *candle.Open, *values[2].(*float64), "Open price should match") + assert.Equal(t, *candle.High, *values[3].(*float64), "High price should match") + assert.Equal(t, *candle.Low, *values[4].(*float64), "Low price should match") + assert.Equal(t, *candle.Close, *values[5].(*float64), "Close price should match") + assert.Equal(t, *candle.Volume, *values[6].(*float64), "Volume should match") + + t.Logf("QueryRow Result: [%v] Candle: %s, Open: %f, High: %f, Low: %f, Close: %f, Volume: %f\n", + values[0], + *(values[1].(*string)), + *values[2].(*float64), + *values[3].(*float64), + *values[4].(*float64), + *values[5].(*float64), + *values[6].(*float64), + ) +} + +func TestStructToValueArrayWithCandlePtr(t *testing.T) { + now := time.Now() + + candle := &Candle{ + Timestamp: Ptr(now), + Symbol: Ptr("AAPL"), + Open: Ptr(100.5), + High: Ptr(105.0), + Low: Ptr(99.5), + Close: Ptr(102.0), + Volume: Ptr(1500.0), + } + + values := structToValueArray(candle) + assert.NotNil(t, values, "Values should not be nil") + assert.Len(t, values, 7, "Expected 7 fields in the Candle struct") + assert.Equal(t, now.Format(timeFormat), values[0].(string), "Timestamp should match") + assert.Equal(t, *candle.Symbol, *(values[1].(*string)), "Symbol should match") + assert.Equal(t, *candle.Open, *values[2].(*float64), "Open price should match") + assert.Equal(t, *candle.High, *values[3].(*float64), "High price should match") + assert.Equal(t, *candle.Low, *values[4].(*float64), "Low price should match") + assert.Equal(t, *candle.Close, *values[5].(*float64), "Close price should match") + assert.Equal(t, *candle.Volume, *values[6].(*float64), "Volume should match") + + t.Logf("QueryRow Result: [%v] Candle: %s, Open: %f, High: %f, Low: %f, Close: %f, Volume: %f\n", + values[0], + *(values[1].(*string)), + *values[2].(*float64), + *values[3].(*float64), + *values[4].(*float64), + *values[5].(*float64), + *values[6].(*float64), + ) +} diff --git a/tag.bat b/tag.bat index c02a2ef..600e15c 100644 --- a/tag.bat +++ b/tag.bat @@ -12,7 +12,7 @@ git tag database/ent/v0.0.10 --force git tag database/gorm/v0.0.10 --force git tag database/mongodb/v0.0.12 --force git tag database/influxdb/v0.0.12 --force -git tag database/clickhouse/v0.0.13 --force +git tag database/clickhouse/v0.0.14 --force git tag database/elasticsearch/v0.0.12 --force git tag database/cassandra/v0.0.10 --force