123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552 |
- // Package nl has low level primitives for making Netlink calls.
- package nl
- import (
- "bytes"
- "encoding/binary"
- "fmt"
- "net"
- "runtime"
- "sync"
- "sync/atomic"
- "syscall"
- "unsafe"
- "github.com/vishvananda/netns"
- )
- const (
- // Family type definitions
- FAMILY_V4 = syscall.AF_INET
- FAMILY_V6 = syscall.AF_INET6
- )
- // SupportedNlFamilies contains the list of netlink families this netlink package supports
- var SupportedNlFamilies = []int{syscall.NETLINK_ROUTE, syscall.NETLINK_XFRM}
- var nextSeqNr uint32
- // GetIPFamily returns the family type of a net.IP.
- func GetIPFamily(ip net.IP) int {
- if len(ip) <= net.IPv4len {
- return FAMILY_V4
- }
- if ip.To4() != nil {
- return FAMILY_V4
- }
- return FAMILY_V6
- }
- var nativeEndian binary.ByteOrder
- // Get native endianness for the system
- func NativeEndian() binary.ByteOrder {
- if nativeEndian == nil {
- var x uint32 = 0x01020304
- if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
- nativeEndian = binary.BigEndian
- } else {
- nativeEndian = binary.LittleEndian
- }
- }
- return nativeEndian
- }
- // Byte swap a 16 bit value if we aren't big endian
- func Swap16(i uint16) uint16 {
- if NativeEndian() == binary.BigEndian {
- return i
- }
- return (i&0xff00)>>8 | (i&0xff)<<8
- }
- // Byte swap a 32 bit value if aren't big endian
- func Swap32(i uint32) uint32 {
- if NativeEndian() == binary.BigEndian {
- return i
- }
- return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24
- }
- type NetlinkRequestData interface {
- Len() int
- Serialize() []byte
- }
- // IfInfomsg is related to links, but it is used for list requests as well
- type IfInfomsg struct {
- syscall.IfInfomsg
- }
- // Create an IfInfomsg with family specified
- func NewIfInfomsg(family int) *IfInfomsg {
- return &IfInfomsg{
- IfInfomsg: syscall.IfInfomsg{
- Family: uint8(family),
- },
- }
- }
- func DeserializeIfInfomsg(b []byte) *IfInfomsg {
- return (*IfInfomsg)(unsafe.Pointer(&b[0:syscall.SizeofIfInfomsg][0]))
- }
- func (msg *IfInfomsg) Serialize() []byte {
- return (*(*[syscall.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:]
- }
- func (msg *IfInfomsg) Len() int {
- return syscall.SizeofIfInfomsg
- }
- func rtaAlignOf(attrlen int) int {
- return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1)
- }
- func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg {
- msg := NewIfInfomsg(family)
- parent.children = append(parent.children, msg)
- return msg
- }
- // Extend RtAttr to handle data and children
- type RtAttr struct {
- syscall.RtAttr
- Data []byte
- children []NetlinkRequestData
- }
- // Create a new Extended RtAttr object
- func NewRtAttr(attrType int, data []byte) *RtAttr {
- return &RtAttr{
- RtAttr: syscall.RtAttr{
- Type: uint16(attrType),
- },
- children: []NetlinkRequestData{},
- Data: data,
- }
- }
- // Create a new RtAttr obj anc add it as a child of an existing object
- func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
- attr := NewRtAttr(attrType, data)
- parent.children = append(parent.children, attr)
- return attr
- }
- func (a *RtAttr) Len() int {
- if len(a.children) == 0 {
- return (syscall.SizeofRtAttr + len(a.Data))
- }
- l := 0
- for _, child := range a.children {
- l += rtaAlignOf(child.Len())
- }
- l += syscall.SizeofRtAttr
- return rtaAlignOf(l + len(a.Data))
- }
- // Serialize the RtAttr into a byte array
- // This can't just unsafe.cast because it must iterate through children.
- func (a *RtAttr) Serialize() []byte {
- native := NativeEndian()
- length := a.Len()
- buf := make([]byte, rtaAlignOf(length))
- next := 4
- if a.Data != nil {
- copy(buf[next:], a.Data)
- next += rtaAlignOf(len(a.Data))
- }
- if len(a.children) > 0 {
- for _, child := range a.children {
- childBuf := child.Serialize()
- copy(buf[next:], childBuf)
- next += rtaAlignOf(len(childBuf))
- }
- }
- if l := uint16(length); l != 0 {
- native.PutUint16(buf[0:2], l)
- }
- native.PutUint16(buf[2:4], a.Type)
- return buf
- }
- type NetlinkRequest struct {
- syscall.NlMsghdr
- Data []NetlinkRequestData
- Sockets map[int]*SocketHandle
- }
- // Serialize the Netlink Request into a byte array
- func (req *NetlinkRequest) Serialize() []byte {
- length := syscall.SizeofNlMsghdr
- dataBytes := make([][]byte, len(req.Data))
- for i, data := range req.Data {
- dataBytes[i] = data.Serialize()
- length = length + len(dataBytes[i])
- }
- req.Len = uint32(length)
- b := make([]byte, length)
- hdr := (*(*[syscall.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
- next := syscall.SizeofNlMsghdr
- copy(b[0:next], hdr)
- for _, data := range dataBytes {
- for _, dataByte := range data {
- b[next] = dataByte
- next = next + 1
- }
- }
- return b
- }
- func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
- if data != nil {
- req.Data = append(req.Data, data)
- }
- }
- // Execute the request against a the given sockType.
- // Returns a list of netlink messages in serialized format, optionally filtered
- // by resType.
- func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
- var (
- s *NetlinkSocket
- err error
- )
- if req.Sockets != nil {
- if sh, ok := req.Sockets[sockType]; ok {
- s = sh.Socket
- req.Seq = atomic.AddUint32(&sh.Seq, 1)
- }
- }
- sharedSocket := s != nil
- if s == nil {
- s, err = getNetlinkSocket(sockType)
- if err != nil {
- return nil, err
- }
- defer s.Close()
- } else {
- s.Lock()
- defer s.Unlock()
- }
- if err := s.Send(req); err != nil {
- return nil, err
- }
- pid, err := s.GetPid()
- if err != nil {
- return nil, err
- }
- var res [][]byte
- done:
- for {
- msgs, err := s.Receive()
- if err != nil {
- return nil, err
- }
- for _, m := range msgs {
- if m.Header.Seq != req.Seq {
- if sharedSocket {
- continue
- }
- return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
- }
- if m.Header.Pid != pid {
- return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
- }
- if m.Header.Type == syscall.NLMSG_DONE {
- break done
- }
- if m.Header.Type == syscall.NLMSG_ERROR {
- native := NativeEndian()
- error := int32(native.Uint32(m.Data[0:4]))
- if error == 0 {
- break done
- }
- return nil, syscall.Errno(-error)
- }
- if resType != 0 && m.Header.Type != resType {
- continue
- }
- res = append(res, m.Data)
- if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
- break done
- }
- }
- }
- return res, nil
- }
- // Create a new netlink request from proto and flags
- // Note the Len value will be inaccurate once data is added until
- // the message is serialized
- func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
- return &NetlinkRequest{
- NlMsghdr: syscall.NlMsghdr{
- Len: uint32(syscall.SizeofNlMsghdr),
- Type: uint16(proto),
- Flags: syscall.NLM_F_REQUEST | uint16(flags),
- Seq: atomic.AddUint32(&nextSeqNr, 1),
- },
- }
- }
- type NetlinkSocket struct {
- fd int
- lsa syscall.SockaddrNetlink
- sync.Mutex
- }
- func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
- fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
- if err != nil {
- return nil, err
- }
- s := &NetlinkSocket{
- fd: fd,
- }
- s.lsa.Family = syscall.AF_NETLINK
- if err := syscall.Bind(fd, &s.lsa); err != nil {
- syscall.Close(fd)
- return nil, err
- }
- return s, nil
- }
- // GetNetlinkSocketAt opens a netlink socket in the network namespace newNs
- // and positions the thread back into the network namespace specified by curNs,
- // when done. If curNs is close, the function derives the current namespace and
- // moves back into it when done. If newNs is close, the socket will be opened
- // in the current network namespace.
- func GetNetlinkSocketAt(newNs, curNs netns.NsHandle, protocol int) (*NetlinkSocket, error) {
- c, err := executeInNetns(newNs, curNs)
- if err != nil {
- return nil, err
- }
- defer c()
- return getNetlinkSocket(protocol)
- }
- // executeInNetns sets execution of the code following this call to the
- // network namespace newNs, then moves the thread back to curNs if open,
- // otherwise to the current netns at the time the function was invoked
- // In case of success, the caller is expected to execute the returned function
- // at the end of the code that needs to be executed in the network namespace.
- // Example:
- // func jobAt(...) error {
- // d, err := executeInNetns(...)
- // if err != nil { return err}
- // defer d()
- // < code which needs to be executed in specific netns>
- // }
- // TODO: his function probably belongs to netns pkg.
- func executeInNetns(newNs, curNs netns.NsHandle) (func(), error) {
- var (
- err error
- moveBack func(netns.NsHandle) error
- closeNs func() error
- unlockThd func()
- )
- restore := func() {
- // order matters
- if moveBack != nil {
- moveBack(curNs)
- }
- if closeNs != nil {
- closeNs()
- }
- if unlockThd != nil {
- unlockThd()
- }
- }
- if newNs.IsOpen() {
- runtime.LockOSThread()
- unlockThd = runtime.UnlockOSThread
- if !curNs.IsOpen() {
- if curNs, err = netns.Get(); err != nil {
- restore()
- return nil, fmt.Errorf("could not get current namespace while creating netlink socket: %v", err)
- }
- closeNs = curNs.Close
- }
- if err := netns.Set(newNs); err != nil {
- restore()
- return nil, fmt.Errorf("failed to set into network namespace %d while creating netlink socket: %v", newNs, err)
- }
- moveBack = netns.Set
- }
- return restore, nil
- }
- // Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
- // and subscribe it to multicast groups passed in variable argument list.
- // Returns the netlink socket on which Receive() method can be called
- // to retrieve the messages from the kernel.
- func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
- fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
- if err != nil {
- return nil, err
- }
- s := &NetlinkSocket{
- fd: fd,
- }
- s.lsa.Family = syscall.AF_NETLINK
- for _, g := range groups {
- s.lsa.Groups |= (1 << (g - 1))
- }
- if err := syscall.Bind(fd, &s.lsa); err != nil {
- syscall.Close(fd)
- return nil, err
- }
- return s, nil
- }
- // SubscribeAt works like Subscribe plus let's the caller choose the network
- // namespace in which the socket would be opened (newNs). Then control goes back
- // to curNs if open, otherwise to the netns at the time this function was called.
- func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*NetlinkSocket, error) {
- c, err := executeInNetns(newNs, curNs)
- if err != nil {
- return nil, err
- }
- defer c()
- return Subscribe(protocol, groups...)
- }
- func (s *NetlinkSocket) Close() {
- syscall.Close(s.fd)
- s.fd = -1
- }
- func (s *NetlinkSocket) GetFd() int {
- return s.fd
- }
- func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
- if s.fd < 0 {
- return fmt.Errorf("Send called on a closed socket")
- }
- if err := syscall.Sendto(s.fd, request.Serialize(), 0, &s.lsa); err != nil {
- return err
- }
- return nil
- }
- func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
- if s.fd < 0 {
- return nil, fmt.Errorf("Receive called on a closed socket")
- }
- rb := make([]byte, syscall.Getpagesize())
- nr, _, err := syscall.Recvfrom(s.fd, rb, 0)
- if err != nil {
- return nil, err
- }
- if nr < syscall.NLMSG_HDRLEN {
- return nil, fmt.Errorf("Got short response from netlink")
- }
- rb = rb[:nr]
- return syscall.ParseNetlinkMessage(rb)
- }
- func (s *NetlinkSocket) GetPid() (uint32, error) {
- lsa, err := syscall.Getsockname(s.fd)
- if err != nil {
- return 0, err
- }
- switch v := lsa.(type) {
- case *syscall.SockaddrNetlink:
- return v.Pid, nil
- }
- return 0, fmt.Errorf("Wrong socket type")
- }
- func ZeroTerminated(s string) []byte {
- bytes := make([]byte, len(s)+1)
- for i := 0; i < len(s); i++ {
- bytes[i] = s[i]
- }
- bytes[len(s)] = 0
- return bytes
- }
- func NonZeroTerminated(s string) []byte {
- bytes := make([]byte, len(s))
- for i := 0; i < len(s); i++ {
- bytes[i] = s[i]
- }
- return bytes
- }
- func BytesToString(b []byte) string {
- n := bytes.Index(b, []byte{0})
- return string(b[:n])
- }
- func Uint8Attr(v uint8) []byte {
- return []byte{byte(v)}
- }
- func Uint16Attr(v uint16) []byte {
- native := NativeEndian()
- bytes := make([]byte, 2)
- native.PutUint16(bytes, v)
- return bytes
- }
- func Uint32Attr(v uint32) []byte {
- native := NativeEndian()
- bytes := make([]byte, 4)
- native.PutUint32(bytes, v)
- return bytes
- }
- func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
- var attrs []syscall.NetlinkRouteAttr
- for len(b) >= syscall.SizeofRtAttr {
- a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
- if err != nil {
- return nil, err
- }
- ra := syscall.NetlinkRouteAttr{Attr: *a, Value: vbuf[:int(a.Len)-syscall.SizeofRtAttr]}
- attrs = append(attrs, ra)
- b = b[alen:]
- }
- return attrs, nil
- }
- func netlinkRouteAttrAndValue(b []byte) (*syscall.RtAttr, []byte, int, error) {
- a := (*syscall.RtAttr)(unsafe.Pointer(&b[0]))
- if int(a.Len) < syscall.SizeofRtAttr || int(a.Len) > len(b) {
- return nil, nil, 0, syscall.EINVAL
- }
- return a, b[syscall.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
- }
- // SocketHandle contains the netlink socket and the associated
- // sequence counter for a specific netlink family
- type SocketHandle struct {
- Seq uint32
- Socket *NetlinkSocket
- }
- // Close closes the netlink socket
- func (sh *SocketHandle) Close() {
- if sh.Socket != nil {
- sh.Socket.Close()
- }
- }