ipnet.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package ip
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net"
  7. )
  8. type IP4 uint32
  9. func FromBytes(ip []byte) IP4 {
  10. if NativelyLittle() {
  11. return IP4(uint32(ip[3]) |
  12. (uint32(ip[2]) << 8) |
  13. (uint32(ip[1]) << 16) |
  14. (uint32(ip[0]) << 24))
  15. } else {
  16. return IP4(uint32(ip[0]) |
  17. (uint32(ip[1]) << 8) |
  18. (uint32(ip[2]) << 16) |
  19. (uint32(ip[3]) << 24))
  20. }
  21. }
  22. func FromIP(ip net.IP) IP4 {
  23. return FromBytes(ip.To4())
  24. }
  25. func ParseIP4(s string) (IP4, error) {
  26. ip := net.ParseIP(s)
  27. if ip == nil {
  28. return IP4(0), errors.New("Invalid IP address format")
  29. }
  30. return FromIP(ip), nil
  31. }
  32. func (ip IP4) Octets() (a, b, c, d byte) {
  33. if NativelyLittle() {
  34. a, b, c, d = byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)
  35. } else {
  36. a, b, c, d = byte(ip), byte(ip>>8), byte(ip>>16), byte(ip>>24)
  37. }
  38. return
  39. }
  40. func (ip IP4) ToIP() net.IP {
  41. return net.IPv4(ip.Octets())
  42. }
  43. func (ip IP4) NetworkOrder() uint32 {
  44. if NativelyLittle() {
  45. a, b, c, d := byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)
  46. return uint32(a) | (uint32(b) << 8) | (uint32(c) << 16) | (uint32(d) << 24)
  47. } else {
  48. return uint32(ip)
  49. }
  50. }
  51. func (ip IP4) String() string {
  52. return ip.ToIP().String()
  53. }
  54. func (ip IP4) StringSep(sep string) string {
  55. a, b, c, d := ip.Octets()
  56. return fmt.Sprintf("%d%s%d%s%d%s%d", a, sep, b, sep, c, sep, d)
  57. }
  58. // json.Marshaler impl
  59. func (ip IP4) MarshalJSON() ([]byte, error) {
  60. return []byte(fmt.Sprintf(`"%s"`, ip)), nil
  61. }
  62. // json.Unmarshaler impl
  63. func (ip *IP4) UnmarshalJSON(j []byte) error {
  64. j = bytes.Trim(j, "\"")
  65. if val, err := ParseIP4(string(j)); err != nil {
  66. return err
  67. } else {
  68. *ip = val
  69. return nil
  70. }
  71. }
  72. // similar to net.IPNet but has uint based representation
  73. type IP4Net struct {
  74. IP IP4
  75. PrefixLen uint
  76. }
  77. func (n IP4Net) String() string {
  78. return fmt.Sprintf("%s/%d", n.IP.String(), n.PrefixLen)
  79. }
  80. func (n IP4Net) StringSep(octetSep, prefixSep string) string {
  81. return fmt.Sprintf("%s%s%d", n.IP.StringSep(octetSep), prefixSep, n.PrefixLen)
  82. }
  83. func (n IP4Net) Network() IP4Net {
  84. return IP4Net{
  85. n.IP & IP4(n.Mask()),
  86. n.PrefixLen,
  87. }
  88. }
  89. func (n IP4Net) Next() IP4Net {
  90. return IP4Net{
  91. n.IP + (1 << (32 - n.PrefixLen)),
  92. n.PrefixLen,
  93. }
  94. }
  95. func FromIPNet(n *net.IPNet) IP4Net {
  96. prefixLen, _ := n.Mask.Size()
  97. return IP4Net{
  98. FromIP(n.IP),
  99. uint(prefixLen),
  100. }
  101. }
  102. func (n IP4Net) ToIPNet() *net.IPNet {
  103. return &net.IPNet{
  104. IP: n.IP.ToIP(),
  105. Mask: net.CIDRMask(int(n.PrefixLen), 32),
  106. }
  107. }
  108. func (n IP4Net) Overlaps(other IP4Net) bool {
  109. var mask uint32
  110. if n.PrefixLen < other.PrefixLen {
  111. mask = n.Mask()
  112. } else {
  113. mask = other.Mask()
  114. }
  115. return (uint32(n.IP) & mask) == (uint32(other.IP) & mask)
  116. }
  117. func (n IP4Net) Equal(other IP4Net) bool {
  118. return n.IP == other.IP && n.PrefixLen == other.PrefixLen
  119. }
  120. func (n IP4Net) Mask() uint32 {
  121. var ones uint32 = 0xFFFFFFFF
  122. return ones << (32 - n.PrefixLen)
  123. }
  124. func (n IP4Net) Contains(ip IP4) bool {
  125. return (uint32(n.IP) & n.Mask()) == (uint32(ip) & n.Mask())
  126. }
  127. // json.Marshaler impl
  128. func (n IP4Net) MarshalJSON() ([]byte, error) {
  129. return []byte(fmt.Sprintf(`"%s"`, n)), nil
  130. }
  131. // json.Unmarshaler impl
  132. func (n *IP4Net) UnmarshalJSON(j []byte) error {
  133. j = bytes.Trim(j, "\"")
  134. if _, val, err := net.ParseCIDR(string(j)); err != nil {
  135. fmt.Println(err)
  136. return err
  137. } else {
  138. *n = FromIPNet(val)
  139. return nil
  140. }
  141. }