snowflake_id.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. package plugins
  2. import (
  3. "context"
  4. "git.nspix.com/golang/rest/v3"
  5. "github.com/bwmarrin/snowflake"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/schema"
  8. "os"
  9. "reflect"
  10. "strconv"
  11. )
  12. var (
  13. sf *snowflake.Node
  14. )
  15. func init() {
  16. var err error
  17. no, _ := strconv.ParseInt(os.Getenv("REST_SNOWFLAKE_NODE"), 10, 64)
  18. if no == 0 {
  19. no = 1
  20. }
  21. if sf, err = snowflake.NewNode(no); err != nil {
  22. panic(err)
  23. }
  24. }
  25. //SnowflakeID 自动生成主键ID
  26. func snowflakeID(db *gorm.DB) {
  27. var (
  28. err error
  29. )
  30. if db.Statement.Schema != nil {
  31. if field := db.Statement.Schema.LookUpField("ID"); field != nil {
  32. if field.DataType == schema.String {
  33. if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice {
  34. for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
  35. if _, zero := field.ValueOf(context.Background(), db.Statement.ReflectValue.Index(i)); zero {
  36. if err = field.Set(context.Background(), db.Statement.ReflectValue.Index(i), sf.Generate().String()); err != nil {
  37. _ = db.AddError(err)
  38. }
  39. }
  40. }
  41. } else {
  42. if _, zero := field.ValueOf(context.Background(), db.Statement.ReflectValue); zero {
  43. db.Statement.SetColumn("ID", sf.Generate().String())
  44. }
  45. }
  46. }
  47. }
  48. }
  49. }
  50. func RegisterSnowflakeIDCallback(ri *rest.CRUD) (err error) {
  51. db := ri.DB()
  52. return db.Callback().Create().Before("gorm:create").Register("snowflake_id", snowflakeID)
  53. }