Initial perseus rewrite
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
7b95d6b80d
commit
a6826055bf
|
@ -0,0 +1,56 @@
|
||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"tildegit.org/andinus/perseus/password"
|
||||||
|
"tildegit.org/andinus/perseus/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// addToken will generate a random token, add it to database and
|
||||||
|
// return the token.
|
||||||
|
func (u *User) addToken(db *storage.DB) error {
|
||||||
|
u.Token = password.RandStr(64)
|
||||||
|
|
||||||
|
// Set user id from username.
|
||||||
|
err := u.GetID(db)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/addtoken.go: %s\n",
|
||||||
|
"failed to get id from username")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire write lock on the database.
|
||||||
|
db.Mu.Lock()
|
||||||
|
defer db.Mu.Unlock()
|
||||||
|
|
||||||
|
// Start the transaction
|
||||||
|
tx, err := db.Conn.Begin()
|
||||||
|
defer tx.Rollback()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/addtoken.go: %s\n",
|
||||||
|
"failed to begin transaction")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := db.Conn.Prepare(`
|
||||||
|
INSERT INTO access(id, token, genTime) values(?, ?, ?)`)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/addtoken.go: %s\n",
|
||||||
|
"failed to prepare statement")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
_, err = stmt.Exec(u.ID, u.Token, time.Now().UTC())
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/addtoken.go: %s\n",
|
||||||
|
"failed to execute statement")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
return err
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"tildegit.org/andinus/perseus/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// addUser adds the user to record.
|
||||||
|
func (u *User) addUser(db *storage.DB) error {
|
||||||
|
// Acquire write lock on the database.
|
||||||
|
db.Mu.Lock()
|
||||||
|
defer db.Mu.Unlock()
|
||||||
|
|
||||||
|
// Start the transaction
|
||||||
|
tx, err := db.Conn.Begin()
|
||||||
|
defer tx.Rollback()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/adduser.go: %s\n",
|
||||||
|
"failed to begin transaction")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := db.Conn.Prepare(`
|
||||||
|
INSERT INTO accounts(id, username, hash, regTime) values(?, ?, ?, ?)`)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/adduser.go: %s\n",
|
||||||
|
"failed to prepare statement")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
_, err = stmt.Exec(u.ID, u.Username, u.Hash, time.Now().UTC())
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/adduser.go: %s\n",
|
||||||
|
"failed to execute statement")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"tildegit.org/andinus/perseus/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetID returns id from username.
|
||||||
|
func (u *User) GetID(db *storage.DB) error {
|
||||||
|
// Acquire read lock on database.
|
||||||
|
db.Mu.RLock()
|
||||||
|
defer db.Mu.RUnlock()
|
||||||
|
|
||||||
|
// Get password for this user from the database.
|
||||||
|
stmt, err := db.Conn.Prepare("SELECT id FROM accounts WHERE username = ?")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/getid.go: %s\n",
|
||||||
|
"failed to prepare statement")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
var id string
|
||||||
|
err = stmt.QueryRow(u.Username).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/getid.go: %s\n",
|
||||||
|
"query failed")
|
||||||
|
}
|
||||||
|
u.ID = id
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"tildegit.org/andinus/perseus/password"
|
||||||
|
"tildegit.org/andinus/perseus/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Login takes in login details and returns an error. If error doesn't
|
||||||
|
// equal nil then consider login failed. It will also set the u.Token
|
||||||
|
// field.
|
||||||
|
func (u *User) Login(db *storage.DB) error {
|
||||||
|
// Acquire read lock on the database.
|
||||||
|
db.Mu.RLock()
|
||||||
|
|
||||||
|
// Get password for this user from the database.
|
||||||
|
stmt, err := db.Conn.Prepare("SELECT hash FROM accounts WHERE username = ?")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/login.go: %s\n",
|
||||||
|
"failed to prepare statement")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
var hash string
|
||||||
|
err = stmt.QueryRow(u.Username).Scan(&hash)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/login.go: %s\n",
|
||||||
|
"query failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
u.Hash = hash
|
||||||
|
|
||||||
|
// Check user's password.
|
||||||
|
err = password.Check(u.Password, u.Hash)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/login.go: %s%s\n",
|
||||||
|
"user login failed, username: ", u.Username)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
db.Mu.RUnlock()
|
||||||
|
|
||||||
|
err = u.addToken(db)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/login.go: %s\n",
|
||||||
|
"addtoken failed")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"tildegit.org/andinus/perseus/password"
|
||||||
|
"tildegit.org/andinus/perseus/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Register takes in registration details and returns an error. If
|
||||||
|
// error doesn't equal nil then the registration was unsuccessful.
|
||||||
|
func (u User) Register(db *storage.DB) error {
|
||||||
|
var err error
|
||||||
|
u.ID = password.RandStr(64)
|
||||||
|
u.Username = strings.ToLower(u.Username)
|
||||||
|
|
||||||
|
// Validate username. It must be alphanumeric and less than
|
||||||
|
// 128 characters.
|
||||||
|
re := regexp.MustCompile("^[a-zA-Z0-9]*$")
|
||||||
|
if !re.MatchString(u.Username) {
|
||||||
|
return errors.New("account/register.go: invalid username")
|
||||||
|
}
|
||||||
|
if len(u.Username) > 128 {
|
||||||
|
return errors.New("account/register.go: username too long")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate password
|
||||||
|
if len(u.Password) < 8 {
|
||||||
|
return errors.New("account/register.go: password too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Hash, err = password.Hash(u.Password)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/register.go: %s\n",
|
||||||
|
"password.Hash func failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = u.addUser(db)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("account/register.go: %s\n",
|
||||||
|
"addUser func failed")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package user
|
package account
|
||||||
|
|
||||||
// User holds information about the user.
|
// User holds information about the user.
|
||||||
type User struct {
|
type User struct {
|
||||||
|
@ -6,4 +6,5 @@ type User struct {
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
Hash string
|
Hash string
|
||||||
|
Token string
|
||||||
}
|
}
|
|
@ -1,14 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
)
|
|
||||||
|
|
||||||
// genID generates a random id string of length n. Don't forget to
|
|
||||||
// seed the random number generator otherwise it won't be random.
|
|
||||||
func genID(n int) string {
|
|
||||||
b := make([]byte, n/2)
|
|
||||||
rand.Read(b)
|
|
||||||
return base64.StdEncoding.EncodeToString(b)
|
|
||||||
}
|
|
|
@ -1,13 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// hashPass takes a string as input and returns the hash of the
|
|
||||||
// password.
|
|
||||||
func hashPass(password string) (string, error) {
|
|
||||||
// 10 is the default cost.
|
|
||||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 10)
|
|
||||||
return string(bytes), err
|
|
||||||
}
|
|
|
@ -6,12 +6,12 @@ steps:
|
||||||
- name: vet
|
- name: vet
|
||||||
image: golang:1.13
|
image: golang:1.13
|
||||||
commands:
|
commands:
|
||||||
- go vet ./...
|
- go vet ./...
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
image: golang:1.13
|
image: golang:1.13
|
||||||
commands:
|
commands:
|
||||||
- go test -v ./auth
|
- go test -v ./password
|
||||||
|
|
||||||
---
|
---
|
||||||
kind: pipeline
|
kind: pipeline
|
||||||
|
@ -24,4 +24,4 @@ steps:
|
||||||
GOARCH: amd64
|
GOARCH: amd64
|
||||||
GOOS: openbsd
|
GOOS: openbsd
|
||||||
commands:
|
commands:
|
||||||
- go build ./cmd/perseus
|
- go build ./cmd/perseus
|
||||||
|
|
|
@ -15,25 +15,24 @@ func main() {
|
||||||
db := storage.Init()
|
db := storage.Init()
|
||||||
defer db.Conn.Close()
|
defer db.Conn.Close()
|
||||||
|
|
||||||
envPort, exists := os.LookupEnv("PERSEUS_PORT")
|
envPort := os.Getenv("PERSEUS_PORT")
|
||||||
if !exists {
|
if envPort == "" {
|
||||||
envPort = "8080"
|
envPort = "8080"
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf("127.0.0.1:%s", envPort)
|
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: addr,
|
Addr: fmt.Sprintf("127.0.0.1:%s", envPort),
|
||||||
WriteTimeout: 8 * time.Second,
|
WriteTimeout: 8 * time.Second,
|
||||||
ReadTimeout: 8 * time.Second,
|
ReadTimeout: 8 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
http.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) {
|
||||||
web.HandleRegister(w, r, db)
|
web.RegisterHandler(w, r, db)
|
||||||
})
|
})
|
||||||
http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
web.HandleLogin(w, r, db)
|
web.LoginHandler(w, r, db)
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Printf("main/main.go: listening on port %s...", envPort)
|
log.Printf("perseus: listening on port %s...", envPort)
|
||||||
log.Fatal(srv.ListenAndServe())
|
log.Fatal(srv.ListenAndServe())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
package core
|
|
||||||
|
|
||||||
// Version will return the current version.
|
|
||||||
func Version() string {
|
|
||||||
return "v0.1.0"
|
|
||||||
}
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"tildegit.org/andinus/perseus/account"
|
||||||
|
"tildegit.org/andinus/perseus/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoginHandler handles login.
|
||||||
|
func LoginHandler(w http.ResponseWriter, r *http.Request, db *storage.DB) {
|
||||||
|
p := Page{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
t, err := template.ParseFiles("web/templates/login.html")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("web/login.go: 500 Internal Server Error :: %s", err.Error())
|
||||||
|
http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
t.Execute(w, p)
|
||||||
|
|
||||||
|
case http.MethodPost:
|
||||||
|
if err = r.ParseForm(); err != nil {
|
||||||
|
log.Printf("web/login.go: 400 Bad Request :: %s", err.Error())
|
||||||
|
http.Error(w, "400 Bad Request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get form values
|
||||||
|
u := account.User{}
|
||||||
|
u.Username = r.FormValue("username")
|
||||||
|
u.Password = r.FormValue("password")
|
||||||
|
|
||||||
|
// Perform login
|
||||||
|
err = u.Login(db)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("web/login.go: %s :: %s",
|
||||||
|
"login failed",
|
||||||
|
err.Error())
|
||||||
|
|
||||||
|
error := []string{}
|
||||||
|
error = append(error,
|
||||||
|
fmt.Sprintf("Login failed"))
|
||||||
|
|
||||||
|
p.Error = error
|
||||||
|
t.Execute(w, p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login successful, set token
|
||||||
|
cookie := http.Cookie{
|
||||||
|
Name: "token",
|
||||||
|
Value: u.Token,
|
||||||
|
// Expire the cookie after 16 days from
|
||||||
|
// current UTC time.
|
||||||
|
Expires: time.Now().UTC().Add(16 * 24 * time.Hour),
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
http.SetCookie(w, &cookie)
|
||||||
|
success := []string{}
|
||||||
|
success = append(success,
|
||||||
|
fmt.Sprintf("Login successful"))
|
||||||
|
p.Success = success
|
||||||
|
t.Execute(w, p)
|
||||||
|
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
log.Printf("web/login.go: %v not allowed on %v", r.Method, r.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,11 +1,8 @@
|
||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import "html/template"
|
||||||
"html/template"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Page holds page information that is sent to all webpages rendered
|
// Page holds page information.
|
||||||
// by perseus.
|
|
||||||
type Page struct {
|
type Page struct {
|
||||||
SafeList []template.HTML
|
SafeList []template.HTML
|
||||||
List []string
|
List []string
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"tildegit.org/andinus/perseus/account"
|
||||||
|
"tildegit.org/andinus/perseus/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterHandler handles registration.
|
||||||
|
func RegisterHandler(w http.ResponseWriter, r *http.Request, db *storage.DB) {
|
||||||
|
p := Page{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
t, err := template.ParseFiles("web/templates/register.html")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("web/register.go: 500 Internal Server Error :: %s", err.Error())
|
||||||
|
http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Notice = []string{
|
||||||
|
"Only [a-z] & [0-9] allowed for username",
|
||||||
|
"Password length must be greater than 8 characters",
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
t.Execute(w, p)
|
||||||
|
|
||||||
|
case http.MethodPost:
|
||||||
|
if err = r.ParseForm(); err != nil {
|
||||||
|
log.Printf("web/register.go: 400 Bad Request :: %s", err.Error())
|
||||||
|
http.Error(w, "400 Bad Request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get form values
|
||||||
|
u := account.User{}
|
||||||
|
u.Username = r.FormValue("username")
|
||||||
|
u.Password = r.FormValue("password")
|
||||||
|
|
||||||
|
// Perform registration
|
||||||
|
err = u.Register(db)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("web/register.go: %s :: %s",
|
||||||
|
"registration failed",
|
||||||
|
err.Error())
|
||||||
|
|
||||||
|
error := []string{}
|
||||||
|
error = append(error,
|
||||||
|
fmt.Sprintf("Registration failed"))
|
||||||
|
|
||||||
|
// Check if the error was because of username
|
||||||
|
// not being unique.
|
||||||
|
if strings.HasPrefix(err.Error(), "UNIQUE constraint failed") {
|
||||||
|
error = append(error,
|
||||||
|
fmt.Sprintf("Username not unique"))
|
||||||
|
}
|
||||||
|
p.Error = error
|
||||||
|
} else {
|
||||||
|
success := []string{}
|
||||||
|
success = append(success,
|
||||||
|
fmt.Sprintf("Registration successful"))
|
||||||
|
p.Success = success
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Execute(w, p)
|
||||||
|
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
log.Printf("web/register.go: %v not allowed on %v", r.Method, r.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,14 +1,13 @@
|
||||||
package auth
|
// Password package contains functions related to passwords.
|
||||||
|
package password
|
||||||
|
|
||||||
import (
|
import "golang.org/x/crypto/bcrypt"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// checkPass takes a string and hash as input and returns an error. If
|
// Check takes a string and hash as input and returns an error. If
|
||||||
// the error is not nil then the consider the password wrong. We're
|
// the error is not nil then the consider the password wrong. We're
|
||||||
// returning error instead of a bool so that we can print failed
|
// returning error instead of a bool so that we can print failed
|
||||||
// logins to log and logging shouldn't happen here.
|
// logins to log and logging shouldn't happen here.
|
||||||
func checkPass(password, hash string) error {
|
func Check(password, hash string) error {
|
||||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
|
@ -1,9 +1,9 @@
|
||||||
package auth
|
package password
|
||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
// TestCheckPass tests the checkPass function.
|
// TestCheck tests the Check function.
|
||||||
func TestCheckPass(t *testing.T) {
|
func TestCheck(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
passhash := make(map[string]string)
|
passhash := make(map[string]string)
|
||||||
|
|
||||||
|
@ -13,24 +13,22 @@ func TestCheckPass(t *testing.T) {
|
||||||
passhash["Z1S/kQ=="] = "$2a$10$fZ05kKmb7bh4vBLebpK1u.3bUNQ6eeX5ghT/GZaekgS.5bx4.Ru1e"
|
passhash["Z1S/kQ=="] = "$2a$10$fZ05kKmb7bh4vBLebpK1u.3bUNQ6eeX5ghT/GZaekgS.5bx4.Ru1e"
|
||||||
passhash["J861dQ=="] = "$2a$10$nXb6Btn6n3AWMAUkDh9bFObvQw5V9FLKhfX.E1EzRWgVDuqIp99u2"
|
passhash["J861dQ=="] = "$2a$10$nXb6Btn6n3AWMAUkDh9bFObvQw5V9FLKhfX.E1EzRWgVDuqIp99u2"
|
||||||
|
|
||||||
// We also check with values generated with hashPass, this may
|
// We also check with values generated with Hash, this may
|
||||||
// fail if hashPass itself fails in that case it's not
|
// fail if Hash itself fails in that case it's not Check error
|
||||||
// checkPass error so the test shouldn't fail but warning
|
// so the test shouldn't fail but warning should be sent. We
|
||||||
// should be sent. We use genID func to generate random inputs
|
// use genID func to generate random inputs for this test.
|
||||||
// for this test.
|
|
||||||
for i := 1; i <= 4; i++ {
|
for i := 1; i <= 4; i++ {
|
||||||
p := genID(8)
|
p := RandStr(8)
|
||||||
passhash[p], err = hashPass(p)
|
passhash[p], err = Hash(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("hashPass func failed")
|
t.Log("hashPass func failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We test the checkPass func by ranging over all values of
|
// We test the Check func by ranging over all values of
|
||||||
// passhash. We assume that hashPass func returns correct
|
// passhash. We assume that Hash func returns correct hashes.
|
||||||
// hashes.
|
|
||||||
for p, h := range passhash {
|
for p, h := range passhash {
|
||||||
err = checkPass(p, h)
|
err = Check(p, h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("password: %s, hash: %s didn't match.",
|
t.Errorf("password: %s, hash: %s didn't match.",
|
||||||
p, h)
|
p, h)
|
|
@ -0,0 +1,11 @@
|
||||||
|
package password
|
||||||
|
|
||||||
|
import "golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
// Hash takes a string as input and returns the hash of the
|
||||||
|
// password.
|
||||||
|
func Hash(password string) (string, error) {
|
||||||
|
// 10 is the default cost.
|
||||||
|
out, err := bcrypt.GenerateFromPassword([]byte(password), 10)
|
||||||
|
return string(out), err
|
||||||
|
}
|
|
@ -1,21 +1,21 @@
|
||||||
package auth
|
package password
|
||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
// TestHashPass tests the checkPass function.
|
// TestHash tests the Hash function.
|
||||||
func TestHashPass(t *testing.T) {
|
func TestHash(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
passhash := make(map[string]string)
|
passhash := make(map[string]string)
|
||||||
|
|
||||||
// We generate random hashes with hashPass, random string is
|
// We generate random hashes with Hash, random string is
|
||||||
// generate by genID func.
|
// generate by RandStr func.
|
||||||
for i := 1; i <= 8; i++ {
|
for i := 1; i <= 8; i++ {
|
||||||
p := genID(8)
|
p := RandStr(8)
|
||||||
passhash[p], err = hashPass(p)
|
passhash[p], err = Hash(p)
|
||||||
|
|
||||||
// Here we test if the hashPass func runs sucessfully.
|
// Here we test if the hashPass func runs sucessfully.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("hashPass func failed for password: %s",
|
t.Errorf("Hash func failed for password: %s",
|
||||||
p)
|
p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ func TestHashPass(t *testing.T) {
|
||||||
// hashes. We assume that checkPass func returns correct
|
// hashes. We assume that checkPass func returns correct
|
||||||
// values.
|
// values.
|
||||||
for p, h := range passhash {
|
for p, h := range passhash {
|
||||||
err = checkPass(p, h)
|
err = Check(p, h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("password: %s, hash: %s didn't match.",
|
t.Errorf("password: %s, hash: %s didn't match.",
|
||||||
p, h)
|
p, h)
|
|
@ -0,0 +1,13 @@
|
||||||
|
package password
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandStr will return a random base64 encoded string of length n.
|
||||||
|
func RandStr(n int) string {
|
||||||
|
b := make([]byte, n/2)
|
||||||
|
rand.Read(b)
|
||||||
|
return base64.StdEncoding.EncodeToString(b)
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package sqlite3
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
@ -17,8 +17,7 @@ func initErr(db *DB, err error) {
|
||||||
log.Fatalf("Initialization Error :: %s", err.Error())
|
log.Fatalf("Initialization Error :: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes a sqlite3 database.
|
func initDB(db *DB) {
|
||||||
func Init(db *DB) {
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// We set the database path, first the environment variable
|
// We set the database path, first the environment variable
|
||||||
|
@ -36,7 +35,7 @@ func Init(db *DB) {
|
||||||
db.Conn, err = sql.Open("sqlite3", db.Path)
|
db.Conn, err = sql.Open("sqlite3", db.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("sqlite3/init.go: %s\n",
|
log.Printf("sqlite3/init.go: %s\n",
|
||||||
"Failed to open database connection")
|
"failed to open database connection")
|
||||||
initErr(db, err)
|
initErr(db, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,11 +49,11 @@ func Init(db *DB) {
|
||||||
token TEXT NOT NULL,
|
token TEXT NOT NULL,
|
||||||
genTime TEXT NOT NULL);`,
|
genTime TEXT NOT NULL);`,
|
||||||
|
|
||||||
`CREATE TABLE IF NOT EXISTS users (
|
`CREATE TABLE IF NOT EXISTS accounts (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
type TEXT NOT NULL DEFAULT user,
|
type TEXT NOT NULL DEFAULT user,
|
||||||
username VARCHAR(128) NOT NULL UNIQUE,
|
username VARCHAR(128) NOT NULL UNIQUE,
|
||||||
password TEXT NOT NULL,
|
hash TEXT NOT NULL,
|
||||||
regTime TEXT NOT NULL);`,
|
regTime TEXT NOT NULL);`,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +66,7 @@ func Init(db *DB) {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("sqlite3/init.go: %s\n",
|
log.Printf("sqlite3/init.go: %s\n",
|
||||||
"Failed to prepare statement")
|
"failed to prepare statement")
|
||||||
log.Println(s)
|
log.Println(s)
|
||||||
initErr(db, err)
|
initErr(db, err)
|
||||||
}
|
}
|
||||||
|
@ -76,7 +75,7 @@ func Init(db *DB) {
|
||||||
stmt.Close()
|
stmt.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("sqlite3/init.go: %s\n",
|
log.Printf("sqlite3/init.go: %s\n",
|
||||||
"Failed to execute statement")
|
"failed to execute statement")
|
||||||
log.Println(s)
|
log.Println(s)
|
||||||
initErr(db, err)
|
initErr(db, err)
|
||||||
}
|
}
|
|
@ -1,13 +0,0 @@
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DB holds the database connection, mutex & path.
|
|
||||||
type DB struct {
|
|
||||||
Path string
|
|
||||||
Mu *sync.RWMutex
|
|
||||||
Conn *sql.DB
|
|
||||||
}
|
|
|
@ -1,17 +1,23 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"tildegit.org/andinus/perseus/storage/sqlite3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DB holds the database connection, mutex & path.
|
||||||
|
type DB struct {
|
||||||
|
Path string
|
||||||
|
Mu *sync.RWMutex
|
||||||
|
Conn *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
// Init initializes the database.
|
// Init initializes the database.
|
||||||
func Init() *sqlite3.DB {
|
func Init() *DB {
|
||||||
var db sqlite3.DB = sqlite3.DB{
|
db := DB{
|
||||||
Mu: new(sync.RWMutex),
|
Mu: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlite3.Init(&db)
|
initDB(&db)
|
||||||
return &db
|
return &db
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,42 +0,0 @@
|
||||||
package user
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"tildegit.org/andinus/perseus/storage/sqlite3"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AddUser adds the user to record.
|
|
||||||
func (u User) AddUser(db *sqlite3.DB) error {
|
|
||||||
// Acquire write lock on the database.
|
|
||||||
db.Mu.Lock()
|
|
||||||
defer db.Mu.Unlock()
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := db.Conn.Begin()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("user/adduser.go: %s\n",
|
|
||||||
"failed to begin transaction")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
usrStmt, err := db.Conn.Prepare(`
|
|
||||||
INSERT INTO users(id, username, password, regTime) values(?, ?, ?, ?)`)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("user/adduser.go: %s\n",
|
|
||||||
"failed to prepare statement")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer usrStmt.Close()
|
|
||||||
|
|
||||||
_, err = usrStmt.Exec(u.ID, u.Username, u.Password, time.Now().UTC())
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("user/adduser.go: %s\n",
|
|
||||||
"failed to execute statement")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Commit()
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,29 +0,0 @@
|
||||||
package user
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"tildegit.org/andinus/perseus/storage/sqlite3"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetID returns id from username.
|
|
||||||
func (u *User) GetID(db *sqlite3.DB) error {
|
|
||||||
// Get password for this user from the database.
|
|
||||||
stmt, err := db.Conn.Prepare("SELECT id FROM users WHERE username = ?")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("user/getid.go: %s\n",
|
|
||||||
"failed to prepare statement")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
var id string
|
|
||||||
err = stmt.QueryRow(u.Username).Scan(&id)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("user/getid.go: %s\n",
|
|
||||||
"query failed")
|
|
||||||
}
|
|
||||||
u.ID = id
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -40,7 +40,7 @@
|
||||||
/
|
/
|
||||||
<a href="https://andinus.nand.sh/perseus">Perseus</a>
|
<a href="https://andinus.nand.sh/perseus">Perseus</a>
|
||||||
<span style="float:right">
|
<span style="float:right">
|
||||||
Perseus {{ .Version }}
|
Perseus {{ if .Version}} {{ . }} {{ end }}
|
||||||
/
|
/
|
||||||
<a href="https://tildegit.org/andinus/perseus">
|
<a href="https://tildegit.org/andinus/perseus">
|
||||||
Source Code
|
Source Code
|
|
@ -40,7 +40,7 @@
|
||||||
/
|
/
|
||||||
<a href="https://andinus.nand.sh/perseus">Perseus</a>
|
<a href="https://andinus.nand.sh/perseus">Perseus</a>
|
||||||
<span style="float:right">
|
<span style="float:right">
|
||||||
Perseus {{ .Version }}
|
Perseus {{ if .Version}} {{ . }} {{ end }}
|
||||||
/
|
/
|
||||||
<a href="https://tildegit.org/andinus/perseus">
|
<a href="https://tildegit.org/andinus/perseus">
|
||||||
Source Code
|
Source Code
|
Loading…
Reference in New Issue