config, userdirs, refactor

This commit is contained in:
Hedy Li 2021-07-19 12:58:41 +08:00
parent b9ec9f24b9
commit 7006d7e6a4
Signed by: hedy
GPG Key ID: B51B5A8D1B176372
5 changed files with 162 additions and 46 deletions

62
config.go Normal file
View File

@ -0,0 +1,62 @@
package main
import (
"github.com/BurntSushi/toml"
"io/ioutil"
"os"
"fmt"
)
type Config struct {
Port int
Hostname string
RootDir string
UserDirEnable bool
UserDir string
// UserSlug string
DirlistReverse bool
DirlistSort string
DirlistTitles bool
}
var defaultConf = &Config{
Port: 300,
Hostname: "localhost",
RootDir: "/var/spartan/",
DirlistReverse: false,
DirlistSort: "name",
DirlistTitles: true,
UserDirEnable: false,
UserDir: "public_spartan",
}
func LoadConfig(path string) (*Config, error) {
var err error
var conf Config
// Defaults
conf = *defaultConf
_, err = os.Stat(path)
if os.IsNotExist(err) {
fmt.Println(path, "does not exist, using default configuration values")
return &conf, nil
}
f, err := os.Open(path)
if err == nil {
defer f.Close()
contents, err := ioutil.ReadAll(f)
if err != nil {
return nil, err
}
if _, err = toml.Decode(string(contents), &conf); err != nil {
return nil, err
}
}
if conf.DirlistSort != "name" && conf.DirlistSort != "time" && conf.DirlistSort != "size" {
fmt.Println("Warning: DirlistSort config option is not one of name/time/size, defaulting to name.")
conf.DirlistSort = "name"
}
return &conf, nil
}

View File

@ -11,9 +11,9 @@ import (
"strings" "strings"
) )
func generateDirectoryListing(reqPath, path string) ([]byte, error) { func generateDirectoryListing(reqPath, path string, conf *Config) ([]byte, error) {
dirSort := "time" dirSort := conf.DirlistSort
dirReverse := false dirReverse := conf.DirlistReverse
var listing string var listing string
files, err := ioutil.ReadDir(path) files, err := ioutil.ReadDir(path)
if err != nil { if err != nil {
@ -22,6 +22,7 @@ func generateDirectoryListing(reqPath, path string) ([]byte, error) {
listing = "# Directory listing\n\n" listing = "# Directory listing\n\n"
// TODO: custom dirlist header in config // TODO: custom dirlist header in config
// Do "up" link first // Do "up" link first
reqPath = strings.ReplaceAll(reqPath, "/.", "")
if reqPath != "/" { if reqPath != "/" {
if strings.HasSuffix(reqPath, "/") { if strings.HasSuffix(reqPath, "/") {
reqPath = reqPath[:len(reqPath)-1] reqPath = reqPath[:len(reqPath)-1]
@ -59,13 +60,13 @@ func generateDirectoryListing(reqPath, path string) ([]byte, error) {
if file.IsDir() { if file.IsDir() {
relativeUrl += "/" relativeUrl += "/"
} }
listing += fmt.Sprintf("=> %s %s\n", relativeUrl, generatePrettyFileLabel(file, path)) listing += fmt.Sprintf("=> %s %s\n", relativeUrl, generatePrettyFileLabel(file, path, conf))
} }
return []byte(listing), nil return []byte(listing), nil
} }
func generatePrettyFileLabel(info os.FileInfo, path string) string { func generatePrettyFileLabel(info os.FileInfo, path string, conf *Config) string {
dirTitles := true // TODO: config dirTitles := conf.DirlistTitles
var size string var size string
if info.IsDir() { if info.IsDir() {
size = " " size = " "

5
go.mod
View File

@ -1,3 +1,8 @@
module spsrv module spsrv
go 1.15 go 1.15
require (
github.com/BurntSushi/toml v0.3.1
github.com/spf13/pflag v1.0.5
)

4
go.sum Normal file
View File

@ -0,0 +1,4 @@
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=

124
spsrv.go
View File

@ -3,7 +3,6 @@ package main
import ( import (
"bufio" "bufio"
"errors" "errors"
"flag"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -14,6 +13,8 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
flag "github.com/spf13/pflag"
) )
const ( const (
@ -24,38 +25,58 @@ const (
) )
var ( var (
hostname = flag.String("h", "localhost", "hostname") hostname = flag.StringP("hostname", "h", defaultConf.Hostname, "Hostname")
contentDir = flag.String("d", "/var/spartan", "content directory") port = flag.IntP("port", "p", defaultConf.Port, "Port to listen to")
port = flag.Int("p", 300, "port number") rootDir = flag.StringP("dir", "d", defaultConf.RootDir, "Root content directory")
confPath = flag.StringP("config", "c", "/etc/spsrv.conf", "Path to config file")
) )
func main() { func main() {
flag.Parse() flag.Parse()
conf, err := LoadConfig(*confPath)
if err != nil {
fmt.Println("Error loading config")
fmt.Println(err.Error())
return
}
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) // This allows users overriding values in config via the CLI
if *hostname != defaultConf.Hostname {
conf.Hostname = *hostname
}
if *port != defaultConf.Port {
conf.Port = *port
}
if *rootDir != defaultConf.RootDir {
conf.RootDir = *rootDir
}
// TODO: do something with conf.Hostname (b(like restricting to ipv4/6 etc)
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", conf.Port))
if err != nil { if err != nil {
log.Fatalf("Unable to listen: %s", err) log.Fatalf("Unable to listen: %s", err)
} }
log.Println("✨ You are now running on spsrv ✨") log.Println("✨ You are now running on spsrv ✨")
log.Printf("Listening for connections on port: %d", *port) log.Printf("Listening for connections on port: %d", conf.Port)
serveSpartan(listener) serveSpartan(listener, conf)
} }
// serveSpartan accepts connections and returns content // serveSpartan accepts connections and returns content
func serveSpartan(listener net.Listener) { func serveSpartan(listener net.Listener, conf *Config) {
for { for {
// Blocking until request received
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
continue continue
} }
log.Println("Accepted connection") log.Println("Accepted connection from", conn.RemoteAddr())
go handleConnection(conn) go handleConnection(conn, conf)
} }
} }
// handleConnection handles a request and does the reponse // handleConnection handles a request and does the response
func handleConnection(conn io.ReadWriteCloser) { func handleConnection(conn io.ReadWriteCloser, conf *Config) {
defer conn.Close() defer conn.Close()
// Check the size of the request buffer. // Check the size of the request buffer.
@ -73,59 +94,82 @@ func handleConnection(conn io.ReadWriteCloser) {
// Parse incoming request URL. // Parse incoming request URL.
request := s.Text() request := s.Text()
path, _, err := parseRequest(request) reqPath, _, err := parseRequest(request)
if err != nil { if err != nil {
sendResponseHeader(conn, statusClientError, "Bad request") sendResponseHeader(conn, statusClientError, "Bad request")
return return
} }
log.Println("Handling request:", request) log.Println("Handling request:", request)
if strings.Contains(reqPath, ".."){
sendResponseHeader(conn, statusClientError, "Stop it with your directory traversal technique!")
return
}
// Time to fetch the files! // Time to fetch the files!
serveFile(conn, path) path := resolvePath(reqPath, conf)
serveFile(conn, reqPath, path, conf)
log.Println("Closed connection") log.Println("Closed connection")
} }
// serveFile serves opens the requested path and returns the file content func resolvePath(reqPath string, conf *Config) (path string) {
func serveFile(conn io.ReadWriteCloser, reqPath string) { // Handle tildes
if strings.HasPrefix(reqPath, "/~") {
bits := strings.Split(reqPath, "/")
username := bits[1][1:]
new_prefix := filepath.Join("/home/", username, conf.UserDir)
path = filepath.Clean(strings.Replace(reqPath, bits[1], new_prefix, 1))
if strings.HasSuffix(reqPath, "/") {
path = filepath.Join(path, "index.gmi")
}
return
}
path = reqPath
// TODO: [config] default index file for a directory is index.gmi // TODO: [config] default index file for a directory is index.gmi
path := reqPath
if strings.HasSuffix(reqPath, "/") || reqPath == "" { if strings.HasSuffix(reqPath, "/") || reqPath == "" {
path = filepath.Join(reqPath, "index.gmi") path = filepath.Join(reqPath, "index.gmi")
} }
cleanPath := filepath.Clean(path) path = filepath.Clean(filepath.Join(conf.RootDir, path))
return
}
// serveFile serves opens the requested path and returns the file content
func serveFile(conn io.ReadWriteCloser, reqPath, path string, conf *Config) {
// If the content directory is not specified as an absolute path, make it absolute. // If the content directory is not specified as an absolute path, make it absolute.
prefixDir := "" // prefixDir := ""
var rootDir http.Dir // var rootDir http.Dir
if !strings.HasPrefix(*contentDir, "/") { // if !strings.HasPrefix(conf.RootDir, "/") {
prefixDir, _ = os.Getwd() // prefixDir, _ = os.Getwd()
} // }
// Avoid directory traversal type attacks. // Avoid directory traversal type attacks.
rootDir = http.Dir(prefixDir + strings.Replace(*contentDir, ".", "", -1)) // rootDir = http.Dir(prefixDir + strings.Replace(conf.RootDir, ".", "", -1))
// Open the requested resource. // Open the requested resource.
var content []byte var content []byte
log.Printf("Fetching: %s", cleanPath) log.Printf("Fetching: %s", path)
f, err := rootDir.Open(cleanPath) f, err := os.Open(path)
if err != nil { if err != nil {
// not putting the /folder to /folder/ redirect here because folder can still // not putting the /folder to /folder/ redirect here because folder can still
// be opened without errors // be opened without errors
// Directory listing // Directory listing
if strings.HasSuffix(cleanPath, "index.gmi") { if strings.HasSuffix(path, "index.gmi") {
fullPath := filepath.Join(fmt.Sprint(rootDir), cleanPath) // fullPath := filepath.Join(fmt.Sprint(rootDir), path)
fullPath := path
if _, err := os.Stat(fullPath); os.IsNotExist(err) { if _, err := os.Stat(fullPath); os.IsNotExist(err) {
// If and only if the path is index.gmi AND index.gmi does not exist // If and only if the path is index.gmi AND index.gmi does not exist
fullPath = strings.TrimSuffix(fullPath, "index.gmi") fullPath = strings.TrimSuffix(fullPath, "index.gmi")
log.Println("Generating directory listing:", fullPath) if _, err := os.Stat(fullPath); err == nil {
content, err = generateDirectoryListing(reqPath, fullPath) // If the directly exists
if err != nil { log.Println("Generating directory listing:", fullPath)
log.Println(err) content, err = generateDirectoryListing(reqPath, fullPath, conf)
sendResponseHeader(conn, statusServerError, "Error generating directory listing") if err != nil {
log.Println(err)
sendResponseHeader(conn, statusServerError, "Error generating directory listing")
return
}
path += ".gmi" // OOF, this is just to have the text/gemini meta later lol
serveContent(conn, content, path)
return return
} }
cleanPath += ".gmi" // OOF, this is just to have the text/gemini meta later lol
serveContent(conn, content, cleanPath)
return
} }
} }
log.Println(err) log.Println(err)
@ -141,16 +185,16 @@ func serveFile(conn io.ReadWriteCloser, reqPath string) {
// I wish I could check if err is a "path/to/dir" is a directory error // I wish I could check if err is a "path/to/dir" is a directory error
// but I couldn't figure out how, so this check below is the best I // but I couldn't figure out how, so this check below is the best I
// can come up with I guess // can come up with I guess
if _, err := os.Stat(filepath.Join(fmt.Sprint(rootDir), cleanPath+"/")); !os.IsNotExist(err) { if _, err := os.Stat(path + "/"); !os.IsNotExist(err) {
log.Println("Redirecting", cleanPath, "to", cleanPath+"/") log.Println("Redirecting", path, "to", reqPath+"/")
sendResponseHeader(conn, statusRedirect, cleanPath+"/") sendResponseHeader(conn, statusRedirect, reqPath+"/")
return return
} }
log.Println(err) log.Println(err)
sendResponseHeader(conn, statusServerError, "Resource could not be read") sendResponseHeader(conn, statusServerError, "Resource could not be read")
return return
} }
serveContent(conn, content, cleanPath) serveContent(conn, content, path)
} }
func serveContent(conn io.ReadWriteCloser, content []byte, path string) { func serveContent(conn io.ReadWriteCloser, content []byte, path string) {