feat: database.
This commit is contained in:
198
database/clickhouse/batch.go
Normal file
198
database/clickhouse/batch.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
clickhouseV2 "github.com/ClickHouse/clickhouse-go/v2"
|
||||
driverV2 "github.com/ClickHouse/clickhouse-go/v2/lib/driver"
|
||||
)
|
||||
|
||||
// BatchInserter 批量插入器
|
||||
type BatchInserter struct {
|
||||
conn clickhouseV2.Conn
|
||||
tableName string
|
||||
columns []string
|
||||
batchSize int
|
||||
rows []interface{}
|
||||
insertStmt string
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewBatchInserter 创建新的批量插入器
|
||||
func NewBatchInserter(
|
||||
ctx context.Context,
|
||||
conn clickhouseV2.Conn,
|
||||
tableName string,
|
||||
batchSize int,
|
||||
columns []string,
|
||||
) (*BatchInserter, error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1000 // 默认批量大小
|
||||
}
|
||||
|
||||
if len(columns) == 0 {
|
||||
return nil, errors.New("必须指定列名")
|
||||
}
|
||||
|
||||
// 构建INSERT语句
|
||||
placeholders := make([]string, len(columns))
|
||||
for i := range placeholders {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
|
||||
insertStmt := fmt.Sprintf(
|
||||
"INSERT INTO %s (%s) VALUES (%s)",
|
||||
tableName,
|
||||
strings.Join(columns, ", "),
|
||||
strings.Join(placeholders, ", "),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &BatchInserter{
|
||||
conn: conn,
|
||||
tableName: tableName,
|
||||
columns: columns,
|
||||
batchSize: batchSize,
|
||||
rows: make([]interface{}, 0, batchSize),
|
||||
insertStmt: insertStmt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add 添加数据行
|
||||
func (bi *BatchInserter) Add(row interface{}) error {
|
||||
bi.mu.Lock()
|
||||
defer bi.mu.Unlock()
|
||||
|
||||
// 检查上下文是否已取消
|
||||
if bi.ctx.Err() != nil {
|
||||
return bi.ctx.Err()
|
||||
}
|
||||
|
||||
bi.rows = append(bi.rows, row)
|
||||
|
||||
// 达到批量大小时自动提交
|
||||
if len(bi.rows) >= bi.batchSize {
|
||||
return bi.flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 强制提交当前批次
|
||||
func (bi *BatchInserter) Flush() error {
|
||||
bi.mu.Lock()
|
||||
defer bi.mu.Unlock()
|
||||
|
||||
return bi.flush()
|
||||
}
|
||||
|
||||
// Close 关闭插入器并提交剩余数据
|
||||
func (bi *BatchInserter) Close() error {
|
||||
defer bi.cancel()
|
||||
|
||||
bi.mu.Lock()
|
||||
defer bi.mu.Unlock()
|
||||
|
||||
return bi.flush()
|
||||
}
|
||||
|
||||
// flush 内部提交方法
|
||||
func (bi *BatchInserter) flush() error {
|
||||
if len(bi.rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建批量
|
||||
batch, err := bi.conn.PrepareBatch(bi.ctx, bi.insertStmt)
|
||||
if err != nil {
|
||||
return ErrBatchPrepareFailed
|
||||
}
|
||||
|
||||
// 添加所有行
|
||||
for _, row := range bi.rows {
|
||||
// 使用反射获取字段值
|
||||
if err = appendStructToBatch(batch, row, bi.columns); err != nil {
|
||||
return ErrBatchAppendFailed
|
||||
}
|
||||
}
|
||||
|
||||
// 提交批量
|
||||
if err = batch.Send(); err != nil {
|
||||
return ErrBatchSendFailed
|
||||
}
|
||||
|
||||
// 清空批次
|
||||
bi.rows = bi.rows[:0]
|
||||
return nil
|
||||
}
|
||||
|
||||
// appendStructToBatch 使用反射将结构体字段添加到批次
|
||||
func appendStructToBatch(batch driverV2.Batch, obj interface{}, columns []string) error {
|
||||
v := reflect.ValueOf(obj)
|
||||
|
||||
// 如果是指针,获取指针指向的值
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return errors.New("nil指针")
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
// 必须是结构体
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("期望结构体类型,得到 %v", v.Kind())
|
||||
}
|
||||
|
||||
// 获取结构体类型
|
||||
t := v.Type()
|
||||
|
||||
// 准备参数值
|
||||
values := make([]interface{}, len(columns))
|
||||
|
||||
// 映射列名到结构体字段
|
||||
for i, col := range columns {
|
||||
// 查找匹配的字段
|
||||
found := false
|
||||
for j := 0; j < v.NumField(); j++ {
|
||||
field := t.Field(j)
|
||||
|
||||
// 检查ch标签
|
||||
if tag := field.Tag.Get("ch"); tag == col {
|
||||
values[i] = v.Field(j).Interface()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
// 检查json标签
|
||||
if tag := field.Tag.Get("json"); tag == col {
|
||||
values[i] = v.Field(j).Interface()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
// 检查字段名
|
||||
if field.Name == col {
|
||||
values[i] = v.Field(j).Interface()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf("未找到列 %s 对应的结构体字段", col)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加到批次
|
||||
return batch.Append(values...)
|
||||
}
|
||||
Reference in New Issue
Block a user