gateway.go 3.8 KB


  1. package entry
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "io"
  7. "net"
  8. "sync/atomic"
  9. "time"
  10. "github.com/sourcegraph/conc"
  11. )
  12. const (
  13. minFeatureLength = 3
  14. )
  15. var (
  16. ErrShortFeature = errors.New("short feature")
  17. ErrInvalidListener = errors.New("invalid listener")
  18. )
  19. type (
  20. Feature []byte
  21. listenerEntity struct {
  22. feature Feature
  23. listener *Listener
  24. }
  25. Gateway struct {
  26. ctx context.Context
  27. cancelFunc context.CancelCauseFunc
  28. l net.Listener
  29. ch chan net.Conn
  30. address string
  31. state *State
  32. waitGroup conc.WaitGroup
  33. listeners []*listenerEntity
  34. direct *Listener
  35. exitFlag int32
  36. }
  37. )
  38. func (gw *Gateway) handle(conn net.Conn) {
  39. var (
  40. n int
  41. err error
  42. success int32
  43. feature = make([]byte, minFeatureLength)
  44. )
  45. atomic.AddInt32(&gw.state.Concurrency, 1)
  46. defer func() {
  47. if atomic.LoadInt32(&success) != 1 {
  48. atomic.AddInt32(&gw.state.Concurrency, -1)
  49. gw.state.IncRequestDiscarded(1)
  50. _ = conn.Close()
  51. }
  52. }()
  53. //set deadline
  54. if err = conn.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil {
  55. return
  56. }
  57. //read feature
  58. if n, err = io.ReadFull(conn, feature); err != nil {
  59. return
  60. }
  61. //reset deadline
  62. if err = conn.SetReadDeadline(time.Time{}); err != nil {
  63. return
  64. }
  65. for _, l := range gw.listeners {
  66. if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
  67. atomic.StoreInt32(&success, 1)
  68. l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
  69. return
  70. }
  71. }
  72. }
  73. func (gw *Gateway) accept() {
  74. atomic.StoreInt32(&gw.state.Accepting, 1)
  75. defer func() {
  76. atomic.StoreInt32(&gw.state.Accepting, 0)
  77. }()
  78. for {
  79. if conn, err := gw.l.Accept(); err != nil {
  80. break
  81. } else {
  82. //give direct listener
  83. if gw.direct != nil {
  84. gw.direct.Receive(conn)
  85. } else {
  86. select {
  87. case gw.ch <- conn:
  88. gw.state.IncRequest(1)
  89. case <-gw.ctx.Done():
  90. return
  91. }
  92. }
  93. }
  94. }
  95. }
  96. func (gw *Gateway) worker() {
  97. atomic.StoreInt32(&gw.state.Processing, 1)
  98. defer func() {
  99. atomic.StoreInt32(&gw.state.Processing, 0)
  100. }()
  101. for {
  102. select {
  103. case <-gw.ctx.Done():
  104. return
  105. case conn, ok := <-gw.ch:
  106. if ok {
  107. gw.handle(conn)
  108. }
  109. }
  110. }
  111. }
  112. func (gw *Gateway) Direct(l net.Listener) {
  113. if ls, ok := l.(*Listener); ok {
  114. gw.direct = ls
  115. }
  116. }
  117. func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
  118. var (
  119. ok bool
  120. ls *Listener
  121. )
  122. if len(feature) < minFeatureLength {
  123. return ErrShortFeature
  124. }
  125. if ls, ok = listener.(*Listener); !ok {
  126. return ErrInvalidListener
  127. }
  128. for _, l := range gw.listeners {
  129. if bytes.Compare(l.feature, feature) == 0 {
  130. l.listener = ls
  131. return
  132. }
  133. }
  134. gw.listeners = append(gw.listeners, &listenerEntity{
  135. feature: feature,
  136. listener: ls,
  137. })
  138. return
  139. }
  140. func (gw *Gateway) Apply(feature ...Feature) (listener net.Listener, err error) {
  141. listener = newListener(gw.l.Addr())
  142. for _, code := range feature {
  143. if len(code) < minFeatureLength {
  144. continue
  145. }
  146. err = gw.Bind(code, listener)
  147. }
  148. return listener, nil
  149. }
  150. func (gw *Gateway) Release(feature Feature) {
  151. for i, l := range gw.listeners {
  152. if bytes.Compare(l.feature, feature) == 0 {
  153. gw.listeners = append(gw.listeners[:i], gw.listeners[i+1:]...)
  154. }
  155. }
  156. }
  157. func (gw *Gateway) State() *State {
  158. return gw.state
  159. }
  160. func (gw *Gateway) Start(ctx context.Context) (err error) {
  161. gw.ctx, gw.cancelFunc = context.WithCancelCause(ctx)
  162. if gw.l, err = net.Listen("tcp", gw.address); err != nil {
  163. return
  164. }
  165. for i := 0; i < 2; i++ {
  166. gw.waitGroup.Go(gw.worker)
  167. }
  168. gw.waitGroup.Go(gw.accept)
  169. return
  170. }
  171. func (gw *Gateway) Stop() (err error) {
  172. if !atomic.CompareAndSwapInt32(&gw.exitFlag, 0, 1) {
  173. return
  174. }
  175. gw.cancelFunc(io.ErrClosedPipe)
  176. err = gw.l.Close()
  177. gw.waitGroup.Wait()
  178. close(gw.ch)
  179. return
  180. }
  181. func New(address string) *Gateway {
  182. gw := &Gateway{
  183. address: address,
  184. state: &State{},
  185. ch: make(chan net.Conn, 10),
  186. }
  187. return gw
  188. }