Browse Source

add request lib

fancl 1 year ago
parent
commit
f3b532ec67
5 changed files with 411 additions and 2 deletions
  1. 1 1
      entry/http/server.go
  2. 1 1
      pkg/cache/memcache.go
  3. 27 0
      pkg/request/auth.go
  4. 152 0
      pkg/request/client.go
  5. 230 0
      pkg/request/request.go

+ 1 - 1
entry/http/server.go

@@ -64,7 +64,7 @@ func (svr *Server) Use(middleware ...Middleware) {
 }
 
 func (svr *Server) Any(prefix string, handle http.Handler) {
-	if !strings.HasSuffix(prefix, "/") {
+	if !strings.HasPrefix(prefix, "/") {
 		prefix = "/" + prefix
 	}
 	svr.anyRequests[prefix] = handle

+ 1 - 1
pkg/cache/memcache.go

@@ -28,6 +28,6 @@ func (cache *MemCache) Del(ctx context.Context, key string) {
 
 func NewMemCache() *MemCache {
 	return &MemCache{
-		engine: cache.New(time.Hour, time.Minute*90),
+		engine: cache.New(time.Hour, time.Minute*10),
 	}
 }

+ 27 - 0
pkg/request/auth.go

@@ -0,0 +1,27 @@
+package request
+
+import (
+	"encoding/base64"
+	"fmt"
+)
+
+type Authorization interface {
+	Token() string
+}
+
+type BasicAuth struct {
+	Username string
+	Password string
+}
+
+type BearerAuth struct {
+	AccessToken string
+}
+
+func (auth *BasicAuth) Token() string {
+	return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(auth.Username+":"+auth.Password)))
+}
+
+func (auth *BearerAuth) Token() string {
+	return fmt.Sprintf("Bearer %s", auth.AccessToken)
+}

+ 152 - 0
pkg/request/client.go

@@ -0,0 +1,152 @@
+package request
+
+import (
+	"bytes"
+	"io"
+	"net/http"
+	"net/http/cookiejar"
+	"strings"
+)
+
+type (
+	BeforeRequest func(req *http.Request) (err error)
+	AfterRequest  func(req *http.Request, res *http.Response) (err error)
+
+	Client struct {
+		baseUrl             string
+		Authorization       Authorization
+		client              *http.Client
+		cookieJar           *cookiejar.Jar
+		interceptorRequest  []BeforeRequest
+		interceptorResponse []AfterRequest
+	}
+)
+
+func (client *Client) stashUri(urlPath string) string {
+	var (
+		pos int
+	)
+	if len(urlPath) == 0 {
+		return client.baseUrl
+	}
+	if pos = strings.Index(urlPath, "//"); pos == -1 {
+		if client.baseUrl != "" {
+			if urlPath[0] != '/' {
+				urlPath = "/" + urlPath
+			}
+			return client.baseUrl + urlPath
+		}
+	}
+	return urlPath
+}
+
+func (client *Client) BeforeRequest(cb BeforeRequest) *Client {
+	client.interceptorRequest = append(client.interceptorRequest, cb)
+	return client
+}
+
+func (client *Client) AfterRequest(cb AfterRequest) *Client {
+	client.interceptorResponse = append(client.interceptorResponse, cb)
+	return client
+}
+
+func (client *Client) SetBaseUrl(s string) *Client {
+	client.baseUrl = strings.TrimSuffix(s, "/")
+	return client
+}
+
+func (client *Client) SetCookieJar(cookieJar *cookiejar.Jar) *Client {
+	client.client.Jar = cookieJar
+	return client
+}
+
+func (client *Client) SetClient(httpClient *http.Client) *Client {
+	client.client = httpClient
+	if client.cookieJar != nil {
+		client.client.Jar = client.cookieJar
+	}
+	return client
+}
+
+func (client *Client) SetTransport(transport http.RoundTripper) *Client {
+	client.client.Transport = transport
+	return client
+}
+
+func (client *Client) Get(urlPath string) *Request {
+	return newRequest(http.MethodGet, client.stashUri(urlPath), client)
+}
+
+func (client *Client) Put(urlPath string) *Request {
+	return newRequest(http.MethodPut, client.stashUri(urlPath), client)
+}
+
+func (client *Client) Post(urlPath string) *Request {
+	return newRequest(http.MethodPost, client.stashUri(urlPath), client)
+}
+
+func (client *Client) Delete(urlPath string) *Request {
+	return newRequest(http.MethodDelete, client.stashUri(urlPath), client)
+}
+
+func (client *Client) execute(r *Request) (res *http.Response, err error) {
+	var (
+		n      int
+		buf    []byte
+		reader io.Reader
+	)
+	if r.contentType == "" && r.body != nil {
+		r.contentType = r.detectContentType(r.body)
+	}
+	if r.body != nil {
+		if buf, err = r.readRequestBody(r.contentType, r.body); err != nil {
+			return
+		}
+		reader = bytes.NewReader(buf)
+	}
+	if r.rawRequest, err = http.NewRequest(r.method, r.uri, reader); err != nil {
+		return
+	}
+	for k, vs := range r.header {
+		for _, v := range vs {
+			r.rawRequest.Header.Add(k, v)
+		}
+	}
+	if r.contentType != "" {
+		r.rawRequest.Header.Set("Content-Type", r.contentType)
+	}
+	if client.Authorization != nil {
+		r.rawRequest.Header.Set("Authorization", client.Authorization.Token())
+	}
+	if r.context != nil {
+		r.rawRequest = r.rawRequest.WithContext(r.context)
+	}
+	n = len(client.interceptorRequest)
+	for i := n - 1; i >= 0; i-- {
+		if err = client.interceptorRequest[i](r.rawRequest); err != nil {
+			return
+		}
+	}
+	if r.rawResponse, err = client.client.Do(r.rawRequest); err != nil {
+		return nil, err
+	}
+	n = len(client.interceptorResponse)
+	for i := n - 1; i >= 0; i-- {
+		if err = client.interceptorResponse[i](r.rawRequest, r.rawResponse); err != nil {
+			_ = r.rawResponse.Body.Close()
+			return
+		}
+	}
+	return r.rawResponse, err
+}
+
+func New() *Client {
+	client := &Client{
+		client:              http.DefaultClient,
+		interceptorRequest:  make([]BeforeRequest, 0, 10),
+		interceptorResponse: make([]AfterRequest, 0, 10),
+	}
+	client.cookieJar, _ = cookiejar.New(nil)
+	client.client.Jar = client.cookieJar
+	return client
+}

+ 230 - 0
pkg/request/request.go

@@ -0,0 +1,230 @@
+package request
+
+import (
+	"context"
+	"encoding/json"
+	"encoding/xml"
+	"fmt"
+	"io"
+	"net/http"
+	"net/url"
+	"os"
+	"path"
+	"reflect"
+	"regexp"
+	"strings"
+)
+
+const (
+	JSON = "application/json"
+	XML  = "application/xml"
+
+	plainTextType   = "text/plain; charset=utf-8"
+	jsonContentType = "application/json"
+	formContentType = "application/x-www-form-urlencoded"
+)
+
+var (
+	jsonCheck = regexp.MustCompile(`(?i:(application|text)/(json|.*\+json|json\-.*)(;|$))`)
+	xmlCheck  = regexp.MustCompile(`(?i:(application|text)/(xml|.*\+xml)(;|$))`)
+)
+
+type Request struct {
+	context       context.Context
+	method        string
+	uri           string
+	url           *url.URL
+	body          any
+	query         url.Values
+	formData      url.Values
+	header        http.Header
+	contentType   string
+	authorization Authorization
+	client        *Client
+	rawRequest    *http.Request
+	rawResponse   *http.Response
+}
+
+func (r *Request) detectContentType(body interface{}) string {
+	contentType := plainTextType
+	kind := reflect.Indirect(reflect.ValueOf(body)).Type().Kind()
+	switch kind {
+	case reflect.Struct, reflect.Map:
+		contentType = jsonContentType
+	case reflect.String:
+		contentType = plainTextType
+	default:
+		if b, ok := body.([]byte); ok {
+			contentType = http.DetectContentType(b)
+		} else if kind == reflect.Slice {
+			contentType = jsonContentType
+		}
+	}
+	return contentType
+}
+
+func (r *Request) readRequestBody(contentType string, body any) (buf []byte, err error) {
+	var (
+		ok     bool
+		s      string
+		reader io.Reader
+	)
+	kind := reflect.Indirect(reflect.ValueOf(body)).Type().Kind()
+	if reader, ok = r.body.(io.Reader); ok {
+		buf, err = io.ReadAll(reader)
+		goto __end
+	}
+	if buf, ok = r.body.([]byte); ok {
+		goto __end
+	}
+	if s, ok = r.body.(string); ok {
+		buf = []byte(s)
+		goto __end
+	}
+	if jsonCheck.MatchString(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
+		buf, err = json.Marshal(r.body)
+		goto __end
+	}
+	if xmlCheck.MatchString(contentType) && (kind == reflect.Struct) {
+		buf, err = xml.Marshal(r.body)
+		goto __end
+	}
+	err = fmt.Errorf("unmarshal content type %s", contentType)
+__end:
+	return
+}
+
+func (r *Request) SetContext(ctx context.Context) *Request {
+	r.context = ctx
+	return r
+}
+
+func (r *Request) AddQuery(k, v string) *Request {
+	r.query.Add(k, v)
+	return r
+}
+
+func (r *Request) SetQuery(vs map[string]string) *Request {
+	for k, v := range vs {
+		r.query.Set(k, v)
+	}
+	return r
+}
+
+func (r *Request) AddFormData(k, v string) *Request {
+	r.contentType = formContentType
+	r.formData.Add(k, v)
+	return r
+}
+
+func (r *Request) SetFormData(vs map[string]string) *Request {
+	r.contentType = formContentType
+	for k, v := range vs {
+		r.formData.Set(k, v)
+	}
+	return r
+}
+
+func (r *Request) SetBody(v any) *Request {
+	r.body = v
+	return r
+}
+
+func (r *Request) SetContentType(v string) *Request {
+	r.contentType = v
+	return r
+}
+
+func (r *Request) AddHeader(k, v string) *Request {
+	r.header.Add(k, v)
+	return r
+}
+
+func (r *Request) SetHeader(h http.Header) *Request {
+	r.header = h
+	return r
+}
+
+func (r *Request) Do() (res *http.Response, err error) {
+	var s string
+	s = r.formData.Encode()
+	if len(s) > 0 {
+		r.body = s
+	}
+	r.url.RawQuery = r.query.Encode()
+	r.uri = r.url.String()
+	return r.client.execute(r)
+}
+
+func (r *Request) Response(v any) (err error) {
+	var (
+		res         *http.Response
+		buf         []byte
+		contentType string
+	)
+	if res, err = r.Do(); err != nil {
+		return
+	}
+	defer func() {
+		_ = res.Body.Close()
+	}()
+	if res.StatusCode/100 != 2 {
+		if buf, err = io.ReadAll(res.Body); err == nil && len(buf) > 0 {
+			err = fmt.Errorf("http response %s(%d): %s", res.Status, res.StatusCode, string(buf))
+		} else {
+			err = fmt.Errorf("http response %d: %s", res.StatusCode, res.Status)
+		}
+		return
+	}
+	contentType = strings.ToLower(res.Header.Get("Content-Type"))
+	extName := path.Ext(r.rawRequest.URL.String())
+	if strings.Contains(contentType, JSON) || extName == ".json" {
+		err = json.NewDecoder(res.Body).Decode(v)
+	} else if strings.Contains(contentType, XML) || extName == ".xml" {
+		err = xml.NewDecoder(res.Body).Decode(v)
+	} else {
+		err = fmt.Errorf("unsupported content type: %s", contentType)
+	}
+	return
+}
+
+func (r *Request) Download(s string) (err error) {
+	var (
+		fp  *os.File
+		res *http.Response
+	)
+	if res, err = r.Do(); err != nil {
+		return
+	}
+	defer func() {
+		_ = res.Body.Close()
+	}()
+	if fp, err = os.OpenFile(s, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil {
+		return
+	}
+	defer func() {
+		_ = fp.Close()
+	}()
+	_, err = io.Copy(fp, res.Body)
+	return
+}
+
+func newRequest(method string, uri string, client *Client) *Request {
+	var (
+		err error
+	)
+	r := &Request{
+		context:  context.Background(),
+		method:   method,
+		uri:      uri,
+		header:   make(http.Header),
+		formData: make(url.Values),
+		client:   client,
+	}
+	if r.url, err = url.Parse(uri); err == nil {
+		r.query = r.url.Query()
+	} else {
+		r.query = make(url.Values)
+	}
+	return r
+}