xfrm_policy_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package netlink
  2. import (
  3. "bytes"
  4. "net"
  5. "testing"
  6. )
  7. const zeroCIDR = "0.0.0.0/0"
  8. func TestXfrmPolicyAddUpdateDel(t *testing.T) {
  9. tearDown := setUpNetlinkTest(t)
  10. defer tearDown()
  11. policy := getPolicy()
  12. if err := XfrmPolicyAdd(policy); err != nil {
  13. t.Fatal(err)
  14. }
  15. policies, err := XfrmPolicyList(FAMILY_ALL)
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. if len(policies) != 1 {
  20. t.Fatal("Policy not added properly")
  21. }
  22. if !comparePolicies(policy, &policies[0]) {
  23. t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", policy, policies[0])
  24. }
  25. // Look for a specific policy
  26. sp, err := XfrmPolicyGet(policy)
  27. if err != nil {
  28. t.Fatal(err)
  29. }
  30. if !comparePolicies(policy, sp) {
  31. t.Fatalf("unexpected policy returned")
  32. }
  33. // Modify the policy
  34. policy.Priority = 100
  35. if err := XfrmPolicyUpdate(policy); err != nil {
  36. t.Fatal(err)
  37. }
  38. sp, err = XfrmPolicyGet(policy)
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. if sp.Priority != 100 {
  43. t.Fatalf("failed to modify the policy")
  44. }
  45. if err = XfrmPolicyDel(policy); err != nil {
  46. t.Fatal(err)
  47. }
  48. policies, err = XfrmPolicyList(FAMILY_ALL)
  49. if err != nil {
  50. t.Fatal(err)
  51. }
  52. if len(policies) != 0 {
  53. t.Fatal("Policy not removed properly")
  54. }
  55. // Src and dst are not mandatory field. Creation should succeed
  56. policy.Src = nil
  57. policy.Dst = nil
  58. if err = XfrmPolicyAdd(policy); err != nil {
  59. t.Fatal(err)
  60. }
  61. sp, err = XfrmPolicyGet(policy)
  62. if err != nil {
  63. t.Fatal(err)
  64. }
  65. if !comparePolicies(policy, sp) {
  66. t.Fatalf("unexpected policy returned")
  67. }
  68. if err = XfrmPolicyDel(policy); err != nil {
  69. t.Fatal(err)
  70. }
  71. if _, err := XfrmPolicyGet(policy); err == nil {
  72. t.Fatalf("Unexpected success")
  73. }
  74. }
  75. func TestXfrmPolicyFlush(t *testing.T) {
  76. setUpNetlinkTest(t)()
  77. p1 := getPolicy()
  78. if err := XfrmPolicyAdd(p1); err != nil {
  79. t.Fatal(err)
  80. }
  81. p1.Dir = XFRM_DIR_IN
  82. s := p1.Src
  83. p1.Src = p1.Dst
  84. p1.Dst = s
  85. if err := XfrmPolicyAdd(p1); err != nil {
  86. t.Fatal(err)
  87. }
  88. policies, err := XfrmPolicyList(FAMILY_ALL)
  89. if err != nil {
  90. t.Fatal(err)
  91. }
  92. if len(policies) != 2 {
  93. t.Fatalf("unexpected number of policies: %d", len(policies))
  94. }
  95. if err := XfrmPolicyFlush(); err != nil {
  96. t.Fatal(err)
  97. }
  98. policies, err = XfrmPolicyList(FAMILY_ALL)
  99. if err != nil {
  100. t.Fatal(err)
  101. }
  102. if len(policies) != 0 {
  103. t.Fatalf("unexpected number of policies: %d", len(policies))
  104. }
  105. }
  106. func comparePolicies(a, b *XfrmPolicy) bool {
  107. if a == b {
  108. return true
  109. }
  110. if a == nil || b == nil {
  111. return false
  112. }
  113. // Do not check Index which is assigned by kernel
  114. return a.Dir == b.Dir && a.Priority == b.Priority &&
  115. compareIPNet(a.Src, b.Src) && compareIPNet(a.Dst, b.Dst) &&
  116. a.Mark.Value == b.Mark.Value && a.Mark.Mask == b.Mark.Mask &&
  117. compareTemplates(a.Tmpls, b.Tmpls)
  118. }
  119. func compareTemplates(a, b []XfrmPolicyTmpl) bool {
  120. if len(a) != len(b) {
  121. return false
  122. }
  123. for i, ta := range a {
  124. tb := b[i]
  125. if !ta.Dst.Equal(tb.Dst) || !ta.Src.Equal(tb.Src) || ta.Spi != tb.Spi ||
  126. ta.Mode != tb.Mode || ta.Reqid != tb.Reqid || ta.Proto != tb.Proto {
  127. return false
  128. }
  129. }
  130. return true
  131. }
  132. func compareIPNet(a, b *net.IPNet) bool {
  133. if a == b {
  134. return true
  135. }
  136. // For unspecified src/dst parseXfrmPolicy would set the zero address cidr
  137. if (a == nil && b.String() == zeroCIDR) || (b == nil && a.String() == zeroCIDR) {
  138. return true
  139. }
  140. if a == nil || b == nil {
  141. return false
  142. }
  143. return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
  144. }
  145. func getPolicy() *XfrmPolicy {
  146. src, _ := ParseIPNet("127.1.1.1/32")
  147. dst, _ := ParseIPNet("127.1.1.2/32")
  148. policy := &XfrmPolicy{
  149. Src: src,
  150. Dst: dst,
  151. Proto: 17,
  152. DstPort: 1234,
  153. SrcPort: 5678,
  154. Dir: XFRM_DIR_OUT,
  155. Mark: &XfrmMark{
  156. Value: 0xabff22,
  157. Mask: 0xffffffff,
  158. },
  159. Priority: 10,
  160. }
  161. tmpl := XfrmPolicyTmpl{
  162. Src: net.ParseIP("127.0.0.1"),
  163. Dst: net.ParseIP("127.0.0.2"),
  164. Proto: XFRM_PROTO_ESP,
  165. Mode: XFRM_MODE_TUNNEL,
  166. Spi: 0xabcdef99,
  167. }
  168. policy.Tmpls = append(policy.Tmpls, tmpl)
  169. return policy
  170. }