package rpc import ( "context" "errors" "io" "net" "sync" "sync/atomic" "time" "git.nspix.com/golang/micro/log" ) var ( DefaultTimeout = time.Second * 5 ) type ( Client struct { conn net.Conn seq int32 once sync.Once isConnected int32 transactionLocker sync.RWMutex transaction map[uint16]*transaction exitFlag int32 exitChan chan struct{} network string address string connLock sync.Mutex pintAt time.Time Timeout time.Duration } transaction struct { sequence uint16 response *Response canceledFlag int32 ch chan *transaction } ) func (t *transaction) Cancel() { if atomic.CompareAndSwapInt32(&t.canceledFlag, 0, 1) { close(t.ch) } } func (t *transaction) Done(r *Response) { t.response = r if t.ch != nil && atomic.LoadInt32(&t.canceledFlag) == 0 { select { case t.ch <- t: default: } } } func (c *Client) commit(seq uint16) *transaction { c.transactionLocker.Lock() trans := &transaction{ sequence: seq, ch: make(chan *transaction), } c.transaction[seq] = trans c.transactionLocker.Unlock() return trans } func (c *Client) eventLoop() { ticker := time.NewTicker(time.Second * 10) defer ticker.Stop() for { select { case <-c.exitChan: return case <-ticker.C: if atomic.LoadInt32(&c.isConnected) == 1 { _ = writeFrame(c.conn, &Frame{ Func: FuncPing, }) } } } } func (c *Client) rdyLoop() { defer atomic.StoreInt32(&c.isConnected, 0) for { if frame, err := readFrame(c.conn); err == nil { switch frame.Func { case FuncPing: c.pintAt = time.Now() case FuncResponse: c.transactionLocker.RLock() ch, ok := c.transaction[frame.Sequence] c.transactionLocker.RUnlock() if ok { if res, err := ReadResponse(frame.Data); err == nil { ch.Done(res) } else { ch.Cancel() } } else { log.Warnf("RPC: connection %s response %d dropped", c.conn.LocalAddr(), frame.Sequence) } } } else { log.Infof("RPC: connection %s closed", c.conn.LocalAddr()) break } } } func (c *Client) DialerContext(ctx context.Context, network string, addr string) (err error) { var ( ok bool deadline time.Time ) if deadline, ok = ctx.Deadline(); !ok { deadline = time.Now().Add(c.Timeout) } c.network = network c.address = addr c.once.Do(func() { go c.eventLoop() }) return c.dialer(deadline.Sub(time.Now())) } func (c *Client) Dialer(network string, addr string) (err error) { c.network = network c.address = addr c.once.Do(func() { go c.eventLoop() }) return c.dialer(c.Timeout) } func (c *Client) dialer(timeout time.Duration) (err error) { c.connLock.Lock() defer c.connLock.Unlock() if atomic.LoadInt32(&c.isConnected) == 1 { return } if c.conn, err = net.DialTimeout(c.network, c.address, timeout); err != nil { return } else { atomic.StoreInt32(&c.isConnected, 1) go c.rdyLoop() } return } func (c *Client) Do(ctx context.Context, req *Request) (res *Response, err error) { if atomic.LoadInt32(&c.isConnected) == 0 { var ( ok bool deadline time.Time ) if deadline, ok = ctx.Deadline(); !ok { deadline = time.Now().Add(c.Timeout) } if err = c.dialer(deadline.Sub(time.Now())); err != nil { err = io.ErrClosedPipe return } } seq := uint16(atomic.AddInt32(&c.seq, 1)) if err = writeFrame(c.conn, &Frame{ Func: FuncRequest, Sequence: seq, Data: req.Bytes(), }); err != nil { return } trans := c.commit(seq) select { case t, ok := <-trans.ch: if ok { res = t.response } else { //canceled err = io.ErrClosedPipe } case <-c.exitChan: err = io.ErrClosedPipe case <-ctx.Done(): trans.Cancel() err = errors.New("Client.Timeout exceeded while awaiting response") } return } func (c *Client) Close() (err error) { if atomic.CompareAndSwapInt32(&c.exitFlag, 0, 1) { c.transactionLocker.Lock() for _, t := range c.transaction { t.Cancel() } c.transactionLocker.Unlock() if c.conn != nil { err = c.conn.Close() } c.isConnected = 0 close(c.exitChan) } return } func NewClient() *Client { return &Client{ Timeout: DefaultTimeout, exitChan: make(chan struct{}), transaction: make(map[uint16]*transaction), } }