From e364ce97803ac65b6a1d4a3a00f3a8b9665c478d Mon Sep 17 00:00:00 2001 From: prx Date: Wed, 12 Oct 2022 21:43:44 +0200 Subject: [PATCH] listen on localhost, both ipv4 and ipv6 if available. Use kqueue to handle many connections --- main.c | 165 ++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 100 insertions(+), 65 deletions(-) diff --git a/main.c b/main.c index 8c63a2f..1e252b5 100644 --- a/main.c +++ b/main.c @@ -1,3 +1,5 @@ +#include +#include #include #include @@ -6,23 +8,23 @@ #include #include -#include #include #include #include #include #include -#define PORT "2507" -#define BACKLOG 42 #define DEFAULT_TABLE "iblocked" +#define PORT "2507" +#define MAXSOCK 2 /* ipv4 + ipv6 */ +#define BACKLOG 10 static void *get_in_addr(struct sockaddr *); static void runcmd(const char*, const char**); -static void sigchld(int unused); static void usage(void); +/* return printable ip from sockaddr */ static void *get_in_addr(struct sockaddr *sa) { if (sa->sa_family == AF_INET) @@ -31,6 +33,7 @@ static void *get_in_addr(struct sockaddr *sa) return &(((struct sockaddr_in6*)sa)->sin6_addr); } +/* run cmd in execv() after fork() */ static void runcmd(const char* cmd, const char** arg_list) { pid_t pid = fork(); @@ -42,18 +45,11 @@ static void runcmd(const char* cmd, const char** arg_list) /* if this is reached, then exec failed */ syslog(LOG_DAEMON, "execv error"); err(1,"execv"); + } else { /* parent */ + waitpid(pid, NULL, WNOHANG); } } -void -sigchld(int unused) -{ - (void)unused; - if (signal(SIGCHLD, sigchld) == SIG_ERR) - err(1, "can't install SIGCHLD handler:"); - while (waitpid(WAIT_ANY, NULL, WNOHANG) > 0); -} - static void usage(void) { fprintf(stderr, "usage: %s [table]\n", getprogname()); @@ -65,12 +61,12 @@ main(int argc, char *argv[]) { char ip[INET6_ADDRSTRLEN] = {'\0'}; const char *table = DEFAULT_TABLE; - int sockfd = 0; + const char *err_cause = NULL; int new_fd = 0; - int retval = 0; + int nsock = 0; + int kq = 0; socklen_t sin_size = 0; - struct addrinfo hints, *servinfo, *p; - struct sockaddr_storage client_addr; + int s[MAXSOCK] = {0}; const char *bancmd[] = { "/usr/bin/doas", "-n", "/sbin/pfctl", "-t", table, "-T", "add", ip, @@ -79,10 +75,16 @@ main(int argc, char *argv[]) "/sbin/pfctl", "-k", ip, NULL }; + struct kevent ev[MAXSOCK] = {0}; + struct addrinfo hints, *servinfo, *res; + struct sockaddr_storage client_addr; - + /* safety first */ if (unveil("/usr/bin/doas", "rx") != 0) err(1, "unveil"); + /* necessary to resolve localhost with getaddrinfo() */ + if (unveil("/etc/hosts", "r") != 0) + err(1, "unveil"); if (pledge("stdio inet exec proc rpath", NULL) != 0) err(1, "pledge"); @@ -100,72 +102,105 @@ main(int argc, char *argv[]) hints.ai_socktype = SOCK_STREAM; hints.ai_flags = AI_PASSIVE; - if ((retval = getaddrinfo(NULL, PORT, &hints, &servinfo)) != 0) { + /* get ips for localhost */ + int retval = getaddrinfo("localhost", PORT, &hints, &servinfo); + if (retval != 0) { syslog(LOG_DAEMON, "getaddrinfo failed"); err(1, "getaddrinfo :%s", gai_strerror(retval)); } - /* get a socket and bind */ - for (p = servinfo; p != NULL; p = p->ai_next) { - if ((sockfd = socket(p->ai_family, - p->ai_socktype, - p->ai_protocol)) == -1) { + /* create sockets and bind for each local ip, store them in s[] */ + for (res = servinfo; res && nsock < MAXSOCK; res = res->ai_next) { + + s[nsock] = socket(res->ai_family, + res->ai_socktype, + res->ai_protocol); + if (s[nsock] == -1) { + err_cause = "socket"; + continue; + } + /* make sure PORT can be reused by second IP */ + int yes = 1; + if (setsockopt(s[nsock], SOL_SOCKET, SO_REUSEPORT, &yes, + sizeof(int)) == -1) + err(1, "setsockopt"); + + if (bind(s[nsock], res->ai_addr, res->ai_addrlen) == -1) { + close(s[nsock]); + err_cause = "bind()"; continue; } - if (bind(sockfd, p->ai_addr, p->ai_addrlen) == -1) { - close(sockfd); - continue; - } + if (listen(s[nsock], BACKLOG) == -1) + err_cause = "listen"; - break; + /* log the obtained ip */ + inet_ntop(res->ai_family, + get_in_addr((struct sockaddr *)res->ai_addr), + ip, sizeof(ip)); + syslog(LOG_DAEMON, "listening on %s port %s, muahaha :>", + ip, + PORT); + + nsock++; } + /* clean up no longer used servinfo */ freeaddrinfo(servinfo); - if (p == NULL) { - syslog(LOG_DAEMON, "Failed to bind"); - err(1, "Failed to bind"); - } + if (nsock == 0) + err(1, "Error when calling %s", err_cause); - if (listen(sockfd, BACKLOG) == -1) { - syslog(LOG_DAEMON, "listen failed"); - err(1, "listen"); - } + /* configure events */ + kq = kqueue(); - sigchld(0); + /* add event for each IP */ + for (int i = 0; i <= nsock; i++) + EV_SET(&(ev[i]), s[i], EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, 0); - syslog(LOG_DAEMON, "ready to reap on port %s, muhahaha :>", PORT); + /* register event */ + if (kevent(kq, ev, MAXSOCK, NULL, 0, NULL) == -1) + err(1, "kevent"); - while (1) { - sin_size = sizeof(client_addr); - new_fd = accept(sockfd, - (struct sockaddr*)&client_addr, - &sin_size); + /* infinite loop to wait for connections */ + for (;;) { + int nevents = kevent(kq, NULL, 0, ev, MAXSOCK, NULL); + if (nevents == -1) + err(1, "kevent"); - if (new_fd == -1) - continue; + /* loop for events */ + for (int i = 0; i < nevents; i++) { - /* get client ip */ - inet_ntop(client_addr.ss_family, - get_in_addr((struct sockaddr *)&client_addr), - ip, sizeof(ip)); + if (ev[i].filter & EVFILT_READ) { - close(new_fd); /* no longer needed */ + /* get client ip */ + sin_size = sizeof(client_addr); + new_fd = accept(ev[i].ident, + (struct sockaddr*)&client_addr, + &sin_size); + if (new_fd == -1) + continue; + inet_ntop(client_addr.ss_family, + get_in_addr((struct sockaddr *)&client_addr), + ip, sizeof(ip)); - pid_t id = fork(); - if (id == -1) { - syslog(LOG_DAEMON, "fork error"); - err(1, "fork"); - } else if (id == 0) { /* child process */ - syslog(LOG_DAEMON, "blocking %s", ip); - runcmd(bancmd[0], bancmd); - syslog(LOG_DAEMON, "kill states for %s", ip); - runcmd(killstatecmd[0], killstatecmd); - close(sockfd); - exit(0); - } - } - close(sockfd); + close(new_fd); /* no longer required */ + + /* ban this ip */ + syslog(LOG_DAEMON, "blocking %s", ip); + runcmd(bancmd[0], bancmd); + syslog(LOG_DAEMON, "kill states for %s", ip); + runcmd(killstatecmd[0], killstatecmd); + } + if (ev[i].filter & EVFILT_SIGNAL) { + break; + } + } /* events loop */ + } /* infinite loop */ + + /* probably never reached */ + close(kq); + for (int i = 0; i <= nsock; i++) + close(s[i]); return 0; }