From e7cafd86b947ad654c8081e238faba4df5bd3c33 Mon Sep 17 00:00:00 2001 From: ame Date: Wed, 22 Jan 2025 02:45:23 -0600 Subject: work on websockets --- src/net.c | 297 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 285 insertions(+), 12 deletions(-) (limited to 'src/net.c') diff --git a/src/net.c b/src/net.c index d1944b5..ba853e2 100644 --- a/src/net.c +++ b/src/net.c @@ -15,6 +15,9 @@ #include #include +#include +#include + #define pab(M) {printf(M);abort();} SSL* ssl_connect(SSL_CTX* ctx, int sockfd, const char* hostname){ @@ -103,6 +106,273 @@ int chunked_encoding_round(char* input, int length, struct chunked_encoding_stat return 0; } +struct url { + str* proto, + * domain, + * port, + * path; +}; + +struct url parse_url(char* url, int len){ + struct url awa = {0}; + str* buffer = str_init(""); + int read = 0; + + for(int i = 0; i != len; i++){ + if(url[i] == ':'){ + if(awa.proto == NULL && i + 2 < len && url[i + 1] == '/' && url[i + 2] == '/'){ + awa.proto = buffer; + buffer = str_init(""); + i += 2; + } else if (awa.port == NULL){ + awa.domain = buffer; + buffer = str_init(""); + read = 1; + } + } else if(read != 2 && url[i] == '/'){ + if(read == 1){ + awa.port = buffer; + } else { + awa.domain = buffer; + } + buffer = str_init(""); + read = 2; + i--; + } else { + str_pushl(buffer, url + i, 1); + } + } + + if(read == 0) awa.domain = buffer; + else if(read == 1) awa.port = buffer; + else awa.path = buffer; + + return awa; +} + +void free_url(struct url u){ + if(u.proto != NULL) str_free(u.proto); + if(u.domain != NULL) str_free(u.domain); + if(u.path != NULL) str_free(u.path); + if(u.port != NULL) str_free(u.port); +} + +#define BUFFER_LEN 4096 +struct wss_data { + SSL* ssl; + SSL_CTX* ctx; + int sock; + str* buffer; +}; + +struct ws_frame_info { + int fin; + int rsv1; + int rsv2; + int rsv3; + int opcode; + int mask; + int payload; +}; + +int i_ws_read(lua_State* L){ + lua_pushstring(L, "_"); + lua_gettable(L, 1); + struct wss_data* data = lua_touserdata(L, -1); + char buffer[BUFFER_LEN] = {0}; + int total_len = 0; + int len; + + for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ + if(len < 0){ + lua_pushinteger(L, len); + return 1; + } + total_len += len; + if(total_len >= 2) break; + } + + uint64_t payload = 0; + struct ws_frame_info frame_info = {.fin = (buffer[0] >> 7) & 1, .rsv1 = (buffer[0] >> 6) & 1, + .rsv2 = (buffer[0] >> 5) & 1, .rsv3 = (buffer[0] >> 4) & 1, .opcode = buffer[0] & 0b1111, + .mask = (buffer[1] >> 7) & 1, .payload = buffer[1] & 0b1111111}; + //printf("fin: %i\npayload: %i\n", frame_info.fin, frame_info.payload); + memset(buffer, 0, total_len); + total_len = 0; + + if(frame_info.payload <= 125) payload = frame_info.payload; + else if(frame_info.payload == 126) { + for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ + if(len < 0){ + lua_pushinteger(L, len); + return 1; + } + total_len += len; + if(total_len >= 2) break; + } + + payload = (buffer[0] & 0xff) << 8 | buffer[1] & 0xff; + } else { + for(; (len = SSL_read(data->ssl, buffer + total_len, 8 - total_len)) > 0;){ + if(len < 0){ + lua_pushinteger(L, len); + return 1; + } + total_len += len; + if(total_len >= 8) break; + } + + payload = ((uint64_t)buffer[0] & 0xff) << 56 | ((uint64_t)buffer[1] & 0xff) << 48 | ((uint64_t)buffer[2] & 0xff) << 40 | + ((uint64_t)buffer[3] & 0xff) << 32 | (buffer[4] & 0xff) << 24 | (buffer[5] & 0xff) << 16 | (buffer[6] & 0xff) << 8 | (buffer[7] & 0xff); + } + //printf("final payload: %lu\n", payload); + + str* message = str_init(""); + memset(buffer, 0, BUFFER_LEN); + for(; message->len != payload && (len = SSL_read(data->ssl, buffer, lesser(payload - message->len, BUFFER_LEN))) > 0;){ + if(len < 0){ + str_free(message); + lua_pushinteger(L, len); + return 1; + } + str_pushl(message, buffer, len); + memset(buffer, 0, len); + } + + lua_newtable(L); + int idx = lua_gettop(L); + luaI_tsetsl(L, idx, "content", message->c, message->len); + luaI_tseti(L, idx, "opcode", frame_info.opcode); + + str_free(message); + return 1; +} + +int i_ws_write(lua_State* L){ + lua_pushstring(L, "_"); + lua_gettable(L, 1); + struct wss_data* data = lua_touserdata(L, -1); + + uint64_t clen; + const char* content = luaL_tolstring(L, 2, &clen); + str* send_data = str_init(""); + + str_pushl(send_data, (const char[]){0b10000001}, 1); + if(clen <= 125) str_pushl(send_data, (const char[]){(0x1 << 7) | clen}, 1); + else if(clen <= 65535) + str_pushl(send_data, (const char[]){(0x1 << 7) | 126, (clen >> 8) & 0xff, clen & 0xff}, 3); + else + str_pushl(send_data, (const char[]){(0x1 << 7) | 127, (clen >> 56) & 0xff, (clen >> 48) & 0xff, + (clen >> 40) & 0xff, (clen >> 32) & 0xff, (clen >> 24) & 0xff, (clen >> 16) & 0xff, + (clen >> 8) & 0xff, clen & 0xff}, 9); + str_pushl(send_data, (const char[]){0, 0, 0, 0}, 4); + str_pushl(send_data, content, clen); + + int s = SSL_write(data->ssl, send_data->c, send_data->len); + lua_pushinteger(L, 1); + + str_free(send_data); + return 1; +} + +int i_ws_close(lua_State* L){ + printf("free\n"); + lua_pushstring(L, "_"); + lua_gettable(L, 1); + struct wss_data* data = lua_touserdata(L, -1); + + str_free(data->buffer); + + SSL_set_shutdown(data->ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); + SSL_shutdown(data->ssl); + SSL_free(data->ssl); + SSL_CTX_free(data->ctx); + + close(data->sock); + + free(data); + return 0; +} + +int l_wss(lua_State* L){ + uint64_t len = 0; + char* request_url = (char*)lua_tolstring(L, 1, &len); + struct url awa = parse_url(request_url, len); + if(awa.proto != NULL && strcmp(awa.proto->c, "ws") == 0){ + //send to l_ws, todo + abort(); + } + + char* port = awa.port == NULL ? "443" : awa.port->c; + char* path = awa.path == NULL ? "/" : awa.path->c; + int sock = get_host(awa.domain->c, port); + int set = 1; + signal(SIGPIPE, SIG_IGN); + + SSL_library_init(); + SSL_load_error_strings(); + SSL_CTX* ctx = SSL_CTX_new(SSLv23_client_method()); + SSL* ssl = ssl_connect(ctx, sock, awa.domain->c); + + char* request = calloc(512, sizeof * request); + sprintf(request, "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"\ + "Sec-Websocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-Websocket-Version: 13\r\n\r\n", path, awa.domain->c); + + SSL_write(ssl, request, strlen(request)); + free_url(awa); + free(request); + + char buffer[BUFFER_LEN]; + int extra_len = len = 0; + str* a = str_init(""); + char* header_eof = NULL; + for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ + str_pushl(a, buffer, len); + if((header_eof = memmem(a->c, a->len, "\r\n\r\n", 4)) != NULL){ + extra_len = a->len - (header_eof - a->c); + break; + } + memset(buffer, 0, BUFFER_LEN); + } + + if(header_eof == NULL){ + printf("header error\n"); + lua_pushinteger(L, -1); + return 1; + } + + struct wss_data *data = malloc(sizeof * data); + data->ssl = ssl; + data->ctx = ctx; + data->sock = sock; + data->buffer = str_init("");//str_initl(header_eof, extra_len); + str_free(a); + + lua_newtable(L); + int idx = lua_gettop(L); + + luaI_tsetlud(L, idx, "_", data); + luaI_tsetcf(L, idx, "read", i_ws_read); + luaI_tsetcf(L, idx, "write", i_ws_write); + luaI_tsetcf(L, idx, "close", i_ws_close); + + lua_newtable(L); + int meta_idx = lua_gettop(L); + + luaI_tsetcf(L, meta_idx, "__gc", i_ws_close); + + lua_pushvalue(L, meta_idx); + lua_setmetatable(L, idx); + + lua_pushvalue(L, idx); + + //verify stuff here + //parray_t* owo = NULL; + //parse_header(a->c, header_eof - a->c, &owo); + + return 1; +} + int l_srequest(lua_State* L){ int params = lua_gettop(L); @@ -166,26 +436,29 @@ int l_srequest(lua_State* L){ //char* req = "GET / HTTP/1.1\nHost: amyy.cc\nConnection: Close\n\n"; - char* request = calloc(cont_len + header->len + 256, sizeof * request); + char* request = calloc(cont_len + header->len + 512, sizeof * request); sprintf(request, "%s %s HTTP/1.1\r\nHost: %s\r\nConnection: Close%s\r\n\r\n%s", action, path, host, header->c, cont); + //printf("%s\n", request); str_free(header); SSL_write(ssl, request, strlen(request)); free(request); str* a = str_init(""); - char buffer[512]; + char buffer[BUFFER_LEN]; int len = 0; int extra_len = 0; char* header_eof = NULL; - for(; (len = SSL_read(ssl, buffer, 511)) > 0;){ + for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ + int blen = a->len; str_pushl(a, buffer, len); - if((header_eof = memmem(a->c, a->len, "\r\n\r\n", 4)) != NULL){ + int offset = blen >= 4 ? 4 : blen; + if((header_eof = memmem(a->c + blen - offset, len + offset, "\r\n\r\n", 4)) != NULL){ extra_len = a->len - (header_eof - a->c); break; } - memset(buffer, 0, 512); + memset(buffer, 0, BUFFER_LEN); } if(header_eof != NULL){ @@ -193,7 +466,7 @@ int l_srequest(lua_State* L){ int idx = lua_gettop(L); parray_t* owo = NULL; - parse_header(a->c, header_eof - a->c, &owo); + int err = parse_header(a->c, header_eof - a->c, &owo); for(int i = 0; i != owo->len; i++){ luaI_tsets(L, idx, (owo->P[i].key)->c, ((str*)owo->P[i].value)->c); @@ -213,22 +486,22 @@ int l_srequest(lua_State* L){ .content = content }; chunked_encoding_round(header_eof + 4, extra_len - 4, &state); - memset(buffer, 0, 512); + memset(buffer, 0, BUFFER_LEN); - for(; (len = SSL_read(ssl, buffer, 511)) > 0;){ + for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ chunked_encoding_round(buffer, len, &state); - memset(buffer, 0, 512); + memset(buffer, 0, BUFFER_LEN); } str_free(state.buffer); content = state.content; } } else { - memset(buffer, 0, 512); + memset(buffer, 0, BUFFER_LEN); - for(; (len = SSL_read(ssl, buffer, 511)) > 0;){ + for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ str_pushl(content, buffer, len); - memset(buffer, 0, 512); + memset(buffer, 0, BUFFER_LEN); } } -- cgit v1.2.3