Swap fmt for log, and allow HTTP only operation.

This commit is contained in:
Solderpunk 2023-03-02 19:32:08 +01:00
parent 1b42ddadfe
commit 9a083fcbfd
1 changed files with 34 additions and 16 deletions

50
main.go
View File

@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -19,9 +18,11 @@ func main() {
func main_body() int { func main_body() int {
var conf_file string var conf_file string
var http_only bool
// Parse args and read config // Parse args and read config
flag.StringVar(&conf_file, "c", "", "Path to config file") flag.StringVar(&conf_file, "c", "", "Path to config file")
flag.BoolVar(&http_only, "h", false, "HTTP only")
flag.Parse() flag.Parse()
if conf_file == "" { if conf_file == "" {
_, err := os.Stat("/etc/shizaru.conf") _, err := os.Stat("/etc/shizaru.conf")
@ -31,35 +32,50 @@ func main_body() int {
} }
config, err := getConfig(conf_file) config, err := getConfig(conf_file)
if err != nil { if err != nil {
fmt.Println("Error reading config file " + conf_file) log.Println("Error reading config file " + conf_file)
return 1 return 1
} }
https := ! http_only
// Open logfile // Open logfile
logfile, err := os.OpenFile(config.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) logfile, err := os.OpenFile(config.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
fmt.Println("Error opening log file " + config.LogPath + ".") log.Println("Error opening log file " + config.LogPath + ".")
return 2 return 2
} }
defer logfile.Close() defer logfile.Close()
errs := make(chan error, 2) // Configure HTTP and HTTPS servers
// Start the HTTP server, which redirect all incoming connections to HTTPS // By default, all the HTTP server does is redirect everything to HTTPS.
// Alternatively, serve *only* on HTTP, for use behind nginx or similar.
var http_server *http.Server
var https_server *http.Server
http.HandleFunc("/", LoggingWrapper(logfile, GetHandler(config))) http.HandleFunc("/", LoggingWrapper(logfile, GetHandler(config)))
http_server := &http.Server{Addr: ":"+strconv.Itoa(config.HttpPort), Handler: http.HandlerFunc(GetRedirectTLSHandler(config))} if(http_only) {
http_server = &http.Server{Addr: ":"+strconv.Itoa(config.HttpPort), Handler: nil}
} else {
http_server = &http.Server{Addr: ":"+strconv.Itoa(config.HttpPort), Handler: http.HandlerFunc(GetRedirectTLSHandler(config))}
tlscfg := &tls.Config{
MinVersion: tls.VersionTLS10,
}
https_server = &http.Server{Addr: ":"+strconv.Itoa(config.HttpsPort), Handler: nil, TLSConfig: tlscfg}
}
// Start HTTP server
errs := make(chan error, 2)
go func() { go func() {
errs <- http_server.ListenAndServe() errs <- http_server.ListenAndServe()
}() }()
tlscfg := &tls.Config{ // Start HTTPS server
MinVersion: tls.VersionTLS10, if(https) {
go func() {
errs <- https_server.ListenAndServe()
}()
log.Println("Listening on ports " + strconv.Itoa(config.HttpPort) + " and " + strconv.Itoa(config.HttpsPort) + "...")
} else {
log.Println("Listening on port " + strconv.Itoa(config.HttpPort) + "...")
} }
// Start the HTTPS server which actually handles most traffic.
https_server := &http.Server{Addr: ":"+strconv.Itoa(config.HttpsPort), Handler: nil, TLSConfig: tlscfg}
go func() {
errs <- https_server.ListenAndServeTLS(config.CertPath, config.KeyPath)
}()
fmt.Println("Listening on ports " + strconv.Itoa(config.HttpPort) + " and " + strconv.Itoa(config.HttpsPort) + "...")
// Listen for signals to gracefully shutdown // Listen for signals to gracefully shutdown
stop := make(chan os.Signal, 1) stop := make(chan os.Signal, 1)
@ -69,9 +85,11 @@ func main_body() int {
// Wait for a signal or an error // Wait for a signal or an error
select { select {
case <-stop: case <-stop:
fmt.Println("Shutting down!") log.Println("Shutting down!")
http_server.Shutdown(context.Background()) http_server.Shutdown(context.Background())
https_server.Shutdown(context.Background()) if(https) {
https_server.Shutdown(context.Background())
}
case err := <-errs: case err := <-errs:
log.Println("Fatal: " + err.Error()) log.Println("Fatal: " + err.Error())
} }