Browse Source

添加多类型查询

fancl 2 years ago
parent
commit
3c0c4bf6ca
5 changed files with 129 additions and 42 deletions
  1. 9 2
      crud.go
  2. 8 0
      options.go
  3. 9 5
      plugins/snowflake_id.go
  4. 2 1
      plugins/validate.go
  5. 101 34
      rest.go

+ 9 - 2
crud.go

@@ -65,6 +65,10 @@ func (r *CRUD) createStatement(db *gorm.DB) *gorm.Statement {
 	}
 }
 
+func (r *CRUD) DB() *gorm.DB {
+	return r.db
+}
+
 //DisableCache 禁用缓存
 func (r *CRUD) DisableCache() *CRUD {
 	r.enableCache = false
@@ -103,8 +107,11 @@ func (r *CRUD) Attach(ctx context.Context, model Model, cbs ...Option) (err erro
 	if err = r.db.AutoMigrate(model); err != nil {
 		return
 	}
-	if err = r.migrateSchema(ctx, DefaultNamespace, model); err != nil {
-		return
+	//不启用schema
+	if opts.Schema {
+		if err = r.migrateSchema(ctx, DefaultNamespace, model); err != nil {
+			return
+		}
 	}
 	opts.Delegate = r.delegate
 	opts.LookupFunc = r.VisibleSchemas

+ 8 - 0
options.go

@@ -13,6 +13,7 @@ type Options struct {
 	DB              *gorm.DB
 	Formatter       *Formatter
 	Delegate        *delegate
+	Schema          bool
 	Context         context.Context
 	LookupFunc      func(ctx context.Context, ns string, moduleName string, tableName string, scene string) []*Schema
 }
@@ -25,6 +26,12 @@ func WithContext(ctx context.Context) Option {
 	}
 }
 
+func WithoutSchema(ctx context.Context) Option {
+	return func(o *Options) {
+		o.Schema = false
+	}
+}
+
 func WithDB(db *gorm.DB) Option {
 	return func(o *Options) {
 		o.DB = db
@@ -52,6 +59,7 @@ func WithTablePrefix(prefix string) Option {
 
 func newOptions() *Options {
 	return &Options{
+		Schema:    true,
 		Namespace: DefaultNamespace,
 		Formatter: DefaultFormatter,
 	}

+ 9 - 5
plugins/snowflake_id.go

@@ -2,6 +2,7 @@ package plugins
 
 import (
 	"context"
+	"git.nspix.com/golang/rest/v3"
 	"github.com/bwmarrin/snowflake"
 	"gorm.io/gorm"
 	"gorm.io/gorm/schema"
@@ -16,7 +17,7 @@ var (
 
 func init() {
 	var err error
-	no, _ := strconv.ParseInt(os.Getenv("CC_NODE"), 10, 64)
+	no, _ := strconv.ParseInt(os.Getenv("REST_SNOWFLAKE_NODE"), 10, 64)
 	if no == 0 {
 		no = 1
 	}
@@ -26,8 +27,10 @@ func init() {
 }
 
 //SnowflakeID 自动生成主键ID
-func SnowflakeID(db *gorm.DB) {
-	var err error
+func snowflakeID(db *gorm.DB) {
+	var (
+		err error
+	)
 	if db.Statement.Schema != nil {
 		if field := db.Statement.Schema.LookUpField("ID"); field != nil {
 			if field.DataType == schema.String {
@@ -49,6 +52,7 @@ func SnowflakeID(db *gorm.DB) {
 	}
 }
 
-func RegisterSnowflakeIDCallback(db *gorm.DB) (err error) {
-	return db.Callback().Create().Before("gorm:create").Register("snowflake_id", SnowflakeID)
+func RegisterSnowflakeIDCallback(ri *rest.CRUD) (err error) {
+	db := ri.DB()
+	return db.Callback().Create().Before("gorm:create").Register("snowflake_id", snowflakeID)
 }

+ 2 - 1
plugins/validate.go

@@ -258,7 +258,8 @@ func (vv *validation) validation(db *gorm.DB) {
 	}
 }
 
-func RegisterValidationCallback(db *gorm.DB, crud *rest.CRUD) (err error) {
+func RegisterValidationCallback(crud *rest.CRUD) (err error) {
+	db := crud.DB()
 	callback := db.Callback()
 	vv := &validation{crud: crud}
 	if callback.Create().Get("validations:validate") == nil {

+ 101 - 34
rest.go

@@ -10,6 +10,7 @@ import (
 	"git.nspix.com/golang/rest/v3/inflector"
 	"gorm.io/gorm"
 	"gorm.io/gorm/clause"
+	httppkg "net/http"
 	"path"
 	"reflect"
 	"strconv"
@@ -34,6 +35,16 @@ const (
 	ErrorAccessDeniedMessage = "access denied"
 )
 
+const (
+	OperatorEqual        = "eq"
+	OperatorGreaterThan  = "gt"
+	OperatorGreaterEqual = "ge"
+	OperatorLessThan     = "lt"
+	OperatorLessEqual    = "le"
+	OperatorLike         = "like"
+	OperatorBetween      = "between"
+)
+
 type (
 	DiffAttr struct {
 		Column   string      `json:"column"`
@@ -42,6 +53,13 @@ type (
 		NewValue interface{} `json:"new_value"`
 	}
 
+	Condition struct {
+		Column string        `json:"column"`
+		Expr   string        `json:"expr"`
+		Value  interface{}   `json:"value,omitempty"`
+		Values []interface{} `json:"values,omitempty"`
+	}
+
 	Restful struct {
 		model         Model
 		opts          *Options
@@ -151,6 +169,15 @@ func (r *Restful) hasScenario(s string) bool {
 	return true
 }
 
+func (r *Restful) findCondition(schema *Schema, conditions []*Condition) *Condition {
+	for _, cond := range conditions {
+		if cond.Column == schema.Column {
+			return cond
+		}
+	}
+	return nil
+}
+
 func (r *Restful) prepareConditions(ctx context.Context, requestCtx *http.Context, query *Query, schemas []*Schema) (err error) {
 	var (
 		ok          bool
@@ -170,50 +197,90 @@ func (r *Restful) prepareConditions(ctx context.Context, requestCtx *http.Contex
 			return
 		}
 	}
-	//处理默认的搜索
-	for _, schema := range schemas {
-		skip = false
-		if skip {
-			continue
-		}
-		if schema.Native == 0 {
-			continue
+	if requestCtx.Request().Method == httppkg.MethodPut || requestCtx.Request().Method == httppkg.MethodPost {
+		conditions := make([]*Condition, 0)
+		if err = requestCtx.Bind(&conditions); err != nil {
+			return
 		}
-		formValue = requestCtx.FormValue(schema.Column)
-		switch schema.Format {
-		case FormatString, FormatText:
-			if schema.Attribute.Match == MatchExactly {
-				query.AndFilterWhere(NewCond(schema.Column, formValue))
-			} else {
-				query.AndFilterWhere(NewCond(schema.Column, formValue).WithExpr("LIKE"))
-			}
-		case FormatTime, FormatDate, FormatDatetime, FormatTimestamp:
-			var sep string
-			seps := []byte{',', '/'}
-			for _, s := range seps {
-				if strings.IndexByte(formValue, s) > -1 {
-					sep = string(s)
+		for _, schema := range schemas {
+			cond := r.findCondition(schema, conditions)
+			if cond == nil {
+				continue
+			}
+			switch schema.Format {
+			case FormatInteger, FormatFloat, FormatTimestamp, FormatDatetime, FormatDate, FormatTime:
+				switch cond.Expr {
+				case OperatorBetween:
+					if len(cond.Values) == 2 {
+						query.AndFilterWhere(NewCond(schema.Column, cond.Values[0]).WithExpr(">="))
+						query.AndFilterWhere(NewCond(schema.Column, cond.Values[1]).WithExpr("<="))
+					}
+				case OperatorGreaterThan:
+					query.AndFilterWhere(NewCond(schema.Column, cond.Value).WithExpr(">"))
+				case OperatorGreaterEqual:
+					query.AndFilterWhere(NewCond(schema.Column, cond.Value).WithExpr(">="))
+				case OperatorLessThan:
+					query.AndFilterWhere(NewCond(schema.Column, cond.Value).WithExpr("<"))
+				case OperatorLessEqual:
+					query.AndFilterWhere(NewCond(schema.Column, cond.Value).WithExpr("<="))
+				default:
+					query.AndFilterWhere(NewCond(schema.Column, cond.Value))
+				}
+			default:
+				switch cond.Expr {
+				case OperatorLike:
+					query.AndFilterWhere(NewCond(schema.Column, cond.Value).WithExpr("LIKE"))
+				default:
+					query.AndFilterWhere(NewCond(schema.Column, cond.Value))
 				}
 			}
-			if ss := strings.Split(formValue, sep); len(ss) == 2 {
-				query.AndFilterWhere(
-					NewCond(schema.Column, strings.TrimSpace(ss[0])).WithExpr(">="),
-					NewCond(schema.Column, strings.TrimSpace(ss[1])).WithExpr("<="),
-				)
-			} else {
-				query.AndFilterWhere(NewCond(schema.Column, formValue))
+		}
+	} else {
+		//处理默认的搜索
+		for _, schema := range schemas {
+			skip = false
+			if skip {
+				continue
+			}
+			if schema.Native == 0 {
+				continue
 			}
-		case FormatInteger, FormatFloat:
-			query.AndFilterWhere(NewCond(schema.Column, formValue))
-		default:
-			if schema.Type == TypeString {
+			formValue = requestCtx.FormValue(schema.Column)
+			switch schema.Format {
+			case FormatString, FormatText:
 				if schema.Attribute.Match == MatchExactly {
 					query.AndFilterWhere(NewCond(schema.Column, formValue))
 				} else {
 					query.AndFilterWhere(NewCond(schema.Column, formValue).WithExpr("LIKE"))
 				}
-			} else {
+			case FormatTime, FormatDate, FormatDatetime, FormatTimestamp:
+				var sep string
+				seps := []byte{',', '/'}
+				for _, s := range seps {
+					if strings.IndexByte(formValue, s) > -1 {
+						sep = string(s)
+					}
+				}
+				if ss := strings.Split(formValue, sep); len(ss) == 2 {
+					query.AndFilterWhere(
+						NewCond(schema.Column, strings.TrimSpace(ss[0])).WithExpr(">="),
+						NewCond(schema.Column, strings.TrimSpace(ss[1])).WithExpr("<="),
+					)
+				} else {
+					query.AndFilterWhere(NewCond(schema.Column, formValue))
+				}
+			case FormatInteger, FormatFloat:
 				query.AndFilterWhere(NewCond(schema.Column, formValue))
+			default:
+				if schema.Type == TypeString {
+					if schema.Attribute.Match == MatchExactly {
+						query.AndFilterWhere(NewCond(schema.Column, formValue))
+					} else {
+						query.AndFilterWhere(NewCond(schema.Column, formValue).WithExpr("LIKE"))
+					}
+				} else {
+					query.AndFilterWhere(NewCond(schema.Column, formValue))
+				}
 			}
 		}
 	}