xfrm_state_linux.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package netlink
  2. import (
  3. "fmt"
  4. "syscall"
  5. "unsafe"
  6. "github.com/vishvananda/netlink/nl"
  7. )
  8. func writeStateAlgo(a *XfrmStateAlgo) []byte {
  9. algo := nl.XfrmAlgo{
  10. AlgKeyLen: uint32(len(a.Key) * 8),
  11. AlgKey: a.Key,
  12. }
  13. end := len(a.Name)
  14. if end > 64 {
  15. end = 64
  16. }
  17. copy(algo.AlgName[:end], a.Name)
  18. return algo.Serialize()
  19. }
  20. func writeStateAlgoAuth(a *XfrmStateAlgo) []byte {
  21. algo := nl.XfrmAlgoAuth{
  22. AlgKeyLen: uint32(len(a.Key) * 8),
  23. AlgTruncLen: uint32(a.TruncateLen),
  24. AlgKey: a.Key,
  25. }
  26. end := len(a.Name)
  27. if end > 64 {
  28. end = 64
  29. }
  30. copy(algo.AlgName[:end], a.Name)
  31. return algo.Serialize()
  32. }
  33. func writeMark(m *XfrmMark) []byte {
  34. mark := &nl.XfrmMark{
  35. Value: m.Value,
  36. Mask: m.Mask,
  37. }
  38. if mark.Mask == 0 {
  39. mark.Mask = ^uint32(0)
  40. }
  41. return mark.Serialize()
  42. }
  43. // XfrmStateAdd will add an xfrm state to the system.
  44. // Equivalent to: `ip xfrm state add $state`
  45. func XfrmStateAdd(state *XfrmState) error {
  46. return pkgHandle.XfrmStateAdd(state)
  47. }
  48. // XfrmStateAdd will add an xfrm state to the system.
  49. // Equivalent to: `ip xfrm state add $state`
  50. func (h *Handle) XfrmStateAdd(state *XfrmState) error {
  51. return h.xfrmStateAddOrUpdate(state, nl.XFRM_MSG_NEWSA)
  52. }
  53. // XfrmStateUpdate will update an xfrm state to the system.
  54. // Equivalent to: `ip xfrm state update $state`
  55. func XfrmStateUpdate(state *XfrmState) error {
  56. return pkgHandle.XfrmStateUpdate(state)
  57. }
  58. // XfrmStateUpdate will update an xfrm state to the system.
  59. // Equivalent to: `ip xfrm state update $state`
  60. func (h *Handle) XfrmStateUpdate(state *XfrmState) error {
  61. return h.xfrmStateAddOrUpdate(state, nl.XFRM_MSG_UPDSA)
  62. }
  63. func (h *Handle) xfrmStateAddOrUpdate(state *XfrmState, nlProto int) error {
  64. // A state with spi 0 can't be deleted so don't allow it to be set
  65. if state.Spi == 0 {
  66. return fmt.Errorf("Spi must be set when adding xfrm state.")
  67. }
  68. req := h.newNetlinkRequest(nlProto, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
  69. msg := &nl.XfrmUsersaInfo{}
  70. msg.Family = uint16(nl.GetIPFamily(state.Dst))
  71. msg.Id.Daddr.FromIP(state.Dst)
  72. msg.Saddr.FromIP(state.Src)
  73. msg.Id.Proto = uint8(state.Proto)
  74. msg.Mode = uint8(state.Mode)
  75. msg.Id.Spi = nl.Swap32(uint32(state.Spi))
  76. msg.Reqid = uint32(state.Reqid)
  77. msg.ReplayWindow = uint8(state.ReplayWindow)
  78. limitsToLft(state.Limits, &msg.Lft)
  79. req.AddData(msg)
  80. if state.Auth != nil {
  81. out := nl.NewRtAttr(nl.XFRMA_ALG_AUTH_TRUNC, writeStateAlgoAuth(state.Auth))
  82. req.AddData(out)
  83. }
  84. if state.Crypt != nil {
  85. out := nl.NewRtAttr(nl.XFRMA_ALG_CRYPT, writeStateAlgo(state.Crypt))
  86. req.AddData(out)
  87. }
  88. if state.Encap != nil {
  89. encapData := make([]byte, nl.SizeofXfrmEncapTmpl)
  90. encap := nl.DeserializeXfrmEncapTmpl(encapData)
  91. encap.EncapType = uint16(state.Encap.Type)
  92. encap.EncapSport = nl.Swap16(uint16(state.Encap.SrcPort))
  93. encap.EncapDport = nl.Swap16(uint16(state.Encap.DstPort))
  94. encap.EncapOa.FromIP(state.Encap.OriginalAddress)
  95. out := nl.NewRtAttr(nl.XFRMA_ENCAP, encapData)
  96. req.AddData(out)
  97. }
  98. if state.Mark != nil {
  99. out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
  100. req.AddData(out)
  101. }
  102. _, err := req.Execute(syscall.NETLINK_XFRM, 0)
  103. return err
  104. }
  105. // XfrmStateDel will delete an xfrm state from the system. Note that
  106. // the Algos are ignored when matching the state to delete.
  107. // Equivalent to: `ip xfrm state del $state`
  108. func XfrmStateDel(state *XfrmState) error {
  109. return pkgHandle.XfrmStateDel(state)
  110. }
  111. // XfrmStateDel will delete an xfrm state from the system. Note that
  112. // the Algos are ignored when matching the state to delete.
  113. // Equivalent to: `ip xfrm state del $state`
  114. func (h *Handle) XfrmStateDel(state *XfrmState) error {
  115. _, err := h.xfrmStateGetOrDelete(state, nl.XFRM_MSG_DELSA)
  116. return err
  117. }
  118. // XfrmStateList gets a list of xfrm states in the system.
  119. // Equivalent to: `ip [-4|-6] xfrm state show`.
  120. // The list can be filtered by ip family.
  121. func XfrmStateList(family int) ([]XfrmState, error) {
  122. return pkgHandle.XfrmStateList(family)
  123. }
  124. // XfrmStateList gets a list of xfrm states in the system.
  125. // Equivalent to: `ip xfrm state show`.
  126. // The list can be filtered by ip family.
  127. func (h *Handle) XfrmStateList(family int) ([]XfrmState, error) {
  128. req := h.newNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP)
  129. msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWSA)
  130. if err != nil {
  131. return nil, err
  132. }
  133. var res []XfrmState
  134. for _, m := range msgs {
  135. if state, err := parseXfrmState(m, family); err == nil {
  136. res = append(res, *state)
  137. } else if err == familyError {
  138. continue
  139. } else {
  140. return nil, err
  141. }
  142. }
  143. return res, nil
  144. }
  145. // XfrmStateGet gets the xfrm state described by the ID, if found.
  146. // Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
  147. // Only the fields which constitue the SA ID must be filled in:
  148. // ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
  149. // mark is optional
  150. func XfrmStateGet(state *XfrmState) (*XfrmState, error) {
  151. return pkgHandle.XfrmStateGet(state)
  152. }
  153. // XfrmStateGet gets the xfrm state described by the ID, if found.
  154. // Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
  155. // Only the fields which constitue the SA ID must be filled in:
  156. // ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
  157. // mark is optional
  158. func (h *Handle) XfrmStateGet(state *XfrmState) (*XfrmState, error) {
  159. return h.xfrmStateGetOrDelete(state, nl.XFRM_MSG_GETSA)
  160. }
  161. func (h *Handle) xfrmStateGetOrDelete(state *XfrmState, nlProto int) (*XfrmState, error) {
  162. req := h.newNetlinkRequest(nlProto, syscall.NLM_F_ACK)
  163. msg := &nl.XfrmUsersaId{}
  164. msg.Family = uint16(nl.GetIPFamily(state.Dst))
  165. msg.Daddr.FromIP(state.Dst)
  166. msg.Proto = uint8(state.Proto)
  167. msg.Spi = nl.Swap32(uint32(state.Spi))
  168. req.AddData(msg)
  169. if state.Mark != nil {
  170. out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
  171. req.AddData(out)
  172. }
  173. if state.Src != nil {
  174. out := nl.NewRtAttr(nl.XFRMA_SRCADDR, state.Src)
  175. req.AddData(out)
  176. }
  177. resType := nl.XFRM_MSG_NEWSA
  178. if nlProto == nl.XFRM_MSG_DELSA {
  179. resType = 0
  180. }
  181. msgs, err := req.Execute(syscall.NETLINK_XFRM, uint16(resType))
  182. if err != nil {
  183. return nil, err
  184. }
  185. if nlProto == nl.XFRM_MSG_DELSA {
  186. return nil, nil
  187. }
  188. s, err := parseXfrmState(msgs[0], FAMILY_ALL)
  189. if err != nil {
  190. return nil, err
  191. }
  192. return s, nil
  193. }
  194. var familyError = fmt.Errorf("family error")
  195. func parseXfrmState(m []byte, family int) (*XfrmState, error) {
  196. msg := nl.DeserializeXfrmUsersaInfo(m)
  197. // This is mainly for the state dump
  198. if family != FAMILY_ALL && family != int(msg.Family) {
  199. return nil, familyError
  200. }
  201. var state XfrmState
  202. state.Dst = msg.Id.Daddr.ToIP()
  203. state.Src = msg.Saddr.ToIP()
  204. state.Proto = Proto(msg.Id.Proto)
  205. state.Mode = Mode(msg.Mode)
  206. state.Spi = int(nl.Swap32(msg.Id.Spi))
  207. state.Reqid = int(msg.Reqid)
  208. state.ReplayWindow = int(msg.ReplayWindow)
  209. lftToLimits(&msg.Lft, &state.Limits)
  210. attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
  211. if err != nil {
  212. return nil, err
  213. }
  214. for _, attr := range attrs {
  215. switch attr.Attr.Type {
  216. case nl.XFRMA_ALG_AUTH, nl.XFRMA_ALG_CRYPT:
  217. var resAlgo *XfrmStateAlgo
  218. if attr.Attr.Type == nl.XFRMA_ALG_AUTH {
  219. if state.Auth == nil {
  220. state.Auth = new(XfrmStateAlgo)
  221. }
  222. resAlgo = state.Auth
  223. } else {
  224. state.Crypt = new(XfrmStateAlgo)
  225. resAlgo = state.Crypt
  226. }
  227. algo := nl.DeserializeXfrmAlgo(attr.Value[:])
  228. (*resAlgo).Name = nl.BytesToString(algo.AlgName[:])
  229. (*resAlgo).Key = algo.AlgKey
  230. case nl.XFRMA_ALG_AUTH_TRUNC:
  231. if state.Auth == nil {
  232. state.Auth = new(XfrmStateAlgo)
  233. }
  234. algo := nl.DeserializeXfrmAlgoAuth(attr.Value[:])
  235. state.Auth.Name = nl.BytesToString(algo.AlgName[:])
  236. state.Auth.Key = algo.AlgKey
  237. state.Auth.TruncateLen = int(algo.AlgTruncLen)
  238. case nl.XFRMA_ENCAP:
  239. encap := nl.DeserializeXfrmEncapTmpl(attr.Value[:])
  240. state.Encap = new(XfrmStateEncap)
  241. state.Encap.Type = EncapType(encap.EncapType)
  242. state.Encap.SrcPort = int(nl.Swap16(encap.EncapSport))
  243. state.Encap.DstPort = int(nl.Swap16(encap.EncapDport))
  244. state.Encap.OriginalAddress = encap.EncapOa.ToIP()
  245. case nl.XFRMA_MARK:
  246. mark := nl.DeserializeXfrmMark(attr.Value[:])
  247. state.Mark = new(XfrmMark)
  248. state.Mark.Value = mark.Value
  249. state.Mark.Mask = mark.Mask
  250. }
  251. }
  252. return &state, nil
  253. }
  254. // XfrmStateFlush will flush the xfrm state on the system.
  255. // proto = 0 means any transformation protocols
  256. // Equivalent to: `ip xfrm state flush [ proto XFRM-PROTO ]`
  257. func XfrmStateFlush(proto Proto) error {
  258. return pkgHandle.XfrmStateFlush(proto)
  259. }
  260. // XfrmStateFlush will flush the xfrm state on the system.
  261. // proto = 0 means any transformation protocols
  262. // Equivalent to: `ip xfrm state flush [ proto XFRM-PROTO ]`
  263. func (h *Handle) XfrmStateFlush(proto Proto) error {
  264. req := h.newNetlinkRequest(nl.XFRM_MSG_FLUSHSA, syscall.NLM_F_ACK)
  265. req.AddData(&nl.XfrmUsersaFlush{Proto: uint8(proto)})
  266. _, err := req.Execute(syscall.NETLINK_XFRM, 0)
  267. if err != nil {
  268. return err
  269. }
  270. return nil
  271. }
  272. func limitsToLft(lmts XfrmStateLimits, lft *nl.XfrmLifetimeCfg) {
  273. if lmts.ByteSoft != 0 {
  274. lft.SoftByteLimit = lmts.ByteSoft
  275. } else {
  276. lft.SoftByteLimit = nl.XFRM_INF
  277. }
  278. if lmts.ByteHard != 0 {
  279. lft.HardByteLimit = lmts.ByteHard
  280. } else {
  281. lft.HardByteLimit = nl.XFRM_INF
  282. }
  283. if lmts.PacketSoft != 0 {
  284. lft.SoftPacketLimit = lmts.PacketSoft
  285. } else {
  286. lft.SoftPacketLimit = nl.XFRM_INF
  287. }
  288. if lmts.PacketHard != 0 {
  289. lft.HardPacketLimit = lmts.PacketHard
  290. } else {
  291. lft.HardPacketLimit = nl.XFRM_INF
  292. }
  293. lft.SoftAddExpiresSeconds = lmts.TimeSoft
  294. lft.HardAddExpiresSeconds = lmts.TimeHard
  295. lft.SoftUseExpiresSeconds = lmts.TimeUseSoft
  296. lft.HardUseExpiresSeconds = lmts.TimeUseHard
  297. }
  298. func lftToLimits(lft *nl.XfrmLifetimeCfg, lmts *XfrmStateLimits) {
  299. *lmts = *(*XfrmStateLimits)(unsafe.Pointer(lft))
  300. }