Browse Source

修改排序和定义label功能

fancl 3 years ago
parent
commit
092a845c6f
5 changed files with 130 additions and 43 deletions
  1. 63 17
      crud.go
  2. 34 17
      entity.go
  3. 4 0
      model.go
  4. 26 2
      options.go
  5. 3 7
      schema.go

+ 63 - 17
crud.go

@@ -10,21 +10,31 @@ import (
 	"gorm.io/driver/sqlite"
 	"gorm.io/gorm"
 	"reflect"
+	"sort"
 	"sync"
+	"sync/atomic"
+)
+
+const (
+	TypeModule = "module"
+	TypeTable  = "table"
 )
 
 type (
 	CRUD struct {
+		index          int32
 		db             *gorm.DB
 		entities       sync.Map
 		httpSvr        *http.Server
 		httpMiddleware []http.Middleware
 		callback       *Callback
+		moduleLabels   map[string]string
 	}
 
 	treeValue struct {
-		Label    string       `json:"label"`
-		Value    string       `json:"value"`
+		Label    string       `json:"label"` //标签
+		Value    string       `json:"value"` //值
+		Type     string       `json:"type"`  //类型
 		Children []*treeValue `json:"children"`
 	}
 
@@ -57,22 +67,44 @@ func (crud *CRUD) Callback() *Callback {
 	return crud.callback
 }
 
-//handleQueryCrudModules 处理
+//handleQueryCrudModules 获取模块的信息
 func (crud *CRUD) handleQueryCrudModules(ctx *http.Context) (err error) {
 	ts := make([]*treeValue, 0)
+	entities := make(Entities, 0)
 	crud.entities.Range(func(key, value interface{}) bool {
-		e := value.(*Entity)
+		entities = append(entities, value.(*Entity))
+		return true
+	})
+	sort.Sort(entities)
+	isHandled := false
+	for _, e := range entities {
+		isHandled = false
 		for _, tv := range ts {
 			if tv.Value == e.model.ModuleName() {
-				tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName()})
-				return true
+				if viewer, ok := e.model.(ModelViewer); ok {
+					tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: viewer.TableLabel(), Type: TypeTable})
+				} else {
+					tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName(), Type: TypeTable})
+				}
+				isHandled = true
+				break
 			}
 		}
-		tv := &treeValue{Label: e.model.ModuleName(), Value: e.model.ModuleName()}
-		tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName()})
+		if isHandled {
+			continue
+		}
+		moduleLabel := crud.moduleLabels[e.model.ModuleName()]
+		if moduleLabel == "" {
+			moduleLabel = e.model.ModuleName()
+		}
+		tv := &treeValue{Label: moduleLabel, Value: e.model.ModuleName(), Type: TypeModule}
+		if viewer, ok := e.model.(ModelViewer); ok {
+			tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: viewer.TableLabel(), Type: TypeTable})
+		} else {
+			tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName(), Type: TypeTable})
+		}
 		ts = append(ts, tv)
-		return true
-	})
+	}
 	return ctx.Success(ts)
 }
 
@@ -208,11 +240,13 @@ func (crud *CRUD) Attach(model Model, ops ...Option) (err error) {
 		return
 	}
 	//migrate table schema
-	if err = migrateUp("", model); err != nil {
+	if err = migrateUp(opts.Namespace, model, opts.MigrateOptions); err != nil {
 		return
 	}
 	scenarios := model.Scenario()
+	atomic.AddInt32(&crud.index, 1)
 	entity := newEntity(model, opts)
+	entity.index = crud.index
 	entity.callback = crud.callback
 	if len(scenarios) == 0 {
 		entity.scenarios = []string{ScenarioList, ScenarioCreate, ScenarioUpdate, ScenarioDelete, ScenarioExport, ScenarioView}
@@ -260,7 +294,9 @@ func (crud *CRUD) Routes(svr *http.Server, ms ...http.Middleware) {
 		}
 		//注册获取键值对数据格式
 		if entity.isKvMapping() {
-			crud.httpSvr.Handle("GET", entity.getScenarioUrl("mapping"), entity.getScenarioHandle("mapping"), entity.opts.Middleware...)
+			uri = entity.getScenarioUrl("mapping")
+			log.Debugf("CRUD: register %s %s", "GET", uri)
+			crud.httpSvr.Handle("GET", uri, entity.getScenarioHandle("mapping"), entity.opts.Middleware...)
 		}
 		return true
 	})
@@ -286,15 +322,24 @@ func (crud *CRUD) IsNewRecord(value reflect.Value, stmt *gorm.Statement) bool {
 	return false
 }
 
+// SetModuleLabel 设置模块标签
+func (crud *CRUD) SetModuleLabel(moduleName string, moduleLabel string) {
+	if crud.moduleLabels == nil {
+		crud.moduleLabels = make(map[string]string)
+	}
+	crud.moduleLabels[moduleName] = moduleLabel
+}
+
 //NewCRUD 创建一个新的CRUD模型
 func NewCRUD(db *gorm.DB, svr *http.Server) (crud *CRUD, err error) {
 	if err = initSchema(db); err != nil {
 		return
 	}
 	crud = &CRUD{
-		db:       db,
-		callback: newCallback(),
-		httpSvr:  svr,
+		db:           db,
+		callback:     newCallback(),
+		httpSvr:      svr,
+		moduleLabels: make(map[string]string),
 	}
 	return
 }
@@ -324,8 +369,9 @@ func Dialer(cfg *Config) (crud *CRUD, err error) {
 		return
 	}
 	crud = &CRUD{
-		db:       db,
-		callback: newCallback(),
+		db:           db,
+		moduleLabels: make(map[string]string),
+		callback:     newCallback(),
 	}
 	return
 }

+ 34 - 17
entity.go

@@ -40,22 +40,39 @@ type DiffAttr struct {
 	NewValue interface{} `json:"new_value"`
 }
 
-type Entity struct {
-	opts                 *Options
-	model                Model
-	primaryKey           string
-	reflectValue         reflect.Value
-	reflectType          reflect.Type
-	statement            *gorm.Statement
-	isImplementKvMapping bool
-	mappingLabelField    string
-	mappingValueField    string
-	singularName         string
-	pluralizeName        string
-	scenarios            []string
-	callback             *Callback
-	lruCache             *lru.Cache
-	createdAt            time.Time
+type (
+	Entity struct {
+		index                int32
+		opts                 *Options
+		model                Model
+		primaryKey           string
+		reflectValue         reflect.Value
+		reflectType          reflect.Type
+		statement            *gorm.Statement
+		isImplementKvMapping bool
+		mappingLabelField    string
+		mappingValueField    string
+		singularName         string
+		pluralizeName        string
+		scenarios            []string
+		callback             *Callback
+		lruCache             *lru.Cache
+		createdAt            time.Time
+	}
+
+	Entities []*Entity
+)
+
+func (e Entities) Len() int {
+	return len(e)
+}
+
+func (e Entities) Less(i, j int) bool {
+	return e[i].index < e[j].index
+}
+
+func (e Entities) Swap(i, j int) {
+	e[i], e[j] = e[j], e[j]
 }
 
 func (e *Entity) ID() string {
@@ -199,7 +216,7 @@ func (e *Entity) getScenarioUrl(scenario string) string {
 	case ScenarioExport:
 		uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "-export"
 	case ScenarioMapping:
-		uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "-mapping"
+		uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "-pairs"
 	}
 	return path.Clean(uri)
 }

+ 4 - 0
model.go

@@ -12,6 +12,10 @@ type (
 		Scenario() []string
 	}
 
+	ModelViewer interface {
+		TableLabel() string
+	}
+
 	Migrating interface {
 		Defaults() []MigrateValue
 	}

+ 26 - 2
options.go

@@ -8,25 +8,47 @@ import (
 
 type Options struct {
 	EnableNamespace   bool
+	Namespace         string
 	DB                *gorm.DB
 	Prefix            string
 	TablePrefixes     []string
 	RemoveTablePrefix bool
 	Formatter         *Formatter
 	Middleware        []http.Middleware
+	MigrateOptions    *MigrateOptions
 }
 
 type Option func(o *Options)
 
+func WithMigration(ops ...MigrateOption) Option {
+	return func(o *Options) {
+		if o.MigrateOptions == nil {
+			o.MigrateOptions = &MigrateOptions{}
+		}
+		for _, f := range ops {
+			f(o.MigrateOptions)
+		}
+	}
+}
+
 func WithPrefix(prefix string) Option {
 	return func(o *Options) {
 		if !strings.HasPrefix(prefix, "/") {
 			prefix = "/" + prefix
 		}
+		if strings.HasSuffix(prefix, "/") {
+			prefix = strings.TrimSuffix(prefix, "/")
+		}
 		o.Prefix = prefix
 	}
 }
 
+func WithNamespace(namespace string) Option {
+	return func(o *Options) {
+		o.Namespace = namespace
+	}
+}
+
 func WithDB(db *gorm.DB) Option {
 	return func(o *Options) {
 		o.DB = db
@@ -35,7 +57,9 @@ func WithDB(db *gorm.DB) Option {
 
 func NewOptions() *Options {
 	return &Options{
-		Formatter:  DefaultFormatter,
-		Middleware: make([]http.Middleware, 0),
+		Namespace:      DefaultNamespace,
+		Formatter:      DefaultFormatter,
+		Middleware:     make([]http.Middleware, 0),
+		MigrateOptions: &MigrateOptions{},
 	}
 }

+ 3 - 7
schema.go

@@ -404,7 +404,7 @@ func generateFieldScenario(field *schema.Field) string {
 }
 
 // migrate 合并数据表结构
-func migrateUp(namespace string, value interface{}, opts ...MigrateOption) (err error) {
+func migrateUp(namespace string, value interface{}, opts *MigrateOptions) (err error) {
 	var (
 		pos            int
 		ok             bool
@@ -420,10 +420,6 @@ func migrateUp(namespace string, value interface{}, opts ...MigrateOption) (err
 		columnName     string
 		columnLabel    string
 	)
-	opt := &MigrateOptions{}
-	for _, o := range opts {
-		o(opt)
-	}
 	if schemaDB == nil {
 		return errors.New("call initSchema first")
 	}
@@ -499,8 +495,8 @@ func migrateUp(namespace string, value interface{}, opts ...MigrateOption) (err
 			Properties: generateFieldProperties(field),
 			Position:   pv,
 		}
-		if opt.Callback != nil {
-			if err = opt.Callback(schemaModel); err == nil {
+		if opts.Callback != nil {
+			if err = opts.Callback(schemaModel); err == nil {
 				values = append(values, schemaModel)
 				pos++
 			}