snowflake_id.go 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. package plugins
  2. import (
  3. "context"
  4. "github.com/bwmarrin/snowflake"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/schema"
  7. "os"
  8. "reflect"
  9. "strconv"
  10. )
  11. var (
  12. sf *snowflake.Node
  13. )
  14. func init() {
  15. var err error
  16. no, _ := strconv.ParseInt(os.Getenv("CC_NODE"), 10, 64)
  17. if no == 0 {
  18. no = 1
  19. }
  20. if sf, err = snowflake.NewNode(no); err != nil {
  21. panic(err)
  22. }
  23. }
  24. //SnowflakeID 自动生成主键ID
  25. func SnowflakeID(db *gorm.DB) {
  26. var err error
  27. if db.Statement.Schema != nil {
  28. if field := db.Statement.Schema.LookUpField("ID"); field != nil {
  29. if field.DataType == schema.String {
  30. if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice {
  31. for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
  32. if _, zero := field.ValueOf(context.Background(), db.Statement.ReflectValue.Index(i)); zero {
  33. if err = field.Set(context.Background(), db.Statement.ReflectValue.Index(i), sf.Generate().String()); err != nil {
  34. _ = db.AddError(err)
  35. }
  36. }
  37. }
  38. } else {
  39. if _, zero := field.ValueOf(context.Background(), db.Statement.ReflectValue); zero {
  40. db.Statement.SetColumn("ID", sf.Generate().String())
  41. }
  42. }
  43. }
  44. }
  45. }
  46. }
  47. func RegisterSnowflakeIDCallback(db *gorm.DB) (err error) {
  48. return db.Callback().Create().Before("gorm:create").Register("snowflake_id", SnowflakeID)
  49. }