summaryrefslogtreecommitdiff
path: root/db
diff options
context:
space:
mode:
Diffstat (limited to 'db')
-rw-r--r--db/account.go112
-rw-r--r--db/bearer.go127
-rw-r--r--db/main.go74
-rw-r--r--db/utils.go24
4 files changed, 337 insertions, 0 deletions
diff --git a/db/account.go b/db/account.go
new file mode 100644
index 0000000..2c99045
--- /dev/null
+++ b/db/account.go
@@ -0,0 +1,112 @@
+package db
+
+import "errors"
+
+type Account struct {
+ Email string
+ PassHash string
+
+ Info *Login
+}
+
+type Login struct {
+ id uint
+ Level uint
+ FirstName, LastName string
+ Bearer *Bearer
+}
+
+func (ac *Account) CreateAccount(safe *SafeDB) error {
+ const sqlStatement string = `
+ INSERT INTO Accounts (
+ id,
+ Email,
+ PassHash,
+ Level,
+ FirstName,
+ LastName
+ )
+ VALUES (NULL, ?, ?, ?, ?, ?);
+ `
+
+ safe.mu.Lock()
+ defer safe.mu.Unlock()
+
+ _, err := safe.db.Exec(
+ sqlStatement,
+ ac.Email,
+ ToBlake3(ac.PassHash),
+
+ ac.Info.FirstName,
+ ac.Info.LastName,
+ ac.Info.Level,
+ )
+
+ return err
+}
+
+func (ac *Account) Login(safe *SafeDB) error {
+ const sqlStatementQuery string = `
+ SELECT id, PassHash, Level, FirstName, LastName
+ FROM Accounts
+ WHERE Accounts.Email = ?
+ `
+
+ ac.Info = &Login{}
+ ac.Info.Bearer = &Bearer{}
+ safe.mu.Lock()
+ row := safe.db.QueryRow(sqlStatementQuery, ac.Email)
+ safe.mu.Unlock()
+
+ var PassHash string
+ err := row.Scan(
+ &ac.Info.id,
+ &PassHash,
+ &ac.Info.FirstName,
+ &ac.Info.LastName,
+ &ac.Info.Level,
+ )
+ if err != nil {
+ return err
+ }
+ if PassHash != ac.PassHash {
+ return errors.New("Auth failed")
+ }
+
+ err = ac.Info.Bearer.Generate(safe, ac.Info)
+ if err != nil {
+ return err
+ }
+
+ return err
+}
+
+func (ac *Account) fromBearer(safe *SafeDB, b *Bearer) error {
+ const sqlStatementAccount string = `
+ SELECT Email, PassHash, Level, FirstName, LastName
+ FROM Accounts
+ WHERE Accounts.id = ?
+ `
+
+ safe.mu.Lock()
+ row := safe.db.QueryRow(sqlStatementAccount, b.accountId)
+ safe.mu.Unlock()
+
+ ac.Info = &Login{}
+ ac.Info.id = b.accountId
+ ac.Info.Bearer = b
+ err := row.Scan(
+ &ac.Email,
+ &ac.PassHash,
+
+ &ac.Info.FirstName,
+ &ac.Info.LastName,
+ &ac.Info.Level,
+ )
+ if err != nil {
+ return err
+ }
+ ac.Info.Bearer = b
+
+ return err
+}
diff --git a/db/bearer.go b/db/bearer.go
new file mode 100644
index 0000000..b16d506
--- /dev/null
+++ b/db/bearer.go
@@ -0,0 +1,127 @@
+package db
+
+import (
+ "errors"
+ "time"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+type Bearer struct {
+ id, accountId uint
+ Token string
+ ValidUpTo time.Time
+}
+
+func (b *Bearer) FromToken(safe *SafeDB, Token string) error {
+ const sqlStatementBearer string = `
+ SELECT id, ValidUpTo, accountId
+ FROM Bearer
+ WHERE Bearer.Token = ?
+ `
+
+ b.Token = Token
+ var ValidUpToString string
+ safe.mu.Lock()
+ row := safe.db.QueryRow(sqlStatementBearer, Token)
+ safe.mu.Unlock()
+
+ err := row.Scan(
+ &b.id,
+ &ValidUpToString,
+ &b.accountId,
+ )
+ if err != nil {
+ return err
+ }
+
+ layout := "2006-01-02 15:04:05.999999999-07:00"
+ b.ValidUpTo, err = time.Parse(layout, ValidUpToString)
+ if err != nil {
+ return err
+ }
+
+ timeNow := time.Now()
+ if timeNow.After(b.ValidUpTo) {
+ return errors.New("Outdated Bearer Token")
+ }
+
+ return err
+}
+
+func (b *Bearer) Update(safe *SafeDB) error {
+ const sqlStatementBearer string = `
+ UPDATE Bearer
+ SET ValidUpTo = ?
+ WHERE id = ?
+ `
+
+ validUpTo := time.Now().Add(time.Hour * 24)
+ safe.mu.Lock()
+ _, err := safe.db.Exec(sqlStatementBearer, validUpTo, b.id)
+ safe.mu.Unlock()
+ if err != nil {
+ return err
+ }
+ b.ValidUpTo = validUpTo
+
+ return nil
+}
+
+func (b *Bearer) VerifyAndUpdate(safe *SafeDB, token string) error {
+ err := b.FromToken(safe, token)
+ if err != nil {
+ return err
+ }
+
+ err = b.Update(safe)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (b *Bearer) Generate(safe *SafeDB, lg *Login) error {
+ const sqlGenBearer string = `
+ INSERT INTO Bearer (
+ id,
+ Token,
+ ValidUpTo,
+ accountId
+ )
+ VALUES (NULL, ?, ?, ?);
+ `
+
+ Token, err := GenRandomString(128)
+ if err != nil {
+ return err
+ }
+
+ timeNow := time.Now()
+ ValidUpTo := timeNow.Add(time.Hour * 24)
+ safe.mu.Lock()
+ res, err := safe.db.Exec(
+ sqlGenBearer,
+ Token,
+ ValidUpTo,
+ lg.id,
+ )
+ safe.mu.Unlock()
+ if err != nil {
+ return err
+ }
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ return err
+ }
+
+ b.id = uint(id)
+ b.accountId = lg.id
+ b.Token = Token
+ b.ValidUpTo = ValidUpTo
+ lg.Bearer = b
+
+ return err
+}
diff --git a/db/main.go b/db/main.go
new file mode 100644
index 0000000..c78ea3a
--- /dev/null
+++ b/db/main.go
@@ -0,0 +1,74 @@
+package db
+
+import (
+ "database/sql"
+ "os"
+ "path/filepath"
+ "sync"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+type SafeDB struct {
+ mu sync.Mutex
+
+ path string
+ db *sql.DB
+}
+
+func (safe *SafeDB) setupPath() error {
+ const path string = "/var/lib/redq/"
+ const name string = "redq.sqlite3"
+
+ err := os.MkdirAll(path, os.ModeDir)
+ if err != nil {
+ return err
+ }
+
+ safe.path = filepath.Join(path, name)
+ return nil
+}
+
+func NewSafeDB() (*SafeDB, error) {
+ const create string = `
+ CREATE TABLE IF NOT EXISTS Accounts(
+ id INTEGER PRIMARY KEY,
+ Email CHAR(64) NOT NULL UNIQUE,
+ PassHash CHAR(128) NOT NULL,
+
+ Level INTEGER NOT NULL,
+ FirstName CHAR(32) NOT NULL,
+ LastName CHAR(32) NOT NULL
+ );
+
+ CREATE TABLE IF NOT EXISTS Bearer(
+ id INTEGER PRIMARY KEY,
+ Token CHAR(128) NOT NULL UNIQUE,
+ ValidUpTo TIME NOT NULL,
+ accountId INTEGER NOT NULL,
+
+ FOREIGN KEY (accountId)
+ REFERENCES Accounts (id)
+ );
+ `
+ safe := &SafeDB{}
+ err := safe.setupPath()
+ if err != nil {
+ return nil, err
+ }
+
+ safe.mu.Lock()
+ defer safe.mu.Unlock()
+
+ safe.db, err = sql.Open("sqlite3", safe.path)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = safe.db.Exec(create)
+ if err != nil {
+ return nil, err
+ }
+
+ return safe, nil
+}
diff --git a/db/utils.go b/db/utils.go
new file mode 100644
index 0000000..0b0f1cb
--- /dev/null
+++ b/db/utils.go
@@ -0,0 +1,24 @@
+package db
+
+import (
+ "encoding/base64"
+ "lukechampine.com/blake3"
+ "math/rand"
+)
+
+func ToBlake3(pass string) string {
+ hash := blake3.Sum512([]byte(pass))
+ hash64b := base64.StdEncoding.EncodeToString(hash[:])
+
+ return "blake3-" + hash64b
+}
+
+func GenRandomString(n int) (string, error) {
+ b := make([]byte, n)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+
+ return base64.URLEncoding.EncodeToString(b)[:n], nil
+}