|
@@ -1,7 +1,6 @@
|
|
|
package netlink
|
|
|
|
|
|
import (
|
|
|
- "sync/atomic"
|
|
|
"syscall"
|
|
|
|
|
|
"github.com/vishvananda/netlink/nl"
|
|
@@ -11,27 +10,34 @@ import (
|
|
|
|
|
|
var pkgHandle = &Handle{}
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
type Handle struct {
|
|
|
- seq uint32
|
|
|
- routeSocket *nl.NetlinkSocket
|
|
|
- xfrmSocket *nl.NetlinkSocket
|
|
|
+ sockets map[int]*nl.SocketHandle
|
|
|
lookupByDump bool
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+func (h *Handle) SupportsNetlinkFamily(nlFamily int) bool {
|
|
|
+ _, ok := h.sockets[nlFamily]
|
|
|
+ return ok
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
-func NewHandle() (*Handle, error) {
|
|
|
- return newHandle(netns.None(), netns.None())
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func NewHandle(nlFamilies ...int) (*Handle, error) {
|
|
|
+ return newHandle(netns.None(), netns.None(), nlFamilies...)
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-func NewHandleAt(ns netns.NsHandle) (*Handle, error) {
|
|
|
- return newHandle(ns, netns.None())
|
|
|
+func NewHandleAt(ns netns.NsHandle, nlFamilies ...int) (*Handle, error) {
|
|
|
+ return newHandle(ns, netns.None(), nlFamilies...)
|
|
|
}
|
|
|
|
|
|
|
|
@@ -40,37 +46,33 @@ func NewHandleAtFrom(newNs, curNs netns.NsHandle) (*Handle, error) {
|
|
|
return newHandle(newNs, curNs)
|
|
|
}
|
|
|
|
|
|
-func newHandle(newNs, curNs netns.NsHandle) (*Handle, error) {
|
|
|
- var (
|
|
|
- err error
|
|
|
- rSocket *nl.NetlinkSocket
|
|
|
- xSocket *nl.NetlinkSocket
|
|
|
- )
|
|
|
- rSocket, err = nl.GetNetlinkSocketAt(newNs, curNs, syscall.NETLINK_ROUTE)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+func newHandle(newNs, curNs netns.NsHandle, nlFamilies ...int) (*Handle, error) {
|
|
|
+ h := &Handle{sockets: map[int]*nl.SocketHandle{}}
|
|
|
+ fams := nl.SupportedNlFamilies
|
|
|
+ if len(nlFamilies) != 0 {
|
|
|
+ fams = nlFamilies
|
|
|
}
|
|
|
- xSocket, err = nl.GetNetlinkSocketAt(newNs, curNs, syscall.NETLINK_XFRM)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+ for _, f := range fams {
|
|
|
+ s, err := nl.GetNetlinkSocketAt(newNs, curNs, f)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ h.sockets[f] = &nl.SocketHandle{Socket: s}
|
|
|
}
|
|
|
- return &Handle{routeSocket: rSocket, xfrmSocket: xSocket}, nil
|
|
|
+ return h, nil
|
|
|
}
|
|
|
|
|
|
|
|
|
func (h *Handle) Delete() {
|
|
|
- if h.routeSocket != nil {
|
|
|
- h.routeSocket.Close()
|
|
|
- }
|
|
|
- if h.xfrmSocket != nil {
|
|
|
- h.xfrmSocket.Close()
|
|
|
+ for _, sh := range h.sockets {
|
|
|
+ sh.Close()
|
|
|
}
|
|
|
- h.routeSocket, h.xfrmSocket = nil, nil
|
|
|
+ h.sockets = nil
|
|
|
}
|
|
|
|
|
|
func (h *Handle) newNetlinkRequest(proto, flags int) *nl.NetlinkRequest {
|
|
|
|
|
|
- if h.routeSocket == nil {
|
|
|
+ if h.sockets == nil {
|
|
|
return nl.NewNetlinkRequest(proto, flags)
|
|
|
}
|
|
|
return &nl.NetlinkRequest{
|
|
@@ -78,9 +80,7 @@ func (h *Handle) newNetlinkRequest(proto, flags int) *nl.NetlinkRequest {
|
|
|
Len: uint32(syscall.SizeofNlMsghdr),
|
|
|
Type: uint16(proto),
|
|
|
Flags: syscall.NLM_F_REQUEST | uint16(flags),
|
|
|
- Seq: atomic.AddUint32(&h.seq, 1),
|
|
|
},
|
|
|
- RouteSocket: h.routeSocket,
|
|
|
- XfmrSocket: h.xfrmSocket,
|
|
|
+ Sockets: h.sockets,
|
|
|
}
|
|
|
}
|