package rest import ( "context" "fmt" "git.nspix.com/golang/micro/gateway/http" "git.nspix.com/golang/micro/log" "git.nspix.com/golang/rest/v2/pkg/logger" "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" "reflect" "sync" ) type ( CRUD struct { db *gorm.DB entities sync.Map httpSvr *http.Server httpMiddleware []http.Middleware callback *Callback } treeValue struct { Label string `json:"label"` Value string `json:"value"` Children []*treeValue `json:"children"` } contextKey struct{} ) var ( ctxKey = contextKey{} ) func (t *treeValue) Append(v *treeValue) { if t.Children == nil { t.Children = make([]*treeValue, 0) } t.Children = append(t.Children, v) } //DB 获取当前数据库实例 func (crud *CRUD) DB() *gorm.DB { return crud.db } //WithDB 设置当前数据库实例 func (crud *CRUD) WithDB(db *gorm.DB) { crud.db = db } //Callback 获取回调管理器 func (crud *CRUD) Callback() *Callback { return crud.callback } //handleQueryCrudModules 处理 func (crud *CRUD) handleQueryCrudModules(ctx *http.Context) (err error) { ts := make([]*treeValue, 0) crud.entities.Range(func(key, value interface{}) bool { e := value.(*Entity) for _, tv := range ts { if tv.Value == e.model.ModuleName() { tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName()}) return true } } tv := &treeValue{Label: e.model.ModuleName(), Value: e.model.ModuleName()} tv.Append(&treeValue{Value: e.model.ModuleName() + "-" + e.model.TableName(), Label: e.model.TableName()}) ts = append(ts, tv) return true }) return ctx.Success(ts) } //handleQuerySchema 处理http查询schema请求 func (crud *CRUD) handleQuerySchema(ctx *http.Context) (err error) { var ( schemas []*Schema ) if schemas, err = getSchemasNoCache(crud.db, ctx.ParamValue(NamespaceVariable), ctx.ParamValue("module"), ctx.ParamValue("table")); err == nil { return ctx.Success(schemas) } return ctx.Error(10011, err.Error()) } //handleSaveSchema 保存schema func (crud *CRUD) handleSaveSchema(ctx *http.Context) (err error) { schemas := make([]*Schema, 0) if err = ctx.Bind(&schemas); err != nil { return ctx.Error(HttpInvalidPayload, err.Error()) } if err = crud.db.Transaction(func(tx *gorm.DB) error { for _, scm := range schemas { if err2 := tx.Save(scm).Error; err2 != nil { return err2 } } return nil }); err == nil { invalidCache(ctx.ParamValue(NamespaceVariable), ctx.ParamValue("module"), ctx.ParamValue("table")) return ctx.Success(map[string]interface{}{ "count": len(schemas), "state": "success", }) } return ctx.Error(10014, err.Error()) } //handleDeleteSchema 删除表的schema func (crud *CRUD) handleDeleteSchema(ctx *http.Context) (err error) { id := ctx.ParamValue("id") model := &Schema{} if err = crud.db.Where("id=?", id).First(model).Error; err == nil { invalidCache(ctx.ParamValue(NamespaceVariable), model.Module, model.Table) if err = crud.db.Where("id=?", id).Delete(&Schema{}).Error; err == nil { return ctx.Success(map[string]string{ "id": id, "state": "success", }) } else { return ctx.Error(10012, err.Error()) } } else { return ctx.Error(10012, err.Error()) } } //router 绑定路由 func (crud *CRUD) router() { if crud.httpSvr == nil { return } //获取注册上来的模块 crud.httpSvr.Handle("GET", "/crud/modules", crud.handleQueryCrudModules, crud.httpMiddleware...) //获取所有schema crud.httpSvr.Handle("GET", "/schema/:module/:table", crud.handleQuerySchema, crud.httpMiddleware...) //更新schema crud.httpSvr.Handle("POST", "/schema/:module/:table", crud.handleSaveSchema, crud.httpMiddleware...) //删除schema crud.httpSvr.Handle("DELETE", "/schema/:id", crud.handleDeleteSchema, crud.httpMiddleware...) } // Attach 附加一个模型数据 func (crud *CRUD) Attach(model Model, ops ...Option) (err error) { opts := NewOptions() for _, op := range ops { op(opts) } if opts.DB == nil { opts.DB = crud.db } //auto migrate database struct if err = crud.db.AutoMigrate(model); err != nil { return } //migrate table schema if err = migrateUp("", model); err != nil { return } scenarios := model.Scenario() entity := newEntity(model, opts) entity.callback = crud.callback if len(scenarios) == 0 { entity.scenarios = []string{ScenarioList, ScenarioCreate, ScenarioUpdate, ScenarioDelete, ScenarioExport, ScenarioView} } else { entity.scenarios = model.Scenario() } crud.entities.Store(entity.ID(), entity) return } // Detach 删除一个模型数据 func (crud *CRUD) Detach(model Model) { id := model.TableName() + "@" + model.ModuleName() crud.entities.Delete(id) } //Routes 生成增删改查的路由 func (crud *CRUD) Routes(svr *http.Server, ms ...http.Middleware) { var ( method string uri string ) if crud.httpSvr == nil { crud.httpSvr = svr } if crud.httpSvr == nil { return } crud.httpMiddleware = ms crud.entities.Range(func(key, value interface{}) bool { entity := value.(*Entity) if entity.opts.Middleware == nil || len(entity.opts.Middleware) == 0 { entity.opts.Middleware = crud.httpMiddleware } for _, scenario := range entity.scenarios { method = entity.getScenarioMethod(scenario) uri = entity.getScenarioUrl(scenario) log.Debugf("CRUD: register %s %s", method, uri) crud.httpSvr.Handle( method, uri, entity.getScenarioHandle(scenario), entity.opts.Middleware..., ) } //注册获取键值对数据格式 if entity.isKvMapping() { crud.httpSvr.Handle("GET", entity.getScenarioUrl("mapping"), entity.getScenarioHandle("mapping"), entity.opts.Middleware...) } return true }) crud.router() } //Schemas 获取一个模型的显示字段 func (crud *CRUD) Schemas(namespace, moduleName, tableName, scenario string) []*Schema { if v, ok := crud.entities.Load(tableName + "@" + moduleName); ok { e := v.(*Entity) return visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), scenario) } return nil } // IsNewRecord 判断该模型是不是一条新记录 func (crud *CRUD) IsNewRecord(value reflect.Value, stmt *gorm.Statement) bool { for _, pf := range stmt.Schema.PrimaryFields { if _, isZero := pf.ValueOf(value); isZero { return true } } return false } //NewCRUD 创建一个新的CRUD模型 func NewCRUD(db *gorm.DB, svr *http.Server) (crud *CRUD, err error) { if err = initSchema(db); err != nil { return } crud = &CRUD{ db: db, callback: newCallback(), httpSvr: svr, } return } func openDatabase(cfg *Config) (db *gorm.DB, err error) { dbCfg := &gorm.Config{Logger: logger.NewGormLogger()} switch cfg.Driver { case "sqlite3": db, err = gorm.Open(sqlite.Open(cfg.ParseDSN()), dbCfg) case "mysql": db, err = gorm.Open(mysql.Open(cfg.ParseDSN()), dbCfg) default: err = fmt.Errorf("unsupported sql driver %s", cfg.Driver) } return } //Dialer 连接一个数据库 func Dialer(cfg *Config) (crud *CRUD, err error) { var ( db *gorm.DB ) if db, err = openDatabase(cfg); err != nil { return } if err = initSchema(db); err != nil { return } crud = &CRUD{ db: db, callback: newCallback(), } return } func FromContext(ctx context.Context) *CRUD { if v := ctx.Value(ctxKey); v != nil { return v.(*CRUD) } return nil } func WithContext(ctx context.Context, r *CRUD) context.Context { return context.WithValue(ctx, ctxKey, r) }