premailer.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. package premailer
  2. import (
  3. "fmt"
  4. "regexp"
  5. "sort"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "github.com/PuerkitoBio/goquery"
  10. "github.com/vanng822/css"
  11. "golang.org/x/net/html"
  12. )
  13. // Premailer is the inteface of Premailer
  14. type Premailer interface {
  15. // Transform process and inlining css
  16. // It start to collect the rules in the document style tags
  17. // Calculate specificity and sort the rules based on that
  18. // It then collects the affected elements
  19. // And applies the rules on those
  20. // The leftover rules will put back into a style element
  21. Transform() (string, error)
  22. }
  23. var unmergableSelector = regexp.MustCompile("(?i)\\:{1,2}(visited|active|hover|focus|link|root|in-range|invalid|valid|after|before|selection|target|first\\-(line|letter))|^\\@")
  24. var notSupportedSelector = regexp.MustCompile("(?i)\\:(checked|disabled|enabled|lang)")
  25. type premailer struct {
  26. doc *goquery.Document
  27. elIdAttr string
  28. elements map[string]*elementRules
  29. rules []*styleRule
  30. leftover []*css.CSSRule
  31. allRules [][]*css.CSSRule
  32. elementId int
  33. processed bool
  34. options *Options
  35. }
  36. // NewPremailer return a new instance of Premailer
  37. // It take a Document as argument and it shouldn't be nil
  38. func NewPremailer(doc *goquery.Document, options *Options) Premailer {
  39. pr := premailer{}
  40. pr.doc = doc
  41. pr.rules = make([]*styleRule, 0)
  42. pr.allRules = make([][]*css.CSSRule, 0)
  43. pr.leftover = make([]*css.CSSRule, 0)
  44. pr.elements = make(map[string]*elementRules)
  45. pr.elIdAttr = "pr-el-id"
  46. if options == nil {
  47. options = NewOptions()
  48. }
  49. pr.options = options
  50. return &pr
  51. }
  52. func (pr *premailer) sortRules() {
  53. ruleIndex := 0
  54. for ruleSetIndex, rules := range pr.allRules {
  55. if rules == nil {
  56. continue
  57. }
  58. for _, rule := range rules {
  59. if rule.Type != css.STYLE_RULE {
  60. pr.leftover = append(pr.leftover, rule)
  61. continue
  62. }
  63. normalStyles := make([]*css.CSSStyleDeclaration, 0)
  64. importantStyles := make([]*css.CSSStyleDeclaration, 0)
  65. for _, s := range rule.Style.Styles {
  66. if s.Important {
  67. importantStyles = append(importantStyles, s)
  68. } else {
  69. normalStyles = append(normalStyles, s)
  70. }
  71. }
  72. selectors := strings.Split(rule.Style.Selector.Text(), ",")
  73. for _, selector := range selectors {
  74. if unmergableSelector.MatchString(selector) || notSupportedSelector.MatchString(selector) {
  75. // cause longer css
  76. pr.leftover = append(pr.leftover, copyRule(selector, rule))
  77. continue
  78. }
  79. if strings.Contains(selector, "*") {
  80. // keep this?
  81. pr.leftover = append(pr.leftover, copyRule(selector, rule))
  82. continue
  83. }
  84. if len(normalStyles) > 0 {
  85. pr.rules = append(pr.rules, &styleRule{makeSpecificity(0, ruleSetIndex, ruleIndex, selector), selector, normalStyles})
  86. ruleIndex += 1
  87. }
  88. if len(importantStyles) > 0 {
  89. pr.rules = append(pr.rules, &styleRule{makeSpecificity(1, ruleSetIndex, ruleIndex, selector), selector, importantStyles})
  90. ruleIndex += 1
  91. }
  92. }
  93. }
  94. }
  95. sort.Sort(bySpecificity(pr.rules))
  96. }
  97. func (pr *premailer) collectRules() {
  98. var wg sync.WaitGroup
  99. pr.doc.Find("style:not([data-premailer='ignore'])").Each(func(_ int, s *goquery.Selection) {
  100. if media, exist := s.Attr("media"); exist && media != "all" {
  101. return
  102. }
  103. wg.Add(1)
  104. pr.allRules = append(pr.allRules, nil)
  105. go func(ruleSetIndex int) {
  106. defer wg.Done()
  107. ss := css.Parse(s.Text())
  108. pr.allRules[ruleSetIndex] = ss.GetCSSRuleList()
  109. s.ReplaceWithHtml("")
  110. }(len(pr.allRules) - 1)
  111. })
  112. wg.Wait()
  113. }
  114. func (pr *premailer) collectElements() {
  115. for _, rule := range pr.rules {
  116. pr.doc.Find(rule.selector).Each(func(_ int, s *goquery.Selection) {
  117. if id, exist := s.Attr(pr.elIdAttr); exist {
  118. pr.elements[id].rules = append(pr.elements[id].rules, rule)
  119. } else {
  120. id := strconv.Itoa(pr.elementId)
  121. s.SetAttr(pr.elIdAttr, id)
  122. rules := make([]*styleRule, 0)
  123. rules = append(rules, rule)
  124. pr.elements[id] = &elementRules{
  125. element: s,
  126. rules: rules,
  127. cssToAttributes: pr.options.CssToAttributes,
  128. keepBangImportant: pr.options.KeepBangImportant,
  129. }
  130. pr.elementId += 1
  131. }
  132. })
  133. }
  134. }
  135. func (pr *premailer) applyInline() {
  136. for _, element := range pr.elements {
  137. element.inline()
  138. element.element.RemoveAttr(pr.elIdAttr)
  139. if pr.options.RemoveClasses {
  140. element.element.RemoveAttr("class")
  141. }
  142. }
  143. }
  144. func (pr *premailer) addLeftover() {
  145. if len(pr.leftover) > 0 {
  146. headNode := pr.doc.Find("head")
  147. styleNode := &html.Node{}
  148. styleNode.Type = html.ElementNode
  149. styleNode.Data = "style"
  150. styleNode.Attr = []html.Attribute{{Key: "type", Val: "text/css"}}
  151. cssNode := &html.Node{}
  152. cssData := make([]string, 0, len(pr.leftover))
  153. for _, rule := range pr.leftover {
  154. if rule.Type == css.MEDIA_RULE {
  155. mcssData := make([]string, 0, len(rule.Rules))
  156. for _, mrule := range rule.Rules {
  157. mcssData = append(mcssData, makeRuleImportant(mrule))
  158. }
  159. cssData = append(cssData, fmt.Sprintf("%s %s{\n%s\n}\n",
  160. rule.Type.Text(),
  161. rule.Style.Selector.Text(),
  162. strings.Join(mcssData, "\n")))
  163. } else {
  164. cssData = append(cssData, makeRuleImportant(rule))
  165. }
  166. }
  167. cssNode.Data = strings.Join(cssData, "")
  168. cssNode.Type = html.TextNode
  169. styleNode.AppendChild(cssNode)
  170. headNode.AppendNodes(styleNode)
  171. }
  172. }
  173. // Transform process and inlining css
  174. // It start to collect the rules in the document style tags
  175. // Calculate specificity and sort the rules based on that
  176. // It then collects the affected elements
  177. // And applies the rules on those
  178. // The leftover rules will put back into a style element
  179. func (pr *premailer) Transform() (string, error) {
  180. if !pr.processed {
  181. pr.collectRules()
  182. pr.sortRules()
  183. pr.collectElements()
  184. pr.applyInline()
  185. pr.addLeftover()
  186. pr.processed = true
  187. }
  188. return pr.doc.Html()
  189. }