Browse Source

规范导出函数

lxg 3 năm trước cách đây
mục cha
commit
ef65fc6bef
4 tập tin đã thay đổi với 30 bổ sung22 xóa
  1. 4 2
      cmd/main.go
  2. 15 0
      crud.go
  3. 11 6
      plugins/validate.go
  4. 0 14
      schema.go

+ 4 - 2
cmd/main.go

@@ -1,6 +1,7 @@
 package main
 package main
 
 
 import (
 import (
+	"fmt"
 	"time"
 	"time"
 
 
 	"git.nspix.com/golang/rest/v2"
 	"git.nspix.com/golang/rest/v2"
@@ -31,14 +32,15 @@ func main() {
 		crud *rest.CRUD
 		crud *rest.CRUD
 		err  error
 		err  error
 	)
 	)
-	dsn := "root:root@tcp(sv2:43699)/rest?charset=utf8mb4&parseTime=True&loc=Local"
+	dsn := "root:root@tcp(192.168.9.199:3306)/rest?charset=utf8mb4&parseTime=True&loc=Local"
 	if db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil {
 	if db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil {
 		return
 		return
 	}
 	}
 	if crud, err = rest.NewCRUD(db, nil); err != nil {
 	if crud, err = rest.NewCRUD(db, nil); err != nil {
+		fmt.Println(err)
 		return
 		return
 	}
 	}
-	crud.Attach(&User{})
+	fmt.Println(crud.Attach(&User{}))
 	crud.Routes(nil)
 	crud.Routes(nil)
 
 
 	time.Sleep(time.Second * 5)
 	time.Sleep(time.Second * 5)

+ 15 - 0
crud.go

@@ -1,6 +1,7 @@
 package rest
 package rest
 
 
 import (
 import (
+	"reflect"
 	"sync"
 	"sync"
 
 
 	"git.nspix.com/golang/micro/gateway/http"
 	"git.nspix.com/golang/micro/gateway/http"
@@ -31,6 +32,10 @@ func (t *treeValue) Append(v *treeValue) {
 	t.Children = append(t.Children, v)
 	t.Children = append(t.Children, v)
 }
 }
 
 
+func (crud *CRUD) DB() *gorm.DB {
+	return crud.db
+}
+
 //Use 附加一个钩子函数
 //Use 附加一个钩子函数
 func (crud *CRUD) Use(h Hook) {
 func (crud *CRUD) Use(h Hook) {
 	crud.hooks = append(crud.hooks, h)
 	crud.hooks = append(crud.hooks, h)
@@ -202,6 +207,16 @@ func (crud *CRUD) Schemas(namespace, moduleName, tableName, scenario string) []*
 	return nil
 	return nil
 }
 }
 
 
+// IsNewRecord 判断该模型是不是一条新记录
+func (crud *CRUD) IsNewRecord(value reflect.Value, stmt *gorm.Statement) bool {
+	for _, pf := range stmt.Schema.PrimaryFields {
+		if _, isZero := pf.ValueOf(value); isZero {
+			return true
+		}
+	}
+	return false
+}
+
 //NewCRUD 创建一个新的CRUD模型
 //NewCRUD 创建一个新的CRUD模型
 func NewCRUD(db *gorm.DB, svr *http.Server) (crud *CRUD, err error) {
 func NewCRUD(db *gorm.DB, svr *http.Server) (crud *CRUD, err error) {
 	if err = initSchema(db); err != nil {
 	if err = initSchema(db); err != nil {

+ 11 - 6
plugins/validate.go

@@ -30,6 +30,10 @@ type (
 		Column string
 		Column string
 		Model  interface{}
 		Model  interface{}
 	}
 	}
+
+	validation struct {
+		curd *rest.CRUD
+	}
 )
 )
 
 
 var (
 var (
@@ -144,7 +148,7 @@ func formatError(rule rest.Rule, scm *rest.Schema, tag string) string {
 	return s
 	return s
 }
 }
 
 
-func validation(db *gorm.DB) {
+func (vv *validation) validation(db *gorm.DB) {
 	if result, ok := db.Get(SkipValidations); ok && result.(bool) {
 	if result, ok := db.Get(SkipValidations); ok && result.(bool) {
 		return
 		return
 	}
 	}
@@ -174,7 +178,7 @@ func validation(db *gorm.DB) {
 			break
 			break
 		}
 		}
 	}
 	}
-	schemas = rest.VisibleSchemas(stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
+	schemas = vv.curd.Schemas(stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
 	for _, scm := range schemas {
 	for _, scm := range schemas {
 		if scm.Rules == "" {
 		if scm.Rules == "" {
 			continue
 			continue
@@ -235,12 +239,13 @@ func validation(db *gorm.DB) {
 	}
 	}
 }
 }
 
 
-func RegisterValidationCallback(db *gorm.DB) {
-	callback := db.Callback()
+func RegisterValidationCallback(curd *rest.CRUD) {
+	callback := curd.DB().Callback()
+	vv := &validation{curd: curd}
 	if callback.Create().Get("validations:validate") == nil {
 	if callback.Create().Get("validations:validate") == nil {
-		_ = callback.Create().Before("gorm:before_create").Register("validations:validate", validation)
+		_ = callback.Create().Before("gorm:before_create").Register("validations:validate", vv.validation)
 	}
 	}
 	if callback.Update().Get("validations:validate") == nil {
 	if callback.Update().Get("validations:validate") == nil {
-		_ = callback.Update().Before("gorm:before_update").Register("validations:validate", validation)
+		_ = callback.Update().Before("gorm:before_update").Register("validations:validate", vv.validation)
 	}
 	}
 }
 }

+ 0 - 14
schema.go

@@ -209,10 +209,6 @@ func createStatement(db *gorm.DB) *gorm.Statement {
 	}
 	}
 }
 }
 
 
-func VisibleSchemas(namespace, modelName, tableName, scenario string) (schemas []*Schema) {
-	return visibleSchemas(namespace, modelName, tableName, scenario)
-}
-
 // visibleSchemas 获取某个场景下面的字段
 // visibleSchemas 获取某个场景下面的字段
 func visibleSchemas(namespace, modelName, tableName, scenario string) (schemas []*Schema) {
 func visibleSchemas(namespace, modelName, tableName, scenario string) (schemas []*Schema) {
 	schemas, _ = getSchemas(namespace, modelName, tableName)
 	schemas, _ = getSchemas(namespace, modelName, tableName)
@@ -266,16 +262,6 @@ func invalidCache(namespace, moduleName, tableName string) {
 	schemaCache.Remove(cacheKey)
 	schemaCache.Remove(cacheKey)
 }
 }
 
 
-// IsNewRecord 判断该模型是不是一条新记录
-func IsNewRecord(value reflect.Value, stmt *gorm.Statement) bool {
-	for _, pf := range stmt.Schema.PrimaryFields {
-		if _, isZero := pf.ValueOf(value); isZero {
-			return true
-		}
-	}
-	return false
-}
-
 // extraSize 提取模型定义的长度大小
 // extraSize 提取模型定义的长度大小
 func extraSize(str string) int {
 func extraSize(str string) int {
 	var (
 	var (