diff options
-rw-r--r-- | cmd/main.go | 4 | ||||
-rw-r--r-- | db/models.go | 4 | ||||
-rw-r--r-- | db/query.sql | 14 | ||||
-rw-r--r-- | db/query.sql.go | 41 | ||||
-rw-r--r-- | db/schema.sql | 6 | ||||
-rw-r--r-- | dns/main.go | 56 |
6 files changed, 112 insertions, 13 deletions
diff --git a/cmd/main.go b/cmd/main.go index 6597749..540f9da 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -11,8 +11,8 @@ import ( "github.com/jackc/pgx/v5" "sinanmohd.com/redq/api" "sinanmohd.com/redq/db" - "sinanmohd.com/redq/usage" "sinanmohd.com/redq/dns" + "sinanmohd.com/redq/usage" ) func main() { @@ -29,7 +29,7 @@ func main() { defer conn.Close(ctx) queries := db.New(conn) - d, err := dns.New() + d, err := dns.New(queries, ctx) if err != nil { os.Exit(0) } diff --git a/db/models.go b/db/models.go index a6da6bf..1ea2797 100644 --- a/db/models.go +++ b/db/models.go @@ -8,6 +8,10 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +type Dnsblacklist struct { + Name string +} + type Usage struct { Hardwareaddr int64 Starttime pgtype.Timestamp diff --git a/db/query.sql b/db/query.sql index cfea3f1..9ba172f 100644 --- a/db/query.sql +++ b/db/query.sql @@ -6,4 +6,16 @@ INSERT INTO Usage ( ); -- name: GetUsage :one -SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress FROM Usage; +SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress +FROM Usage; + +-- name: EnterDnsBlackList :exec +INSERT INTO DnsBlackList ( + Name +) VALUES ( + $1 +); + +-- name: GetDnsBlackList :many +SELECT * +FROM DnsBlackList; diff --git a/db/query.sql.go b/db/query.sql.go index d6304a7..80c6469 100644 --- a/db/query.sql.go +++ b/db/query.sql.go @@ -11,6 +11,19 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const enterDnsBlackList = `-- name: EnterDnsBlackList :exec +INSERT INTO DnsBlackList ( + Name +) VALUES ( + $1 +) +` + +func (q *Queries) EnterDnsBlackList(ctx context.Context, name string) error { + _, err := q.db.Exec(ctx, enterDnsBlackList, name) + return err +} + const enterUsage = `-- name: EnterUsage :exec INSERT INTO Usage ( HardwareAddr, StartTime, StopTime, Egress, Ingress @@ -38,8 +51,34 @@ func (q *Queries) EnterUsage(ctx context.Context, arg EnterUsageParams) error { return err } +const getDnsBlackList = `-- name: GetDnsBlackList :many +SELECT name +FROM DnsBlackList +` + +func (q *Queries) GetDnsBlackList(ctx context.Context) ([]string, error) { + rows, err := q.db.Query(ctx, getDnsBlackList) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getUsage = `-- name: GetUsage :one -SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress FROM Usage +SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress +FROM Usage ` type GetUsageRow struct { diff --git a/db/schema.sql b/db/schema.sql index aece061..c8c35a1 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -1,7 +1,11 @@ -CREATE TABLE Usage ( +CREATE TABLE IF NOT EXISTS Usage ( HardwareAddr BIGINT NOT NULL, StartTime TIMESTAMP NOT NULL, StopTime TIMESTAMP NOT NULL, Egress BIGINT NOT NULL, Ingress BIGINT NOT NULL ); + +CREATE TABLE IF NOT EXISTS DnsBlackList ( + Name TEXT NOT NULL UNIQUE +); diff --git a/dns/main.go b/dns/main.go index b3adfca..a9be24e 100644 --- a/dns/main.go +++ b/dns/main.go @@ -1,43 +1,71 @@ package dns import ( + "context" "log" "net" + "sync" "github.com/miekg/dns" + "sinanmohd.com/redq/db" ) +type DnsBlackList struct { + data map[string]bool + mutex sync.RWMutex +} + type Dns struct { - server dns.Server - config *dns.ClientConfig + server dns.Server + config *dns.ClientConfig + queries *db.Queries + ctxDb context.Context + blackList DnsBlackList } func (d *Dns) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { var resp *dns.Msg var err error - client := new(dns.Client) - req.RecursionDesired = true; + d.blackList.mutex.RLock() + for _, qustion := range req.Question { + _, ok := d.blackList.data[qustion.Name] + if ok == false { + continue + } + + resp = new(dns.Msg) + resp.SetReply(req) + w.WriteMsg(resp) + return + } + d.blackList.mutex.RUnlock() + + client := new(dns.Client) + req.RecursionDesired = true for _, upstream := range d.config.Servers { resp, _, err = client.Exchange(req, net.JoinHostPort(upstream, d.config.Port)) if err == nil { break } - log.Printf("dns query: %s", err) + log.Printf("dns resolving: %s", err) + } + if err != nil { + return } w.WriteMsg(resp) } -func New() (*Dns, error) { +func New(queries *db.Queries, ctxDb context.Context) (*Dns, error) { var d Dns var err error d.server = dns.Server{ - Net: "udp", + Net: "udp", ReusePort: true, - Handler: &d, + Handler: &d, } d.config, err = dns.ClientConfigFromFile("/etc/resolv.conf") @@ -46,6 +74,18 @@ func New() (*Dns, error) { return nil, err } + d.queries = queries + d.ctxDb = ctxDb + d.blackList.data = make(map[string]bool) + blackList, err := d.queries.GetDnsBlackList(d.ctxDb) + if err != nil { + log.Printf("reading dns blacklist database: %s", err) + return nil, err + } + for _, entry := range blackList { + d.blackList.data[entry] = true + } + return &d, nil } |