package rest import ( "fmt" "git.nspix.com/golang/micro/helper/utils" "gorm.io/gorm" "reflect" "strconv" "strings" "time" ) type ( Query struct { db *gorm.DB condition string fields []string params []interface{} table string joins []join orderBy []string groupBy []string limit int offset int } condition struct { Field string `json:"field"` Value interface{} `json:"value"` Operator string `json:"operator"` } join struct { Table string Direction string Conds []*condition } ) func (query *Query) compile() (*gorm.DB, error) { db := query.db if query.condition != "" { db = db.Where(query.condition, query.params...) } if query.fields != nil { db = db.Select(strings.Join(query.fields, ",")) } if query.table != "" { db = db.Table(query.table) } if query.joins != nil && len(query.joins) > 0 { for _, joinEntity := range query.joins { cs, ps := query.buildConditions("OR", false, joinEntity.Conds...) db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...) } } if query.orderBy != nil && len(query.orderBy) > 0 { db = db.Order(strings.Join(query.orderBy, ",")) } if query.groupBy != nil && len(query.groupBy) > 0 { db = db.Group(strings.Join(query.groupBy, ",")) } if query.offset > 0 { db = db.Offset(query.offset) } if query.limit > 0 { db = db.Limit(query.limit) } return db, nil } func (query *Query) decodeValue(v interface{}) string { refVal := reflect.Indirect(reflect.ValueOf(v)) switch refVal.Kind() { case reflect.Bool: if refVal.Bool() { return "1" } else { return "0" } case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: return strconv.FormatInt(refVal.Int(), 10) case reflect.Float32, reflect.Float64: return strconv.FormatFloat(refVal.Float(), 'f', -1, 64) case reflect.String: return "'" + refVal.String() + "'" case timeKind: if tm, ok := refVal.Interface().(time.Time); ok { return "'" + tm.Format("2006-01-02 15:04:05") + "'" } return fmt.Sprint(v) default: return fmt.Sprint(v) } } func (query *Query) buildConditions(operator string, filter bool, conds ...*condition) (str string, params []interface{}) { var ( sb strings.Builder ) params = make([]interface{}, 0) for _, cond := range conds { if filter { if utils.IsEmpty(cond.Value) { continue } } if cond.Operator == "" { cond.Operator = "=" } switch strings.ToUpper(cond.Operator) { case "=", "<>", ">", "<", ">=", "<=", "!=": if sb.Len() > 0 { sb.WriteString(" " + operator + " ") } if cond.Operator == "=" && cond.Value == nil { sb.WriteString("`" + cond.Field + "` IS NULL") } else { sb.WriteString("`" + cond.Field + "` " + cond.Operator + " ?") params = append(params, cond.Value) } case "LIKE": if sb.Len() > 0 { sb.WriteString(" " + operator + " ") } cond.Value = fmt.Sprintf("%%%s%%", cond.Value) sb.WriteString("`" + cond.Field + "` LIKE ?") params = append(params, cond.Value) case "IN": if sb.Len() > 0 { sb.WriteString(" " + operator + " ") } refVal := reflect.Indirect(reflect.ValueOf(cond.Value)) switch refVal.Kind() { case reflect.Slice: ss := make([]string, refVal.Len()) for i := 0; i < refVal.Len(); i++ { ss[i] = query.decodeValue(refVal.Index(i)) } sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")") case reflect.String: sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")") } case "BETWEEN": refVal := reflect.ValueOf(cond.Value) if refVal.Kind() == reflect.Slice && refVal.Len() == 2 { sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?") params = append(params, refVal.Index(0), refVal.Index(1)) } } } str = sb.String() return } func (query *Query) Select(fields ...string) *Query { if query.fields == nil { query.fields = fields } else { query.fields = append(query.fields, fields...) } return query } func (query *Query) From(table string) *Query { query.table = table return query } //// Joins specify Joins conditions //// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) //func (s *DB) Joins(query string, args ...interface{}) *DB { // return s.clone().search.Joins(query, args...).db //} func (query *Query) LeftJoin(table string, conds ...*condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "LEFT", Conds: conds, }) return query } func (query *Query) RightJoin(table string, conds ...*condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "RIGHT", Conds: conds, }) return query } func (query *Query) InnerJoin(table string, conds ...*condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "INNER", Conds: conds, }) return query } func (query *Query) AndFilterWhere(conds ...*condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := query.buildConditions("AND", true, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND " + cs } return query } func (query *Query) AndWhere(conds ...*condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := query.buildConditions("AND", false, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) OrFilterWhere(conds ...*condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := query.buildConditions("OR", true, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) OrWhere(conds ...*condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := query.buildConditions("OR", false, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) GroupBy(cols ...string) *Query { query.groupBy = append(query.groupBy, cols...) return query } func (query *Query) OrderBy(col, direction string) *Query { direction = strings.ToUpper(direction) if direction == "" || !(direction == "ASC" || direction == "DESC") { direction = "ASC" } query.orderBy = append(query.orderBy, col+" "+direction) return query } func (query *Query) Offset(i int) *Query { query.offset = i return query } func (query *Query) Limit(i int) *Query { query.limit = i return query } func (query *Query) Count(v interface{}) (i int64) { var ( db *gorm.DB err error ) if db, err = query.compile(); err != nil { return } else { err = db.Model(v).Count(&i).Error } return } func (query *Query) One(v interface{}) (err error) { var ( db *gorm.DB ) if db, err = query.compile(); err != nil { return } else { err = db.First(v).Error } return } func (query *Query) All(v interface{}) (err error) { var ( db *gorm.DB ) if db, err = query.compile(); err != nil { return } else { err = db.Find(v).Error } return } func NewQuery(db *gorm.DB) *Query { return &Query{ db: db, params: make([]interface{}, 0), orderBy: make([]string, 0), groupBy: make([]string, 0), joins: make([]join, 0), } } func NewQueryCondition(field string, value interface{}) *condition { return &condition{ Field: field, Value: value, Operator: "=", } } func NewQueryConditionWithOperator(operator, field string, value interface{}) *condition { cond := &condition{ Field: field, Value: value, Operator: operator, } return cond }