gateway.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package gateway
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "io"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. const (
  13. MinFeatureLength = 3
  14. )
  15. var (
  16. ErrShortFeature = errors.New("short feature")
  17. ErrFeatureExists = errors.New("feature already exists")
  18. )
  19. type (
  20. listener struct {
  21. feature []byte
  22. l *Listener
  23. }
  24. Gateway struct {
  25. listeners []*listener
  26. l net.Listener
  27. ch chan net.Conn
  28. state *State
  29. }
  30. )
  31. func (g *Gateway) Attach(feature []byte, l *Listener) (err error) {
  32. //特征量不够大
  33. if len(feature) < MinFeatureLength {
  34. return ErrShortFeature
  35. }
  36. //判断重复
  37. for _, v := range g.listeners {
  38. if bytes.Equal(v.feature, feature) {
  39. return ErrFeatureExists
  40. }
  41. }
  42. g.listeners = append(g.listeners, &listener{
  43. feature: feature,
  44. l: l,
  45. })
  46. return
  47. }
  48. func (g *Gateway) Attaches(features [][]byte, l *Listener) (err error) {
  49. for _, b := range features {
  50. if err = g.Attach(b, l); err != nil {
  51. break
  52. }
  53. }
  54. return
  55. }
  56. func (g *Gateway) Detach(feature []byte) (err error) {
  57. for i, l := range g.listeners {
  58. if bytes.Equal(l.feature, feature) {
  59. g.listeners = append(g.listeners[:i], g.listeners[i+1:]...)
  60. break
  61. }
  62. }
  63. return
  64. }
  65. func (g *Gateway) process(conn net.Conn) {
  66. var (
  67. n int
  68. err error
  69. feature = make([]byte, MinFeatureLength)
  70. )
  71. atomic.AddInt32(&g.state.NumOfRequest, 1)
  72. //set deadline
  73. if err = conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
  74. return
  75. }
  76. if n, err = io.ReadFull(conn, feature); err != nil {
  77. _ = conn.Close()
  78. return
  79. }
  80. //reset deadline
  81. if err = conn.SetReadDeadline(time.Time{}); err != nil {
  82. return
  83. }
  84. for _, l := range g.listeners {
  85. if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
  86. l.l.Receive(wrapConn(conn, g.state, feature[:n]))
  87. break
  88. }
  89. }
  90. }
  91. func (g *Gateway) worker(ctx context.Context) {
  92. var (
  93. ok bool
  94. conn net.Conn
  95. )
  96. for {
  97. select {
  98. case conn, ok = <-g.ch:
  99. if ok {
  100. g.process(conn)
  101. }
  102. case <-ctx.Done():
  103. return
  104. }
  105. }
  106. }
  107. func (g *Gateway) schedule(ctx context.Context) {
  108. for {
  109. conn, err := g.l.Accept()
  110. if err != nil {
  111. return
  112. }
  113. select {
  114. case g.ch <- conn:
  115. case <-ctx.Done():
  116. return
  117. }
  118. }
  119. }
  120. func (g *Gateway) State() *State {
  121. return g.state
  122. }
  123. //Run 运行项目
  124. func (g *Gateway) Run(ctx context.Context) {
  125. var wg sync.WaitGroup
  126. wg.Add(2)
  127. go func() {
  128. g.worker(ctx)
  129. wg.Done()
  130. }()
  131. go func() {
  132. g.schedule(ctx)
  133. wg.Done()
  134. }()
  135. wg.Wait()
  136. for _, l := range g.listeners {
  137. _ = l.l.Close()
  138. }
  139. }
  140. func New(l net.Listener) *Gateway {
  141. return &Gateway{
  142. l: l,
  143. state: &State{},
  144. ch: make(chan net.Conn, 10),
  145. listeners: make([]*listener, 0),
  146. }
  147. }