200 lines
3.9 KiB
Go
200 lines
3.9 KiB
Go
package clickhouse
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
|
|
clickhouseV2 "github.com/ClickHouse/clickhouse-go/v2"
|
|
driverV2 "github.com/ClickHouse/clickhouse-go/v2/lib/driver"
|
|
)
|
|
|
|
// BatchInserter 批量插入器
|
|
type BatchInserter struct {
|
|
conn clickhouseV2.Conn
|
|
tableName string
|
|
columns []string
|
|
batchSize int
|
|
rows []interface{}
|
|
insertStmt string
|
|
mu sync.Mutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewBatchInserter 创建新的批量插入器
|
|
func NewBatchInserter(
|
|
ctx context.Context,
|
|
conn clickhouseV2.Conn,
|
|
tableName string,
|
|
batchSize int,
|
|
columns []string,
|
|
) (*BatchInserter, error) {
|
|
if batchSize <= 0 {
|
|
batchSize = 1000 // 默认批量大小
|
|
}
|
|
|
|
if len(columns) == 0 {
|
|
return nil, errors.New("必须指定列名")
|
|
}
|
|
|
|
// 构建INSERT语句
|
|
placeholders := make([]string, len(columns))
|
|
for i := range placeholders {
|
|
placeholders[i] = "?"
|
|
}
|
|
|
|
insertStmt := fmt.Sprintf(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
tableName,
|
|
strings.Join(columns, ", "),
|
|
strings.Join(placeholders, ", "),
|
|
)
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
return &BatchInserter{
|
|
conn: conn,
|
|
tableName: tableName,
|
|
columns: columns,
|
|
batchSize: batchSize,
|
|
rows: make([]interface{}, 0, batchSize),
|
|
insertStmt: insertStmt,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}, nil
|
|
}
|
|
|
|
// Add 添加数据行
|
|
func (bi *BatchInserter) Add(row interface{}) error {
|
|
bi.mu.Lock()
|
|
defer bi.mu.Unlock()
|
|
|
|
// 检查上下文是否已取消
|
|
if bi.ctx.Err() != nil {
|
|
return bi.ctx.Err()
|
|
}
|
|
|
|
bi.rows = append(bi.rows, row)
|
|
|
|
// 达到批量大小时自动提交
|
|
if len(bi.rows) >= bi.batchSize {
|
|
return bi.flush()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Flush 强制提交当前批次
|
|
func (bi *BatchInserter) Flush() error {
|
|
bi.mu.Lock()
|
|
defer bi.mu.Unlock()
|
|
|
|
return bi.flush()
|
|
}
|
|
|
|
// Close 关闭插入器并提交剩余数据
|
|
func (bi *BatchInserter) Close() error {
|
|
defer bi.cancel()
|
|
|
|
bi.mu.Lock()
|
|
defer bi.mu.Unlock()
|
|
|
|
return bi.flush()
|
|
}
|
|
|
|
// flush 内部提交方法
|
|
func (bi *BatchInserter) flush() error {
|
|
if len(bi.rows) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// 创建批量
|
|
batch, err := bi.conn.PrepareBatch(bi.ctx, bi.insertStmt)
|
|
if err != nil {
|
|
return ErrBatchPrepareFailed
|
|
}
|
|
|
|
// 添加所有行
|
|
for _, row := range bi.rows {
|
|
// 使用反射获取字段值
|
|
if err = appendStructToBatch(batch, row, bi.columns); err != nil {
|
|
return ErrBatchAppendFailed
|
|
}
|
|
}
|
|
|
|
// 提交批量
|
|
if err = batch.Send(); err != nil {
|
|
return ErrBatchSendFailed
|
|
}
|
|
|
|
// 清空批次
|
|
bi.rows = bi.rows[:0]
|
|
return nil
|
|
}
|
|
|
|
// appendStructToBatch 使用反射将结构体字段添加到批次
|
|
func appendStructToBatch(batch driverV2.Batch, obj interface{}, columns []string) error {
|
|
v := reflect.ValueOf(obj)
|
|
|
|
// 如果是指针,获取指针指向的值
|
|
if v.Kind() == reflect.Ptr {
|
|
if v.IsNil() {
|
|
return errors.New("nil指针")
|
|
}
|
|
v = v.Elem()
|
|
}
|
|
|
|
// 必须是结构体
|
|
if v.Kind() != reflect.Struct {
|
|
return fmt.Errorf("期望结构体类型,得到 %v", v.Kind())
|
|
}
|
|
|
|
// 获取结构体类型
|
|
t := v.Type()
|
|
|
|
// 准备参数值
|
|
values := make([]interface{}, len(columns))
|
|
|
|
// 映射列名到结构体字段
|
|
for i, col := range columns {
|
|
// 查找匹配的字段
|
|
found := false
|
|
for j := 0; j < v.NumField(); j++ {
|
|
field := t.Field(j)
|
|
|
|
// 检查ch标签
|
|
if tag := field.Tag.Get("ch"); strings.TrimSpace(tag) == col {
|
|
values[i] = v.Field(j).Interface()
|
|
found = true
|
|
break
|
|
}
|
|
|
|
// 检查json标签
|
|
jsonTags := strings.Split(field.Tag.Get("json"), ",")
|
|
if len(jsonTags) > 0 && strings.TrimSpace(jsonTags[0]) == col {
|
|
values[i] = v.Field(j).Interface()
|
|
found = true
|
|
break
|
|
}
|
|
|
|
// 检查字段名
|
|
if field.Name == col {
|
|
values[i] = v.Field(j).Interface()
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
return fmt.Errorf("未找到列 %s 对应的结构体字段", col)
|
|
}
|
|
}
|
|
|
|
// 添加到批次
|
|
return batch.Append(values...)
|
|
}
|