package rpc import ( "net" "sync" "sync/atomic" "git.nspix.com/golang/micro/log" ) type HandleFunc func(ctx *Context) error type Server struct { listener net.Listener ch chan *Request ctxPool sync.Pool serviceMap sync.Map // map[string]HandleFunc sessions sync.Map exitFlag int32 exitChan chan struct{} } func (svr *Server) getContext() *Context { if v := svr.ctxPool.Get(); v != nil { return v.(*Context) } else { return &Context{} } } func (svr *Server) putContext(c *Context) { svr.ctxPool.Put(c) } func (svr *Server) wrkLoop() { for { select { case req, ok := <-svr.ch: if ok { svr.handleRequest(req) } case <-svr.exitChan: return } } } func (svr *Server) handleRequest(req *Request) { var ( ok bool err error cb HandleFunc val interface{} ctx *Context ) if val, ok = svr.serviceMap.Load(req.Method); ok { cb = val.(HandleFunc) ctx = svr.getContext() ctx.Reset(req, NewResponse()) if err = cb(ctx); err == nil { if err = writeFrame(req.conn, &Frame{ Func: FuncResponse, Sequence: req.Sequence, Data: ctx.Response().Bytes(), }); err != nil { log.Warnf("RPC: write sequence(%d) response failed cause by: %s", req.Sequence, err.Error()) } } } else { ctx = svr.getContext() resp := NewResponse() resp.code = 404 resp.message = "not found" ctx.Reset(req, resp) if err = writeFrame(req.conn, &Frame{ Func: FuncResponse, Sequence: req.Sequence, Data: ctx.Response().Bytes(), }); err != nil { log.Warnf("RPC: write sequence(%d) response failed cause by: %s", req.Sequence, err.Error()) } } } func (svr *Server) process(conn net.Conn) { var ( err error frame *Frame ) log.Infof("RPC: connection %s connecting", conn.RemoteAddr()) svr.sessions.Store(conn.LocalAddr().String(), conn) defer func() { log.Infof("RPC: connection %s closed", conn.RemoteAddr()) svr.sessions.Delete(conn.LocalAddr().String()) _ = conn.Close() }() for { if frame, err = readFrame(conn); err != nil { return } switch frame.Func { case FuncPing: if err = writeFrame(conn, &Frame{Func: FuncPing}); err != nil { log.Warnf("RPC: write ping frame error: %s", err.Error()) } case FuncRequest: //read request if req, err2 := ReadRequest(frame.Data); err2 == nil { req.reset(frame.Sequence, conn) select { case svr.ch <- req: default: } } else { log.Warnf("RPC: read request error: %s", err2.Error()) } } } } func (svr *Server) Handle(method string, f HandleFunc) { svr.serviceMap.Store(method, f) return } func (svr *Server) Serve(l net.Listener) (err error) { svr.listener = l go func() { svr.wrkLoop() }() for { if conn, err2 := svr.listener.Accept(); err2 == nil { go svr.process(conn) } else { err = err2 break } } return } func (svr *Server) Close() (err error) { if atomic.CompareAndSwapInt32(&svr.exitFlag, 0, 1) { //clear sessions svr.sessions.Range(func(key, value interface{}) bool { c := value.(net.Conn) c.Close() return true }) err = svr.listener.Close() close(svr.exitChan) } return } func NewServer() *Server { return &Server{ ch: make(chan *Request, 10), exitChan: make(chan struct{}), } }