package rpc import ( "context" "net" "net/http" "sync" "sync/atomic" "git.nspix.com/golang/micro/log" ) type HandleFunc func(ctx *Context) error type Server struct { ctx context.Context listener net.Listener ch chan *Request ctxPool sync.Pool serviceMap sync.Map // map[string]HandleFunc sessions sync.Map exitFlag int32 exitChan chan struct{} } func (svr *Server) getContext() *Context { if v := svr.ctxPool.Get(); v != nil { return v.(*Context) } else { return &Context{} } } func (svr *Server) putContext(c *Context) { svr.ctxPool.Put(c) } func (svr *Server) wrkLoop() { for { select { case req, ok := <-svr.ch: if ok { svr.handleRequest(req) } case <-svr.exitChan: return } } } func (svr *Server) handleRequest(req *Request) { var ( ok bool err error cb HandleFunc val interface{} ctx *Context ) if val, ok = svr.serviceMap.Load(req.Method); ok { cb = val.(HandleFunc) ctx = svr.getContext() ctx.Reset(req, NewResponse()) if err = cb(ctx); err == nil { if err = writeFrame(req.conn, &Frame{ Func: FuncResponse, Sequence: req.Sequence, Data: ctx.Response().Bytes(), }); err != nil { log.Warnf("RPC: write request(%s@%d) response error: %s", req.Method, req.Sequence, err.Error()) } } else { log.Warnf("RPC: handle request(%s@%d) error: %s", req.Method, req.Sequence, err.Error()) resp := NewResponse() resp.code = http.StatusServiceUnavailable resp.message = http.StatusText(http.StatusServiceUnavailable) ctx.Reset(req, resp) if err = writeFrame(req.conn, &Frame{ Func: FuncResponse, Sequence: req.Sequence, Data: ctx.Response().Bytes(), }); err != nil { log.Warnf("RPC: write request(%s@%d) response error: %s", req.Method, req.Sequence, err.Error()) } } } else { ctx = svr.getContext() resp := NewResponse() resp.code = http.StatusNotFound resp.message = http.StatusText(http.StatusNotFound) ctx.Reset(req, resp) if err = writeFrame(req.conn, &Frame{ Func: FuncResponse, Sequence: req.Sequence, Data: ctx.Response().Bytes(), }); err != nil { log.Warnf("RPC: write request(%s@%d) response failed cause by: %s", req.Method, req.Sequence, err.Error()) } } } func (svr *Server) process(conn net.Conn) { var ( err error frame *Frame ) svr.sessions.Store(conn.LocalAddr().String(), conn) defer func() { svr.sessions.Delete(conn.LocalAddr().String()) _ = conn.Close() }() for { if frame, err = readFrame(conn); err != nil { return } switch frame.Func { case FuncPing: if err = writeFrame(conn, &Frame{Func: FuncPing}); err != nil { return } case FuncRequest: if req, err2 := ReadRequest(frame.Data); err2 == nil { req.reset(frame.Sequence, conn) select { case svr.ch <- req: default: } } } } } func (svr *Server) Handle(method string, f HandleFunc) { svr.serviceMap.Store(method, f) return } func (svr *Server) Serve(l net.Listener) (err error) { svr.listener = l go func() { svr.wrkLoop() }() for { if conn, err2 := svr.listener.Accept(); err2 == nil { go svr.process(conn) } else { err = err2 break } } return } func (svr *Server) Close() (err error) { if atomic.CompareAndSwapInt32(&svr.exitFlag, 0, 1) { //clear sessions svr.sessions.Range(func(key, value interface{}) bool { c := value.(net.Conn) _ = c.Close() return true }) err = svr.listener.Close() close(svr.exitChan) } return } func New(ctx context.Context) *Server { return &Server{ ctx: ctx, ch: make(chan *Request, 10), exitChan: make(chan struct{}), } }