package rest import ( "context" "fmt" "git.nspix.com/golang/micro/gateway/http" "git.nspix.com/golang/micro/log" "git.nspix.com/golang/rest/v2/pkg/logger" "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" "reflect" "sort" "sync" "sync/atomic" ) const ( TypeModule = "module" TypeTable = "table" ) type ( CRUD struct { index int32 db *gorm.DB entities sync.Map httpSvr *http.Server httpMiddleware []http.Middleware callback *Callback moduleLabels map[string]string } treeValue struct { Label string `json:"label"` //标签 Value string `json:"value"` //值 Type string `json:"type"` //类型 Children []*treeValue `json:"children"` } contextKey struct{} ) var ( ctxKey = contextKey{} ) func (t *treeValue) Append(v *treeValue) { if t.Children == nil { t.Children = make([]*treeValue, 0) } t.Children = append(t.Children, v) } //DB 获取当前数据库实例 func (crud *CRUD) DB() *gorm.DB { return crud.db } //WithDB 设置当前数据库实例 func (crud *CRUD) WithDB(db *gorm.DB) { crud.db = db } //Callback 获取回调管理器 func (crud *CRUD) Callback() *Callback { return crud.callback } //handleQueryCrudModules 获取模块的信息 func (crud *CRUD) handleQueryCrudModules(ctx *http.Context) (err error) { ts := make([]*treeValue, 0) entities := make(Entities, 0) crud.entities.Range(func(key, value interface{}) bool { entities = append(entities, value.(*Entity)) return true }) sort.Sort(entities) isHandled := false for _, e := range entities { isHandled = false for _, tv := range ts { if tv.Value == e.model.ModuleName() { if viewer, ok := e.model.(ModelViewer); ok { tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: viewer.TableLabel(), Type: TypeTable}) } else { tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName(), Type: TypeTable}) } isHandled = true break } } if isHandled { continue } moduleLabel := crud.moduleLabels[e.model.ModuleName()] if moduleLabel == "" { moduleLabel = e.model.ModuleName() } tv := &treeValue{Label: moduleLabel, Value: e.model.ModuleName(), Type: TypeModule} if viewer, ok := e.model.(ModelViewer); ok { tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: viewer.TableLabel(), Type: TypeTable}) } else { tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName(), Type: TypeTable}) } ts = append(ts, tv) } return ctx.Success(ts) } //handleQuerySchema 处理http查询schema请求 func (crud *CRUD) handleQuerySchema(ctx *http.Context) (err error) { var ( schemas []*Schema modelEntity *Entity ) moduleName := ctx.ParamValue("module") tableName := ctx.ParamValue("table") crud.entities.Range(func(key, value interface{}) bool { entity := value.(*Entity) if entity.model.ModuleName() == moduleName { if entity.model.TableName() == tableName { modelEntity = entity return false } if entity.opts.RemoveTablePrefix { for _, prefix := range entity.opts.TablePrefixes { if prefix+tableName == entity.model.TableName() { modelEntity = entity return false } } } } return true }) if modelEntity == nil { return ctx.Error(10011, fmt.Sprintf("module %s table %s schema not found", moduleName, tableName)) } if schemas, err = getSchemasNoCache(crud.db, ctx.ParamValue(NamespaceVariable), modelEntity.model.ModuleName(), modelEntity.model.TableName()); err == nil { return ctx.Success(schemas) } return ctx.Error(10011, err.Error()) } //handleSaveSchema 保存schema func (crud *CRUD) handleSaveSchema(ctx *http.Context) (err error) { var ( modelEntity *Entity ) moduleName := ctx.ParamValue("module") tableName := ctx.ParamValue("table") crud.entities.Range(func(key, value interface{}) bool { entity := value.(*Entity) if entity.model.ModuleName() == moduleName { if entity.model.TableName() == tableName { modelEntity = entity return false } if entity.opts.RemoveTablePrefix { for _, prefix := range entity.opts.TablePrefixes { if prefix+tableName == entity.model.TableName() { modelEntity = entity return false } } } } return true }) if modelEntity == nil { return ctx.Error(10011, fmt.Sprintf("module %s table %s schema not found", moduleName, tableName)) } schemas := make([]*Schema, 0) if err = ctx.Bind(&schemas); err != nil { return ctx.Error(HttpInvalidPayload, err.Error()) } if err = crud.db.Transaction(func(tx *gorm.DB) error { for _, scm := range schemas { if err2 := tx.Save(scm).Error; err2 != nil { return err2 } } return nil }); err == nil { invalidCache(ctx.ParamValue(NamespaceVariable), modelEntity.model.ModuleName(), modelEntity.model.TableName()) return ctx.Success(map[string]interface{}{ "count": len(schemas), "state": "success", }) } return ctx.Error(10014, err.Error()) } //handleDeleteSchema 删除表的schema func (crud *CRUD) handleDeleteSchema(ctx *http.Context) (err error) { id := ctx.ParamValue("id") model := &Schema{} if err = crud.db.Where("id=?", id).First(model).Error; err == nil { invalidCache(ctx.ParamValue(NamespaceVariable), model.Module, model.Table) if err = crud.db.Where("id=?", id).Delete(&Schema{}).Error; err == nil { return ctx.Success(map[string]string{ "id": id, "state": "success", }) } else { return ctx.Error(10012, err.Error()) } } else { return ctx.Error(10012, err.Error()) } } //router 绑定路由 func (crud *CRUD) router() { if crud.httpSvr == nil { return } //获取注册上来的模块 crud.httpSvr.Handle("GET", "/crud/modules", crud.handleQueryCrudModules, crud.httpMiddleware...) //获取所有schema crud.httpSvr.Handle("GET", "/schema/:module/:table", crud.handleQuerySchema, crud.httpMiddleware...) //更新schema crud.httpSvr.Handle("POST", "/schema/:module/:table", crud.handleSaveSchema, crud.httpMiddleware...) //删除schema crud.httpSvr.Handle("DELETE", "/schema/:id", crud.handleDeleteSchema, crud.httpMiddleware...) } // Attach 附加一个模型数据 func (crud *CRUD) Attach(model Model, ops ...Option) (err error) { opts := NewOptions() for _, op := range ops { op(opts) } if opts.DB == nil { opts.DB = crud.db } //auto migrate database struct tx := opts.DB.Session(&gorm.Session{NewDB: true}) if opts.MigrateOptions != nil && opts.MigrateOptions.TableOptions != "" { tx.Set("gorm:table_options", opts.MigrateOptions.TableOptions) } if err = tx.AutoMigrate(model); err != nil { return } //migrate table schema if err = migrateUp(opts.Namespace, model, opts.MigrateOptions); err != nil { return } scenarios := model.Scenario() entity := newEntity(atomic.AddInt32(&crud.index, 1), model, opts) entity.callback = crud.callback if len(scenarios) == 0 { entity.scenarios = []string{ScenarioList, ScenarioCreate, ScenarioUpdate, ScenarioDelete, ScenarioExport, ScenarioView} } else { entity.scenarios = model.Scenario() } crud.entities.Store(entity.ID(), entity) return } // Detach 删除一个模型数据 func (crud *CRUD) Detach(model Model) { id := model.TableName() + "@" + model.ModuleName() crud.entities.Delete(id) } //Routes 生成增删改查的路由 func (crud *CRUD) Routes(svr *http.Server, ms ...http.Middleware) { var ( method string uri string ) if crud.httpSvr == nil { crud.httpSvr = svr } if crud.httpSvr == nil { return } crud.httpMiddleware = ms crud.entities.Range(func(key, value interface{}) bool { entity := value.(*Entity) if entity.opts.Middleware == nil || len(entity.opts.Middleware) == 0 { entity.opts.Middleware = crud.httpMiddleware } for _, scenario := range entity.scenarios { method = entity.getScenarioMethod(scenario) uri = entity.getScenarioUrl(scenario) log.Debugf("CRUD: register %s %s", method, uri) crud.httpSvr.Handle( method, uri, entity.getScenarioHandle(scenario), entity.opts.Middleware..., ) } //注册获取键值对数据格式 if entity.isKvMapping() { uri = entity.getScenarioUrl("mapping") log.Debugf("CRUD: register %s %s", "GET", uri) crud.httpSvr.Handle("GET", uri, entity.getScenarioHandle("mapping"), entity.opts.Middleware...) } return true }) crud.router() } //Schemas 获取一个模型的显示字段 func (crud *CRUD) Schemas(namespace, moduleName, tableName, scenario string) []*Schema { if v, ok := crud.entities.Load(tableName + "@" + moduleName); ok { e := v.(*Entity) return visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), scenario) } return nil } // IsNewRecord 判断该模型是不是一条新记录 func (crud *CRUD) IsNewRecord(value reflect.Value, stmt *gorm.Statement) bool { for _, pf := range stmt.Schema.PrimaryFields { if _, isZero := pf.ValueOf(value); isZero { return true } } return false } // SetModuleLabel 设置模块标签 func (crud *CRUD) SetModuleLabel(moduleName string, moduleLabel string) { if crud.moduleLabels == nil { crud.moduleLabels = make(map[string]string) } crud.moduleLabels[moduleName] = moduleLabel } //NewCRUD 创建一个新的CRUD模型 func NewCRUD(db *gorm.DB, svr *http.Server) (crud *CRUD, err error) { if err = initSchema(db); err != nil { return } crud = &CRUD{ db: db, callback: newCallback(), httpSvr: svr, moduleLabels: make(map[string]string), } return } func openDatabase(cfg *Config) (db *gorm.DB, err error) { dbCfg := &gorm.Config{Logger: logger.NewGormLogger()} switch cfg.Driver { case "sqlite3": db, err = gorm.Open(sqlite.Open(cfg.ParseDSN()), dbCfg) case "mysql": db, err = gorm.Open(mysql.Open(cfg.ParseDSN()), dbCfg) default: err = fmt.Errorf("unsupported sql driver %s", cfg.Driver) } return } //Dialer 连接一个数据库 func Dialer(cfg *Config) (crud *CRUD, err error) { var ( db *gorm.DB ) if db, err = openDatabase(cfg); err != nil { return } if err = initSchema(db); err != nil { return } crud = &CRUD{ db: db, moduleLabels: make(map[string]string), callback: newCallback(), } return } func FromContext(ctx context.Context) *CRUD { if v := ctx.Value(ctxKey); v != nil { return v.(*CRUD) } return nil } func WithContext(ctx context.Context, r *CRUD) context.Context { return context.WithValue(ctx, ctxKey, r) }