ソースを参照

修复校验规则

lxg 4 年 前
コミット
4b4e9bbbda
3 ファイル変更56 行追加30 行削除
  1. 19 18
      crud/entity.go
  2. 34 11
      orm/validator/callback.go
  3. 3 1
      reset_test.go

+ 19 - 18
crud/entity.go

@@ -5,6 +5,12 @@ import (
 	"encoding/csv"
 	"encoding/json"
 	"fmt"
+	"reflect"
+	"strconv"
+	"strings"
+	"sync/atomic"
+	"time"
+
 	"git.nspix.com/golang/micro/gateway/http"
 	"git.nspix.com/golang/rest/orm/query"
 	"git.nspix.com/golang/rest/orm/schema"
@@ -12,11 +18,6 @@ import (
 	"git.nspix.com/golang/rest/scenario"
 	"gorm.io/gorm"
 	"gorm.io/gorm/clause"
-	"reflect"
-	"strconv"
-	"strings"
-	"sync/atomic"
-	"time"
 )
 
 var seq int64
@@ -204,7 +205,7 @@ func (e *Entity) actionIndex(c *http.Context) (err error) {
 	search := query.New(e.db)
 	if e.Callbacks.BeforeQuery != nil {
 		if err = e.Callbacks.BeforeQuery(e.Value, search, c); err != nil {
-			return c.Error(8004, err.Error())
+			return c.Error(8001, err.Error())
 		}
 	}
 	searchSchemas := schema.VisibleField(e.Module, e.stmt.Table, scenario.Search)
@@ -214,7 +215,7 @@ func (e *Entity) actionIndex(c *http.Context) (err error) {
 	search.Offset(pv * pagesize).Limit(pagesize)
 	//添加排序支持
 	if err = search.All(models.Interface()); err != nil {
-		return c.Error(8004, err.Error())
+		return c.Error(8002, err.Error())
 	}
 	return c.Success(map[string]interface{}{
 		"page":       page,
@@ -227,16 +228,16 @@ func (e *Entity) actionIndex(c *http.Context) (err error) {
 func (e *Entity) actionCreate(c *http.Context) (err error) {
 	val := reflect.New(e.refType).Interface()
 	if err = c.Bind(val); err != nil {
-		return c.Error(8002, err.Error())
+		return c.Error(1005, err.Error())
 	}
 	if e.Callbacks.BeforeInsert != nil {
 		if err = e.Callbacks.BeforeInsert(val, c); err != nil {
-			return c.Error(8004, err.Error())
+			return c.Error(8003, err.Error())
 		}
 	}
 	if e.Callbacks.BeforeSave != nil {
 		if err = e.Callbacks.BeforeSave(val, c); err != nil {
-			return c.Error(8004, err.Error())
+			return c.Error(8005, err.Error())
 		}
 	}
 	sess := e.db.Create(val)
@@ -244,7 +245,7 @@ func (e *Entity) actionCreate(c *http.Context) (err error) {
 		if err, ok := sess.Error.(*validator.StructError); ok {
 			return c.Error(1001, err.Error())
 		} else {
-			return c.Error(8004, sess.Error.Error())
+			return c.Error(8006, sess.Error.Error())
 		}
 	}
 	if e.Callbacks.AfterInsert != nil {
@@ -266,16 +267,16 @@ func (e *Entity) actionUpdate(c *http.Context) (err error) {
 		return c.Error(8004, err.Error())
 	}
 	if err = c.Bind(val); err != nil {
-		return c.Error(8002, err.Error())
+		return c.Error(1005, err.Error())
 	}
 	if e.Callbacks.BeforeUpdate != nil {
 		if err = e.Callbacks.BeforeUpdate(val, c); err != nil {
-			return c.Error(8004, err.Error())
+			return c.Error(8007, err.Error())
 		}
 	}
 	if e.Callbacks.BeforeSave != nil {
 		if err = e.Callbacks.BeforeSave(val, c); err != nil {
-			return c.Error(8004, err.Error())
+			return c.Error(8005, err.Error())
 		}
 	}
 	sess := e.db.Model(val).Updates(val)
@@ -283,7 +284,7 @@ func (e *Entity) actionUpdate(c *http.Context) (err error) {
 		if err, ok := sess.Error.(*validator.StructError); ok {
 			return c.Error(1001, err.Error())
 		} else {
-			return c.Error(8004, sess.Error.Error())
+			return c.Error(8008, sess.Error.Error())
 		}
 	}
 	if e.Callbacks.AfterUpdate != nil {
@@ -306,7 +307,7 @@ func (e *Entity) actionDelete(c *http.Context) (err error) {
 	}
 	if e.Callbacks.BeforeDelete != nil {
 		if err = e.Callbacks.BeforeDelete(val, c); err != nil {
-			return c.Error(8004, err.Error())
+			return c.Error(8009, err.Error())
 		}
 	}
 	if err = e.db.Where(e.primaryKey()+"=?", c.ParamValue("id")).Model(val).Delete(val).Error; err == nil {
@@ -318,7 +319,7 @@ func (e *Entity) actionDelete(c *http.Context) (err error) {
 			"id":    idStr,
 		})
 	} else {
-		return c.Error(8009, err.Error())
+		return c.Error(8010, err.Error())
 	}
 }
 
@@ -350,7 +351,7 @@ func (e *Entity) actionExport(c *http.Context) (err error) {
 	//添加排序支持
 	e.buildSortable(c, search)
 	if err = search.All(models.Interface()); err != nil {
-		return c.Error(8004, err.Error())
+		return c.Error(8002, err.Error())
 	}
 	c.Response().Header().Set("Content-Type", "text/csv")
 	c.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", e.naming.plural+"-"+time.Now().Format("20060102")))

+ 34 - 11
orm/validator/callback.go

@@ -4,21 +4,22 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"reflect"
+	"regexp"
+	"strconv"
+	"strings"
+
 	"git.nspix.com/golang/rest/internal/empty"
 	"git.nspix.com/golang/rest/orm/schema"
 	"git.nspix.com/golang/rest/scenario"
 	"github.com/go-playground/locales/en"
 	"github.com/go-playground/locales/zh"
-	"github.com/go-playground/universal-translator"
+	ut "github.com/go-playground/universal-translator"
 	"github.com/go-playground/validator/v10"
 	translation "github.com/go-playground/validator/v10/translations/zh"
 	"gorm.io/gorm"
 	"gorm.io/gorm/clause"
 	shm "gorm.io/gorm/schema"
-	"reflect"
-	"regexp"
-	"strconv"
-	"strings"
 )
 
 type validateScope struct{}
@@ -166,10 +167,11 @@ func generateTag(scene string, rule schema.Rule) string {
 
 func validation(db *gorm.DB) {
 	var (
-		err   error
-		tag   string
-		rule  schema.Rule
-		scene = scenario.Create
+		err          error
+		tag          string
+		skipValidate bool
+		rule         schema.Rule
+		scene        = scenario.Create
 	)
 	if _, ok := db.Get("gorm:update_column"); ok {
 		return
@@ -198,8 +200,29 @@ func validation(db *gorm.DB) {
 				if !fieldValue.IsValid() {
 					continue
 				}
-				//如果没有必填 并且值为空跳过验证
-				if !strings.Contains(tag, "required") && empty.Is(fieldValue.Interface()) {
+				skipValidate = false
+				if strings.Contains(tag, "required") {
+					//如果数值为整形,小数,Bool跳过验证
+					if fieldValue.Interface() != nil {
+						vType := reflect.ValueOf(fieldValue.Interface())
+						switch vType.Kind() {
+						case reflect.Bool:
+							skipValidate = true
+						case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+							skipValidate = true
+						case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+							skipValidate = true
+						case reflect.Float32, reflect.Float64:
+							skipValidate = true
+						}
+					}
+				} else {
+					//如果没有必填 并且值为空跳过验证
+					if empty.Is(fieldValue.Interface()) {
+						continue
+					}
+				}
+				if skipValidate {
 					continue
 				}
 				ctx := context.WithValue(context.Background(), validateScopeKey, &scope{

+ 3 - 1
reset_test.go

@@ -2,10 +2,11 @@ package rest
 
 import (
 	"fmt"
+	"testing"
+
 	"git.nspix.com/golang/rest/orm/schema"
 	"gorm.io/driver/sqlite"
 	"gorm.io/gorm"
-	"testing"
 )
 
 type TestModel struct {
@@ -33,6 +34,7 @@ func TestRemoteDriver(t *testing.T) {
 	if err != nil {
 		t.Fatal(err.Error())
 	}
+
 	schema.SetDriver(&schema.RemoteDriver{Url: "https://schema.nspix.com/schema"})
 	Initialize(nil, db)
 	if err = RegisterModel("test", &TestModel{}); err != nil {