From 8c46301da70740ee40f707e88dd170243c6dd46a Mon Sep 17 00:00:00 2001 From: amelia squires Date: Mon, 24 Feb 2025 13:23:19 -0600 Subject: error handling and parse_url on srequest --- src/net.c | 94 +++++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 53 insertions(+), 41 deletions(-) (limited to 'src/net.c') diff --git a/src/net.c b/src/net.c index 334acc8..f62cfc6 100644 --- a/src/net.c +++ b/src/net.c @@ -193,14 +193,12 @@ int i_ws_read(lua_State* L){ int total_len = 0; int len; - for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ - if(len < 0){ - lua_pushinteger(L, len); - return 1; - } + for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ total_len += len; if(total_len >= 2) break; } + + if(len < 0) luaI_error(L, len, "SSL_read error"); uint64_t payload = 0; struct ws_frame_info frame_info = {.fin = (buffer[0] >> 7) & 1, .rsv1 = (buffer[0] >> 6) & 1, @@ -212,25 +210,22 @@ int i_ws_read(lua_State* L){ if(frame_info.payload <= 125) payload = frame_info.payload; else if(frame_info.payload == 126) { - for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ - if(len < 0){ - lua_pushinteger(L, len); - return 1; - } + for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){ total_len += len; if(total_len >= 2) break; } + if(len < 0) luaI_error(L, len, "SSL_read error"); + payload = (buffer[0] & 0xff) << 8 | buffer[1] & 0xff; } else { for(; (len = SSL_read(data->ssl, buffer + total_len, 8 - total_len)) > 0;){ - if(len < 0){ - lua_pushinteger(L, len); - return 1; - } total_len += len; if(total_len >= 8) break; - } + } + + if(len < 0) luaI_error(L, -1, "SSL_read error"); + payload = ((uint64_t)buffer[0] & 0xff) << 56 | ((uint64_t)buffer[1] & 0xff) << 48 | ((uint64_t)buffer[2] & 0xff) << 40 | ((uint64_t)buffer[3] & 0xff) << 32 | (buffer[4] & 0xff) << 24 | (buffer[5] & 0xff) << 16 | (buffer[6] & 0xff) << 8 | (buffer[7] & 0xff); @@ -239,15 +234,16 @@ int i_ws_read(lua_State* L){ str* message = str_init(""); memset(buffer, 0, BUFFER_LEN); - for(; message->len != payload && (len = SSL_read(data->ssl, buffer, lesser(payload - message->len, BUFFER_LEN))) > 0;){ - if(len < 0){ - str_free(message); - lua_pushinteger(L, len); - return 1; - } + for(; message->len != payload && (len = SSL_read(data->ssl, buffer, lesser(payload - message->len, BUFFER_LEN))) > 0;){ str_pushl(message, buffer, len); memset(buffer, 0, len); } + + if(len < 0) { + str_free(message); + luaI_error(L, len, "SSL_read error"); + } + lua_newtable(L); int idx = lua_gettop(L); @@ -279,6 +275,7 @@ int i_ws_write(lua_State* L){ str_pushl(send_data, content, clen); int s = SSL_write(data->ssl, send_data->c, send_data->len); + if(s <= 0) luaI_error(L, s, "SSL_write error"); lua_pushinteger(L, 1); str_free(send_data); @@ -290,7 +287,7 @@ int i_ws_close(lua_State* L){ lua_gettable(L, 1); struct wss_data* data = lua_touserdata(L, -1); - if(data != NULL){ + if(data != NULL && data->buffer != NULL){ str_free(data->buffer); SSL_set_shutdown(data->ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); @@ -303,7 +300,9 @@ int i_ws_close(lua_State* L){ free(data); } - data = NULL; + lua_pushstring(L, "_"); + lua_pushnil(L); + lua_settable(L, 1); return 0; } @@ -330,10 +329,14 @@ int l_wss(lua_State* L){ sprintf(request, "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"\ "Sec-Websocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-Websocket-Version: 13\r\n\r\n", path, awa.domain->c); - SSL_write(ssl, request, strlen(request)); + int s = SSL_write(ssl, request, strlen(request)); free_url(awa); free(request); + if(s <= 0){ + luaI_error(L, s, "SSL_write error"); + } + char buffer[BUFFER_LEN]; int extra_len = len = 0; str* a = str_init(""); @@ -347,10 +350,10 @@ int l_wss(lua_State* L){ memset(buffer, 0, BUFFER_LEN); } + if(len < 0) luaI_error(L, len, "SSL_read error"); + if(header_eof == NULL){ - printf("header error\n"); - lua_pushinteger(L, -1); - return 1; + luaI_error(L, -1, "error with header formating"); } struct wss_data *data = malloc(sizeof * data); @@ -389,29 +392,29 @@ int l_srequest(lua_State* L){ int params = lua_gettop(L); int check = 1; - const char* host = luaL_checkstring(L, check); - const char* port = "443"; - if(lua_isnumber(L, check + 1)){ - check++; - port = lua_tostring(L, check); + uint64_t ilen = 0; + char* request_url = (char*)lua_tolstring(L, 1, &ilen); + struct url awa = parse_url(request_url, ilen); + if(awa.proto != NULL && strcmp(awa.proto->c, "http") == 0){ + //send to l_request, todo + abort(); } + const char* host = awa.domain == NULL ? request_url : awa.domain->c; + const char* port = awa.port == NULL ? "443" : awa.port->c; + char* path = awa.path == NULL ? "/" : awa.path->c; int sock = get_host((char*)host, (char*)port); if(sock == -1){ - p_fatal("could not resolve address"); - abort(); + //p_fatal("could not resolve address"); + //abort(); + luaI_error(L, -1, "error resolving address"); } ssl_init(); SSL_CTX* ctx = SSL_CTX_new(SSLv23_client_method()); SSL* ssl = ssl_connect(ctx, sock, host); + if(ssl == NULL) luaI_error(L, -1, "ssl_connect error"); - char* path = "/"; - if(params >= check + 1){ - check++; - path = (char*)luaL_checkstring(L, check); - } - char* cont = ""; size_t cont_len = 0; if(params >= check + 1){ @@ -455,9 +458,11 @@ int l_srequest(lua_State* L){ //printf("%s\n", request); str_free(header); - SSL_write(ssl, request, strlen(request)); + int s = SSL_write(ssl, request, strlen(request)); free(request); + if(s <= 0) luaI_error(L, s, "SSL_write error"); + str* a = str_init(""); char buffer[BUFFER_LEN]; int len = 0; @@ -475,6 +480,9 @@ int l_srequest(lua_State* L){ memset(buffer, 0, BUFFER_LEN); } + if(len < 0) luaI_error(L, len, "SSL_read error"); + + if(header_eof != NULL){ lua_newtable(L); int idx = lua_gettop(L); @@ -509,6 +517,8 @@ int l_srequest(lua_State* L){ memset(buffer, 0, BUFFER_LEN); } + if(len < 0) luaI_error(L, len, "SSL_read error"); + str_free(state.buffer); content = state.content; } @@ -520,6 +530,8 @@ int l_srequest(lua_State* L){ str_pushl(content, buffer, len); memset(buffer, 0, BUFFER_LEN); } + + if(len < 0) luaI_error(L, len, "SSL_read error"); } parray_clear(owo, STR); -- cgit v1.2.3