From 8e7257aac8b30aaa57577770fd636e784361e35d Mon Sep 17 00:00:00 2001 From: ame Date: Thu, 12 Jun 2025 02:10:10 -0500 Subject: fix some net code, add streaming to some requests --- src/io.c | 42 +--------------- src/io.h | 2 - src/lua.c | 118 ++++++++++++++++++++++++++++++++++++++++++++ src/lua.h | 6 +++ src/net.c | 150 ++++++++++++++++++++++++++++++++++++++++++-------------- src/net/luai.c | 6 ++- src/net/util.c | 16 +++--- src/test.c | 12 +++++ src/test.h | 2 + src/types/str.c | 1 + 10 files changed, 267 insertions(+), 88 deletions(-) (limited to 'src') diff --git a/src/io.c b/src/io.c index 0f19933..389d407 100644 --- a/src/io.c +++ b/src/io.c @@ -295,45 +295,5 @@ int l_json_parse(lua_State* L){ return 1; } -int l_arg_handle(lua_State* L){ - luaL_checktype(L, 1, LUA_TTABLE); - luaL_checktype(L, 2, LUA_TTABLE); - - size_t len = lua_objlen(L,1); - for(size_t i = 0; i <= len - 1; i++){ - lua_pushnumber(L,i + 1); - lua_gettable(L,1); - - lua_pushnumber(L,1); - lua_gettable(L,-2); - size_t inner_len = lua_objlen(L,-1); - size_t inner_idx = lua_gettop(L); - - lua_pushnumber(L, 2); - lua_gettable(L, -3); - - size_t function_idx = lua_gettop(L); - - for(int ii = 1; ii <= inner_len; ii++){ - lua_pushnumber(L, ii); - lua_gettable(L, inner_idx); - - const char* key = lua_tostring(L, -1); - - size_t input_len = lua_objlen(L, 2); - - for(int iii = 1; iii <= input_len; iii++){ - lua_pushnumber(L, iii); - lua_gettable(L, 2); - if(strcmp(lua_tostring(L, -1), key) == 0){ - lua_pushvalue(L, function_idx); - lua_pcall(L, 0, 0, 0); - ii = inner_len + 1; - break; - } - } - - } - } - return 0; +int resolve_path(lua_State* L){ } diff --git a/src/io.h b/src/io.h index 36702f9..b959bbd 100644 --- a/src/io.h +++ b/src/io.h @@ -25,7 +25,6 @@ int l_log(lua_State*); int l_warn(lua_State*); int l_error(lua_State*); int l_pprint(lua_State*); -int l_arg_handle(lua_State*); int l_json_parse(lua_State*); @@ -56,6 +55,5 @@ static const luaL_Reg io_function_list [] = { {"error",l_error}, {"pprint",l_pprint}, {"json_parse",l_json_parse}, - {"arg_handle",l_arg_handle}, {NULL,NULL} }; diff --git a/src/lua.c b/src/lua.c index d122b59..ebddd03 100644 --- a/src/lua.c +++ b/src/lua.c @@ -22,6 +22,124 @@ void __free_(void* p){ return (free)(p); } +int _stream_read(lua_State* L){ + uint64_t len = 0; + if(lua_gettop(L) > 1){ + len = lua_tointeger(L, 2); + } + + lua_pushstring(L, "_read"); + lua_gettable(L, 1); + stream_read_function rf = lua_touserdata(L, -1); + + lua_pushstring(L, "_state"); + lua_gettable(L, 1); + void* state = lua_touserdata(L, -1); + + str* cont = str_init(""); + int ret = rf(len, &cont, &state); + + if(ret < 0){ + luaI_error(L, ret, "read error"); + } + + if(ret == 0){ + luaI_tsetb(L, 1, "more", 0); + } + + lua_pushlstring(L, cont->c, cont->len); + free(cont); + return 1; +} + +int _stream_file(lua_State* L){ + const int CHUNK_SIZE = 4096; + uint64_t maxlen = 0; + uint64_t totallen = 0; + if(lua_gettop(L) > 2){ + maxlen = lua_tointeger(L, 3); + } + + lua_pushstring(L, "_read"); + lua_gettable(L, 1); + stream_read_function rf = lua_touserdata(L, -1); + + lua_pushstring(L, "_state"); + lua_gettable(L, 1); + void* state = lua_touserdata(L, -1); + + const char* filename = lua_tostring(L, 2); + FILE *f; + f = fopen(filename, "w"); + if(f == NULL){ + luaI_error(L, -1, "unable to open file"); + } + + str* cont = str_init(""); + for(;;){ + int ret = rf(CHUNK_SIZE, &cont, &state); + //printf("%s\n", cont->c); + + if(ret < 0){ + fclose(f); + luaI_error(L, ret, "read error"); + } + + fwrite(cont->c, sizeof * cont->c, cont->len, f); + totallen += cont->len; + str_clear(cont); + + if(ret == 0 || totallen >= maxlen){ + if(ret == 0) {luaI_tsetb(L, 1, "more", 0);} + break; + } + } + + fclose(f); + return 0; +} + +int _stream_free(lua_State* L){ + lua_pushstring(L, "_free"); + lua_gettable(L, 1); + void* rf = lua_touserdata(L, -1); + + lua_pushstring(L, "_state"); + lua_gettable(L, 1); + void* state = lua_touserdata(L, -1); + + printf("call free\n"); + if(rf != NULL){ + printf("run free\n"); + ((stream_free_function)rf)(&state); + } + return 0; +} + +void luaI_newstream(lua_State* L, stream_read_function readf, stream_free_function freef, void* state){ + lua_newtable(L); + int tidx = lua_gettop(L); + + luaI_tsetlud(L, tidx, "_read", readf); + luaI_tsetlud(L, tidx, "_free", freef); + luaI_tsetlud(L, tidx, "_state", state); + luaI_tsetcf(L, tidx, "read", _stream_read); + luaI_tsetcf(L, tidx, "close", _stream_free); + luaI_tsetb(L, tidx, "more", 1); + luaI_tsetcf(L, tidx, "file", _stream_file); + + lua_newtable(L); + int midx = lua_gettop(L); + + luaI_tsetcf(L, midx, "__gc", _stream_free); + + lua_pushvalue(L, midx); + lua_setmetatable(L, tidx); + + lua_pushvalue(L, tidx); +} + + int writer(lua_State *L, const void* p, size_t sz, void* ud){ char o[2] = {0}; for (int i =0; i #include #include +#include +#include "types/str.h" #ifndef __lua_h #define __lua_h @@ -31,6 +33,10 @@ void luaI_copyvars(lua_State* src, lua_State* dest); void lua_upvalue_key_table(lua_State* L, int fidx); int lua_assign_upvalues(lua_State* L, int fidx); +typedef int (*stream_read_function)(uint64_t, str**, void**); +typedef int (*stream_free_function)(void**); +void luaI_newstream(lua_State* L, stream_read_function, stream_free_function, void*); + //generic macro that takes other macros (see below) #define _tset_b(L, Tidx, K, V, F)\ lua_pushstring(L, K);\ diff --git a/src/net.c b/src/net.c index 9089f49..fc68a91 100644 --- a/src/net.c +++ b/src/net.c @@ -84,6 +84,7 @@ struct chunked_encoding_state { str* content; }; +//remove this eventually int chunked_encoding_round(char* input, int length, struct chunked_encoding_state* state){ //printf("'%s'\n", input); for(int i = 0; i < length; i++){ @@ -390,6 +391,96 @@ int l_wss(lua_State* L){ return 1; } +struct _srequest_state { + SSL* ssl; + SSL_CTX* ctx; + int sock; + str* buffer; //anything pre-existing + struct chunked_encoding_state* state; +}; + +int _srequest_free(void** _state){ + struct _srequest_state* state = *((struct _srequest_state**)_state); + SSL_set_shutdown(state->ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); + SSL_shutdown(state->ssl); + SSL_free(state->ssl); + SSL_CTX_free(state->ctx); + close(state->sock); + + if(state->state != NULL){ + str_free(state->state->buffer); + str_free(state->state->content); + free(state->state); + } + + str_free(state->buffer); + free(state); + return 0; +} + +int _srequest_read(uint64_t reqlen, str** _output, void** _state){ + struct _srequest_state* state = *((struct _srequest_state**)_state); + str* output = *_output; + //states using chunked encoding should skip this + if(state->buffer != NULL){ + str_pushl(output, state->buffer->c, state->buffer->len); + str_free(state->buffer); + state->buffer = NULL; + } + + char buffer[BUFFER_LEN]; + memset(buffer, 0, BUFFER_LEN); + uint64_t len; + for(; (len = SSL_read(state->ssl, buffer, BUFFER_LEN)) > 0;){ + if(state->state != NULL){ + chunked_encoding_round(buffer, len, state->state); + } else { + str_pushl(output, buffer, len); + } + memset(buffer, 0, BUFFER_LEN); + } + + if(state->state == NULL){ + str_pushl(output, state->state->content->c, state->state->content->len); + str_clear(state->state->content); + } + + *_output = output; + return 1; +} + +int _srequest_chunked_encoding(char* input, int length, struct chunked_encoding_state* state){ + //printf("'%s'\n", input); + for(int i = 0; i < length; i++){ + //printf("%i/%i\n", i, length); + if(state->reading_length){ + str_pushl(state->buffer, input + i, 1); + + if(state->buffer->len >= 2 && memmem(state->buffer->c + state->buffer->len - 2, 2, "\r\n", 2)){ + + str_popb(state->buffer, 2); + state->chunk_length = strtoll(state->buffer->c, NULL, 16); + str_clear(state->buffer); + state->reading_length = 0; + } + } else { + int len = lesser(state->chunk_length - state->buffer->len, length - i); + str_pushl(state->buffer, input + i, len); + i += len; + + if(state->buffer->len >= state->chunk_length){ + state->reading_length = 1; + str_pushl(state->content, state->buffer->c, state->buffer->len); + str_clear(state->buffer); + } + } + } + + //printf("buffer '%s'\n", state->buffer->c); + + return 0; +} + int l_srequest(lua_State* L){ int params = lua_gettop(L); @@ -452,7 +543,6 @@ int l_srequest(lua_State* L){ action = (char*)lua_tostring(L, check); } - //char* req = "GET / HTTP/1.1\nHost: amyy.cc\nConnection: Close\n\n"; char* request = calloc(cont_len + header->len + 512, sizeof * request); @@ -484,7 +574,6 @@ int l_srequest(lua_State* L){ if(len < 0) luaI_error(L, len, "SSL_read error"); - if(header_eof != NULL){ lua_newtable(L); int idx = lua_gettop(L); @@ -502,55 +591,45 @@ int l_srequest(lua_State* L){ luaI_treplk(L, idx, "Request", "version"); luaI_treplk(L, idx, "Version", "code-name"); - str* content = str_init(""); void* encoding = parray_get(owo, "Transfer-Encoding"); + + struct _srequest_state *read_state = calloc(sizeof * read_state, 1); + read_state->ctx = ctx; + read_state->ssl = ssl; + read_state->sock = sock; if(encoding != NULL){ if(strcmp(((str*)encoding)->c, "chunked") == 0){ - struct chunked_encoding_state state = { - .reading_length = 1, - .buffer = str_init(""), - .content = content - }; - chunked_encoding_round(header_eof + 4, extra_len - 4, &state); - memset(buffer, 0, BUFFER_LEN); - - for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ - chunked_encoding_round(buffer, len, &state); - memset(buffer, 0, BUFFER_LEN); - } - - if(len < 0) luaI_error(L, len, "SSL_read error"); - - str_free(state.buffer); - content = state.content; - } - } else { - str_pushl(content, header_eof + 4, extra_len - 4); - memset(buffer, 0, BUFFER_LEN); + struct chunked_encoding_state* state = calloc(sizeof * state, 1); + state->reading_length = 1; + state->buffer = str_init(""); + state->content = str_init(""); - for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ - str_pushl(content, buffer, len); + chunked_encoding_round(header_eof + 4, extra_len - 4, state); memset(buffer, 0, BUFFER_LEN); + + read_state->buffer = str_init(""); + read_state->state = state; } - - if(len < 0) luaI_error(L, len, "SSL_read error"); + } else { + read_state->buffer = str_initl(header_eof + 4, extra_len - 4); } - parray_clear(owo, STR); - luaI_tsetsl(L, idx, "content", content->c, content->len); - str_free(content); + luaI_newstream(L, _srequest_read, _srequest_free, read_state); + int v = lua_gettop(L); + luaI_tsetv(L, idx, "content", v); + lua_pushvalue(L, idx); } else { - lua_pushstring(L, a->c); + luaI_error(L, -1, "error with header"); } str_free(a); - SSL_set_shutdown(ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); + /*SSL_set_shutdown(ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); SSL_shutdown(ssl); SSL_free(ssl); SSL_CTX_free(ctx); - close(sock); + close(sock);*/ return 1; } @@ -651,8 +730,7 @@ void* handle_client(void *_arg){ if(bytes_received == -2) net_error(client_fd, 431); */ - //ignore if header is just fucked - if(bite >= -1){ + if(bite > 0){ parray_t* table; //checks for a valid header diff --git a/src/net/luai.c b/src/net/luai.c index 3e16dcb..205f217 100644 --- a/src/net/luai.c +++ b/src/net/luai.c @@ -75,7 +75,11 @@ int rolling_file_parse(lua_State* L, int* files_idx, int* body_idx, char* buffer //parray_set(content, "_current", (void*)(current)); content.boundary_id = str_init(""); - str_popb(content.boundary, 4); + + //quick fix? + //str_popb(content.boundary, 4); + if(content.boundary->len >= 4) str_popb(content.boundary, 4); + //parray_set(content, "_boundary", (void*)boundary); //parray_set(content, "_boundary_id", (void*)boundary_id); diff --git a/src/net/util.c b/src/net/util.c index 76ec30e..92d3bce 100644 --- a/src/net/util.c +++ b/src/net/util.c @@ -11,21 +11,21 @@ int64_t recv_header(int client_fd, char** _buffer, char** header_eof){ for(;;){ n = recv(client_fd, buffer + len, BUFFER_SIZE, 0); - if(n < 0){ - printf("%s %i\n", strerror(errno), errno); + if(n <= 0){ + //printf("%s %i\n", strerror(errno), errno); return -1; } - + + if((len += n) >= MAX_HEADER_SIZE){ + return -2; + } + // search the last 4 characters too if they exist // this could probably be changed to 3 int64_t start_len = len - 4 > 0 ? len - 4 : 0; int64_t search_end = len - 4 > 0 ? n + 4 : n; if((*header_eof = memmem(buffer + start_len, search_end, "\r\n\r\n", 4)) != NULL){ - return len + n; - } - - if((len += n) >= MAX_HEADER_SIZE){ - return -2; + return len; } buffer = realloc(buffer, sizeof* buffer * (len + BUFFER_SIZE + 1)); diff --git a/src/test.c b/src/test.c index 4403e60..9ae5e5b 100644 --- a/src/test.c +++ b/src/test.c @@ -55,3 +55,15 @@ int l_upvalue_key_table(lua_State* L){ lua_upvalue_key_table(L, 1); return 1; } + +int rea(uint64_t len, str** _output, void** v){ + str* output = *_output; + str_push(output, "awa!!!! test\n"); + *_output = output; + return 1; +} + +int l_stream_test(lua_State* L){ + luaI_newstream(L, rea, NULL, 0); + return 1; +} diff --git a/src/test.h b/src/test.h index 2eceaeb..bb39cc8 100644 --- a/src/test.h +++ b/src/test.h @@ -4,11 +4,13 @@ int ld_match(lua_State*); int l_stack_dump(lua_State*); int l_upvalue_key_table(lua_State* L); +int l_stream_test(lua_State* L); static const luaL_Reg test_function_list [] = { {"_match", ld_match}, {"stack_dump", l_stack_dump}, {"upvalue_key_table", l_upvalue_key_table}, + {"stream", l_stream_test}, {NULL,NULL} }; diff --git a/src/types/str.c b/src/types/str.c index 4b257cb..0c8d63a 100644 --- a/src/types/str.c +++ b/src/types/str.c @@ -56,6 +56,7 @@ void str_popf(str* s, int len){ void str_popb(str* s, int len){ s->len -= len; + s->len = s->len > 0 ? s->len : 0; s->c[s->len] = 0; } -- cgit v1.2.3