|
- package entry
- import (
- "bytes"
- "context"
- "errors"
- "io"
- "net"
- "sync/atomic"
- "time"
- "github.com/sourcegraph/conc"
- )
- const (
- minFeatureLength = 3
- )
- var (
- ErrShortFeature = errors.New("short feature")
- ErrInvalidListener = errors.New("invalid listener")
- )
- type (
- Feature []byte
- listenerEntity struct {
- feature Feature
- listener *Listener
- }
- Gateway struct {
- ctx context.Context
- cancelFunc context.CancelCauseFunc
- l net.Listener
- ch chan net.Conn
- address string
- state *State
- waitGroup conc.WaitGroup
- listeners []*listenerEntity
- direct *Listener
- exitFlag int32
- }
- )
- func (gw *Gateway) handle(conn net.Conn) {
- var (
- n int
- err error
- success int32
- feature = make([]byte, minFeatureLength)
- )
- atomic.AddInt32(&gw.state.Concurrency, 1)
- defer func() {
- if atomic.LoadInt32(&success) != 1 {
- atomic.AddInt32(&gw.state.Concurrency, -1)
- gw.state.IncRequestDiscarded(1)
- _ = conn.Close()
- }
- }()
-
- if err = conn.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil {
- return
- }
-
- if n, err = io.ReadFull(conn, feature); err != nil {
- return
- }
-
- if err = conn.SetReadDeadline(time.Time{}); err != nil {
- return
- }
- for _, l := range gw.listeners {
- if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
- atomic.StoreInt32(&success, 1)
- l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
- return
- }
- }
- }
- func (gw *Gateway) accept() {
- atomic.StoreInt32(&gw.state.Accepting, 1)
- defer func() {
- atomic.StoreInt32(&gw.state.Accepting, 0)
- }()
- for {
- if conn, err := gw.l.Accept(); err != nil {
- break
- } else {
-
- if gw.direct != nil {
- gw.direct.Receive(conn)
- } else {
- select {
- case gw.ch <- conn:
- gw.state.IncRequest(1)
- case <-gw.ctx.Done():
- return
- }
- }
- }
- }
- }
- func (gw *Gateway) worker() {
- atomic.StoreInt32(&gw.state.Processing, 1)
- defer func() {
- atomic.StoreInt32(&gw.state.Processing, 0)
- }()
- for {
- select {
- case <-gw.ctx.Done():
- return
- case conn, ok := <-gw.ch:
- if ok {
- gw.handle(conn)
- }
- }
- }
- }
- func (gw *Gateway) Direct(l net.Listener) {
- if ls, ok := l.(*Listener); ok {
- gw.direct = ls
- }
- }
- func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
- var (
- ok bool
- ls *Listener
- )
- if len(feature) < minFeatureLength {
- return ErrShortFeature
- }
- if ls, ok = listener.(*Listener); !ok {
- return ErrInvalidListener
- }
- for _, l := range gw.listeners {
- if bytes.Compare(l.feature, feature) == 0 {
- l.listener = ls
- return
- }
- }
- gw.listeners = append(gw.listeners, &listenerEntity{
- feature: feature,
- listener: ls,
- })
- return
- }
- func (gw *Gateway) Apply(feature ...Feature) (listener net.Listener, err error) {
- listener = newListener(gw.l.Addr())
- for _, code := range feature {
- if len(code) < minFeatureLength {
- continue
- }
- err = gw.Bind(code, listener)
- }
- return listener, nil
- }
- func (gw *Gateway) Release(feature Feature) {
- for i, l := range gw.listeners {
- if bytes.Compare(l.feature, feature) == 0 {
- gw.listeners = append(gw.listeners[:i], gw.listeners[i+1:]...)
- }
- }
- }
- func (gw *Gateway) State() *State {
- return gw.state
- }
- func (gw *Gateway) Start(ctx context.Context) (err error) {
- gw.ctx, gw.cancelFunc = context.WithCancelCause(ctx)
- if gw.l, err = net.Listen("tcp", gw.address); err != nil {
- return
- }
- for i := 0; i < 2; i++ {
- gw.waitGroup.Go(gw.worker)
- }
- gw.waitGroup.Go(gw.accept)
- return
- }
- func (gw *Gateway) Stop() (err error) {
- if !atomic.CompareAndSwapInt32(&gw.exitFlag, 0, 1) {
- return
- }
- gw.cancelFunc(io.ErrClosedPipe)
- err = gw.l.Close()
- gw.waitGroup.Wait()
- close(gw.ch)
- return
- }
- func New(address string) *Gateway {
- gw := &Gateway{
- address: address,
- state: &State{},
- ch: make(chan net.Conn, 10),
- }
- return gw
- }
|