package query import ( "gorm.io/gorm" "strings" ) type Query struct { db *gorm.DB condition string fields []string params []interface{} limit int offset int table string orderBy []string groupBy []string joins []join } type join struct { Table string Direction string Conds []*Condition } func (query *Query) Select(fields ...string) *Query { if query.fields == nil { query.fields = fields } else { query.fields = append(query.fields, fields...) } return query } func (query *Query) From(table string) *Query { query.table = table return query } //// Joins specify Joins conditions //// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) //func (s *DB) Joins(query string, args ...interface{}) *DB { // return s.clone().search.Joins(query, args...).db //} func (query *Query) LeftJoin(table string, conds ...*Condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "LEFT", Conds: conds, }) return query } func (query *Query) RightJoin(table string, conds ...*Condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "RIGHT", Conds: conds, }) return query } func (query *Query) InnerJoin(table string, conds ...*Condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "INNER", Conds: conds, }) return query } func (query *Query) AndFilterWhere(conds ...*Condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := buildConditions("AND", true, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND " + cs } return query } func (query *Query) AndWhere(conds ...*Condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := buildConditions("AND", false, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND " + cs } return query } func (query *Query) OrFilterWhere(conds ...*Condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := buildConditions("OR", true, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) OrWhere(conds ...*Condition) *Query { length := len(conds) if length == 0 { return query } cs, ps := buildConditions("OR", false, conds...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) GroupBy(cols ...string) *Query { query.groupBy = append(query.groupBy, cols...) return query } func (query *Query) OrderBy(col, direction string) *Query { direction = strings.ToUpper(direction) if direction == "" || !(direction == "ASC" || direction == "DESC") { direction = "ASC" } query.orderBy = append(query.orderBy, col+" "+direction) return query } func (query *Query) Offset(i int) *Query { query.offset = i return query } func (query *Query) Limit(i int) *Query { query.limit = i return query } func (query *Query) prepare() (*gorm.DB, error) { db := query.db if query.condition != "" { db = db.Where(query.condition, query.params...) } if query.fields != nil { db = db.Select(strings.Join(query.fields, ",")) } if query.table != "" { db = db.Table(query.table) } if query.joins != nil && len(query.joins) > 0 { for _, joinEntity := range query.joins { cs, ps := buildConditions("OR", false, joinEntity.Conds...) db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...) } } if query.orderBy != nil && len(query.orderBy) > 0 { db = db.Order(strings.Join(query.orderBy, ",")) } if query.groupBy != nil && len(query.groupBy) > 0 { db = db.Group(strings.Join(query.groupBy, ",")) } if query.offset > 0 { db = db.Offset(query.offset) } if query.limit > 0 { db = db.Limit(query.limit) } return db, nil } func (query *Query) Count(v interface{}) (i int64) { var ( db *gorm.DB err error ) if db, err = query.prepare(); err != nil { return } else { err = db.Model(v).Count(&i).Error } return } func (query *Query) One(v interface{}) (err error) { var ( db *gorm.DB ) if db, err = query.prepare(); err != nil { return } else { err = db.First(v).Error } return } func (query *Query) All(v interface{}) (err error) { var ( db *gorm.DB ) if db, err = query.prepare(); err != nil { return } else { err = db.Find(v).Error } return } func New(db *gorm.DB) *Query { return &Query{ db: db, params: make([]interface{}, 0), orderBy: make([]string, 0), groupBy: make([]string, 0), joins: make([]join, 0), } }