Browse Source

修复插件里面获取schema在同一个会话导致数据获取失败

sugar 3 years ago
parent
commit
66f672e7c6
6 changed files with 145 additions and 94 deletions
  1. 5 1
      cmd/main.go
  2. 5 4
      crud.go
  3. 66 47
      entity.go
  4. 5 4
      options.go
  5. 6 5
      plugins/validate.go
  6. 58 33
      schema.go

+ 5 - 1
cmd/main.go

@@ -1,6 +1,8 @@
 package main
 
 import (
+	"time"
+
 	"git.nspix.com/golang/rest/v2"
 	"gorm.io/driver/mysql"
 	"gorm.io/gorm"
@@ -29,7 +31,7 @@ func main() {
 		crud *rest.CRUD
 		err  error
 	)
-	dsn := "root:root@tcp(192.168.9.199:3306)/rest?charset=utf8mb4&parseTime=True&loc=Local"
+	dsn := "root:root@tcp(sv2:43699)/rest?charset=utf8mb4&parseTime=True&loc=Local"
 	if db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil {
 		return
 	}
@@ -38,4 +40,6 @@ func main() {
 	}
 	crud.Attach(&User{})
 	crud.Routes(nil)
+
+	time.Sleep(time.Second * 5)
 }

+ 5 - 4
crud.go

@@ -1,10 +1,11 @@
 package rest
 
 import (
+	"sync"
+
 	"git.nspix.com/golang/micro/gateway/http"
 	"git.nspix.com/golang/micro/log"
 	"gorm.io/gorm"
-	"sync"
 )
 
 type (
@@ -132,7 +133,7 @@ func (crud *CRUD) Attach(model Model, ops ...Option) (err error) {
 		opts.DB = crud.db
 	}
 	//migrate table schema
-	if err = migrate(opts.DB, model); err != nil {
+	if err = migrateUp("", model); err != nil {
 		return
 	}
 	scenarios := model.Scenario()
@@ -192,14 +193,14 @@ func (crud *CRUD) Routes(ms ...http.Middleware) {
 func (crud *CRUD) Schemas(namespace, moduleName, tableName, scenario string) []*Schema {
 	if v, ok := crud.entities.Load(tableName + "@" + moduleName); ok {
 		e := v.(*Entity)
-		return visibleSchemas(crud.db, namespace, e.model.ModuleName(), e.model.TableName(), scenario)
+		return visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), scenario)
 	}
 	return nil
 }
 
 //NewCRUD 创建一个新的CRUD模型
 func NewCRUD(db *gorm.DB, svr *http.Server) (crud *CRUD, err error) {
-	if err = db.AutoMigrate(&Schema{}); err != nil {
+	if err = initSchema(db); err != nil {
 		return
 	}
 	crud = &CRUD{

+ 66 - 47
entity.go

@@ -5,16 +5,18 @@ import (
 	"encoding/csv"
 	"encoding/json"
 	"fmt"
+	"path"
+	"reflect"
+	"strconv"
+	"strings"
+	"time"
+
 	"git.nspix.com/golang/micro/gateway/http"
 	"git.nspix.com/golang/rest/v2/errors"
 	"git.nspix.com/golang/rest/v2/internal/inflector"
 	lru "github.com/hashicorp/golang-lru"
 	"gorm.io/gorm"
 	"gorm.io/gorm/clause"
-	"path"
-	"reflect"
-	"strconv"
-	"strings"
 )
 
 const (
@@ -52,6 +54,7 @@ type Entity struct {
 	scenarios            []string
 	hooks                []Hook
 	lruCache             *lru.Cache
+	createdAt            time.Time
 }
 
 func (e *Entity) ID() string {
@@ -194,7 +197,7 @@ func (e *Entity) getScenarioUrl(scenario string) string {
 		uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "/:id"
 	case ScenarioExport:
 		uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "-export"
-	case "mapping":
+	case ScenarioMapping:
 		uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "-mapping"
 	}
 	return path.Clean(uri)
@@ -216,7 +219,7 @@ func (e *Entity) getScenarioHandle(scenario string) http.HandleFunc {
 		handleFunc = e.actionDelete
 	case ScenarioExport:
 		handleFunc = e.actionExport
-	case "mapping":
+	case ScenarioMapping:
 		handleFunc = e.actionMapping
 	}
 	return handleFunc
@@ -232,11 +235,6 @@ func (e *Entity) prepareConditions(ctx *http.Context, query *Query, schemas []*S
 	)
 	model = reflect.New(e.reflectType).Interface()
 	activeModel, _ = model.(FilterColumnInterface)
-
-	namespace := ctx.ParamValue("@namespace")
-	if namespace != "" {
-		query.AndWhere(NewQueryCondition("namespace", namespace))
-	}
 	//处理默认的搜索
 	for _, schema := range schemas {
 		if activeModel != nil {
@@ -285,7 +283,6 @@ func (e *Entity) prepareConditions(ctx *http.Context, query *Query, schemas []*S
 			}
 		}
 	}
-
 	//处理排序
 	sortPar := ctx.FormValue("sort")
 	if sortPar != "" {
@@ -360,9 +357,12 @@ func (e *Entity) actionIndex(ctx *http.Context) (err error) {
 	models := reflect.New(sliceValue.Type())
 	models.Elem().Set(sliceValue)
 	query = NewQuery(e.opts.DB)
-	searchSchemas := visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), ScenarioSearch)
-	indexSchemas := visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), ScenarioList)
+	searchSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioSearch)
+	indexSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioList)
 	e.prepareConditions(ctx, query, searchSchemas)
+	if e.opts.EnableNamespace {
+		query.AndFilterWhere(NewQueryCondition("namespace", namespace))
+	}
 	query.Offset(pageIndex * pageSize).Limit(pageSize)
 	if err = query.All(models.Interface()); err != nil {
 		return ctx.Error(HttpDatabaseQueryFailed, err.Error())
@@ -392,16 +392,22 @@ func (e *Entity) actionView(ctx *http.Context) (err error) {
 	scenario := ctx.FormValue("scenario")
 	idStr := ctx.ParamValue("id")
 	model = reflect.New(e.reflectType).Interface()
-	if err = e.opts.DB.Where(e.primaryKey+"=? AND namespace=?", idStr, namespace).First(model).Error; err != nil {
+	conditions := map[string]interface{}{
+		e.primaryKey: idStr,
+	}
+	if e.opts.EnableNamespace {
+		conditions["namespace"] = namespace
+	}
+	if err = e.opts.DB.Where(conditions).First(model).Error; err != nil {
 		return ctx.Error(HttpDatabaseFindFailed, err.Error())
 	}
 	if ctx.FormValue("format") != "" {
 		//获取指定场景下面的字段进行渲染显示
 		var schemas []*Schema
 		if scenario == "" {
-			schemas = visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), ScenarioView)
+			schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioView)
 		} else {
-			schemas = visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), scenario)
+			schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), scenario)
 		}
 		requestCtx := ctx.Request().Context()
 		if requestCtx == nil {
@@ -426,9 +432,12 @@ func (e *Entity) actionExport(ctx *http.Context) (err error) {
 	models := reflect.New(sliceValue.Type())
 	models.Elem().Set(sliceValue)
 	query = NewQuery(e.opts.DB)
-	searchSchemas := visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), ScenarioSearch)
-	exportSchemas := visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), ScenarioList)
+	searchSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioSearch)
+	exportSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioList)
 	e.prepareConditions(ctx, query, searchSchemas)
+	if e.opts.EnableNamespace {
+		query.AndFilterWhere(NewQueryCondition("namespace", namespace))
+	}
 	if err = query.All(models.Interface()); err != nil {
 		return ctx.Error(HttpDatabaseExportFailed, err.Error())
 	}
@@ -485,10 +494,10 @@ func (e *Entity) actionCreate(ctx *http.Context) (err error) {
 	if err = ctx.Bind(model); err != nil {
 		return ctx.Error(HttpInvalidPayload, err.Error())
 	}
-	schemas = visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), ScenarioCreate)
+	schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioCreate)
 	//设置某个字段的值
 	e.setFieldValue(refModel, "namespace", namespace)
-	//设置默认的创建用户和更新用户信息
+	//global set field value
 	e.setFieldValue(refModel, "CreatedBy", ctx.ParamValue("@uid"))
 	e.setFieldValue(refModel, "CreatedDept", ctx.ParamValue("@department"))
 	e.setFieldValue(refModel, "UpdatedBy", ctx.ParamValue("@uid"))
@@ -538,18 +547,16 @@ func (e *Entity) actionCreate(ctx *http.Context) (err error) {
 			"table": e.model.TableName(),
 			"state": "created",
 		})
-	} else {
-		//数据校验不合法
-		if validateError, ok := err.(*errors.StructError); ok {
-			ctx.Response().Header().Set("Content-Type", "application/json")
-			return json.NewEncoder(ctx.Response()).Encode(map[string]interface{}{
-				"errno":  HttpValidateFailed,
-				"result": validateError,
-			})
-		} else {
-			return ctx.Error(HttpDatabaseCreateFailed, err.Error())
-		}
 	}
+	//form validation
+	if validateError, ok := err.(*errors.StructError); ok {
+		ctx.Response().Header().Set("Content-Type", "application/json")
+		return json.NewEncoder(ctx.Response()).Encode(map[string]interface{}{
+			"errno":  HttpValidateFailed,
+			"result": validateError,
+		})
+	}
+	return ctx.Error(HttpDatabaseCreateFailed, err.Error())
 }
 
 func (e *Entity) actionUpdate(ctx *http.Context) (err error) {
@@ -573,10 +580,16 @@ func (e *Entity) actionUpdate(ctx *http.Context) (err error) {
 	//默认设置更新用户
 	e.setFieldValue(refModel, "UpdatedBy", ctx.ParamValue("@uid"))
 	e.setFieldValue(refModel, "UpdatedDept", ctx.ParamValue("@department"))
-	if err = e.opts.DB.Where(e.primaryKey+"=? AND namespace=?", idStr, namespace).First(model).Error; err != nil {
+	conditions := map[string]interface{}{
+		e.primaryKey: idStr,
+	}
+	if e.opts.EnableNamespace {
+		conditions["namespace"] = namespace
+	}
+	if err = e.opts.DB.Where(conditions).First(model).Error; err != nil {
 		return ctx.Error(HttpDatabaseFindFailed, err.Error())
 	}
-	schemas = visibleSchemas(e.opts.DB, namespace, e.model.ModuleName(), e.model.TableName(), ScenarioUpdate)
+	schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioUpdate)
 	for _, scm := range schemas {
 		oldValues[scm.Column] = e.getFieldValue(refModel, scm.Column)
 	}
@@ -621,7 +634,7 @@ func (e *Entity) actionUpdate(ctx *http.Context) (err error) {
 		}
 		return errTx
 	}); err == nil {
-		e.invalidCache(ctx.ParamValue("@namespace"))
+		e.invalidCache(namespace)
 		pkVal := e.getPrimaryKeyValue(model)
 		if len(e.hooks) > 0 {
 			for _, hook := range e.hooks {
@@ -633,17 +646,16 @@ func (e *Entity) actionUpdate(ctx *http.Context) (err error) {
 			"table": e.model.TableName(),
 			"state": "updated",
 		})
-	} else {
-		if validateError, ok := err.(*errors.StructError); ok {
-			ctx.Response().Header().Set("Content-Type", "application/json")
-			return json.NewEncoder(ctx.Response()).Encode(map[string]interface{}{
-				"errno":  HttpValidateFailed,
-				"result": validateError,
-			})
-		} else {
-			return ctx.Error(HttpDatabaseUpdateFailed, err.Error())
-		}
 	}
+	//form validation
+	if validateError, ok := err.(*errors.StructError); ok {
+		ctx.Response().Header().Set("Content-Type", "application/json")
+		return json.NewEncoder(ctx.Response()).Encode(map[string]interface{}{
+			"errno":  HttpValidateFailed,
+			"result": validateError,
+		})
+	}
+	return ctx.Error(HttpDatabaseUpdateFailed, err.Error())
 }
 
 func (e *Entity) actionDelete(ctx *http.Context) (err error) {
@@ -658,7 +670,13 @@ func (e *Entity) actionDelete(ctx *http.Context) (err error) {
 	idStr := ctx.ParamValue("id")
 	namespace = ctx.ParamValue("@namespace")
 	model = reflect.New(e.reflectType).Interface()
-	if err = e.opts.DB.Where(e.primaryKey+"=? AND namespace=?", idStr, namespace).First(model).Error; err != nil {
+	conditions := map[string]interface{}{
+		e.primaryKey: idStr,
+	}
+	if e.opts.EnableNamespace {
+		conditions["namespace"] = namespace
+	}
+	if err = e.opts.DB.Where(conditions).First(model).Error; err != nil {
 		return ctx.Error(HttpDatabaseFindFailed, err.Error())
 	}
 	if err = e.opts.DB.Transaction(func(tx *gorm.DB) error {
@@ -676,7 +694,7 @@ func (e *Entity) actionDelete(ctx *http.Context) (err error) {
 		}
 		return errTx
 	}); err == nil {
-		e.invalidCache(ctx.ParamValue("@namespace"))
+		e.invalidCache(namespace)
 		return ctx.Success(map[string]interface{}{
 			"id":    e.getPrimaryKeyValue(model),
 			"table": e.model.TableName(),
@@ -696,6 +714,7 @@ func newEntity(model Model, opts *Options) *Entity {
 	entity := &Entity{
 		model:        model,
 		opts:         opts,
+		createdAt:    time.Now(),
 		reflectValue: reflect.Indirect(reflect.ValueOf(model)),
 	}
 	entity.lruCache, _ = lru.New(50)

+ 5 - 4
options.go

@@ -6,10 +6,11 @@ import (
 )
 
 type Options struct {
-	DB         *gorm.DB
-	Prefix     string
-	Formatter  *Formatter
-	Middleware []http.Middleware
+	EnableNamespace bool
+	DB              *gorm.DB
+	Prefix          string
+	Formatter       *Formatter
+	Middleware      []http.Middleware
 }
 
 type Option func(o *Options)

+ 6 - 5
plugins/validate.go

@@ -4,6 +4,11 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"reflect"
+	"regexp"
+	"strconv"
+	"strings"
+
 	"git.nspix.com/golang/micro/helper/utils"
 	"git.nspix.com/golang/rest/v2"
 	errors2 "git.nspix.com/golang/rest/v2/errors"
@@ -11,10 +16,6 @@ import (
 	"gorm.io/gorm"
 	"gorm.io/gorm/clause"
 	"gorm.io/gorm/schema"
-	"reflect"
-	"regexp"
-	"strconv"
-	"strings"
 )
 
 const (
@@ -173,7 +174,7 @@ func validation(db *gorm.DB) {
 			break
 		}
 	}
-	schemas = rest.VisibleSchemas(db, stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
+	schemas = rest.VisibleSchemas(stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
 	for _, scm := range schemas {
 		if scm.Rules == "" {
 			continue

+ 58 - 33
schema.go

@@ -3,15 +3,16 @@ package rest
 import (
 	"encoding/json"
 	"errors"
-	"git.nspix.com/golang/micro/helper/utils"
-	"github.com/hashicorp/golang-lru"
-	"gorm.io/gorm"
-	"gorm.io/gorm/clause"
-	"gorm.io/gorm/schema"
 	"reflect"
 	"strconv"
 	"strings"
 	"time"
+
+	"git.nspix.com/golang/micro/helper/utils"
+	lru "github.com/hashicorp/golang-lru"
+	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
+	"gorm.io/gorm/schema"
 )
 
 var (
@@ -21,16 +22,21 @@ var (
 	timePtrKind = reflect.TypeOf(&time.Time{}).Kind()
 
 	schemaCache, _ = lru.New(512)
+
+	DefaultNamespace = "default"
+
+	schemaDB *gorm.DB
 )
 
 const (
-	ScenarioCreate = "create"
-	ScenarioUpdate = "update"
-	ScenarioDelete = "delete"
-	ScenarioSearch = "search"
-	ScenarioExport = "export"
-	ScenarioList   = "list"
-	ScenarioView   = "view"
+	ScenarioCreate  = "create"
+	ScenarioUpdate  = "update"
+	ScenarioDelete  = "delete"
+	ScenarioSearch  = "search"
+	ScenarioExport  = "export"
+	ScenarioList    = "list"
+	ScenarioView    = "view"
+	ScenarioMapping = "mapping"
 
 	Basic    = "basic"
 	Advanced = "advanced"
@@ -121,6 +127,15 @@ func (r *Rule) String() string {
 	return string(buf)
 }
 
+// getProperties 获取属性
+func (schema *Schema) getProperties() *Properties {
+	if schema.properties == nil {
+		schema.properties = &Properties{}
+		_ = json.Unmarshal([]byte(schema.Properties), schema.properties)
+	}
+	return schema.properties
+}
+
 func dataTypeOf(field *schema.Field) string {
 	var dataType string
 	reflectType := field.FieldType
@@ -194,13 +209,13 @@ func createStatement(db *gorm.DB) *gorm.Statement {
 	}
 }
 
-func VisibleSchemas(db *gorm.DB, namespace, modelName, tableName, scenario string) (schemas []*Schema) {
-	return visibleSchemas(db, namespace, modelName, tableName, scenario)
+func VisibleSchemas(namespace, modelName, tableName, scenario string) (schemas []*Schema) {
+	return visibleSchemas(namespace, modelName, tableName, scenario)
 }
 
 // visibleSchemas 获取某个场景下面的字段
-func visibleSchemas(db *gorm.DB, namespace, modelName, tableName, scenario string) (schemas []*Schema) {
-	schemas, _ = getSchemas(db, namespace, modelName, tableName)
+func visibleSchemas(namespace, modelName, tableName, scenario string) (schemas []*Schema) {
+	schemas, _ = getSchemas(namespace, modelName, tableName)
 	values := make([]*Schema, 0)
 	for _, scm := range schemas {
 		if scm.Enable != 1 {
@@ -218,13 +233,16 @@ func visibleSchemas(db *gorm.DB, namespace, modelName, tableName, scenario strin
 }
 
 // getSchemas 获取某个模型下面所有的字段配置
-func getSchemas(db *gorm.DB, namespace, moduleName, tableName string) (schemas []*Schema, err error) {
+func getSchemas(namespace, moduleName, tableName string) (schemas []*Schema, err error) {
 	schemas = make([]*Schema, 0)
 	cacheKey := namespace + ":" + tableName + "@" + moduleName
 	if v, ok := schemaCache.Get(cacheKey); ok {
 		return v.([]*Schema), nil
 	}
-	if err = db.Where("`namespace`=? AND `module_name`=? AND `table_name`=?", namespace, moduleName, tableName).Order("position,id ASC").Find(&schemas).Error; err == nil {
+	if len(namespace) == 0 {
+		namespace = DefaultNamespace
+	}
+	if err = schemaDB.Where("`namespace`=? AND `module_name`=? AND `table_name`=?", namespace, moduleName, tableName).Order("position,id ASC").Find(&schemas).Error; err == nil {
 		//修改表结构缓存
 		if len(schemas) > 0 {
 			schemaCache.Add(cacheKey, schemas)
@@ -234,7 +252,11 @@ func getSchemas(db *gorm.DB, namespace, moduleName, tableName string) (schemas [
 }
 
 func getSchemasNoCache(db *gorm.DB, namespace, moduleName, tableName string) (schemas []*Schema, err error) {
-	err = db.Where("`namespace`=? AND `module_name`=? AND `table_name`=?", namespace, moduleName, tableName).Order("position,id ASC").Find(&schemas).Error
+	tx := db.Session(&gorm.Session{NewDB: true, SkipHooks: true})
+	if len(namespace) == 0 {
+		namespace = DefaultNamespace
+	}
+	err = tx.Where("`namespace`=? AND `module_name`=? AND `table_name`=?", namespace, moduleName, tableName).Order("position,id ASC").Find(&schemas).Error
 	return
 }
 
@@ -254,15 +276,6 @@ func IsNewRecord(value reflect.Value, stmt *gorm.Statement) bool {
 	return false
 }
 
-// getProperties 获取属性
-func (schema *Schema) getProperties() *Properties {
-	if schema.properties == nil {
-		schema.properties = &Properties{}
-		_ = json.Unmarshal([]byte(schema.Properties), schema.properties)
-	}
-	return schema.properties
-}
-
 // extraSize 提取模型定义的长度大小
 func extraSize(str string) int {
 	var (
@@ -422,7 +435,7 @@ func generateFieldScenario(field *schema.Field) string {
 }
 
 // migrate 合并数据表结构
-func migrate(db *gorm.DB, value interface{}) (err error) {
+func migrateUp(namespace string, value interface{}) (err error) {
 	var (
 		pos            int
 		ok             bool
@@ -438,7 +451,13 @@ func migrate(db *gorm.DB, value interface{}) (err error) {
 		columnName     string
 		columnLabel    string
 	)
-	stmt = createStatement(db)
+	if schemaDB == nil {
+		return errors.New("call initSchema first")
+	}
+	if len(namespace) == 0 {
+		namespace = DefaultNamespace
+	}
+	stmt = createStatement(schemaDB)
 	if value == nil {
 		return ErrInvalidModelInstance
 	}
@@ -450,7 +469,7 @@ func migrate(db *gorm.DB, value interface{}) (err error) {
 	if err = stmt.Parse(value); err != nil {
 		return err
 	}
-	if schemas, err = getSchemas(db, "default", moduleName, tableName); err != nil {
+	if schemas, err = getSchemas(namespace, moduleName, tableName); err != nil {
 		schemas = make([]*Schema, 0)
 	}
 	totalCount := len(stmt.Schema.Fields)
@@ -491,7 +510,7 @@ func migrate(db *gorm.DB, value interface{}) (err error) {
 			tag = Basic
 		}
 		values = append(values, &Schema{
-			Namespace:  "default",
+			Namespace:  namespace,
 			Module:     moduleName,
 			Table:      tableName,
 			Enable:     1,
@@ -511,7 +530,13 @@ func migrate(db *gorm.DB, value interface{}) (err error) {
 	}
 	if len(values) > 0 {
 		//batch save
-		err = db.Create(values).Error
+		err = schemaDB.Create(values).Error
 	}
 	return
 }
+
+func initSchema(db *gorm.DB) (err error) {
+	schemaDB = db.Session(&gorm.Session{NewDB: true})
+	err = schemaDB.AutoMigrate(&Schema{})
+	return
+}