From 8e7257aac8b30aaa57577770fd636e784361e35d Mon Sep 17 00:00:00 2001 From: ame Date: Thu, 12 Jun 2025 02:10:10 -0500 Subject: fix some net code, add streaming to some requests --- src/net.c | 150 +++++++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 114 insertions(+), 36 deletions(-) (limited to 'src/net.c') diff --git a/src/net.c b/src/net.c index 9089f49..fc68a91 100644 --- a/src/net.c +++ b/src/net.c @@ -84,6 +84,7 @@ struct chunked_encoding_state { str* content; }; +//remove this eventually int chunked_encoding_round(char* input, int length, struct chunked_encoding_state* state){ //printf("'%s'\n", input); for(int i = 0; i < length; i++){ @@ -390,6 +391,96 @@ int l_wss(lua_State* L){ return 1; } +struct _srequest_state { + SSL* ssl; + SSL_CTX* ctx; + int sock; + str* buffer; //anything pre-existing + struct chunked_encoding_state* state; +}; + +int _srequest_free(void** _state){ + struct _srequest_state* state = *((struct _srequest_state**)_state); + SSL_set_shutdown(state->ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); + SSL_shutdown(state->ssl); + SSL_free(state->ssl); + SSL_CTX_free(state->ctx); + close(state->sock); + + if(state->state != NULL){ + str_free(state->state->buffer); + str_free(state->state->content); + free(state->state); + } + + str_free(state->buffer); + free(state); + return 0; +} + +int _srequest_read(uint64_t reqlen, str** _output, void** _state){ + struct _srequest_state* state = *((struct _srequest_state**)_state); + str* output = *_output; + //states using chunked encoding should skip this + if(state->buffer != NULL){ + str_pushl(output, state->buffer->c, state->buffer->len); + str_free(state->buffer); + state->buffer = NULL; + } + + char buffer[BUFFER_LEN]; + memset(buffer, 0, BUFFER_LEN); + uint64_t len; + for(; (len = SSL_read(state->ssl, buffer, BUFFER_LEN)) > 0;){ + if(state->state != NULL){ + chunked_encoding_round(buffer, len, state->state); + } else { + str_pushl(output, buffer, len); + } + memset(buffer, 0, BUFFER_LEN); + } + + if(state->state == NULL){ + str_pushl(output, state->state->content->c, state->state->content->len); + str_clear(state->state->content); + } + + *_output = output; + return 1; +} + +int _srequest_chunked_encoding(char* input, int length, struct chunked_encoding_state* state){ + //printf("'%s'\n", input); + for(int i = 0; i < length; i++){ + //printf("%i/%i\n", i, length); + if(state->reading_length){ + str_pushl(state->buffer, input + i, 1); + + if(state->buffer->len >= 2 && memmem(state->buffer->c + state->buffer->len - 2, 2, "\r\n", 2)){ + + str_popb(state->buffer, 2); + state->chunk_length = strtoll(state->buffer->c, NULL, 16); + str_clear(state->buffer); + state->reading_length = 0; + } + } else { + int len = lesser(state->chunk_length - state->buffer->len, length - i); + str_pushl(state->buffer, input + i, len); + i += len; + + if(state->buffer->len >= state->chunk_length){ + state->reading_length = 1; + str_pushl(state->content, state->buffer->c, state->buffer->len); + str_clear(state->buffer); + } + } + } + + //printf("buffer '%s'\n", state->buffer->c); + + return 0; +} + int l_srequest(lua_State* L){ int params = lua_gettop(L); @@ -452,7 +543,6 @@ int l_srequest(lua_State* L){ action = (char*)lua_tostring(L, check); } - //char* req = "GET / HTTP/1.1\nHost: amyy.cc\nConnection: Close\n\n"; char* request = calloc(cont_len + header->len + 512, sizeof * request); @@ -484,7 +574,6 @@ int l_srequest(lua_State* L){ if(len < 0) luaI_error(L, len, "SSL_read error"); - if(header_eof != NULL){ lua_newtable(L); int idx = lua_gettop(L); @@ -502,55 +591,45 @@ int l_srequest(lua_State* L){ luaI_treplk(L, idx, "Request", "version"); luaI_treplk(L, idx, "Version", "code-name"); - str* content = str_init(""); void* encoding = parray_get(owo, "Transfer-Encoding"); + + struct _srequest_state *read_state = calloc(sizeof * read_state, 1); + read_state->ctx = ctx; + read_state->ssl = ssl; + read_state->sock = sock; if(encoding != NULL){ if(strcmp(((str*)encoding)->c, "chunked") == 0){ - struct chunked_encoding_state state = { - .reading_length = 1, - .buffer = str_init(""), - .content = content - }; - chunked_encoding_round(header_eof + 4, extra_len - 4, &state); - memset(buffer, 0, BUFFER_LEN); - - for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ - chunked_encoding_round(buffer, len, &state); - memset(buffer, 0, BUFFER_LEN); - } - - if(len < 0) luaI_error(L, len, "SSL_read error"); - - str_free(state.buffer); - content = state.content; - } - } else { - str_pushl(content, header_eof + 4, extra_len - 4); - memset(buffer, 0, BUFFER_LEN); + struct chunked_encoding_state* state = calloc(sizeof * state, 1); + state->reading_length = 1; + state->buffer = str_init(""); + state->content = str_init(""); - for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){ - str_pushl(content, buffer, len); + chunked_encoding_round(header_eof + 4, extra_len - 4, state); memset(buffer, 0, BUFFER_LEN); + + read_state->buffer = str_init(""); + read_state->state = state; } - - if(len < 0) luaI_error(L, len, "SSL_read error"); + } else { + read_state->buffer = str_initl(header_eof + 4, extra_len - 4); } - parray_clear(owo, STR); - luaI_tsetsl(L, idx, "content", content->c, content->len); - str_free(content); + luaI_newstream(L, _srequest_read, _srequest_free, read_state); + int v = lua_gettop(L); + luaI_tsetv(L, idx, "content", v); + lua_pushvalue(L, idx); } else { - lua_pushstring(L, a->c); + luaI_error(L, -1, "error with header"); } str_free(a); - SSL_set_shutdown(ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); + /*SSL_set_shutdown(ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); SSL_shutdown(ssl); SSL_free(ssl); SSL_CTX_free(ctx); - close(sock); + close(sock);*/ return 1; } @@ -651,8 +730,7 @@ void* handle_client(void *_arg){ if(bytes_received == -2) net_error(client_fd, 431); */ - //ignore if header is just fucked - if(bite >= -1){ + if(bite > 0){ parray_t* table; //checks for a valid header -- cgit v1.2.3