gateway.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package gateway
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "io"
  7. "net"
  8. "sync"
  9. "time"
  10. )
  11. const (
  12. MinFeatureLength = 3
  13. )
  14. var (
  15. ErrShortFeature = errors.New("short feature")
  16. ErrFeatureExists = errors.New("feature already exists")
  17. )
  18. type (
  19. listener struct {
  20. feature []byte
  21. l *Listener
  22. }
  23. Gateway struct {
  24. listeners []*listener
  25. l net.Listener
  26. ch chan net.Conn
  27. enableStats bool
  28. }
  29. )
  30. func (g *Gateway) Attach(feature []byte, l *Listener) (err error) {
  31. //特征量不够大
  32. if len(feature) < MinFeatureLength {
  33. return ErrShortFeature
  34. }
  35. //判断重复
  36. for _, v := range g.listeners {
  37. if bytes.Equal(v.feature, feature) {
  38. return ErrFeatureExists
  39. }
  40. }
  41. g.listeners = append(g.listeners, &listener{
  42. feature: feature,
  43. l: l,
  44. })
  45. return
  46. }
  47. func (g *Gateway) Attaches(features [][]byte, l *Listener) (err error) {
  48. for _, b := range features {
  49. if err = g.Attach(b, l); err != nil {
  50. break
  51. }
  52. }
  53. return
  54. }
  55. func (g *Gateway) Detach(feature []byte) (err error) {
  56. for i, l := range g.listeners {
  57. if bytes.Equal(l.feature, feature) {
  58. g.listeners = append(g.listeners[:i], g.listeners[i+1:]...)
  59. break
  60. }
  61. }
  62. return
  63. }
  64. func (g *Gateway) process(conn net.Conn) {
  65. var (
  66. n int
  67. err error
  68. feature = make([]byte, MinFeatureLength)
  69. )
  70. //set deadline
  71. if err = conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
  72. return
  73. }
  74. if n, err = io.ReadFull(conn, feature); err != nil {
  75. _ = conn.Close()
  76. return
  77. }
  78. //reset deadline
  79. if err = conn.SetReadDeadline(time.Time{}); err != nil {
  80. return
  81. }
  82. for _, l := range g.listeners {
  83. if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
  84. l.l.Receive(wrapConn(conn, feature[:n]))
  85. break
  86. }
  87. }
  88. }
  89. func (g *Gateway) worker(ctx context.Context) {
  90. var (
  91. ok bool
  92. conn net.Conn
  93. )
  94. for {
  95. select {
  96. case conn, ok = <-g.ch:
  97. if ok {
  98. g.process(conn)
  99. }
  100. case <-ctx.Done():
  101. return
  102. }
  103. }
  104. }
  105. func (g *Gateway) schedule(ctx context.Context) {
  106. for {
  107. conn, err := g.l.Accept()
  108. if err != nil {
  109. return
  110. }
  111. select {
  112. case g.ch <- conn:
  113. case <-ctx.Done():
  114. return
  115. }
  116. }
  117. }
  118. //Run 运行项目
  119. func (g *Gateway) Run(ctx context.Context) {
  120. var wg sync.WaitGroup
  121. wg.Add(2)
  122. go func() {
  123. g.worker(ctx)
  124. wg.Done()
  125. }()
  126. go func() {
  127. g.schedule(ctx)
  128. wg.Done()
  129. }()
  130. wg.Wait()
  131. for _, l := range g.listeners {
  132. _ = l.l.Close()
  133. }
  134. }
  135. func New(l net.Listener, stats bool) *Gateway {
  136. return &Gateway{
  137. l: l,
  138. enableStats: stats,
  139. ch: make(chan net.Conn, 10),
  140. listeners: make([]*listener, 0),
  141. }
  142. }