From 875cd9d0445fd8a882354e33ccd1d916ad4aee06 Mon Sep 17 00:00:00 2001 From: Daniel Ledda Date: Sun, 30 Nov 2025 20:01:06 +0100 Subject: [PATCH] working chat app example with some rough edges --- app.c | 257 +++++++++++++++++++++++++++++++++++------------------ core.c | 2 +- core.h | 11 +-- os.h | 39 +++++++- os_linux.c | 130 +++++++++++++++++++++++++-- 5 files changed, 340 insertions(+), 99 deletions(-) diff --git a/app.c b/app.c index f568001..8cbf0e3 100644 --- a/app.c +++ b/app.c @@ -1,121 +1,208 @@ +#include "core.h" #define DJSTD_BASIC_ENTRY #include "core.c" #include "signal.h" -#include "stdlib.h" -Server *server = NULL; +Server *openServer = NULL; +SocketList *openSockets = NULL; void handleSigint(int dummy) { - if (server) { + if (openServer) { println(""); println("Closing server socket."); - serverClose(server); + serverClose(openServer); + println("Success."); + } + if (openSockets && openSockets->length) { + println(""); + println("Closing open sockets."); + for (EachEl(*openSockets, Socket, socket)) { + socketClose(socket); + } println("Success."); } signal(SIGINT, SIG_DFL); raise(SIGINT); } -int djstd_entry(Arena *arena, StringList args) { - signal(SIGINT, &handleSigint); +typedef struct ChatClient ChatClient; +struct ChatClient { + Socket *socket; + string nickname; +}; +DefineList(ChatClient, ChatClient); + +void startServer(Arena *arena, int32 port) { + println("Starting server..."); + Server server = serverInit((ServerInitInfo){ + .concurrentClients=2, + .port=port, + .memory=Megabytes(64), + .maxEvents=64, + }); + openServer = &server; + + serverListen(&server); + if (server.listening) { + println("Listening on port %d", port); + } - bool isServer = strEql(args.data[0], s("server")); - bool isClient = strEql(args.data[0], s("client")); + Arena *serverLoopArena = arenaAlloc(Megabytes(64)); + ChatClientList chatClients = PushListZero(arena, ChatClientList, 256); - if (!isServer && !isClient || args.length < 2) { - println("Usage: [type] [port] ([remote_address])"); - println("[type] is either 'server' or 'client'"); - println("[remote_address] can be given if a client app, default is loopback"); - return 0; - } + Forever { + ServerEvent *nextEvent; - int port = 8080; - Int32Result portParsed = parsePositiveInt(args.data[1]); - if (portParsed.valid) { - port = portParsed.result; - } + do { + nextEvent = serverGetNextEvent(&server); + switch (nextEvent->type) { + case ServerEventType_AcceptClient: { + Socket *client = serverAccept(&server); + if (client != NULL) { + println("New client connected from %d", client->address); + } + break; + }; + case ServerEventType_ClientMessage: { + StringResult clientMsg = socketReadStr(serverLoopArena, nextEvent->tClientMessage.client); + ChatClient *chatClient = NULL; + if (clientMsg.valid) { + if (strStartsWith(clientMsg.result, s("hello-"))) { + StringList nickSplit = strSplit(serverLoopArena, s("-"), clientMsg.result); + if (nickSplit.length == 2 && nickSplit.data[1].length > 0) { + string newNick = PushString(arena, nickSplit.data[1].length); + newNick.length = nickSplit.data[1].length; + memcpy(newNick.str, nickSplit.data[1].str, nickSplit.data[1].length); + ChatClient newChatClient = (ChatClient){ + .socket=nextEvent->tClientMessage.client, + .nickname=newNick, + }; + AppendList(&chatClients, newChatClient); + println("Client from %d calls themselves \"%S\"", newChatClient.socket->address, newChatClient.nickname); + } + } else { + for (EachEl(chatClients, ChatClient, maybeChatClient)) { + if (maybeChatClient->socket->handle == nextEvent->tClientMessage.client->handle) { + chatClient = maybeChatClient; + } + } + if (chatClient != NULL) { + if (strStartsWith(clientMsg.result, s("say-"))) { + StringList saySplit = strSplit(arena, s("-"), clientMsg.result); + if (saySplit.length == 2 && saySplit.data[1].length > 0) { + string broadcast = strPrintf(serverLoopArena, "%S says:\n%S", chatClient->nickname, saySplit.data[1]); + for (EachEl(server.clients, Socket, client)) { + socketWriteStr(client, broadcast); + } + } + } else { + // Invalid client message + } + } + } + } + break; + }; + case ServerEventType_None: { + break; + }; + default: + break; + } + } while (nextEvent != NULL); - string addr = s("::1"); - if (!isServer && args.length > 2) { - if (args.data[2].length > 0) { - addr = args.data[2]; - } + arenaFreeFrom(serverLoopArena, 0); } - if (isServer) { - println("Starting server on port %d", port); - Server myserver = serverInit((ServerInitInfo){ - .concurrentClients=2, - .port=port, - .memory=Megabytes(64), - }); + println("Shutting down chat."); + serverClose(&server); +} - server = &myserver; +void clearStdInLn() { + print("\r"); + print(ANSI_INSTRUCTION(J)); +} - serverListen(&myserver); - Socket *client1 = serverAccept(&myserver); - Socket *client2 = serverAccept(&myserver); +void clearStdInLnAfterInput() { + print("\r"); + print(ANSI_INSTRUCTION(A)); + print(ANSI_INSTRUCTION(J)); +} +void startClient(Arena *arena, string addr, int32 port, string nickname) { + fcntl(0, F_SETFL, fcntl(0, F_GETFL) | O_NONBLOCK); + + println("Connecting to server at [%S]:%d with nickname \"%S\"", addr, port, nickname); + Socket server = socketConnect(arena, (SocketConnectInfo){ .address=addr, .port=port }); + if (server.closed) { + println("Connection error. Closing."); + } else { + println("Connected successfully"); + string message = strPrintf(arena, "hello-%S", nickname); + CharList inputBuf = PushList(arena, CharList, 512); + socketWriteStr(&server, message); + + print("(you)> "); Forever { - string message = s("Hello. You are client 1.\n"); - socketWrite(client1, message.str, message.length); - message = s("Hello. You are client 2.\n"); - socketWrite(client2, message.str, message.length); - - string buf = PushStringFill(arena, 256, 0); - uint64 bytesRead = socketRead(client1, buf.str, buf.length - 1); - - if (bytesRead > 0) { - buf.length = bytesRead; - string message = strSplit(arena, s("\n"), buf).data[0]; - message = strPrintf(arena, "Client 1 said: %S\n", message); - println("%S", message); - socketWrite(client2, message.str, message.length); - - println("Saying goodbye to everyone"); - message = s("Goodbye\n"); - socketWrite(client1, message.str, message.length); - socketWrite(client2, message.str, message.length); - break; + Scratch scratch = scratchStart(&arena, 1); + + int32 numRead = read(0, inputBuf.data + inputBuf.length, inputBuf.capacity - inputBuf.length); + if (numRead >= 0) { + inputBuf.length += numRead; + if (inputBuf.data[inputBuf.length - 1] == '\n') { + clearStdInLnAfterInput(); + socketWriteStr(&server, strPrintf(scratch.arena, "say-%S", (string){.str=inputBuf.data,.length=inputBuf.length})); + inputBuf.length = 0; + } } - } - socketClose(client1); - socketClose(client2); - serverClose(&myserver); - } else { - println("Connecting to socket at %S on port %d", addr, port); - Socket sock = socketConnect(arena, (SocketConnectInfo){ .address=addr, .port=port }); - if (sock.closed) { - println("Connection error. Closing."); - } else { - string message; - uint64 bytesWritten; - - string buf = PushStringFill(arena, 256, 0); - - socketRead(&sock, buf.str, buf.length); - string messageReceived = strSplit(arena, s("\n"), buf).data[0]; - println("%S", strPrintf(arena, "Server said: %S", messageReceived)); - - if (strEql(messageReceived, s("Hello. You are client 1."))) { - string broadcast = s("HELLO WORLD!!!!\n"); - socketWrite(&sock, broadcast.str, broadcast.length); + StringResult serverMsg = socketReadStr(scratch.arena, &server); + if (serverMsg.valid && serverMsg.result.length > 0) { + clearStdInLn(); + println("%S", serverMsg.result); + print("(you)> %S", (string){.str=inputBuf.data, inputBuf.length}); } - Forever { - socketRead(&sock, buf.str, buf.length); - messageReceived = strSplit(arena, s("\n"), buf).data[0]; - println("Server said: %S", messageReceived); - if (strEql(messageReceived, s("Goodbye"))) { - println("Quitting"); - break; + scratchEnd(scratch); + } + } + socketClose(&server); +} + +int djstd_entry(Arena *arena, StringList args) { + signal(SIGINT, &handleSigint); + + bool argumentErr = true; + bool isServer = strEql(args.data[0], s("server")); + bool isClient = strEql(args.data[0], s("client")); + + if (isServer) { + Int32Result portParsed = parsePositiveInt(args.data[1]); + if (portParsed.valid) { + startServer(arena, portParsed.result); + argumentErr = false; + } + } else if (isClient) { + if (args.length == 3) { + StringList split = strSplit(arena, s("]:"), args.data[1]); + if (split.length == 2) { + Int32Result portParsed = parsePositiveInt(split.data[1]); + string addr = strSlice(split.data[0], 1, split.data[0].length); + string nickname = args.data[2]; + if (portParsed.valid && addr.length > 0 && nickname.length > 0) { + startClient(arena, addr, portParsed.result, nickname); + argumentErr = false; } } } + } - socketClose(&sock); + if (argumentErr) { + println("Usage:"); + println("server [PORT]"); + println("OR"); + println("client [REMOTE_ADDRESS:PORT] [NICKNAME]"); } return 0; diff --git a/core.c b/core.c index b68ee7c..d2e72e5 100644 --- a/core.c +++ b/core.c @@ -211,7 +211,7 @@ StringList strSplit(Arena *arena, string splitStr, string inputStr) { splitString->str = inputStr.str + start; splitString->length = c - start; splitCount++; - start = c + 1; + start = c + splitStr.length; } c++; } diff --git a/core.h b/core.h index b76e5bf..66df1bc 100644 --- a/core.h +++ b/core.h @@ -181,8 +181,8 @@ DefineList(string, String); #define ListRemove(list, index)\ if ((index) >= 0 && (index) < (list)->length) {\ - memcpy((list)->data + (index), (list)->data + (index) + 1, (parentNode->children.length - (i + 1))*sizeof(*((list)->data)));\ - parentNode->children.length -= 1;\ + memcpy((list)->data + (index), (list)->data + (index) + 1, ((list)->length - (i + 1))*sizeof(*((list)->data)));\ + (list)->length -= 1;\ } // ### Strings ### @@ -255,9 +255,10 @@ typedef enum { StdStream_stderr, } StdStream; -#define ANSI_INSTRUCTION_FROM_ENUM(ansiCodeEnum) ANSI_INSTRUCTION(ansiCodeEnum) -#define ANSI_INSTRUCTION(ansiCode) "\x1b[" #ansiCode "m" -#define ANSI_INSTRUCTION_STR(ansiCodeStr) "\x1b[" ansiCodeStr "m" +#define ANSI_INSTRUCTION(ansiCode) "\x1b[" #ansiCode +#define ANSI_INSTRUCTION_STR(ansiCodeStr) "\x1b[" ansiCodeStr +#define ANSI_GRAPHIC_INSTRUCTION(ansiCode) "\x1b[" #ansiCode "m" +#define ANSI_GRAPHIC_INSTRUCTION_STR(ansiCodeStr) "\x1b[" ansiCodeStr "m" #define ANSI_RESET ANSI_INSTRUCTION(0) #define ANSI_fg_black 30 diff --git a/os.h b/os.h index f4d5b20..e6063a5 100644 --- a/os.h +++ b/os.h @@ -57,9 +57,10 @@ struct Server { typedef struct ServerInitInfo ServerInitInfo; struct ServerInitInfo { - uint16 port; - uint32 concurrentClients; - uint64 memory; + int16 port; + int32 concurrentClients; + int64 memory; + int32 maxEvents; }; typedef struct SocketConnectInfo SocketConnectInfo; @@ -71,14 +72,46 @@ struct SocketConnectInfo { // Server/Client interface Server serverInit(ServerInitInfo info); + void serverListen(Server *s); + Socket *serverAccept(Server *s); + void serverClose(Server *s); +enum ServerEventType { + ServerEventType_AcceptClient, + ServerEventType_ClientMessage, + ServerEventType_None, + ServerEventType_COUNT, +}; + +typedef struct ServerEvent ServerEvent; +struct ServerEvent { + enum ServerEventType type; + union { + struct {} tAcceptClient; + struct { + int32 clientId; + Socket *client; + } tClientMessage; + }; +}; + +ServerEvent *serverGetNextEvent(Server *s); + // Generic socket interface Socket socketConnect(Arena *arena, SocketConnectInfo info); + int64 socketRead(Socket *s, byte *dest, uint64 numBytes); + +DefineResult(string, String); +StringResult socketReadStr(Arena *arena, Socket *s); + int64 socketWrite(Socket *s, byte *source, uint64 numBytes); + +int64 socketWriteStr(Socket *socket, string data); + void socketClose(Socket *s); #endif diff --git a/os_linux.c b/os_linux.c index 2fbbbac..55bae27 100644 --- a/os_linux.c +++ b/os_linux.c @@ -143,9 +143,15 @@ OS_Thread os_createThread(void *(*entry)(void *), void *ctx) { return (OS_Thread){ .id=handle }; } +DefineList(ServerEvent, ServerEvent); typedef struct EPollServerEvents EPollServerEvents; struct EPollServerEvents { int epollFd; + int32 maxEvents; + int32 numEvents; + struct epoll_event *events; + bool err; + ServerEventList userEvents; }; Server serverInit(ServerInitInfo info) { @@ -153,6 +159,10 @@ Server serverInit(ServerInitInfo info) { EPollServerEvents *events = PushStructZero(arena, EPollServerEvents); events->epollFd = epoll_create1(0); + events->events = PushArrayZero(arena, struct epoll_event, info.maxEvents); + events->maxEvents = info.maxEvents; + events->numEvents = 0; + events->userEvents = PushListZero(arena, ServerEventList, info.maxEvents); struct sockaddr_in6 *serverAddr = PushStructZero(arena, struct sockaddr_in6); serverAddr->sin6_family = AF_INET6; @@ -171,6 +181,12 @@ Server serverInit(ServerInitInfo info) { fcntl((uint64)server.handle, F_SETFL, fcntl((uint64)server.handle, F_GETFL, 0) | O_NONBLOCK); + struct epoll_event event = { + .data.fd=(uint64)server.handle, + .events=EPOLLIN | EPOLLET, + }; + epoll_ctl(events->epollFd, EPOLL_CTL_ADD, (int64)server.handle, &event); + int bindErr = bind((uint64)server.handle, (struct sockaddr *)serverAddr, sizeof(*serverAddr)); if (bindErr == -1) { // TODO(dledda): handle err @@ -182,7 +198,9 @@ Server serverInit(ServerInitInfo info) { void serverListen(Server *s) { int listenErr = listen((uint64)s->handle, s->clients.capacity); if (listenErr == -1) { - // TODO(dledda): handle err ? + s->listening = false; + } else { + s->listening = true; } } @@ -193,10 +211,18 @@ Socket *serverAccept(Server *s) { uint64 clientSockHandle = accept((int)(uint64)s->handle, (struct sockaddr *)clientAddr, &clientAddrLen); if (clientSockHandle == -1) { clientSockHandle = (uint64)NULL; + println("ERR server accept"); + perror("accept"); } else { fcntl((uint64)clientSockHandle, F_SETFL, fcntl((uint64)clientSockHandle, F_GETFL, 0) | O_NONBLOCK); } + struct epoll_event event = { + .data.fd=clientSockHandle, + .events=EPOLLIN | EPOLLET, + }; + epoll_ctl(((EPollServerEvents *)s->events)->epollFd, EPOLL_CTL_ADD, clientSockHandle, &event); + if (s->clients.length < s->clients.capacity) { AppendList(&s->clients, ((Socket){ .handle=(SocketHandle *)(uint64)clientSockHandle, @@ -208,6 +234,58 @@ Socket *serverAccept(Server *s) { } } +ServerEvent *serverGetNextEvent(Server *s) { + EPollServerEvents *serverEvents = ((EPollServerEvents *)s->events); + + if (serverEvents->userEvents.length == 0) { + serverEvents->numEvents = epoll_wait(serverEvents->epollFd, serverEvents->events, serverEvents->maxEvents, -1); + if (serverEvents->numEvents == -1) { + serverEvents->err = true; + serverEvents->numEvents = 0; + } else { + serverEvents->userEvents.length = serverEvents->numEvents; + for (int32 i = 0; i < serverEvents->numEvents; i++) { + struct epoll_event *ev = &serverEvents->events[i]; + if ((ev->events & EPOLLIN) && ev->data.fd == (int)(int64)s->handle) { + serverEvents->userEvents.data[i] = (ServerEvent){ + .type=ServerEventType_AcceptClient, + .tAcceptClient={}, + }; + } else if (ev->events & EPOLLIN) { + int64 fd = serverEvents->events[i].data.fd; + Socket *client = NULL; + ServerEvent serverEv = { + .type=ServerEventType_ClientMessage, + }; + + for (EachIn(s->clients, j)) { + if ((int64)s->clients.data[j].handle == fd) { + client = &s->clients.data[j]; + serverEv.tClientMessage.client = client; + serverEv.tClientMessage.clientId = j; + } + } + + if (client == NULL) { + serverEv.type = ServerEventType_None; + } else { + serverEvents->userEvents.data[i] = serverEv; + } + } + } + } + + } + + if (serverEvents->userEvents.length == 0) { + AppendList(&serverEvents->userEvents, (ServerEvent){ .type=ServerEventType_None }); + } + + // Pop next event + serverEvents->userEvents.length--; + return &serverEvents->userEvents.data[serverEvents->userEvents.length]; +} + int64 socketRead(Socket *socket, byte *dest, uint64 numBytes) { int64 bytesRead = read((uint64)socket->handle, dest, numBytes); if (bytesRead == -1) { @@ -216,30 +294,66 @@ int64 socketRead(Socket *socket, byte *dest, uint64 numBytes) { return bytesRead; } +StringResult socketReadStr(Arena *arena, Socket *socket) { + byte *dest = PushArray(arena, byte, Kilobytes(256)); + int64 bytesRead = read((uint64)socket->handle, dest, Kilobytes(1024)); + bool err = bytesRead == -1 || bytesRead == 0; + if (err) { + arenaPopTo(arena, dest); + } else { + arenaPopTo(arena, dest + bytesRead); + } + return (StringResult){ + .valid=!err, + .result=(string){ + .str=dest, + .length=err ? 0 : bytesRead, + }, + }; +} + void serverClose(Server *s) { close((int)(uint64)s->handle); } +bool serverHangupClient(Server *s, Socket *client) { + struct epoll_event eventUnsubscribe = { + .data.fd=(int)(int64)client->handle, + .events=EPOLLIN | EPOLLET, + }; + int err = epoll_ctl(((EPollServerEvents *)s->events)->epollFd, EPOLL_CTL_DEL, (int)(int64)client->handle, &eventUnsubscribe); + if (err == 0) { + for (EachIn(s->clients, i)) { + if (s->clients.data[i].handle == client->handle) { + ListRemove(&s->clients, i); + return true; + } + } + } + return false; +} + void socketClose(Socket *s) { close((int)(uint64)s->handle); } Socket socketConnect(Arena *arena, SocketConnectInfo info) { int socketFd = socket(AF_INET6, SOCK_STREAM, 0 /* IPPROTO_TCP */); + fcntl(socketFd, F_SETFL, fcntl(socketFd, F_GETFL, 0) | O_NONBLOCK); + struct sockaddr_in6 *remoteAddr = PushStructZero(arena, struct sockaddr_in6); remoteAddr->sin6_family = AF_INET6; inet_pton(AF_INET6, cstring(arena, info.address), &remoteAddr->sin6_addr); remoteAddr->sin6_port = htons(info.port); int connectErr = connect(socketFd, (struct sockaddr *)remoteAddr, sizeof(*remoteAddr)); + Socket result = { .handle=(SocketHandle *)(uint64)socketFd, .address=(Address *)remoteAddr, .closed=false, + //.closed=connectErr == -1, }; - if (connectErr == -1) { - // TODO(dledda): handle err - result.closed = true; - } + return result; } @@ -249,4 +363,10 @@ int64 socketWrite(Socket *socket, byte *source, uint64 numBytes) { return written; } +int64 socketWriteStr(Socket *socket, string data) { + int64 written = send((uint64)socket->handle, data.str, data.length, MSG_NOSIGNAL); + if (written == -1) socket->closed = true; + return written; +} + #endif