xfrm_state_test.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package netlink
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "net"
  6. "testing"
  7. )
  8. func TestXfrmStateAddGetDel(t *testing.T) {
  9. for _, s := range []*XfrmState{getBaseState(), getAeadState()} {
  10. testXfrmStateAddGetDel(t, s)
  11. }
  12. }
  13. func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) {
  14. tearDown := setUpNetlinkTest(t)
  15. defer tearDown()
  16. if err := XfrmStateAdd(state); err != nil {
  17. t.Fatal(err)
  18. }
  19. states, err := XfrmStateList(FAMILY_ALL)
  20. if err != nil {
  21. t.Fatal(err)
  22. }
  23. if len(states) != 1 {
  24. t.Fatal("State not added properly")
  25. }
  26. if !compareStates(state, &states[0]) {
  27. t.Fatalf("unexpected states returned")
  28. }
  29. // Get specific state
  30. sa, err := XfrmStateGet(state)
  31. if err != nil {
  32. t.Fatal(err)
  33. }
  34. if !compareStates(state, sa) {
  35. t.Fatalf("unexpected state returned")
  36. }
  37. if err = XfrmStateDel(state); err != nil {
  38. t.Fatal(err)
  39. }
  40. states, err = XfrmStateList(FAMILY_ALL)
  41. if err != nil {
  42. t.Fatal(err)
  43. }
  44. if len(states) != 0 {
  45. t.Fatal("State not removed properly")
  46. }
  47. if _, err := XfrmStateGet(state); err == nil {
  48. t.Fatalf("Unexpected success")
  49. }
  50. }
  51. func TestXfrmStateFlush(t *testing.T) {
  52. setUpNetlinkTest(t)()
  53. state1 := getBaseState()
  54. state2 := getBaseState()
  55. state2.Src = net.ParseIP("127.1.0.1")
  56. state2.Dst = net.ParseIP("127.1.0.2")
  57. state2.Proto = XFRM_PROTO_AH
  58. state2.Mode = XFRM_MODE_TUNNEL
  59. state2.Spi = 20
  60. state2.Mark = nil
  61. state2.Crypt = nil
  62. if err := XfrmStateAdd(state1); err != nil {
  63. t.Fatal(err)
  64. }
  65. if err := XfrmStateAdd(state2); err != nil {
  66. t.Fatal(err)
  67. }
  68. // flushing proto for which no state is present should return silently
  69. if err := XfrmStateFlush(XFRM_PROTO_COMP); err != nil {
  70. t.Fatal(err)
  71. }
  72. if err := XfrmStateFlush(XFRM_PROTO_AH); err != nil {
  73. t.Fatal(err)
  74. }
  75. if _, err := XfrmStateGet(state2); err == nil {
  76. t.Fatalf("Unexpected success")
  77. }
  78. if err := XfrmStateAdd(state2); err != nil {
  79. t.Fatal(err)
  80. }
  81. if err := XfrmStateFlush(0); err != nil {
  82. t.Fatal(err)
  83. }
  84. states, err := XfrmStateList(FAMILY_ALL)
  85. if err != nil {
  86. t.Fatal(err)
  87. }
  88. if len(states) != 0 {
  89. t.Fatal("State not flushed properly")
  90. }
  91. }
  92. func TestXfrmStateUpdateLimits(t *testing.T) {
  93. setUpNetlinkTest(t)()
  94. // Program state with limits
  95. state := getBaseState()
  96. state.Limits.TimeHard = 3600
  97. state.Limits.TimeSoft = 60
  98. state.Limits.PacketHard = 1000
  99. state.Limits.PacketSoft = 50
  100. state.Limits.ByteHard = 1000000
  101. state.Limits.ByteSoft = 50000
  102. state.Limits.TimeUseHard = 3000
  103. state.Limits.TimeUseSoft = 1500
  104. if err := XfrmStateAdd(state); err != nil {
  105. t.Fatal(err)
  106. }
  107. // Verify limits
  108. s, err := XfrmStateGet(state)
  109. if err != nil {
  110. t.Fatal(err)
  111. }
  112. if !compareLimits(state, s) {
  113. t.Fatalf("Incorrect time hard/soft retrieved: %s", s.Print(true))
  114. }
  115. // Update limits
  116. state.Limits.TimeHard = 1800
  117. state.Limits.TimeSoft = 30
  118. state.Limits.PacketHard = 500
  119. state.Limits.PacketSoft = 25
  120. state.Limits.ByteHard = 500000
  121. state.Limits.ByteSoft = 25000
  122. state.Limits.TimeUseHard = 2000
  123. state.Limits.TimeUseSoft = 1000
  124. if err := XfrmStateUpdate(state); err != nil {
  125. t.Fatal(err)
  126. }
  127. // Verify new limits
  128. s, err = XfrmStateGet(state)
  129. if err != nil {
  130. t.Fatal(err)
  131. }
  132. if s.Limits.TimeHard != 1800 || s.Limits.TimeSoft != 30 {
  133. t.Fatalf("Incorrect time hard retrieved: (%d, %d)", s.Limits.TimeHard, s.Limits.TimeSoft)
  134. }
  135. }
  136. func getBaseState() *XfrmState {
  137. return &XfrmState{
  138. Src: net.ParseIP("127.0.0.1"),
  139. Dst: net.ParseIP("127.0.0.2"),
  140. Proto: XFRM_PROTO_ESP,
  141. Mode: XFRM_MODE_TUNNEL,
  142. Spi: 1,
  143. Auth: &XfrmStateAlgo{
  144. Name: "hmac(sha256)",
  145. Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
  146. },
  147. Crypt: &XfrmStateAlgo{
  148. Name: "cbc(aes)",
  149. Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
  150. },
  151. Mark: &XfrmMark{
  152. Value: 0x12340000,
  153. Mask: 0xffff0000,
  154. },
  155. }
  156. }
  157. func getAeadState() *XfrmState {
  158. // 128 key bits + 32 salt bits
  159. k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177")
  160. return &XfrmState{
  161. Src: net.ParseIP("192.168.1.1"),
  162. Dst: net.ParseIP("192.168.2.2"),
  163. Proto: XFRM_PROTO_ESP,
  164. Mode: XFRM_MODE_TUNNEL,
  165. Spi: 2,
  166. Aead: &XfrmStateAlgo{
  167. Name: "rfc4106(gcm(aes))",
  168. Key: k,
  169. ICVLen: 64,
  170. },
  171. }
  172. }
  173. func compareStates(a, b *XfrmState) bool {
  174. if a == b {
  175. return true
  176. }
  177. if a == nil || b == nil {
  178. return false
  179. }
  180. return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
  181. a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
  182. compareAlgo(a.Auth, b.Auth) &&
  183. compareAlgo(a.Crypt, b.Crypt) &&
  184. compareAlgo(a.Aead, b.Aead) &&
  185. compareMarks(a.Mark, b.Mark)
  186. }
  187. func compareLimits(a, b *XfrmState) bool {
  188. return a.Limits.TimeHard == b.Limits.TimeHard &&
  189. a.Limits.TimeSoft == b.Limits.TimeSoft &&
  190. a.Limits.PacketHard == b.Limits.PacketHard &&
  191. a.Limits.PacketSoft == b.Limits.PacketSoft &&
  192. a.Limits.ByteHard == b.Limits.ByteHard &&
  193. a.Limits.ByteSoft == b.Limits.ByteSoft &&
  194. a.Limits.TimeUseHard == b.Limits.TimeUseHard &&
  195. a.Limits.TimeUseSoft == b.Limits.TimeUseSoft
  196. }
  197. func compareAlgo(a, b *XfrmStateAlgo) bool {
  198. if a == b {
  199. return true
  200. }
  201. if a == nil || b == nil {
  202. return false
  203. }
  204. return a.Name == b.Name && bytes.Equal(a.Key, b.Key) &&
  205. (a.TruncateLen == 0 || a.TruncateLen == b.TruncateLen) &&
  206. (a.ICVLen == 0 || a.ICVLen == b.ICVLen)
  207. }
  208. func compareMarks(a, b *XfrmMark) bool {
  209. if a == b {
  210. return true
  211. }
  212. if a == nil || b == nil {
  213. return false
  214. }
  215. return a.Value == b.Value && a.Mask == b.Mask
  216. }