diff options
-rw-r--r-- | api/dns.go | 64 | ||||
-rw-r--r-- | api/main.go | 15 | ||||
-rw-r--r-- | cmd/main.go | 2 | ||||
-rw-r--r-- | db/query.sql | 10 | ||||
-rw-r--r-- | db/query.sql.go | 16 | ||||
-rw-r--r-- | dns/main.go | 29 |
6 files changed, 121 insertions, 15 deletions
diff --git a/api/dns.go b/api/dns.go new file mode 100644 index 0000000..f8d345c --- /dev/null +++ b/api/dns.go @@ -0,0 +1,64 @@ +package api + +import ( + "encoding/json" + "log" + "net" + + "sinanmohd.com/redq/dns" +) + +type DnsResp map[string]string + +func handleDnsBlock(conn net.Conn, d *dns.Dns, domains []string) { + resp := make(DnsResp) + + for _, domain := range domains { + err := d.Block(domain) + if err != nil { + resp[domain] = err.Error() + } else { + resp[domain] = "blocked" + } + } + + buf, err := json.Marshal(resp) + if err != nil { + log.Printf("marshaling json: %s", err) + return + } + + conn.Write(buf) +} + +func handleDnsUnblock(conn net.Conn, d *dns.Dns, domains []string) { + resp := make(DnsResp) + + for _, domain := range domains { + err := d.Unblock(domain) + if err != nil { + resp[domain] = err.Error() + } else { + resp[domain] = "unblocked" + } + } + + buf, err := json.Marshal(resp) + if err != nil { + log.Printf("marshaling json: %s", err) + return + } + + conn.Write(buf) +} + +func handleDns(conn net.Conn, d *dns.Dns, domains []string, action string) { + switch action { + case "block": + handleDnsBlock(conn, d, domains) + case "unblock": + handleDnsUnblock(conn, d, domains) + default: + log.Printf("handling dns: invalid action '%s'", action) + } +} diff --git a/api/main.go b/api/main.go index 32f7d08..eb8715f 100644 --- a/api/main.go +++ b/api/main.go @@ -7,6 +7,7 @@ import ( "net" "sinanmohd.com/redq/db" + "sinanmohd.com/redq/dns" "sinanmohd.com/redq/usage" ) @@ -16,9 +17,9 @@ const ( ) type ApiReq struct { - Type string `json:"type"` - Action string `json:"action"` - Arg string `json:"arg"` + Type string `json:"type"` + Action string `json:"action"` + Arg []string `json:"arg"` } type Api struct { @@ -42,7 +43,7 @@ func New() (*Api, error) { return &a, nil } -func (a *Api) Run(u *usage.Usage, queries *db.Queries, ctxDb context.Context) { +func (a *Api) Run(u *usage.Usage, d *dns.Dns, queries *db.Queries, ctxDb context.Context) { for { conn, err := a.sock.Accept() if err != nil { @@ -50,11 +51,11 @@ func (a *Api) Run(u *usage.Usage, queries *db.Queries, ctxDb context.Context) { continue } - go handleConn(conn, u, queries, ctxDb) + go handleConn(conn, u, d, queries, ctxDb) } } -func handleConn(conn net.Conn, u *usage.Usage, queries *db.Queries, ctxDb context.Context) { +func handleConn(conn net.Conn, u *usage.Usage, d *dns.Dns, queries *db.Queries, ctxDb context.Context) { defer conn.Close() var req ApiReq buf := make([]byte, bufSize) @@ -76,6 +77,8 @@ func handleConn(conn net.Conn, u *usage.Usage, queries *db.Queries, ctxDb contex handleBandwidth(conn, u) case "usage": handleUsage(conn, u, queries, ctxDb) + case "dns": + handleDns(conn, d, req.Arg, req.Action) default: log.Printf("invalid request type: %s", req.Type) } diff --git a/cmd/main.go b/cmd/main.go index 540f9da..db9cc13 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -54,5 +54,5 @@ func main() { go u.Run(iface, queries, ctx) go d.Run() - a.Run(u, queries, ctx) + a.Run(u, d, queries, ctx) } diff --git a/db/query.sql b/db/query.sql index 9ba172f..f4b17dd 100644 --- a/db/query.sql +++ b/db/query.sql @@ -6,8 +6,7 @@ INSERT INTO Usage ( ); -- name: GetUsage :one -SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress -FROM Usage; +SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress FROM Usage; -- name: EnterDnsBlackList :exec INSERT INTO DnsBlackList ( @@ -16,6 +15,9 @@ INSERT INTO DnsBlackList ( $1 ); +-- name: DeleteDnsBlackList :exec +DELETE FROM DnsBlackList +WHERE Name = $1; + -- name: GetDnsBlackList :many -SELECT * -FROM DnsBlackList; +SELECT * FROM DnsBlackList; diff --git a/db/query.sql.go b/db/query.sql.go index 80c6469..7c28723 100644 --- a/db/query.sql.go +++ b/db/query.sql.go @@ -11,6 +11,16 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const deleteDnsBlackList = `-- name: DeleteDnsBlackList :exec +DELETE FROM DnsBlackList +WHERE Name = $1 +` + +func (q *Queries) DeleteDnsBlackList(ctx context.Context, name string) error { + _, err := q.db.Exec(ctx, deleteDnsBlackList, name) + return err +} + const enterDnsBlackList = `-- name: EnterDnsBlackList :exec INSERT INTO DnsBlackList ( Name @@ -52,8 +62,7 @@ func (q *Queries) EnterUsage(ctx context.Context, arg EnterUsageParams) error { } const getDnsBlackList = `-- name: GetDnsBlackList :many -SELECT name -FROM DnsBlackList +SELECT name FROM DnsBlackList ` func (q *Queries) GetDnsBlackList(ctx context.Context) ([]string, error) { @@ -77,8 +86,7 @@ func (q *Queries) GetDnsBlackList(ctx context.Context) ([]string, error) { } const getUsage = `-- name: GetUsage :one -SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress -FROM Usage +SELECT SUM(Ingress) AS Ingress, SUM(Egress) AS Egress FROM Usage ` type GetUsageRow struct { diff --git a/dns/main.go b/dns/main.go index a9be24e..3daa879 100644 --- a/dns/main.go +++ b/dns/main.go @@ -37,6 +37,7 @@ func (d *Dns) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { resp = new(dns.Msg) resp.SetReply(req) w.WriteMsg(resp) + d.blackList.mutex.RUnlock() return } d.blackList.mutex.RUnlock() @@ -92,3 +93,31 @@ func New(queries *db.Queries, ctxDb context.Context) (*Dns, error) { func (d *Dns) Run() { d.server.ListenAndServe() } + +func (d *Dns) Block(domain string) error { + err := d.queries.EnterDnsBlackList(d.ctxDb, domain) + if err != nil { + log.Printf("adding dns blacklist entry: %s", err) + return err + } + + d.blackList.mutex.Lock() + d.blackList.data[domain] = true + d.blackList.mutex.Unlock() + + return nil +} + +func (d *Dns) Unblock(domain string) error { + err := d.queries.DeleteDnsBlackList(d.ctxDb, domain) + if err != nil { + log.Printf("deleting dns blacklist entry: %s", err) + return err + } + + d.blackList.mutex.Lock() + delete(d.blackList.data, domain) + d.blackList.mutex.Unlock() + + return nil +} |