12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- package plugins
- import (
- "context"
- "git.nspix.com/golang/rest/v3"
- "github.com/bwmarrin/snowflake"
- "gorm.io/gorm"
- "gorm.io/gorm/schema"
- "os"
- "reflect"
- "strconv"
- )
- var (
- sf *snowflake.Node
- )
- func init() {
- var err error
- no, _ := strconv.ParseInt(os.Getenv("REST_SNOWFLAKE_NODE"), 10, 64)
- if no == 0 {
- no = 1
- }
- if sf, err = snowflake.NewNode(no); err != nil {
- panic(err)
- }
- }
- //SnowflakeID 自动生成主键ID
- func snowflakeID(db *gorm.DB) {
- var (
- err error
- )
- if db.Statement.Schema != nil {
- if field := db.Statement.Schema.LookUpField("ID"); field != nil {
- if field.DataType == schema.String {
- if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice {
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- if _, zero := field.ValueOf(context.Background(), db.Statement.ReflectValue.Index(i)); zero {
- if err = field.Set(context.Background(), db.Statement.ReflectValue.Index(i), sf.Generate().String()); err != nil {
- _ = db.AddError(err)
- }
- }
- }
- } else {
- if _, zero := field.ValueOf(context.Background(), db.Statement.ReflectValue); zero {
- db.Statement.SetColumn("ID", sf.Generate().String())
- }
- }
- }
- }
- }
- }
- func RegisterSnowflakeIDCallback(ri *rest.CRUD) (err error) {
- db := ri.DB()
- return db.Callback().Create().Before("gorm:create").Register("snowflake_id", snowflakeID)
- }
|