client.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package rpc
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. var (
  12. DefaultTimeout = time.Second * 5
  13. )
  14. type (
  15. Client struct {
  16. conn net.Conn
  17. seq uint16
  18. once sync.Once
  19. isConnected int32
  20. transactionLocker sync.RWMutex
  21. transaction map[uint16]*transaction
  22. exitFlag int32
  23. exitChan chan struct{}
  24. network string
  25. address string
  26. connLock sync.Mutex
  27. pintAt time.Time
  28. Timeout time.Duration
  29. }
  30. transaction struct {
  31. sequence uint16
  32. response *Response
  33. isCanceled bool
  34. ch chan *transaction
  35. }
  36. )
  37. func (t *transaction) Cancel() {
  38. t.isCanceled = true
  39. close(t.ch)
  40. }
  41. func (t *transaction) Done(r *Response) {
  42. t.response = r
  43. if t.ch != nil && !t.isCanceled {
  44. select {
  45. case t.ch <- t:
  46. default:
  47. }
  48. }
  49. }
  50. func (c *Client) commit(seq uint16) *transaction {
  51. c.transactionLocker.Lock()
  52. trans := &transaction{
  53. sequence: seq,
  54. isCanceled: false,
  55. ch: make(chan *transaction),
  56. }
  57. c.transaction[seq] = trans
  58. c.transactionLocker.Unlock()
  59. return trans
  60. }
  61. func (c *Client) eventLoop() {
  62. ticker := time.NewTicker(time.Second * 10)
  63. defer ticker.Stop()
  64. for {
  65. select {
  66. case <-c.exitChan:
  67. return
  68. case <-ticker.C:
  69. if atomic.LoadInt32(&c.isConnected) == 1 {
  70. _ = writeFrame(c.conn, &Frame{
  71. Func: FuncPing,
  72. })
  73. }
  74. }
  75. }
  76. }
  77. func (c *Client) rdyLoop() {
  78. defer atomic.StoreInt32(&c.isConnected, 0)
  79. for {
  80. if frame, err := readFrame(c.conn); err == nil {
  81. if frame.Func == FuncResponse {
  82. c.transactionLocker.RLock()
  83. ch, ok := c.transaction[frame.Sequence]
  84. c.transactionLocker.RUnlock()
  85. if ok {
  86. if res, err := ReadResponse(frame.Data); err == nil {
  87. ch.Done(res)
  88. } else {
  89. ch.Cancel()
  90. }
  91. }
  92. } else if frame.Func == FuncPing {
  93. c.pintAt = time.Now()
  94. }
  95. } else {
  96. break
  97. }
  98. }
  99. }
  100. func (c *Client) DialerContext(ctx context.Context, network string, addr string) (err error) {
  101. var (
  102. ok bool
  103. deadline time.Time
  104. )
  105. if deadline, ok = ctx.Deadline(); !ok {
  106. deadline = time.Now().Add(c.Timeout)
  107. }
  108. c.network = network
  109. c.address = addr
  110. c.once.Do(func() {
  111. go c.eventLoop()
  112. })
  113. return c.dialer(deadline.Sub(time.Now()))
  114. }
  115. func (c *Client) Dialer(network string, addr string) (err error) {
  116. c.network = network
  117. c.address = addr
  118. c.once.Do(func() {
  119. go c.eventLoop()
  120. })
  121. return c.dialer(c.Timeout)
  122. }
  123. func (c *Client) dialer(timeout time.Duration) (err error) {
  124. c.connLock.Lock()
  125. defer c.connLock.Unlock()
  126. if atomic.LoadInt32(&c.isConnected) == 1 {
  127. return
  128. }
  129. if c.conn, err = net.DialTimeout(c.network, c.address, timeout); err != nil {
  130. return
  131. } else {
  132. atomic.StoreInt32(&c.isConnected, 1)
  133. go c.rdyLoop()
  134. }
  135. return
  136. }
  137. func (c *Client) Do(ctx context.Context, req *Request) (res *Response, err error) {
  138. if atomic.LoadInt32(&c.isConnected) == 0 {
  139. var (
  140. ok bool
  141. deadline time.Time
  142. )
  143. if deadline, ok = ctx.Deadline(); !ok {
  144. deadline = time.Now().Add(c.Timeout)
  145. }
  146. if err = c.dialer(deadline.Sub(time.Now())); err != nil {
  147. err = io.ErrClosedPipe
  148. return
  149. }
  150. }
  151. c.seq++
  152. seq := c.seq
  153. if err = writeFrame(c.conn, &Frame{
  154. Func: FuncRequest,
  155. Sequence: seq,
  156. Data: req.Bytes(),
  157. }); err != nil {
  158. return
  159. }
  160. trans := c.commit(seq)
  161. select {
  162. case t, ok := <-trans.ch:
  163. if ok {
  164. res = t.response
  165. } else {
  166. err = io.ErrClosedPipe
  167. }
  168. case <-ctx.Done():
  169. trans.Cancel()
  170. err = errors.New("Client.Timeout exceeded while awaiting response")
  171. }
  172. return
  173. }
  174. func (c *Client) Close() (err error) {
  175. if atomic.CompareAndSwapInt32(&c.exitFlag, 0, 1) {
  176. if c.conn != nil {
  177. err = c.conn.Close()
  178. }
  179. c.isConnected = 0
  180. close(c.exitChan)
  181. }
  182. return
  183. }
  184. func NewClient() *Client {
  185. return &Client{
  186. Timeout: DefaultTimeout,
  187. exitChan: make(chan struct{}),
  188. transaction: make(map[uint16]*transaction),
  189. }
  190. }