netroute.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package netroute
  2. import (
  3. "regexp"
  4. "net"
  5. "strconv"
  6. "strings"
  7. "bufio"
  8. "bytes"
  9. ps "github.com/bhendo/go-powershell"
  10. psbe "github.com/bhendo/go-powershell/backend"
  11. "fmt"
  12. "math/big"
  13. )
  14. // Interface is an injectable interface for running MSFT_NetRoute commands. Implementations must be goroutine-safe.
  15. type Interface interface {
  16. // Get all net routes on the host
  17. GetNetRoutesAll() ([]Route, error)
  18. // Get net routes by link and destination subnet
  19. GetNetRoutes(linkIndex int, destinationSubnet *net.IPNet) ([]Route, error)
  20. // Create a new route
  21. NewNetRoute(linkIndex int, destinationSubnet *net.IPNet, gatewayAddress net.IP) error
  22. // Remove an existing route
  23. RemoveNetRoute(linkIndex int, destinationSubnet *net.IPNet, gatewayAddress net.IP) error
  24. // exit the shell
  25. Exit()
  26. }
  27. type Route struct {
  28. LinkIndex int
  29. DestinationSubnet *net.IPNet
  30. GatewayAddress net.IP
  31. RouteMetric int
  32. IfMetric int
  33. }
  34. type shell struct {
  35. shellInstance ps.Shell
  36. }
  37. func New() Interface {
  38. s, _ := ps.New(&psbe.Local{})
  39. runner := &shell{
  40. shellInstance: s,
  41. }
  42. return runner
  43. }
  44. func (shell *shell) Exit() {
  45. shell.shellInstance.Exit()
  46. shell.shellInstance = nil
  47. }
  48. func (shell *shell) GetNetRoutesAll() ([]Route, error) {
  49. getRouteCmdLine := "get-netroute -erroraction Ignore"
  50. stdout, err := shell.runScript(getRouteCmdLine)
  51. if err != nil {
  52. return nil, err
  53. }
  54. return parseRoutesList(stdout), nil
  55. }
  56. func (shell *shell) GetNetRoutes(linkIndex int, destinationSubnet *net.IPNet) ([]Route, error) {
  57. getRouteCmdLine := fmt.Sprintf("get-netroute -InterfaceIndex %v -DestinationPrefix %v -erroraction Ignore", linkIndex, destinationSubnet.String())
  58. stdout, err := shell.runScript(getRouteCmdLine)
  59. if err != nil {
  60. return nil, err
  61. }
  62. return parseRoutesList(stdout), nil
  63. }
  64. func (shell *shell) RemoveNetRoute(linkIndex int, destinationSubnet *net.IPNet, gatewayAddress net.IP) error {
  65. removeRouteCmdLine := fmt.Sprintf("remove-netroute -InterfaceIndex %v -DestinationPrefix %v -NextHop %v -Verbose -Confirm:$false", linkIndex, destinationSubnet.String(), gatewayAddress.String())
  66. _, err := shell.runScript(removeRouteCmdLine)
  67. return err
  68. }
  69. func (shell *shell) NewNetRoute(linkIndex int, destinationSubnet *net.IPNet, gatewayAddress net.IP) error {
  70. newRouteCmdLine := fmt.Sprintf("new-netroute -InterfaceIndex %v -DestinationPrefix %v -NextHop %v -Verbose", linkIndex, destinationSubnet.String(), gatewayAddress.String())
  71. _, err := shell.runScript(newRouteCmdLine)
  72. return err
  73. }
  74. func parseRoutesList(stdout string) []Route {
  75. internalWhitespaceRegEx := regexp.MustCompile(`[\s\p{Zs}]{2,}`)
  76. scanner := bufio.NewScanner(strings.NewReader(stdout))
  77. var routes []Route
  78. for scanner.Scan() {
  79. line := internalWhitespaceRegEx.ReplaceAllString(scanner.Text(), "|")
  80. if strings.HasPrefix(line, "ifIndex") || strings.HasPrefix(line, "----") {
  81. continue
  82. }
  83. parts := strings.Split(line, "|")
  84. if len(parts) != 5 {
  85. continue
  86. }
  87. linkIndex, err := strconv.Atoi(parts[0])
  88. if err != nil {
  89. continue
  90. }
  91. gatewayAddress := net.ParseIP(parts[2])
  92. if gatewayAddress == nil {
  93. continue
  94. }
  95. _, destinationSubnet, err := net.ParseCIDR(parts[1])
  96. if err != nil {
  97. continue
  98. }
  99. route := Route{
  100. DestinationSubnet: destinationSubnet,
  101. GatewayAddress: gatewayAddress,
  102. LinkIndex: linkIndex,
  103. }
  104. routes = append(routes, route)
  105. }
  106. return routes
  107. }
  108. func (r *Route) Equal(route Route) bool {
  109. if r.DestinationSubnet.IP.Equal(route.DestinationSubnet.IP) && r.GatewayAddress.Equal(route.GatewayAddress) && bytes.Equal(r.DestinationSubnet.Mask, route.DestinationSubnet.Mask) {
  110. return true
  111. }
  112. return false
  113. }
  114. func (shell *shell) runScript(cmdLine string) (string, error) {
  115. stdout, _, err := shell.shellInstance.Execute(cmdLine)
  116. if err != nil {
  117. return "", err
  118. }
  119. return stdout, nil
  120. }
  121. func IpToInt(ip net.IP) *big.Int {
  122. if v := ip.To4(); v != nil {
  123. return big.NewInt(0).SetBytes(v)
  124. }
  125. return big.NewInt(0).SetBytes(ip.To16())
  126. }
  127. func IntToIP(i *big.Int) net.IP {
  128. return net.IP(i.Bytes())
  129. }