Files
kratos-bootstrap/database/mongodb/client.go
2025-06-26 21:53:55 +08:00

174 lines
5.0 KiB
Go

package mongodb
import (
"context"
"time"
"github.com/go-kratos/kratos/v2/log"
mongoV2 "go.mongodb.org/mongo-driver/v2/mongo"
optionsV2 "go.mongodb.org/mongo-driver/v2/mongo/options"
conf "github.com/tx7do/kratos-bootstrap/api/gen/go/conf/v1"
)
type Client struct {
log *log.Helper
cli *mongoV2.Client
database string
timeout time.Duration
}
func NewClient(logger log.Logger, cfg *conf.Bootstrap) (*Client, error) {
c := &Client{
log: log.NewHelper(log.With(logger, "module", "mongodb-client")),
}
if err := c.createMongodbClient(cfg); err != nil {
return nil, err
}
return c, nil
}
// createMongodbClient 创建MongoDB客户端
func (c *Client) createMongodbClient(cfg *conf.Bootstrap) error {
if cfg.Data == nil || cfg.Data.Mongodb == nil {
return nil
}
var opts []*optionsV2.ClientOptions
if cfg.Data.Mongodb.GetUri() != "" {
opts = append(opts, optionsV2.Client().ApplyURI(cfg.Data.Mongodb.GetUri()))
}
if cfg.Data.Mongodb.GetUsername() != "" && cfg.Data.Mongodb.GetPassword() != "" {
credential := optionsV2.Credential{
Username: cfg.Data.Mongodb.GetUsername(),
Password: cfg.Data.Mongodb.GetPassword(),
}
if cfg.Data.Mongodb.GetPassword() != "" {
credential.PasswordSet = true
}
opts = append(opts, optionsV2.Client().SetAuth(credential))
}
if cfg.Data.Mongodb.ConnectTimeout != nil {
opts = append(opts, optionsV2.Client().SetConnectTimeout(cfg.Data.Mongodb.GetConnectTimeout().AsDuration()))
}
if cfg.Data.Mongodb.ServerSelectionTimeout != nil {
opts = append(opts, optionsV2.Client().SetServerSelectionTimeout(cfg.Data.Mongodb.GetServerSelectionTimeout().AsDuration()))
}
if cfg.Data.Mongodb.Timeout != nil {
opts = append(opts, optionsV2.Client().SetTimeout(cfg.Data.Mongodb.GetTimeout().AsDuration()))
}
opts = append(opts, optionsV2.Client().SetBSONOptions(&optionsV2.BSONOptions{
UseJSONStructTags: true, // 使用JSON结构标签
}))
cli, err := mongoV2.Connect(opts...)
if err != nil {
c.log.Errorf("failed to create mongodb client: %v", err)
return err
}
c.database = cfg.Data.Mongodb.GetDatabase()
if cfg.Data.Mongodb.GetTimeout() != nil {
c.timeout = cfg.Data.Mongodb.GetTimeout().AsDuration()
} else {
c.timeout = 10 * time.Second // 默认超时时间
}
c.cli = cli
return nil
}
// Close 关闭MongoDB客户端
func (c *Client) Close() {
if c.cli == nil {
c.log.Warn("mongodb client is already closed or not initialized")
return
}
if err := c.cli.Disconnect(context.Background()); err != nil {
c.log.Errorf("failed to disconnect mongodb client: %v", err)
} else {
c.log.Info("mongodb client disconnected successfully")
}
}
// CheckConnect 检查MongoDB连接状态
func (c *Client) CheckConnect() {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
if err := c.cli.Ping(ctx, nil); err != nil {
c.log.Errorf("failed to ping mongodb: %v", err)
} else {
c.log.Info("mongodb client is connected")
}
}
// InsertOne 插入单个文档
func (c *Client) InsertOne(ctx context.Context, collection string, document interface{}) (*mongoV2.InsertOneResult, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
return c.cli.Database(c.database).Collection(collection).InsertOne(ctx, document)
}
// InsertMany 插入多个文档
func (c *Client) InsertMany(ctx context.Context, collection string, documents []interface{}) (*mongoV2.InsertManyResult, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
return c.cli.Database(c.database).Collection(collection).InsertMany(ctx, documents)
}
// FindOne 查询单个文档
func (c *Client) FindOne(ctx context.Context, collection string, filter interface{}, result interface{}) error {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
return c.cli.Database(c.database).Collection(collection).FindOne(ctx, filter).Decode(result)
}
// Find 查询多个文档
func (c *Client) Find(ctx context.Context, collection string, filter interface{}, results interface{}) error {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
cursor, err := c.cli.Database(c.database).Collection(collection).Find(ctx, filter)
if err != nil {
c.log.Errorf("failed to find documents in collection %s: %v", collection, err)
return err
}
defer func(cursor *mongoV2.Cursor, ctx context.Context) {
if err = cursor.Close(ctx); err != nil {
c.log.Errorf("failed to close cursor: %v", err)
}
}(cursor, ctx)
return cursor.All(ctx, results)
}
// UpdateOne 更新单个文档
func (c *Client) UpdateOne(ctx context.Context, collection string, filter, update interface{}) (*mongoV2.UpdateResult, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
return c.cli.Database(c.database).Collection(collection).UpdateOne(ctx, filter, update)
}
// DeleteOne 删除单个文档
func (c *Client) DeleteOne(ctx context.Context, collection string, filter interface{}) (*mongoV2.DeleteResult, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
return c.cli.Database(c.database).Collection(collection).DeleteOne(ctx, filter)
}