Files
kratos-bootstrap/database/clickhouse/batch.go
2025-06-29 18:36:43 +08:00

200 lines
3.9 KiB
Go

package clickhouse
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
clickhouseV2 "github.com/ClickHouse/clickhouse-go/v2"
driverV2 "github.com/ClickHouse/clickhouse-go/v2/lib/driver"
)
// BatchInserter 批量插入器
type BatchInserter struct {
conn clickhouseV2.Conn
tableName string
columns []string
batchSize int
rows []interface{}
insertStmt string
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
// NewBatchInserter 创建新的批量插入器
func NewBatchInserter(
ctx context.Context,
conn clickhouseV2.Conn,
tableName string,
batchSize int,
columns []string,
) (*BatchInserter, error) {
if batchSize <= 0 {
batchSize = 1000 // 默认批量大小
}
if len(columns) == 0 {
return nil, errors.New("必须指定列名")
}
// 构建INSERT语句
placeholders := make([]string, len(columns))
for i := range placeholders {
placeholders[i] = "?"
}
insertStmt := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
tableName,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "),
)
ctx, cancel := context.WithCancel(ctx)
return &BatchInserter{
conn: conn,
tableName: tableName,
columns: columns,
batchSize: batchSize,
rows: make([]interface{}, 0, batchSize),
insertStmt: insertStmt,
ctx: ctx,
cancel: cancel,
}, nil
}
// Add 添加数据行
func (bi *BatchInserter) Add(row interface{}) error {
bi.mu.Lock()
defer bi.mu.Unlock()
// 检查上下文是否已取消
if bi.ctx.Err() != nil {
return bi.ctx.Err()
}
bi.rows = append(bi.rows, row)
// 达到批量大小时自动提交
if len(bi.rows) >= bi.batchSize {
return bi.flush()
}
return nil
}
// Flush 强制提交当前批次
func (bi *BatchInserter) Flush() error {
bi.mu.Lock()
defer bi.mu.Unlock()
return bi.flush()
}
// Close 关闭插入器并提交剩余数据
func (bi *BatchInserter) Close() error {
defer bi.cancel()
bi.mu.Lock()
defer bi.mu.Unlock()
return bi.flush()
}
// flush 内部提交方法
func (bi *BatchInserter) flush() error {
if len(bi.rows) == 0 {
return nil
}
// 创建批量
batch, err := bi.conn.PrepareBatch(bi.ctx, bi.insertStmt)
if err != nil {
return ErrBatchPrepareFailed
}
// 添加所有行
for _, row := range bi.rows {
// 使用反射获取字段值
if err = appendStructToBatch(batch, row, bi.columns); err != nil {
return ErrBatchAppendFailed
}
}
// 提交批量
if err = batch.Send(); err != nil {
return ErrBatchSendFailed
}
// 清空批次
bi.rows = bi.rows[:0]
return nil
}
// appendStructToBatch 使用反射将结构体字段添加到批次
func appendStructToBatch(batch driverV2.Batch, obj interface{}, columns []string) error {
v := reflect.ValueOf(obj)
// 如果是指针,获取指针指向的值
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return errors.New("nil指针")
}
v = v.Elem()
}
// 必须是结构体
if v.Kind() != reflect.Struct {
return fmt.Errorf("期望结构体类型,得到 %v", v.Kind())
}
// 获取结构体类型
t := v.Type()
// 准备参数值
values := make([]interface{}, len(columns))
// 映射列名到结构体字段
for i, col := range columns {
// 查找匹配的字段
found := false
for j := 0; j < v.NumField(); j++ {
field := t.Field(j)
// 检查ch标签
if tag := field.Tag.Get("ch"); strings.TrimSpace(tag) == col {
values[i] = v.Field(j).Interface()
found = true
break
}
// 检查json标签
jsonTags := strings.Split(field.Tag.Get("json"), ",")
if len(jsonTags) > 0 && strings.TrimSpace(jsonTags[0]) == col {
values[i] = v.Field(j).Interface()
found = true
break
}
// 检查字段名
if field.Name == col {
values[i] = v.Field(j).Interface()
found = true
break
}
}
if !found {
return fmt.Errorf("未找到列 %s 对应的结构体字段", col)
}
}
// 添加到批次
return batch.Append(values...)
}