download.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. package s3manager
  2. import (
  3. "fmt"
  4. "io"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/aws/aws-sdk-go/aws/awsutil"
  10. "github.com/aws/aws-sdk-go/service/s3"
  11. )
  12. // The default range of bytes to get at a time when using Download().
  13. var DefaultDownloadPartSize int64 = 1024 * 1024 * 5
  14. // The default number of goroutines to spin up when using Download().
  15. var DefaultDownloadConcurrency = 5
  16. // The default set of options used when opts is nil in Download().
  17. var DefaultDownloadOptions = &DownloadOptions{
  18. PartSize: DefaultDownloadPartSize,
  19. Concurrency: DefaultDownloadConcurrency,
  20. }
  21. // DownloadOptions keeps tracks of extra options to pass to an Download() call.
  22. type DownloadOptions struct {
  23. // The buffer size (in bytes) to use when buffering data into chunks and
  24. // sending them as parts to S3. The minimum allowed part size is 5MB, and
  25. // if this value is set to zero, the DefaultPartSize value will be used.
  26. PartSize int64
  27. // The number of goroutines to spin up in parallel when sending parts.
  28. // If this is set to zero, the DefaultConcurrency value will be used.
  29. Concurrency int
  30. // An S3 client to use when performing downloads. Leave this as nil to use
  31. // a default client.
  32. S3 *s3.S3
  33. }
  34. // NewDownloader creates a new Downloader structure that downloads an object
  35. // from S3 in concurrent chunks. Pass in an optional DownloadOptions struct
  36. // to customize the downloader behavior.
  37. func NewDownloader(opts *DownloadOptions) *Downloader {
  38. if opts == nil {
  39. opts = DefaultDownloadOptions
  40. }
  41. return &Downloader{opts: opts}
  42. }
  43. // The Downloader structure that calls Download(). It is safe to call Download()
  44. // on this structure for multiple objects and across concurrent goroutines.
  45. type Downloader struct {
  46. opts *DownloadOptions
  47. }
  48. // Download downloads an object in S3 and writes the payload into w using
  49. // concurrent GET requests.
  50. //
  51. // It is safe to call this method for multiple objects and across concurrent
  52. // goroutines.
  53. //
  54. // The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
  55. // downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
  56. func (d *Downloader) Download(w io.WriterAt, input *s3.GetObjectInput) (n int64, err error) {
  57. impl := downloader{w: w, in: input, opts: *d.opts}
  58. return impl.download()
  59. }
  60. // downloader is the implementation structure used internally by Downloader.
  61. type downloader struct {
  62. opts DownloadOptions
  63. in *s3.GetObjectInput
  64. w io.WriterAt
  65. wg sync.WaitGroup
  66. m sync.Mutex
  67. pos int64
  68. totalBytes int64
  69. written int64
  70. err error
  71. }
  72. // init initializes the downloader with default options.
  73. func (d *downloader) init() {
  74. d.totalBytes = -1
  75. if d.opts.Concurrency == 0 {
  76. d.opts.Concurrency = DefaultDownloadConcurrency
  77. }
  78. if d.opts.PartSize == 0 {
  79. d.opts.PartSize = DefaultDownloadPartSize
  80. }
  81. if d.opts.S3 == nil {
  82. d.opts.S3 = s3.New(nil)
  83. }
  84. }
  85. // download performs the implementation of the object download across ranged
  86. // GETs.
  87. func (d *downloader) download() (n int64, err error) {
  88. d.init()
  89. // Spin up workers
  90. ch := make(chan dlchunk, d.opts.Concurrency)
  91. for i := 0; i < d.opts.Concurrency; i++ {
  92. d.wg.Add(1)
  93. go d.downloadPart(ch)
  94. }
  95. // Assign work
  96. for d.geterr() == nil {
  97. if d.pos != 0 {
  98. // This is not the first chunk, let's wait until we know the total
  99. // size of the payload so we can see if we have read the entire
  100. // object.
  101. total := d.getTotalBytes()
  102. if total < 0 {
  103. // Total has not yet been set, so sleep and loop around while
  104. // waiting for our first worker to resolve this value.
  105. time.Sleep(10 * time.Millisecond)
  106. continue
  107. } else if d.pos >= total {
  108. break // We're finished queueing chunks
  109. }
  110. }
  111. // Queue the next range of bytes to read.
  112. ch <- dlchunk{w: d.w, start: d.pos, size: d.opts.PartSize}
  113. d.pos += d.opts.PartSize
  114. }
  115. // Wait for completion
  116. close(ch)
  117. d.wg.Wait()
  118. // Return error
  119. return d.written, d.err
  120. }
  121. // downloadPart is an individual goroutine worker reading from the ch channel
  122. // and performing a GetObject request on the data with a given byte range.
  123. //
  124. // If this is the first worker, this operation also resolves the total number
  125. // of bytes to be read so that the worker manager knows when it is finished.
  126. func (d *downloader) downloadPart(ch chan dlchunk) {
  127. defer d.wg.Done()
  128. for {
  129. chunk, ok := <-ch
  130. if !ok {
  131. break
  132. }
  133. if d.geterr() == nil {
  134. // Get the next byte range of data
  135. in := &s3.GetObjectInput{}
  136. awsutil.Copy(in, d.in)
  137. rng := fmt.Sprintf("bytes=%d-%d",
  138. chunk.start, chunk.start+chunk.size-1)
  139. in.Range = &rng
  140. resp, err := d.opts.S3.GetObject(in)
  141. if err != nil {
  142. d.seterr(err)
  143. } else {
  144. d.setTotalBytes(resp) // Set total if not yet set.
  145. n, err := io.Copy(&chunk, resp.Body)
  146. resp.Body.Close()
  147. if err != nil {
  148. d.seterr(err)
  149. }
  150. d.incrwritten(n)
  151. }
  152. }
  153. }
  154. }
  155. // getTotalBytes is a thread-safe getter for retrieving the total byte status.
  156. func (d *downloader) getTotalBytes() int64 {
  157. d.m.Lock()
  158. defer d.m.Unlock()
  159. return d.totalBytes
  160. }
  161. // getTotalBytes is a thread-safe setter for setting the total byte status.
  162. func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
  163. d.m.Lock()
  164. defer d.m.Unlock()
  165. if d.totalBytes >= 0 {
  166. return
  167. }
  168. parts := strings.Split(*resp.ContentRange, "/")
  169. total, err := strconv.ParseInt(parts[len(parts)-1], 10, 64)
  170. if err != nil {
  171. d.err = err
  172. return
  173. }
  174. d.totalBytes = total
  175. }
  176. func (d *downloader) incrwritten(n int64) {
  177. d.m.Lock()
  178. defer d.m.Unlock()
  179. d.written += n
  180. }
  181. // geterr is a thread-safe getter for the error object
  182. func (d *downloader) geterr() error {
  183. d.m.Lock()
  184. defer d.m.Unlock()
  185. return d.err
  186. }
  187. // seterr is a thread-safe setter for the error object
  188. func (d *downloader) seterr(e error) {
  189. d.m.Lock()
  190. defer d.m.Unlock()
  191. d.err = e
  192. }
  193. // dlchunk represents a single chunk of data to write by the worker routine.
  194. // This structure also implements an io.SectionReader style interface for
  195. // io.WriterAt, effectively making it an io.SectionWriter (which does not
  196. // exist).
  197. type dlchunk struct {
  198. w io.WriterAt
  199. start int64
  200. size int64
  201. cur int64
  202. }
  203. // Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
  204. // position to its end (or EOF).
  205. func (c *dlchunk) Write(p []byte) (n int, err error) {
  206. if c.cur >= c.size {
  207. return 0, io.EOF
  208. }
  209. n, err = c.w.WriteAt(p, c.start+c.cur)
  210. c.cur += int64(n)
  211. return
  212. }