From 5fe51d86214298637838539230cf43058d041533 Mon Sep 17 00:00:00 2001 From: 6543 <6543@obermui.de> Date: Sun, 5 Dec 2021 19:00:57 +0100 Subject: [PATCH] rm certDB helper and build in --- server/certificates/certificates.go | 33 +++++++++++++++---------- server/certificates/mock.go | 4 +++- server/database/helpers.go | 37 ----------------------------- server/database/interface.go | 9 ++++--- server/database/setup.go | 29 +++++++++++++++++----- 5 files changed, 53 insertions(+), 59 deletions(-) delete mode 100644 server/database/helpers.go diff --git a/server/certificates/certificates.go b/server/certificates/certificates.go index bb6a3c3..fa76538 100644 --- a/server/certificates/certificates.go +++ b/server/certificates/certificates.go @@ -188,8 +188,11 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error { func retrieveCertFromDB(sni, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, keyDatabase database.CertDB) (tls.Certificate, bool) { // parse certificate from database - res := &certificate.Resource{} - if !database.PogrebGet(keyDatabase, sni, res) { + res, err := keyDatabase.Get(sni) + if err != nil { + panic(err) // TODO: no panic + } + if res == nil { return tls.Certificate{}, false } @@ -294,7 +297,9 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) { // avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10)) - database.PogrebPut(keyDatabase, []byte(name), renew) + if err := keyDatabase.Put(name, renew); err != nil { + return mockCert(domains[0], err.Error(), string(mainDomainSuffix), keyDatabase), err + } return tlsCertificate, nil } } @@ -302,7 +307,9 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re } log.Printf("Obtained certificate for %v", domains) - database.PogrebPut(keyDatabase, []byte(name), res) + if err := keyDatabase.Put(name, res); err != nil { + return tls.Certificate{}, err + } tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) if err != nil { return tls.Certificate{}, err @@ -447,12 +454,12 @@ func SetupCertificates(mainDomainSuffix []byte, dnsProvider string, acmeConfig * } } -func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, keyDatabase database.CertDB) { +func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) { for { // clean up expired certs now := time.Now() expiredCertCount := 0 - keyDatabaseIterator := keyDatabase.Items() + keyDatabaseIterator := certDB.Items() key, resBytes, err := keyDatabaseIterator.Next() for err == nil { if !bytes.Equal(key, mainDomainSuffix) { @@ -466,7 +473,7 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) if err != nil || !tlsCertificates[0].NotAfter.After(now) { - err := keyDatabase.Delete(key) + err := certDB.Delete(key) if err != nil { log.Printf("[ERROR] Deleting expired certificate for %s failed: %s", string(key), err) } else { @@ -479,7 +486,7 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi log.Printf("[INFO] Removed %d expired certificates from the database", expiredCertCount) // compact the database - result, err := keyDatabase.Compact() + result, err := certDB.Compact() if err != nil { log.Printf("[ERROR] Compacting key database failed: %s", err) } else { @@ -487,16 +494,18 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi } // update main cert - res := &certificate.Resource{} - if !database.PogrebGet(keyDatabase, mainDomainSuffix, res) { - log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted") + res, err := certDB.Get(mainDomainSuffix) + if err != nil { + log.Err(err).Msgf("could not get cert for domain '%s'", mainDomainSuffix) + } else if res == nil { + log.Error().Msgf("Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted") } else { tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) // renew main certificate 30 days before it expires if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) { go (func() { - _, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, keyDatabase) + _, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB) if err != nil { log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err) } diff --git a/server/certificates/mock.go b/server/certificates/mock.go index 22d5470..0e87e6e 100644 --- a/server/certificates/mock.go +++ b/server/certificates/mock.go @@ -74,7 +74,9 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB) if domain == "*"+mainDomainSuffix || domain == mainDomainSuffix[1:] { databaseName = mainDomainSuffix } - database.PogrebPut(keyDatabase, []byte(databaseName), res) + if err := keyDatabase.Put(databaseName, res); err != nil { + panic(err) + } tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) if err != nil { diff --git a/server/database/helpers.go b/server/database/helpers.go deleted file mode 100644 index ea3e899..0000000 --- a/server/database/helpers.go +++ /dev/null @@ -1,37 +0,0 @@ -package database - -import ( - "bytes" - "encoding/gob" -) - -func PogrebPut(db CertDB, name []byte, obj interface{}) { - var resGob bytes.Buffer - resEnc := gob.NewEncoder(&resGob) - err := resEnc.Encode(obj) - if err != nil { - panic(err) - } - err = db.Put(name, resGob.Bytes()) - if err != nil { - panic(err) - } -} - -func PogrebGet(db CertDB, name []byte, obj interface{}) bool { - resBytes, err := db.Get(name) - if err != nil { - panic(err) - } - if resBytes == nil { - return false - } - - resGob := bytes.NewBuffer(resBytes) - resDec := gob.NewDecoder(resGob) - err = resDec.Decode(obj) - if err != nil { - panic(err) - } - return true -} diff --git a/server/database/interface.go b/server/database/interface.go index 80d74d3..01b9872 100644 --- a/server/database/interface.go +++ b/server/database/interface.go @@ -1,11 +1,14 @@ package database -import "github.com/akrylysov/pogreb" +import ( + "github.com/akrylysov/pogreb" + "github.com/go-acme/lego/v4/certificate" +) type CertDB interface { Close() error - Put(key []byte, value []byte) error - Get(key []byte) ([]byte, error) + Put(name string, cert *certificate.Resource) error + Get(name []byte) (*certificate.Resource, error) Delete(key []byte) error Compact() (pogreb.CompactionResult, error) Items() *pogreb.ItemIterator diff --git a/server/database/setup.go b/server/database/setup.go index f7eeafc..f3cac16 100644 --- a/server/database/setup.go +++ b/server/database/setup.go @@ -1,14 +1,16 @@ package database import ( + "bytes" "context" + "encoding/gob" "fmt" "time" - "github.com/rs/zerolog/log" - "github.com/akrylysov/pogreb" "github.com/akrylysov/pogreb/fs" + "github.com/go-acme/lego/v4/certificate" + "github.com/rs/zerolog/log" ) type aDB struct { @@ -23,12 +25,27 @@ func (p aDB) Close() error { return p.intern.Sync() } -func (p aDB) Put(key []byte, value []byte) error { - return p.intern.Put(key, value) +func (p aDB) Put(name string, cert *certificate.Resource) error { + var resGob bytes.Buffer + if err := gob.NewEncoder(&resGob).Encode(cert); err != nil { + return err + } + return p.intern.Put([]byte(name), resGob.Bytes()) } -func (p aDB) Get(key []byte) ([]byte, error) { - return p.intern.Get(key) +func (p aDB) Get(name []byte) (*certificate.Resource, error) { + cert := &certificate.Resource{} + resBytes, err := p.intern.Get(name) + if err != nil { + return nil, err + } + if resBytes == nil { + return nil, nil + } + if err = gob.NewDecoder(bytes.NewBuffer(resBytes)).Decode(cert); err != nil { + return nil, err + } + return cert, nil } func (p aDB) Delete(key []byte) error {