|
@@ -29,11 +29,15 @@ import (
|
|
|
// Adds the output of stderr to exec.ExitError
|
|
|
type Error struct {
|
|
|
exec.ExitError
|
|
|
- cmd exec.Cmd
|
|
|
- msg string
|
|
|
+ cmd exec.Cmd
|
|
|
+ msg string
|
|
|
+ exitStatus *int //for overriding
|
|
|
}
|
|
|
|
|
|
func (e *Error) ExitStatus() int {
|
|
|
+ if e.exitStatus != nil {
|
|
|
+ return *e.exitStatus
|
|
|
+ }
|
|
|
return e.Sys().(syscall.WaitStatus).ExitStatus()
|
|
|
}
|
|
|
|
|
@@ -41,6 +45,13 @@ func (e *Error) Error() string {
|
|
|
return fmt.Sprintf("running %v: exit status %v: %v", e.cmd.Args, e.ExitStatus(), e.msg)
|
|
|
}
|
|
|
|
|
|
+// IsNotExist returns true if the error is due to the chain or rule not existing
|
|
|
+func (e *Error) IsNotExist() bool {
|
|
|
+ return e.ExitStatus() == 1 &&
|
|
|
+ (e.msg == "iptables: Bad rule (does a matching rule exist in that chain?).\n" ||
|
|
|
+ e.msg == "iptables: No chain/target/match by that name.\n")
|
|
|
+}
|
|
|
+
|
|
|
// Protocol to differentiate between IPv4 and IPv6
|
|
|
type Protocol byte
|
|
|
|
|
@@ -50,10 +61,15 @@ const (
|
|
|
)
|
|
|
|
|
|
type IPTables struct {
|
|
|
- path string
|
|
|
- proto Protocol
|
|
|
- hasCheck bool
|
|
|
- hasWait bool
|
|
|
+ path string
|
|
|
+ proto Protocol
|
|
|
+ hasCheck bool
|
|
|
+ hasWait bool
|
|
|
+ hasRandomFully bool
|
|
|
+ v1 int
|
|
|
+ v2 int
|
|
|
+ v3 int
|
|
|
+ mode string // the underlying iptables operating mode, e.g. nf_tables
|
|
|
}
|
|
|
|
|
|
// New creates a new IPTables.
|
|
@@ -69,15 +85,21 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- checkPresent, waitPresent, err := getIptablesCommandSupport(path)
|
|
|
- if err != nil {
|
|
|
- return nil, fmt.Errorf("error checking iptables version: %v", err)
|
|
|
- }
|
|
|
+ vstring, err := getIptablesVersionString(path)
|
|
|
+ v1, v2, v3, mode, err := extractIptablesVersion(vstring)
|
|
|
+
|
|
|
+ checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
|
|
|
+
|
|
|
ipt := IPTables{
|
|
|
- path: path,
|
|
|
- proto: proto,
|
|
|
- hasCheck: checkPresent,
|
|
|
- hasWait: waitPresent,
|
|
|
+ path: path,
|
|
|
+ proto: proto,
|
|
|
+ hasCheck: checkPresent,
|
|
|
+ hasWait: waitPresent,
|
|
|
+ hasRandomFully: randomFullyPresent,
|
|
|
+ v1: v1,
|
|
|
+ v2: v2,
|
|
|
+ v3: v3,
|
|
|
+ mode: mode,
|
|
|
}
|
|
|
return &ipt, nil
|
|
|
}
|
|
@@ -248,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) {
|
|
|
}
|
|
|
|
|
|
rules := strings.Split(stdout.String(), "\n")
|
|
|
+
|
|
|
+ // strip trailing newline
|
|
|
if len(rules) > 0 && rules[len(rules)-1] == "" {
|
|
|
rules = rules[:len(rules)-1]
|
|
|
}
|
|
|
|
|
|
+ // nftables mode doesn't return an error code when listing a non-existent
|
|
|
+ // chain. Patch that up.
|
|
|
+ if len(rules) == 0 && ipt.mode == "nf_tables" {
|
|
|
+ v := 1
|
|
|
+ return nil, &Error{
|
|
|
+ cmd: exec.Cmd{Args: args},
|
|
|
+ msg: "iptables: No chain/target/match by that name.",
|
|
|
+ exitStatus: &v,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for i, rule := range rules {
|
|
|
+ rules[i] = filterRuleOutput(rule)
|
|
|
+ }
|
|
|
+
|
|
|
return rules, nil
|
|
|
}
|
|
|
|
|
@@ -266,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error {
|
|
|
func (ipt *IPTables) ClearChain(table, chain string) error {
|
|
|
err := ipt.NewChain(table, chain)
|
|
|
|
|
|
+ // the exit code for "this table already exists" is different for
|
|
|
+ // different iptables modes
|
|
|
+ existsErr := 1
|
|
|
+ if ipt.mode == "nf_tables" {
|
|
|
+ existsErr = 4
|
|
|
+ }
|
|
|
+
|
|
|
eerr, eok := err.(*Error)
|
|
|
switch {
|
|
|
case err == nil:
|
|
|
return nil
|
|
|
- case eok && eerr.ExitStatus() == 1:
|
|
|
+ case eok && eerr.ExitStatus() == existsErr:
|
|
|
// chain already exists. Flush (clear) it.
|
|
|
return ipt.run("-t", table, "-F", chain)
|
|
|
default:
|
|
@@ -289,6 +335,21 @@ func (ipt *IPTables) DeleteChain(table, chain string) error {
|
|
|
return ipt.run("-t", table, "-X", chain)
|
|
|
}
|
|
|
|
|
|
+// ChangePolicy changes policy on chain to target
|
|
|
+func (ipt *IPTables) ChangePolicy(table, chain, target string) error {
|
|
|
+ return ipt.run("-t", table, "-P", chain, target)
|
|
|
+}
|
|
|
+
|
|
|
+// Check if the underlying iptables command supports the --random-fully flag
|
|
|
+func (ipt *IPTables) HasRandomFully() bool {
|
|
|
+ return ipt.hasRandomFully
|
|
|
+}
|
|
|
+
|
|
|
+// Return version components of the underlying iptables command
|
|
|
+func (ipt *IPTables) GetIptablesVersion() (int, int, int) {
|
|
|
+ return ipt.v1, ipt.v2, ipt.v3
|
|
|
+}
|
|
|
+
|
|
|
// run runs an iptables command with the given arguments, ignoring
|
|
|
// any stdout output
|
|
|
func (ipt *IPTables) run(args ...string) error {
|
|
@@ -324,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
|
|
|
if err := cmd.Run(); err != nil {
|
|
|
switch e := err.(type) {
|
|
|
case *exec.ExitError:
|
|
|
- return &Error{*e, cmd, stderr.String()}
|
|
|
+ return &Error{*e, cmd, stderr.String(), nil}
|
|
|
default:
|
|
|
return err
|
|
|
}
|
|
@@ -343,45 +404,40 @@ func getIptablesCommand(proto Protocol) string {
|
|
|
}
|
|
|
|
|
|
// Checks if iptables has the "-C" and "--wait" flag
|
|
|
-func getIptablesCommandSupport(path string) (bool, bool, error) {
|
|
|
- vstring, err := getIptablesVersionString(path)
|
|
|
- if err != nil {
|
|
|
- return false, false, err
|
|
|
- }
|
|
|
-
|
|
|
- v1, v2, v3, err := extractIptablesVersion(vstring)
|
|
|
- if err != nil {
|
|
|
- return false, false, err
|
|
|
- }
|
|
|
-
|
|
|
- return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), nil
|
|
|
+func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
|
|
|
+ return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
|
|
|
}
|
|
|
|
|
|
-// getIptablesVersion returns the first three components of the iptables version.
|
|
|
-// e.g. "iptables v1.3.66" would return (1, 3, 66, nil)
|
|
|
-func extractIptablesVersion(str string) (int, int, int, error) {
|
|
|
- versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
|
|
|
+// getIptablesVersion returns the first three components of the iptables version
|
|
|
+// and the operating mode (e.g. nf_tables or legacy)
|
|
|
+// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
|
|
|
+func extractIptablesVersion(str string) (int, int, int, string, error) {
|
|
|
+ versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
|
|
|
result := versionMatcher.FindStringSubmatch(str)
|
|
|
if result == nil {
|
|
|
- return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str)
|
|
|
+ return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
|
|
|
}
|
|
|
|
|
|
v1, err := strconv.Atoi(result[1])
|
|
|
if err != nil {
|
|
|
- return 0, 0, 0, err
|
|
|
+ return 0, 0, 0, "", err
|
|
|
}
|
|
|
|
|
|
v2, err := strconv.Atoi(result[2])
|
|
|
if err != nil {
|
|
|
- return 0, 0, 0, err
|
|
|
+ return 0, 0, 0, "", err
|
|
|
}
|
|
|
|
|
|
v3, err := strconv.Atoi(result[3])
|
|
|
if err != nil {
|
|
|
- return 0, 0, 0, err
|
|
|
+ return 0, 0, 0, "", err
|
|
|
}
|
|
|
|
|
|
- return v1, v2, v3, nil
|
|
|
+ mode := "legacy"
|
|
|
+ if result[4] != "" {
|
|
|
+ mode = result[4]
|
|
|
+ }
|
|
|
+ return v1, v2, v3, mode, nil
|
|
|
}
|
|
|
|
|
|
// Runs "iptables --version" to get the version string
|
|
@@ -424,6 +480,20 @@ func iptablesHasWaitCommand(v1 int, v2 int, v3 int) bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+// Checks if an iptables version is after 1.6.2, when --random-fully was added
|
|
|
+func iptablesHasRandomFully(v1 int, v2 int, v3 int) bool {
|
|
|
+ if v1 > 1 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if v1 == 1 && v2 > 6 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if v1 == 1 && v2 == 6 && v3 >= 2 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
// Checks if a rule specification exists for a table
|
|
|
func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string) (bool, error) {
|
|
|
rs := strings.Join(append([]string{"-A", chain}, rulespec...), " ")
|
|
@@ -435,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string
|
|
|
}
|
|
|
return strings.Contains(stdout.String(), rs), nil
|
|
|
}
|
|
|
+
|
|
|
+// counterRegex is the regex used to detect nftables counter format
|
|
|
+var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)
|
|
|
+
|
|
|
+// filterRuleOutput works around some inconsistencies in output.
|
|
|
+// For example, when iptables is in legacy vs. nftables mode, it produces
|
|
|
+// different results.
|
|
|
+func filterRuleOutput(rule string) string {
|
|
|
+ out := rule
|
|
|
+
|
|
|
+ // work around an output difference in nftables mode where counters
|
|
|
+ // are output in iptables-save format, rather than iptables -S format
|
|
|
+ // The string begins with "[0:0]"
|
|
|
+ //
|
|
|
+ // Fixes #49
|
|
|
+ if groups := counterRegex.FindStringSubmatch(out); groups != nil {
|
|
|
+ // drop the brackets
|
|
|
+ out = out[len(groups[0]):]
|
|
|
+ out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
|
|
|
+ }
|
|
|
+
|
|
|
+ return out
|
|
|
+}
|