summaryrefslogtreecommitdiff
path: root/dns
diff options
context:
space:
mode:
authorsinanmohd <sinan@sinanmohd.com>2024-07-07 15:04:52 +0530
committersinanmohd <sinan@sinanmohd.com>2024-07-07 15:07:34 +0530
commit60aa1b7adbd133a5b1679ab2feac4ec17e3b48b5 (patch)
tree082f64ac3ef625fb364227c138a99ba5c571016b /dns
parent38bf443cb2f52bd4cc822e69f9339094d846396f (diff)
dns: block blacklisted domains in the database
Diffstat (limited to 'dns')
-rw-r--r--dns/main.go56
1 files changed, 48 insertions, 8 deletions
diff --git a/dns/main.go b/dns/main.go
index b3adfca..a9be24e 100644
--- a/dns/main.go
+++ b/dns/main.go
@@ -1,43 +1,71 @@
package dns
import (
+ "context"
"log"
"net"
+ "sync"
"github.com/miekg/dns"
+ "sinanmohd.com/redq/db"
)
+type DnsBlackList struct {
+ data map[string]bool
+ mutex sync.RWMutex
+}
+
type Dns struct {
- server dns.Server
- config *dns.ClientConfig
+ server dns.Server
+ config *dns.ClientConfig
+ queries *db.Queries
+ ctxDb context.Context
+ blackList DnsBlackList
}
func (d *Dns) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
var resp *dns.Msg
var err error
- client := new(dns.Client)
- req.RecursionDesired = true;
+ d.blackList.mutex.RLock()
+ for _, qustion := range req.Question {
+ _, ok := d.blackList.data[qustion.Name]
+ if ok == false {
+ continue
+ }
+
+ resp = new(dns.Msg)
+ resp.SetReply(req)
+ w.WriteMsg(resp)
+ return
+ }
+ d.blackList.mutex.RUnlock()
+
+ client := new(dns.Client)
+ req.RecursionDesired = true
for _, upstream := range d.config.Servers {
resp, _, err = client.Exchange(req, net.JoinHostPort(upstream, d.config.Port))
if err == nil {
break
}
- log.Printf("dns query: %s", err)
+ log.Printf("dns resolving: %s", err)
+ }
+ if err != nil {
+ return
}
w.WriteMsg(resp)
}
-func New() (*Dns, error) {
+func New(queries *db.Queries, ctxDb context.Context) (*Dns, error) {
var d Dns
var err error
d.server = dns.Server{
- Net: "udp",
+ Net: "udp",
ReusePort: true,
- Handler: &d,
+ Handler: &d,
}
d.config, err = dns.ClientConfigFromFile("/etc/resolv.conf")
@@ -46,6 +74,18 @@ func New() (*Dns, error) {
return nil, err
}
+ d.queries = queries
+ d.ctxDb = ctxDb
+ d.blackList.data = make(map[string]bool)
+ blackList, err := d.queries.GetDnsBlackList(d.ctxDb)
+ if err != nil {
+ log.Printf("reading dns blacklist database: %s", err)
+ return nil, err
+ }
+ for _, entry := range blackList {
+ d.blackList.data[entry] = true
+ }
+
return &d, nil
}