123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- // Copyright 2012 Google Inc. All rights reserved.
- // Use of this source code is governed by the Apache 2.0
- // license that can be found in the LICENSE file.
- // +build appengine
- package socket
- import (
- "fmt"
- "io"
- "net"
- "strconv"
- "time"
- "github.com/golang/protobuf/proto"
- "golang.org/x/net/context"
- "google.golang.org/appengine/internal"
- pb "google.golang.org/appengine/internal/socket"
- )
- // Dial connects to the address addr on the network protocol.
- // The address format is host:port, where host may be a hostname or an IP address.
- // Known protocols are "tcp" and "udp".
- // The returned connection satisfies net.Conn, and is valid while ctx is valid;
- // if the connection is to be used after ctx becomes invalid, invoke SetContext
- // with the new context.
- func Dial(ctx context.Context, protocol, addr string) (*Conn, error) {
- return DialTimeout(ctx, protocol, addr, 0)
- }
- var ipFamilies = []pb.CreateSocketRequest_SocketFamily{
- pb.CreateSocketRequest_IPv4,
- pb.CreateSocketRequest_IPv6,
- }
- // DialTimeout is like Dial but takes a timeout.
- // The timeout includes name resolution, if required.
- func DialTimeout(ctx context.Context, protocol, addr string, timeout time.Duration) (*Conn, error) {
- dialCtx := ctx // Used for dialing and name resolution, but not stored in the *Conn.
- if timeout > 0 {
- var cancel context.CancelFunc
- dialCtx, cancel = context.WithTimeout(ctx, timeout)
- defer cancel()
- }
- host, portStr, err := net.SplitHostPort(addr)
- if err != nil {
- return nil, err
- }
- port, err := strconv.Atoi(portStr)
- if err != nil {
- return nil, fmt.Errorf("socket: bad port %q: %v", portStr, err)
- }
- var prot pb.CreateSocketRequest_SocketProtocol
- switch protocol {
- case "tcp":
- prot = pb.CreateSocketRequest_TCP
- case "udp":
- prot = pb.CreateSocketRequest_UDP
- default:
- return nil, fmt.Errorf("socket: unknown protocol %q", protocol)
- }
- packedAddrs, resolved, err := resolve(dialCtx, ipFamilies, host)
- if err != nil {
- return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
- }
- if len(packedAddrs) == 0 {
- return nil, fmt.Errorf("no addresses for %q", host)
- }
- packedAddr := packedAddrs[0] // use first address
- fam := pb.CreateSocketRequest_IPv4
- if len(packedAddr) == net.IPv6len {
- fam = pb.CreateSocketRequest_IPv6
- }
- req := &pb.CreateSocketRequest{
- Family: fam.Enum(),
- Protocol: prot.Enum(),
- RemoteIp: &pb.AddressPort{
- Port: proto.Int32(int32(port)),
- PackedAddress: packedAddr,
- },
- }
- if resolved {
- req.RemoteIp.HostnameHint = &host
- }
- res := &pb.CreateSocketReply{}
- if err := internal.Call(dialCtx, "remote_socket", "CreateSocket", req, res); err != nil {
- return nil, err
- }
- return &Conn{
- ctx: ctx,
- desc: res.GetSocketDescriptor(),
- prot: prot,
- local: res.ProxyExternalIp,
- remote: req.RemoteIp,
- }, nil
- }
- // LookupIP returns the given host's IP addresses.
- func LookupIP(ctx context.Context, host string) (addrs []net.IP, err error) {
- packedAddrs, _, err := resolve(ctx, ipFamilies, host)
- if err != nil {
- return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
- }
- addrs = make([]net.IP, len(packedAddrs))
- for i, pa := range packedAddrs {
- addrs[i] = net.IP(pa)
- }
- return addrs, nil
- }
- func resolve(ctx context.Context, fams []pb.CreateSocketRequest_SocketFamily, host string) ([][]byte, bool, error) {
- // Check if it's an IP address.
- if ip := net.ParseIP(host); ip != nil {
- if ip := ip.To4(); ip != nil {
- return [][]byte{ip}, false, nil
- }
- return [][]byte{ip}, false, nil
- }
- req := &pb.ResolveRequest{
- Name: &host,
- AddressFamilies: fams,
- }
- res := &pb.ResolveReply{}
- if err := internal.Call(ctx, "remote_socket", "Resolve", req, res); err != nil {
- // XXX: need to map to pb.ResolveReply_ErrorCode?
- return nil, false, err
- }
- return res.PackedAddress, true, nil
- }
- // withDeadline is like context.WithDeadline, except it ignores the zero deadline.
- func withDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) {
- if deadline.IsZero() {
- return parent, func() {}
- }
- return context.WithDeadline(parent, deadline)
- }
- // Conn represents a socket connection.
- // It implements net.Conn.
- type Conn struct {
- ctx context.Context
- desc string
- offset int64
- prot pb.CreateSocketRequest_SocketProtocol
- local, remote *pb.AddressPort
- readDeadline, writeDeadline time.Time // optional
- }
- // SetContext sets the context that is used by this Conn.
- // It is usually used only when using a Conn that was created in a different context,
- // such as when a connection is created during a warmup request but used while
- // servicing a user request.
- func (cn *Conn) SetContext(ctx context.Context) {
- cn.ctx = ctx
- }
- func (cn *Conn) Read(b []byte) (n int, err error) {
- const maxRead = 1 << 20
- if len(b) > maxRead {
- b = b[:maxRead]
- }
- req := &pb.ReceiveRequest{
- SocketDescriptor: &cn.desc,
- DataSize: proto.Int32(int32(len(b))),
- }
- res := &pb.ReceiveReply{}
- if !cn.readDeadline.IsZero() {
- req.TimeoutSeconds = proto.Float64(cn.readDeadline.Sub(time.Now()).Seconds())
- }
- ctx, cancel := withDeadline(cn.ctx, cn.readDeadline)
- defer cancel()
- if err := internal.Call(ctx, "remote_socket", "Receive", req, res); err != nil {
- return 0, err
- }
- if len(res.Data) == 0 {
- return 0, io.EOF
- }
- if len(res.Data) > len(b) {
- return 0, fmt.Errorf("socket: internal error: read too much data: %d > %d", len(res.Data), len(b))
- }
- return copy(b, res.Data), nil
- }
- func (cn *Conn) Write(b []byte) (n int, err error) {
- const lim = 1 << 20 // max per chunk
- for n < len(b) {
- chunk := b[n:]
- if len(chunk) > lim {
- chunk = chunk[:lim]
- }
- req := &pb.SendRequest{
- SocketDescriptor: &cn.desc,
- Data: chunk,
- StreamOffset: &cn.offset,
- }
- res := &pb.SendReply{}
- if !cn.writeDeadline.IsZero() {
- req.TimeoutSeconds = proto.Float64(cn.writeDeadline.Sub(time.Now()).Seconds())
- }
- ctx, cancel := withDeadline(cn.ctx, cn.writeDeadline)
- defer cancel()
- if err = internal.Call(ctx, "remote_socket", "Send", req, res); err != nil {
- // assume zero bytes were sent in this RPC
- break
- }
- n += int(res.GetDataSent())
- cn.offset += int64(res.GetDataSent())
- }
- return
- }
- func (cn *Conn) Close() error {
- req := &pb.CloseRequest{
- SocketDescriptor: &cn.desc,
- }
- res := &pb.CloseReply{}
- if err := internal.Call(cn.ctx, "remote_socket", "Close", req, res); err != nil {
- return err
- }
- cn.desc = "CLOSED"
- return nil
- }
- func addr(prot pb.CreateSocketRequest_SocketProtocol, ap *pb.AddressPort) net.Addr {
- if ap == nil {
- return nil
- }
- switch prot {
- case pb.CreateSocketRequest_TCP:
- return &net.TCPAddr{
- IP: net.IP(ap.PackedAddress),
- Port: int(*ap.Port),
- }
- case pb.CreateSocketRequest_UDP:
- return &net.UDPAddr{
- IP: net.IP(ap.PackedAddress),
- Port: int(*ap.Port),
- }
- }
- panic("unknown protocol " + prot.String())
- }
- func (cn *Conn) LocalAddr() net.Addr { return addr(cn.prot, cn.local) }
- func (cn *Conn) RemoteAddr() net.Addr { return addr(cn.prot, cn.remote) }
- func (cn *Conn) SetDeadline(t time.Time) error {
- cn.readDeadline = t
- cn.writeDeadline = t
- return nil
- }
- func (cn *Conn) SetReadDeadline(t time.Time) error {
- cn.readDeadline = t
- return nil
- }
- func (cn *Conn) SetWriteDeadline(t time.Time) error {
- cn.writeDeadline = t
- return nil
- }
- // KeepAlive signals that the connection is still in use.
- // It may be called to prevent the socket being closed due to inactivity.
- func (cn *Conn) KeepAlive() error {
- req := &pb.GetSocketNameRequest{
- SocketDescriptor: &cn.desc,
- }
- res := &pb.GetSocketNameReply{}
- return internal.Call(cn.ctx, "remote_socket", "GetSocketName", req, res)
- }
- func init() {
- internal.RegisterErrorCodeMap("remote_socket", pb.RemoteSocketServiceError_ErrorCode_name)
- }
|