gateway.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package entry
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "github.com/sourcegraph/conc"
  7. "io"
  8. "net"
  9. "sync/atomic"
  10. "time"
  11. )
  12. const (
  13. minFeatureLength = 3
  14. )
  15. var (
  16. ErrShortFeature = errors.New("short feature")
  17. ErrInvalidListener = errors.New("invalid listener")
  18. )
  19. type (
  20. Feature []byte
  21. listenerEntity struct {
  22. feature Feature
  23. listener *Listener
  24. }
  25. Gateway struct {
  26. ctx context.Context
  27. cancelFunc context.CancelCauseFunc
  28. l net.Listener
  29. ch chan net.Conn
  30. address string
  31. state *State
  32. waitGroup conc.WaitGroup
  33. listeners []*listenerEntity
  34. exitFlag int32
  35. }
  36. )
  37. func (gw *Gateway) handle(conn net.Conn) {
  38. var (
  39. n int
  40. err error
  41. successed int32
  42. feature = make([]byte, minFeatureLength)
  43. )
  44. atomic.AddInt32(&gw.state.Concurrency, 1)
  45. defer func() {
  46. if atomic.LoadInt32(&successed) != 1 {
  47. atomic.AddInt32(&gw.state.Concurrency, -1)
  48. atomic.AddInt64(&gw.state.Request.Discarded, 1)
  49. _ = conn.Close()
  50. }
  51. }()
  52. //set deadline
  53. if err = conn.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil {
  54. return
  55. }
  56. //read feature
  57. if n, err = io.ReadFull(conn, feature); err != nil {
  58. return
  59. }
  60. //reset deadline
  61. if err = conn.SetReadDeadline(time.Time{}); err != nil {
  62. return
  63. }
  64. for _, l := range gw.listeners {
  65. if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
  66. atomic.StoreInt32(&successed, 1)
  67. l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
  68. return
  69. }
  70. }
  71. }
  72. func (gw *Gateway) accept() {
  73. atomic.StoreInt32(&gw.state.Accepting, 1)
  74. defer func() {
  75. atomic.StoreInt32(&gw.state.Accepting, 0)
  76. }()
  77. for {
  78. if conn, err := gw.l.Accept(); err != nil {
  79. break
  80. } else {
  81. select {
  82. case gw.ch <- conn:
  83. atomic.AddInt64(&gw.state.Request.Total, 1)
  84. case <-gw.ctx.Done():
  85. return
  86. }
  87. }
  88. }
  89. }
  90. func (gw *Gateway) worker() {
  91. atomic.StoreInt32(&gw.state.Processing, 1)
  92. defer func() {
  93. atomic.StoreInt32(&gw.state.Processing, 0)
  94. }()
  95. for {
  96. select {
  97. case <-gw.ctx.Done():
  98. return
  99. case conn, ok := <-gw.ch:
  100. if ok {
  101. gw.handle(conn)
  102. }
  103. }
  104. }
  105. }
  106. func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
  107. var (
  108. ok bool
  109. ls *Listener
  110. )
  111. if len(feature) < minFeatureLength {
  112. return ErrShortFeature
  113. }
  114. if ls, ok = listener.(*Listener); !ok {
  115. return ErrInvalidListener
  116. }
  117. for _, l := range gw.listeners {
  118. if bytes.Compare(l.feature, feature) == 0 {
  119. l.listener = ls
  120. return
  121. }
  122. }
  123. gw.listeners = append(gw.listeners, &listenerEntity{
  124. feature: feature,
  125. listener: ls,
  126. })
  127. return
  128. }
  129. func (gw *Gateway) Apply(feature ...Feature) (listener net.Listener, err error) {
  130. listener = newListener(gw.l.Addr())
  131. for _, code := range feature {
  132. if len(code) < minFeatureLength {
  133. continue
  134. }
  135. err = gw.Bind(code, listener)
  136. }
  137. return listener, nil
  138. }
  139. func (gw *Gateway) Release(feature Feature) {
  140. for i, l := range gw.listeners {
  141. if bytes.Compare(l.feature, feature) == 0 {
  142. gw.listeners = append(gw.listeners[:i], gw.listeners[i+1:]...)
  143. }
  144. }
  145. }
  146. func (gw *Gateway) State() *State {
  147. return gw.state
  148. }
  149. func (gw *Gateway) Start(ctx context.Context) (err error) {
  150. gw.ctx, gw.cancelFunc = context.WithCancelCause(ctx)
  151. if gw.l, err = net.Listen("tcp", gw.address); err != nil {
  152. return
  153. }
  154. gw.waitGroup.Go(gw.worker)
  155. gw.waitGroup.Go(gw.accept)
  156. return
  157. }
  158. func (gw *Gateway) Stop() (err error) {
  159. if !atomic.CompareAndSwapInt32(&gw.exitFlag, 0, 1) {
  160. return
  161. }
  162. gw.cancelFunc(io.ErrClosedPipe)
  163. err = gw.l.Close()
  164. gw.waitGroup.Wait()
  165. close(gw.ch)
  166. return
  167. }
  168. func New(address string) *Gateway {
  169. gw := &Gateway{
  170. address: address,
  171. state: &State{},
  172. ch: make(chan net.Conn, 10),
  173. }
  174. return gw
  175. }