Files
kratos-bootstrap/rpc/rest.go
2024-11-13 23:48:48 +08:00

136 lines
4.1 KiB
Go

package rpc
import (
"crypto/tls"
"net/http/pprof"
"github.com/gorilla/handlers"
"github.com/go-kratos/aegis/ratelimit"
"github.com/go-kratos/aegis/ratelimit/bbr"
"github.com/go-kratos/kratos/v2/middleware"
midRateLimit "github.com/go-kratos/kratos/v2/middleware/ratelimit"
"github.com/go-kratos/kratos/v2/middleware/recovery"
"github.com/go-kratos/kratos/v2/middleware/tracing"
"github.com/go-kratos/kratos/v2/middleware/validate"
kratosRest "github.com/go-kratos/kratos/v2/transport/http"
conf "github.com/tx7do/kratos-bootstrap/api/gen/go/conf/v1"
"github.com/tx7do/kratos-bootstrap/utils"
)
// CreateRestServer 创建REST服务端
func CreateRestServer(cfg *conf.Bootstrap, mds ...middleware.Middleware) *kratosRest.Server {
var options []kratosRest.ServerOption
options = append(options, initRestConfig(cfg, mds...)...)
srv := kratosRest.NewServer(options...)
if cfg.Server != nil && cfg.Server.Rest != nil && cfg.Server.Rest.GetEnablePprof() {
registerHttpPprof(srv)
}
return srv
}
func initRestConfig(cfg *conf.Bootstrap, mds ...middleware.Middleware) []kratosRest.ServerOption {
if cfg.Server == nil || cfg.Server.Rest == nil {
return nil
}
var options []kratosRest.ServerOption
if cfg.Server.Rest.Cors != nil {
options = append(options, kratosRest.Filter(handlers.CORS(
handlers.AllowedHeaders(cfg.Server.Rest.Cors.Headers),
handlers.AllowedMethods(cfg.Server.Rest.Cors.Methods),
handlers.AllowedOrigins(cfg.Server.Rest.Cors.Origins),
)))
}
var ms []middleware.Middleware
ms = append(ms, mds...)
if cfg.Server.Rest.Middleware != nil {
if cfg.Server.Rest.Middleware.GetEnableRecovery() {
ms = append(ms, recovery.Recovery())
}
if cfg.Server.Rest.Middleware.GetEnableTracing() {
ms = append(ms, tracing.Server())
}
if cfg.Server.Rest.Middleware.GetEnableValidate() {
ms = append(ms, validate.Validator())
}
if cfg.Server.Rest.Middleware.GetEnableCircuitBreaker() {
}
if cfg.Server.Rest.Middleware.Limiter != nil {
var limiter ratelimit.Limiter
switch cfg.Server.Rest.Middleware.Limiter.GetName() {
case "bbr":
limiter = bbr.NewLimiter()
}
ms = append(ms, midRateLimit.Server(midRateLimit.WithLimiter(limiter)))
}
}
options = append(options, kratosRest.Middleware(ms...))
if cfg.Server.Rest.Network != "" {
options = append(options, kratosRest.Network(cfg.Server.Rest.Network))
}
if cfg.Server.Rest.Addr != "" {
options = append(options, kratosRest.Address(cfg.Server.Rest.Addr))
}
if cfg.Server.Rest.Timeout != nil {
options = append(options, kratosRest.Timeout(cfg.Server.Rest.Timeout.AsDuration()))
}
if cfg.Server.Rest.Tls != nil {
var tlsCfg *tls.Config
var err error
if cfg.Server.Rest.Tls.File != nil {
if tlsCfg, err = utils.LoadServerTlsConfigFile(
cfg.Server.Rest.Tls.File.GetKeyPath(),
cfg.Server.Rest.Tls.File.GetCertPath(),
cfg.Server.Rest.Tls.File.GetCaPath(),
cfg.Server.Rest.Tls.InsecureSkipVerify,
); err != nil {
panic(err)
}
}
if tlsCfg == nil && cfg.Server.Rest.Tls.Config != nil {
if tlsCfg, err = utils.LoadServerTlsConfig(
cfg.Server.Rest.Tls.Config.GetKeyPem(),
cfg.Server.Rest.Tls.Config.GetCertPem(),
cfg.Server.Rest.Tls.Config.GetCaPem(),
cfg.Server.Rest.Tls.InsecureSkipVerify,
); err != nil {
panic(err)
}
}
if tlsCfg != nil {
options = append(options, kratosRest.TLSConfig(tlsCfg))
}
}
return options
}
func registerHttpPprof(s *kratosRest.Server) {
s.HandleFunc("/debug/pprof", pprof.Index)
s.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
s.HandleFunc("/debug/pprof/profile", pprof.Profile)
s.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
s.HandleFunc("/debug/pprof/trace", pprof.Trace)
s.HandleFunc("/debug/pprof/allocs", pprof.Handler("allocs").ServeHTTP)
s.HandleFunc("/debug/pprof/block", pprof.Handler("block").ServeHTTP)
s.HandleFunc("/debug/pprof/goroutine", pprof.Handler("goroutine").ServeHTTP)
s.HandleFunc("/debug/pprof/heap", pprof.Handler("heap").ServeHTTP)
s.HandleFunc("/debug/pprof/mutex", pprof.Handler("mutex").ServeHTTP)
s.HandleFunc("/debug/pprof/threadcreate", pprof.Handler("threadcreate").ServeHTTP)
}