From db2611fcad18f73572dd1b344e4197536086be53 Mon Sep 17 00:00:00 2001 From: ame Date: Sun, 15 Feb 2026 04:08:16 -0600 Subject: ssl server support, websocket upgrades, and net changes --- src/encode/base64.c | 10 +-- src/encode/base64.h | 2 + src/hash/sha01.c | 2 +- src/hash/sha01.h | 2 + src/lua.h | 4 + src/net.c | 223 +++++++++++++++++++++++++-------------------------- src/net/common.h | 14 +++- src/net/lua.c | 109 +++++++++++++++---------- src/net/lua.h | 1 + src/net/luai.c | 4 +- src/net/ssl.h | 2 + src/net/util.c | 98 +++++++---------------- src/net/util.h | 7 +- src/net/websocket.c | 225 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/net/websocket.h | 16 ++++ src/types/str.c | 4 + src/types/str.h | 1 + 17 files changed, 487 insertions(+), 237 deletions(-) create mode 100644 src/net/ssl.h create mode 100644 src/net/websocket.c create mode 100644 src/net/websocket.h (limited to 'src') diff --git a/src/encode/base64.c b/src/encode/base64.c index d6b3474..9f95cb6 100644 --- a/src/encode/base64.c +++ b/src/encode/base64.c @@ -48,10 +48,8 @@ int de_base64(char* in, char* out){ } return 0; } -int en_base64(char* in, char* out){ - int len = 0; - for(int i = 0; in[i]!='\0'; i++) len++; +int en_base64(char* in, uint64_t len, char* out){ //char out[(len+1)*3]; for(int i = 0; i < len; i+=3){ uint8_t f = i>len?0:in[i]; @@ -63,8 +61,8 @@ int en_base64(char* in, char* out){ uint8_t i3 = (uint8_t)(s<<4)>>2 | (t>>6); uint8_t i4 = t & 0x3f; - if(t==0)i4 = 64; - if(s==0)i3 = 64; + if(i+1>=len)i3 = 64; + if(i+2>=len)i4 = 64; sprintf(out,"%s%c%c%c%c",out,char_index(i1),char_index(i2), char_index(i3),char_index(i4)); } @@ -78,7 +76,7 @@ int l_base64encode(lua_State* L){ memcpy(a, _a, len); char* encode = calloc(len * 3,sizeof * encode); - en_base64(a, encode); + en_base64(a, len, encode); lua_pushstring(L, encode); free(a); diff --git a/src/encode/base64.h b/src/encode/base64.h index 216d5e2..a58f503 100644 --- a/src/encode/base64.h +++ b/src/encode/base64.h @@ -2,3 +2,5 @@ int l_base64encode(lua_State*); int l_base64decode(lua_State*); + +int en_base64(char* in, uint64_t, char* out); diff --git a/src/hash/sha01.c b/src/hash/sha01.c index aa403a2..809860a 100644 --- a/src/hash/sha01.c +++ b/src/hash/sha01.c @@ -136,7 +136,7 @@ void sha01_final(struct sha01_hash* hash, char* out_stream){ hash->buffer[63 - i] = (uint8_t) (lhhh >> (i * 8) & 0xFF); sha01_round(hash); - sprintf(out_stream,"%02x%02x%02x%02x%02x",hash->h0,hash->h1,hash->h2,hash->h3,hash->h4); + sprintf(out_stream,"%08x%08x%08x%08x%08x",hash->h0,hash->h1,hash->h2,hash->h3,hash->h4); memcpy(hash, &old_hash, sizeof * hash); memcpy(hash->buffer, old, bs); diff --git a/src/hash/sha01.h b/src/hash/sha01.h index bb343c2..76d81db 100644 --- a/src/hash/sha01.h +++ b/src/hash/sha01.h @@ -20,3 +20,5 @@ int l_sha0(lua_State*); int l_sha0_init(lua_State*); int l_sha0_update(lua_State*); int l_sha0_final(lua_State*); + +void sha1(uint8_t* a, size_t len, char* out_stream); diff --git a/src/lua.h b/src/lua.h index b4ea5f2..18ad8b3 100644 --- a/src/lua.h +++ b/src/lua.h @@ -74,6 +74,10 @@ int luaI_errtraceback(lua_State* L); lua_pushstring(L, K);\ lua_pushnil(L);\ lua_settable(L, Tidx); +#define luaI_tsettab(L, Tidx, K)\ + lua_pushstring(L, K);\ + lua_newtable(L);\ + lua_settable(L, Tidx); #define luaI_treplk(L, Tidx, K, nK){\ lua_pushstring(L, K);\ diff --git a/src/net.c b/src/net.c index d2f4e17..c8abff7 100644 --- a/src/net.c +++ b/src/net.c @@ -3,6 +3,7 @@ #include "net/lua.h" #include "net/luai.h" #include "types/str.h" +#include "net/websocket.h" #include @@ -28,6 +29,7 @@ void ssl_init(){ if(has_ssl_init == 0){ has_ssl_init = 1; SSL_library_init(); + OpenSSL_add_all_algorithms(); SSL_load_error_strings(); } } @@ -171,96 +173,29 @@ void free_url(struct url u){ if(u.port != NULL) str_free(u.port); } -#define BUFFER_LEN 4096 -struct wss_data { - SSL* ssl; - SSL_CTX* ctx; - int sock; - str* buffer; -}; - -struct ws_frame_info { - int fin; - int rsv1; - int rsv2; - int rsv3; - int opcode; - int mask; - int payload; -}; - int i_ws_read(lua_State* L){ lua_pushstring(L, "_"); lua_gettable(L, 1); - struct wss_data* data = lua_touserdata(L, -1); - char buffer[BUFFER_LEN] = {0}; - int total_len = 0; - int len; - - for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ - total_len += len; - if(total_len >= 2) break; - } - - if(len < 0) luaI_error(L, len, "SSL_read error"); - - uint64_t payload = 0; - struct ws_frame_info frame_info = {.fin = (buffer[0] >> 7) & 1, .rsv1 = (buffer[0] >> 6) & 1, - .rsv2 = (buffer[0] >> 5) & 1, .rsv3 = (buffer[0] >> 4) & 1, .opcode = buffer[0] & 0b1111, - .mask = (buffer[1] >> 7) & 1, .payload = buffer[1] & 0b1111111}; - //printf("fin: %i\npayload: %i\n", frame_info.fin, frame_info.payload); - memset(buffer, 0, total_len); - total_len = 0; - - if(frame_info.payload <= 125) payload = frame_info.payload; - else if(frame_info.payload == 126) { - for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ - total_len += len; - if(total_len >= 2) break; - } - - if(len < 0) luaI_error(L, len, "SSL_read error"); - - payload = (buffer[0] & 0xff) << 8 | (buffer[1] & 0xff); - } else { - for(; (len = SSL_read(data->ssl, buffer + total_len, 8 - total_len)) > 0;){ - total_len += len; - if(total_len >= 8) break; - } - - if(len < 0) luaI_error(L, -1, "SSL_read error"); - - - payload = ((uint64_t)buffer[0] & 0xff) << 56 | ((uint64_t)buffer[1] & 0xff) << 48 | ((uint64_t)buffer[2] & 0xff) << 40 | - ((uint64_t)buffer[3] & 0xff) << 32 | (buffer[4] & 0xff) << 24 | (buffer[5] & 0xff) << 16 | (buffer[6] & 0xff) << 8 | (buffer[7] & 0xff); - } - //printf("final payload: %lu\n", payload); - - str* message = str_init(""); - memset(buffer, 0, BUFFER_LEN); - for(; message->len != payload && (len = SSL_read(data->ssl, buffer, lesser(payload - message->len, BUFFER_LEN))) > 0;){ - str_pushl(message, buffer, len); - memset(buffer, 0, len); - } - - if(len < 0) { - str_free(message); - luaI_error(L, len, "SSL_read error"); + struct net_data* data = lua_touserdata(L, -1); + struct ws_frame_info frame = {}; + if(ws_read(data, &frame) == -1){ + luaI_error(L, -1, "SSL_read error"); + str_clear(data->buffer); } lua_newtable(L); int idx = lua_gettop(L); - luaI_tsetsl(L, idx, "content", message->c, message->len); - luaI_tseti(L, idx, "opcode", frame_info.opcode); + luaI_tsetsl(L, idx, "content", data->buffer->c, data->buffer->len); + luaI_tseti(L, idx, "opcode", frame.opcode); - str_free(message); + str_clear(data->buffer); return 1; } int i_ws_write(lua_State* L){ lua_pushstring(L, "_"); lua_gettable(L, 1); - struct wss_data* data = lua_touserdata(L, -1); + struct net_data* data = lua_touserdata(L, -1); uint64_t clen; const char* content = luaL_tolstring(L, 2, &clen); @@ -289,7 +224,7 @@ int i_ws_write(lua_State* L){ int i_ws_close(lua_State* L){ lua_pushstring(L, "_"); lua_gettable(L, 1); - struct wss_data* data = lua_touserdata(L, -1); + struct net_data* data = lua_touserdata(L, -1); if(data != NULL && data->buffer != NULL){ str_free(data->buffer); @@ -310,7 +245,7 @@ int i_ws_close(lua_State* L){ return 0; } - +#define BUFFER_LEN 4096 int l_wss(lua_State* L){ uint64_t len = 0; @@ -331,8 +266,8 @@ int l_wss(lua_State* L){ SSL* ssl = ssl_connect(ctx, sock, awa.domain->c); char* request = calloc(512, sizeof * request); - sprintf(request, "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"\ - "Sec-Websocket-Key: aWxvdmVsb3ZlbG92ZXlvdQ==\r\nSec-Websocket-Version: 13\r\n\r\n", path, awa.domain->c); + sprintf(request, "GET %s HTTP/1.1\r\nhost: %s\r\nconnection: upgrade\r\nupgrade: websocket\r\n"\ + "sec-websocket-key: aWxvdmVsb3ZlbG92ZXlvdQ==\r\nsec-websocket-version: 13\r\n\r\n", path, awa.domain->c); int s = SSL_write(ssl, request, strlen(request)); free_url(awa); @@ -360,7 +295,7 @@ int l_wss(lua_State* L){ luaI_error(L, -1, "error with header formating"); } - struct wss_data *data = malloc(sizeof * data); + struct net_data *data = malloc(sizeof * data); data->ssl = ssl; data->ctx = ctx; data->sock = sock; @@ -537,7 +472,7 @@ int _request(lua_State* L, struct request_state* state){ lua_newtable(L); int header_idx = lua_gettop(L); - luaI_tsets(L, header_idx, "User-Agent", "lullaby/"MAJOR_VERSION); + luaI_tsets(L, header_idx, "user-agent", "lullaby/"MAJOR_VERSION); if(params >= 3){ lua_pushvalue(L, header_idx); @@ -564,7 +499,7 @@ int _request(lua_State* L, struct request_state* state){ //char* req = "GET / HTTP/1.1\nHost: amyy.cc\nConnection: Close\n\n"; char* request = calloc(cont_len + header->len + strlen(host) + strlen(path) + 512, sizeof * request); - sprintf(request, "%s %s HTTP/1.1\r\nHost: %s\r\nConnection: Close%s\r\n\r\n%s", action, path, host, header->c, cont); + sprintf(request, "%s %s HTTP/1.1\r\nhost: %s\r\nconnection: close%s\r\n\r\n%s", action, path, host, header->c, cont); if(awa.path != NULL) str_free(awa.path); if(awa.domain != NULL) str_free(awa.domain); @@ -610,16 +545,16 @@ int _request(lua_State* L, struct request_state* state){ luaI_tsets(L, idx, (owo->P[i].key)->c, ((str*)owo->P[i].value)->c); } //done out of pure laziness, parse_header was meant for requests but works fine for responses, change this later - luaI_treplk(L, idx, "Path", "code"); - luaI_treplk(L, idx, "Request", "version"); - luaI_treplk(L, idx, "Version", "code-name"); + luaI_treplk(L, idx, "path", "code"); + luaI_treplk(L, idx, "request", "version"); + luaI_treplk(L, idx, "version", "code-name"); lua_pushstring(L, "code"); lua_gettable(L, idx); int code = atoi(lua_tostring(L, -1)); luaI_tseti(L, idx, "code", code); - void* encoding = parray_get(owo, "Transfer-Encoding"); + void* encoding = parray_get(owo, "transfer-encoding"); //struct _srequest_state *read_state = calloc(sizeof * read_state, 1); //read_state->ctx = state->ctx; @@ -710,8 +645,9 @@ void path_parse(struct net_path_t* path, str* raw){ _Atomic size_t threads = 0; void* handle_client(void *_arg){ + signal(SIGPIPE, SIG_IGN); thread_arg_struct* args = (thread_arg_struct*)_arg; - int client_fd = args->fd; + struct net_data* ctx = args->ctx; char* buffer; int header_eof = -1; lua_State* L = args->L; @@ -719,7 +655,7 @@ void* handle_client(void *_arg){ char* header = NULL; - int64_t bite = recv_header(client_fd, &buffer, &header); + int64_t bite = recv_header(ctx, &buffer, &header); header_eof = header - buffer; if(bite > 0){ @@ -728,13 +664,13 @@ void* handle_client(void *_arg){ //checks for a valid header int val = parse_header(buffer, header_eof + 2, &table); - if(val == -2) net_error(client_fd, 414); + if(val == -2) net_error(ctx, 414); if(val >= 0){ - str* path = (str*)parray_get(table, "Path"); - str* sR = (str*)parray_get(table, "Request"); - str* sT = (str*)parray_get(table, "Content-Type"); - str* sC = (str*)parray_get(table, "Cookie"); + str* path = (str*)parray_get(table, "path"); + str* sR = (str*)parray_get(table, "request"); + str* sT = (str*)parray_get(table, "content-type"); + str* sC = (str*)parray_get(table, "cookie"); struct file_parse* file_cont = calloc(1, sizeof * file_cont); lua_newtable(L); @@ -755,7 +691,7 @@ void* handle_client(void *_arg){ parray_t* v = NULL; if(decoded_err == 1 || args->paths == NULL){ - net_error(client_fd, 400); + net_error(ctx, 400); } else { str_push(aa, decoded_path->c); @@ -787,7 +723,7 @@ void* handle_client(void *_arg){ luaI_tsetv(L, req_idx, "cookies", lcookie); parray_clear(cookie, STR); - parray_remove(table, "Cookie", STR); + parray_remove(table, "cookie", STR); } if(parsed_path.query->len != 0){ @@ -805,7 +741,7 @@ void* handle_client(void *_arg){ int ld = lua_gettop(L); luaI_tsetv(L, req_idx, "_data", ld); luaI_tsetv(L, req_idx, "files", files_idx); - luaI_tsetv(L, req_idx, "Body", body_idx); + luaI_tsetv(L, req_idx, "body", body_idx); for(int i = 0; i != table->len; i+=1){ //printf("'%s' :: '%s'\n",table[i]->c, table[i+1]->c); @@ -817,29 +753,34 @@ void* handle_client(void *_arg){ luaI_tsets(L, req_idx, "rawpath", path->c); if(bite == -1){ - client_fd = -2; + args->ctx->sock = -2; } luaI_tseti(L, req_idx, "_bytes", bite - header_eof - 4); - luaI_tseti(L, req_idx, "client_fd", client_fd); + //luaI_tseti(L, req_idx, "client_fd", client_fd); luaI_tsetcf(L, req_idx, "roll", l_roll); + luaI_tsetv(L, res_idx, "req", req_idx); + luaI_tsetv(L, req_idx, "res", res_idx); + //functions luaI_tsetcf(L, res_idx, "send", l_send); luaI_tsetcf(L, res_idx, "sendfile", l_sendfile); luaI_tsetcf(L, res_idx, "write", l_write); luaI_tsetcf(L, res_idx, "close", l_close); luaI_tsetcf(L, res_idx, "stop", l_stop); + luaI_tsetcf(L, res_idx, "upgrade", l_connection_upgrade); //values - luaI_tseti(L, res_idx, "client_fd", client_fd); + //luaI_tseti(L, res_idx, "client_fd", client_fd); + luaI_tsetlud(L, res_idx, "_", ctx); luaI_tsets(L, res_idx, "_request", sR->c); //header table lua_newtable(L); int header_idx = lua_gettop(L); - luaI_tseti(L, header_idx, "Code", 200); + luaI_tseti(L, header_idx, "code", 200); luaI_tsetv(L, res_idx, "header", header_idx); @@ -885,7 +826,7 @@ void* handle_client(void *_arg){ if(lua_pcall(L, 2, 0, errtraceback_idx) != 0){ fprintf(stderr, "(net thread) %s\n", lua_tostring(L, -1)); //send an error message if send has not been called - if(client_fd >= 0) net_error(client_fd, 500); + if(args->ctx->sock >= 0) net_error(ctx, 500); goto net_end; } @@ -904,9 +845,9 @@ net_end: larray_clear(params); parray_lclear(owo); //dont free the rest - lua_pushstring(L, "client_fd"); - lua_gettable(L, res_idx); - client_fd = luaL_checkinteger(L, -1); + //lua_pushstring(L, "client_fd"); + //lua_gettable(L, res_idx); + //client_fd = luaL_checkinteger(L, -1); } @@ -922,9 +863,12 @@ net_end: parray_clear(table, STR); } - if(client_fd != -1){ - shutdown(client_fd, 2); - closesocket(client_fd); + if(args->ctx->sock != -1){ + if(args->ctx->ssl != NULL) SSL_shutdown(args->ctx->ssl); + else { + shutdown(args->ctx->sock, 2); + closesocket(args->ctx->sock); + } } free(args); @@ -941,12 +885,23 @@ int clean_lullaby_net(lua_State* L){ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* state){ parse_mimetypes(); + if(state->ssl) ssl_init(); //need these on windows for sockets (stupid) #ifdef _WIN32 WSADATA Data; WSAStartup(MAKEWORD(2, 2), &Data); #endif + SSL_CTX* server_ctx; + if(state->ssl){ + if(!(server_ctx = SSL_CTX_new(TLS_server_method()))) + luaI_error(L, -7, "SSL_CTX_new error"); + luaI_assert(L, SSL_CTX_use_certificate_file(server_ctx, state->ssl_crt->c, SSL_FILETYPE_PEM) > 0) + luaI_assert(L, SSL_CTX_use_PrivateKey_file(server_ctx, state->ssl_key->c, SSL_FILETYPE_PEM) > 0) + if (SSL_CTX_check_private_key(server_ctx) == -1) + luaI_error(L, -8, "key does not match crt"); + } + int server_fd; struct sockaddr_in server_addr; struct pollfd fds[2]; @@ -1008,7 +963,8 @@ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* if(threads >= max_con){ //deny request - net_error(*client_fd, 503); +#warning "need a better way to do this with ssl support" + //net_error(*client_fd, 503); close(*client_fd); free(client_fd); @@ -1016,8 +972,31 @@ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* } thread_arg_struct* args = malloc(sizeof * args); - - args->fd = *client_fd; + args->ctx = malloc(sizeof * args->ctx); + + args->ctx->ssl = NULL; + if(state->ssl){ + args->ctx->ssl = SSL_new(server_ctx); + if(!args->ctx->ssl){ + fprintf(stderr, "SSL_new fail\n"); + close(*client_fd); + free(client_fd); + free(args); + continue; + } + SSL_set_fd(args->ctx->ssl, *client_fd); + + int f; + if((f = SSL_accept(args->ctx->ssl)) <= 0){ + fprintf(stderr, "SSL_accept fail %i\n", f); + close(*client_fd); + free(client_fd); + SSL_free(args->ctx->ssl); + free(args); + continue; + } + } + args->ctx->sock = *client_fd; args->port = port; args->cli = client_addr; args->L = luaL_newstate(); @@ -1050,6 +1029,7 @@ net_end: close(server_fd); close(efd); free(state); + SSL_CTX_free(server_ctx); for(int i = 0; i != paths->len; i++){ struct sarray_t* path = paths->P[i].value; @@ -1157,6 +1137,9 @@ int l_listen(lua_State* L){ struct net_server_state *state = malloc(sizeof * state); state->event_fd = -1; + state->ssl = 0; + state->ssl_crt = str_init(""); + state->ssl_key = str_init(""); int port = luaL_checkinteger(L, 2); @@ -1173,6 +1156,8 @@ int l_listen(lua_State* L){ luaI_tsetcf(L, mt, "PATCH", l_PATCHq); luaI_tsetcf(L, mt, "all", l_allq); + luaI_tsettab(L, mt, "ssl"); + luaI_tsetcf(L, mt, "close", l_net_close); luaI_tsetv(L, mt, "port", 2); @@ -1187,6 +1172,22 @@ int l_listen(lua_State* L){ if(state->event_fd == -2) luaI_error(L, -2, "closed"); + lua_getfield(L, mt, "ssl"); + int ssl = lua_gettop(L); + if(!lua_isnil(L, -1) && lua_type(L, ssl) == LUA_TTABLE){ + lua_getfield(L, ssl, "key"); + if(!lua_isnil(L, -1)){ + str_push(state->ssl_key, luaL_checkstring(L, -1)); + lua_getfield(L, ssl, "crt"); + if(!lua_isnil(L, -1)){ + str_push(state->ssl_crt, luaL_checkstring(L, -1)); + lua_pop(L, 2); + state->ssl = 1; + } + } + } + lua_pop(L, 1); + return start_serv(L, port, paths, state); ; } diff --git a/src/net/common.h b/src/net/common.h index 6bb161c..6d2f625 100644 --- a/src/net/common.h +++ b/src/net/common.h @@ -17,6 +17,7 @@ #include "../types/str.h" #include "../types/parray.h" #include "../types/map.h" +#include "ssl.h" #define max_con 200 //2^42 @@ -35,8 +36,16 @@ struct file_parse { int dash_count, table_idx; }; +struct net_data { + SSL* ssl; + SSL_CTX* ctx; + int sock; + str* buffer; +}; + typedef struct { - int fd, ser; + struct net_data* ctx; + int ser; int port; lua_State* L; struct sockaddr_in cli; @@ -51,6 +60,9 @@ struct lchar { struct net_server_state { int event_fd; + int ssl; + str* ssl_key; + str* ssl_crt; }; struct sarray_t { diff --git a/src/net/lua.c b/src/net/lua.c index f005185..d2b0d35 100644 --- a/src/net/lua.c +++ b/src/net/lua.c @@ -1,6 +1,9 @@ #include "lua.h" #include "luai.h" +#include "util.h" #include "common.h" +#include "../hash/fnv.h" +#include "websocket.h" int l_write(lua_State* L){ int res_idx = 1; @@ -12,11 +15,12 @@ int l_write(lua_State* L){ int head = strcmp(luaL_checkstring(L, -1), "HEAD") == 0; lua_pushvalue(L, res_idx); - lua_pushstring(L, "client_fd"); + lua_pushstring(L, "_"); lua_gettable(L, res_idx); - int client_fd = luaL_checkinteger(L, -1); + struct net_data* ctx = lua_touserdata(L, -1); + - client_fd_errors(client_fd); + client_fd_errors(ctx->sock); size_t len; char* content = (char*)luaL_checklstring(L, 2, &len); @@ -42,7 +46,7 @@ int l_write(lua_State* L){ resp = str_init(content); } - send(client_fd, resp->c, resp->len, MSG_NOSIGNAL); + net_ctx_write(ctx, resp->c, resp->len); str_free(resp); return 0; @@ -51,11 +55,11 @@ int l_write(lua_State* L){ int l_send(lua_State* L){ int res_idx = 1; lua_pushvalue(L, res_idx); - lua_pushstring(L, "client_fd"); + lua_pushstring(L, "_"); lua_gettable(L, res_idx); - int client_fd = luaL_checkinteger(L, -1); + struct net_data* ctx = lua_touserdata(L, -1); - client_fd_errors(client_fd); + client_fd_errors(ctx->sock); size_t len; char* content = (char*)luaL_checklstring(L, 2, &len); @@ -75,13 +79,12 @@ int l_send(lua_State* L){ } else i_write_header(L, header, &resp, content, len); - send(client_fd, resp->c, resp->len, MSG_NOSIGNAL); + net_ctx_write(ctx, resp->c, resp->len); + + if(ctx->ssl != NULL) SSL_shutdown(ctx->ssl); + else closesocket(ctx->sock); + ctx->sock = -1; - // - lua_pushstring(L, "client_fd"); - lua_pushinteger(L, -1); - lua_settable(L, res_idx); - closesocket(client_fd); //printf("%i | %i\n'%s'\n%i\n",client_fd,a,resp->c,resp->len); str_free(resp); return 0; @@ -91,15 +94,14 @@ int l_close(lua_State* L){ int res_idx = 1; lua_pushvalue(L, res_idx); - lua_pushstring(L, "client_fd"); + lua_pushstring(L, "_"); lua_gettable(L, res_idx); - int client_fd = luaL_checkinteger(L, -1); - client_fd_errors(client_fd); + struct net_data* ctx = lua_touserdata(L, -1); + client_fd_errors(ctx->sock); - lua_pushstring(L, "client_fd"); - lua_pushinteger(L, -1); - lua_settable(L, res_idx); - closesocket(client_fd); + if(ctx->ssl != NULL) SSL_shutdown(ctx->ssl); + else closesocket(ctx->sock); + ctx->sock = -1; return 0; } @@ -127,7 +129,7 @@ int l_roll(lua_State* L){ lua_gettable(L, 1); int64_t bytes = luaL_checkinteger(L, -1); - lua_pushstring(L, "Content-Length"); + lua_pushstring(L, "content-length"); lua_gettable(L, 1); if(lua_type(L, -1) == LUA_TNIL) { lua_pushinteger(L, -1); @@ -139,30 +141,31 @@ int l_roll(lua_State* L){ struct file_parse* data = (void*)lua_topointer(L, -1); lua_pushvalue(L, 1); - lua_pushstring(L, "client_fd"); + lua_pushstring(L, "_"); lua_gettable(L, 1); - int client_fd = luaL_checkinteger(L, -1); - client_fd_errors(client_fd); + struct net_data* ctx = lua_touserdata(L, -1); + + client_fd_errors(ctx->sock); - fd_set rfd; - FD_ZERO(&rfd); - FD_SET(client_fd, &rfd); + //fd_set rfd; + //FD_ZERO(&rfd); + //FD_SET(client_fd, &rfd); //printf("* %li / %li\n", bytes, content_length); if(bytes >= content_length){ lua_pushinteger(L, -1); return 1; } - if(select(client_fd+1, &rfd, NULL, NULL, &((struct timeval){.tv_sec = 0, .tv_usec = 0})) == 0){ + /*if(select(client_fd+1, &rfd, NULL, NULL, &((struct timeval){.tv_sec = 0, .tv_usec = 0})) == 0){ lua_pushinteger(L, 0); return 1; - } + }*/ //time_start(recv) if(alen == -1) alen = content_length - bytes; char* buffer = malloc(alen * sizeof * buffer); - int r = recv(client_fd, buffer, alen, 0); + int r = net_ctx_read(ctx, buffer, alen); if(r <= 0){ lua_pushinteger(L, r - 1); return 1; @@ -173,7 +176,7 @@ int l_roll(lua_State* L){ lua_pushinteger(L, bytes + r); lua_settable(L, 1); - lua_pushstring(L, "Body"); + lua_pushstring(L, "body"); lua_gettable(L, 1); int body_idx = lua_gettop(L); @@ -183,7 +186,7 @@ int l_roll(lua_State* L){ //time_start(parse) rolling_file_parse(L, &files_idx, &body_idx, buffer, NULL, r, data); //time_end("parse", parse) - luaI_tsetv(L, 1, "Body", body_idx); + luaI_tsetv(L, 1, "body", body_idx); luaI_tsetv(L, 1, "files", files_idx); free(buffer); @@ -214,17 +217,17 @@ int l_sendfile(lua_State* L){ luaI_assert(L, !access(path, R_OK) /*missing permissions*/); lua_pushvalue(L, res_idx); - lua_pushstring(L, "client_fd"); + lua_pushstring(L, "_"); lua_gettable(L, res_idx); - int client_fd = luaL_checkinteger(L, -1); - client_fd_errors(client_fd); + struct net_data* ctx = lua_touserdata(L, -1); + client_fd_errors(ctx->sock); lua_pushvalue(L, res_idx); lua_pushstring(L, "header"); lua_gettable(L, -2); int header = lua_gettop(L); - lua_pushstring(L, "Content-Type"); + lua_pushstring(L, "content-type"); lua_gettable(L, header); char* ext = strrchr(path, '.'); @@ -232,7 +235,7 @@ int l_sendfile(lua_State* L){ char* content_type = map_get(mime_type, ext + 1); if(content_type) - {luaI_tsets(L, header, "Content-Type", content_type);} + {luaI_tsets(L, header, "content-type", content_type);} } char* buffer = calloc(sizeof* buffer, bsize + 1); @@ -243,24 +246,24 @@ int l_sendfile(lua_State* L){ char size[256]; sprintf(size, "%li", sz); - luaI_tsets(L, header, "Content-Length", size); + luaI_tsets(L, header, "content-length", size); if(attachment) { char disp[256]; sprintf(disp, "attachment; filename=\"%s\"", filename); - luaI_tsets(L, header, "Content-Disposition", disp); + luaI_tsets(L, header, "content-disposition", disp); } else { - luaI_tsets(L, header, "Content-Disposition", "inline;"); + luaI_tsets(L, header, "content-disposition", "inline;"); } str* r; i_write_header(L, header, &r, "", 0); - send(client_fd, r->c, r->len, MSG_NOSIGNAL); + net_ctx_write(ctx, r->c, r->len); str_free(r); for(size_t i = 0; i < sz; i += bsize){ fread(buffer, sizeof * buffer, bsize, fp); - if(send(client_fd, buffer, bsize > sz - i ? sz - i : bsize, MSG_NOSIGNAL) == -1) + if(net_ctx_write(ctx, buffer, bsize > sz - i ? sz - i : bsize) == -1) break; } @@ -269,3 +272,25 @@ int l_sendfile(lua_State* L){ return 0; } + +int l_connection_upgrade(lua_State* L){ + int res_idx = 1; + lua_getfield(L, res_idx, "req"); + int req_idx = 2; + + lua_getfield(L, req_idx, "upgrade"); + uint64_t hash, len; + uint8_t* s = (uint8_t*)luaL_checklstring(L, -1, &len); + hash = fnv_1(s, len, v_1); + + switch(hash){ + case 0xf042f81495060e72: //websocket + l_websocket_upgrade(L); + break; + default: + luaI_error(L, -1, "can't upgrade"); + break; + } + + return 0; +} diff --git a/src/net/lua.h b/src/net/lua.h index b1e96fc..492a999 100644 --- a/src/net/lua.h +++ b/src/net/lua.h @@ -6,3 +6,4 @@ int l_close(lua_State* L); int l_stop(lua_State* L); int l_roll(lua_State* L); int l_sendfile(lua_State* L); +int l_connection_upgrade(lua_State* L); diff --git a/src/net/luai.c b/src/net/luai.c index 0b63d18..dc39720 100644 --- a/src/net/luai.c +++ b/src/net/luai.c @@ -11,7 +11,7 @@ void i_write_header(lua_State* L, int header_top, str** _resp, char* content, si for(;lua_next(L, header_top) != 0;){ char* key = (char*)luaL_checklstring(L, -2, NULL); - if(strcmp(key, "Code") != 0){ + if(strcmp(key, "code") != 0){ str_push(header_vs, key); str_push(header_vs, ": "); str_push(header_vs, (char*)luaL_checklstring(L, -1, NULL)); @@ -21,7 +21,7 @@ void i_write_header(lua_State* L, int header_top, str** _resp, char* content, si } lua_pushvalue(L, header_top); - lua_pushstring(L, "Code"); + lua_pushstring(L, "code"); lua_gettable(L, header_top); int code = luaL_checkinteger(L, -1); diff --git a/src/net/ssl.h b/src/net/ssl.h new file mode 100644 index 0000000..4473ced --- /dev/null +++ b/src/net/ssl.h @@ -0,0 +1,2 @@ +#include +#include diff --git a/src/net/util.c b/src/net/util.c index 62b394c..cc478e4 100644 --- a/src/net/util.c +++ b/src/net/util.c @@ -1,7 +1,8 @@ #include "common.h" #include "util.h" +#include -int64_t recv_header(int client_fd, char** _buffer, char** header_eof){ +int64_t recv_header(struct net_data* data, char** _buffer, char** header_eof){ char* buffer = calloc(sizeof* buffer, BUFFER_SIZE); *_buffer = buffer; int64_t len = 0; @@ -9,7 +10,7 @@ int64_t recv_header(int client_fd, char** _buffer, char** header_eof){ *header_eof = 0; for(;;){ - n = recv(client_fd, buffer + len, BUFFER_SIZE, 0); + n = net_ctx_read(data, buffer + len, BUFFER_SIZE); if(n <= 0){ //printf("%s %i\n", strerror(errno), errno); @@ -34,70 +35,9 @@ int64_t recv_header(int client_fd, char** _buffer, char** header_eof){ } } -/** - * @brief calls recv into buffer until everything is read - * - */ -// deprecated!! replaced by recv_header (above) -int64_t recv_full_buffer(int client_fd, char** _buffer, int* header_eof, int* state){ - char* header, *buffer = malloc(BUFFER_SIZE * sizeof * buffer); - memset(buffer, 0, BUFFER_SIZE); - int64_t len = 0; - *header_eof = -1; - int n, content_len = -1; - uint64_t con_len_full = 0; - //printf("&_\n"); - for(;;){ - n = recv(client_fd, buffer + len, BUFFER_SIZE, 0); - if(n < 0){ - *_buffer = buffer; - printf("%s %i\n",strerror(errno),errno); - if(*header_eof == -1) return -2; //dont even try w/ request, no header to read - return -1; //well the header is fine atleast - - }; - if(*header_eof == -1 && (header = strstr(buffer, "\r\n\r\n")) != NULL){ - *header_eof = header - buffer; - char* cont_len_raw = strstr(buffer, "Content-Length: "); - - if(cont_len_raw == NULL) { - len += n; - *_buffer = buffer; - return len; - } - - str* cont_len_str = str_init(""); - if(cont_len_raw == NULL) abort(); - //i is length of 'Content-Length: ' - for(int i = 16; cont_len_raw[i] != '\r'; i++) str_pushl(cont_len_str, cont_len_raw + i, 1); - con_len_full = strtol(cont_len_str->c, NULL, 10); - //if(content_len < 0) p_fatal("idk"); - str_free(cont_len_str); - if(con_len_full > max_content_length) { - *_buffer = buffer; - *state = (len + n != con_len_full + *header_eof + 4); - return len + n; - } - content_len = 1; - buffer = realloc(buffer, con_len_full + *header_eof + 4 + BUFFER_SIZE); - if(buffer == NULL) p_fatal("unable to allocate"); - } - - len += n; - if(len >= MAX_HEADER_SIZE){ - *_buffer = buffer; - return -2;//p_fatal("too large"); - } - if(*header_eof == -1){ - buffer = realloc(buffer, len + BUFFER_SIZE + 1); - memset(buffer + len, 0, n + 1); - } - - - if(content_len != -1 && len - *header_eof - 4 >= con_len_full) break; - } - *_buffer = buffer; - return len; +void lowercase(char* c, uint64_t len){ + for(int i = 0; i != len; i++) + c[i] = tolower(c[i]); } #define max_uri_len 4096 @@ -116,8 +56,8 @@ int parse_header(char* buffer, int header_eof, parray_t** _table){ for(; oi != header_eof; oi++){ if(buffer[oi] == ' ' || buffer[oi] == '\n'){ if(buffer[oi] == '\n') current->c[current->len - 1] = 0; - if(item < 3) parray_set(table, item == 0 ? "Request" : - item == 1 ? "Path" : "Version", (void*)str_init(current->c)); + if(item < 3) parray_set(table, item == 0 ? "request" : + item == 1 ? "path" : "version", (void*)str_init(current->c)); str_clear(current); item++; if(buffer[oi] == '\n') break; @@ -152,6 +92,7 @@ int parse_header(char* buffer, int header_eof, parray_t** _table){ //todo: figure out system to handle this str* id = (str*)parray_get(table, sw->c); if(id != NULL) str_free(id); + lowercase(sw->c, sw->len); parray_set(table, sw->c, (void*)str_init(current->c)); str_clear(current); str_free(sw); @@ -159,7 +100,7 @@ int parse_header(char* buffer, int header_eof, parray_t** _table){ key = 1; } continue; - } else str_pushl(current, buffer + i, 1); + } else str_pushc(current, buffer[i]); } if(sw != NULL){ parray_set(table, sw->c, (void*)str_init(current->c)); @@ -299,6 +240,7 @@ int content_disposition(str* src, parray_t** _dest){ return 1; } +#warning "leak, last calloc" int match_param(char* path, char* match, parray_t* arr){ int pi, index, imatch, start, mi; mi = pi = imatch = start = 0; @@ -523,10 +465,10 @@ void _parse_mimetypes(){ } -int net_error(int fd, int code){ +int net_error(struct net_data* ctx, int code){ char out[512] = {0}; sprintf(out, "HTTP/1.1 %i %s\n\n", code, http_code(code)); - send(fd, out, strlen(out), MSG_NOSIGNAL); + net_ctx_write(ctx, out, strlen(out)); return 0; } @@ -557,3 +499,17 @@ int percent_decode(str* input, str** _output){ *_output = output; return 0; } + +int net_ctx_read(struct net_data* data, void* buffer, size_t c){ + if(data->ssl == NULL){ + return read(data->sock, buffer, c); + } + return SSL_read(data->ssl, buffer, c); +} + +int net_ctx_write(struct net_data* data, void* buffer, size_t c){ + if(data->ssl == NULL){ + return write(data->sock, buffer, c); + } + return SSL_write(data->ssl, buffer, c); +} diff --git a/src/net/util.h b/src/net/util.h index a54ac95..8b11e85 100644 --- a/src/net/util.h +++ b/src/net/util.h @@ -12,8 +12,7 @@ * @param {int*} pointer to an int, will be where the header ends * @return {int64_t} bytes read, -1 if the body was damaged, -2 if the header was */ -int64_t recv_full_buffer(int client_fd, char** _buffer, int* header_eof, int* state); -int64_t recv_header(int client_fd, char** _buffer, char** header_eof); +int64_t recv_header(struct net_data* ctx, char** _buffer, char** header_eof); /** * @brief converts the request buffer into a parray_t @@ -55,7 +54,9 @@ int match_param(char* path, char* match, parray_t* arr); void parse_mimetypes(); -int net_error(int fd, int code); +int net_error(struct net_data* data, int code); int percent_decode(str* input, str** _output); +int net_ctx_read(struct net_data* data, void* buffer, size_t c); +int net_ctx_write(struct net_data* data, void* buffer, size_t c); diff --git a/src/net/websocket.c b/src/net/websocket.c new file mode 100644 index 0000000..af37f32 --- /dev/null +++ b/src/net/websocket.c @@ -0,0 +1,225 @@ +#include "websocket.h" +#include "common.h" +#include "util.h" +#include "../lua.h" +#include "../hash/sha01.h" +#include "../encode/base64.h" +#include +#include + +//why????? +const char* magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +//calloc or zero {out} +void hex_decode(char* in, char* out, uint64_t* outlen){ + uint64_t len = 0; + for(int i = 0; in[i] != '\0'; i++, len++){ + int value = 0; + char c = in[i]; + if(c >= '0' && c <= '9') + value = c - '0'; + else if(c >= 'A' && c <= 'F') + value = 10 + (c - 'A'); + else if(c >= 'a' && c <= 'f') + value = 10 + (c - 'a'); + + out[(i/2)] += value << (((i + 1) % 2) * 4); + //out[i/2] += strtoul((char[]){in[i], in[i+1]}, 0, 16); + } + + *outlen = len/2; +} + +struct ws_frame_info ws_frame_decode(char* buffer){ + struct ws_frame_info frame_info = {.fin = (buffer[0] >> 7) & 1, .rsv1 = (buffer[0] >> 6) & 1, + .rsv2 = (buffer[0] >> 5) & 1, .rsv3 = (buffer[0] >> 4) & 1, .opcode = buffer[0] & 0b1111, + .mask = (buffer[1] >> 7) & 1, .payload = buffer[1] & 0b1111111}; + return frame_info; +} + +void ws_frame_build(struct ws_frame_info frame, uint64_t extended, char* mkey, str* out){ + uint8_t d = (frame.fin << 7) | (frame.rsv1 << 6) | (frame.rsv2 << 5) | (frame.rsv3 << 4) | frame.opcode; + str_pushc(out, d); + frame.payload = extended; + if(extended > 125) frame.payload = 126; + if(extended > UINT16_MAX) frame.payload = 127; + d = (frame.mask << 7) | frame.payload; + str_pushc(out, d); + + if(frame.payload == 126){ + uint16_t sh = extended; + for(int i = 2; i != 0; i--) + str_pushc(out, ((uint8_t*)&sh)[i - 1]); + } else if(frame.payload == 127){ + for(int i = 8; i != 0; i--) + str_pushc(out, ((uint8_t*)&extended)[i - 1]); + } + + if(frame.mask){ + str_pushl(out, mkey, 4); + } +} + +#define BUFFER_LEN 4096 + +int ws_read(struct net_data* data, struct ws_frame_info* frame_info){ + char buffer[BUFFER_LEN] = {0}; + int total_len = 0; + int len; + + for(; (len = net_ctx_read(data, buffer + total_len, 2 - total_len)) > 0;){ + total_len += len; + if(total_len >= 2) break; + } + + if(len < 0 || total_len <= 0) return -1; + + uint64_t payload = 0; + *frame_info = ws_frame_decode(buffer); + memset(buffer, 0, total_len); + total_len = 0; + + if(frame_info->payload <= 125) payload = frame_info->payload; + else if(frame_info->payload == 126) { + for(; (len = net_ctx_read(data, buffer + total_len, 2 - total_len)) > 0;){ + total_len += len; + if(total_len >= 2) break; + } + + if(len < 0) return -1; + + payload = (buffer[0] & 0xff) << 8 | (buffer[1] & 0xff); + } else { + for(; (len = net_ctx_read(data, buffer + total_len, 8 - total_len)) > 0;){ + total_len += len; + if(total_len >= 8) break; + } + + if(len < 0) return -1; + + payload = ((uint64_t)buffer[0] & 0xff) << 56 | ((uint64_t)buffer[1] & 0xff) << 48 | ((uint64_t)buffer[2] & 0xff) << 40 | + ((uint64_t)buffer[3] & 0xff) << 32 | (buffer[4] & 0xff) << 24 | (buffer[5] & 0xff) << 16 | (buffer[6] & 0xff) << 8 | (buffer[7] & 0xff); + } + + total_len = 0; + uint8_t mask[4] = {0}; + if(frame_info->mask){ + for(; (len = net_ctx_read(data, buffer + total_len, 4 - total_len)) > 0;){ + total_len += len; + if(total_len >= 4) break; + } + mask[0] = buffer[0]; + mask[1] = buffer[1]; + mask[2] = buffer[2]; + mask[3] = buffer[3]; + } + + uint64_t i = 0; + memset(buffer, 0, BUFFER_LEN); + for(; data->buffer->len != payload && (len = net_ctx_read(data, buffer, lesser(payload - data->buffer->len, BUFFER_LEN))) > 0;){ + if(frame_info->mask){ + for(int z = 0; z != len; z++,i++) + buffer[z] ^= mask[i % 4]; + } + str_pushl(data->buffer, buffer, len); + memset(buffer, 0, len); + } + + if(len < 0) return -1; + + return 1; +} + +int l_ws_read(lua_State* L){ + lua_getfield(L, 1, "_ws"); + struct net_data* data = lua_touserdata(L, -1); + struct ws_frame_info frame = {}; + int c = ws_read(data, &frame); + if(c == -1){ + luaI_error(L, -1, "SSL_read error"); + str_clear(data->buffer); + } + + lua_newtable(L); + int idx = lua_gettop(L); + luaI_tsetsl(L, idx, "content", data->buffer->c, data->buffer->len); + luaI_tseti(L, idx, "opcode", frame.opcode); + + str_clear(data->buffer); + return 1; +} + +int l_ws_write(lua_State* L){ + lua_getfield(L, 1, "_ws"); + struct net_data* data = lua_touserdata(L, -1); + struct ws_frame_info frame = {}; + frame.rsv1 = 0; + frame.fin = 1; + frame.mask = 0; + frame.opcode = 0b0001; + + uint64_t len; + const char* s = lua_tolstring(L, 2, &len); + + str* f = str_init(""); + ws_frame_build(frame, len, NULL, f); + str_pushl(f, s, len); + write(data->sock, f->c, f->len); + str_free(f); + + return 0; +} + +int l_websocket_upgrade(lua_State* L){ + int res_idx = 1; + int req_idx = 2; + + lua_getfield(L, res_idx, "_"); + struct net_data* ctx = lua_touserdata(L, -1); + + lua_getfield(L, req_idx, "sec-websocket-key"); + const char* wskey = luaL_checklstring(L, -1, NULL); + str* newkey = str_init(wskey); + str_push(newkey, magic); + char* sha = calloc(1512, sizeof * sha); + printf("%s\n", newkey->c); + sha1((uint8_t*)newkey->c, newkey->len, sha); + printf("%s\n", sha); + char* bin = calloc(512, sizeof * bin); + uint64_t len; + hex_decode(sha, bin, &len); + char* b64 = calloc(512 * 3, sizeof * b64); + en_base64(bin, len, b64); + + char* upgrade = calloc(8192, sizeof * upgrade); + sprintf(upgrade, "HTTP/1.1 101 Switching Protocols\r\n" + "upgrade: websocket\r\n" + "connection: upgrade\r\n" + "sec-websocket-accept: %s\r\n" + "\r\n", b64); + + printf("%s\n", upgrade); + net_ctx_write(ctx, upgrade, strlen(upgrade)); + free(upgrade); + free(sha); + free(bin); + free(b64); + str_free(newkey); + + struct net_data *data = calloc(1, sizeof * data); + data->sock = ctx->sock; + data->ssl = ctx->ssl; + data->ctx = ctx->ctx; + data->buffer = str_init(""); + + luaI_tsetnil(L, req_idx, "roll"); + luaI_tsetlud(L, res_idx, "_ws", data); +#warning "missing ws commands" + luaI_tsetcf(L, res_idx, "send", l_ws_write); + //luaI_tsetcf(L, res_idx, "write", ); + //luaI_tsetcf(L, res_idx, "sendfile", ); + luaI_tsetcf(L, res_idx, "read", l_ws_read); + //luaI_tsetcf(L, res_idx, "close", ); + + return 1; +} diff --git a/src/net/websocket.h b/src/net/websocket.h new file mode 100644 index 0000000..98ae1ae --- /dev/null +++ b/src/net/websocket.h @@ -0,0 +1,16 @@ +#include "../lua.h" +#include "common.h" + +struct ws_frame_info { + int fin; + int rsv1; + int rsv2; + int rsv3; + int opcode; + int mask; + int payload; +}; + +int ws_read(struct net_data* data, struct ws_frame_info* frame_info); +struct ws_frame_info ws_frame_decode(char* buffer); +int l_websocket_upgrade(lua_State* L); diff --git a/src/types/str.c b/src/types/str.c index e673473..b4c0b5e 100644 --- a/src/types/str.c +++ b/src/types/str.c @@ -56,6 +56,10 @@ void str_pushl(str* s, const char* insert, size_t l){ s->c[s->len] = '\0'; } +void str_pushc(str* s, char insert){ + str_pushl(s, &insert, 1); +} + void str_clear(str* s){ memset(s->c, 0, s->len); diff --git a/src/types/str.h b/src/types/str.h index e650542..131454c 100644 --- a/src/types/str.h +++ b/src/types/str.h @@ -19,6 +19,7 @@ str* str_init(const char*); void str_free(str*); void str_push(str*, const char*); void str_pushl(str*, const char*, size_t); +void str_pushc(str*, char); void str_clear(str*); void str_popf(str*, int); void str_popb(str*, int); -- cgit v1.2.3