#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef BACKLOG #define BACKLOG 5 #endif #ifndef TIMEOUT #define TIMEOUT 5 #endif #ifndef MAXCONN #define MAXCONN 5 #endif #ifndef SLOWDOWN #define SLOWDOWN 10 #endif #define ISASCIILOWER(c) ((c) >= 'a' && (c) <= 'z') #define ISASCIIUPPER(c) ((c) >= 'A' && (c) <= 'Z') #define ISASCIIDIGIT(c) ((c) >= '0' && (c) <= '9') #define ISASCIIXDIGIT(c) (ISASCIIDIGIT(c) || \ ((c) >= 'A' && (c) <= 'F') || ((c) >= 'a' && (c) <= 'f')) #define ISASCIIALPHA(c) (ISASCIILOWER(c) || ISASCIIUPPER(c)) #define ISASCIIALNUM(c) (ISASCIIALPHA(c) || ISASCIIDIGIT(c)) #define DEBASE(c) \ ((c) - (ISASCIIDIGIT(c) ? '0' : \ (ISASCIIUPPER(c) ? 'A' : 'a') - 10)) #define GEMINI_DEFAULT_PORT "1965" #define GEMINI_URL_MAX 1024 #define GEMINI_META_MAX 1024 #define NIO_R(i, r, f, c, b, s) \ do { \ while ((i) < (s) && \ ((r) = (f)((c), (b) + (i), (s) - (i))) > 0) \ (i) += (r); \ } while (0) #define NIO(i, r, f, c, b, s) \ do { \ (i) = 0; \ (r) = 0; \ NIO_R((i), (r), (f), (c), (b), (s)); \ } while (0) enum { CONN_STATE_LISTEN, CONN_STATE_HANDSHAKE, CONN_STATE_RECV, CONN_STATE_SEND, CONN_STATE_CLOSE }; struct connection { char request[GEMINI_URL_MAX + 2]; char response[GEMINI_URL_MAX + 32]; time_t ctime; struct tls *tls; size_t reqi; size_t reql; size_t resi; size_t resl; int socket; int state; }; #define CONNECTIONS_INITIALIZER \ { NULL, NULL, NULL, "echo? ", NULL, NULL, \ 0, SIZE_MAX, 0, SIZE_MAX, BACKLOG, TIMEOUT } struct connections { struct tls *server; struct connection *list; struct pollfd *pfds; const char *prompt; void (*oninput)(struct connections *, size_t); void *ctx; size_t count; size_t slowdown; size_t curmax; size_t max; int backlog; int timeout; }; static void exitsig(int); static void gracefulexitsig(int); static void rmpidfile(void); static uintmax_t strtou(const char *, uintmax_t, uintmax_t); static void oninput(struct connections *, size_t); static int addaddr(struct connections *, const char *, const char *); static int addconn(struct connections *, int, int); static void remconn(struct connections *, size_t); static int evconn(struct connections *, size_t); static size_t depct(char *, size_t, const char *, size_t); static size_t pctmatch(const char *, size_t, const char *); static int validurl(const char *, size_t); static int sigexit = 0; static int siggracefulexit = 0; static const char *pidfile = NULL; int main(int argc, char *argv[]) { struct connections conns = CONNECTIONS_INITIALIZER; time_t now; struct tls *tls; struct tls_config *config; const char *port = GEMINI_DEFAULT_PORT; const char *keyfile = NULL; const char *certfile = NULL; FILE *fp; size_t max = MAXCONN; size_t slow = SLOWDOWN; size_t i; int timeout = INFTIM; int c; #ifdef __OpenBSD__ if (pledge("stdio rpath wpath cpath inet proc unveil", NULL) != 0) err(EX_OSERR, "Could not pledge(2)"); #endif while ((c = getopt(argc, argv, "D:P:b:c:k:m:p:s:t:")) != -1) switch (c) { case 'D': pidfile = optarg; break; case 'P': if (strlen(optarg) > GEMINI_META_MAX) errx(EX_USAGE, "The prompt is too big"); conns.prompt = optarg; break; case 'b': conns.backlog = (int)strtou(optarg, 0, INT_MAX); case 'c': certfile = optarg; break; case 'k': keyfile = optarg; break; case 's': slow = (size_t)strtou(optarg, 0, SIZE_MAX); break; case 'm': max = (size_t)strtou(optarg, 0, SIZE_MAX); break; case 'p': port = optarg; break; case 't': conns.timeout = (int)strtou(optarg, 0, INT_MAX); break; default: goto usage; } argc -= optind; argv += optind; if (keyfile == NULL || certfile == NULL) goto usage; #ifdef __OpenBSD__ if (unveil(certfile, "r") != 0) err(EX_NOINPUT, "Failed to unveil: %s", certfile); if (unveil(keyfile, "r") != 0) err(EX_NOINPUT, "Failed to unveil: %s", keyfile); if (pidfile != NULL && unveil(pidfile, "wc") != 0) err(EX_NOINPUT, "Failed to unveil: %s", pidfile); if (pledge("stdio rpath wpath cpath inet proc", NULL) != 0) err(EX_OSERR, "Could not pledge(2)"); #endif if ((conns.server = tls_server()) == NULL || (config = tls_config_new()) == NULL) err(EX_OSERR, "Could not create the TLS structures"); if (tls_config_set_key_file(config, keyfile) != 0 || tls_config_set_cert_file(config, certfile) != 0) errx(EX_SOFTWARE, "TLS-Config Error: %s", tls_config_error(config)); if (tls_configure(conns.server, config) != 0) errx(EX_SOFTWARE, "TLS Error: %s", tls_error(tls)); if (*argv == NULL && addaddr(&conns, NULL, port) <= 0) err(EX_NOHOST, "Could not listen on: *:%s", port); for (; *argv != NULL; argv++) if (addaddr(&conns, *argv, port) <= 0) warn("Could not listen on: %s:%s", *argv, port); if (conns.count == 0) return (EX_NOHOST); if (max > 0 && SIZE_MAX - conns.count >= max) conns.max = max + conns.count; conns.slowdown = conns.max; if (slow > 0) conns.slowdown -= (conns.max + (slow - 1)) / slow; if (pidfile != NULL) { if (daemon(1, 1) != 0) err(EX_OSERR, "Failed to become a daemon"); if ((fp = fopen(pidfile, "w")) == NULL) err(EX_CANTCREAT, "Failed to write the PID file"); (void)fprintf(fp, "%u\n", getpid()); (void)fclose(fp); (void)atexit(&rmpidfile); } (void)signal(SIGINT, &gracefulexitsig); (void)signal(SIGTERM, &exitsig); #ifdef __OpenBSD__ if (pledge("stdio rpath cpath inet", NULL) != 0) err(EX_OSERR, "Could not pledge(2)"); #endif conns.oninput = &oninput; while ((c = poll(conns.pfds, conns.count, timeout)) != -1 || errno == EINTR) { if (sigexit || (siggracefulexit && timeout == INFTIM)) { while (conns.count > 0) remconn(&conns, conns.count - 1); free(conns.list); free(conns.pfds); return (0); } /* Remove all connections that timed-out */ now = time(NULL); for (i = 0; i < conns.count; i++) while (i < conns.count && conns.list[i].state != CONN_STATE_LISTEN && difftime(now, conns.list[i].ctime) >= (double)conns.timeout) remconn(&conns, i); if (c <= 0) continue; /* Remove dead connections */ for (i = 0; i < conns.count; i++) while (i < conns.count && (conns.pfds[i].revents & (POLLNVAL | POLLHUP | POLLERR))) remconn(&conns, i); /* Check for events */ for (i = 0; i < conns.count; i++) while (i < conns.count && (conns.pfds[i].revents & (POLLIN | POLLOUT)) != 0 && evconn(&conns, i) != 0) remconn(&conns, i); /* * Prevent poll(2) from timeing out when there * are no active connections. */ for (i = 0; i < conns.count && conns.list[i].state == CONN_STATE_LISTEN; i++) /* do nothing */; timeout = i < conns.count && conns.timeout > 0 ? conns.timeout : INFTIM; } err(EX_OSERR, "poll(2) error"); usage: (void)fprintf(stderr, "usage: gemini-echo-server [-D pid-file] [-b backlog]\n" " [-m maximum-connections] " "[-p port]\n" " [-s slowdown-divisor] " "[-t timeout]\n" " -c cert-file -k key-file " "[host ...]\n"); return (EX_USAGE); } void exitsig(int sig) { sigexit = 1; } void gracefulexitsig(int sig) { siggracefulexit = 1; } void rmpidfile(void) { if (pidfile != NULL) (void)remove(pidfile); } uintmax_t strtou(const char *str, uintmax_t n, uintmax_t m) { char *p = NULL; uintmax_t um; errno = 0; um = strtoumax(str, &p, 0); if (errno == 0 && (p == NULL || *p != '\0')) errno = EINVAL; if (errno == 0 && (um < n || um > m)) errno = ERANGE; if (errno != 0) err(EX_USAGE, "Bad number: %s", str); return (um); } void oninput(struct connections *conns, size_t i) { /* sizeof(conns->list[i].request) < INT_MAX */ (void)printf("%.*s\n", (int)(conns->list[i].reql - conns->list[i].reqi), conns->list[i].request + conns->list[i].reqi); } int addaddr(struct connections *conns, const char *addr, const char *serv) { struct addrinfo *ai; struct addrinfo *a; size_t old = conns->count; int fd; if (getaddrinfo(addr, serv, NULL, &ai) != 0) return (-1); for (a = ai; a != NULL; a = a->ai_next) { /* Gemini uses TCP */ if (a->ai_protocol != IPPROTO_TCP) continue; if ((fd = socket(a->ai_family, SOCK_STREAM, IPPROTO_TCP)) < 0) { warn("Failed to open a socket"); continue; } if (bind(fd, a->ai_addr, a->ai_addrlen) < 0) { close(fd); warn("Failed to bind(2)"); continue; } if (listen(fd, conns->backlog) != 0) { close(fd); warn("Failed to listen(2)"); continue; } if (addconn(conns, fd, CONN_STATE_LISTEN) != 0) { close(fd); warn("Failed to add a listening connection"); continue; } } return (conns->count - old); } int addconn(struct connections *conns, int fd, int state) { struct connection *conn; void *p; size_t n; size_t i; int f; if (conns->count >= conns->curmax) { if (conns->curmax >= conns->max) return (1); /* Grow the lists */ n = conns->count + 1; if ((p = reallocarray(conns->list, n, sizeof(conns->list[0]))) == NULL) return (-1); conns->list = p; if ((p = reallocarray(conns->pfds, n, sizeof(conns->pfds[0]))) == NULL) return (-1); conns->pfds = p; conns->curmax = n; } /* Initialize the connection structure */ conns->list[conns->count].ctime = time(NULL); conns->list[conns->count].tls = NULL; conns->list[conns->count].reqi = 0; conns->list[conns->count].reql = 0; conns->list[conns->count].resi = 0; conns->list[conns->count].resl = 0; conns->list[conns->count].socket = fd; conns->list[conns->count].state = state; conns->pfds[conns->count].fd = fd; conns->pfds[conns->count].events = POLLIN; conns->pfds[conns->count].revents = 0; /* Set O_NONBLOCK */ if ((f = fcntl(fd, F_GETFL)) < 0 || fcntl(fd, F_SETFL, f | O_NONBLOCK) < 0) return (-1); if (state == CONN_STATE_HANDSHAKE) { /* Initialize TLS */ if (tls_accept_socket(conns->server, &conns->list[conns->count].tls, fd) != 0) return (-2); conns->pfds[conns->count].events |= POLLOUT; } if (++conns->count >= conns->max) { /* Stop allowing new connections */ for (i = 0; i < conns->count; i++) if (conns->list[i].state == CONN_STATE_LISTEN) conns->pfds[i].events = 0; } return (0); } void remconn(struct connections *conns, size_t i) { size_t j; if (i >= conns->count) return; if (conns->count == conns->max) { /* Allow new connections */ for (j = 0; j < conns->count; j++) if (conns->list[j].state == CONN_STATE_LISTEN) conns->pfds[j].events = POLLIN; } if (conns->list[i].tls != NULL) tls_free(conns->list[i].tls); if (conns->list[i].socket >= 0) (void)close(conns->list[i].socket); if (i + 1 == conns->count) { conns->count--; return; } /* There can not be overflows */ assert(sizeof(conns->list[0]) <= SIZE_MAX / conns->count - i - 1); assert(sizeof(conns->pfds[0]) <= SIZE_MAX / conns->count - i - 1); /* Shift the lists down one place */ (void)memmove(conns->list + i, conns->list + i + 1, sizeof(conns->list[0]) * (conns->count - i - 1)); (void)memmove(conns->pfds + i, conns->pfds + i + 1, sizeof(conns->pfds[0]) * (conns->count - i - 1)); conns->count--; } int evconn(struct connections *conns, size_t i) { struct connection *conn; const char *response; ssize_t r; int e; assert(i < conns->count); conn = &conns->list[i]; switch (conn->state) { case CONN_STATE_LISTEN: /* Try to accept a new connection */ if ((e = accept(conn->socket, NULL, NULL)) >= 0 && addconn(conns, e, CONN_STATE_HANDSHAKE) == 0 && evconn(conns, conns->count - 1) != 0) remconn(conns, conns->count - 1); break; case CONN_STATE_HANDSHAKE: /* Try to complete the TLS handshake */ switch(e = tls_handshake(conn->tls)) { case TLS_WANT_POLLIN: conns->pfds[i].events = POLLIN; break; case TLS_WANT_POLLOUT: conns->pfds[i].events = POLLOUT; break; case -1: return (-1); case 0: if (conns->count >= conns->slowdown) { /* Status 44: Slowdown */ conn->resl = snprintf(conn->response, sizeof(conn->response), "44\t%d\r\n", conns->timeout); conn->state = CONN_STATE_SEND; } else conn->state = CONN_STATE_RECV; return (evconn(conns, i)); } break; case CONN_STATE_RECV: /* Try to receive the request */ NIO_R(conn->reql, r, tls_read, conn->tls, conn->request, sizeof(conn->request)); if (r == -1 /* Error */) return (-1); if (r == TLS_WANT_POLLIN) conns->pfds[i].events = POLLIN; else if (r == TLS_WANT_POLLOUT) conns->pfds[i].events = POLLOUT; if (conn->reql > 0 && conn->request[conn->reql - 1] == '\n') { /* * The request is complete, with no known * extraneous characters. */ /* Strip off the /\r?\n/ */ conn->reql--; if (conn->reql > 0 && conn->request[conn->reql - 1] == '\r') conn->reql--; if (!validurl(conn->request, conn->reql)) { response = "59\tInvalid URL\r\n"; goto response; } /* Strip off the fragment */ for (conn->reqi = 0; conn->reqi < conn->reql && conn->request[conn->reqi] != '#'; conn->reqi++) /* do nothing */; conn->reql = conn->reqi; /* Find the query */ for (conn->reqi = 0; conn->reqi < conn->reql && conn->request[conn->reqi] != '?'; conn->reqi++) /* do nothing */; /* If there is no query, ask for one */ if (conn->reqi == conn->reql) { /* Status 10: Input */ conn->resl = snprintf(conn->response, sizeof(conn->response), "10\t%s\r\n", conns->prompt); conn->state = CONN_STATE_SEND; return (evconn(conns, i)); } conn->reqi++; if (conns->oninput != NULL) conns->oninput(conns, i); /* Echo the un-escaped query as the response. */ /* Status 20: OK */ /* * sizeof(conn->response) >= GEMINI_URL_MAX + * strlen("20\ttext/plain; charset=utf-8\r\n") */ (void)strncpy(conn->response, "20\ttext/plain; charset=utf-8\r\n", sizeof(conn->response)); r = strnlen(conn->response, sizeof(conn->response)); r += depct(conn->response + r, sizeof(conn->response) - r, conn->request + conn->reqi, conn->reql - conn->reqi); assert(r < sizeof(conn->response)); /* Add a new-line it needed */ if (conn->response[r - 1] != '\n') conn->response[r++] = '\n'; conn->resl = r; conn->state = CONN_STATE_SEND; return (evconn(conns, i)); } if (conn->reql == sizeof(conn->request)) { response = "59\tRequest Too Long\r\n"; goto response; } break; case CONN_STATE_SEND: /* Try to send the response */ NIO_R(conn->resi, r, tls_write, conn->tls, conn->response, conn->resl); if (r == -1 /* Error */) return (-1); if (r == TLS_WANT_POLLIN) conns->pfds[i].events = POLLIN; else if (r == TLS_WANT_POLLOUT) conns->pfds[i].events = POLLOUT; if (conn->resi == conn->resl) { conn->state = CONN_STATE_CLOSE; return (evconn(conns, i)); } break; case CONN_STATE_CLOSE: /* Try close the connection */ switch (tls_close(conn->tls)) { case TLS_WANT_POLLIN: conns->pfds[i].events = POLLIN; break; case TLS_WANT_POLLOUT: conns->pfds[i].events = POLLOUT; break; case 0: return (1); default: return (-1); } break; default: abort(); /* Unreachable */ } return (0); response: assert(response != NULL && strlen(response) < sizeof(conn->response)); conn->resl = strlen(response); (void)memcpy(conn->response, response, conn->resl); conn->state = CONN_STATE_SEND; return (evconn(conns, i)); } size_t depct(char *dst, size_t dl, const char *src, size_t sl) { size_t i; size_t j; for (i = 0, j = 0; i < sl; i++, j++) { if (src[i] == '%' && i + 2 < sl && ISASCIIXDIGIT(src[i + 1]) && ISASCIIXDIGIT(src[i + 2])) { if (j < dl) dst[j] = DEBASE(src[i + 1]) << 4 | DEBASE(src[i + 2]); i += 2; } else if (j < dl) dst[j] = src[i]; } return (j); } size_t pctmatch(const char *s, size_t l, const char *c) { size_t i; if (c == NULL) c = ""; for (i = 0; i < l; i++) { if (l - i >= 3 && s[i] == '%' && ISASCIIXDIGIT(s[i + 1]) && ISASCIIXDIGIT(s[i + 2])) i += 2; else if (!ISASCIIALNUM(s[i]) && strchr(c, s[i]) == NULL) break; } return (i); } int validurl(const char *req, size_t l) { size_t i = 0; size_t j; if (l == 0) return (0); /* Optionally skip the `scheme://' or the `//' */ /* * The `//' is treated as part of the `scheme' because * `hostname:port' should be preferred over `scheme:hostname'. */ j = i; if (ISASCIIALPHA(req[j])) { for (j++; j < l && (ISASCIIALNUM(req[j]) || req[j] == '+' || req[j] == '-' || req[j] == '.'); j++) /* do nothing */; if (l - j >= 3 && req[j] == ':' && req[j + 1] == '/' && req[j + 2] == '/') { if (strncasecmp(req + i, "gemini", j - i) != 0) return (0); i = j + 3; } } else if (l - i >= 2 && req[i] == '/' && req[i + 1] == '/') i += 2; /* Optionally skip the `userinfo@' */ j = pctmatch(req + i, l - i, "~$&'()*+,-.:;=_~"); if (i + j < l && req[i + j] == '@') i += j + 1; /* * Process any IP-literal. * It is not really worth it to actually parse IPv6 strings * just to extract the hostname for tls_connect(3). * Otherwise process the `reg-name', of which IPv4 addresses * are a subset. */ if (i < l && req[i] == '[') { for (i++; i < l && (ISASCIIALNUM(req[i]) || strchr("~$&'()*+,-.:;=_~", req[i]) != NULL); i++) /* do nothing */; if (i >= l || req[i] != ']') return (0); i++; } else if ((j = i + pctmatch(req + i, l - i, "~$&'()*+,-.;=_~")) > i) { i = j; } else { return (0); } if (i < l && req[i] == ':') for (i++; i < l && ISASCIIDIGIT(req[i]); i++) /* do nothing */; if (i < l && req[i] == '/') i += 1 + pctmatch(req + i + 1, l - i - 1, "!$&'()*+,-./:;=@_~"); if (i < l && req[i] == '?') i += 1 + pctmatch(req + i + 1, l - i - 1, "!$&'()*+,-./:;=?@_~"); if (i < l && req[i] == '#') i += 1 + pctmatch(req + i + 1, l - i - 1, "!$&'()*+,-./:;=?@_~"); return (i == l); }