add hostname config thing

This commit is contained in:
Hedy Li 2021-07-26 10:41:31 +08:00
parent 6c71b360ac
commit adb339c0b3
Signed by: hedy
GPG Key ID: B51B5A8D1B176372
2 changed files with 30 additions and 21 deletions

View File

@ -8,27 +8,29 @@ import (
)
type Config struct {
Port int
Hostname string
RootDir string
UserDirEnable bool
UserDir string
DirlistEnable bool
DirlistReverse bool
DirlistSort string
DirlistTitles bool
Port int
Hostname string
RootDir string
UserDirEnable bool
UserDir string
DirlistEnable bool
DirlistReverse bool
DirlistSort string
DirlistTitles bool
RestrictHostname string
}
var defaultConf = &Config{
Port: 300,
Hostname: "localhost",
RootDir: "/var/spartan/",
DirlistEnable: true,
DirlistReverse: false,
DirlistSort: "name",
DirlistTitles: true,
UserDirEnable: true,
UserDir: "public_spartan",
Port: 300,
Hostname: "localhost",
RootDir: "/var/spartan/",
DirlistEnable: true,
DirlistReverse: false,
DirlistSort: "name",
DirlistTitles: true,
UserDirEnable: true,
UserDir: "public_spartan",
RestrictHostname: "",
}
func LoadConfig(path string) (*Config, error) {

View File

@ -94,11 +94,18 @@ func handleConnection(conn io.ReadWriteCloser, conf *Config) {
// Parse incoming request URL.
request := s.Text()
reqPath, _, err := parseRequest(request)
host, reqPath, _, err := parseRequest(request)
if err != nil {
sendResponseHeader(conn, statusClientError, "Bad request")
return
}
if conf.RestrictHostname != "" {
if conf.RestrictHostname != host {
log.Println("Request host does not match conf.RestrictHostname, returning client error.")
sendResponseHeader(conn, statusClientError, "No proxying to other hosts!")
return
}
}
log.Println("Handling request:", request)
if strings.Contains(reqPath, "..") {
sendResponseHeader(conn, statusClientError, "Stop it with your directory traversal technique!")
@ -231,13 +238,13 @@ func sendResponseContent(conn io.ReadWriteCloser, content []byte) {
}
}
func parseRequest(r string) (path string, contentLength int, err error) {
func parseRequest(r string) (host, path string, contentLength int, err error) {
parts := strings.Split(r, " ")
if len(parts) != 3 {
err = errors.New("Bad request")
return
}
_, path, contentLengthString := parts[0], parts[1], parts[2]
host, path, contentLengthString := parts[0], parts[1], parts[2]
contentLength, err = strconv.Atoi(contentLengthString)
if err != nil {
return