diff --git a/hakrawler.go b/hakrawler.go index 1da86ed..1630997 100644 --- a/hakrawler.go +++ b/hakrawler.go @@ -45,6 +45,7 @@ func main() { proxy := flag.String(("proxy"), "", "Proxy URL. E.g. -proxy http://127.0.0.1:8080") timeout := flag.Int("timeout", -1, "Maximum time to crawl each URL from stdin, in seconds.") disableRedirects := flag.Bool("dr", false, "Disable following HTTP redirects.") + matchApex := flag.Bool("match-apex", false, "Match domain apex.") flag.Parse() @@ -79,6 +80,17 @@ func main() { continue } + apexDomain, err := getApexDomain(hostname) + if err != nil { + log.Println("Error getting apex domain:", err) + continue + } + + match := apexDomain + if !*matchApex { + match = "" + } + allowed_domains := []string{hostname} // if "Host" header is set, append it to allowed domains if headers != nil { @@ -92,7 +104,7 @@ func main() { // default user agent header colly.UserAgent("Mozilla/5.0 (X11; Linux x86_64; rv:78.0) Gecko/20100101 Firefox/78.0"), // set custom headers - colly.Headers(headers), + colly.Headers(headers) // limit crawling to the domain of the specified URL colly.AllowedDomains(allowed_domains...), // set MaxDepth to the specified depth @@ -127,19 +139,19 @@ func main() { abs_link := e.Request.AbsoluteURL(link) if strings.Contains(abs_link, url) || !*inside { - printResult(link, "href", *showSource, *showWhere, *showJson, results, e) + printResult(link, "href", *showSource, *showWhere, *showJson, results, e, match) e.Request.Visit(link) } }) // find and print all the JavaScript files c.OnHTML("script[src]", func(e *colly.HTMLElement) { - printResult(e.Attr("src"), "script", *showSource, *showWhere, *showJson, results, e) + printResult(e.Attr("src"), "script", *showSource, *showWhere, *showJson, results, e, match) }) // find and print all the form action URLs c.OnHTML("form[action]", func(e *colly.HTMLElement) { - printResult(e.Attr("action"), "form", *showSource, *showWhere, *showJson, results, e) + printResult(e.Attr("action"), "form", *showSource, *showWhere, *showJson, results, e, match) }) // add the custom headers @@ -248,10 +260,19 @@ func extractHostname(urlString string) (string, error) { } // print result constructs output lines and sends them to the results chan -func printResult(link string, sourceName string, showSource bool, showWhere bool, showJson bool, results chan string, e *colly.HTMLElement) { +func printResult(link string, sourceName string, showSource bool, showWhere bool, showJson bool, results chan string, e *colly.HTMLElement, match string) { result := e.Request.AbsoluteURL(link) whereURL := e.Request.URL.String() if result != "" { + parsedUrl, err := url.Parse(result) + if err != nil { + log.Println("Error parsing URL:", err) + return + } + if match != "" && !strings.HasSuffix(parsedUrl.Hostname(), match) { + return + } + if showJson { where := "" if showWhere { @@ -290,3 +311,13 @@ func isUnique(url string) bool { sm.Store(url, true) return true } + +// getApexDomain returns the apex domain of a hostname +func getApexDomain(hostname string) (string, error) { + parts := strings.Split(hostname, ".") + if len(parts) < 2 { + return "", errors.New("Invalid hostname") + } + + return parts[len(parts)-2] + "." + parts[len(parts)-1], nil +}