add hostname config thing

This commit is contained in:
Hedy Li 2021-07-26 10:41:31 +08:00
parent 6c71b360ac
commit adb339c0b3
Signed by: hedy
GPG Key ID: B51B5A8D1B176372
2 changed files with 30 additions and 21 deletions

View File

@ -8,27 +8,29 @@ import (
) )
type Config struct { type Config struct {
Port int Port int
Hostname string Hostname string
RootDir string RootDir string
UserDirEnable bool UserDirEnable bool
UserDir string UserDir string
DirlistEnable bool DirlistEnable bool
DirlistReverse bool DirlistReverse bool
DirlistSort string DirlistSort string
DirlistTitles bool DirlistTitles bool
RestrictHostname string
} }
var defaultConf = &Config{ var defaultConf = &Config{
Port: 300, Port: 300,
Hostname: "localhost", Hostname: "localhost",
RootDir: "/var/spartan/", RootDir: "/var/spartan/",
DirlistEnable: true, DirlistEnable: true,
DirlistReverse: false, DirlistReverse: false,
DirlistSort: "name", DirlistSort: "name",
DirlistTitles: true, DirlistTitles: true,
UserDirEnable: true, UserDirEnable: true,
UserDir: "public_spartan", UserDir: "public_spartan",
RestrictHostname: "",
} }
func LoadConfig(path string) (*Config, error) { func LoadConfig(path string) (*Config, error) {

View File

@ -94,11 +94,18 @@ func handleConnection(conn io.ReadWriteCloser, conf *Config) {
// Parse incoming request URL. // Parse incoming request URL.
request := s.Text() request := s.Text()
reqPath, _, err := parseRequest(request) host, reqPath, _, err := parseRequest(request)
if err != nil { if err != nil {
sendResponseHeader(conn, statusClientError, "Bad request") sendResponseHeader(conn, statusClientError, "Bad request")
return return
} }
if conf.RestrictHostname != "" {
if conf.RestrictHostname != host {
log.Println("Request host does not match conf.RestrictHostname, returning client error.")
sendResponseHeader(conn, statusClientError, "No proxying to other hosts!")
return
}
}
log.Println("Handling request:", request) log.Println("Handling request:", request)
if strings.Contains(reqPath, "..") { if strings.Contains(reqPath, "..") {
sendResponseHeader(conn, statusClientError, "Stop it with your directory traversal technique!") sendResponseHeader(conn, statusClientError, "Stop it with your directory traversal technique!")
@ -231,13 +238,13 @@ func sendResponseContent(conn io.ReadWriteCloser, content []byte) {
} }
} }
func parseRequest(r string) (path string, contentLength int, err error) { func parseRequest(r string) (host, path string, contentLength int, err error) {
parts := strings.Split(r, " ") parts := strings.Split(r, " ")
if len(parts) != 3 { if len(parts) != 3 {
err = errors.New("Bad request") err = errors.New("Bad request")
return return
} }
_, path, contentLengthString := parts[0], parts[1], parts[2] host, path, contentLengthString := parts[0], parts[1], parts[2]
contentLength, err = strconv.Atoi(contentLengthString) contentLength, err = strconv.Atoi(contentLengthString)
if err != nil { if err != nil {
return return