nl_linux.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. // Package nl has low level primitives for making Netlink calls.
  2. package nl
  3. import (
  4. "bytes"
  5. "encoding/binary"
  6. "fmt"
  7. "net"
  8. "runtime"
  9. "sync"
  10. "sync/atomic"
  11. "syscall"
  12. "unsafe"
  13. "github.com/vishvananda/netns"
  14. )
  15. const (
  16. // Family type definitions
  17. FAMILY_ALL = syscall.AF_UNSPEC
  18. FAMILY_V4 = syscall.AF_INET
  19. FAMILY_V6 = syscall.AF_INET6
  20. )
  21. // SupportedNlFamilies contains the list of netlink families this netlink package supports
  22. var SupportedNlFamilies = []int{syscall.NETLINK_ROUTE, syscall.NETLINK_XFRM}
  23. var nextSeqNr uint32
  24. // GetIPFamily returns the family type of a net.IP.
  25. func GetIPFamily(ip net.IP) int {
  26. if len(ip) <= net.IPv4len {
  27. return FAMILY_V4
  28. }
  29. if ip.To4() != nil {
  30. return FAMILY_V4
  31. }
  32. return FAMILY_V6
  33. }
  34. var nativeEndian binary.ByteOrder
  35. // Get native endianness for the system
  36. func NativeEndian() binary.ByteOrder {
  37. if nativeEndian == nil {
  38. var x uint32 = 0x01020304
  39. if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
  40. nativeEndian = binary.BigEndian
  41. } else {
  42. nativeEndian = binary.LittleEndian
  43. }
  44. }
  45. return nativeEndian
  46. }
  47. // Byte swap a 16 bit value if we aren't big endian
  48. func Swap16(i uint16) uint16 {
  49. if NativeEndian() == binary.BigEndian {
  50. return i
  51. }
  52. return (i&0xff00)>>8 | (i&0xff)<<8
  53. }
  54. // Byte swap a 32 bit value if aren't big endian
  55. func Swap32(i uint32) uint32 {
  56. if NativeEndian() == binary.BigEndian {
  57. return i
  58. }
  59. return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24
  60. }
  61. type NetlinkRequestData interface {
  62. Len() int
  63. Serialize() []byte
  64. }
  65. // IfInfomsg is related to links, but it is used for list requests as well
  66. type IfInfomsg struct {
  67. syscall.IfInfomsg
  68. }
  69. // Create an IfInfomsg with family specified
  70. func NewIfInfomsg(family int) *IfInfomsg {
  71. return &IfInfomsg{
  72. IfInfomsg: syscall.IfInfomsg{
  73. Family: uint8(family),
  74. },
  75. }
  76. }
  77. func DeserializeIfInfomsg(b []byte) *IfInfomsg {
  78. return (*IfInfomsg)(unsafe.Pointer(&b[0:syscall.SizeofIfInfomsg][0]))
  79. }
  80. func (msg *IfInfomsg) Serialize() []byte {
  81. return (*(*[syscall.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:]
  82. }
  83. func (msg *IfInfomsg) Len() int {
  84. return syscall.SizeofIfInfomsg
  85. }
  86. func rtaAlignOf(attrlen int) int {
  87. return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1)
  88. }
  89. func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg {
  90. msg := NewIfInfomsg(family)
  91. parent.children = append(parent.children, msg)
  92. return msg
  93. }
  94. // Extend RtAttr to handle data and children
  95. type RtAttr struct {
  96. syscall.RtAttr
  97. Data []byte
  98. children []NetlinkRequestData
  99. }
  100. // Create a new Extended RtAttr object
  101. func NewRtAttr(attrType int, data []byte) *RtAttr {
  102. return &RtAttr{
  103. RtAttr: syscall.RtAttr{
  104. Type: uint16(attrType),
  105. },
  106. children: []NetlinkRequestData{},
  107. Data: data,
  108. }
  109. }
  110. // Create a new RtAttr obj anc add it as a child of an existing object
  111. func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
  112. attr := NewRtAttr(attrType, data)
  113. parent.children = append(parent.children, attr)
  114. return attr
  115. }
  116. func (a *RtAttr) Len() int {
  117. if len(a.children) == 0 {
  118. return (syscall.SizeofRtAttr + len(a.Data))
  119. }
  120. l := 0
  121. for _, child := range a.children {
  122. l += rtaAlignOf(child.Len())
  123. }
  124. l += syscall.SizeofRtAttr
  125. return rtaAlignOf(l + len(a.Data))
  126. }
  127. // Serialize the RtAttr into a byte array
  128. // This can't just unsafe.cast because it must iterate through children.
  129. func (a *RtAttr) Serialize() []byte {
  130. native := NativeEndian()
  131. length := a.Len()
  132. buf := make([]byte, rtaAlignOf(length))
  133. next := 4
  134. if a.Data != nil {
  135. copy(buf[next:], a.Data)
  136. next += rtaAlignOf(len(a.Data))
  137. }
  138. if len(a.children) > 0 {
  139. for _, child := range a.children {
  140. childBuf := child.Serialize()
  141. copy(buf[next:], childBuf)
  142. next += rtaAlignOf(len(childBuf))
  143. }
  144. }
  145. if l := uint16(length); l != 0 {
  146. native.PutUint16(buf[0:2], l)
  147. }
  148. native.PutUint16(buf[2:4], a.Type)
  149. return buf
  150. }
  151. type NetlinkRequest struct {
  152. syscall.NlMsghdr
  153. Data []NetlinkRequestData
  154. Sockets map[int]*SocketHandle
  155. }
  156. // Serialize the Netlink Request into a byte array
  157. func (req *NetlinkRequest) Serialize() []byte {
  158. length := syscall.SizeofNlMsghdr
  159. dataBytes := make([][]byte, len(req.Data))
  160. for i, data := range req.Data {
  161. dataBytes[i] = data.Serialize()
  162. length = length + len(dataBytes[i])
  163. }
  164. req.Len = uint32(length)
  165. b := make([]byte, length)
  166. hdr := (*(*[syscall.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
  167. next := syscall.SizeofNlMsghdr
  168. copy(b[0:next], hdr)
  169. for _, data := range dataBytes {
  170. for _, dataByte := range data {
  171. b[next] = dataByte
  172. next = next + 1
  173. }
  174. }
  175. return b
  176. }
  177. func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
  178. if data != nil {
  179. req.Data = append(req.Data, data)
  180. }
  181. }
  182. // Execute the request against a the given sockType.
  183. // Returns a list of netlink messages in serialized format, optionally filtered
  184. // by resType.
  185. func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
  186. var (
  187. s *NetlinkSocket
  188. err error
  189. )
  190. if req.Sockets != nil {
  191. if sh, ok := req.Sockets[sockType]; ok {
  192. s = sh.Socket
  193. req.Seq = atomic.AddUint32(&sh.Seq, 1)
  194. }
  195. }
  196. sharedSocket := s != nil
  197. if s == nil {
  198. s, err = getNetlinkSocket(sockType)
  199. if err != nil {
  200. return nil, err
  201. }
  202. defer s.Close()
  203. } else {
  204. s.Lock()
  205. defer s.Unlock()
  206. }
  207. if err := s.Send(req); err != nil {
  208. return nil, err
  209. }
  210. pid, err := s.GetPid()
  211. if err != nil {
  212. return nil, err
  213. }
  214. var res [][]byte
  215. done:
  216. for {
  217. msgs, err := s.Receive()
  218. if err != nil {
  219. return nil, err
  220. }
  221. for _, m := range msgs {
  222. if m.Header.Seq != req.Seq {
  223. if sharedSocket {
  224. continue
  225. }
  226. return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
  227. }
  228. if m.Header.Pid != pid {
  229. return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
  230. }
  231. if m.Header.Type == syscall.NLMSG_DONE {
  232. break done
  233. }
  234. if m.Header.Type == syscall.NLMSG_ERROR {
  235. native := NativeEndian()
  236. error := int32(native.Uint32(m.Data[0:4]))
  237. if error == 0 {
  238. break done
  239. }
  240. return nil, syscall.Errno(-error)
  241. }
  242. if resType != 0 && m.Header.Type != resType {
  243. continue
  244. }
  245. res = append(res, m.Data)
  246. if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
  247. break done
  248. }
  249. }
  250. }
  251. return res, nil
  252. }
  253. // Create a new netlink request from proto and flags
  254. // Note the Len value will be inaccurate once data is added until
  255. // the message is serialized
  256. func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
  257. return &NetlinkRequest{
  258. NlMsghdr: syscall.NlMsghdr{
  259. Len: uint32(syscall.SizeofNlMsghdr),
  260. Type: uint16(proto),
  261. Flags: syscall.NLM_F_REQUEST | uint16(flags),
  262. Seq: atomic.AddUint32(&nextSeqNr, 1),
  263. },
  264. }
  265. }
  266. type NetlinkSocket struct {
  267. fd int
  268. lsa syscall.SockaddrNetlink
  269. sync.Mutex
  270. }
  271. func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
  272. fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
  273. if err != nil {
  274. return nil, err
  275. }
  276. s := &NetlinkSocket{
  277. fd: fd,
  278. }
  279. s.lsa.Family = syscall.AF_NETLINK
  280. if err := syscall.Bind(fd, &s.lsa); err != nil {
  281. syscall.Close(fd)
  282. return nil, err
  283. }
  284. return s, nil
  285. }
  286. // GetNetlinkSocketAt opens a netlink socket in the network namespace newNs
  287. // and positions the thread back into the network namespace specified by curNs,
  288. // when done. If curNs is close, the function derives the current namespace and
  289. // moves back into it when done. If newNs is close, the socket will be opened
  290. // in the current network namespace.
  291. func GetNetlinkSocketAt(newNs, curNs netns.NsHandle, protocol int) (*NetlinkSocket, error) {
  292. c, err := executeInNetns(newNs, curNs)
  293. if err != nil {
  294. return nil, err
  295. }
  296. defer c()
  297. return getNetlinkSocket(protocol)
  298. }
  299. // executeInNetns sets execution of the code following this call to the
  300. // network namespace newNs, then moves the thread back to curNs if open,
  301. // otherwise to the current netns at the time the function was invoked
  302. // In case of success, the caller is expected to execute the returned function
  303. // at the end of the code that needs to be executed in the network namespace.
  304. // Example:
  305. // func jobAt(...) error {
  306. // d, err := executeInNetns(...)
  307. // if err != nil { return err}
  308. // defer d()
  309. // < code which needs to be executed in specific netns>
  310. // }
  311. // TODO: his function probably belongs to netns pkg.
  312. func executeInNetns(newNs, curNs netns.NsHandle) (func(), error) {
  313. var (
  314. err error
  315. moveBack func(netns.NsHandle) error
  316. closeNs func() error
  317. unlockThd func()
  318. )
  319. restore := func() {
  320. // order matters
  321. if moveBack != nil {
  322. moveBack(curNs)
  323. }
  324. if closeNs != nil {
  325. closeNs()
  326. }
  327. if unlockThd != nil {
  328. unlockThd()
  329. }
  330. }
  331. if newNs.IsOpen() {
  332. runtime.LockOSThread()
  333. unlockThd = runtime.UnlockOSThread
  334. if !curNs.IsOpen() {
  335. if curNs, err = netns.Get(); err != nil {
  336. restore()
  337. return nil, fmt.Errorf("could not get current namespace while creating netlink socket: %v", err)
  338. }
  339. closeNs = curNs.Close
  340. }
  341. if err := netns.Set(newNs); err != nil {
  342. restore()
  343. return nil, fmt.Errorf("failed to set into network namespace %d while creating netlink socket: %v", newNs, err)
  344. }
  345. moveBack = netns.Set
  346. }
  347. return restore, nil
  348. }
  349. // Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
  350. // and subscribe it to multicast groups passed in variable argument list.
  351. // Returns the netlink socket on which Receive() method can be called
  352. // to retrieve the messages from the kernel.
  353. func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
  354. fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
  355. if err != nil {
  356. return nil, err
  357. }
  358. s := &NetlinkSocket{
  359. fd: fd,
  360. }
  361. s.lsa.Family = syscall.AF_NETLINK
  362. for _, g := range groups {
  363. s.lsa.Groups |= (1 << (g - 1))
  364. }
  365. if err := syscall.Bind(fd, &s.lsa); err != nil {
  366. syscall.Close(fd)
  367. return nil, err
  368. }
  369. return s, nil
  370. }
  371. // SubscribeAt works like Subscribe plus let's the caller choose the network
  372. // namespace in which the socket would be opened (newNs). Then control goes back
  373. // to curNs if open, otherwise to the netns at the time this function was called.
  374. func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*NetlinkSocket, error) {
  375. c, err := executeInNetns(newNs, curNs)
  376. if err != nil {
  377. return nil, err
  378. }
  379. defer c()
  380. return Subscribe(protocol, groups...)
  381. }
  382. func (s *NetlinkSocket) Close() {
  383. syscall.Close(s.fd)
  384. s.fd = -1
  385. }
  386. func (s *NetlinkSocket) GetFd() int {
  387. return s.fd
  388. }
  389. func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
  390. if s.fd < 0 {
  391. return fmt.Errorf("Send called on a closed socket")
  392. }
  393. if err := syscall.Sendto(s.fd, request.Serialize(), 0, &s.lsa); err != nil {
  394. return err
  395. }
  396. return nil
  397. }
  398. func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
  399. if s.fd < 0 {
  400. return nil, fmt.Errorf("Receive called on a closed socket")
  401. }
  402. rb := make([]byte, syscall.Getpagesize())
  403. nr, _, err := syscall.Recvfrom(s.fd, rb, 0)
  404. if err != nil {
  405. return nil, err
  406. }
  407. if nr < syscall.NLMSG_HDRLEN {
  408. return nil, fmt.Errorf("Got short response from netlink")
  409. }
  410. rb = rb[:nr]
  411. return syscall.ParseNetlinkMessage(rb)
  412. }
  413. func (s *NetlinkSocket) GetPid() (uint32, error) {
  414. lsa, err := syscall.Getsockname(s.fd)
  415. if err != nil {
  416. return 0, err
  417. }
  418. switch v := lsa.(type) {
  419. case *syscall.SockaddrNetlink:
  420. return v.Pid, nil
  421. }
  422. return 0, fmt.Errorf("Wrong socket type")
  423. }
  424. func ZeroTerminated(s string) []byte {
  425. bytes := make([]byte, len(s)+1)
  426. for i := 0; i < len(s); i++ {
  427. bytes[i] = s[i]
  428. }
  429. bytes[len(s)] = 0
  430. return bytes
  431. }
  432. func NonZeroTerminated(s string) []byte {
  433. bytes := make([]byte, len(s))
  434. for i := 0; i < len(s); i++ {
  435. bytes[i] = s[i]
  436. }
  437. return bytes
  438. }
  439. func BytesToString(b []byte) string {
  440. n := bytes.Index(b, []byte{0})
  441. return string(b[:n])
  442. }
  443. func Uint8Attr(v uint8) []byte {
  444. return []byte{byte(v)}
  445. }
  446. func Uint16Attr(v uint16) []byte {
  447. native := NativeEndian()
  448. bytes := make([]byte, 2)
  449. native.PutUint16(bytes, v)
  450. return bytes
  451. }
  452. func Uint32Attr(v uint32) []byte {
  453. native := NativeEndian()
  454. bytes := make([]byte, 4)
  455. native.PutUint32(bytes, v)
  456. return bytes
  457. }
  458. func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
  459. var attrs []syscall.NetlinkRouteAttr
  460. for len(b) >= syscall.SizeofRtAttr {
  461. a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
  462. if err != nil {
  463. return nil, err
  464. }
  465. ra := syscall.NetlinkRouteAttr{Attr: *a, Value: vbuf[:int(a.Len)-syscall.SizeofRtAttr]}
  466. attrs = append(attrs, ra)
  467. b = b[alen:]
  468. }
  469. return attrs, nil
  470. }
  471. func netlinkRouteAttrAndValue(b []byte) (*syscall.RtAttr, []byte, int, error) {
  472. a := (*syscall.RtAttr)(unsafe.Pointer(&b[0]))
  473. if int(a.Len) < syscall.SizeofRtAttr || int(a.Len) > len(b) {
  474. return nil, nil, 0, syscall.EINVAL
  475. }
  476. return a, b[syscall.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
  477. }
  478. // SocketHandle contains the netlink socket and the associated
  479. // sequence counter for a specific netlink family
  480. type SocketHandle struct {
  481. Seq uint32
  482. Socket *NetlinkSocket
  483. }
  484. // Close closes the netlink socket
  485. func (sh *SocketHandle) Close() {
  486. if sh.Socket != nil {
  487. sh.Socket.Close()
  488. }
  489. }