Bläddra i källkod

add gateway support direct connect

fancl 1 år sedan
förälder
incheckning
dc88ceb73d
8 ändrade filer med 121 tillägg och 38 borttagningar
  1. 1 0
      cmd/main.go
  2. 26 12
      entry/gateway.go
  3. 4 4
      entry/http/context.go
  4. 22 7
      entry/http/server.go
  5. 30 14
      options.go
  6. 8 0
      service.go
  7. 22 1
      util/env/env.go
  8. 8 0
      util/fetch/fetch.go

+ 1 - 0
cmd/main.go

@@ -29,6 +29,7 @@ func main() {
 	svr := kos.Init(
 		kos.WithName("git.nspix.com/golang/test", "0.0.1"),
 		kos.WithServer(&subServer{}),
+		kos.WithDirectHttp(),
 	)
 	svr.Run()
 }

+ 26 - 12
entry/gateway.go

@@ -37,20 +37,21 @@ type (
 		state      *State
 		waitGroup  conc.WaitGroup
 		listeners  []*listenerEntity
+		direct     *Listener
 		exitFlag   int32
 	}
 )
 
 func (gw *Gateway) handle(conn net.Conn) {
 	var (
-		n         int
-		err       error
-		successed int32
-		feature   = make([]byte, minFeatureLength)
+		n       int
+		err     error
+		success int32
+		feature = make([]byte, minFeatureLength)
 	)
 	atomic.AddInt32(&gw.state.Concurrency, 1)
 	defer func() {
-		if atomic.LoadInt32(&successed) != 1 {
+		if atomic.LoadInt32(&success) != 1 {
 			atomic.AddInt32(&gw.state.Concurrency, -1)
 			atomic.AddInt64(&gw.state.Request.Discarded, 1)
 			_ = conn.Close()
@@ -70,7 +71,7 @@ func (gw *Gateway) handle(conn net.Conn) {
 	}
 	for _, l := range gw.listeners {
 		if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
-			atomic.StoreInt32(&successed, 1)
+			atomic.StoreInt32(&success, 1)
 			l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
 			return
 		}
@@ -86,11 +87,16 @@ func (gw *Gateway) accept() {
 		if conn, err := gw.l.Accept(); err != nil {
 			break
 		} else {
-			select {
-			case gw.ch <- conn:
-				atomic.AddInt64(&gw.state.Request.Total, 1)
-			case <-gw.ctx.Done():
-				return
+			//give direct listener
+			if gw.direct != nil {
+				gw.direct.Receive(conn)
+			} else {
+				select {
+				case gw.ch <- conn:
+					atomic.AddInt64(&gw.state.Request.Total, 1)
+				case <-gw.ctx.Done():
+					return
+				}
 			}
 		}
 	}
@@ -113,6 +119,12 @@ func (gw *Gateway) worker() {
 	}
 }
 
+func (gw *Gateway) Direct(l net.Listener) {
+	if ls, ok := l.(*Listener); ok {
+		gw.direct = ls
+	}
+}
+
 func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
 	var (
 		ok bool
@@ -165,7 +177,9 @@ func (gw *Gateway) Start(ctx context.Context) (err error) {
 	if gw.l, err = net.Listen("tcp", gw.address); err != nil {
 		return
 	}
-	gw.waitGroup.Go(gw.worker)
+	for i := 0; i < 2; i++ {
+		gw.waitGroup.Go(gw.worker)
+	}
 	gw.waitGroup.Go(gw.accept)
 	return
 }

+ 4 - 4
entry/http/context.go

@@ -27,18 +27,18 @@ func (ctx *Context) reset(req *http.Request, res http.ResponseWriter, ps map[str
 	ctx.req, ctx.res, ctx.params = req, res, ps
 }
 
-func (c *Context) RealIp() string {
-	if ip := c.Request().Header.Get("X-Forwarded-For"); ip != "" {
+func (ctx *Context) RealIp() string {
+	if ip := ctx.Request().Header.Get("X-Forwarded-For"); ip != "" {
 		i := strings.IndexAny(ip, ",")
 		if i > 0 {
 			return strings.TrimSpace(ip[:i])
 		}
 		return ip
 	}
-	if ip := c.Request().Header.Get("X-Real-IP"); ip != "" {
+	if ip := ctx.Request().Header.Get("X-Real-IP"); ip != "" {
 		return ip
 	}
-	ra, _, _ := net.SplitHostPort(c.Request().RemoteAddr)
+	ra, _, _ := net.SplitHostPort(ctx.Request().RemoteAddr)
 	return ra
 }
 

+ 22 - 7
entry/http/server.go

@@ -16,10 +16,11 @@ var (
 )
 
 type Server struct {
-	ctx        context.Context
-	serve      *http.Server
-	router     *router.Router
-	middleware []Middleware
+	ctx         context.Context
+	serve       *http.Server
+	router      *router.Router
+	middleware  []Middleware
+	anyRequests map[string]http.Handler
 }
 
 func (svr *Server) applyContext() *Context {
@@ -62,6 +63,13 @@ func (svr *Server) Use(middleware ...Middleware) {
 	svr.middleware = append(svr.middleware, middleware...)
 }
 
+func (svr *Server) Any(prefix string, handle http.Handler) {
+	if !strings.HasSuffix(prefix, "/") {
+		prefix = "/" + prefix
+	}
+	svr.anyRequests[prefix] = handle
+}
+
 func (svr *Server) Handle(method string, path string, cb HandleFunc, middleware ...Middleware) {
 	if method == "" {
 		method = http.MethodPost
@@ -149,6 +157,12 @@ func (svr *Server) handleRequest(res http.ResponseWriter, req *http.Request) {
 }
 
 func (svr *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
+	for prefix, handle := range svr.anyRequests {
+		if strings.HasPrefix(request.URL.Path, prefix) {
+			handle.ServeHTTP(writer, request)
+			return
+		}
+	}
 	switch request.Method {
 	case http.MethodOptions:
 		svr.handleOption(writer, request)
@@ -175,9 +189,10 @@ func (svr *Server) Shutdown() (err error) {
 
 func New(ctx context.Context) *Server {
 	svr := &Server{
-		ctx:        ctx,
-		router:     router.New(),
-		middleware: make([]Middleware, 0, 10),
+		ctx:         ctx,
+		router:      router.New(),
+		anyRequests: make(map[string]http.Handler),
+		middleware:  make([]Middleware, 0, 10),
 	}
 	return svr
 }

+ 30 - 14
options.go

@@ -12,19 +12,21 @@ import (
 
 type (
 	Options struct {
-		Name            string
-		Version         string
-		Address         string
-		Port            int
-		EnableDebug     bool              //开启调试模式
-		DisableHttp     bool              //禁用HTTP入口
-		DisableCommand  bool              //禁用命令行入口
-		DisableStateApi bool              //禁用系统状态接口
-		Metadata        map[string]string //原数据
-		Context         context.Context
-		Signals         []os.Signal
-		server          Server
-		shortName       string
+		Name                string
+		Version             string
+		Address             string
+		Port                int
+		EnableDebug         bool              //开启调试模式
+		DisableHttp         bool              //禁用HTTP入口
+		EnableDirectHttp    bool              //启用HTTP直连模式
+		DisableCommand      bool              //禁用命令行入口
+		EnableDirectCommand bool              //启用命令行直连模式
+		DisableStateApi     bool              //禁用系统状态接口
+		Metadata            map[string]string //原数据
+		Context             context.Context
+		Signals             []os.Signal
+		server              Server
+		shortName           string
 	}
 
 	Option func(o *Options)
@@ -67,6 +69,20 @@ func WithDebug() Option {
 	}
 }
 
+func WithDirectHttp() Option {
+	return func(o *Options) {
+		o.DisableCommand = true
+		o.EnableDirectHttp = true
+	}
+}
+
+func WithDirectCommand() Option {
+	return func(o *Options) {
+		o.DisableHttp = true
+		o.EnableDirectCommand = true
+	}
+}
+
 func NewOptions() *Options {
 	opts := &Options{
 		Name:     env.Get(EnvAppName, sys.Hostname()),
@@ -75,7 +91,7 @@ func NewOptions() *Options {
 		Metadata: make(map[string]string),
 		Signals:  []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL},
 	}
-	opts.Port = int(env.Integer(EnvAppPort, 80))
+	opts.Port = int(env.Integer(18080, EnvAppPort, "HTTP_PORT", "KOS_PORT"))
 	opts.Address = env.Get(EnvAppAddress, ip.Internal())
 	return opts
 }

+ 8 - 0
service.go

@@ -8,6 +8,7 @@ import (
 	"git.nspix.com/golang/kos/entry"
 	"git.nspix.com/golang/kos/entry/cli"
 	"git.nspix.com/golang/kos/entry/http"
+	_ "git.nspix.com/golang/kos/pkg/cache"
 	"git.nspix.com/golang/kos/pkg/log"
 	"git.nspix.com/golang/kos/util/env"
 	"github.com/sourcegraph/conc"
@@ -127,6 +128,9 @@ func (app *application) httpServe() (err error) {
 	select {
 	case err = <-errChan:
 	case <-timer.C:
+		if app.opts.EnableDirectHttp {
+			app.gateway.Direct(l)
+		}
 	}
 	return
 }
@@ -152,6 +156,9 @@ func (app *application) commandServe() (err error) {
 	select {
 	case err = <-errChan:
 	case <-timer.C:
+		if app.opts.EnableDirectCommand {
+			app.gateway.Direct(l)
+		}
 	}
 	return
 }
@@ -220,6 +227,7 @@ func (app *application) preStart() (err error) {
 			return
 		}
 	}
+
 	app.plugins.Range(func(key, value any) bool {
 		if plugin, ok := value.(Plugin); ok {
 			if err = plugin.BeforeStart(); err != nil {

+ 22 - 1
util/env/env.go

@@ -15,7 +15,19 @@ func Get(name string, val string) string {
 	}
 }
 
-func Integer(name string, val int64) int64 {
+func Getter(val string, names ...string) string {
+	var (
+		value string
+	)
+	for _, name := range names {
+		if value = strings.TrimSpace(os.Getenv(name)); value != "" {
+			return value
+		}
+	}
+	return val
+}
+
+func Int(name string, val int64) int64 {
 	value := Get(name, "")
 	if n, err := strconv.ParseInt(value, 10, 64); err == nil {
 		return n
@@ -24,6 +36,15 @@ func Integer(name string, val int64) int64 {
 	}
 }
 
+func Integer(val int64, names ...string) int64 {
+	value := Getter("", names...)
+	if n, err := strconv.ParseInt(value, 10, 64); err == nil {
+		return n
+	} else {
+		return val
+	}
+}
+
 func Float(name string, val float64) float64 {
 	value := Get(name, "")
 	if n, err := strconv.ParseFloat(value, 64); err == nil {

+ 8 - 0
util/fetch/fetch.go

@@ -192,6 +192,14 @@ func Request(ctx context.Context, urlString string, response any, cbs ...Option)
 	return
 }
 
+func Do(ctx context.Context, req *http.Request, cbs ...Option) (res *http.Response, err error) {
+	opts := newOptions()
+	for _, cb := range cbs {
+		cb(opts)
+	}
+	return do(ctx, req, opts)
+}
+
 func do(ctx context.Context, req *http.Request, opts *Options) (res *http.Response, err error) {
 	if opts.Human {
 		if req.Header.Get("User-Agent") == "" {