socket_classic.go 7.6 KB


  1. // Copyright 2012 Google Inc. All rights reserved.
  2. // Use of this source code is governed by the Apache 2.0
  3. // license that can be found in the LICENSE file.
  4. // +build appengine
  5. package socket
  6. import (
  7. "fmt"
  8. "io"
  9. "net"
  10. "strconv"
  11. "time"
  12. "github.com/golang/protobuf/proto"
  13. "golang.org/x/net/context"
  14. "google.golang.org/appengine/internal"
  15. pb "google.golang.org/appengine/internal/socket"
  16. )
  17. // Dial connects to the address addr on the network protocol.
  18. // The address format is host:port, where host may be a hostname or an IP address.
  19. // Known protocols are "tcp" and "udp".
  20. // The returned connection satisfies net.Conn, and is valid while ctx is valid;
  21. // if the connection is to be used after ctx becomes invalid, invoke SetContext
  22. // with the new context.
  23. func Dial(ctx context.Context, protocol, addr string) (*Conn, error) {
  24. return DialTimeout(ctx, protocol, addr, 0)
  25. }
  26. var ipFamilies = []pb.CreateSocketRequest_SocketFamily{
  27. pb.CreateSocketRequest_IPv4,
  28. pb.CreateSocketRequest_IPv6,
  29. }
  30. // DialTimeout is like Dial but takes a timeout.
  31. // The timeout includes name resolution, if required.
  32. func DialTimeout(ctx context.Context, protocol, addr string, timeout time.Duration) (*Conn, error) {
  33. dialCtx := ctx // Used for dialing and name resolution, but not stored in the *Conn.
  34. if timeout > 0 {
  35. var cancel context.CancelFunc
  36. dialCtx, cancel = context.WithTimeout(ctx, timeout)
  37. defer cancel()
  38. }
  39. host, portStr, err := net.SplitHostPort(addr)
  40. if err != nil {
  41. return nil, err
  42. }
  43. port, err := strconv.Atoi(portStr)
  44. if err != nil {
  45. return nil, fmt.Errorf("socket: bad port %q: %v", portStr, err)
  46. }
  47. var prot pb.CreateSocketRequest_SocketProtocol
  48. switch protocol {
  49. case "tcp":
  50. prot = pb.CreateSocketRequest_TCP
  51. case "udp":
  52. prot = pb.CreateSocketRequest_UDP
  53. default:
  54. return nil, fmt.Errorf("socket: unknown protocol %q", protocol)
  55. }
  56. packedAddrs, resolved, err := resolve(dialCtx, ipFamilies, host)
  57. if err != nil {
  58. return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
  59. }
  60. if len(packedAddrs) == 0 {
  61. return nil, fmt.Errorf("no addresses for %q", host)
  62. }
  63. packedAddr := packedAddrs[0] // use first address
  64. fam := pb.CreateSocketRequest_IPv4
  65. if len(packedAddr) == net.IPv6len {
  66. fam = pb.CreateSocketRequest_IPv6
  67. }
  68. req := &pb.CreateSocketRequest{
  69. Family: fam.Enum(),
  70. Protocol: prot.Enum(),
  71. RemoteIp: &pb.AddressPort{
  72. Port: proto.Int32(int32(port)),
  73. PackedAddress: packedAddr,
  74. },
  75. }
  76. if resolved {
  77. req.RemoteIp.HostnameHint = &host
  78. }
  79. res := &pb.CreateSocketReply{}
  80. if err := internal.Call(dialCtx, "remote_socket", "CreateSocket", req, res); err != nil {
  81. return nil, err
  82. }
  83. return &Conn{
  84. ctx: ctx,
  85. desc: res.GetSocketDescriptor(),
  86. prot: prot,
  87. local: res.ProxyExternalIp,
  88. remote: req.RemoteIp,
  89. }, nil
  90. }
  91. // LookupIP returns the given host's IP addresses.
  92. func LookupIP(ctx context.Context, host string) (addrs []net.IP, err error) {
  93. packedAddrs, _, err := resolve(ctx, ipFamilies, host)
  94. if err != nil {
  95. return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
  96. }
  97. addrs = make([]net.IP, len(packedAddrs))
  98. for i, pa := range packedAddrs {
  99. addrs[i] = net.IP(pa)
  100. }
  101. return addrs, nil
  102. }
  103. func resolve(ctx context.Context, fams []pb.CreateSocketRequest_SocketFamily, host string) ([][]byte, bool, error) {
  104. // Check if it's an IP address.
  105. if ip := net.ParseIP(host); ip != nil {
  106. if ip := ip.To4(); ip != nil {
  107. return [][]byte{ip}, false, nil
  108. }
  109. return [][]byte{ip}, false, nil
  110. }
  111. req := &pb.ResolveRequest{
  112. Name: &host,
  113. AddressFamilies: fams,
  114. }
  115. res := &pb.ResolveReply{}
  116. if err := internal.Call(ctx, "remote_socket", "Resolve", req, res); err != nil {
  117. // XXX: need to map to pb.ResolveReply_ErrorCode?
  118. return nil, false, err
  119. }
  120. return res.PackedAddress, true, nil
  121. }
  122. // withDeadline is like context.WithDeadline, except it ignores the zero deadline.
  123. func withDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) {
  124. if deadline.IsZero() {
  125. return parent, func() {}
  126. }
  127. return context.WithDeadline(parent, deadline)
  128. }
  129. // Conn represents a socket connection.
  130. // It implements net.Conn.
  131. type Conn struct {
  132. ctx context.Context
  133. desc string
  134. offset int64
  135. prot pb.CreateSocketRequest_SocketProtocol
  136. local, remote *pb.AddressPort
  137. readDeadline, writeDeadline time.Time // optional
  138. }
  139. // SetContext sets the context that is used by this Conn.
  140. // It is usually used only when using a Conn that was created in a different context,
  141. // such as when a connection is created during a warmup request but used while
  142. // servicing a user request.
  143. func (cn *Conn) SetContext(ctx context.Context) {
  144. cn.ctx = ctx
  145. }
  146. func (cn *Conn) Read(b []byte) (n int, err error) {
  147. const maxRead = 1 << 20
  148. if len(b) > maxRead {
  149. b = b[:maxRead]
  150. }
  151. req := &pb.ReceiveRequest{
  152. SocketDescriptor: &cn.desc,
  153. DataSize: proto.Int32(int32(len(b))),
  154. }
  155. res := &pb.ReceiveReply{}
  156. if !cn.readDeadline.IsZero() {
  157. req.TimeoutSeconds = proto.Float64(cn.readDeadline.Sub(time.Now()).Seconds())
  158. }
  159. ctx, cancel := withDeadline(cn.ctx, cn.readDeadline)
  160. defer cancel()
  161. if err := internal.Call(ctx, "remote_socket", "Receive", req, res); err != nil {
  162. return 0, err
  163. }
  164. if len(res.Data) == 0 {
  165. return 0, io.EOF
  166. }
  167. if len(res.Data) > len(b) {
  168. return 0, fmt.Errorf("socket: internal error: read too much data: %d > %d", len(res.Data), len(b))
  169. }
  170. return copy(b, res.Data), nil
  171. }
  172. func (cn *Conn) Write(b []byte) (n int, err error) {
  173. const lim = 1 << 20 // max per chunk
  174. for n < len(b) {
  175. chunk := b[n:]
  176. if len(chunk) > lim {
  177. chunk = chunk[:lim]
  178. }
  179. req := &pb.SendRequest{
  180. SocketDescriptor: &cn.desc,
  181. Data: chunk,
  182. StreamOffset: &cn.offset,
  183. }
  184. res := &pb.SendReply{}
  185. if !cn.writeDeadline.IsZero() {
  186. req.TimeoutSeconds = proto.Float64(cn.writeDeadline.Sub(time.Now()).Seconds())
  187. }
  188. ctx, cancel := withDeadline(cn.ctx, cn.writeDeadline)
  189. defer cancel()
  190. if err = internal.Call(ctx, "remote_socket", "Send", req, res); err != nil {
  191. // assume zero bytes were sent in this RPC
  192. break
  193. }
  194. n += int(res.GetDataSent())
  195. cn.offset += int64(res.GetDataSent())
  196. }
  197. return
  198. }
  199. func (cn *Conn) Close() error {
  200. req := &pb.CloseRequest{
  201. SocketDescriptor: &cn.desc,
  202. }
  203. res := &pb.CloseReply{}
  204. if err := internal.Call(cn.ctx, "remote_socket", "Close", req, res); err != nil {
  205. return err
  206. }
  207. cn.desc = "CLOSED"
  208. return nil
  209. }
  210. func addr(prot pb.CreateSocketRequest_SocketProtocol, ap *pb.AddressPort) net.Addr {
  211. if ap == nil {
  212. return nil
  213. }
  214. switch prot {
  215. case pb.CreateSocketRequest_TCP:
  216. return &net.TCPAddr{
  217. IP: net.IP(ap.PackedAddress),
  218. Port: int(*ap.Port),
  219. }
  220. case pb.CreateSocketRequest_UDP:
  221. return &net.UDPAddr{
  222. IP: net.IP(ap.PackedAddress),
  223. Port: int(*ap.Port),
  224. }
  225. }
  226. panic("unknown protocol " + prot.String())
  227. }
  228. func (cn *Conn) LocalAddr() net.Addr { return addr(cn.prot, cn.local) }
  229. func (cn *Conn) RemoteAddr() net.Addr { return addr(cn.prot, cn.remote) }
  230. func (cn *Conn) SetDeadline(t time.Time) error {
  231. cn.readDeadline = t
  232. cn.writeDeadline = t
  233. return nil
  234. }
  235. func (cn *Conn) SetReadDeadline(t time.Time) error {
  236. cn.readDeadline = t
  237. return nil
  238. }
  239. func (cn *Conn) SetWriteDeadline(t time.Time) error {
  240. cn.writeDeadline = t
  241. return nil
  242. }
  243. // KeepAlive signals that the connection is still in use.
  244. // It may be called to prevent the socket being closed due to inactivity.
  245. func (cn *Conn) KeepAlive() error {
  246. req := &pb.GetSocketNameRequest{
  247. SocketDescriptor: &cn.desc,
  248. }
  249. res := &pb.GetSocketNameReply{}
  250. return internal.Call(cn.ctx, "remote_socket", "GetSocketName", req, res)
  251. }
  252. func init() {
  253. internal.RegisterErrorCodeMap("remote_socket", pb.RemoteSocketServiceError_ErrorCode_name)
  254. }