From db2611fcad18f73572dd1b344e4197536086be53 Mon Sep 17 00:00:00 2001 From: ame Date: Sun, 15 Feb 2026 04:08:16 -0600 Subject: ssl server support, websocket upgrades, and net changes --- src/net.c | 223 +++++++++++++++++++++++++++++++------------------------------- 1 file changed, 112 insertions(+), 111 deletions(-) (limited to 'src/net.c') diff --git a/src/net.c b/src/net.c index d2f4e17..c8abff7 100644 --- a/src/net.c +++ b/src/net.c @@ -3,6 +3,7 @@ #include "net/lua.h" #include "net/luai.h" #include "types/str.h" +#include "net/websocket.h" #include @@ -28,6 +29,7 @@ void ssl_init(){ if(has_ssl_init == 0){ has_ssl_init = 1; SSL_library_init(); + OpenSSL_add_all_algorithms(); SSL_load_error_strings(); } } @@ -171,96 +173,29 @@ void free_url(struct url u){ if(u.port != NULL) str_free(u.port); } -#define BUFFER_LEN 4096 -struct wss_data { - SSL* ssl; - SSL_CTX* ctx; - int sock; - str* buffer; -}; - -struct ws_frame_info { - int fin; - int rsv1; - int rsv2; - int rsv3; - int opcode; - int mask; - int payload; -}; - int i_ws_read(lua_State* L){ lua_pushstring(L, "_"); lua_gettable(L, 1); - struct wss_data* data = lua_touserdata(L, -1); - char buffer[BUFFER_LEN] = {0}; - int total_len = 0; - int len; - - for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ - total_len += len; - if(total_len >= 2) break; - } - - if(len < 0) luaI_error(L, len, "SSL_read error"); - - uint64_t payload = 0; - struct ws_frame_info frame_info = {.fin = (buffer[0] >> 7) & 1, .rsv1 = (buffer[0] >> 6) & 1, - .rsv2 = (buffer[0] >> 5) & 1, .rsv3 = (buffer[0] >> 4) & 1, .opcode = buffer[0] & 0b1111, - .mask = (buffer[1] >> 7) & 1, .payload = buffer[1] & 0b1111111}; - //printf("fin: %i\npayload: %i\n", frame_info.fin, frame_info.payload); - memset(buffer, 0, total_len); - total_len = 0; - - if(frame_info.payload <= 125) payload = frame_info.payload; - else if(frame_info.payload == 126) { - for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ - total_len += len; - if(total_len >= 2) break; - } - - if(len < 0) luaI_error(L, len, "SSL_read error"); - - payload = (buffer[0] & 0xff) << 8 | (buffer[1] & 0xff); - } else { - for(; (len = SSL_read(data->ssl, buffer + total_len, 8 - total_len)) > 0;){ - total_len += len; - if(total_len >= 8) break; - } - - if(len < 0) luaI_error(L, -1, "SSL_read error"); - - - payload = ((uint64_t)buffer[0] & 0xff) << 56 | ((uint64_t)buffer[1] & 0xff) << 48 | ((uint64_t)buffer[2] & 0xff) << 40 | - ((uint64_t)buffer[3] & 0xff) << 32 | (buffer[4] & 0xff) << 24 | (buffer[5] & 0xff) << 16 | (buffer[6] & 0xff) << 8 | (buffer[7] & 0xff); - } - //printf("final payload: %lu\n", payload); - - str* message = str_init(""); - memset(buffer, 0, BUFFER_LEN); - for(; message->len != payload && (len = SSL_read(data->ssl, buffer, lesser(payload - message->len, BUFFER_LEN))) > 0;){ - str_pushl(message, buffer, len); - memset(buffer, 0, len); - } - - if(len < 0) { - str_free(message); - luaI_error(L, len, "SSL_read error"); + struct net_data* data = lua_touserdata(L, -1); + struct ws_frame_info frame = {}; + if(ws_read(data, &frame) == -1){ + luaI_error(L, -1, "SSL_read error"); + str_clear(data->buffer); } lua_newtable(L); int idx = lua_gettop(L); - luaI_tsetsl(L, idx, "content", message->c, message->len); - luaI_tseti(L, idx, "opcode", frame_info.opcode); + luaI_tsetsl(L, idx, "content", data->buffer->c, data->buffer->len); + luaI_tseti(L, idx, "opcode", frame.opcode); - str_free(message); + str_clear(data->buffer); return 1; } int i_ws_write(lua_State* L){ lua_pushstring(L, "_"); lua_gettable(L, 1); - struct wss_data* data = lua_touserdata(L, -1); + struct net_data* data = lua_touserdata(L, -1); uint64_t clen; const char* content = luaL_tolstring(L, 2, &clen); @@ -289,7 +224,7 @@ int i_ws_write(lua_State* L){ int i_ws_close(lua_State* L){ lua_pushstring(L, "_"); lua_gettable(L, 1); - struct wss_data* data = lua_touserdata(L, -1); + struct net_data* data = lua_touserdata(L, -1); if(data != NULL && data->buffer != NULL){ str_free(data->buffer); @@ -310,7 +245,7 @@ int i_ws_close(lua_State* L){ return 0; } - +#define BUFFER_LEN 4096 int l_wss(lua_State* L){ uint64_t len = 0; @@ -331,8 +266,8 @@ int l_wss(lua_State* L){ SSL* ssl = ssl_connect(ctx, sock, awa.domain->c); char* request = calloc(512, sizeof * request); - sprintf(request, "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"\ - "Sec-Websocket-Key: aWxvdmVsb3ZlbG92ZXlvdQ==\r\nSec-Websocket-Version: 13\r\n\r\n", path, awa.domain->c); + sprintf(request, "GET %s HTTP/1.1\r\nhost: %s\r\nconnection: upgrade\r\nupgrade: websocket\r\n"\ + "sec-websocket-key: aWxvdmVsb3ZlbG92ZXlvdQ==\r\nsec-websocket-version: 13\r\n\r\n", path, awa.domain->c); int s = SSL_write(ssl, request, strlen(request)); free_url(awa); @@ -360,7 +295,7 @@ int l_wss(lua_State* L){ luaI_error(L, -1, "error with header formating"); } - struct wss_data *data = malloc(sizeof * data); + struct net_data *data = malloc(sizeof * data); data->ssl = ssl; data->ctx = ctx; data->sock = sock; @@ -537,7 +472,7 @@ int _request(lua_State* L, struct request_state* state){ lua_newtable(L); int header_idx = lua_gettop(L); - luaI_tsets(L, header_idx, "User-Agent", "lullaby/"MAJOR_VERSION); + luaI_tsets(L, header_idx, "user-agent", "lullaby/"MAJOR_VERSION); if(params >= 3){ lua_pushvalue(L, header_idx); @@ -564,7 +499,7 @@ int _request(lua_State* L, struct request_state* state){ //char* req = "GET / HTTP/1.1\nHost: amyy.cc\nConnection: Close\n\n"; char* request = calloc(cont_len + header->len + strlen(host) + strlen(path) + 512, sizeof * request); - sprintf(request, "%s %s HTTP/1.1\r\nHost: %s\r\nConnection: Close%s\r\n\r\n%s", action, path, host, header->c, cont); + sprintf(request, "%s %s HTTP/1.1\r\nhost: %s\r\nconnection: close%s\r\n\r\n%s", action, path, host, header->c, cont); if(awa.path != NULL) str_free(awa.path); if(awa.domain != NULL) str_free(awa.domain); @@ -610,16 +545,16 @@ int _request(lua_State* L, struct request_state* state){ luaI_tsets(L, idx, (owo->P[i].key)->c, ((str*)owo->P[i].value)->c); } //done out of pure laziness, parse_header was meant for requests but works fine for responses, change this later - luaI_treplk(L, idx, "Path", "code"); - luaI_treplk(L, idx, "Request", "version"); - luaI_treplk(L, idx, "Version", "code-name"); + luaI_treplk(L, idx, "path", "code"); + luaI_treplk(L, idx, "request", "version"); + luaI_treplk(L, idx, "version", "code-name"); lua_pushstring(L, "code"); lua_gettable(L, idx); int code = atoi(lua_tostring(L, -1)); luaI_tseti(L, idx, "code", code); - void* encoding = parray_get(owo, "Transfer-Encoding"); + void* encoding = parray_get(owo, "transfer-encoding"); //struct _srequest_state *read_state = calloc(sizeof * read_state, 1); //read_state->ctx = state->ctx; @@ -710,8 +645,9 @@ void path_parse(struct net_path_t* path, str* raw){ _Atomic size_t threads = 0; void* handle_client(void *_arg){ + signal(SIGPIPE, SIG_IGN); thread_arg_struct* args = (thread_arg_struct*)_arg; - int client_fd = args->fd; + struct net_data* ctx = args->ctx; char* buffer; int header_eof = -1; lua_State* L = args->L; @@ -719,7 +655,7 @@ void* handle_client(void *_arg){ char* header = NULL; - int64_t bite = recv_header(client_fd, &buffer, &header); + int64_t bite = recv_header(ctx, &buffer, &header); header_eof = header - buffer; if(bite > 0){ @@ -728,13 +664,13 @@ void* handle_client(void *_arg){ //checks for a valid header int val = parse_header(buffer, header_eof + 2, &table); - if(val == -2) net_error(client_fd, 414); + if(val == -2) net_error(ctx, 414); if(val >= 0){ - str* path = (str*)parray_get(table, "Path"); - str* sR = (str*)parray_get(table, "Request"); - str* sT = (str*)parray_get(table, "Content-Type"); - str* sC = (str*)parray_get(table, "Cookie"); + str* path = (str*)parray_get(table, "path"); + str* sR = (str*)parray_get(table, "request"); + str* sT = (str*)parray_get(table, "content-type"); + str* sC = (str*)parray_get(table, "cookie"); struct file_parse* file_cont = calloc(1, sizeof * file_cont); lua_newtable(L); @@ -755,7 +691,7 @@ void* handle_client(void *_arg){ parray_t* v = NULL; if(decoded_err == 1 || args->paths == NULL){ - net_error(client_fd, 400); + net_error(ctx, 400); } else { str_push(aa, decoded_path->c); @@ -787,7 +723,7 @@ void* handle_client(void *_arg){ luaI_tsetv(L, req_idx, "cookies", lcookie); parray_clear(cookie, STR); - parray_remove(table, "Cookie", STR); + parray_remove(table, "cookie", STR); } if(parsed_path.query->len != 0){ @@ -805,7 +741,7 @@ void* handle_client(void *_arg){ int ld = lua_gettop(L); luaI_tsetv(L, req_idx, "_data", ld); luaI_tsetv(L, req_idx, "files", files_idx); - luaI_tsetv(L, req_idx, "Body", body_idx); + luaI_tsetv(L, req_idx, "body", body_idx); for(int i = 0; i != table->len; i+=1){ //printf("'%s' :: '%s'\n",table[i]->c, table[i+1]->c); @@ -817,29 +753,34 @@ void* handle_client(void *_arg){ luaI_tsets(L, req_idx, "rawpath", path->c); if(bite == -1){ - client_fd = -2; + args->ctx->sock = -2; } luaI_tseti(L, req_idx, "_bytes", bite - header_eof - 4); - luaI_tseti(L, req_idx, "client_fd", client_fd); + //luaI_tseti(L, req_idx, "client_fd", client_fd); luaI_tsetcf(L, req_idx, "roll", l_roll); + luaI_tsetv(L, res_idx, "req", req_idx); + luaI_tsetv(L, req_idx, "res", res_idx); + //functions luaI_tsetcf(L, res_idx, "send", l_send); luaI_tsetcf(L, res_idx, "sendfile", l_sendfile); luaI_tsetcf(L, res_idx, "write", l_write); luaI_tsetcf(L, res_idx, "close", l_close); luaI_tsetcf(L, res_idx, "stop", l_stop); + luaI_tsetcf(L, res_idx, "upgrade", l_connection_upgrade); //values - luaI_tseti(L, res_idx, "client_fd", client_fd); + //luaI_tseti(L, res_idx, "client_fd", client_fd); + luaI_tsetlud(L, res_idx, "_", ctx); luaI_tsets(L, res_idx, "_request", sR->c); //header table lua_newtable(L); int header_idx = lua_gettop(L); - luaI_tseti(L, header_idx, "Code", 200); + luaI_tseti(L, header_idx, "code", 200); luaI_tsetv(L, res_idx, "header", header_idx); @@ -885,7 +826,7 @@ void* handle_client(void *_arg){ if(lua_pcall(L, 2, 0, errtraceback_idx) != 0){ fprintf(stderr, "(net thread) %s\n", lua_tostring(L, -1)); //send an error message if send has not been called - if(client_fd >= 0) net_error(client_fd, 500); + if(args->ctx->sock >= 0) net_error(ctx, 500); goto net_end; } @@ -904,9 +845,9 @@ net_end: larray_clear(params); parray_lclear(owo); //dont free the rest - lua_pushstring(L, "client_fd"); - lua_gettable(L, res_idx); - client_fd = luaL_checkinteger(L, -1); + //lua_pushstring(L, "client_fd"); + //lua_gettable(L, res_idx); + //client_fd = luaL_checkinteger(L, -1); } @@ -922,9 +863,12 @@ net_end: parray_clear(table, STR); } - if(client_fd != -1){ - shutdown(client_fd, 2); - closesocket(client_fd); + if(args->ctx->sock != -1){ + if(args->ctx->ssl != NULL) SSL_shutdown(args->ctx->ssl); + else { + shutdown(args->ctx->sock, 2); + closesocket(args->ctx->sock); + } } free(args); @@ -941,12 +885,23 @@ int clean_lullaby_net(lua_State* L){ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* state){ parse_mimetypes(); + if(state->ssl) ssl_init(); //need these on windows for sockets (stupid) #ifdef _WIN32 WSADATA Data; WSAStartup(MAKEWORD(2, 2), &Data); #endif + SSL_CTX* server_ctx; + if(state->ssl){ + if(!(server_ctx = SSL_CTX_new(TLS_server_method()))) + luaI_error(L, -7, "SSL_CTX_new error"); + luaI_assert(L, SSL_CTX_use_certificate_file(server_ctx, state->ssl_crt->c, SSL_FILETYPE_PEM) > 0) + luaI_assert(L, SSL_CTX_use_PrivateKey_file(server_ctx, state->ssl_key->c, SSL_FILETYPE_PEM) > 0) + if (SSL_CTX_check_private_key(server_ctx) == -1) + luaI_error(L, -8, "key does not match crt"); + } + int server_fd; struct sockaddr_in server_addr; struct pollfd fds[2]; @@ -1008,7 +963,8 @@ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* if(threads >= max_con){ //deny request - net_error(*client_fd, 503); +#warning "need a better way to do this with ssl support" + //net_error(*client_fd, 503); close(*client_fd); free(client_fd); @@ -1016,8 +972,31 @@ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* } thread_arg_struct* args = malloc(sizeof * args); - - args->fd = *client_fd; + args->ctx = malloc(sizeof * args->ctx); + + args->ctx->ssl = NULL; + if(state->ssl){ + args->ctx->ssl = SSL_new(server_ctx); + if(!args->ctx->ssl){ + fprintf(stderr, "SSL_new fail\n"); + close(*client_fd); + free(client_fd); + free(args); + continue; + } + SSL_set_fd(args->ctx->ssl, *client_fd); + + int f; + if((f = SSL_accept(args->ctx->ssl)) <= 0){ + fprintf(stderr, "SSL_accept fail %i\n", f); + close(*client_fd); + free(client_fd); + SSL_free(args->ctx->ssl); + free(args); + continue; + } + } + args->ctx->sock = *client_fd; args->port = port; args->cli = client_addr; args->L = luaL_newstate(); @@ -1050,6 +1029,7 @@ net_end: close(server_fd); close(efd); free(state); + SSL_CTX_free(server_ctx); for(int i = 0; i != paths->len; i++){ struct sarray_t* path = paths->P[i].value; @@ -1157,6 +1137,9 @@ int l_listen(lua_State* L){ struct net_server_state *state = malloc(sizeof * state); state->event_fd = -1; + state->ssl = 0; + state->ssl_crt = str_init(""); + state->ssl_key = str_init(""); int port = luaL_checkinteger(L, 2); @@ -1173,6 +1156,8 @@ int l_listen(lua_State* L){ luaI_tsetcf(L, mt, "PATCH", l_PATCHq); luaI_tsetcf(L, mt, "all", l_allq); + luaI_tsettab(L, mt, "ssl"); + luaI_tsetcf(L, mt, "close", l_net_close); luaI_tsetv(L, mt, "port", 2); @@ -1187,6 +1172,22 @@ int l_listen(lua_State* L){ if(state->event_fd == -2) luaI_error(L, -2, "closed"); + lua_getfield(L, mt, "ssl"); + int ssl = lua_gettop(L); + if(!lua_isnil(L, -1) && lua_type(L, ssl) == LUA_TTABLE){ + lua_getfield(L, ssl, "key"); + if(!lua_isnil(L, -1)){ + str_push(state->ssl_key, luaL_checkstring(L, -1)); + lua_getfield(L, ssl, "crt"); + if(!lua_isnil(L, -1)){ + str_push(state->ssl_crt, luaL_checkstring(L, -1)); + lua_pop(L, 2); + state->ssl = 1; + } + } + } + lua_pop(L, 1); + return start_serv(L, port, paths, state); ; } -- cgit v1.2.3