gemini-echo-server/gemini-echo-server.c

727 lines
18 KiB
C

#include <assert.h>
#include <errno.h>
#include <inttypes.h>
#include <limits.h>
#include <signal.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <fcntl.h>
#include <netdb.h>
#include <poll.h>
#include <unistd.h>
#include <err.h>
#include <sysexits.h>
#include <tls.h>
#ifndef BACKLOG
#define BACKLOG 5
#endif
#ifndef TIMEOUT
#define TIMEOUT 5
#endif
#ifndef MAXCONN
#define MAXCONN 5
#endif
#ifndef SLOWDOWN
#define SLOWDOWN 10
#endif
#define ISASCIILOWER(c) ((c) >= 'a' && (c) <= 'z')
#define ISASCIIUPPER(c) ((c) >= 'A' && (c) <= 'Z')
#define ISASCIIDIGIT(c) ((c) >= '0' && (c) <= '9')
#define ISASCIIXDIGIT(c) (ISASCIIDIGIT(c) || \
((c) >= 'A' && (c) <= 'F') || ((c) >= 'a' && (c) <= 'f'))
#define ISASCIIALPHA(c) (ISASCIILOWER(c) || ISASCIIUPPER(c))
#define ISASCIIALNUM(c) (ISASCIIALPHA(c) || ISASCIIDIGIT(c))
#define DEBASE(c) \
((c) - (ISASCIIDIGIT(c) ? '0' : \
(ISASCIIUPPER(c) ? 'A' : 'a') - 10))
#define GEMINI_DEFAULT_PORT "1965"
#define GEMINI_URL_MAX 1024
#define GEMINI_META_MAX 1024
#define NIO_R(i, r, f, c, b, s) \
do { \
while ((i) < (s) && \
((r) = (f)((c), (b) + (i), (s) - (i))) > 0) \
(i) += (r); \
} while (0)
#define NIO(i, r, f, c, b, s) \
do { \
(i) = 0; \
(r) = 0; \
NIO_R((i), (r), (f), (c), (b), (s)); \
} while (0)
enum {
CONN_STATE_LISTEN,
CONN_STATE_HANDSHAKE,
CONN_STATE_RECV,
CONN_STATE_SEND,
CONN_STATE_CLOSE
};
struct connection {
char request[GEMINI_URL_MAX + 2];
char response[GEMINI_URL_MAX + 32];
time_t ctime;
struct tls *tls;
size_t reqi;
size_t reql;
size_t resi;
size_t resl;
int socket;
int state;
};
#define CONNECTIONS_INITIALIZER \
{ NULL, NULL, NULL, "echo? ", NULL, NULL, \
0, SIZE_MAX, 0, SIZE_MAX, BACKLOG, TIMEOUT }
struct connections {
struct tls *server;
struct connection *list;
struct pollfd *pfds;
const char *prompt;
void (*oninput)(struct connections *, size_t);
void *ctx;
size_t count;
size_t slowdown;
size_t curmax;
size_t max;
int backlog;
int timeout;
};
static void exitsig(int);
static void gracefulexitsig(int);
static void rmpidfile(void);
static uintmax_t strtou(const char *, uintmax_t, uintmax_t);
static void oninput(struct connections *, size_t);
static int addaddr(struct connections *, const char *,
const char *);
static int addconn(struct connections *, int, int);
static void remconn(struct connections *, size_t);
static int evconn(struct connections *, size_t);
static size_t depct(char *, size_t, const char *, size_t);
static size_t pctmatch(const char *, size_t, const char *);
static int validurl(const char *, size_t);
static int sigexit = 0;
static int siggracefulexit = 0;
static const char *pidfile = NULL;
int
main(int argc, char *argv[])
{
struct connections conns = CONNECTIONS_INITIALIZER;
time_t now;
struct tls *tls;
struct tls_config *config;
const char *port = GEMINI_DEFAULT_PORT;
const char *keyfile = NULL;
const char *certfile = NULL;
FILE *fp;
size_t max = MAXCONN;
size_t slow = SLOWDOWN;
size_t i;
int timeout = INFTIM;
int c;
#ifdef __OpenBSD__
if (pledge("stdio rpath wpath cpath inet proc unveil", NULL)
!= 0)
err(EX_OSERR, "Could not pledge(2)");
#endif
while ((c = getopt(argc, argv, "D:P:b:c:k:m:p:s:t:")) != -1)
switch (c) {
case 'D':
pidfile = optarg;
break;
case 'P':
if (strlen(optarg) > GEMINI_META_MAX)
errx(EX_USAGE, "The prompt is too big");
conns.prompt = optarg;
break;
case 'b':
conns.backlog = (int)strtou(optarg, 0, INT_MAX);
case 'c':
certfile = optarg;
break;
case 'k':
keyfile = optarg;
break;
case 's':
slow = (size_t)strtou(optarg, 0, SIZE_MAX);
break;
case 'm':
max = (size_t)strtou(optarg, 0, SIZE_MAX);
break;
case 'p':
port = optarg;
break;
case 't':
conns.timeout = (int)strtou(optarg, 0, INT_MAX);
break;
default:
goto usage;
}
argc -= optind;
argv += optind;
if (keyfile == NULL || certfile == NULL)
goto usage;
#ifdef __OpenBSD__
if (unveil(certfile, "r") != 0)
err(EX_NOINPUT, "Failed to unveil: %s", certfile);
if (unveil(keyfile, "r") != 0)
err(EX_NOINPUT, "Failed to unveil: %s", keyfile);
if (pidfile != NULL && unveil(pidfile, "wc") != 0)
err(EX_NOINPUT, "Failed to unveil: %s", pidfile);
if (pledge("stdio rpath wpath cpath inet proc", NULL) != 0)
err(EX_OSERR, "Could not pledge(2)");
#endif
if ((conns.server = tls_server()) == NULL ||
(config = tls_config_new()) == NULL)
err(EX_OSERR, "Could not create the TLS structures");
if (tls_config_set_key_file(config, keyfile) != 0 ||
tls_config_set_cert_file(config, certfile) != 0)
errx(EX_SOFTWARE, "TLS-Config Error: %s",
tls_config_error(config));
if (tls_configure(conns.server, config) != 0)
errx(EX_SOFTWARE, "TLS Error: %s", tls_error(tls));
if (*argv == NULL && addaddr(&conns, NULL, port) <= 0)
err(EX_NOHOST, "Could not listen on: *:%s", port);
for (; *argv != NULL; argv++)
if (addaddr(&conns, *argv, port) <= 0)
warn("Could not listen on: %s:%s", *argv, port);
if (conns.count == 0)
return (EX_NOHOST);
if (max > 0 && SIZE_MAX - conns.count >= max)
conns.max = max + conns.count;
conns.slowdown = conns.max;
if (slow > 0)
conns.slowdown -= (conns.max + (slow - 1)) / slow;
if (pidfile != NULL) {
if (daemon(1, 1) != 0)
err(EX_OSERR, "Failed to become a daemon");
if ((fp = fopen(pidfile, "w")) == NULL)
err(EX_CANTCREAT,
"Failed to write the PID file");
(void)fprintf(fp, "%u\n", getpid());
(void)fclose(fp);
(void)atexit(&rmpidfile);
}
(void)signal(SIGINT, &gracefulexitsig);
(void)signal(SIGTERM, &exitsig);
#ifdef __OpenBSD__
if (pledge("stdio rpath cpath inet", NULL) != 0)
err(EX_OSERR, "Could not pledge(2)");
#endif
conns.oninput = &oninput;
while ((c = poll(conns.pfds, conns.count, timeout)) != -1 ||
errno == EINTR) {
if (sigexit || (siggracefulexit && timeout == INFTIM)) {
while (conns.count > 0)
remconn(&conns, conns.count - 1);
free(conns.list);
free(conns.pfds);
return (0);
}
/* Remove all connections that timed-out */
now = time(NULL);
for (i = 0; i < conns.count; i++)
while (i < conns.count &&
conns.list[i].state != CONN_STATE_LISTEN &&
difftime(now, conns.list[i].ctime)
>= (double)conns.timeout)
remconn(&conns, i);
if (c <= 0)
continue;
/* Remove dead connections */
for (i = 0; i < conns.count; i++)
while (i < conns.count &&
(conns.pfds[i].revents &
(POLLNVAL | POLLHUP | POLLERR)))
remconn(&conns, i);
/* Check for events */
for (i = 0; i < conns.count; i++)
while (i < conns.count &&
(conns.pfds[i].revents &
(POLLIN | POLLOUT)) != 0 &&
evconn(&conns, i) != 0)
remconn(&conns, i);
/*
* Prevent poll(2) from timeing out when there
* are no active connections.
*/
for (i = 0; i < conns.count &&
conns.list[i].state == CONN_STATE_LISTEN; i++)
/* do nothing */;
timeout = i < conns.count && conns.timeout > 0 ?
conns.timeout : INFTIM;
}
err(EX_OSERR, "poll(2) error");
usage:
(void)fprintf(stderr,
"usage: gemini-echo-server [-D pid-file] [-b backlog]\n"
" [-m maximum-connections] "
"[-p port]\n"
" [-s slowdown-divisor] "
"[-t timeout]\n"
" -c cert-file -k key-file "
"[host ...]\n");
return (EX_USAGE);
}
void
exitsig(int sig)
{
sigexit = 1;
}
void
gracefulexitsig(int sig)
{
siggracefulexit = 1;
}
void
rmpidfile(void)
{
if (pidfile != NULL)
(void)remove(pidfile);
}
uintmax_t
strtou(const char *str, uintmax_t n, uintmax_t m)
{
char *p = NULL;
uintmax_t um;
errno = 0;
um = strtoumax(str, &p, 0);
if (errno == 0 && (p == NULL || *p != '\0'))
errno = EINVAL;
if (errno == 0 && (um < n || um > m))
errno = ERANGE;
if (errno != 0)
err(EX_USAGE, "Bad number: %s", str);
return (um);
}
void
oninput(struct connections *conns, size_t i)
{
/* sizeof(conns->list[i].request) < INT_MAX */
(void)printf("%.*s\n",
(int)(conns->list[i].reql - conns->list[i].reqi),
conns->list[i].request + conns->list[i].reqi);
}
int
addaddr(struct connections *conns, const char *addr, const char *serv)
{
struct addrinfo *ai;
struct addrinfo *a;
size_t old = conns->count;
int fd;
if (getaddrinfo(addr, serv, NULL, &ai) != 0)
return (-1);
for (a = ai; a != NULL; a = a->ai_next) {
/* Gemini uses TCP */
if (a->ai_protocol != IPPROTO_TCP)
continue;
if ((fd = socket(a->ai_family, SOCK_STREAM,
IPPROTO_TCP)) < 0) {
warn("Failed to open a socket");
continue;
}
if (bind(fd, a->ai_addr, a->ai_addrlen) < 0) {
close(fd);
warn("Failed to bind(2)");
continue;
}
if (listen(fd, conns->backlog) != 0) {
close(fd);
warn("Failed to listen(2)");
continue;
}
if (addconn(conns, fd, CONN_STATE_LISTEN) != 0) {
close(fd);
warn("Failed to add a listening connection");
continue;
}
}
return (conns->count - old);
}
int
addconn(struct connections *conns, int fd, int state)
{
struct connection *conn;
void *p;
size_t n;
size_t i;
int f;
if (conns->count >= conns->curmax) {
if (conns->curmax >= conns->max)
return (1);
/* Grow the lists */
n = conns->count + 1;
if ((p = reallocarray(conns->list, n,
sizeof(conns->list[0]))) == NULL)
return (-1);
conns->list = p;
if ((p = reallocarray(conns->pfds, n,
sizeof(conns->pfds[0]))) == NULL)
return (-1);
conns->pfds = p;
conns->curmax = n;
}
/* Initialize the connection structure */
conns->list[conns->count].ctime = time(NULL);
conns->list[conns->count].tls = NULL;
conns->list[conns->count].reqi = 0;
conns->list[conns->count].reql = 0;
conns->list[conns->count].resi = 0;
conns->list[conns->count].resl = 0;
conns->list[conns->count].socket = fd;
conns->list[conns->count].state = state;
conns->pfds[conns->count].fd = fd;
conns->pfds[conns->count].events = POLLIN;
conns->pfds[conns->count].revents = 0;
/* Set O_NONBLOCK */
if ((f = fcntl(fd, F_GETFL)) < 0 ||
fcntl(fd, F_SETFL, f | O_NONBLOCK) < 0)
return (-1);
if (state == CONN_STATE_HANDSHAKE) {
/* Initialize TLS */
if (tls_accept_socket(conns->server,
&conns->list[conns->count].tls, fd) != 0)
return (-2);
conns->pfds[conns->count].events |= POLLOUT;
}
if (++conns->count >= conns->max) {
/* Stop allowing new connections */
for (i = 0; i < conns->count; i++)
if (conns->list[i].state == CONN_STATE_LISTEN)
conns->pfds[i].events = 0;
}
return (0);
}
void
remconn(struct connections *conns, size_t i)
{
size_t j;
if (i >= conns->count)
return;
if (conns->count == conns->max) {
/* Allow new connections */
for (j = 0; j < conns->count; j++)
if (conns->list[j].state == CONN_STATE_LISTEN)
conns->pfds[j].events = POLLIN;
}
if (conns->list[i].tls != NULL)
tls_free(conns->list[i].tls);
if (conns->list[i].socket >= 0)
(void)close(conns->list[i].socket);
if (i + 1 == conns->count) {
conns->count--;
return;
}
/* There can not be overflows */
assert(sizeof(conns->list[0])
<= SIZE_MAX / conns->count - i - 1);
assert(sizeof(conns->pfds[0])
<= SIZE_MAX / conns->count - i - 1);
/* Shift the lists down one place */
(void)memmove(conns->list + i, conns->list + i + 1,
sizeof(conns->list[0]) * (conns->count - i - 1));
(void)memmove(conns->pfds + i, conns->pfds + i + 1,
sizeof(conns->pfds[0]) * (conns->count - i - 1));
conns->count--;
}
int
evconn(struct connections *conns, size_t i)
{
struct connection *conn;
const char *response;
ssize_t r;
int e;
assert(i < conns->count);
conn = &conns->list[i];
switch (conn->state) {
case CONN_STATE_LISTEN:
/* Try to accept a new connection */
if ((e = accept(conn->socket, NULL, NULL)) >= 0 &&
addconn(conns, e, CONN_STATE_HANDSHAKE) == 0 &&
evconn(conns, conns->count - 1) != 0)
remconn(conns, conns->count - 1);
break;
case CONN_STATE_HANDSHAKE:
/* Try to complete the TLS handshake */
switch(e = tls_handshake(conn->tls)) {
case TLS_WANT_POLLIN:
conns->pfds[i].events = POLLIN;
break;
case TLS_WANT_POLLOUT:
conns->pfds[i].events = POLLOUT;
break;
case -1:
return (-1);
case 0:
if (conns->count >= conns->slowdown) {
/* Status 44: Slowdown */
conn->resl = snprintf(conn->response,
sizeof(conn->response),
"44\t%d\r\n", conns->timeout);
conn->state = CONN_STATE_SEND;
} else
conn->state = CONN_STATE_RECV;
return (evconn(conns, i));
}
break;
case CONN_STATE_RECV:
/* Try to receive the request */
NIO_R(conn->reql, r, tls_read, conn->tls,
conn->request, sizeof(conn->request));
if (r == -1 /* Error */)
return (-1);
if (r == TLS_WANT_POLLIN)
conns->pfds[i].events = POLLIN;
else if (r == TLS_WANT_POLLOUT)
conns->pfds[i].events = POLLOUT;
if (conn->reql > 0 &&
conn->request[conn->reql - 1] == '\n') {
/*
* The request is complete, with no known
* extraneous characters.
*/
/* Strip off the /\r?\n/ */
conn->reql--;
if (conn->reql > 0 &&
conn->request[conn->reql - 1] == '\r')
conn->reql--;
if (!validurl(conn->request, conn->reql)) {
response = "59\tInvalid URL\r\n";
goto response;
}
/* Strip off the fragment */
for (conn->reqi = 0; conn->reqi < conn->reql &&
conn->request[conn->reqi] != '#';
conn->reqi++)
/* do nothing */;
conn->reql = conn->reqi;
/* Find the query */
for (conn->reqi = 0; conn->reqi < conn->reql &&
conn->request[conn->reqi] != '?';
conn->reqi++)
/* do nothing */;
/* If there is no query, ask for one */
if (conn->reqi == conn->reql) {
/* Status 10: Input */
conn->resl = snprintf(conn->response,
sizeof(conn->response),
"10\t%s\r\n", conns->prompt);
conn->state = CONN_STATE_SEND;
return (evconn(conns, i));
}
conn->reqi++;
if (conns->oninput != NULL)
conns->oninput(conns, i);
/* Echo the un-escaped query as the response. */
/* Status 20: OK */
/*
* sizeof(conn->response) >= GEMINI_URL_MAX +
* strlen("20\ttext/plain; charset=utf-8\r\n")
*/
(void)strncpy(conn->response,
"20\ttext/plain; charset=utf-8\r\n",
sizeof(conn->response));
r = strnlen(conn->response,
sizeof(conn->response));
r += depct(conn->response + r,
sizeof(conn->response) - r,
conn->request + conn->reqi,
conn->reql - conn->reqi);
assert(r < sizeof(conn->response));
/* Add a new-line it needed */
if (conn->response[r - 1] != '\n')
conn->response[r++] = '\n';
conn->resl = r;
conn->state = CONN_STATE_SEND;
return (evconn(conns, i));
}
if (conn->reql == sizeof(conn->request)) {
response = "59\tRequest Too Long\r\n";
goto response;
}
break;
case CONN_STATE_SEND:
/* Try to send the response */
NIO_R(conn->resi, r, tls_write, conn->tls,
conn->response, conn->resl);
if (r == -1 /* Error */)
return (-1);
if (r == TLS_WANT_POLLIN)
conns->pfds[i].events = POLLIN;
else if (r == TLS_WANT_POLLOUT)
conns->pfds[i].events = POLLOUT;
if (conn->resi == conn->resl) {
conn->state = CONN_STATE_CLOSE;
return (evconn(conns, i));
}
break;
case CONN_STATE_CLOSE:
/* Try close the connection */
switch (tls_close(conn->tls)) {
case TLS_WANT_POLLIN:
conns->pfds[i].events = POLLIN;
break;
case TLS_WANT_POLLOUT:
conns->pfds[i].events = POLLOUT;
break;
case 0:
return (1);
default:
return (-1);
}
break;
default:
abort(); /* Unreachable */
}
return (0);
response:
assert(response != NULL &&
strlen(response) < sizeof(conn->response));
conn->resl = strlen(response);
(void)memcpy(conn->response, response, conn->resl);
conn->state = CONN_STATE_SEND;
return (evconn(conns, i));
}
size_t
depct(char *dst, size_t dl, const char *src, size_t sl)
{
size_t i;
size_t j;
for (i = 0, j = 0; i < sl; i++, j++) {
if (src[i] == '%' && i + 2 < sl &&
ISASCIIXDIGIT(src[i + 1]) &&
ISASCIIXDIGIT(src[i + 2])) {
if (j < dl)
dst[j] = DEBASE(src[i + 1]) << 4 |
DEBASE(src[i + 2]);
i += 2;
} else if (j < dl)
dst[j] = src[i];
}
return (j);
}
size_t
pctmatch(const char *s, size_t l, const char *c)
{
size_t i;
if (c == NULL)
c = "";
for (i = 0; i < l; i++) {
if (l - i >= 3 && s[i] == '%' &&
ISASCIIXDIGIT(s[i + 1]) &&
ISASCIIXDIGIT(s[i + 2]))
i += 2;
else if (!ISASCIIALNUM(s[i]) &&
strchr(c, s[i]) == NULL)
break;
}
return (i);
}
int
validurl(const char *req, size_t l)
{
size_t i = 0;
size_t j;
if (l == 0)
return (0);
/* Optionally skip the `scheme://' or the `//' */
/*
* The `//' is treated as part of the `scheme' because
* `hostname:port' should be preferred over `scheme:hostname'.
*/
j = i;
if (ISASCIIALPHA(req[j])) {
for (j++; j < l && (ISASCIIALNUM(req[j]) ||
req[j] == '+' || req[j] == '-' ||
req[j] == '.'); j++)
/* do nothing */;
if (l - j >= 3 && req[j] == ':' && req[j + 1] == '/' &&
req[j + 2] == '/') {
if (strncasecmp(req + i, "gemini", j - i) != 0)
return (0);
i = j + 3;
}
} else if (l - i >= 2 && req[i] == '/' && req[i + 1] == '/')
i += 2;
/* Optionally skip the `userinfo@' */
j = pctmatch(req + i, l - i, "~$&'()*+,-.:;=_~");
if (i + j < l && req[i + j] == '@')
i += j + 1;
/*
* Process any IP-literal.
* It is not really worth it to actually parse IPv6 strings
* just to extract the hostname for tls_connect(3).
* Otherwise process the `reg-name', of which IPv4 addresses
* are a subset.
*/
if (i < l && req[i] == '[') {
for (i++; i < l && (ISASCIIALNUM(req[i]) ||
strchr("~$&'()*+,-.:;=_~", req[i]) != NULL); i++)
/* do nothing */;
if (i >= l || req[i] != ']')
return (0);
i++;
} else if ((j = i + pctmatch(req + i, l - i,
"~$&'()*+,-.;=_~")) > i) {
i = j;
} else {
return (0);
}
if (i < l && req[i] == ':')
for (i++; i < l && ISASCIIDIGIT(req[i]); i++)
/* do nothing */;
if (i < l && req[i] == '/')
i += 1 + pctmatch(req + i + 1, l - i - 1,
"!$&'()*+,-./:;=@_~");
if (i < l && req[i] == '?')
i += 1 + pctmatch(req + i + 1, l - i - 1,
"!$&'()*+,-./:;=?@_~");
if (i < l && req[i] == '#')
i += 1 + pctmatch(req + i + 1, l - i - 1,
"!$&'()*+,-./:;=?@_~");
return (i == l);
}