xfrm_state_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package netlink
  2. import (
  3. "bytes"
  4. "net"
  5. "testing"
  6. )
  7. func TestXfrmStateAddGetDel(t *testing.T) {
  8. tearDown := setUpNetlinkTest(t)
  9. defer tearDown()
  10. state := getBaseState()
  11. if err := XfrmStateAdd(state); err != nil {
  12. t.Fatal(err)
  13. }
  14. states, err := XfrmStateList(FAMILY_ALL)
  15. if err != nil {
  16. t.Fatal(err)
  17. }
  18. if len(states) != 1 {
  19. t.Fatal("State not added properly")
  20. }
  21. if !compareStates(state, &states[0]) {
  22. t.Fatalf("unexpected states returned")
  23. }
  24. // Get specific state
  25. sa, err := XfrmStateGet(state)
  26. if err != nil {
  27. t.Fatal(err)
  28. }
  29. if !compareStates(state, sa) {
  30. t.Fatalf("unexpected state returned")
  31. }
  32. if err = XfrmStateDel(state); err != nil {
  33. t.Fatal(err)
  34. }
  35. states, err = XfrmStateList(FAMILY_ALL)
  36. if err != nil {
  37. t.Fatal(err)
  38. }
  39. if len(states) != 0 {
  40. t.Fatal("State not removed properly")
  41. }
  42. if _, err := XfrmStateGet(state); err == nil {
  43. t.Fatalf("Unexpected success")
  44. }
  45. }
  46. func TestXfrmStateFlush(t *testing.T) {
  47. setUpNetlinkTest(t)()
  48. state1 := getBaseState()
  49. state2 := getBaseState()
  50. state2.Src = net.ParseIP("127.1.0.1")
  51. state2.Dst = net.ParseIP("127.1.0.2")
  52. state2.Proto = XFRM_PROTO_AH
  53. state2.Mode = XFRM_MODE_TUNNEL
  54. state2.Spi = 20
  55. state2.Mark = nil
  56. state2.Crypt = nil
  57. if err := XfrmStateAdd(state1); err != nil {
  58. t.Fatal(err)
  59. }
  60. if err := XfrmStateAdd(state2); err != nil {
  61. t.Fatal(err)
  62. }
  63. // flushing proto for which no state is present should return silently
  64. if err := XfrmStateFlush(XFRM_PROTO_COMP); err != nil {
  65. t.Fatal(err)
  66. }
  67. if err := XfrmStateFlush(XFRM_PROTO_AH); err != nil {
  68. t.Fatal(err)
  69. }
  70. if _, err := XfrmStateGet(state2); err == nil {
  71. t.Fatalf("Unexpected success")
  72. }
  73. if err := XfrmStateAdd(state2); err != nil {
  74. t.Fatal(err)
  75. }
  76. if err := XfrmStateFlush(0); err != nil {
  77. t.Fatal(err)
  78. }
  79. states, err := XfrmStateList(FAMILY_ALL)
  80. if err != nil {
  81. t.Fatal(err)
  82. }
  83. if len(states) != 0 {
  84. t.Fatal("State not flushed properly")
  85. }
  86. }
  87. func TestXfrmStateUpdateLimits(t *testing.T) {
  88. setUpNetlinkTest(t)()
  89. // Program state with limits
  90. state := getBaseState()
  91. state.Limits.TimeHard = 3600
  92. state.Limits.TimeSoft = 60
  93. state.Limits.PacketHard = 1000
  94. state.Limits.PacketSoft = 50
  95. state.Limits.ByteHard = 1000000
  96. state.Limits.ByteSoft = 50000
  97. state.Limits.TimeUseHard = 3000
  98. state.Limits.TimeUseSoft = 1500
  99. if err := XfrmStateAdd(state); err != nil {
  100. t.Fatal(err)
  101. }
  102. // Verify limits
  103. s, err := XfrmStateGet(state)
  104. if err != nil {
  105. t.Fatal(err)
  106. }
  107. if !compareLimits(state, s) {
  108. t.Fatalf("Incorrect time hard/soft retrieved: %s", s.Print(true))
  109. }
  110. // Update limits
  111. state.Limits.TimeHard = 1800
  112. state.Limits.TimeSoft = 30
  113. state.Limits.PacketHard = 500
  114. state.Limits.PacketSoft = 25
  115. state.Limits.ByteHard = 500000
  116. state.Limits.ByteSoft = 25000
  117. state.Limits.TimeUseHard = 2000
  118. state.Limits.TimeUseSoft = 1000
  119. if err := XfrmStateUpdate(state); err != nil {
  120. t.Fatal(err)
  121. }
  122. // Verify new limits
  123. s, err = XfrmStateGet(state)
  124. if err != nil {
  125. t.Fatal(err)
  126. }
  127. if s.Limits.TimeHard != 1800 || s.Limits.TimeSoft != 30 {
  128. t.Fatalf("Incorrect time hard retrieved: (%d, %d)", s.Limits.TimeHard, s.Limits.TimeSoft)
  129. }
  130. }
  131. func getBaseState() *XfrmState {
  132. return &XfrmState{
  133. Src: net.ParseIP("127.0.0.1"),
  134. Dst: net.ParseIP("127.0.0.2"),
  135. Proto: XFRM_PROTO_ESP,
  136. Mode: XFRM_MODE_TUNNEL,
  137. Spi: 1,
  138. Auth: &XfrmStateAlgo{
  139. Name: "hmac(sha256)",
  140. Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
  141. },
  142. Crypt: &XfrmStateAlgo{
  143. Name: "cbc(aes)",
  144. Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
  145. },
  146. Mark: &XfrmMark{
  147. Value: 0x12340000,
  148. Mask: 0xffff0000,
  149. },
  150. }
  151. }
  152. func compareStates(a, b *XfrmState) bool {
  153. if a == b {
  154. return true
  155. }
  156. if a == nil || b == nil {
  157. return false
  158. }
  159. return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
  160. a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
  161. a.Auth.Name == b.Auth.Name && bytes.Equal(a.Auth.Key, b.Auth.Key) &&
  162. a.Crypt.Name == b.Crypt.Name && bytes.Equal(a.Crypt.Key, b.Crypt.Key) &&
  163. a.Mark.Value == b.Mark.Value && a.Mark.Mask == b.Mark.Mask
  164. }
  165. func compareLimits(a, b *XfrmState) bool {
  166. return a.Limits.TimeHard == b.Limits.TimeHard &&
  167. a.Limits.TimeSoft == b.Limits.TimeSoft &&
  168. a.Limits.PacketHard == b.Limits.PacketHard &&
  169. a.Limits.PacketSoft == b.Limits.PacketSoft &&
  170. a.Limits.ByteHard == b.Limits.ByteHard &&
  171. a.Limits.ByteSoft == b.Limits.ByteSoft &&
  172. a.Limits.TimeUseHard == b.Limits.TimeUseHard &&
  173. a.Limits.TimeUseSoft == b.Limits.TimeUseSoft
  174. }