Skip to content

Commit

Permalink
clean up code and return correct exit codes
Browse files Browse the repository at this point in the history
  • Loading branch information
1lann committed Dec 13, 2021
1 parent 0b7e58e commit 3b812e6
Showing 1 changed file with 50 additions and 21 deletions.
71 changes: 50 additions & 21 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import (
"log"
"os"
"path"
"path/filepath"
"strings"
"sync"
"sync/atomic"

"github.com/fatih/color"
"github.com/karrick/godirwalk"
Expand All @@ -21,28 +21,32 @@ import (
var printMutex = new(sync.Mutex)

var mode = flag.String("mode", "report", "the output mode, either \"report\" (every jar pretty printed) or \"list\" (list of potentially vulnerable files)")
var includeZip = flag.Bool("include-zip", false, "include zip files in the scan")

func main() {
// Parse the arguments and flags provided to the program.
flag.Parse()

stderr := log.New(os.Stderr, "", 0)

if flag.Arg(0) == "" {
fmt.Println("Usage: log4shelldetect [options] <path>")
fmt.Println("Scans a file or folder recursively for jar files that may be")
fmt.Println("vulnerable to Log4Shell (CVE-2021-44228) by inspecting")
fmt.Println("the class paths inside the Jar")
fmt.Println("")
fmt.Println("Options:")
stderr.Println("Usage: log4shelldetect [options] <path>")
stderr.Println("Scans a file or folder recursively for jar files that may be")
stderr.Println("vulnerable to Log4Shell (CVE-2021-44228) by inspecting")
stderr.Println("the class paths inside the Jar")
stderr.Println("")
stderr.Println("Options:")
flag.PrintDefaults()
os.Exit(1)
os.Exit(2)
}

target := flag.Arg(0)

// Identify if the provided path is a file or a folder.
f, err := os.Stat(target)
if err != nil {
panic(err)
stderr.Println("Error accessing target path:", err)
os.Exit(1)
}

if !f.IsDir() {
Expand All @@ -55,17 +59,29 @@ func main() {
// for concurrent scanning of jars.
pool := make(chan struct{}, 8)

var hasNotableResults uint32

// Scan through the directory provided recursively.
err = godirwalk.Walk(target, &godirwalk.Options{
Callback: func(osPathname string, de *godirwalk.Dirent) error {
// For each file in the directory, check if it ends in ".jar"
ext := strings.ToLower(filepath.Ext(osPathname))
if ext == ".jar" || ext == ".war" || ext == ".ear" || ext == ".zip" {
if shouldCheck(osPathname) {
pool <- struct{}{}
// If it is, take a goroutine (thread) from the thread pool
// and check the jar.
go func() {
status, desc := checkJar(osPathname, nil, 0, 0)
if *mode == "list" {
switch status {
case StatusVulnerable, StatusMaybe:
atomic.StoreUint32(&hasNotableResults, 1)
}
} else {
switch status {
case StatusVulnerable, StatusMaybe, StatusPatched:
atomic.StoreUint32(&hasNotableResults, 1)
}
}
// Print the result of the check.
printStatus(osPathname, status, desc)
<-pool
Expand All @@ -76,21 +92,41 @@ func main() {
},
ErrorCallback: func(osPathname string, err error) godirwalk.ErrorAction {
// On directory traversal error, print a warning.
printMutex.Lock()
defer printMutex.Unlock()
log.Printf("skipping %q: %v", osPathname, err)
return godirwalk.SkipNode
},
Unsorted: true,
})
if err != nil {
panic(err)
stderr.Println("Error scanning target path:", err)
os.Exit(1)
}

// Wait for all goroutines (threads) to complete their work.
for i := 0; i < cap(pool); i++ {
pool <- struct{}{}
}

os.Exit(found)
if hasNotableResults != 0 {
os.Exit(3)
}
}

func shouldCheck(filename string) bool {
ext := strings.ToLower(path.Ext(filename))
switch ext {
case ".zip":
if !*includeZip {
return false
}
return true
case ".jar", ".war", ".ear":
return true
}

return false
}

// checkJar checks a given jar file and returns a status and description for whether
Expand Down Expand Up @@ -195,8 +231,7 @@ func checkJar(pathToFile string, rd io.ReaderAt, size int64, depth int) (status
}

// If there is a jar in the jar, recurse into it.
ext := strings.ToLower(path.Ext(file.Name))
if ext == ".jar" || ext == ".war" || ext == ".ear" || ext == ".zip" {
if shouldCheck(file.Name) {
var subStatus Status
var subDesc string
// If the jar is larger than 500 MB, this can be dangerous
Expand Down Expand Up @@ -287,8 +322,6 @@ const (
StatusVulnerable
)

var found = 0

// printStatus takes in the path to the file, status and description, and
// prints the result out to stdout.
func printStatus(fileName string, status Status, desc string) {
Expand All @@ -299,7 +332,6 @@ func printStatus(fileName string, status Status, desc string) {
if *mode == "list" {
if status == StatusVulnerable || status == StatusMaybe {
fmt.Println(fileName)
found = 3
}

return
Expand All @@ -314,15 +346,12 @@ func printStatus(fileName string, status Status, desc string) {
case StatusPatched:
c = color.New(color.FgGreen)
c.Print("PATCHED ")
found = 3
case StatusVulnerable:
c = color.New(color.FgRed)
c.Print("VULNRBL ")
found = 3
case StatusMaybe:
c = color.New(color.FgRed)
c.Print("MAYBE ")
found = 3
case StatusUnknown:
c = color.New(color.FgYellow)
c.Print("UNKNOWN ")
Expand Down

0 comments on commit 3b812e6

Please sign in to comment.