// © 2020 Erik Rigtorp <erik@rigtorp.se>
// SPDX-License-Identifier: CC0-1.0

// gcc -g -Wall -D_GNU_SOURCE epollserver.c

#include <errno.h>
#include <limits.h> // INT_MAX
#include <netdb.h>  // getaddrinfo, getnameinfo
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h> // epoll_*
#include <sys/socket.h>
#include <unistd.h> // close

int xlisten(int efd, const char *host, const char *port) {
  struct addrinfo hints;
  memset(&hints, 0, sizeof(hints));
  hints.ai_family = AF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;
  hints.ai_flags = AI_PASSIVE;

  struct addrinfo *res = NULL;
  int rc = getaddrinfo(host, port, &hints, &res);
  if (rc == -1) {
    fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(rc));
    return -1;
  }

  int success = -1;
  for (struct addrinfo *rp = res; rp != NULL; rp = rp->ai_next) {
    int fd =
        socket(rp->ai_family, rp->ai_socktype | SOCK_NONBLOCK | SOCK_CLOEXEC,
               rp->ai_protocol);
    if (fd == -1) {
      perror("socket");
      continue;
    }

    int opt = 1;
    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
      perror("setsockopt");
      close(fd);
      continue;
    }

    if (rp->ai_family == AF_INET6) {
      int opt = 1;
      if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) == -1) {
        perror("setsockopt");
        close(fd);
        continue;
      }
    }

    if (bind(fd, rp->ai_addr, rp->ai_addrlen) == -1) {
      perror("bind");
      close(fd);
      continue;
    }

    struct epoll_event ev;
    ev.events = EPOLLIN | EPOLLOUT | EPOLLET;
    ev.data.fd = (fd << 1) | 1;
    if (epoll_ctl(efd, EPOLL_CTL_ADD, fd, &ev) == -1) {
      perror("epoll_ctl");
      close(fd);
      continue;
    }

    if (listen(fd, 16) == -1) {
      perror("listen");
      close(fd);
      continue;
    }

    char host[NI_MAXHOST];
    char serve[NI_MAXSERV];
    memset(host, 0, sizeof(host));
    memset(serve, 0, sizeof(serve));
    rc = getnameinfo(rp->ai_addr, rp->ai_addrlen, host, sizeof(host), serve,
                     sizeof(serve), NI_NUMERICHOST | NI_NUMERICSERV);
    if (rc != 0) {
      fprintf(stderr, "getnameinfo: %s\n", gai_strerror(rc));
      close(fd);
      continue;
    }

    fprintf(stdout, "listening on %s:%s\n", host, serve);

    success = 0;
  }

  freeaddrinfo(res);

  return success;
}

int xaccept(int epfd, int lfd) {
  for (;;) {
    struct sockaddr_storage addr;
    socklen_t addrlen = sizeof(addr);
    int fd = accept4(lfd, (struct sockaddr *)&addr, &addrlen,
                     SOCK_NONBLOCK | SOCK_CLOEXEC);
    if (fd == -1) {
      if (errno == EAGAIN || EWOULDBLOCK) {
        break;
      }
      if (errno == EINTR) {
        continue;
      }
      if (errno == ECONNABORTED) {
        continue;
      }
      perror("accept4");
      return -1;
    }
    struct epoll_event ev;
    ev.events = EPOLLIN | EPOLLOUT | EPOLLET;
    ev.data.fd = fd << 1;
    if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
      perror("epoll_ctl");
      close(fd);
      continue;
    }

    char host[NI_MAXHOST];
    char serve[NI_MAXSERV];
    memset(host, 0, sizeof(host));
    memset(serve, 0, sizeof(serve));
    int rc = getnameinfo((struct sockaddr *)&addr, addrlen, host, sizeof(host),
                         serve, sizeof(serve), NI_NUMERICHOST | NI_NUMERICSERV);
    if (rc != 0) {
      fprintf(stderr, "getnameinfo: %s\n", gai_strerror(rc));
      close(fd);
      continue;
    }

    fprintf(stdout, "connection from %s:%s\n", host, serve);
  }
  return 0;
}

int xrecv(int fd) {
  char buf[4096];
  for (;;) {
    int n = recv(fd, buf, sizeof(buf), 0);
    if (n == -1) {
      if (errno == EAGAIN || EWOULDBLOCK) {
        break;
      }
      if (errno == EINTR) {
        continue;
      }
      perror("recv");
      close(fd);
      return -1;
    }
    if (n == 0) {
      errno = ENOTCONN;
      perror("recv");
      close(fd);
      return -1;
    }
    printf("%*s", n, buf);
  }
  return 0;
}

int main(int argc, char *argv[]) {

  int efd = epoll_create1(EPOLL_CLOEXEC);
  if (efd == -1) {
    perror("epoll_create1");
    exit(EXIT_FAILURE);
  }

  int rc = xlisten(efd, "*", "6666");
  if (rc == -1) {
    fprintf(stderr, "failed to listen\n");
    exit(EXIT_FAILURE);
  }

  struct epoll_event events[16];
  for (;;) {
    int nfds = epoll_wait(efd, events, sizeof(events), -1);
    if (nfds == -1) {
      if (errno == EINTR) {
        continue;
      }
      perror("epoll_wait");
      exit(EXIT_FAILURE);
    }

    for (int i = 0; i < nfds; ++i) {
      if (events[i].data.fd & 1) {
        int lfd = events[i].data.fd >> 1;
        if (xaccept(efd, lfd) == -1) {
          exit(EXIT_FAILURE);
        }
      } else {
        int fd = events[i].data.fd >> 1;
        if (xrecv(fd) == -1) {
          fprintf(stdout, "disconnected\n");
        }
      }
    }
  }

  return 0;
}