diff options
| author | ame <[email protected]> | 2025-01-22 02:45:23 -0600 |
|---|---|---|
| committer | ame <[email protected]> | 2025-01-22 02:45:23 -0600 |
| commit | e7cafd86b947ad654c8081e238faba4df5bd3c33 (patch) | |
| tree | 0de207c53c902fc86fedb9a6196e4abebbd5277d /src | |
| parent | a86088dc5ded54ccc41bec5d547ff044c641ae19 (diff) | |
work on websockets
Diffstat (limited to 'src')
| -rw-r--r-- | src/lua.c | 4 | ||||
| -rw-r--r-- | src/net.c | 297 | ||||
| -rw-r--r-- | src/net.h | 4 | ||||
| -rw-r--r-- | src/net/util.c | 6 |
4 files changed, 296 insertions, 15 deletions
@@ -144,8 +144,8 @@ void luaI_deepcopy(lua_State* src, lua_State* dest, enum deep_copy_flags flags){ break;
default:
printf("unknown type %i vs (old)%i\n",lua_type(src, -1), type);
- abort();
- lua_pushnumber(dest, 5);
+ //abort();
+ lua_pushnil(dest);
break;
}
int tidx = lua_gettop(dest);
@@ -15,6 +15,9 @@ #include <openssl/ssl.h>
#include <openssl/err.h>
+#include <assert.h>
+#include <signal.h>
+
#define pab(M) {printf(M);abort();}
SSL* ssl_connect(SSL_CTX* ctx, int sockfd, const char* hostname){
@@ -103,6 +106,273 @@ int chunked_encoding_round(char* input, int length, struct chunked_encoding_stat return 0;
}
+struct url {
+ str* proto,
+ * domain,
+ * port,
+ * path;
+};
+
+struct url parse_url(char* url, int len){
+ struct url awa = {0};
+ str* buffer = str_init("");
+ int read = 0;
+
+ for(int i = 0; i != len; i++){
+ if(url[i] == ':'){
+ if(awa.proto == NULL && i + 2 < len && url[i + 1] == '/' && url[i + 2] == '/'){
+ awa.proto = buffer;
+ buffer = str_init("");
+ i += 2;
+ } else if (awa.port == NULL){
+ awa.domain = buffer;
+ buffer = str_init("");
+ read = 1;
+ }
+ } else if(read != 2 && url[i] == '/'){
+ if(read == 1){
+ awa.port = buffer;
+ } else {
+ awa.domain = buffer;
+ }
+ buffer = str_init("");
+ read = 2;
+ i--;
+ } else {
+ str_pushl(buffer, url + i, 1);
+ }
+ }
+
+ if(read == 0) awa.domain = buffer;
+ else if(read == 1) awa.port = buffer;
+ else awa.path = buffer;
+
+ return awa;
+}
+
+void free_url(struct url u){
+ if(u.proto != NULL) str_free(u.proto);
+ if(u.domain != NULL) str_free(u.domain);
+ if(u.path != NULL) str_free(u.path);
+ if(u.port != NULL) str_free(u.port);
+}
+
+#define BUFFER_LEN 4096
+struct wss_data {
+ SSL* ssl;
+ SSL_CTX* ctx;
+ int sock;
+ str* buffer;
+};
+
+struct ws_frame_info {
+ int fin;
+ int rsv1;
+ int rsv2;
+ int rsv3;
+ int opcode;
+ int mask;
+ int payload;
+};
+
+int i_ws_read(lua_State* L){
+ lua_pushstring(L, "_");
+ lua_gettable(L, 1);
+ struct wss_data* data = lua_touserdata(L, -1);
+ char buffer[BUFFER_LEN] = {0};
+ int total_len = 0;
+ int len;
+
+ for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){
+ if(len < 0){
+ lua_pushinteger(L, len);
+ return 1;
+ }
+ total_len += len;
+ if(total_len >= 2) break;
+ }
+
+ uint64_t payload = 0;
+ struct ws_frame_info frame_info = {.fin = (buffer[0] >> 7) & 1, .rsv1 = (buffer[0] >> 6) & 1,
+ .rsv2 = (buffer[0] >> 5) & 1, .rsv3 = (buffer[0] >> 4) & 1, .opcode = buffer[0] & 0b1111,
+ .mask = (buffer[1] >> 7) & 1, .payload = buffer[1] & 0b1111111};
+ //printf("fin: %i\npayload: %i\n", frame_info.fin, frame_info.payload);
+ memset(buffer, 0, total_len);
+ total_len = 0;
+
+ if(frame_info.payload <= 125) payload = frame_info.payload;
+ else if(frame_info.payload == 126) {
+ for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){
+ if(len < 0){
+ lua_pushinteger(L, len);
+ return 1;
+ }
+ total_len += len;
+ if(total_len >= 2) break;
+ }
+
+ payload = (buffer[0] & 0xff) << 8 | buffer[1] & 0xff;
+ } else {
+ for(; (len = SSL_read(data->ssl, buffer + total_len, 8 - total_len)) > 0;){
+ if(len < 0){
+ lua_pushinteger(L, len);
+ return 1;
+ }
+ total_len += len;
+ if(total_len >= 8) break;
+ }
+
+ payload = ((uint64_t)buffer[0] & 0xff) << 56 | ((uint64_t)buffer[1] & 0xff) << 48 | ((uint64_t)buffer[2] & 0xff) << 40 |
+ ((uint64_t)buffer[3] & 0xff) << 32 | (buffer[4] & 0xff) << 24 | (buffer[5] & 0xff) << 16 | (buffer[6] & 0xff) << 8 | (buffer[7] & 0xff);
+ }
+ //printf("final payload: %lu\n", payload);
+
+ str* message = str_init("");
+ memset(buffer, 0, BUFFER_LEN);
+ for(; message->len != payload && (len = SSL_read(data->ssl, buffer, lesser(payload - message->len, BUFFER_LEN))) > 0;){
+ if(len < 0){
+ str_free(message);
+ lua_pushinteger(L, len);
+ return 1;
+ }
+ str_pushl(message, buffer, len);
+ memset(buffer, 0, len);
+ }
+
+ lua_newtable(L);
+ int idx = lua_gettop(L);
+ luaI_tsetsl(L, idx, "content", message->c, message->len);
+ luaI_tseti(L, idx, "opcode", frame_info.opcode);
+
+ str_free(message);
+ return 1;
+}
+
+int i_ws_write(lua_State* L){
+ lua_pushstring(L, "_");
+ lua_gettable(L, 1);
+ struct wss_data* data = lua_touserdata(L, -1);
+
+ uint64_t clen;
+ const char* content = luaL_tolstring(L, 2, &clen);
+ str* send_data = str_init("");
+
+ str_pushl(send_data, (const char[]){0b10000001}, 1);
+ if(clen <= 125) str_pushl(send_data, (const char[]){(0x1 << 7) | clen}, 1);
+ else if(clen <= 65535)
+ str_pushl(send_data, (const char[]){(0x1 << 7) | 126, (clen >> 8) & 0xff, clen & 0xff}, 3);
+ else
+ str_pushl(send_data, (const char[]){(0x1 << 7) | 127, (clen >> 56) & 0xff, (clen >> 48) & 0xff,
+ (clen >> 40) & 0xff, (clen >> 32) & 0xff, (clen >> 24) & 0xff, (clen >> 16) & 0xff,
+ (clen >> 8) & 0xff, clen & 0xff}, 9);
+ str_pushl(send_data, (const char[]){0, 0, 0, 0}, 4);
+ str_pushl(send_data, content, clen);
+
+ int s = SSL_write(data->ssl, send_data->c, send_data->len);
+ lua_pushinteger(L, 1);
+
+ str_free(send_data);
+ return 1;
+}
+
+int i_ws_close(lua_State* L){
+ printf("free\n");
+ lua_pushstring(L, "_");
+ lua_gettable(L, 1);
+ struct wss_data* data = lua_touserdata(L, -1);
+
+ str_free(data->buffer);
+
+ SSL_set_shutdown(data->ssl, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN);
+ SSL_shutdown(data->ssl);
+ SSL_free(data->ssl);
+ SSL_CTX_free(data->ctx);
+
+ close(data->sock);
+
+ free(data);
+ return 0;
+}
+
+int l_wss(lua_State* L){
+ uint64_t len = 0;
+ char* request_url = (char*)lua_tolstring(L, 1, &len);
+ struct url awa = parse_url(request_url, len);
+ if(awa.proto != NULL && strcmp(awa.proto->c, "ws") == 0){
+ //send to l_ws, todo
+ abort();
+ }
+
+ char* port = awa.port == NULL ? "443" : awa.port->c;
+ char* path = awa.path == NULL ? "/" : awa.path->c;
+ int sock = get_host(awa.domain->c, port);
+ int set = 1;
+ signal(SIGPIPE, SIG_IGN);
+
+ SSL_library_init();
+ SSL_load_error_strings();
+ SSL_CTX* ctx = SSL_CTX_new(SSLv23_client_method());
+ SSL* ssl = ssl_connect(ctx, sock, awa.domain->c);
+
+ char* request = calloc(512, sizeof * request);
+ sprintf(request, "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"\
+ "Sec-Websocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-Websocket-Version: 13\r\n\r\n", path, awa.domain->c);
+
+ SSL_write(ssl, request, strlen(request));
+ free_url(awa);
+ free(request);
+
+ char buffer[BUFFER_LEN];
+ int extra_len = len = 0;
+ str* a = str_init("");
+ char* header_eof = NULL;
+ for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){
+ str_pushl(a, buffer, len);
+ if((header_eof = memmem(a->c, a->len, "\r\n\r\n", 4)) != NULL){
+ extra_len = a->len - (header_eof - a->c);
+ break;
+ }
+ memset(buffer, 0, BUFFER_LEN);
+ }
+
+ if(header_eof == NULL){
+ printf("header error\n");
+ lua_pushinteger(L, -1);
+ return 1;
+ }
+
+ struct wss_data *data = malloc(sizeof * data);
+ data->ssl = ssl;
+ data->ctx = ctx;
+ data->sock = sock;
+ data->buffer = str_init("");//str_initl(header_eof, extra_len);
+ str_free(a);
+
+ lua_newtable(L);
+ int idx = lua_gettop(L);
+
+ luaI_tsetlud(L, idx, "_", data);
+ luaI_tsetcf(L, idx, "read", i_ws_read);
+ luaI_tsetcf(L, idx, "write", i_ws_write);
+ luaI_tsetcf(L, idx, "close", i_ws_close);
+
+ lua_newtable(L);
+ int meta_idx = lua_gettop(L);
+
+ luaI_tsetcf(L, meta_idx, "__gc", i_ws_close);
+
+ lua_pushvalue(L, meta_idx);
+ lua_setmetatable(L, idx);
+
+ lua_pushvalue(L, idx);
+
+ //verify stuff here
+ //parray_t* owo = NULL;
+ //parse_header(a->c, header_eof - a->c, &owo);
+
+ return 1;
+}
+
int l_srequest(lua_State* L){
int params = lua_gettop(L);
@@ -166,26 +436,29 @@ int l_srequest(lua_State* L){ //char* req = "GET / HTTP/1.1\nHost: amyy.cc\nConnection: Close\n\n";
- char* request = calloc(cont_len + header->len + 256, sizeof * request);
+ char* request = calloc(cont_len + header->len + 512, sizeof * request);
sprintf(request, "%s %s HTTP/1.1\r\nHost: %s\r\nConnection: Close%s\r\n\r\n%s", action, path, host, header->c, cont);
+ //printf("%s\n", request);
str_free(header);
SSL_write(ssl, request, strlen(request));
free(request);
str* a = str_init("");
- char buffer[512];
+ char buffer[BUFFER_LEN];
int len = 0;
int extra_len = 0;
char* header_eof = NULL;
- for(; (len = SSL_read(ssl, buffer, 511)) > 0;){
+ for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){
+ int blen = a->len;
str_pushl(a, buffer, len);
- if((header_eof = memmem(a->c, a->len, "\r\n\r\n", 4)) != NULL){
+ int offset = blen >= 4 ? 4 : blen;
+ if((header_eof = memmem(a->c + blen - offset, len + offset, "\r\n\r\n", 4)) != NULL){
extra_len = a->len - (header_eof - a->c);
break;
}
- memset(buffer, 0, 512);
+ memset(buffer, 0, BUFFER_LEN);
}
if(header_eof != NULL){
@@ -193,7 +466,7 @@ int l_srequest(lua_State* L){ int idx = lua_gettop(L);
parray_t* owo = NULL;
- parse_header(a->c, header_eof - a->c, &owo);
+ int err = parse_header(a->c, header_eof - a->c, &owo);
for(int i = 0; i != owo->len; i++){
luaI_tsets(L, idx, (owo->P[i].key)->c, ((str*)owo->P[i].value)->c);
@@ -213,22 +486,22 @@ int l_srequest(lua_State* L){ .content = content
};
chunked_encoding_round(header_eof + 4, extra_len - 4, &state);
- memset(buffer, 0, 512);
+ memset(buffer, 0, BUFFER_LEN);
- for(; (len = SSL_read(ssl, buffer, 511)) > 0;){
+ for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){
chunked_encoding_round(buffer, len, &state);
- memset(buffer, 0, 512);
+ memset(buffer, 0, BUFFER_LEN);
}
str_free(state.buffer);
content = state.content;
}
} else {
- memset(buffer, 0, 512);
+ memset(buffer, 0, BUFFER_LEN);
- for(; (len = SSL_read(ssl, buffer, 511)) > 0;){
+ for(; (len = SSL_read(ssl, buffer, BUFFER_LEN)) > 0;){
str_pushl(content, buffer, len);
- memset(buffer, 0, 512);
+ memset(buffer, 0, BUFFER_LEN);
}
}
@@ -18,6 +18,7 @@ int l_listen(lua_State*); int l_request(lua_State*);
int l_srequest(lua_State*);
+int l_wss(lua_State*);
int64_t recv_full_buffer(int client_fd, char** _buffer, int* header_eof, int* state);
@@ -41,8 +42,9 @@ static char* http_codes[600] = {0}; static const luaL_Reg net_function_list [] = {
{"listen",l_listen},
- {"request",l_request},
+ //{"request",l_request},
{"srequest",l_srequest},
+ {"wss",l_wss},
{NULL,NULL}
};
diff --git a/src/net/util.c b/src/net/util.c index 10a58fa..b54c253 100644 --- a/src/net/util.c +++ b/src/net/util.c @@ -114,6 +114,7 @@ int parse_header(char* buffer, int header_eof, parray_t** _table){ str* current = str_init(""); int oi = 0; int item = 0; + for(; oi != header_eof; oi++){ if(buffer[oi] == ' ' || buffer[oi] == '\n'){ if(buffer[oi] == '\n') current->c[current->len - 1] = 0; @@ -133,6 +134,7 @@ int parse_header(char* buffer, int header_eof, parray_t** _table){ } if(item != 3){ + str_free(current); *_table = table; return -1; } @@ -148,6 +150,10 @@ int parse_header(char* buffer, int header_eof, parray_t** _table){ key = 0; } else { if(buffer[oi] == '\n') current->c[current->len - 1] = 0; + //duplicate keys would cause memory leaks, ignore them for now + //todo: figure out system to handle this + str* id = (str*)parray_get(table, sw->c); + if(id != NULL) str_free(id); parray_set(table, sw->c, (void*)str_init(current->c)); str_clear(current); str_free(sw); |
