diff options
author | sinanmohd <sinan@sinanmohd.com> | 2024-07-07 15:04:52 +0530 |
---|---|---|
committer | sinanmohd <sinan@sinanmohd.com> | 2024-07-07 15:07:34 +0530 |
commit | 60aa1b7adbd133a5b1679ab2feac4ec17e3b48b5 (patch) | |
tree | 082f64ac3ef625fb364227c138a99ba5c571016b /dns | |
parent | 38bf443cb2f52bd4cc822e69f9339094d846396f (diff) |
dns: block blacklisted domains in the database
Diffstat (limited to 'dns')
-rw-r--r-- | dns/main.go | 56 |
1 files changed, 48 insertions, 8 deletions
diff --git a/dns/main.go b/dns/main.go index b3adfca..a9be24e 100644 --- a/dns/main.go +++ b/dns/main.go @@ -1,43 +1,71 @@ package dns import ( + "context" "log" "net" + "sync" "github.com/miekg/dns" + "sinanmohd.com/redq/db" ) +type DnsBlackList struct { + data map[string]bool + mutex sync.RWMutex +} + type Dns struct { - server dns.Server - config *dns.ClientConfig + server dns.Server + config *dns.ClientConfig + queries *db.Queries + ctxDb context.Context + blackList DnsBlackList } func (d *Dns) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { var resp *dns.Msg var err error - client := new(dns.Client) - req.RecursionDesired = true; + d.blackList.mutex.RLock() + for _, qustion := range req.Question { + _, ok := d.blackList.data[qustion.Name] + if ok == false { + continue + } + + resp = new(dns.Msg) + resp.SetReply(req) + w.WriteMsg(resp) + return + } + d.blackList.mutex.RUnlock() + + client := new(dns.Client) + req.RecursionDesired = true for _, upstream := range d.config.Servers { resp, _, err = client.Exchange(req, net.JoinHostPort(upstream, d.config.Port)) if err == nil { break } - log.Printf("dns query: %s", err) + log.Printf("dns resolving: %s", err) + } + if err != nil { + return } w.WriteMsg(resp) } -func New() (*Dns, error) { +func New(queries *db.Queries, ctxDb context.Context) (*Dns, error) { var d Dns var err error d.server = dns.Server{ - Net: "udp", + Net: "udp", ReusePort: true, - Handler: &d, + Handler: &d, } d.config, err = dns.ClientConfigFromFile("/etc/resolv.conf") @@ -46,6 +74,18 @@ func New() (*Dns, error) { return nil, err } + d.queries = queries + d.ctxDb = ctxDb + d.blackList.data = make(map[string]bool) + blackList, err := d.queries.GetDnsBlackList(d.ctxDb) + if err != nil { + log.Printf("reading dns blacklist database: %s", err) + return nil, err + } + for _, entry := range blackList { + d.blackList.data[entry] = true + } + return &d, nil } |