query.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. package rest
  2. import (
  3. "fmt"
  4. "git.nspix.com/golang/micro/helper/utils"
  5. "gorm.io/gorm"
  6. "reflect"
  7. "strconv"
  8. "strings"
  9. "time"
  10. )
  11. type (
  12. Query struct {
  13. db *gorm.DB
  14. condition string
  15. fields []string
  16. params []interface{}
  17. table string
  18. joins []join
  19. orderBy []string
  20. groupBy []string
  21. limit int
  22. offset int
  23. }
  24. condition struct {
  25. Field string `json:"field"`
  26. Value interface{} `json:"value"`
  27. Operator string `json:"operator"`
  28. }
  29. join struct {
  30. Table string
  31. Direction string
  32. Conds []*condition
  33. }
  34. )
  35. func (query *Query) compile() (*gorm.DB, error) {
  36. db := query.db
  37. if query.condition != "" {
  38. db = db.Where(query.condition, query.params...)
  39. }
  40. if query.fields != nil {
  41. db = db.Select(strings.Join(query.fields, ","))
  42. }
  43. if query.table != "" {
  44. db = db.Table(query.table)
  45. }
  46. if query.joins != nil && len(query.joins) > 0 {
  47. for _, joinEntity := range query.joins {
  48. cs, ps := query.buildConditions("OR", false, joinEntity.Conds...)
  49. db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
  50. }
  51. }
  52. if query.orderBy != nil && len(query.orderBy) > 0 {
  53. db = db.Order(strings.Join(query.orderBy, ","))
  54. }
  55. if query.groupBy != nil && len(query.groupBy) > 0 {
  56. db = db.Group(strings.Join(query.groupBy, ","))
  57. }
  58. if query.offset > 0 {
  59. db = db.Offset(query.offset)
  60. }
  61. if query.limit > 0 {
  62. db = db.Limit(query.limit)
  63. }
  64. return db, nil
  65. }
  66. func (query *Query) decodeValue(v interface{}) string {
  67. refVal := reflect.Indirect(reflect.ValueOf(v))
  68. switch refVal.Kind() {
  69. case reflect.Bool:
  70. if refVal.Bool() {
  71. return "1"
  72. } else {
  73. return "0"
  74. }
  75. case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
  76. return strconv.FormatInt(refVal.Int(), 10)
  77. case reflect.Float32, reflect.Float64:
  78. return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
  79. case reflect.String:
  80. return "'" + refVal.String() + "'"
  81. case timeKind:
  82. if tm, ok := refVal.Interface().(time.Time); ok {
  83. return "'" + tm.Format("2006-01-02 15:04:05") + "'"
  84. }
  85. return fmt.Sprint(v)
  86. default:
  87. return fmt.Sprint(v)
  88. }
  89. }
  90. func (query *Query) buildConditions(operator string, filter bool, conds ...*condition) (str string, params []interface{}) {
  91. var (
  92. sb strings.Builder
  93. )
  94. params = make([]interface{}, 0)
  95. for _, cond := range conds {
  96. if filter {
  97. if utils.IsEmpty(cond.Value) {
  98. continue
  99. }
  100. }
  101. if cond.Operator == "" {
  102. cond.Operator = "="
  103. }
  104. switch strings.ToUpper(cond.Operator) {
  105. case "=", "<>", ">", "<", ">=", "<=", "!=":
  106. if sb.Len() > 0 {
  107. sb.WriteString(" " + operator + " ")
  108. }
  109. if cond.Operator == "=" && cond.Value == nil {
  110. sb.WriteString("`" + cond.Field + "` IS NULL")
  111. } else {
  112. sb.WriteString("`" + cond.Field + "` " + cond.Operator + " ?")
  113. params = append(params, cond.Value)
  114. }
  115. case "LIKE":
  116. if sb.Len() > 0 {
  117. sb.WriteString(" " + operator + " ")
  118. }
  119. cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
  120. sb.WriteString("`" + cond.Field + "` LIKE ?")
  121. params = append(params, cond.Value)
  122. case "IN":
  123. if sb.Len() > 0 {
  124. sb.WriteString(" " + operator + " ")
  125. }
  126. refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
  127. switch refVal.Kind() {
  128. case reflect.Slice:
  129. ss := make([]string, refVal.Len())
  130. for i := 0; i < refVal.Len(); i++ {
  131. ss[i] = query.decodeValue(refVal.Index(i))
  132. }
  133. sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
  134. case reflect.String:
  135. sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
  136. }
  137. case "BETWEEN":
  138. refVal := reflect.ValueOf(cond.Value)
  139. if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
  140. sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
  141. params = append(params, refVal.Index(0), refVal.Index(1))
  142. }
  143. }
  144. }
  145. str = sb.String()
  146. return
  147. }
  148. func (query *Query) Select(fields ...string) *Query {
  149. if query.fields == nil {
  150. query.fields = fields
  151. } else {
  152. query.fields = append(query.fields, fields...)
  153. }
  154. return query
  155. }
  156. func (query *Query) From(table string) *Query {
  157. query.table = table
  158. return query
  159. }
  160. //// Joins specify Joins conditions
  161. //// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
  162. //func (s *DB) Joins(query string, args ...interface{}) *DB {
  163. // return s.clone().search.Joins(query, args...).db
  164. //}
  165. func (query *Query) LeftJoin(table string, conds ...*condition) *Query {
  166. query.joins = append(query.joins, join{
  167. Table: table,
  168. Direction: "LEFT",
  169. Conds: conds,
  170. })
  171. return query
  172. }
  173. func (query *Query) RightJoin(table string, conds ...*condition) *Query {
  174. query.joins = append(query.joins, join{
  175. Table: table,
  176. Direction: "RIGHT",
  177. Conds: conds,
  178. })
  179. return query
  180. }
  181. func (query *Query) InnerJoin(table string, conds ...*condition) *Query {
  182. query.joins = append(query.joins, join{
  183. Table: table,
  184. Direction: "INNER",
  185. Conds: conds,
  186. })
  187. return query
  188. }
  189. func (query *Query) AndFilterWhere(conds ...*condition) *Query {
  190. length := len(conds)
  191. if length == 0 {
  192. return query
  193. }
  194. cs, ps := query.buildConditions("AND", true, conds...)
  195. if cs == "" {
  196. return query
  197. }
  198. query.params = append(query.params, ps...)
  199. if query.condition == "" {
  200. query.condition = cs
  201. } else {
  202. query.condition += " AND " + cs
  203. }
  204. return query
  205. }
  206. func (query *Query) AndWhere(conds ...*condition) *Query {
  207. length := len(conds)
  208. if length == 0 {
  209. return query
  210. }
  211. cs, ps := query.buildConditions("AND", false, conds...)
  212. if cs == "" {
  213. return query
  214. }
  215. query.params = append(query.params, ps...)
  216. if query.condition == "" {
  217. query.condition = cs
  218. } else {
  219. query.condition += " AND (" + cs + ")"
  220. }
  221. return query
  222. }
  223. func (query *Query) OrFilterWhere(conds ...*condition) *Query {
  224. length := len(conds)
  225. if length == 0 {
  226. return query
  227. }
  228. cs, ps := query.buildConditions("OR", true, conds...)
  229. if cs == "" {
  230. return query
  231. }
  232. query.params = append(query.params, ps...)
  233. if query.condition == "" {
  234. query.condition = cs
  235. } else {
  236. query.condition += " AND (" + cs + ")"
  237. }
  238. return query
  239. }
  240. func (query *Query) OrWhere(conds ...*condition) *Query {
  241. length := len(conds)
  242. if length == 0 {
  243. return query
  244. }
  245. cs, ps := query.buildConditions("OR", false, conds...)
  246. if cs == "" {
  247. return query
  248. }
  249. query.params = append(query.params, ps...)
  250. if query.condition == "" {
  251. query.condition = cs
  252. } else {
  253. query.condition += " AND (" + cs + ")"
  254. }
  255. return query
  256. }
  257. func (query *Query) GroupBy(cols ...string) *Query {
  258. query.groupBy = append(query.groupBy, cols...)
  259. return query
  260. }
  261. func (query *Query) OrderBy(col, direction string) *Query {
  262. direction = strings.ToUpper(direction)
  263. if direction == "" || !(direction == "ASC" || direction == "DESC") {
  264. direction = "ASC"
  265. }
  266. query.orderBy = append(query.orderBy, col+" "+direction)
  267. return query
  268. }
  269. func (query *Query) Offset(i int) *Query {
  270. query.offset = i
  271. return query
  272. }
  273. func (query *Query) Limit(i int) *Query {
  274. query.limit = i
  275. return query
  276. }
  277. func (query *Query) Count(v interface{}) (i int64) {
  278. var (
  279. db *gorm.DB
  280. err error
  281. )
  282. if db, err = query.compile(); err != nil {
  283. return
  284. } else {
  285. err = db.Model(v).Count(&i).Error
  286. }
  287. return
  288. }
  289. func (query *Query) One(v interface{}) (err error) {
  290. var (
  291. db *gorm.DB
  292. )
  293. if db, err = query.compile(); err != nil {
  294. return
  295. } else {
  296. err = db.First(v).Error
  297. }
  298. return
  299. }
  300. func (query *Query) All(v interface{}) (err error) {
  301. var (
  302. db *gorm.DB
  303. )
  304. if db, err = query.compile(); err != nil {
  305. return
  306. } else {
  307. err = db.Find(v).Error
  308. }
  309. return
  310. }
  311. func NewQuery(db *gorm.DB) *Query {
  312. return &Query{
  313. db: db,
  314. params: make([]interface{}, 0),
  315. orderBy: make([]string, 0),
  316. groupBy: make([]string, 0),
  317. joins: make([]join, 0),
  318. }
  319. }
  320. func NewQueryCondition(field string, value interface{}) *condition {
  321. return &condition{
  322. Field: field,
  323. Value: value,
  324. Operator: "=",
  325. }
  326. }
  327. func NewQueryConditionWithOperator(operator, field string, value interface{}) *condition {
  328. cond := &condition{
  329. Field: field,
  330. Value: value,
  331. Operator: operator,
  332. }
  333. return cond
  334. }