diff --git a/project/db.go b/project/db.go index cfc7a1d..a999995 100644 --- a/project/db.go +++ b/project/db.go @@ -2,7 +2,6 @@ package main import ( "database/sql" - "log" "path" "modernc.org/ql" @@ -15,13 +14,13 @@ func openDB() *sql.DB { db, err0 := sql.Open("ql", dbPath) if err0 != nil { - log.Fatalf("failed to open db: %s", err0) + panicf("failed to open db: %s", err0) } err1 := db.Ping() if err1 != nil { - log.Fatalf("failed to ping db: %s", err1) + panicf("failed to ping db: %s", err1) } migrateUp(db) @@ -30,33 +29,42 @@ func openDB() *sql.DB { } // Database is a database object. Currently using modernc.org/ql -var Database = openDB() +var database *sql.DB + +func getDB() *sql.DB { + if database != nil { + return database + } + + database = openDB() + return database +} // SetConfig will write a key/value pair to the `configs` // table func SetConfig(key string, value []byte) { - tx, err := Database.Begin() + tx, err := getDB().Begin() if err != nil { - log.Fatalf("Failed to SetConfig (0): %s", err) + panicf("Failed to SetConfig (0): %s", err) } _, err2 := tx.Exec("INSERT INTO configs(key, value) VALUES(?1, ?2)", key, string(value)) if err2 != nil { - log.Fatalf("Failed to SetConfig (1): %s", err2) + panicf("Failed to SetConfig (1): %s", err2) } err1 := tx.Commit() if err1 != nil { - log.Fatalf("Failed to SetConfig (2): %s", err) + panicf("Failed to SetConfig (2): %s", err) } } // GetConfig retrieves a key/value pair from the database. func GetConfig(key string) []byte { var result string - row := Database.QueryRow("SELECT value FROM configs WHERE key=$1", key) + row := getDB().QueryRow("SELECT value FROM configs WHERE key=$1", key) err := row.Scan(&result) if err != nil { if err == sql.ErrNoRows { - log.Fatalf("CONFIG MISSING: %s", key) + panicf("CONFIG MISSING: %s", key) } else { panic(err) } diff --git a/project/db_test.go b/project/db_test.go index 54173e0..db0f0a8 100644 --- a/project/db_test.go +++ b/project/db_test.go @@ -1,40 +1,38 @@ package main import ( - "log" "testing" ) func resetDB() { - tx, err := Database.Begin() + tx, err := getDB().Begin() if err != nil { - log.Fatalf("Failed to start transaction: %s", err) + panicf("Failed to start transaction: %s", err) } for i := len(migrations) - 1; i >= 0; i-- { _, err := tx.Exec(migrations[i].down) if err != nil { - log.Fatalf("Migration failure: %s", err) + panicf("Migration failure: %s", err) } } for _, migration := range migrations { _, err := tx.Exec(migration.up) if err != nil { - log.Fatalf("Migration failure: %s", err) + panicf("Migration failure: %s", err) } } if tx.Commit() != nil { - log.Fatal(err) + panic(err) } } func TestSetUpTeardown(t *testing.T) { resetDB() - db := Database - err := db.Ping() + err := getDB().Ping() if err != nil { t.Fatalf("Test setup failed: %s", err) } diff --git a/project/decoders.go b/project/decoders.go index ea8f0fb..9f6258b 100644 --- a/project/decoders.go +++ b/project/decoders.go @@ -1,7 +1,7 @@ package main import ( - "log" + "fmt" "strings" ) @@ -15,7 +15,7 @@ type testCase struct { func B32Decode(input string) []byte { output, error := encoder.DecodeString(input) if error != nil { - log.Fatalf("Error decoding Base32 string %s", input) + panic(fmt.Sprintf("Error decoding Base32 string %s", input)) } return output @@ -24,13 +24,13 @@ func B32Decode(input string) []byte { func validateMhash(input string) string { arry := strings.Split(input, ".") if len(arry) != 2 { - log.Fatalf("Expected '%s' to be an mHash", input) + panicf("Expected '%s' to be an mHash", input) } switch arry[0] + "." { case BlobSigil, MessageSigil, PeerSigil: return input } msg := "Expected left side of Mhash dot to be one of %s, %s, %s. Got: %s" - log.Fatalf(msg, BlobSigil, MessageSigil, PeerSigil, arry[0]) + panicf(msg, BlobSigil, MessageSigil, PeerSigil, arry[0]) return input } diff --git a/project/filesystem.go b/project/filesystem.go index a8f3afc..b7dfcdb 100644 --- a/project/filesystem.go +++ b/project/filesystem.go @@ -1,7 +1,6 @@ package main import ( - "log" "os" "path" @@ -16,7 +15,7 @@ func maybeSetupPigeonDir() string { } else { home, err := homedir.Dir() if err != nil { - log.Fatalf("Home directory resolution error %s", err) + panicf("Home directory resolution error %s", err) } pigeonDataDir = path.Join(home, ".pigeon") } diff --git a/project/migrations.go b/project/migrations.go index 95a5a5c..5858bcb 100644 --- a/project/migrations.go +++ b/project/migrations.go @@ -2,7 +2,6 @@ package main import ( "database/sql" - "log" ) type migration struct { @@ -35,17 +34,17 @@ func migrateUp(db *sql.DB) { tx, err := db.Begin() if err != nil { - log.Fatalf("Failed to start transaction: %s", err) + panicf("Failed to start transaction: %s", err) } for i, migration := range migrations { _, err := tx.Exec(migration.up) if err != nil { - log.Fatalf("Migration failure(%d): %s", i, err) + panicf("Migration failure(%d): %s", i, err) } } if tx.Commit() != nil { - log.Fatal(err) + panic(err) } } diff --git a/project/peers.go b/project/peers.go index 68c6340..fd13ec0 100644 --- a/project/peers.go +++ b/project/peers.go @@ -2,7 +2,7 @@ package main import ( "database/sql" - "log" + "fmt" ) // PeerStatus represents a known state of a peer, such as @@ -20,30 +20,30 @@ const findPeerByStatus = "SELECT status FROM peers WHERE mhash=$1;" func getPeerStatus(mHash string) PeerStatus { var status PeerStatus - row := Database.QueryRow(findPeerByStatus, mHash) + row := getDB().QueryRow(findPeerByStatus, mHash) switch err := row.Scan(&status); err { case sql.ErrNoRows: return "unknown" case nil: return status default: - log.Fatalf("getPeerStatus failure: %s", err) + panicf("getPeerStatus failure: %s", err) panic(err) } } func addPeer(mHash string, status PeerStatus) { - tx, err := Database.Begin() + tx, err := getDB().Begin() if err != nil { - log.Fatalf("Failed to begin addPeer trx (0): %s", err) + panicf("Failed to begin addPeer trx (0): %s", err) } _, err2 := tx.Exec(createPeer, mHash, status) if err2 != nil { - log.Fatalf("Failure. Possible duplicate peer?: %s", err2) + panic(fmt.Sprintf("Failure. Possible duplicate peer?: %s", err2)) } err1 := tx.Commit() if err1 != nil { - log.Fatalf("Failed to commit peer (2): %s", err) + panicf("Failed to commit peer (2): %s", err) } } diff --git a/project/peers_test.go b/project/peers_test.go index b337e1d..7925185 100644 --- a/project/peers_test.go +++ b/project/peers_test.go @@ -22,3 +22,17 @@ func TestGetPeerStatus(t *testing.T) { t.Fatalf("Expected `following`, got %s", status) } } + +func TestAddPeer(t *testing.T) { + defer func() { + r := recover() + if r != nil { + err := "Should not be able to block and follow at the same time" + t.Errorf(err) + } + }() + resetDB() + mHash := "USER.GM84FEYKRQ1QFCZY68YDCRPG8HKXQPQCQSMDQKGTGX8ZY8KFSFJR" + addPeer(mHash, following) + addPeer(mHash, blocked) +} diff --git a/project/util.go b/project/util.go index 0ba867a..2cf67b5 100644 --- a/project/util.go +++ b/project/util.go @@ -2,7 +2,7 @@ package main import ( "crypto/ed25519" - "log" + "fmt" ) func showIdentity() string { @@ -31,9 +31,13 @@ func createOrShowIdentity() string { func CreateIdentity() (ed25519.PublicKey, ed25519.PrivateKey) { pub, priv, err := ed25519.GenerateKey(nil) if err != nil { - log.Fatalf("Keypair creation error %s", err) + panicf("Keypair creation error %s", err) } SetConfig("public_key", pub) SetConfig("private_key", priv) return pub, priv } + +func panicf(tpl string, args ...interface{}) { + panic(fmt.Sprintf(tpl, args...)) +}