本教程基于libuv,需要自行下载和搭建libuv开发环境,参考我之前的文章:C/C++服务器基础(网络、协议、数据库)
netbus框架
1: 每个服务器,都是基于netbus框架;
2: 基于netbus框架来开发上面的每个服务器;
3: 通用模块:
(1): 网络UDP/TCP/websocket;
(2): 数据库mysql/redis模块;
(3): timer模块, 时间戳;
(4): Log模块;
(5): protobuf协议模块;
(6): http模块;
(7): json, base64, MD5, sha1;
(8): service业务服务 –> 来开发自己的业务逻辑;
…..
1: 搭建目录结构结构
3rd: 第三方代码
apps: 放各个服务器的代码:
gateway, center_server, game_server, system_server, test
build: 放跨平台的编译工程与Makefile;
netbus: 基本的框架;
utils: 自己扩展的工具函数;
Session管理
1: 每一个TCP连接进来后,服务器都要保存住,来和他通讯;
2: 当服务器要发送数据到某个客户端的时候,我们要找到这个连接然后发送出去;
3: 所以我们要做好这些客户端的连接管理,称为session管理;
4: 我们把每一个连接,以及处理这个链接的上下文我们叫做session;
5: session管理的两大考虑要素:
(1) 服务器监听session是否有数据可读,随时要内存来读取数据,所以准备好读取数据内存;
(2)异步写数据的时候,需要保存写的buffer —> 4K;
(3) 客户端会有几万的连接进来,和离开,这样session就会面临不断的释放和分配, 所以要做好session内存池, 6W 人同时在线 6W * 8K == 500M的内存;
Session设计:
1: 发送数据;
2 获取session信息;
3: 提供 close函数主动关闭socket;
class session {
public:
virtual void close() = 0;
virtual void send_data(unsigned char* body, int len) = 0;
virtual const char* get_address(int* client_port) = 0;
};
基础框架完整代码
main.cc代码
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "../../netbus/netbus.h"
int main(int argc, char** argv) {
netbus::instance()->start_tcp_server(6080);
netbus::instance()->start_tcp_server(6081);
netbus::instance()->run();
return 0;
}
netbus.h、netbus.cc代码
#ifndef __NETBUS_H__
#define __NETBUS_H__
class netbus {
public:
static netbus* instance();
public:
void start_tcp_server(int port);
void start_ws_server(int port);
void run();
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "uv.h"
#include "session.h"
#include "session_uv.h"
#include "netbus.h"
extern "C" {
static void
uv_alloc_buf(uv_handle_t* handle,
size_t suggested_size,
uv_buf_t* buf) {
uv_session* s = (uv_session*)handle->data;
*buf = uv_buf_init(s->recv_buf + s->recved, RECV_LEN - s->recved);
}
static void
on_close(uv_handle_t* handle) {
uv_session* s = (uv_session*)handle->data;
uv_session::destroy(s);
}
static void
on_shutdown(uv_shutdown_t* req, int status) {
uv_close((uv_handle_t*)req->handle, on_close);
}
static void
after_read(uv_stream_t* stream,
ssize_t nread,
const uv_buf_t* buf) {
uv_session* s = (uv_session*)stream->data;
if (nread < 0) {
//uv_shutdown_t* reg = &s->shutdown;
//memset(reg, 0, sizeof(uv_shutdown_t));
//uv_shutdown(reg, stream, on_shutdown);
s->close();
return;
}
// end
buf->base[nread] = 0;
printf("recv %d\n", nread);
printf("%s\n", buf->base);
// test
s->send_data((unsigned char*)buf->base, nread);
s->recved = 0;
// end
}
static void
uv_connection(uv_stream_t* server, int status) {
uv_session* s = uv_session::create();
uv_tcp_t* client = &s->tcp_handler;
memset(client, 0, sizeof(uv_tcp_t));
uv_tcp_init(uv_default_loop(), client);
client->data = (void*)s;
uv_accept(server, (uv_stream_t*)client);
struct sockaddr_in addr;
int len = sizeof(addr);
uv_tcp_getpeername(client, (sockaddr*)&addr, &len);
uv_ip4_name(&addr, (char*)s->c_address, 64);
s->c_port = ntohs(addr.sin_port);
s->socket_type = (int)(server->data);
printf("new client comming %s:%d\n", s->c_address, s->c_port);
uv_read_start((uv_stream_t*)client, uv_alloc_buf, after_read);
}
}
static netbus g_netbus;
netbus* netbus::instance() {
return &g_netbus;
}
void netbus::start_tcp_server(int port) {
uv_tcp_t* listen = (uv_tcp_t*)malloc(sizeof(uv_tcp_t));
memset(listen, 0, sizeof(uv_tcp_t));
uv_tcp_init(uv_default_loop(), listen);
struct sockaddr_in addr;
uv_ip4_addr("0.0.0.0", port, &addr);
int ret = uv_tcp_bind(listen, (const struct sockaddr*) &addr, 0);
if (ret != 0) {
printf("bind error\n");
free(listen);
return;
}
uv_listen((uv_stream_t*)listen, SOMAXCONN, uv_connection);
listen->data = (void*) TCP_SOCKET;
}
void netbus::run() {
uv_run(uv_default_loop(), UV_RUN_DEFAULT);
}
session.h、session_uv.h、session_uv.cc代码
#ifndef __SESSION_H__
#define __SESSION_H__
class session {
public:
virtual void close() = 0;
virtual void send_data(unsigned char* body, int len) = 0;
virtual const char* get_address(int* client_port) = 0;
};
#endif
#ifndef __SESSION_UV_H__
#define __SESSION_UV_H__
#define RECV_LEN 4096
enum {
TCP_SOCKET,
WS_SOCKET,
};
class uv_session : session {
public:
uv_tcp_t tcp_handler;
char c_address[32];
int c_port;
uv_shutdown_t shutdown;
bool is_shutdown;
public:
char recv_buf[RECV_LEN];
int recved;
int socket_type;
private:
void init();
void exit();
public:
static uv_session* create();
static void destroy(uv_session* s);
void* operator new(size_t size); //为linux提供支持
void operator delete(void* mem);//为linux提供支持
public:
virtual void close();
virtual void send_data(unsigned char* body, int len);
virtual const char* get_address(int* client_port);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "uv.h"
#include "session.h"
#include "session_uv.h"
extern "C" {
static void
after_write(uv_write_t* req, int status) {
if (status == 0) {
printf("write success\n");
}
free(req);
}
static void
on_close(uv_handle_t* handle) {
uv_session* s = (uv_session*)handle->data;
uv_session::destroy(s);
}
static void
on_shutdown(uv_shutdown_t* req, int status) {
uv_close((uv_handle_t*)req->handle, on_close);
}
}
uv_session*
uv_session::new() {
return cache_alloc(session_allocer, sizeof(uv_session));
}
uv_session*
uv_session::delete(void* mem) {
cache_free(session_allocer,mem)
}
uv_session*
uv_session::create() {
//uv_s->uv_session:uv_session();//linux不支持,所以需要重载new,delete
uv_session* uv_s = new uv_session();
uv_s->init();
return uv_s;
}
void
uv_session::destroy(uv_session* s) {
s->exit();
delete s; // temp;
}
void
uv_session::init() {
memset(this->c_address, 0, sizeof(this->c_address));
this->c_port = 0;
this->recved = 0;
this->is_shuntdown = false;
}
void
uv_session::exit() {
}
void
uv_session::close() {
if(this->is_shutdown){
return;
}
this->is_shutdown = true;
uv_shutdown_t* reg = &this->shutdown;
memset(reg, 0, sizeof(uv_shutdown_t));
uv_shutdown(reg, (uv_stream_t*)&this->tcp_handler, on_shutdown);
}
void
uv_session::send_data(unsigned char* body, int len) {
uv_write_t* w_req = (uv_write_t*)malloc(sizeof(uv_write_t));
uv_buf_t* w_buf;
w_buf = uv_buf_init((char*)body, len);
uv_write(w_req, (uv_stream_t*)&this->tcp_handler, w_buf, 1, after_write);
}
const char*
uv_session::get_address(int* port) {
*port = this->c_port;
return this->c_address;
}
内存池管理
1: 高效的内存管理, 将大量的重复的内存分配做好缓冲池,避免”内存碎片”。
2: 编写C和C++同时支持的内存池;
3: C++对象在构建/销毁的时候, 要调用构造函数/析构函数;
new: malloc + 构造函数
delete: 析构函数 + free;
cache_alloc.h和cache_alloc.c
#ifndef __CACHE_ALLOC_H__
#define __CACHE_ALLOC_H__
#ifdef __cplusplus
extern "C" {
#endif
struct cache_allocer* create_cache_allocer(int capacity, int elem_size);
void destroy_cache_allocer(struct cache_allocer* allocer);
void* cache_alloc(struct cache_allocer* allocer, int elem_size);
void cache_free(struct cache_allocer* allocer, void* mem);
#ifdef __cplusplus
}
#endif
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "cache_alloc.h"
struct node {
struct node* next;
};
struct cache_allocer {
unsigned char* cache_mem;
int capacity;
struct node* free_list;
int elem_size;
};
struct cache_allocer*
create_cache_allocer(int capacity, int elem_size) {
struct cache_allocer* allocer = malloc(sizeof(struct cache_allocer));
memset(allocer, 0, sizeof(struct cache_allocer));
elem_size = (elem_size < sizeof(struct node)) ? sizeof(struct node) : elem_size;
allocer->capacity = capacity;
allocer->elem_size = elem_size;
allocer->cache_mem = malloc(capacity * elem_size);
memset(allocer->cache_mem, 0, capacity * elem_size);
allocer->free_list = NULL;
for (int i = 0; i < capacity; i++) {
struct node*walk = (struct node*)(allocer->cache_mem + i * elem_size);
walk->next = allocer->free_list;
allocer->free_list = walk;
}
return allocer;
}
void
destroy_cache_allocer(struct cache_allocer* allocer) {
if (allocer->cache_mem != NULL) {
free(allocer->cache_mem);
}
free(allocer);
}
void*
cache_alloc(struct cache_allocer* allocer, int elem_size) {
if (allocer->elem_size < elem_size) {
return NULL;
}
if (allocer->free_list != NULL) {
void* now = allocer->free_list;
allocer->free_list = allocer->free_list->next;
return now;
}
return malloc(elem_size);
}
void cache_free(struct cache_allocer* allocer, void* mem) {
if (((unsigned char*)mem) >= allocer->cache_mem &&
((unsigned char*)mem) < allocer->cache_mem + allocer->capacity * allocer->elem_size) {
struct node* node = mem;
node->next = allocer->free_list;
allocer->free_list = node;
return;
}
free(mem);
}
小内存分配器small_alloc.c
编写一个全局的小内存块分配器,每个元素的内存字节 <= 128字节,可以使用这个来分配;用small_alloc代替malloc和strdup,small_free代替free,
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "cache_alloc.h"
#include "small_alloc.h"
#define SMALL_ELEM_NUM 10 * 1024
#define SMALL_ELEM_SIZE 128
static struct cache_allocer* small_allocer = NULL;
void*
small_alloc(int size) {
if (small_allocer == NULL) {
small_allocer = create_cache_allocer(SMALL_ELEM_NUM, SMALL_ELEM_SIZE);
}
return cache_alloc(small_allocer, size);
}
void
small_free(void* mem) {
if (small_allocer == NULL) {
small_allocer = create_cache_allocer(SMALL_ELEM_NUM, SMALL_ELEM_SIZE);
}
cache_free(small_allocer, mem);
}
添加WebSocket协议支持
关于websocket的代码及说明参考之前的文章C/C++服务器基础(网络、协议、数据库)
ws_protocol.h、ws_protocol.cc
#ifndef __WS_PROTOCOL_H__
#define __WS_PROTOCOL_H__
class session;
class ws_protocol {
public:
static bool ws_shake_hand(session* s, char* body, int len);
static bool read_ws_header(unsigned char* pkg_data, int pkg_len, int* pkg_size, int* out_header_size);
static void parser_ws_recv_data(unsigned char* raw_data, unsigned char* mask, int raw_len);
static unsigned char* package_ws_send_data(const unsigned char* raw_data, int len, int* ws_data_len);
static void free_ws_send_pkg(unsigned char* ws_pkg);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "../3rd/http_parser/http_parser.h"
#include "../3rd/crypto/base64_encoder.h"
#include "../3rd/crypto/sha1.h"
#include "session.h"
#include "ws_protocol.h"
//
static char* wb_migic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// base64(sha1(key + wb_migic))
static char *wb_accept = "HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade:websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
"WebSocket-Protocol:chat\r\n\r\n";
static char filed_sec_key[512];
static char value_sec_key[512];
static int is_sec_key = 0;
static int has_sec_key = 0;
static int is_shaker_ended = 0;
extern "C" {
int on_message_end(http_parser*p) {
is_shaker_ended = 1;
return 0;
}
}
static int
on_ws_header_field(http_parser* p, const char *at, size_t length) {
if (strncmp(at, "Sec-WebSocket-Key", length) == 0) {
is_sec_key = 1;
}
else {
is_sec_key = 0;
}
return 0;
}
static int
on_ws_header_value(http_parser* p, const char *at, size_t length) {
if (!is_sec_key) {
return 0;
}
strncpy(value_sec_key, at, length);
value_sec_key[length] = 0;
has_sec_key = 1;
return 0;
}
bool ws_protocol::ws_shake_hand(session* s, char* body, int len) {
http_parser_settings settings;
http_parser_settings_init(&settings);
settings.on_header_field = on_ws_header_field;
settings.on_header_value = on_ws_header_value;
settings.on_message_complete = on_message_end;
http_parser p;
http_parser_init(&p, HTTP_REQUEST);
is_sec_key = 0;
has_sec_key = 0;
is_shaker_ended = 0;
http_parser_execute(&p, &settings, body, len);
if (has_sec_key && is_shaker_ended) { // ½âÎöµ½ÁËwebsocketÀïÃæµÄSec-WebSocket-Key
printf("Sec-WebSocket-Key: %s\n", value_sec_key);
// key + migic
static char key_migic[512];
static char sha1_key_migic[SHA1_DIGEST_SIZE];
static char send_client[512];
int sha1_size;
sprintf(key_migic, "%s%s", value_sec_key, wb_migic);
crypt_sha1((unsigned char*)key_migic, strlen(key_migic), (unsigned char*)&sha1_key_migic, &sha1_size);
int base64_len;
char* base_buf = base64_encode((uint8_t*)sha1_key_migic, sha1_size, &base64_len);
sprintf(send_client, wb_accept, base_buf);
base64_encode_free(base_buf);
s->send_data((unsigned char*)send_client, strlen(send_client));
return true;
}
return false;
}
bool
ws_protocol::read_ws_header(unsigned char* recv_data, int recv_len, int* pkg_size, int* out_header_size) {
if (recv_data[0] != 0x81 && recv_data[0] != 0x82) {
return false;
}
if (recv_len < 2) {
return false;
}
unsigned int data_len = recv_data[1] & 0x0000007f;
int head_size = 2;
if (data_len == 126) {
head_size += 2;
if (recv_len < head_size) {
return false;
}
data_len = recv_data[3] | (recv_data[2] << 8);
}
else if (data_len == 127){
head_size += 8;
if (recv_len < head_size) {
return false;
}
unsigned int low = recv_data[5] | (recv_data[4] << 8) | (recv_data[3] << 16) | (recv_data[2] << 24);
unsigned int hight = recv_data[9] | (recv_data[8] << 8) | (recv_data[7] << 16) | (recv_data[6] << 24);
data_len = low;
}
head_size += 4; // 4 个mask
*pkg_size = data_len + head_size;
*out_header_size = head_size;
return true;
}
void
ws_protocol::parser_ws_recv_data(unsigned char* raw_data, unsigned char* mask, int raw_len) {
for (int i = 0; i < raw_len; i ++) {
raw_data[i] = raw_data[i] ^ mask[i % 4];
}
}
unsigned char*
ws_protocol::package_ws_send_data(const unsigned char* raw_data, int len, int* ws_data_len) {
int head_size = 2;
if (len > 125 && len < 65536) {
head_size += 2;
}
else if (len >= 65536) {
head_size += 8;
return NULL;
}
// cache malloc
unsigned char* data_buf = (unsigned char*)malloc(head_size + len);
data_buf[0] = 0x81;
if (len <= 125) {
data_buf[1] = len;
}
else if (len > 125 && len < 65536) {
data_buf[1] = 126;
data_buf[2] = (len & 0x0000ff00) >> 8;
data_buf[3] = (len & 0x000000ff);
}
memcpy(data_buf + head_size, raw_data, len);
*ws_data_len = (head_size + len);
return data_buf;
}
void
ws_protocol::free_ws_send_pkg(unsigned char* ws_pkg) {
// cache free
free(ws_pkg);
}
更改基础框架
netbus.h、netbus.cc代码
#ifndef __NETBUS_H__
#define __NETBUS_H__
class netbus {
public:
static netbus* instance();
public:
void init();
void start_tcp_server(int port);
void start_ws_server(int port);
void run();
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "uv.h"
#include "session.h"
#include "session_uv.h"
#include "netbus.h"
#include "ws_protocol.h"
extern "C" {
static void
on_recv_client_cmd(uv_session* s, unsigned char* body, int len) {
printf("client command !!!!\n");
// test
s->send_data(body, len);
// end
}
static void
on_recv_ws_data(uv_session* s) {
unsigned char* pkg_data = (unsigned char*)((s->long_pkg != NULL) ? s->long_pkg : s->recv_buf);
while (s->recved > 0) {
int pkg_size = 0;
int head_size = 0;
if (pkg_data[0] == 0x88) { // close协议
s->close();
break;
}
// pkg_size - head_size = body_size;
if (!ws_protocol::read_ws_header(pkg_data, s->recved, &pkg_size, &head_size)) {
break;
}
if (s->recved < pkg_size) { // 数据还没收完
break;
}
unsigned char* raw_data = pkg_data + head_size;
unsigned char* mask = raw_data - 4;
ws_protocol::parser_ws_recv_data(raw_data, mask, pkg_size - head_size);
// recv client command;
on_recv_client_cmd(s, raw_data, pkg_size - head_size);
// end
if (s->recved > pkg_size) { // 收到的数据还有剩余
memmove(pkg_data, pkg_data + pkg_size, s->recved - pkg_size);
}
s->recved -= pkg_size;
if (s->recved == 0 && s->long_pkg != NULL) {
free(s->long_pkg);
s->long_pkg = NULL;
s->long_pkg_size = 0;
}
}
}
static void
uv_alloc_buf(uv_handle_t* handle,
size_t suggested_size,
uv_buf_t* buf) {
uv_session* s = (uv_session*)handle->data;
if (s->recved < RECV_LEN) {
*buf = uv_buf_init(s->recv_buf + s->recved, RECV_LEN - s->recved);
}
else {
if (s->long_pkg == NULL) { // alloc mem
if (s->socket_type == WS_SOCKET && s->is_ws_shake) { // ws > RECV_LEN's package
int pkg_size;
int head_size;
ws_protocol::read_ws_header((unsigned char*)s->recv_buf, s->recved, &pkg_size, &head_size);
s->long_pkg_size = pkg_size;
s->long_pkg = (char*)malloc(pkg_size);
memcpy(s->long_pkg, s->recv_buf, s->recved);
}
else { // tcp > RECV_LEN's package
int pkg_size;
int head_size;
tp_protocol::read_header((unsigned char*)s->recv_buf, s->recved, &pkg_size, &head_size);
s->long_pkg_size = pkg_size;
s->long_pkg = (char*)malloc(pkg_size);
memcpy(s->long_pkg, s->recv_buf, s->recved);
}
}
*buf = uv_buf_init(s->long_pkg + s->recved, s->long_pkg_size - s->recved);
}
}
static void
on_close(uv_handle_t* handle) {
uv_session* s = (uv_session*)handle->data;
uv_session::destroy(s);
}
static void
on_shutdown(uv_shutdown_t* req, int status) {
uv_close((uv_handle_t*)req->handle, on_close);
}
static void
after_read(uv_stream_t* stream,
ssize_t nread,
const uv_buf_t* buf) {
uv_session* s = (uv_session*)stream->data;
if (nread < 0) {
// uv_shutdown_t* reg = &s->shutdown;
// memset(reg, 0, sizeof(uv_shutdown_t));
// uv_shutdown(reg, stream, on_shutdown);
s->close();
return;
}
// end
s->recved += nread;
if (s->socket_type == WS_SOCKET) { // websocket
if (s->is_ws_shake == 0) { // shake handle
if (ws_protocol::ws_shake_hand((session*)s, s->recv_buf, s->recved)) {
s->is_ws_shake = 1;
s->recved = 0;
}
}
else { // websocket recv/send data
on_recv_ws_data(s);
}
}
else { // TCP sokcet
}
}
static void
uv_connection(uv_stream_t* server, int status) {
uv_session* s = uv_session::create();
uv_tcp_t* client = &s->tcp_handler;
memset(client, 0, sizeof(uv_tcp_t));
uv_tcp_init(uv_default_loop(), client);
client->data = (void*)s;
uv_accept(server, (uv_stream_t*)client);
struct sockaddr_in addr;
int len = sizeof(addr);
uv_tcp_getpeername(client, (sockaddr*)&addr, &len);
uv_ip4_name(&addr, (char*)s->c_address, 64);
s->c_port = ntohs(addr.sin_port);
s->socket_type = (int)(server->data);
printf("new client comming %s:%d\n", s->c_address, s->c_port);
uv_read_start((uv_stream_t*)client, uv_alloc_buf, after_read);
}
}
static netbus g_netbus;
netbus* netbus::instance() {
return &g_netbus;
}
void netbus::start_tcp_server(int port) {
uv_tcp_t* listen = (uv_tcp_t*)malloc(sizeof(uv_tcp_t));
memset(listen, 0, sizeof(uv_tcp_t));
uv_tcp_init(uv_default_loop(), listen);
struct sockaddr_in addr;
uv_ip4_addr("0.0.0.0", port, &addr);
int ret = uv_tcp_bind(listen, (const struct sockaddr*) &addr, 0);
if (ret != 0) {
printf("bind error\n");
free(listen);
return;
}
uv_listen((uv_stream_t*)listen, SOMAXCONN, uv_connection);
listen->data = (void*) TCP_SOCKET;
}
void netbus::start_ws_server(int port) {
uv_tcp_t* listen = (uv_tcp_t*)malloc(sizeof(uv_tcp_t));
memset(listen, 0, sizeof(uv_tcp_t));
uv_tcp_init(uv_default_loop(), listen);
struct sockaddr_in addr;
uv_ip4_addr("0.0.0.0", port, &addr);
int ret = uv_tcp_bind(listen, (const struct sockaddr*) &addr, 0);
if (ret != 0) {
printf("bind error\n");
free(listen);
return;
}
uv_listen((uv_stream_t*)listen, SOMAXCONN, uv_connection);
listen->data = (void*)WS_SOCKET;
}
void netbus::run() {
uv_run(uv_default_loop(), UV_RUN_DEFAULT);
}
void netbus::init() {
init_session_allocer();
}
session_uv.cc
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "uv.h"
#include "session.h"
#include "session_uv.h"
#include "../utils/cache_alloc.h"
#include "ws_protocol.h"
#define SESSION_CACHE_CAPACITY 6000
#define WQ_CACHE_CAPCITY 4096
struct cache_allocer* session_allocer = NULL;
static cache_allocer* wr_allocer = NULL;
void init_session_allocer() {
if (session_allocer == NULL) {
session_allocer = create_cache_allocer(SESSION_CACHE_CAPACITY, sizeof(uv_session));
}
if (wr_allocer == NULL) {
wr_allocer = create_cache_allocer(WQ_CACHE_CAPCITY, sizeof(uv_write_t));
}
}
extern "C" {
static void
after_write(uv_write_t* req, int status) {
if (status == 0) {
printf("write success\n");
}
// free(req);
cache_free(wr_allocer, req);
}
static void
on_close(uv_handle_t* handle) {
uv_session* s = (uv_session*)handle->data;
uv_session::destroy(s);
}
static void
on_shutdown(uv_shutdown_t* req, int status) {
uv_close((uv_handle_t*)req->handle, on_close);
}
}
uv_session*
uv_session::create() {
// uv_session* uv_s = new uv_session(); // temp
uv_session* uv_s = (uv_session*)cache_alloc(session_allocer, sizeof(uv_session));
uv_s->uv_session::uv_session();
uv_s->init();
return uv_s;
}
void
uv_session::destroy(uv_session* s) {
s->exit();
// delete s; // temp;
s->uv_session::~uv_session();
cache_free(session_allocer, s);
}
void
uv_session::init() {
memset(this->c_address, 0, sizeof(this->c_address));
this->c_port = 0;
this->recved = 0;
this->is_shutdown = false;
this->is_ws_shake = 0;
this->long_pkg = NULL;
this->long_pkg_size = 0;
}
void
uv_session::exit() {
}
void
uv_session::close() {
if (this->is_shutdown) {
return;
}
this->is_shutdown = true;
uv_shutdown_t* reg = &this->shutdown;
memset(reg, 0, sizeof(uv_shutdown_t));
uv_shutdown(reg, (uv_stream_t*)&this->tcp_handler, on_shutdown);
}
void
uv_session::send_data(unsigned char* body, int len) {
// uv_write_t* w_req = (uv_write_t*)malloc(sizeof(uv_write_t));
uv_write_t* w_req = (uv_write_t*)cache_alloc(wr_allocer, sizeof(uv_write_t));
uv_buf_t w_buf;
if (this->socket_type == WS_SOCKET && this->is_ws_shake) {
int ws_pkg_len;
unsigned char* ws_pkg = ws_protocol::package_ws_send_data(body, len, &ws_pkg_len);
w_buf = uv_buf_init((char*)ws_pkg, ws_pkg_len);
uv_write(w_req, (uv_stream_t*)&this->tcp_handler, &w_buf, 1, after_write);
ws_protocol::free_ws_send_pkg(ws_pkg);
}
else {
w_buf = uv_buf_init((char*)body, len);
uv_write(w_req, (uv_stream_t*)&this->tcp_handler, &w_buf, 1, after_write);
}
}
const char*
uv_session::get_address(int* port) {
*port = this->c_port;
return this->c_address;
}
使用
int main(int argc, char** argv) {
netbus::instance()->init();
netbus::instance()->start_tcp_server(6080);
netbus::instance()->start_ws_server(8001);
netbus::instance()->run();
return 0;
}
TCP封包与拆包
问题:
1: 当我们客户端发送 数据包1,数据包2的时候,我们的服务器,可能会同时收到,也就是说,数据包1和数据包2黏在一起了。
2: 当我们收数据的时候,不知道要收多少这个数据包才算结束。
3: 我们要对每个独立的数据包,进行封包,封包后发送。
做法:
1: 前两个字节表示当前包的大小,后面表示包的数据;
2: 收到数据后,根据前两个字节的大小,完整的收完一个数据包后,将数据提取出来。
3: 发送一个数据的时候打入封包信息。
4: websocket其实就是一种封包和拆包协议。
拆包流程:
1: 接收数据,如果不够2个字节,直接返回继续接收数据;
2: 如果超过两个字节,读取前两个字节的大小,根据大小,判断是否收到超过1个包的数据。
3: 如果没有,继续接收。
4: 如果有,处理一个包,处理完后,继续处理下一个包,如果剩下的数据不足一个包或为0,
继续接受数据。
tcp_protocol.h和tcp_protocol.cc代码
#ifndef __TP_PROTOCOL_H__
#define __TP_PROTOCOL_H__
class tp_protocol {
public:
static bool read_header(unsigned char* data, int data_len, int* pkg_size, int* out_header_size);
static unsigned char* package(const unsigned char* raw_data, int len, int* pkg_len);
static void release_package(unsigned char* tp_pkg);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "tp_protocol.h"
#include "../utils/cache_alloc.h"
extern cache_allocer* wbuf_allocer;
bool
tp_protocol::read_header(unsigned char* data, int data_len,
int* pkg_size, int* out_header_size) {
if (data_len < 2) {
return false;
}
*pkg_size = (data[0] | (data[1] << 8));
*out_header_size = 2;
return true;
}
unsigned char*
tp_protocol::package(const unsigned char* raw_data, int len, int* pkg_len) {
int head_size = 2;
// cache malloc
*pkg_len = (head_size + len);
// unsigned char* data_buf = (unsigned char*)malloc((*pkg_len));
unsigned char* data_buf = (unsigned char*)cache_alloc(wbuf_allocer, (*pkg_len));
data_buf[0] = (unsigned char)((*pkg_len) & 0x000000ff);
data_buf[1] = (unsigned char)(((*pkg_len) & 0x0000ff00) >> 8);
memcpy(data_buf + head_size, raw_data, len);
return data_buf;
}
void
tp_protocol::release_package(unsigned char* tp_pkg) {
// free(tp_pkg);
cache_free(wbuf_allocer, tp_pkg);
}
客户端测试代码
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#ifdef WIN32 // WIN32 宏, Linux宏不存在
#include <WinSock2.h>
#include <Windows.h>
#pragma comment (lib, "WSOCK32.LIB")
#endif
#include "tp_protocol.h"
int main(int argc, char** argv) {
int ret;
// 配置一下windows socket 版本
// 一定要加上这个,否者低版本的socket会出很多莫名的问题;
#ifdef WIN32
WORD wVersionRequested;
WSADATA wsaData;
wVersionRequested = MAKEWORD(2, 2);
ret = WSAStartup(wVersionRequested, &wsaData);
if (ret != 0) {
printf("WSAStart up failed\n");
system("pause");
return -1;
}
#endif
int s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (s == INVALID_SOCKET) {
goto failed;
}
// 配置一下要连接服务器的socket
// 127.0.0.1 本机IP地址;
struct sockaddr_in sock_addr;
sock_addr.sin_addr.S_un.S_addr = inet_addr("127.0.0.1");
sock_addr.sin_family = AF_INET;
sock_addr.sin_port = htons(6080); // 连接信息要发送给监听socket;
// 发送连接请求到我们服务端的监听socket;
ret = connect((SOCKET)s, (const sockaddr*)&sock_addr, sizeof(sockaddr));
if (ret != 0) {
goto failed;
}
// 连接成功, s与服务器对应的socket就会建立连接;
// 客户端在连接的时候他也需要一个IP地址+端口;
// 端口是服务器端口。不是,客户端一个没有使用的端口就可以了;
// 客户端自己也会分配一个IP + 端口(只要是没有使用的就可以了);
//
char buf[32];
memset(buf, 0, 32);
int pkg_len = 0;
unsigned char* data = tp_protocol::package((unsigned char*)"Hello", 5, &pkg_len);
send(s, (const char*)data, pkg_len, 0);
tp_protocol::release_package(data);
recv(s, buf, 7, 0);
int head_size = 0;
tp_protocol::read_header((unsigned char*)buf, 7, &pkg_len, &head_size);
printf("%s\n", buf + head_size);
failed:
if (s != INVALID_SOCKET) {
closesocket(s);
s = INVALID_SOCKET;
}
#ifdef WIN32
WSACleanup();
#endif
system("pause");
return 0;
}
服务端修改后的代码
session_uv.cc部分代码
void
uv_session::send_data(unsigned char* body, int len) {
// uv_write_t* w_req = (uv_write_t*)malloc(sizeof(uv_write_t));
uv_write_t* w_req = (uv_write_t*)cache_alloc(wr_allocer, sizeof(uv_write_t));
uv_buf_t w_buf;
if (this->socket_type == WS_SOCKET) {
if (this->is_ws_shake) {
int ws_pkg_len;
unsigned char* ws_pkg = ws_protocol::package_ws_send_data(body, len, &ws_pkg_len);
w_buf = uv_buf_init((char*)ws_pkg, ws_pkg_len);
uv_write(w_req, (uv_stream_t*)&this->tcp_handler, &w_buf, 1, after_write);
ws_protocol::free_ws_send_pkg(ws_pkg);
}
else {
w_buf = uv_buf_init((char*)body, len);
uv_write(w_req, (uv_stream_t*)&this->tcp_handler, &w_buf, 1, after_write);
}
}
else { // tcp,
int tp_pkg_len;
unsigned char* tp_pkg = tp_protocol::package(body, len, &tp_pkg_len);
w_buf = uv_buf_init((char*)tp_pkg, tp_pkg_len);
uv_write(w_req, (uv_stream_t*)&this->tcp_handler, &w_buf, 1, after_write);
tp_protocol::release_package(tp_pkg);
}
}
netbus.cc部分代码
static void
on_recv_tcp_data(uv_session* s) {
unsigned char* pkg_data = (unsigned char*)((s->long_pkg != NULL) ? s->long_pkg : s->recv_buf);
while (s->recved > 0) {
int pkg_size = 0;
int head_size = 0;
if (!tp_protocol::read_header(pkg_data, s->recved, &pkg_size, &head_size)) {
break;
}
if(pkg_size <= head_size){
s->close();
break;
}
if (s->recved < pkg_size) {
break;
}
unsigned char* raw_data = pkg_data + head_size;
// recv client command;
on_recv_client_cmd(s, raw_data, pkg_size - head_size);
// end
if (s->recved > pkg_size) {
memmove(pkg_data, pkg_data + pkg_size, s->recved - pkg_size);
}
s->recved -= pkg_size;
if (s->recved == 0 && s->long_pkg != NULL) {
free(s->long_pkg);
s->long_pkg = NULL;
s->long_pkg_size = 0;
}
}
}
static void
after_read(uv_stream_t* stream,
ssize_t nread,
const uv_buf_t* buf) {
uv_session* s = (uv_session*)stream->data;
if (nread < 0) {
// uv_shutdown_t* reg = &s->shutdown;
// memset(reg, 0, sizeof(uv_shutdown_t));
// uv_shutdown(reg, stream, on_shutdown);
s->close();
return;
}
// end
s->recved += nread;
if (s->socket_type == WS_SOCKET) { // websocket
if (s->is_ws_shake == 0) { // shake handle
if (ws_protocol::ws_shake_hand((session*)s, s->recv_buf, s->recved)) {
s->is_ws_shake = 1;
s->recved = 0;
}
}
else { // websocket recv/send data
on_recv_ws_data(s);
}
}
else { // TCP sokcet
on_recv_tcp_data(s);
}
}
命令格式和协议管理
命令格式如下:
服务号(2字节)|命令号(2字节)|用户标识(4字节)|数据体协议(protobuf)
服务号: 属于哪个服务;
命令号: 属于哪个命令;
用户标识: 服务器内部来存放用户的UID信息;
数据体:加密/解密 protobuf协议/json协议来封装数据,根据用户来选择;
proto_man.h和proto_man.cc
#ifndef __PROTO_MAN_H__
#define __PROTO_MAN_H__
enum {
PROTO_JSON = 0,
PROTO_BUF = 1,
};
struct cmd_msg{
int stype;
int ctype;
unsigned int utag;
void* body; // JSON str 或者是message;
};
class proto_man {
public:
static void init(int proto_type);
static void register_pf_cmd_map(char** pf_map, int len);
static int proto_type();
static bool decode_cmd_msg(unsigned char* cmd, int cmd_len, struct cmd_msg** out_msg);
static void cmd_msg_free(struct cmd_msg* msg);
static unsigned char* encode_msg_to_raw(const struct cmd_msg* msg, int* out_len);
static void msg_raw_free(unsigned char* raw);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "google/protobuf/message.h"
#include "proto_man.h"
#define MAX_PF_MAP_SIZE 1024
#define CMD_HEADER 8
static int g_proto_type = PROTO_BUF;
static char* g_pf_map[MAX_PF_MAP_SIZE];
static int g_cmd_count = 0;
void
proto_man::init(int proto_type) {
g_proto_type = proto_type;
}
int
proto_man::proto_type() {
return g_proto_type;
}
void
proto_man::register_pf_cmd_map(char** pf_map, int len) {
len = (MAX_PF_MAP_SIZE - g_cmd_count) < len ? ((MAX_PF_MAP_SIZE - g_cmd_count)) : len;
for (int i = 0; i < len; i++) {
g_pf_map[g_cmd_count + i] = strdup(pf_map[i]);
}
g_cmd_count += len;
}
static google::protobuf::Message*
create_message(const char* type_name) {
google::protobuf::Message* message = NULL;
const google::protobuf::Descriptor* descriptor =
google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(type_name);
if (descriptor) {
const google::protobuf::Message* prototype =
google::protobuf::MessageFactory::generated_factory()->GetPrototype(descriptor);
if (prototype) {
message = prototype->New();
}
}
return message;
}
static void
release_message(google::protobuf::Message* m) {
delete m;
}
// stype(2 byte) | ctype(2byte) | utag(4byte) | body
bool
proto_man::decode_cmd_msg(unsigned char* cmd, int cmd_len, struct cmd_msg** out_msg) {
*out_msg = NULL;
if (cmd_len < CMD_HEADER) {
return false;
}
struct cmd_msg* msg = (struct cmd_msg*)malloc(sizeof(struct cmd_msg));
msg->stype = cmd[0] | (cmd[1] << 8);
msg->ctype = cmd[2] | (cmd[3] << 8);
msg->utag = cmd[4] | (cmd[5] << 8) | (cmd[6] << 16) | (cmd[7] << 24);
msg->body = NULL;
*out_msg = msg;
if (cmd_len == CMD_HEADER) {
return true;
}
if (g_proto_type == PROTO_JSON) {
int json_len = cmd_len - CMD_HEADER;
char* json_str = (char*)malloc(json_len + 1);
memcpy(json_str, cmd + CMD_HEADER, json_len);
json_str[json_len] = 0;
msg->body = (void*)json_str;
}
else { // protobuf
if (msg->ctype < 0 || msg->ctype >= g_cmd_count || g_pf_map[msg->ctype] == NULL) {
free(msg);
*out_msg = NULL;
return false;
}
google::protobuf::Message* p_m = create_message(g_pf_map[msg->ctype]);
if (p_m == NULL) {
free(msg);
*out_msg = NULL;
return false;
}
if (!p_m->ParseFromArray(cmd + CMD_HEADER, cmd_len - CMD_HEADER)) {
free(msg);
*out_msg = NULL;
release_message(p_m);
return false;
}
msg->body = p_m;
}
return true;
}
void
proto_man::cmd_msg_free(struct cmd_msg* msg) {
if (msg->body) {
if (g_proto_type == PROTO_JSON) {
free(msg->body);
msg->body = NULL;
}
else {
google::protobuf::Message* p_m = (google::protobuf::Message*) msg->body;
delete p_m;
msg->body = NULL;
}
}
free(msg);
}
unsigned char*
proto_man::encode_msg_to_raw(const struct cmd_msg* msg, int* out_len) {
int raw_len = 0;
unsigned char* raw_data = NULL;
*out_len = 0;
if (g_proto_type == PROTO_JSON) {
char* json_str = NULL;
int len = 0;
if (msg->body) {
json_str = (char*)msg->body;
len = strlen(json_str) + 1;
}
raw_data = (unsigned char*)malloc(CMD_HEADER + len);
if (msg->body != NULL) {
memcpy(raw_data + CMD_HEADER, json_str, len - 1);
raw_data[8 + len] = 0;
}
*out_len = (len + CMD_HEADER);
}
else if (g_proto_type == PROTO_BUF){ // protobuf
google::protobuf::Message* p_m = NULL;
int pf_len = 0;
if (msg->body) {
p_m = (google::protobuf::Message*)msg->body;
pf_len = p_m->ByteSize();
}
raw_data = (unsigned char*)malloc(CMD_HEADER + pf_len);
if (msg->body) {
if (!p_m->SerializePartialToArray(raw_data + CMD_HEADER, pf_len)) {
free(raw_data);
return NULL;
}
}
*out_len = (pf_len + CMD_HEADER);
}
else {
return NULL;
}
// header
raw_data[0] = (msg->stype & 0x000000ff);
raw_data[1] = ((msg->stype & 0x0000ff00) >> 8);
raw_data[2] = (msg->ctype & 0x000000ff);
raw_data[3] = ((msg->ctype & 0x0000ff00) >> 8);
memcpy(raw_data + 4, &msg->utag, 4);
//
return raw_data;
}
void
proto_man::msg_raw_free(unsigned char* raw) {
free(raw);
}
pf_cmd_map.h和pf_cmd_map.cc
#ifndef __PF_CMD_MAP_H__
#define __PF_CMD_MAP_H__
void init_pf_cmd_map();
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "../../../netbus/proto_man.h"
char* pf_cmd_map[] = {
"LoginReq",
"LoginRes",
};
void init_pf_cmd_map() {
proto_man::register_pf_cmd_map(pf_cmd_map, sizeof(pf_cmd_map) / sizeof(char*));
}
根据proto协议文件的内容在这里注册对应的message即可
使用
客户端测试代码
#ifdef WIN32
#include <WinSock2.h>
#include <Windows.h>
#pragma comment (lib, "WSOCK32.LIB")
#endif
#include "tp_protocol.h"
int main(int argc, char** argv) {
int ret;
#ifdef WIN32
DWORD wVersionRequested;
WSADATA wsaData;
wVersionRequested = MAKEWORD(2, 2);
ret = WSAStartup(wVersionRequested, &wsaData);
if (ret != 0) {
printf("WSAStart up failed\n");
system("pause");
return -1;
}
#endif
int s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (s == INVALID_SOCKET) {
return 0;
}
struct sockaddr_in sockaddr;
sockaddr.sin_addr.S_un.S_addr = inet_addr("127.0.0.1");
sockaddr.sin_family = AF_INET;
sockaddr.sin_port = htons(6080);
ret = connect((SOCKET)s, (const struct sockaddr*)&sockaddr, sizeof(sockaddr));
if (ret != 0) {
return 0;
}
LoginReq req;
req.set_age(10);
req.set_name("name");
req.set_email("12345@sina.com");
int len = req.ByteSize();
char* data = (char*)malloc(8 + len);
memset(data, 0, 8 + len);
req.SerializePartialToArray(data + 8, len);
int pkg_len;
unsigned char* pkg_data = tp_protocol::package((unsigned char*)data, 8 + len, &pkg_len);
send(s, (const char*)pkg_data, pkg_len, 0);
free(data);
tp_protocol::release_package(pkg_data);
unsigned char recv_buf[256];
int recv_len = recv(s, (char*)recv_buf, 256, 0);
int pkg_size, header_size;
tp_protocol::read_header(recv_buf, recv_len, &pkg_size, &header_size);
if (s != INVALID_SOCKET) {
closesocket(s);
s = INVALID_SOCKET;
}
req.ParseFromArray(recv_buf + header_size + 8, pkg_size - header_size - 8);
printf("%s: %d\n", req.name().c_str(), req.age());
#ifdef WIN32
WSACleanup();
#endif
system("pause");
return 0;
}
服务端netbus.cc部分代码
static void
on_recv_client_cmd(uv_session* s, unsigned char* body, int len) {
printf("client command !!!!\n");
// test
struct cmd_msg* msg = NULL;
if (proto_man::decode_cmd_msg(body, len, &msg)) {
unsigned char* encode_pkg = NULL;
int encode_len = 0;
encode_pkg = proto_man::encode_msg_to_raw(msg, &encode_len);
if (encode_pkg) {
s->send_data(encode_pkg, encode_len);
proto_man::msg_raw_free(encode_pkg);
}
proto_man::cmd_msg_free(msg);
}
// end
}
Service管理
1: 每一个服务都对应一个服务对象;
2: 当请求发给对应服务的时候,服务根据命令 类型来处理业务逻辑;
3: 每一个服务都向服务管理对象注册号,服务号—> service对象;
4: service是一个抽象类,所有的各种的服务都会继承与它;
5: service manager注册一个服务保存好stype–> service对象的映射;
6: 当解码出命令的时候,根据服务号,转发给对应的模块;
7: 服务器开发人员基于service来开发,你只要开发自己的service并注册给netbus就可以了;
service.h和service.c(每一个服务都需要继承service.h)
#ifndef __SERVICE_H__
#define __SERVICE_H__
class session;
struct cmd_msg;
class service {
public:
virtual bool on_session_recv_cmd(session* s, struct cmd_msg* msg);
virtual void on_session_disconnect(session* s);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "session.h"
#include "proto_man.h"
#include "service.h"
bool
service::on_session_recv_cmd(session* s, struct cmd_msg* msg) {
return false;
}
void
service::on_session_disconnect(session* s) {
}
service_man.h和service_man.cc(管理类)
#ifndef __SERVICE_MAN_H__
#define __SERVICE_MAN_H__
class session;
class service;
struct cmd_msg;
class service_man {
public:
static void init();
static bool register_service(int stype, service* s);
static bool on_recv_cmd_msg(session* s, struct cmd_msg* msg);
static void on_session_disconnect(session* s);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "session.h"
#include "proto_man.h"
#include "service.h"
#include "service_man.h"
#define MAX_SERVICE 512 // 0 ~1024-1
static service* g_service_set[MAX_SERVICE];
bool
service_man::register_service(int stype, service* s) {
if (stype < 0 || stype >= MAX_SERVICE) {
return false;
}
if (g_service_set[stype]) {
return false;
}
g_service_set[stype] = s;
return true;
}
bool
service_man::on_recv_cmd_msg(session* s, struct cmd_msg* msg) {
if (g_service_set[msg->stype] == NULL) {
return false;
}
return g_service_set[msg->stype]->on_session_recv_cmd(s, msg);
}
void
service_man::on_session_disconnect(session* s) {
for (int i = 0; i < MAX_SERVICE; i++) {
if (g_service_set[i] == NULL) {
continue;
}
g_service_set[i]->on_session_disconnect(s);
}
}
void
service_man::init() {
memset(g_service_set, 0, sizeof(g_service_set));
}
使用
netbus.cc部分代码
static void
on_recv_client_cmd(uv_session* s, unsigned char* body, int len) {
printf("client command !!!!\n");
struct cmd_msg* msg = NULL;
if (proto_man::decode_cmd_msg(body, len, &msg)) {
if (!service_man::on_recv_cmd_msg((session*)s, msg)) {
s->close();
}
proto_man::cmd_msg_free(msg);
}
}
void netbus::init() {
service_man::init();
init_session_allocer();
}
session_uv.cc部分代码
void
uv_session::close() {
if (this->is_shutdown) {
return;
}
// broadcast serive client is disconnect;
service_man::on_session_disconnect(this);
// end
this->is_shutdown = true;
uv_shutdown_t* reg = &this->shutdown;
memset(reg, 0, sizeof(uv_shutdown_t));
uv_shutdown(reg, (uv_stream_t*)&this->tcp_handler, on_shutdown);
}
log日志管理和Timer时间戳
1: 日志系统需求:
配置一个日志的输出路径,与日志的前缀名字(区分是哪个服务打印的日志,例如网关等);
每天的日志输出到不同的文件里面,根据日期将文件名字整理好;
每一行日志: 哪个文件哪一行代码输出的日志,方便定位;
每一行日志: 分等级: DEBUG, WARNING, ERROR三个等级;
每一行日志记录下打印的时间,精度到秒;
日志是要异步的输出;
time_list.h和time_list.cc
#ifndef __MY_TIMER_LIST_H__
#define __MY_TIMER_LIST_H__
#ifdef __cplusplus
extern "C" {
#endif
// on_timer是一个回掉函数,当timer触发的时候调用;
// udata: 是用户传的自定义的数据结构;
// on_timer执行的时候 udata,就是你这个udata;
// after_sec: 多少秒开始执行;
// repeat_count: 执行多少次, repeat_count == -1一直执行;
// 返回timer的句柄;
struct timer;
struct timer*
schedule(void(*on_timer)(void* udata),
void* udata,
int after_msec,
int repeat_count);
// 取消掉这个timer;
void
cancel_timer(struct timer* t);
struct timer*
schedule_once(void(*on_timer)(void* udata),
void* udata,
int after_msec);
#ifdef __cplusplus
}
#endif
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "uv.h"
#include "time_list.h"
#define my_malloc malloc
#define my_free free
struct timer {
uv_timer_t uv_timer; // libuv timer handle
void(*on_timer)(void* udata);
void* udata;
int repeat_count; // -1一直循环;
};
static struct timer*
alloc_timer(void(*on_timer)(void* udata),
void* udata, int repeat_count) {
struct timer* t = my_malloc(sizeof(struct timer));
memset(t, 0, sizeof(struct timer));
t->on_timer = on_timer;
t->repeat_count = repeat_count;
t->udata = udata;
uv_timer_init(uv_default_loop(), &t->uv_timer);
return t;
}
static void
free_timer(struct timer* t) {
my_free(t);
}
static void
on_uv_timer(uv_timer_t* handle) {
struct timer* t = handle->data;
if (t->repeat_count < 0) { // 不断的触发;
t->on_timer(t->udata);
}
else {
t->repeat_count --;
t->on_timer(t->udata);
if (t->repeat_count == 0) { // 函数time结束
uv_timer_stop(&t->uv_timer); // 停止这个timer
free_timer(t);
}
}
}
struct timer*
schedule(void(*on_timer)(void* udata),
void* udata,
int after_msec,
int repeat_count) {
struct timer* t = alloc_timer(on_timer, udata, repeat_count);
// 启动一个timer;
t->uv_timer.data = t;
uv_timer_start(&t->uv_timer, on_uv_timer, after_msec, after_msec);
// end
return t;
}
void
cancel_timer(struct timer* t) {
if (t->repeat_count == 0) { // 全部触发完成,;
return;
}
uv_timer_stop(&t->uv_timer);
free_timer(t);
}
struct timer*
schedule_once(void(*on_timer)(void* udata),
void* udata,
int after_msec) {
return schedule(on_timer, udata, after_msec, 1);
}
timestamp.h和timestamp.c
#ifndef __TIMESTAMP_H__
#define __TIMESTAMP_H__
#ifdef __cplusplus
extern "C" {
#endif
// 获取当前的时间戳
unsigned long timestamp();
// 获取给定日期的时间戳"%Y(年)%m(月)%d(日)%H(小时)%M(分)%S(秒)"
unsigned long date2timestamp(const char* fmt_date, const char* date);
// fmt_date "%Y(年)%m(月)%d(日)%H(小时)%M(分)%S(秒)"
void timestamp2date(unsigned long t, char*fmt_date, char* out_buf, int buf_len);
unsigned long timestamp_today();
unsigned long timestamp_yesterday();
#ifdef __cplusplus
}
#endif
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include <ctype.h>
#include "timestamp.h"
/*
%a 星期几的简写
%A 星期几的全称
%b 月分的简写
%B 月份的全称
%c 标准的日期的时间串
%C 年份的后两位数字
%d 十进制表示的每月的第几天
%D 月/天/年
%e 在两字符域中,十进制表示的每月的第几天
%F 年-月-日
%g 年份的后两位数字,使用基于周的年
%G 年分,使用基于周的年
%h 简写的月份名
%H 24小时制的小时
%I 12小时制的小时
%j 十进制表示的每年的第几天
%m 十进制表示的月份
%M 十时制表示的分钟数
%n 新行符
%p 本地的AM或PM的等价显示
%r 12小时的时间
%R 显示小时和分钟:hh:mm
%S 十进制的秒数
%t 水平制表符
%T 显示时分秒:hh:mm:ss
%u 每周的第几天,星期一为第一天 (值从0到6,星期一为0)
%U 第年的第几周,把星期日做为第一天(值从0到53)
%V 每年的第几周,使用基于周的年
%w 十进制表示的星期几(值从0到6,星期天为0)
%W 每年的第几周,把星期一做为第一天(值从0到53)
%x 标准的日期串
%X 标准的时间串
%y 不带世纪的十进制年份(值从0到99)
%Y 带世纪部分的十进制年份
%z,%Z 时区名称,如果不能得到时区名称则返回空字符。
*/
#define TM_YEAR_BASE 1900
/*
* We do not implement alternate representations. However, we always
* check whether a given modifier is allowed for a certain conversion.
*/
#define ALT_E 0x01
#define ALT_O 0x02
#define LEGAL_ALT(x) { if (alt_format & ~(x)) return (0); }
static int conv_num(const char **, int *, int, int);
static const char *day[7] = {
"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday",
"Friday", "Saturday"
};
static const char *abday[7] = {
"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"
};
static const char *mon[12] = {
"January", "February", "March", "April", "May", "June", "July",
"August", "September", "October", "November", "December"
};
static const char *abmon[12] = {
"Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"
};
static const char *am_pm[2] = {
"AM", "PM"
};
//window上自己实现my_strptime函数,linux已经提供my_strptime
//my_strptime函数windows平台上实现
char*
my_strptime(const char *buf, const char *fmt, struct tm *tm)
{
char c;
const char *bp;
size_t len = 0;
int alt_format, i, split_year = 0;
bp = buf;
while ((c = *fmt) != '\0') {
/* Clear `alternate' modifier prior to new conversion. */
alt_format = 0;
/* Eat up white-space. */
if (isspace(c)) {
while (isspace(*bp))
bp++;
fmt++;
continue;
}
if ((c = *fmt++) != '%')
goto literal;
again:
switch (c = *fmt++) {
case '%': /* "%%" is converted to "%". */
literal :
if (c != *bp++)
return (0);
break;
/*
* "Alternative" modifiers. Just set the appropriate flag
* and start over again.
*/
case 'E': /* "%E?" alternative conversion modifier. */
LEGAL_ALT(0);
alt_format |= ALT_E;
goto again;
case 'O': /* "%O?" alternative conversion modifier. */
LEGAL_ALT(0);
alt_format |= ALT_O;
goto again;
/*
* "Complex" conversion rules, implemented through recursion.
*/
case 'c': /* Date and time, using the locale's format. */
LEGAL_ALT(ALT_E);
if (!(bp = my_strptime(bp, "%x %X", tm)))
return (0);
break;
case 'D': /* The date as "%m/%d/%y". */
LEGAL_ALT(0);
if (!(bp = my_strptime(bp, "%m/%d/%y", tm)))
return (0);
break;
case 'R': /* The time as "%H:%M". */
LEGAL_ALT(0);
if (!(bp = my_strptime(bp, "%H:%M", tm)))
return (0);
break;
case 'r': /* The time in 12-hour clock representation. */
LEGAL_ALT(0);
if (!(bp = my_strptime(bp, "%I:%M:%S %p", tm)))
return (0);
break;
case 'T': /* The time as "%H:%M:%S". */
LEGAL_ALT(0);
if (!(bp = my_strptime(bp, "%H:%M:%S", tm)))
return (0);
break;
case 'X': /* The time, using the locale's format. */
LEGAL_ALT(ALT_E);
if (!(bp = my_strptime(bp, "%H:%M:%S", tm)))
return (0);
break;
case 'x': /* The date, using the locale's format. */
LEGAL_ALT(ALT_E);
if (!(bp = my_strptime(bp, "%m/%d/%y", tm)))
return (0);
break;
/*
* "Elementary" conversion rules.
*/
case 'A': /* The day of week, using the locale's form. */
case 'a':
LEGAL_ALT(0);
for (i = 0; i < 7; i++) {
/* Full name. */
len = strlen(day[i]);
if (strncmp(day[i], bp, len) == 0)
break;
/* Abbreviated name. */
len = strlen(abday[i]);
if (strncmp(abday[i], bp, len) == 0)
break;
}
/* Nothing matched. */
if (i == 7)
return (0);
tm->tm_wday = i;
bp += len;
break;
case 'B': /* The month, using the locale's form. */
case 'b':
case 'h':
LEGAL_ALT(0);
for (i = 0; i < 12; i++) {
/* Full name. */
len = strlen(mon[i]);
if (strncmp(mon[i], bp, len) == 0)
break;
/* Abbreviated name. */
len = strlen(abmon[i]);
if (strncmp(abmon[i], bp, len) == 0)
break;
}
/* Nothing matched. */
if (i == 12)
return (0);
tm->tm_mon = i;
bp += len;
break;
case 'C': /* The century number. */
LEGAL_ALT(ALT_E);
if (!(conv_num(&bp, &i, 0, 99)))
return (0);
if (split_year) {
tm->tm_year = (tm->tm_year % 100) + (i * 100);
}
else {
tm->tm_year = i * 100;
split_year = 1;
}
break;
case 'd': /* The day of month. */
case 'e':
LEGAL_ALT(ALT_O);
if (!(conv_num(&bp, &tm->tm_mday, 1, 31)))
return (0);
break;
case 'k': /* The hour (24-hour clock representation). */
LEGAL_ALT(0);
/* FALLTHROUGH */
case 'H':
LEGAL_ALT(ALT_O);
if (!(conv_num(&bp, &tm->tm_hour, 0, 23)))
return (0);
break;
case 'l': /* The hour (12-hour clock representation). */
LEGAL_ALT(0);
/* FALLTHROUGH */
case 'I':
LEGAL_ALT(ALT_O);
if (!(conv_num(&bp, &tm->tm_hour, 1, 12)))
return (0);
if (tm->tm_hour == 12)
tm->tm_hour = 0;
break;
case 'j': /* The day of year. */
LEGAL_ALT(0);
if (!(conv_num(&bp, &i, 1, 366)))
return (0);
tm->tm_yday = i - 1;
break;
case 'M': /* The minute. */
LEGAL_ALT(ALT_O);
if (!(conv_num(&bp, &tm->tm_min, 0, 59)))
return (0);
break;
case 'm': /* The month. */
LEGAL_ALT(ALT_O);
if (!(conv_num(&bp, &i, 1, 12)))
return (0);
tm->tm_mon = i - 1;
break;
case 'p': /* The locale's equivalent of AM/PM. */
LEGAL_ALT(0);
/* AM? */
if (strcmp(am_pm[0], bp) == 0) {
if (tm->tm_hour > 11)
return (0);
bp += strlen(am_pm[0]);
break;
}
/* PM? */
else if (strcmp(am_pm[1], bp) == 0) {
if (tm->tm_hour > 11)
return (0);
tm->tm_hour += 12;
bp += strlen(am_pm[1]);
break;
}
/* Nothing matched. */
return (0);
case 'S': /* The seconds. */
LEGAL_ALT(ALT_O);
if (!(conv_num(&bp, &tm->tm_sec, 0, 61)))
return (0);
break;
case 'U': /* The week of year, beginning on sunday. */
case 'W': /* The week of year, beginning on monday. */
LEGAL_ALT(ALT_O);
/*
* XXX This is bogus, as we can not assume any valid
* information present in the tm structure at this
* point to calculate a real value, so just check the
* range for now.
*/
if (!(conv_num(&bp, &i, 0, 53)))
return (0);
break;
case 'w': /* The day of week, beginning on sunday. */
LEGAL_ALT(ALT_O);
if (!(conv_num(&bp, &tm->tm_wday, 0, 6)))
return (0);
break;
case 'Y': /* The year. */
LEGAL_ALT(ALT_E);
if (!(conv_num(&bp, &i, 0, 9999)))
return (0);
tm->tm_year = i - TM_YEAR_BASE;
break;
case 'y': /* The year within 100 years of the epoch. */
LEGAL_ALT(ALT_E | ALT_O);
if (!(conv_num(&bp, &i, 0, 99)))
return (0);
if (split_year) {
tm->tm_year = ((tm->tm_year / 100) * 100) + i;
break;
}
split_year = 1;
if (i <= 68)
tm->tm_year = i + 2000 - TM_YEAR_BASE;
else
tm->tm_year = i + 1900 - TM_YEAR_BASE;
break;
/*
* Miscellaneous conversions.
*/
case 'n': /* Any kind of white-space. */
case 't':
LEGAL_ALT(0);
while (isspace(*bp))
bp++;
break;
default: /* Unknown/unsupported conversion. */
return (0);
}
}
/* LINTED functional specification */
return ((char *)bp);
}
static int
conv_num(const char **buf, int *dest, int llim, int ulim)
{
int result = 0;
/* The limit also determines the number of valid digits. */
int rulim = ulim;
if (**buf < '0' || **buf > '9')
return (0);
do {
result *= 10;
result += *(*buf)++ - '0';
rulim /= 10;
} while ((result * 10 <= ulim) && rulim && **buf >= '0' && **buf <= '9');
if (result < llim || result > ulim)
return (0);
*dest = result;
return (1);
}
unsigned long
timestamp() {
time_t t;
t = time(NULL); // C库的系统函数,获取当前系统的时间戳
return (unsigned long)t;
}
unsigned long
date2timestamp(const char* fmt_date, const char* date) {
struct tm tmp_time;
memset(&tmp_time, 0, sizeof(struct tm));
// linux上面有,windows上面没有。
my_strptime(date, fmt_date, &tmp_time);
time_t t = mktime(&tmp_time);
return (unsigned long)t;
}
void
timestamp2date(unsigned long t, char*fmt_date, char* out_buf, int buf_len) {
time_t t_value = (time_t)t;
struct tm* date = localtime(&t_value);
strftime(out_buf, buf_len, fmt_date, date);
}
unsigned long
timestamp_today() {
time_t now;
now = time(NULL);
struct tm* date = localtime(&now);
date->tm_hour = 0;
date->tm_min = 0;
date->tm_sec = 0;
time_t today = mktime(date);
return (unsigned long)today;
}
unsigned long
timestamp_yesterday() {
unsigned long today = timestamp_today();
unsigned long yesterday = today - 24 * 60 * 60;
return (unsigned long)yesterday;
}
logger.h和logger.cc代码
#ifndef __LOGGER_H__
#define __LOGGER_H__
enum {
DEBUG = 0,
WARNING,
ERROR,
};
#define log_debug(msg, ...) logger::log(__FILE__, __LINE__, DEBUG, msg, ## __VA_ARGS__);
#define log_warning(msg, ...) logger::log(__FILE__, __LINE__, WARNING, msg, ## __VA_ARGS__);
#define log_error(msg, ...) logger::log(__FILE__, __LINE__, ERROR, msg, ## __VA_ARGS__);
class logger {
public:
static void init(char* path, char* prefix, bool std_output = false);
static void log(const char* file_name,
int line_num,
int level, const char* msg, ...);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdarg.h>
#include <fcntl.h>
#include <time.h>
#include <string>
using namespace std;
#include "logger.h"
#include "uv.h"
static string g_log_path;
static string g_prefix;
static uv_fs_t g_file_handle;
static uint32_t g_current_day;
static uint32_t g_last_second;
static char g_format_time[64] = { 0 };
static char* g_log_level[] = { "DEBUG ", "WARNING ", "ERROR "};
static bool g_std_out = false;
static void
open_file(tm* time_struct) {
int result = 0;
char fileName[128] = { 0 };
if (g_file_handle.result != 0) {
uv_fs_close(uv_default_loop(), &g_file_handle, g_file_handle.result, NULL);
uv_fs_req_cleanup(&g_file_handle);
g_file_handle.result = 0;
}
sprintf(fileName, "%s%s.%4d%02d%02d.log", g_log_path.c_str(), g_prefix.c_str(), time_struct->tm_year + 1900, time_struct->tm_mon + 1, time_struct->tm_mday);
result = uv_fs_open(NULL, &g_file_handle, fileName, O_CREAT | O_RDWR | O_APPEND, S_IREAD | S_IWRITE, NULL);
if (result < 0) {
fprintf(stderr, "open file failed! name=%s, reason=%s", fileName, uv_strerror(result));
}
}
static void
prepare_file() {
time_t now = time(NULL);
now += 8 * 60 * 60;
tm* time_struct = gmtime(&now);
if (g_file_handle.result == 0) {
g_current_day = time_struct->tm_mday;
open_file(time_struct);
}
else {
if (g_current_day != time_struct->tm_mday) {
g_current_day = time_struct->tm_mday;
open_file(time_struct);
}
}
}
static void
format_time() {
time_t now = time(NULL);
now += 8 * 60 * 60;
tm* time_struct = gmtime(&now);
if (now != g_last_second) {
g_last_second = (uint32_t)now;
memset(g_format_time, 0, sizeof(g_format_time));
sprintf(g_format_time, "%4d%02d%02d %02d:%02d:%02d ",
time_struct->tm_year + 1900, time_struct->tm_mon + 1, time_struct->tm_mday,
time_struct->tm_hour, time_struct->tm_min, time_struct->tm_sec);
}
}
void
logger::init(char* path, char* prefix, bool std_output) {
g_prefix = prefix;
g_log_path = path;
g_std_out = std_output;
if (*(g_log_path.end() - 1) != '/') {
g_log_path += "/";
}
std::string tmp_path = g_log_path;
int find = tmp_path.find("/");
uv_fs_t req;
int result;
while (find != std::string::npos) {
result = uv_fs_mkdir(uv_default_loop(), &req, tmp_path.substr(0, find).c_str(), 0755, NULL);
find = tmp_path.find("/", find + 1);
}
uv_fs_req_cleanup(&req);
}
void
logger::log(const char* file_name,
int line_num,
int level, const char* msg, ...) {
prepare_file();
format_time();
static char msg_meta_info[1024] = { 0 };
static char msg_content[1024 * 10] = { 0 };
static char new_line = '\n';
va_list args;
va_start(args, msg);
vsnprintf(msg_content, sizeof(msg_content), msg, args);
va_end(args);
sprintf(msg_meta_info, "%s:%u ", file_name, line_num);
uv_buf_t buf[6]; // time level content fileandline newline
buf[0] = uv_buf_init(g_format_time, strlen(g_format_time));
buf[1] = uv_buf_init(g_log_level[level], strlen(g_log_level[level]));
buf[2] = uv_buf_init(msg_meta_info, strlen(msg_meta_info));
buf[3] = uv_buf_init(&new_line, 1);
buf[4] = uv_buf_init(msg_content, strlen(msg_content));
buf[5] = uv_buf_init(&new_line, 1);
uv_fs_t writeReq;
int result = uv_fs_write(NULL, &writeReq, g_file_handle.result, buf, sizeof(buf) / sizeof(buf[0]), -1, NULL);
if (result < 0) {
fprintf(stderr, "log failed %s%s%s%s", g_format_time, g_log_level[level], msg_meta_info, msg_content);
}
uv_fs_req_cleanup(&writeReq);
if (g_std_out) {
printf("%s:%u\n[%s] %s\n", file_name, line_num, g_log_level[level], msg_content);
}
}
使用
logger::init("logger/gateway/", "gateway", true);
log_debug("%d",timestamp());
log_debug("%d",timestamp_today());
log_debug("%d",date2timestamp("%Y%m%d%H%M%S",20240701000000));
unsigned long yesterday = timestamp_yesterday();
char out_buf[64];
timestamp2date(yesterday, "%Y-%m-%d %H:%M:%S", out_buf, sizeof(out_buf));
log_debug("%s", out_buf);
注意设置工作目录,会在工作目录下生成logger文件夹
UDP服务
1: 为什么UDP没有粘包问题:
TCP是面向流的, 流, 要说明就像河水一样, 只要有水, 就会一直流向低处, 不会间断. TCP为了提高传输效率, 发送数据的时候, 并不是直接发送数据到网路, 而是先暂存到系统缓冲, 超过时间或者缓冲满了, 才把缓冲区的内容发送出去, 这样, 就可以有效提高发送效率. 所以会造成所谓的粘包, 即前一份Send的数据跟后一份Send的数据可能会暂存到缓冲当中, 然后一起发送.
UDP就不同了, 面向报文形式, 系统是不会缓冲的, 也不会做优化的, Send的时候, 就会直接Send到网络上, 对方收不收到也不管, 所以这块数据总是能够能一包一包的形式接收到, 而不会出现前一个包跟后一个包都写到缓冲然后一起Send。
udp_session.h和udp_session.cc
#ifndef __UDP_SESSION_H__
#define __UDP_SESSION_H__
class udp_session : session {
public:
uv_udp_t* udp_handler;
char c_address[32];
int c_port;
const struct sockaddr* addr;
public:
virtual void close();
virtual void send_data(unsigned char* body, int len);
virtual const char* get_address(int* client_port);
virtual void send_msg(struct cmd_msg* msg);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "uv.h"
#include "session.h"
#include "udp_session.h"
#include "proto_man.h"
void
udp_session::close() {
}
static void
on_uv_udp_send_end(uv_udp_send_t* req, int status) {
if (status == 0) {
// printf("send sucess\n");
}
free(req);
}
void
udp_session::send_data(unsigned char* body, int len) {
uv_buf_t w_buf;
w_buf = uv_buf_init((char*)body, len);
uv_udp_send_t* req = (uv_udp_send_t*)malloc(sizeof(uv_udp_send_t));
uv_udp_send(req, this->udp_handler, &w_buf, 1, this->addr, on_uv_udp_send_end);
}
const char*
udp_session::get_address(int* port) {
*port = this->c_port;
return this->c_address;
}
void
udp_session::send_msg(struct cmd_msg* msg) {
unsigned char* encode_pkg = NULL;
int encode_len = 0;
encode_pkg = proto_man::encode_msg_to_raw(msg, &encode_len);
if (encode_pkg) {
this->send_data(encode_pkg, encode_len);
proto_man::msg_raw_free(encode_pkg);
}
}
netbus.cc新增代码
static void
on_recv_client_cmd(session* s, unsigned char* body, int len) {
// printf("client command !!!!\n");
struct cmd_msg* msg = NULL;
if (proto_man::decode_cmd_msg(body, len, &msg)) {
if (!service_man::on_recv_cmd_msg((session*)s, msg)) {
s->close();
}
proto_man::cmd_msg_free(msg);
}
}
struct udp_recv_buf {
char* recv_buf;
size_t max_recv_len;
};
static void
udp_uv_alloc_buf(uv_handle_t* handle,
size_t suggested_size,
uv_buf_t* buf) {
// 提前分配好空间,如果不够重新分配
suggested_size = (suggested_size < 8096) ? 8096 : suggested_size;
struct udp_recv_buf* udp_buf = (struct udp_recv_buf*) handle->data;
if (udp_buf->max_recv_len < suggested_size) {
if (udp_buf->recv_buf) {
free(udp_buf->recv_buf);
udp_buf->recv_buf = NULL;
}
udp_buf->recv_buf = (char*)malloc(suggested_size);
udp_buf->max_recv_len = suggested_size;
}
buf->base = udp_buf->recv_buf;
buf->len = suggested_size;
}
after_uv_udp_recv(uv_udp_t* handle,
ssize_t nread,
const uv_buf_t* buf,
const struct sockaddr* addr,
unsigned flags) {
udp_session udp_s;
udp_s.udp_handler = handle;
udp_s.addr = addr;
uv_ip4_name((struct sockaddr_in*)addr, udp_s.c_address, 32);
udp_s.c_port = ntohs(((struct sockaddr_in*)addr)->sin_port);
on_recv_client_cmd((session*)&udp_s, (unsigned char*)buf->base, nread);
}
void
netbus::start_upd_server(int port){
uv_udp_t* server = (uv_udp_t*)malloc(sizeof(uv_udp_t));
memset(server, 0, sizeof(uv_udp_t));
uv_udp_init(uv_default_loop(), server);
struct udp_recv_buf* udp_buf = (struct udp_recv_buf*)malloc(sizeof(struct udp_recv_buf));
memset(udp_buf, 0, sizeof(struct udp_recv_buf));
server->data = (struct udp_recv_buf*) udp_buf;
struct sockaddr_in addr;
uv_ip4_addr("0.0.0.0", port, &addr);
uv_udp_bind(server, (const struct sockaddr*)&addr, 0);
uv_udp_recv_start(server, udp_uv_alloc_buf, after_uv_udp_recv);
}
开启服务器
netbus::instance()->start_upd_server(8002);
异步mysql模块
1: mysql client 依赖库与开发环境;
下载地址: https://dev.mysql.com/downloads/connector/c/
2: 从32bit与64bit的zip包中,解压出include, lib文件, 配置好x86与x64的环境;
3:将mysql client的配置文件放到build路径下;
4: 添加头文件搜索路径到编译器;
5: 添加库文件搜索路径到编译器;
6: mysql 库都是同步的,等待服务器返回结果,所以使用工作队列不阻塞主线程;
5: 工作队列在处理mysql的时候,由于mysql是线程不安全的,所以加锁来保证同步,每次只处理一个请求;
mysql_wrapper.h和mysql_wrapper.cc
#ifndef __MYSQL_WRAPPER_H__
#define __MYSQL_WRAPPER_H__
#include "mysql.h"
class mysql_wrapper {
public:
// 这里的udata用来存lua脚本发过来的函数handle(c++通过handle执行lua中的函数,把参数压入栈中调用这个handle即可),后面会看到用法
static void connect(char* ip, int port,
char* db_name, char* uname, char* pwd,
void(*open_cb)(const char* err, void* context, void* udata), void* udata = NULL);
static void close(void* context);
// 这里的udata用来存lua脚本发过来的函数handle,后面会看到用法
static void query(void* context,
char* sql,
void(*query_cb)(const char* err, MYSQL_RES* result, void* udata), void* udata = NULL);
};
#endif
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "uv.h"
#ifdef WIN32
#include <winsock.h>
#pragma comment(lib, "ws2_32.lib")
#pragma comment(lib, "libmysql.lib")
#endif
#include "mysql.h"
#include "mysql_wrapper.h"
#define my_malloc malloc
#define my_free free
struct connect_req {
char* ip;
int port;
char* db_name;
char* uname;
char* upwd;
void(*open_cb)(const char* err, void* context, void* udata);
char* err;
void* context;
void* udata;
};
struct mysql_context {
void* pConn; // mysql
uv_mutex_t lock;
int is_closed;
};
static void
connect_work(uv_work_t* req) {
struct connect_req* r = (struct connect_req*)req->data;
MYSQL* pConn = mysql_init(NULL);
if (mysql_real_connect(pConn, r->ip, r->uname, r->upwd, r->db_name, r->port, NULL, 0)) { //
// r->context = pConn;
struct mysql_context* c = (struct mysql_context*)my_malloc(sizeof(struct mysql_context));
memset(c, 0, sizeof(struct mysql_context));
c->pConn = pConn;
uv_mutex_init(&c->lock);
r->context = c;
r->err = NULL;
mysql_set_character_set(pConn, "utf8");
}
else {
r->context = NULL;
r->err = strdup(mysql_error(pConn));
}
}
static void
on_connect_complete(uv_work_t* req, int status) {
struct connect_req* r = (struct connect_req*)req->data;
r->open_cb(r->err, r->context, r->udata);
if (r->ip) {
free(r->ip);
}
if (r->db_name) {
free(r->db_name);
}
if (r->uname) {
free(r->uname);
}
if (r->upwd) {
free(r->upwd);
}
if (r->err) {
free(r->err);
}
my_free(r);
my_free(req);
}
void
mysql_wrapper::connect(char* ip, int port,
char* db_name, char* uname, char* pwd,
void(*open_cb)(const char* err, void* context, void* udata), void* udata) {
uv_work_t* w = (uv_work_t*)my_malloc(sizeof(uv_work_t));
memset(w, 0, sizeof(uv_work_t));
struct connect_req* r = (struct connect_req*)my_malloc(sizeof(struct connect_req));
memset(r, 0, sizeof(struct connect_req));
r->ip = strdup(ip);
r->port = port;
r->db_name = strdup(db_name);
r->uname = strdup(uname);
r->upwd = strdup(pwd);
r->open_cb = open_cb;
r->udata = udata;
w->data = (void*) r;
uv_queue_work(uv_default_loop(), w, connect_work, on_connect_complete);
}
static void
close_work(uv_work_t* req) {
struct mysql_context* r = (struct mysql_context*)(req->data);
uv_mutex_lock(&r->lock);
MYSQL* pConn = (MYSQL*)r->pConn;
mysql_close(pConn);
uv_mutex_unlock(&r->lock);
}
static void
on_close_complete(uv_work_t* req, int status) {
struct mysql_context* r = (struct mysql_context*)(req->data);
my_free(r);
my_free(req);
}
void
mysql_wrapper::close(void* context) {
struct mysql_context* c = (struct mysql_context*) context;
if (c->is_closed) {
return;
}
uv_work_t* w = (uv_work_t*)my_malloc(sizeof(uv_work_t));
memset(w, 0, sizeof(uv_work_t));
w->data = (context);
c->is_closed = 1;
uv_queue_work(uv_default_loop(), w, close_work, on_close_complete);
}
struct query_req {
void* context;
char* sql;
void(*query_cb)(const char* err, MYSQL_RES* result, void* udata);
char* err;
MYSQL_RES* result;
void* udata;
};
static void
query_work(uv_work_t* req) {
query_req* r = (query_req*)req->data;
struct mysql_context* my_conn = (struct mysql_context*)(r->context);
uv_mutex_lock(&my_conn->lock);
MYSQL* pConn = (MYSQL*)my_conn->pConn;
int ret = mysql_query(pConn, r->sql);
if (ret != 0) {
r->err = strdup(mysql_error(pConn));
r->result = NULL;
uv_mutex_unlock(&my_conn->lock);
return;
}
r->err = NULL;
MYSQL_RES *result = mysql_store_result(pConn);
r->result = result;
uv_mutex_unlock(&my_conn->lock);
}
static void
on_query_complete(uv_work_t* req, int status) {
query_req* r = (query_req*)req->data;
r->query_cb(r->err, r->result, r->udata);
if (r->sql) {
free(r->sql);
}
if (r->result) {
mysql_free_result(r->result);
r->result = NULL;
}
if (r->err) {
free(r->err);
}
my_free(r);
my_free(req);
}
void
mysql_wrapper::query(void* context,
char* sql,
void(*query_cb)(const char* err, MYSQL_RES* result, void* udata),
void* udata) {
struct mysql_context* c = (struct mysql_context*) context;
if (c->is_closed) {
return;
}
uv_work_t* w = (uv_work_t*)my_malloc(sizeof(uv_work_t));
memset(w, 0, sizeof(uv_work_t));
query_req* r = (query_req*)my_malloc(sizeof(query_req));
memset(r, 0, sizeof(query_req));
r->context = context;
r->sql = strdup(sql);
r->query_cb = query_cb;
r->udata = udata;
w->data = r;
uv_queue_work(uv_default_loop(), w, query_work, on_query_complete);
}
使用
static void
on_query_cb(const char* err, std::vector<std::vector<std::string>>* result) {
if (err) {
printf("err");
return;
}
printf("success");
}
static void
on_open_cb(const char* err, void* context) {
if (err != NULL) {
printf("%s\n", err);
return;
}
printf("connect success");
// mysql_wrapper::query(context, "update class_test set name = \"blake haha\" where id = 8", on_query_cb);
mysql_wrapper::query(context, "select * from class_test", on_query_cb);
// mysql_wrapper::close(context);
}
mysql_wrapper::connect("127.0.0.1", 3306, "class_sql", "root", "123456", on_open_cb);
异步redis模块
1: 准备好redis client库;
hiredis: include头文件, x64_libs, x86_libs;
2: 编译器配置:
#include <hiredis.h>
#define NO_QFORKIMPL //这一行必须加才能正常使用
#include <Win32_Interop/win32fixes.h>
#pragma comment(lib,”hiredis.lib”)
#pragma comment(lib,”Win32_Interop.lib”)
(1)添加库搜索路径
(2)添加头文件搜索路径
(3)将win32fixes.c 加入到工程编译;
(4)win32fixes.h之前加上#define NO_QFORKIMPL
(5)右击项目->属性->配置属性->C/C++->代码生成->运行库->改成多线程调试(/MTd)或多线程(/MT)
(6)右击项目->属性->配置属性->链接器->命令行中输入/NODEFAULTLIB:libcmt.lib
(7)右击项目->属性->配置属性->C/C++->预处理器->预处理器定义->添加“_CRT_SECURE_NO_WARNINGS
redis_wrapper.h和redis_wrapper.cc
#ifndef __REDIS_WRAPPER_H__
#define __REDIS_WRAPPER_H__
#include <hiredis.h>
class redis_wrapper {
public:
// 这里的udata用来存lua脚本发过来的函数handle,后面会看到用法
static void connect(char* ip, int port,
void(*open_cb)(const char* err, void* context, void* udata), void* udata = NULL);
static void close_redis(void* context);
// 这里的udata用来存lua脚本发过来的函数handle,后面会看到用法
static void query(void* context,
char* cmd,
void(*query_cb)(const char* err, redisReply* result, void* udata), void* udata = NULL);
};
#endif
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <hiredis.h>
#ifdef WIN32
#define NO_QFORKIMPL //这一行必须加才能正常使用
#include <Win32_Interop/win32fixes.h>
#pragma comment(lib,"hiredis.lib")
#pragma comment(lib,"Win32_Interop.lib")
#endif
#include "uv.h"
#include "redis_wrapper.h"
#define my_malloc malloc
#define my_free free
struct connect_req {
char* ip;
int port;
void(*open_cb)(const char* err, void* context, void* udata);
char* err;
void* context;
void* udata;
};
struct redis_context {
void* pConn; // mysql
uv_mutex_t lock;
int is_closed;
};
static void
connect_work(uv_work_t* req) {
struct connect_req* r = (struct connect_req*)req->data;
struct timeval timeout = { 5, 0 }; // 5 seconds
redisContext* rc = redisConnectWithTimeout((char*)r->ip, r->port, timeout);
if (rc->err) {
printf("Connection error: %s\n", rc->errstr);
r->err = strdup(rc->errstr);
r->context = NULL;
redisFree(rc);
}
else {
struct redis_context* c = (struct redis_context*)my_malloc(sizeof(struct redis_context));
memset(c, 0, sizeof(struct redis_context));
c->pConn = rc;
uv_mutex_init(&c->lock);
r->err = NULL;
r->context = c;
}
}
static void
on_connect_complete(uv_work_t* req, int status) {
struct connect_req* r = (struct connect_req*)req->data;
r->open_cb(r->err, r->context, r->udata);
if (r->ip) {
free(r->ip);
}
if (r->err) {
free(r->err);
}
my_free(r);
my_free(req);
}
void
redis_wrapper::connect(char* ip, int port,
void(*open_cb)(const char* err, void* context, void* udata), void* udata) {
uv_work_t* w = (uv_work_t*)my_malloc(sizeof(uv_work_t));
memset(w, 0, sizeof(uv_work_t));
struct connect_req* r = (struct connect_req*)my_malloc(sizeof(struct connect_req));
memset(r, 0, sizeof(struct connect_req));
r->ip = strdup(ip);
r->port = port;
r->open_cb = open_cb;
r->udata = udata;
w->data = (void*) r;
uv_queue_work(uv_default_loop(), w, connect_work, on_connect_complete);
}
static void
close_work(uv_work_t* req) {
struct redis_context* r = (struct redis_context*)(req->data);
uv_mutex_lock(&r->lock);
redisContext* c = (redisContext*)r->pConn;
redisFree(c);
r->pConn = NULL;
uv_mutex_unlock(&r->lock);
}
static void
on_close_complete(uv_work_t* req, int status) {
struct redis_context* r = (struct redis_context*)(req->data);
my_free(r);
my_free(req);
}
void
redis_wrapper::close_redis(void* context) {
struct redis_context* c = (struct redis_context*) context;
if (c->is_closed) {
return;
}
uv_work_t* w = (uv_work_t*)my_malloc(sizeof(uv_work_t));
memset(w, 0, sizeof(uv_work_t));
w->data = (context);
c->is_closed = 1;
uv_queue_work(uv_default_loop(), w, close_work, on_close_complete);
}
struct query_req {
void* context;
char* cmd;
void(*query_cb)(const char* err, redisReply* result, void* udata);
char* err;
redisReply* result;
void* udata;
};
static void
query_work(uv_work_t* req) {
query_req* r = (query_req*)req->data;
struct redis_context* my_conn = (struct redis_context*)(r->context);
redisContext* rc = (redisContext*)my_conn->pConn;
uv_mutex_lock(&my_conn->lock);
redisReply* replay = (redisReply*)redisCommand(rc, r->cmd);
if (replay->type == REDIS_REPLY_ERROR) {
r->err = strdup(replay->str);
r->result = NULL;
freeReplyObject(replay);
}
else {
r->result = replay;
r->err = NULL;
}
uv_mutex_unlock(&my_conn->lock);
}
static void
on_query_complete(uv_work_t* req, int status) {
query_req* r = (query_req*)req->data;
r->query_cb(r->err, r->result, r->udata);
if (r->cmd) {
free(r->cmd);
}
if (r->result) {
freeReplyObject(r->result);
}
if (r->err) {
free(r->err);
}
my_free(r);
my_free(req);
}
void
redis_wrapper::query(void* context,
char* cmd,
void(*query_cb)(const char* err, redisReply* result, void* udata), void* udata) {
struct redis_context* c = (struct redis_context*) context;
if (c->is_closed) {
return;
}
uv_work_t* w = (uv_work_t*)my_malloc(sizeof(uv_work_t));
memset(w, 0, sizeof(uv_work_t));
query_req* r = (query_req*)my_malloc(sizeof(query_req));
memset(r, 0, sizeof(query_req));
r->context = context;
r->cmd = strdup(cmd);
r->query_cb = query_cb;
r->udata = udata;
w->data = r;
uv_queue_work(uv_default_loop(), w, query_work, on_query_complete);
}
使用
on_redis_query(const char* err, redisReply* result) {
if (err) {
printf("%s\n", err);
return;
}
printf("success\n");
}
static void
on_redis_open(const char* err, void* context) {
if (err != NULL) {
printf("%s\n", err);
return;
}
printf("connect success\n");
redis_wrapper::query(context, "select 1", on_redis_query);
// redis_wrapper::close_redis(context);
}
redis_wrapper::connect("127.0.0.1", 6379, on_redis_open);
内置lua脚本解释
使用lua开发效率更高,新手上手快,因此很多公司的Service层都使用lua作为脚本语言编程,遇到需要性能的部分也可以通过c/c++开发,比较灵活,因此服务器框架需要兼容lua脚本。
为了能够兼容lua脚本,需要在c/c++框架中开发一个lua解释器(lua虚拟机),解释执行lua代码。
步骤:
1: 下载Lua源码:www.lua.org,版本lua5.3
2: 将Lua代码(src文件夹下)拉入框架进行编译(去掉lua.c和lua.h)
3: 新建lua_wrapper文件夹, 编写相关的代码融合Lua与框架;
lua_wrapper.h和lua_wrapper.cc
#ifndef __LUA_WRAPER_H__
#define __LUA_WRAPER_H__
#include "lua.hpp"
class lua_wrapper {
public:
static void init();
static void exit();
static bool exe_lua_file(const char* lua_file);
};
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "lua_wrapper.h"
lua_State* g_lua_State = NULL;
void
lua_wrapper::init() {
g_lua_State = luaL_newstate();
luaL_openlibs(g_lua_State);
}
void
lua_wrapper::exit() {
if (g_lua_State != NULL) {
lua_close(g_lua_State);
g_lua_State = NULL;
}
}
bool
lua_wrapper::exe_lua_file(const char* lua_file) {
if (luaL_dofile(g_lua_State, lua_file)) {
return false;
}
return true;
}
使用
lua_wrapper::init();
lua_wrapper::exe_lua_file("./main.lua");
日志函数导出
1: Lua 只能调用 int (*lua_CFunction) (lua_State *L) 这种类型的C 函数, 所有的函数如果要给Lua
调用,只能用这样的函数来封装;
2: 怎么获得Lua传递过来的参数? 通过操作Lua 虚拟机的栈;
lua_wrapper.cc
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "../utils/logger.h"
#include "lua_wrapper.h"
lua_State* g_lua_State = NULL;
static void
print_error(const char* file_name, int line_num, const char* msg) {
logger::log(file_name, line_num, ERROR, msg);
}
static void
print_warning(const char* file_name, int line_num, const char* msg) {
logger::log(file_name, line_num, WARNING, msg);
}
static void
print_debug(const char* file_name, int line_num, const char* msg) {
logger::log(file_name, line_num, DEBUG, msg);
}
static void
do_log_message(void(*log)(const char* file_name, int line_num, const char* msg), const char* msg) {
lua_Debug info;
int depth = 0;
/*lua_getinfo函数获取指定层次的栈信息并存储在info结构体中。
"S":获取关于函数定义位置的信息(包括源文件和行号)。
"n":获取关于函数名字的信息。
"l":获取当前行号。*/
while (lua_getstack(g_lua_State, depth, &info)) {
lua_getinfo(g_lua_State, "S", &info);
lua_getinfo(g_lua_State, "n", &info);
lua_getinfo(g_lua_State, "l", &info);
/*检查info.source的第一个字符是否为'@',如果是,表示info.source是一个文件名。
log(&info.source[1], info.currentline, msg);:调用日志记录函数,记录文件名(去掉'@')、行号和消息。*/
if (info.source[0] == '@') {
log(&info.source[1], info.currentline, msg);
return;
}
++depth;
}
if (depth == 0) {
log("trunk", 0, msg);
}
}
static int
lua_log_debug(lua_State *L) {
const char* msg = luaL_checkstring(L, -1);
if (msg) { // file_name, line_num
do_log_message(print_debug, msg);
}
return 0;
}
static int
lua_log_warning(lua_State *L) {
/*luaL_checkstring从Lua栈顶获取一个字符串。如果栈顶不是字符串,则会引发一个Lua错误。
msg:这是从Lua脚本传递过来的日志消息。*/
const char* msg = luaL_checkstring(L, -1);
if (msg) { // file_name, line_num
do_log_message(print_warning, msg);
}
return 0;
}
static int
lua_log_error(lua_State *L) {
const char* msg = luaL_checkstring(L, -1);
if (msg) { // file_name, line_num
do_log_message(print_error, msg);
}
return 0;
}
// 为了防止lua报错导致进程被关,自己定义了错误处理
static int
lua_panic(lua_State *L) {
const char* msg = luaL_checkstring(L, -1);
if (msg) { // file_name, line_num
do_log_message(print_error, msg);
}
return 0;
}
void
lua_wrapper::init() {
g_lua_State = luaL_newstate();
lua_atpanic(g_lua_State, lua_panic); // default abort;
luaL_openlibs(g_lua_State);
// export log
lua_wrapper::reg_func2lua("log_error", lua_log_error);
lua_wrapper::reg_func2lua("log_debug", lua_log_debug);
lua_wrapper::reg_func2lua("log_warning", lua_log_warning);
// end
}
void
lua_wrapper::exit() {
if (g_lua_State != NULL) {
lua_close(g_lua_State);
g_lua_State = NULL;
}
}
bool
lua_wrapper::exe_lua_file(const char* lua_file) {
if (luaL_dofile(g_lua_State, lua_file)) {
lua_log_error(g_lua_State);
return false;
}
return true;
}
void
lua_wrapper::reg_func2lua(const char* name, int(*c_func)(lua_State *L)) {
// 函数压入栈中,g_lua_State是全局的Lua状态指针
lua_pushcfunction(g_lua_State, c_func);
// 将栈顶的C函数弹出并将其存储在全局变量中,变量名为name。这样,Lua脚本就可以通过name来调用这个C函数
lua_setglobal(g_lua_State, name);
}
在main.lua中可以调用log_debug(“HelloWorld”);了
拓展log支持打印不定参数
debug函数改为,warning和error同理
static int
lua_log_debug(lua_State *luastate) {
int nargs = lua_gettop(luastate);
std::string t;
for (int i = 1; i <= nargs; i++)
{
if (lua_istable(luastate, i))
t += "table";
else if (lua_isnone(luastate, i))
t += "none";
else if (lua_isnil(luastate, i))
t += "nil";
else if (lua_isboolean(luastate, i))
{
if (lua_toboolean(luastate, i) != 0)
t += "true";
else
t += "false";
}
else if (lua_isfunction(luastate, i))
t += "function";
else if (lua_islightuserdata(luastate, i))
t += "lightuserdata";
else if (lua_isthread(luastate, i))
t += "thread";
else
{
const char * str = lua_tostring(luastate, i);
if (str)
t += lua_tostring(luastate, i);
else
t += lua_typename(luastate, lua_type(luastate, i));
}
if (i != nargs)
t += "\t";
}
do_log_message(print_debug, t.c_str());
return 0;
}
tolua++模块导出
1: Lua提供了相应的机制,能把C/C++ 接口,导出给lua使用。
2: 这些接口都需要手动的导出注册,才能给Lua调用;
3: tolua++ 封装了Lua的一些接口,方便我们来做C/C++的代码导出;
4: tolua ++ 提供了一个工具,能够自动生成Lua导出代码,我们只要做好对应的导出配置文件;
5: tolua支持的版本是Lua5.1, Lua5.3不能直接支持,需要打上一些补丁,网络上有实现:
https://aur.archlinux.org/cgit/aur.git/tree/tolua53.patch?h=tolua%2B%2B_5.3
6: 我们编译使用tolua库,使用它的封装,来自己手动导出代码接口,服务器要导出的不多;
7: 使用tolua++, 参考了quick-cocos 3.6架构;
1: 我们的服务器会有几个模块,每个模块我们会导出Lua接口给脚本使用;
2: 每个模块编写一个.h与.cc, 命名规则: 模块名字_export_to_lua, mysql_export_to_lua.h
3: 编写模块注册函数: register_模块名字_export, 在初始化的时候调用;
lua_wrapper修改
#ifndef __LUA_WRAPER_H__
#define __LUA_WRAPER_H__
#include "lua.hpp"
class lua_wrapper {
public:
static void init();
static void exit();
static bool exe_lua_file(const char* lua_file);
static lua_State* lua_state();
public:
static void reg_func2lua(const char* name, int(*c_func)(lua_State *L));
public:
// c++利用lua传过来的handle和参数数量来执行lua函数
static int execute_script_handler(int nHandler, int numArgs);
// 移除handle
static void remove_script_handler(int nHandler);
};
#endif
.cc文件修改或新增,主要用于支持tolua++,用handle执行lua脚本的函数
void
lua_wrapper::init() {
g_lua_State = luaL_newstate();
lua_atpanic(g_lua_State, lua_panic); // default abort;
luaL_openlibs(g_lua_State);
toluafix_open(g_lua_State);
// export log
lua_wrapper::reg_func2lua("log_error", lua_log_error);
lua_wrapper::reg_func2lua("log_debug", lua_log_debug);
lua_wrapper::reg_func2lua("log_warning", lua_log_warning);
// end
}
pushFunctionByHandler(int nHandler)
{
toluafix_get_function_by_refid(g_lua_State, nHandler); /* L: ... func */
if (!lua_isfunction(g_lua_State, -1))
{
log_error("[LUA ERROR] function refid '%d' does not reference a Lua function", nHandler);
lua_pop(g_lua_State, 1);
return false;
}
return true;
}
static int
executeFunction(int numArgs)
{
int functionIndex = -(numArgs + 1);
if (!lua_isfunction(g_lua_State, functionIndex))
{
log_error("value at stack [%d] is not function", functionIndex);
lua_pop(g_lua_State, numArgs + 1); // remove function and arguments
return 0;
}
int traceback = 0;
lua_getglobal(g_lua_State, "__G__TRACKBACK__"); /* L: ... func arg1 arg2 ... G */
if (!lua_isfunction(g_lua_State, -1))
{
lua_pop(g_lua_State, 1); /* L: ... func arg1 arg2 ... */
}
else
{
lua_insert(g_lua_State, functionIndex - 1); /* L: ... G func arg1 arg2 ... */
traceback = functionIndex - 1;
}
int error = 0;
error = lua_pcall(g_lua_State, numArgs, 1, traceback); /* L: ... [G] ret */
if (error)
{
if (traceback == 0)
{
log_error("[LUA ERROR] %s", lua_tostring(g_lua_State, -1)); /* L: ... error */
lua_pop(g_lua_State, 1); // remove error message from stack
}
else /* L: ... G error */
{
lua_pop(g_lua_State, 2); // remove __G__TRACKBACK__ and error message from stack
}
return 0;
}
// get return value
int ret = 0;
if (lua_isnumber(g_lua_State, -1))
{
ret = (int)lua_tointeger(g_lua_State, -1);
}
else if (lua_isboolean(g_lua_State, -1))
{
ret = (int)lua_toboolean(g_lua_State, -1);
}
// remove return value from stack
lua_pop(g_lua_State, 1); /* L: ... [G] */
if (traceback)
{
lua_pop(g_lua_State, 1); // remove __G__TRACKBACK__ from stack /* L: ... */
}
return ret;
}
int
lua_wrapper::execute_script_handler(int nHandler, int numArgs) {
int ret = 0;
if (pushFunctionByHandler(nHandler)) /* L: ... arg1 arg2 ... func */
{
if (numArgs > 0)
{
lua_insert(g_lua_State, -(numArgs + 1)); /* L: ... func arg1 arg2 ... */
}
ret = executeFunction(numArgs);
}
lua_settop(g_lua_State, 0);
return ret;
}
void
lua_wrapper::remove_script_handler(int nHandler)
{
toluafix_remove_function_by_refid(g_lua_State, nHandler);
}
mysql模块导出和使用
mysql_export_to_lua.h和mysql_export_to_lua.cc
#ifndef __MYSQL_EXPORT_TO_LUA_H__
#define __MYSQL_EXPORT_TO_LUA_H__
struct lua_State;
int register_mysql_export(lua_State* L);
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "lua_wrapper.h"
#include "../database/mysql_wrapper.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "tolua++.h"
#ifdef __cplusplus
}
#endif
#include "tolua_fix.h"
static void
on_open_cb(const char* err, void* context, void* udata) {
if (err) {
lua_pushstring(lua_wrapper::lua_state(), err);
lua_pushnil(lua_wrapper::lua_state());
}
else {
lua_pushnil(lua_wrapper::lua_state());
tolua_pushuserdata(lua_wrapper::lua_state(), context);
}
lua_wrapper::execute_script_handler((int)udata, 2);
lua_wrapper::remove_script_handler((int)udata);
}
static int
lua_mysql_connect(lua_State* tolua_S) {
char* ip = (char*)tolua_tostring(tolua_S, 1, 0);
if (ip == NULL) {
goto lua_failed;
}
int port = (int)tolua_tonumber(tolua_S, 2, 0);
char* db_name = (char*)tolua_tostring(tolua_S, 3, 0);
if (db_name == NULL) {
goto lua_failed;
}
char* uname = (char*)tolua_tostring(tolua_S, 4, 0);
if (uname == NULL) {
goto lua_failed;
}
char* upwd = (char*)tolua_tostring(tolua_S, 5, 0);
if (upwd == NULL) {
goto lua_failed;
}
int handler = toluafix_ref_function(tolua_S, 6, 0);
mysql_wrapper::connect(ip, port, db_name, uname, upwd, on_open_cb, (void*)handler);
lua_failed:
return 0;
}
static int
lua_mysql_close(lua_State* tolua_S) {
void* context = tolua_touserdata(tolua_S, 1, 0);
if (context) {
mysql_wrapper::close(context);
}
return 0;
}
static void
push_mysql_row(MYSQL_ROW row, int num) {
lua_newtable(lua_wrapper::lua_state()); /* L: table */
int index = 1;
for (int i = 0; i < num; i ++) {
if (row[i] == NULL) {
lua_pushnil(lua_wrapper::lua_state());
}
else {
lua_pushstring(lua_wrapper::lua_state(), row[i]);
}
lua_rawseti(lua_wrapper::lua_state(), -2, index); /* table[index] = value, L: table */
++index;
}
}
static void
on_lua_query_cb(const char* err, MYSQL_RES* result, void* udata) {
if (err) {
lua_pushstring(lua_wrapper::lua_state(), err);
lua_pushnil(lua_wrapper::lua_state());
}
else {
lua_pushnil(lua_wrapper::lua_state());
if (result) { // 把查询得到的结果push成一个表; { {}, {}, {}, ...}
lua_newtable(lua_wrapper::lua_state());
int index = 1;
int num = mysql_num_fields(result);
MYSQL_ROW row;
while (row = mysql_fetch_row(result)) {
push_mysql_row(row, num); /* L: table value */
lua_rawseti(lua_wrapper::lua_state(), -2, index); /* table[index] = value, L: table */
++index;
}
}
else {
lua_pushnil(lua_wrapper::lua_state());
}
}
lua_wrapper::execute_script_handler((int)udata, 2);
lua_wrapper::remove_script_handler((int)udata);
}
static int
lua_mysql_query(lua_State* tolua_S) {
void* context = tolua_touserdata(tolua_S, 1, 0);
if (!context) {
goto lua_failed;
}
char* sql = (char*)tolua_tostring(tolua_S, 2, 0);
if (sql == NULL) {
goto lua_failed;
}
int handler = toluafix_ref_function(tolua_S, 3, 0);
if (handler == 0) {
goto lua_failed;
}
mysql_wrapper::query(context, sql, on_lua_query_cb, (void*)handler);
lua_failed:
return 0;
}
int
register_mysql_export(lua_State* tolua_S) {
lua_getglobal(tolua_S, "_G");
if (lua_istable(tolua_S, -1)) {
tolua_open(tolua_S);
tolua_module(tolua_S, "mysql_wrapper", 0);
tolua_beginmodule(tolua_S, "mysql_wrapper");
tolua_function(tolua_S, "connect", lua_mysql_connect);
tolua_function(tolua_S, "close", lua_mysql_close);
tolua_function(tolua_S, "query", lua_mysql_query);
tolua_endmodule(tolua_S);
}
lua_pop(tolua_S, 1);
return 0;
}
lua_wrapper.cc修改
void
lua_wrapper::init() {
g_lua_State = luaL_newstate();
lua_atpanic(g_lua_State, lua_panic); // default abort;
luaL_openlibs(g_lua_State);
toluafix_open(g_lua_State);
register_mysql_export(g_lua_State);
// export log
lua_wrapper::reg_func2lua("log_error", lua_log_error);
lua_wrapper::reg_func2lua("log_debug", lua_log_debug);
lua_wrapper::reg_func2lua("log_warning", lua_log_warning);
// end
}
lua脚本测试
log_debug("HelloWorld")
key = ""
function PrintTable(table , level)
level = level or 1
local indent = ""
for i = 1, level do
indent = indent.." "
end
if key ~= "" then
print(indent..key.." ".."=".." ".."{")
else
print(indent .. "{")
end
key = ""
for k,v in pairs(table) do
if type(v) == "table" then
key = k
PrintTable(v, level + 1)
else
local content = string.format("%s%s = %s", indent .. " ",tostring(k), tostring(v))
print(content)
end
end
print(indent .. "}")
end
mysql_wrapper.connect("127.0.0.1", 3306, "class_sql", "root", "123456", function(err, context)
log_debug("event call");
if(err) then
print(err)
return
end
--
mysql_wrapper.query(context, "select * from class_test", function (err, ret)
if err then
print(err)
return;
end
print("success")
PrintTable(ret)
end)
end)
这里会输出mysql连接和查询的结果,并且log日志中也有对应lua脚本的文件名和代码行号。
Redis的导出和使用
用法同mysql
redis_export_to_lua.h和redis_export_to_lua.cc
#ifndef __REDIS_EXPORT_LUA_H__
#define __REDIS_EXPORT_LUA_H__
struct lua_State;
int register_redis_export(lua_State* tolua_S);
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "lua_wrapper.h"
#include "../database/redis_wrapper.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "tolua++.h"
#ifdef __cplusplus
}
#endif
#include "tolua_fix.h"
#include "redis_export_to_lua.h"
static void
on_open_cb(const char* err, void* context, void* udata) {
if (err) {
lua_pushstring(lua_wrapper::lua_state(), err);
lua_pushnil(lua_wrapper::lua_state());
}
else {
lua_pushnil(lua_wrapper::lua_state());
tolua_pushuserdata(lua_wrapper::lua_state(), context);
}
lua_wrapper::execute_script_handler((int)udata, 2);
lua_wrapper::remove_script_handler((int)udata);
}
static int
lua_redis_connect(lua_State* tolua_S) {
char* ip = (char*)tolua_tostring(tolua_S, 1, 0);
if (ip == NULL) {
goto lua_failed;
}
int port = (int)tolua_tonumber(tolua_S, 2, 0);
int handler = toluafix_ref_function(tolua_S, 3, 0);
redis_wrapper::connect(ip, port, on_open_cb, (void*)handler);
lua_failed:
return 0;
}
static int
lua_redis_close(lua_State* tolua_S) {
void* context = tolua_touserdata(tolua_S, 1, 0);
if (context) {
redis_wrapper::close_redis(context);
}
return 0;
}
static void
push_result_to_lua(redisReply* result) {
switch (result->type) {
case REDIS_REPLY_STRING:
case REDIS_REPLY_STATUS:
lua_pushstring(lua_wrapper::lua_state(), result->str);
break;
case REDIS_REPLY_INTEGER:
lua_pushinteger(lua_wrapper::lua_state(), result->integer);
break;
case REDIS_REPLY_NIL:
lua_pushnil(lua_wrapper::lua_state());
break;
case REDIS_REPLY_ARRAY:
lua_newtable(lua_wrapper::lua_state());
int index = 1;
for (int i = 0; i < result->elements; i++) {
push_result_to_lua(result->element[i]);
lua_rawseti(lua_wrapper::lua_state(), -2, index); /* table[index] = value, L: table */
++index;
}
break;
}
}
static void
on_lua_query_cb(const char* err, redisReply* result, void* udata) {
if (err) {
lua_pushstring(lua_wrapper::lua_state(), err);
lua_pushnil(lua_wrapper::lua_state());
}
else {
lua_pushnil(lua_wrapper::lua_state());
if (result) { // 把查询得到的结果push lua
push_result_to_lua(result);
}
else {
lua_pushnil(lua_wrapper::lua_state());
}
}
lua_wrapper::execute_script_handler((int)udata, 2);
lua_wrapper::remove_script_handler((int)udata);
}
static int
lua_redis_query(lua_State* tolua_S) {
void* context = tolua_touserdata(tolua_S, 1, 0);
if (!context) {
goto lua_failed;
}
char* cmd = (char*)tolua_tostring(tolua_S, 2, 0);
if (cmd == NULL) {
goto lua_failed;
}
int handler = toluafix_ref_function(tolua_S, 3, 0);
if (handler == 0) {
goto lua_failed;
}
redis_wrapper::query(context, cmd, on_lua_query_cb, (void*)handler);
lua_failed:
return 0;
}
int
register_redis_export(lua_State* tolua_S) {
lua_getglobal(tolua_S, "_G");
if (lua_istable(tolua_S, -1)) {
tolua_open(tolua_S);
tolua_module(tolua_S, "redis_wrapper", 0);
tolua_beginmodule(tolua_S, "redis_wrapper");
tolua_function(tolua_S, "connect", lua_redis_connect);
tolua_function(tolua_S, "close_redis", lua_redis_close);
tolua_function(tolua_S, "query", lua_redis_query);
tolua_endmodule(tolua_S);
}
lua_pop(tolua_S, 1);
return 0;
}
lua_wrapper.cc修改
void
lua_wrapper::init() {
g_lua_State = luaL_newstate();
lua_atpanic(g_lua_State, lua_panic); // default abort;
luaL_openlibs(g_lua_State);
toluafix_open(g_lua_State);
register_mysql_export(g_lua_State);
register_redis_export(g_lua_State);
// export log
lua_wrapper::reg_func2lua("log_error", lua_log_error);
lua_wrapper::reg_func2lua("log_debug", lua_log_debug);
lua_wrapper::reg_func2lua("log_warning", lua_log_warning);
// end
}
lua脚本测试
log_debug("HelloWorld")
key = ""
function PrintTable(table , level)
level = level or 1
local indent = ""
for i = 1, level do
indent = indent.." "
end
if key ~= "" then
print(indent..key.." ".."=".." ".."{")
else
print(indent .. "{")
end
key = ""
for k,v in pairs(table) do
if type(v) == "table" then
key = k
PrintTable(v, level + 1)
else
local content = string.format("%s%s = %s", indent .. " ",tostring(k), tostring(v))
print(content)
end
end
print(indent .. "}")
end
redis_wrapper.connect("127.0.0.1", 6379, function (err, context)
if err then
print(err)
return
end
print("redis connect success")
--redis_wrapper.close_redis(context);
--[[
redis_wrapper.query(context, "hmset 001001 name \"blake\" age \"34\"", function (err, result)
if err then
print(err)
return
end
print(result)
end);
]]
redis_wrapper.query(context, "hgetall 002001", function (err, result)
if err then
print(err)
return
end
PrintTable(result)
end)
end);
Service模块导出
1: 导出service模块;
2: 注册模块内部函数: register;
3: 编写service Handler模块,参考toluafix的ref func,不用是怕随着运行的时间过长,func_id溢出,
所以这个独立出来做一个表;
#define SERVICE_FUNCTION_MAPPING “service_function_mapping”
init_service_func_map: 初始化这个表;
save_service_function: 保存handler 函数;
get_service_function: 获取serivce函数;
push_service_function: push函数;
exe_service_function: 调用函数;
execute_service_handler: 完整的执行函数和参数;
4: 注册函数导出: service.register注册函数;
5: register_service_export: 注册serive模块到处函数;
service_export_to_lua.h和service_export_to_lua.cc
#ifndef __SERVICE_EXPORT_TO_LUA_H__
#define __SERVICE_EXPORT_TO_LUA_H__
struct lua_State;
int register_service_export(lua_State* tolua_S);
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "lua_wrapper.h"
#include "../netbus/service.h"
#include "../netbus/session.h"
#include "../netbus/proto_man.h"
#include "../netbus/service_man.h"
#include "../utils/logger.h"
#include "google/protobuf/message.h"
using namespace google::protobuf;
#ifdef __cplusplus
extern "C" {
#endif
#include "tolua++.h"
#ifdef __cplusplus
}
#endif
#include "service_export_to_lua.h"
#define SERVICE_FUNCTION_MAPPING "service_function_mapping"
// 建立新的表,全局唯一,用于存放function的handle
static
void init_service_function_map(lua_State* L) {
lua_pushstring(L, SERVICE_FUNCTION_MAPPING);
lua_newtable(L);
lua_rawset(L, LUA_REGISTRYINDEX);
}
static unsigned int s_function_ref_id = 0;
//将一个 Lua 函数保存到一个映射表中,并返回一个唯一的引用 ID。这使得 C++ 代码可以在以后通过这个引用 ID 来调用 Lua 函数。
static unsigned int
save_service_function(lua_State* L, int lo, int def)
{
// function at lo
if (!lua_isfunction(L, lo)) return 0;
s_function_ref_id++;
lua_pushstring(L, SERVICE_FUNCTION_MAPPING);
lua_rawget(L, LUA_REGISTRYINDEX); /* stack: fun ... refid_fun */
lua_pushinteger(L, s_function_ref_id); /* stack: fun ... refid_fun refid */
lua_pushvalue(L, lo); /* stack: fun ... refid_fun refid fun */
lua_rawset(L, -3); /* refid_fun[refid] = fun, stack: fun ... refid_ptr */
lua_pop(L, 1); /* stack: fun ... */
return s_function_ref_id;
// lua_pushvalue(L, lo); /* stack: ... func */
// return luaL_ref(L, LUA_REGISTRYINDEX);
}
// 根据函数id获取lua函数压入栈中
static void
get_service_function(lua_State* L, int refid)
{
lua_pushstring(L, SERVICE_FUNCTION_MAPPING);
// 获取隐射表
lua_rawget(L, LUA_REGISTRYINDEX);
lua_pushinteger(L, refid);
// 从隐射表中获取lua函数
lua_rawget(L, -2);
// 移除隐射表,保留lua函数
lua_remove(L, -2);
}
// 给Get封装了一层
static bool
push_service_function(int nHandler)
{
get_service_function(lua_wrapper::lua_state(), nHandler); /* L: ... func */
if (!lua_isfunction(lua_wrapper::lua_state(), -1))
{
log_error("[LUA ERROR] function refid '%d' does not reference a Lua function", nHandler);
lua_pop(lua_wrapper::lua_state(), 1);
return false;
}
return true;
}
static int
exe_function(int numArgs)
{
int functionIndex = -(numArgs + 1);
if (!lua_isfunction(lua_wrapper::lua_state(), functionIndex))
{
log_error("value at stack [%d] is not function", functionIndex);
lua_pop(lua_wrapper::lua_state(), numArgs + 1); // remove function and arguments
return 0;
}
int traceback = 0;
lua_getglobal(lua_wrapper::lua_state(), "__G__TRACKBACK__"); /* L: ... func arg1 arg2 ... G */
if (!lua_isfunction(lua_wrapper::lua_state(), -1))
{
lua_pop(lua_wrapper::lua_state(), 1); /* L: ... func arg1 arg2 ... */
}
else
{
lua_insert(lua_wrapper::lua_state(), functionIndex - 1); /* L: ... G func arg1 arg2 ... */
traceback = functionIndex - 1;
}
int error = 0;
error = lua_pcall(lua_wrapper::lua_state(), numArgs, 1, traceback); /* L: ... [G] ret */
if (error)
{
if (traceback == 0)
{
log_error("[LUA ERROR] %s", lua_tostring(lua_wrapper::lua_state(), -1)); /* L: ... error */
lua_pop(lua_wrapper::lua_state(), 1); // remove error message from stack
}
else /* L: ... G error */
{
lua_pop(lua_wrapper::lua_state(), 2); // remove __G__TRACKBACK__ and error message from stack
}
return 0;
}
// get return value
int ret = 0;
if (lua_isnumber(lua_wrapper::lua_state(), -1))
{
ret = (int)lua_tointeger(lua_wrapper::lua_state(), -1);
}
else if (lua_isboolean(lua_wrapper::lua_state(), -1))
{
ret = (int)lua_toboolean(lua_wrapper::lua_state(), -1);
}
// remove return value from stack
lua_pop(lua_wrapper::lua_state(), 1); /* L: ... [G] */
if (traceback)
{
lua_pop(lua_wrapper::lua_state(), 1); // remove __G__TRACKBACK__ from stack /* L: ... */
}
return ret;
}
static int
execute_service_function(int nHandler, int numArgs) {
int ret = 0;
if (push_service_function(nHandler)) /* L: ... arg1 arg2 ... func */
{
if (numArgs > 0)
{
lua_insert(lua_wrapper::lua_state(), -(numArgs + 1)); /* L: ... func arg1 arg2 ... */
}
ret = exe_function(numArgs);
}
lua_settop(lua_wrapper::lua_state(), 0);
return ret;
}
class lua_service : public service {
public:
unsigned int lua_recv_cmd_handler;
unsigned int lua_disconnect_handler;
public:
virtual bool on_session_recv_cmd(session* s, struct cmd_msg* msg);
virtual void on_session_disconnect(session* s);
};
static void
push_proto_message_tolua(const Message* message) {
lua_State* state = lua_wrapper::lua_state();
if (!message) {
// printf("PushProtobuf2LuaTable failed, message is NULL");
return;
}
const Reflection* reflection = message->GetReflection();
// 顶层table
lua_newtable(state);
const Descriptor* descriptor = message->GetDescriptor();
for (int32_t index = 0; index < descriptor->field_count(); ++index) {
const FieldDescriptor* fd = descriptor->field(index);
const std::string& name = fd->lowercase_name();
// key
lua_pushstring(state, name.c_str());
bool bReapeted = fd->is_repeated();
if (bReapeted) {
// repeated这层的table
lua_newtable(state);
int size = reflection->FieldSize(*message, fd);
for (int i = 0; i < size; ++i) {
char str[32] = { 0 };
switch (fd->cpp_type()) {
case FieldDescriptor::CPPTYPE_DOUBLE:
lua_pushnumber(state, reflection->GetRepeatedDouble(*message, fd, i));
break;
case FieldDescriptor::CPPTYPE_FLOAT:
lua_pushnumber(state, (double)reflection->GetRepeatedFloat(*message, fd, i));
break;
case FieldDescriptor::CPPTYPE_INT64:
sprintf(str, "%lld", (long long)reflection->GetRepeatedInt64(*message, fd, i));
lua_pushstring(state, str);
break;
case FieldDescriptor::CPPTYPE_UINT64:
sprintf(str, "%llu", (unsigned long long)reflection->GetRepeatedUInt64(*message, fd, i));
lua_pushstring(state, str);
break;
case FieldDescriptor::CPPTYPE_ENUM: // 与int32一样处理
lua_pushinteger(state, reflection->GetRepeatedEnum(*message, fd, i)->number());
break;
case FieldDescriptor::CPPTYPE_INT32:
lua_pushinteger(state, reflection->GetRepeatedInt32(*message, fd, i));
break;
case FieldDescriptor::CPPTYPE_UINT32:
lua_pushinteger(state, reflection->GetRepeatedUInt32(*message, fd, i));
break;
case FieldDescriptor::CPPTYPE_STRING:
{
std::string value = reflection->GetRepeatedString(*message, fd, i);
lua_pushlstring(state, value.c_str(), value.size());
}
break;
case FieldDescriptor::CPPTYPE_BOOL:
lua_pushboolean(state, reflection->GetRepeatedBool(*message, fd, i));
break;
case FieldDescriptor::CPPTYPE_MESSAGE:
push_proto_message_tolua(&(reflection->GetRepeatedMessage(*message, fd, i)));
break;
default:
break;
}
lua_rawseti(state, -2, i + 1); // lua's index start at 1
}
}
else {
char str[32] = { 0 };
switch (fd->cpp_type()) {
case FieldDescriptor::CPPTYPE_DOUBLE:
lua_pushnumber(state, reflection->GetDouble(*message, fd));
break;
case FieldDescriptor::CPPTYPE_FLOAT:
lua_pushnumber(state, (double)reflection->GetFloat(*message, fd));
break;
case FieldDescriptor::CPPTYPE_INT64:
sprintf(str, "%lld", (long long)reflection->GetInt64(*message, fd));
lua_pushstring(state, str);
break;
case FieldDescriptor::CPPTYPE_UINT64:
sprintf(str, "%llu", (unsigned long long)reflection->GetUInt64(*message, fd));
lua_pushstring(state, str);
break;
case FieldDescriptor::CPPTYPE_ENUM: // 与int32一样处理
lua_pushinteger(state, (int)reflection->GetEnum(*message, fd)->number());
break;
case FieldDescriptor::CPPTYPE_INT32:
lua_pushinteger(state, reflection->GetInt32(*message, fd));
break;
case FieldDescriptor::CPPTYPE_UINT32:
lua_pushinteger(state, reflection->GetUInt32(*message, fd));
break;
case FieldDescriptor::CPPTYPE_STRING:
{
std::string value = reflection->GetString(*message, fd);
lua_pushlstring(state, value.c_str(), value.size());
}
break;
case FieldDescriptor::CPPTYPE_BOOL:
lua_pushboolean(state, reflection->GetBool(*message, fd));
break;
case FieldDescriptor::CPPTYPE_MESSAGE:
push_proto_message_tolua(&(reflection->GetMessage(*message, fd)));
break;
default:
break;
}
}
lua_rawset(state, -3);
}
}
// protobuf: message key, value --> lua table
// json: json string 传给lua
// {1: stype, 2: ctype, 3: utag, 4: body_table_or_str}
bool
lua_service::on_session_recv_cmd(session* s, struct cmd_msg* msg) {
tolua_pushuserdata(lua_wrapper::lua_state(), (void*)s);
int index = 1;
lua_newtable(lua_wrapper::lua_state());
lua_pushinteger(lua_wrapper::lua_state(), msg->stype);
lua_rawseti(lua_wrapper::lua_state(), -2, index); /* table[index] = value, L: table */
++index;
lua_pushinteger(lua_wrapper::lua_state(), msg->ctype);
lua_rawseti(lua_wrapper::lua_state(), -2, index); /* table[index] = value, L: table */
++index;
lua_pushinteger(lua_wrapper::lua_state(), msg->utag);
lua_rawseti(lua_wrapper::lua_state(), -2, index); /* table[index] = value, L: table */
++index;
if (!msg->body) {
lua_pushnil(lua_wrapper::lua_state());
lua_rawseti(lua_wrapper::lua_state(), -2, index);
++index;
}
else {
if (proto_man::proto_type() == PROTO_JSON) {
lua_pushstring(lua_wrapper::lua_state(), (char*) msg->body);
}
else { // protobuf
push_proto_message_tolua((Message*)msg->body);
}
lua_rawseti(lua_wrapper::lua_state(), -2, index); /* table[index] = value, L: table */
++index;
}
execute_service_function(this->lua_recv_cmd_handler, 2);
return true;
}
void
lua_service::on_session_disconnect(session* s) {
tolua_pushuserdata(lua_wrapper::lua_state(), (void*)s);
execute_service_function(this->lua_disconnect_handler, 1);
}
static int
lua_register_service(lua_State* tolua_S) {
int stype = (int)tolua_tonumber(tolua_S, 1, 0);
bool ret = false;
// table
if (!lua_istable(tolua_S, 2)) {
goto lua_failed;
}
unsigned int lua_recv_cmd_handler;
unsigned int lua_disconnect_handler;
lua_getfield(tolua_S, 2, "on_session_recv_cmd");
lua_getfield(tolua_S, 2, "on_session_disconnect");
// stack 3 on_session_recv_cmd , 4on_session_disconnect
lua_recv_cmd_handler = save_service_function(tolua_S, 3, 0);
lua_disconnect_handler = save_service_function(tolua_S, 4, 0);
if (lua_recv_cmd_handler == 0 || lua_disconnect_handler == 0) {
goto lua_failed;
}
// register service
lua_service* s = new lua_service();
s->lua_disconnect_handler = lua_disconnect_handler;
s->lua_recv_cmd_handler = lua_recv_cmd_handler;
ret = service_man::register_service(stype, s);
// end
lua_failed:
lua_pushboolean(tolua_S, ret ? 1 : 0);
return 1;
}
int
register_service_export(lua_State* tolua_S) {
init_service_function_map(tolua_S);
lua_getglobal(tolua_S, "_G");
if (lua_istable(tolua_S, -1)) {
tolua_open(tolua_S);
tolua_module(tolua_S, "service", 0);
tolua_beginmodule(tolua_S, "service");
tolua_function(tolua_S, "register", lua_register_service);
tolua_endmodule(tolua_S);
}
lua_pop(tolua_S, 1);
return 0;
}
应用层调整
pf_cmd_map.cc
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <string>
#include <map>
#include "../../../netbus/proto_man.h"
std::map<int, std::string> cmd_map = {
{0, "LoginReq"},
{1, "LoginRes"},
};
void init_pf_cmd_map() {
proto_man::register_pb_cmd_map(cmd_map);
}
proto_man部分
#define CMD_HEADER 8
static int g_proto_type = PROTO_BUF;
static std::map<int, std::string> g_pb_cmd_map;
void
proto_man::register_pb_cmd_map(std::map<int, std::string>& map) {
std::map<int, std::string>::iterator it;
for (it = map.begin(); it != map.end(); it++) {
g_pb_cmd_map[it->first] = it->second;
}
}
const char*
proto_man::protobuf_cmd_name(int ctype) {
return g_pb_cmd_map[ctype].c_str();
}
Session模块导出
1: 导出session模块函数: send_msg, close, get_address;
2: 注册模块内部函数:
send_msg: 发送一个消息体 {1: stype, 2: ctype, 3: utag, 4: body};
close: 关闭一个session;
get_address: 获取session对应的IP地址与端口;
3: 编写函数: lua_table_to_protobuf,将lua的消息表转成protobuf对应的对象;
session_export_to_lua.h和session_export_to_lua.cc
#ifndef __SESSION_EXPORT_TO_LUA_H__
#define __SESSION_EXPORT_TO_LUA_H__
struct lua_State;
int register_session_export(lua_State* tolua_S);
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "../netbus/service.h"
#include "../netbus/session.h"
#include "../netbus/proto_man.h"
#include "../netbus/service_man.h"
#include "../utils/logger.h"
#include "lua_wrapper.h"
#include "google/protobuf/message.h"
using namespace google::protobuf;
#ifdef __cplusplus
extern "C" {
#endif
#include "tolua++.h"
#ifdef __cplusplus
}
#endif
#include "session_export_to_lua.h"
// session.close(session)
static int
lua_session_close(lua_State* tolua_S) {
session* s = (session*)tolua_touserdata(tolua_S, 1, NULL);
if (s == NULL) {
goto lua_failed;
}
s->close();
lua_failed:
return 0;
}
// 把lua表转换成Message对象
static google::protobuf::Message*
lua_table_to_protobuf(lua_State* L, int stack_index, const char* msg_name) {
if (!lua_istable(L, stack_index)) {
return NULL;
}
Message* message = proto_man::create_message(msg_name);
if (!message) {
log_error("cant find message %s from compiled poll \n", msg_name);
return NULL;
}
const Reflection* reflection = message->GetReflection();
const Descriptor* descriptor = message->GetDescriptor();
// 遍历table的所有key, 并且与 protobuf结构相比较。如果require的字段没有赋值, 报错! 如果找不到字段,报错!
for (int32_t index = 0; index < descriptor->field_count(); ++index) {
const FieldDescriptor* fd = descriptor->field(index);
const string& name = fd->name();
bool isRequired = fd->is_required();
bool bReapeted = fd->is_repeated();
lua_pushstring(L, name.c_str());
lua_rawget(L, stack_index);
bool isNil = lua_isnil(L, -1);
if (bReapeted) {
if (isNil) {
lua_pop(L, 1);
continue;
}
else {
bool isTable = lua_istable(L, -1);
if (!isTable) {
log_error("cant find required repeated field %s\n", name.c_str());
proto_man::release_message(message);
return NULL;
}
}
lua_pushnil(L);
for (; lua_next(L, -2) != 0;) {
switch (fd->cpp_type()) {
case FieldDescriptor::CPPTYPE_DOUBLE:
{
double value = luaL_checknumber(L, -1);
reflection->AddDouble(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_FLOAT:
{
float value = luaL_checknumber(L, -1);
reflection->AddFloat(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_INT64:
{
int64_t value = luaL_checknumber(L, -1);
reflection->AddInt64(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_UINT64:
{
uint64_t value = luaL_checknumber(L, -1);
reflection->AddUInt64(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_ENUM: // 与int32一样处理
{
int32_t value = luaL_checknumber(L, -1);
const EnumDescriptor* enumDescriptor = fd->enum_type();
const EnumValueDescriptor* valueDescriptor = enumDescriptor->FindValueByNumber(value);
reflection->AddEnum(message, fd, valueDescriptor);
}
break;
case FieldDescriptor::CPPTYPE_INT32:
{
int32_t value = luaL_checknumber(L, -1);
reflection->AddInt32(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_UINT32:
{
uint32_t value = luaL_checknumber(L, -1);
reflection->AddUInt32(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_STRING:
{
size_t size = 0;
const char* value = luaL_checklstring(L, -1, &size);
reflection->AddString(message, fd, std::string(value, size));
}
break;
case FieldDescriptor::CPPTYPE_BOOL:
{
bool value = lua_toboolean(L, -1);
reflection->AddBool(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_MESSAGE:
{
Message* value = lua_table_to_protobuf(L, lua_gettop(L), fd->message_type()->name().c_str());
if (!value) {
log_error("convert to message %s failed whith value %s\n", fd->message_type()->name().c_str(), name.c_str());
proto_man::release_message(value);
return NULL;
}
Message* msg = reflection->AddMessage(message, fd);
msg->CopyFrom(*value);
proto_man::release_message(value);
}
break;
default:
break;
}
// remove value, keep the key
lua_pop(L, 1);
}
}
else {
if (isRequired) {
if (isNil) {
log_error("cant find required field %s\n", name.c_str());
proto_man::release_message(message);
return NULL;
}
}
else {
if (isNil) {
lua_pop(L, 1);
continue;
}
}
switch (fd->cpp_type()) {
case FieldDescriptor::CPPTYPE_DOUBLE:
{
double value = luaL_checknumber(L, -1);
reflection->SetDouble(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_FLOAT:
{
float value = luaL_checknumber(L, -1);
reflection->SetFloat(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_INT64:
{
int64_t value = luaL_checknumber(L, -1);
reflection->SetInt64(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_UINT64:
{
uint64_t value = luaL_checknumber(L, -1);
reflection->SetUInt64(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_ENUM: // 与int32一样处理
{
int32_t value = luaL_checknumber(L, -1);
const EnumDescriptor* enumDescriptor = fd->enum_type();
const EnumValueDescriptor* valueDescriptor = enumDescriptor->FindValueByNumber(value);
reflection->SetEnum(message, fd, valueDescriptor);
}
break;
case FieldDescriptor::CPPTYPE_INT32:
{
int32_t value = luaL_checknumber(L, -1);
reflection->SetInt32(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_UINT32:
{
uint32_t value = luaL_checknumber(L, -1);
reflection->SetUInt32(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_STRING:
{
size_t size = 0;
const char* value = luaL_checklstring(L, -1, &size);
reflection->SetString(message, fd, std::string(value, size));
}
break;
case FieldDescriptor::CPPTYPE_BOOL:
{
bool value = lua_toboolean(L, -1);
reflection->SetBool(message, fd, value);
}
break;
case FieldDescriptor::CPPTYPE_MESSAGE:
{
Message* value = lua_table_to_protobuf(L, lua_gettop(L), fd->message_type()->name().c_str());
if (!value) {
log_error("convert to message %s failed whith value %s \n", fd->message_type()->name().c_str(), name.c_str());
proto_man::release_message(message);
return NULL;
}
Message* msg = reflection->MutableMessage(message, fd);
msg->CopyFrom(*value);
proto_man::release_message(value);
}
break;
default:
break;
}
}
// pop value
lua_pop(L, 1);
}
return message;
}
// {1: stype, 2: ctype, 3: utag, 4 body}
static int
lua_send_msg(lua_State* tolua_S) {
session* s = (session*)tolua_touserdata(tolua_S, 1, NULL);
if (s == NULL) {
goto lua_failed;
}
// stack: 1 s, 2, table,
if (!lua_istable(tolua_S, 2)) {
goto lua_failed;
}
struct cmd_msg msg;
int n = luaL_len(tolua_S, 2);
if (n != 4 && n != 3) { // 包括了发送的数据为空的情况
goto lua_failed;
}
lua_pushnumber(tolua_S, 1);
lua_gettable(tolua_S, 2);
msg.stype = luaL_checkinteger(tolua_S, -1);
lua_pushnumber(tolua_S, 2);
lua_gettable(tolua_S, 2);
msg.ctype = luaL_checkinteger(tolua_S, -1);
lua_pushnumber(tolua_S, 3);
lua_gettable(tolua_S, 2);
msg.utag = luaL_checkinteger(tolua_S, -1);
if (n == 3) { // 如果数据为空
msg.body = NULL;
s->send_msg(&msg);
return 0;
}
lua_pushnumber(tolua_S, 4);
lua_gettable(tolua_S, 2);
if (proto_man::proto_type() == PROTO_JSON) {
msg.body = (char*)lua_tostring(tolua_S, -1);
s->send_msg(&msg);
}
else {
if (!lua_istable(tolua_S, -1)) {
msg.body = NULL;
s->send_msg(&msg);
}
else { // protobuf message table
const char* msg_name = proto_man::protobuf_cmd_name(msg.ctype);
msg.body = lua_table_to_protobuf(tolua_S, lua_gettop(tolua_S), msg_name);
s->send_msg(&msg);
proto_man::release_message((google::protobuf::Message*)(msg.body));
}
}
lua_failed:
return 0;
}
static int
lua_get_addr(lua_State* tolua_S) {
session* s = (session*)tolua_touserdata(tolua_S, 1, NULL);
if (s == NULL) {
goto lua_failed;
}
int client_port;
const char* ip = s->get_address(&client_port);
lua_pushstring(tolua_S, ip);
lua_pushinteger(tolua_S, client_port);
return 2;
lua_failed:
return 0;
}
static int
lua_set_utag(lua_State* tolua_S) {
session* s = (session*)tolua_touserdata(tolua_S, 1, NULL);
if (s == NULL) {
goto lua_failed;
}
unsigned int utag = lua_tointeger(tolua_S, 2);
s->utag = utag;
lua_failed:
return 0;
}
static int
lua_get_utag(lua_State* tolua_S) {
session* s = (session*)tolua_touserdata(tolua_S, 1, NULL);
if (s == NULL) {
goto lua_failed;
}
lua_pushinteger(tolua_S, s->utag);
return 1;
lua_failed:
return 0;
}
static int
lua_as_client(lua_State* tolua_S) {
session* s = (session*)tolua_touserdata(tolua_S, 1, NULL);
if (s == NULL) {
goto lua_failed;
}
lua_pushinteger(tolua_S, s->as_client);
return 1;
lua_failed:
return 0;
}
int
register_session_export(lua_State* tolua_S) {
lua_getglobal(tolua_S, "_G");
if (lua_istable(tolua_S, -1)) {
tolua_open(tolua_S);
tolua_module(tolua_S, "session", 0);
tolua_beginmodule(tolua_S, "session");
tolua_function(tolua_S, "close", lua_session_close);
tolua_function(tolua_S, "send_msg", lua_send_msg);
tolua_function(tolua_S, "get_address", lua_get_addr);
tolua_function(tolua_S, "set_utag", lua_set_utag);
tolua_function(tolua_S, "get_utag", lua_get_utag);
tolua_function(tolua_S, "asclient", lua_as_client);
tolua_endmodule(tolua_S);
}
lua_pop(tolua_S, 1);
return 0;
}
timer模块的导出
1: 导出timer模块函数: schedule, once, cancel,三个接口
2: 注册scheduler模块内部函数:
schedule: 重复的调用;
once: 调用一次;
cancel: 取消定时器;
4: 时间戳,日期 Lua标准里面已经有了,就不导出了,如果有需要再说;
time_list.c回顾
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "uv.h"
#include "time_list.h"
#define my_malloc malloc
#define my_free free
struct timer {
uv_timer_t uv_timer; // libuv timer handle
void(*on_timer)(void* udata);
void* udata;
int repeat_count; // -1一直循环;
};
static struct timer*
alloc_timer(void(*on_timer)(void* udata),
void* udata, int repeat_count) {
struct timer* t = my_malloc(sizeof(struct timer));
memset(t, 0, sizeof(struct timer));
t->on_timer = on_timer;
t->repeat_count = repeat_count;
t->udata = udata;
uv_timer_init(uv_default_loop(), &t->uv_timer);
return t;
}
static void
free_timer(struct timer* t) {
my_free(t);
}
static void
on_uv_timer(uv_timer_t* handle) {
struct timer* t = handle->data;
if (t->repeat_count < 0) { // 不断的触发;
t->on_timer(t->udata);
}
else {
t->repeat_count --;
t->on_timer(t->udata);
if (t->repeat_count == 0) { // 函数time结束
uv_timer_stop(&t->uv_timer); // 停止这个timer
free_timer(t);
}
}
}
struct timer*
schedule_repeat(void(*on_timer)(void* udata),
void* udata,
int after_msec,
int repeat_count,
int repeat_msec) {
struct timer* t = alloc_timer(on_timer, udata, repeat_count);
// 启动一个timer;
t->uv_timer.data = t;
uv_timer_start(&t->uv_timer, on_uv_timer, after_msec, repeat_msec);
// end
return t;
}
void
cancel_timer(struct timer* t) {
if (t->repeat_count == 0) { // 全部触发完成,;
return;
}
uv_timer_stop(&t->uv_timer);
free_timer(t);
}
struct timer*
schedule_once(void(*on_timer)(void* udata),
void* udata,
int after_msec) {
return schedule_repeat(on_timer, udata, after_msec, 1, after_msec);
}
void*
get_timer_udata(struct timer* t) {
return t->udata;
}
schduler_export_to_lua.h和schduler_export_to_lua.cc
#ifndef __SCHEDULER_EXPORT_TO_LUA_H__
#define __SCHEDULER_EXPORT_TO_LUA_H__
int register_scheduler_export(lua_State* tolua_S);
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "../utils/time_list.h"
#include "lua_wrapper.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "tolua++.h"
#ifdef __cplusplus
}
#endif
#include "tolua_fix.h"
#include "scheduler_export_to_lua.h"
#define my_malloc malloc
#define my_free free
struct timer_repeat {
int handler;
int repeat_count;
};
static void
on_lua_repeat_timer(void* udata) {
struct timer_repeat* tr = (struct timer_repeat*) udata;
lua_wrapper::execute_script_handler(tr->handler, 0);
if (tr->repeat_count == -1) {
return;
}
tr->repeat_count --;
if (tr->repeat_count <= 0) {
lua_wrapper::remove_script_handler(tr->handler);
my_free(tr);
}
}
static int
lua_schedule_repeat(lua_State* tolua_S) {
int handler = toluafix_ref_function(tolua_S, 1, 0);
if (handler == 0) {
goto lua_failed;
}
int after_msec = lua_tointeger(tolua_S, 2, 0);
if (after_msec <= 0) {
goto lua_failed;
}
int repeat_count = lua_tointeger(tolua_S, 3, 0);
if (repeat_count == 0) {
goto lua_failed;
}
if (repeat_count < 0) { // -1 forver
repeat_count = -1;
}
int repeat_msec = lua_tointeger(tolua_S, 4, 0);
if (repeat_msec <= 0) {
repeat_msec = after_msec;
}
struct timer_repeat* tr = (struct timer_repeat*)my_malloc(sizeof(struct timer_repeat));
tr->handler = handler;
tr->repeat_count = repeat_count;
struct timer* t = schedule_repeat(on_lua_repeat_timer, tr, after_msec, repeat_count, repeat_msec);
tolua_pushuserdata(tolua_S, t);
return 1;
lua_failed:
if (handler != 0) {
lua_wrapper::remove_script_handler(handler);
}
lua_pushnil(tolua_S);
return 1;
}
static int
lua_schedule_once(lua_State* tolua_S) {
int handler = toluafix_ref_function(tolua_S, 1, 0);
if (handler == 0) {
goto lua_failed;
}
int after_msec = lua_tointeger(tolua_S, 2, 0);
if (after_msec <= 0) {
goto lua_failed;
}
struct timer_repeat* tr = (struct timer_repeat*)my_malloc(sizeof(struct timer_repeat));
tr->handler = handler;
tr->repeat_count = 1;
struct timer* t = schedule_once(on_lua_repeat_timer, (void*)tr, after_msec);
tolua_pushuserdata(tolua_S, t);
return 1;
lua_failed:
if (handler != 0) {
lua_wrapper::remove_script_handler(handler);
}
lua_pushnil(tolua_S);
return 1;
}
static int
lua_schedule_cancel(lua_State* tolua_S) {
if (!lua_isuserdata(tolua_S, 1)) {
goto lua_failed;
}
struct timer* t = (struct timer*)lua_touserdata(tolua_S, 1);
struct timer_repeat* tr = (struct timer_repeat*)get_timer_udata(t);
lua_wrapper::remove_script_handler(tr->handler);
my_free(tr);
cancel_timer(t);
lua_failed:
return 0;
}
int
register_scheduler_export(lua_State* tolua_S) {
lua_getglobal(tolua_S, "_G");
if (lua_istable(tolua_S, -1)) {
tolua_open(tolua_S);
tolua_module(tolua_S, "scheduler", 0);
tolua_beginmodule(tolua_S, "scheduler");
tolua_function(tolua_S, "schedule", lua_schedule_repeat);
tolua_function(tolua_S, "once", lua_schedule_once);
tolua_function(tolua_S, "cancel", lua_schedule_cancel);
tolua_endmodule(tolua_S);
}
lua_pop(tolua_S, 1);
return 0;
}
proto_man模块的导出
1: 导出proto_man模块函数:
init: 初始化协议类型;
proto_type: 获取协议类型;
register_protobuf_cmd_map: 注册protobuf 名字与ctype的数据类型映射表:
static int
lua_register_protobuf_cmd_map(lua_State* L) {
std::map<int, std::string> map;
int n = luaL_len(L, 1);
for (int i = 1; i <= n; ++i) {
lua_pushnumber(L, i);
lua_gettable(L, 1);
map[i] = lua_tostring(L, -1);
lua_pop(L, 1);
}
proto_man::register_protobuf_cmd_map(map);
lua_failed:
return 0;
}
代码调整
1: 修改netbus 接口名字:
start_tcp_server —> tcp_listen
start_ws_server —> ws_listen
start_udp_server –> udp_listen
2: 调整proto_man函数名字: register_protobuf_cmd_map;
3: lua_wrapper类里面添加一个函数: add_search_path(std::string),并导出来独立的函数,用于在lua脚本添加脚本的搜索路径
4: exe_lua_file类中的参数改成std::string
在lua_wrapper添加add_search_path函数
static int
lua_add_search_path(lua_State* L) {
const char* path = luaL_checkstring(L, 1);
if (path) {
std::string str_path = path;
lua_wrapper::add_search_path(str_path);
}
return 0;
}
void
lua_wrapper::init() {
g_lua_State = luaL_newstate();
lua_atpanic(g_lua_State, lua_panic); // default abort;
luaL_openlibs(g_lua_State);
toluafix_open(g_lua_State);
lua_wrapper::reg_func2lua("add_search_path", lua_add_search_path);
register_logger_export(g_lua_State);
register_mysql_export(g_lua_State);
register_redis_export(g_lua_State);
register_service_export(g_lua_State);
register_session_export(g_lua_State);
register_scheduler_export(g_lua_State);
register_netbus_export(g_lua_State);
register_proto_man_export(g_lua_State);
}
void
lua_wrapper::add_search_path(std::string& path) {
char strPath[1024] = { 0 };
sprintf(strPath, "local path = string.match([[%s]],[[(.*)/[^/]*$]])\n package.path = package.path .. [[;]] .. path .. [[/?.lua;]] .. path .. [[/?/init.lua]]\n", path.c_str());
luaL_dostring(g_lua_State, strPath);
}
proto_man_export_to_lua.h和proto_man_export_to_lua.cc
#ifndef __PROTO_MAN_EXPORT_TO_LUA_H__
#define __PROTO_MAN_EXPORT_TO_LUA_H__
int register_proto_man_export(lua_State* tolua_S);
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "../netbus/proto_man.h"
#include "lua_wrapper.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "tolua++.h"
#ifdef __cplusplus
}
#endif
#include "tolua_fix.h"
#include "proto_man_export_to_lua.h"
// local cmd_name_map = {"Name2", "Name1", "Name3"}... Lua 1开始, cmd 1开始
static int
lua_register_protobuf_cmd_map(lua_State* L) {
std::map<int, std::string> cmd_map;
int n = luaL_len(L, 1);
if (n <= 0) {
goto lua_failed;
}
for (int i = 1; i <= n; i++) {
lua_pushnumber(L, i);
lua_gettable(L, 1);
const char* name = luaL_checkstring(L, -1);
if (name) {
cmd_map[i] = name;
}
lua_pop(L, 1);
}
proto_man::register_protobuf_cmd_map(cmd_map);
lua_failed:
return 0;
}
static int
lua_proto_type(lua_State* tolua_S) {
lua_pushinteger(tolua_S, proto_man::proto_type());
lua_failed:
return 1;
}
static int
lua_proto_man_init(lua_State* tolua_S) {
int argc = lua_gettop(tolua_S);
if (argc != 1) {
goto lua_failed;
}
int proto_type = (int)lua_tointeger(tolua_S, 1);
if (proto_type != PROTO_JSON && proto_type != PROTO_BUF) {
goto lua_failed;
}
proto_man::init(proto_type);
lua_failed:
return 0;
}
int
register_proto_man_export(lua_State* tolua_S) {
lua_getglobal(tolua_S, "_G");
if (lua_istable(tolua_S, -1)) {
tolua_open(tolua_S);
tolua_module(tolua_S, "proto_man", 0);
tolua_beginmodule(tolua_S, "proto_man");
tolua_function(tolua_S, "init", lua_proto_man_init);
tolua_function(tolua_S, "proto_type", lua_proto_type);
tolua_function(tolua_S, "register_protobuf_cmd_map", lua_register_protobuf_cmd_map);
tolua_endmodule(tolua_S);
}
lua_pop(tolua_S, 1);
return 0;
}
netbus模块的导出
1: 导出netbus模块函数: tcp_listen, ws_listen, udp_listen,三个接口
2: 注册netbus模块内部函数:
tcp_listen: 开启tcp监听端口;
ws_listen: websocket 监听端口;
udp_listen: udp服务器端口;
netbus_export_to_lua.h和netbus_export_to_lua.cc
#ifndef __NETBUS_EXPORT_TO_LUA_H__
#define __NETBUS_EXPORT_TO_LUA_H__
int register_netbus_export(lua_State* tolua_S);
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "../netbus/netbus.h"
#include "lua_wrapper.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "tolua++.h"
#ifdef __cplusplus
}
#endif
#include "tolua_fix.h"
#include "netbus_export_to_lua.h"
static int
lua_udp_listen(lua_State* tolua_S) {
int argc = lua_gettop(tolua_S);
if (argc != 1) {
goto lua_failed;
}
int port = (int)lua_tointeger(tolua_S, 1);
netbus::instance()->udp_listen(port);
lua_failed:
return 0;
}
static int
lua_tcp_listen(lua_State* tolua_S) {
int argc = lua_gettop(tolua_S);
if (argc != 1) {
goto lua_failed;
}
int port = (int)lua_tointeger(tolua_S, 1);
netbus::instance()->tcp_listen(port);
lua_failed:
return 0;
}
static int
lua_ws_listen(lua_State* tolua_S) {
int argc = lua_gettop(tolua_S);
if (argc != 1) {
goto lua_failed;
}
int port = (int)lua_tointeger(tolua_S, 1);
netbus::instance()->ws_listen(port);
lua_failed:
return 0;
}
int
register_netbus_export(lua_State* tolua_S) {
lua_getglobal(tolua_S, "_G");
if (lua_istable(tolua_S, -1)) {
tolua_open(tolua_S);
tolua_module(tolua_S, "netbus", 0);
tolua_beginmodule(tolua_S, "netbus");
tolua_function(tolua_S, "udp_listen", lua_udp_listen);
tolua_function(tolua_S, "tcp_listen", lua_tcp_listen);
tolua_function(tolua_S, "ws_listen", lua_ws_listen);
tolua_endmodule(tolua_S);
}
lua_pop(tolua_S, 1);
return 0;
}
项目模板
基于上面的代码框架,项目既可以使用C/C++进行开发,也可以完全使用lua进行开发,下面是框架的基础模板
C/C++代码
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <string>
using namespace std;
#include "../../netbus/proto_man.h"
#include "../../netbus/netbus.h"
#include "../../utils/logger.h"
#include "../../utils/time_list.h"
#include "../../utils/timestamp.h"
#include "../../database/mysql_wrapper.h"
#include "../../database/redis_wrapper.h"
#include "../../lua_wrapper/lua_wrapper.h"
int main(int argc, char** argv) {
netbus::instance()->init();
lua_wrapper::init();
if (argc != 3) { // 测试
std::string search_path = "../../apps/lua_test/scripts/";
lua_wrapper::add_search_path(search_path);
std::string lua_file = search_path + "main.lua";
lua_wrapper::do_file(lua_file);
// end
}
else {
std::string search_path = argv[1];
if (*(search_path.end() - 1) != '/') {
search_path += "/";
}
lua_wrapper::add_search_path(search_path);
std::string lua_file = search_path + argv[2];
lua_wrapper::do_file(lua_file);
}
netbus::instance()->run();
lua_wrapper::exit();
system("pause");
return 0;
}
lua代码
main.lua
--初始化日志模块
logger.init("logger/gateway/", "gateway", true)
--end
-- 初始化协议模块
local proto_type = {
PROTO_JSON = 0,
PROTO_BUF = 1,
}
proto_man.init(proto_type.PROTO_BUF)
-- 如果是protobuf协议,还要注册一下映射表
if proto_man.proto_type() == proto_type.PROTO_BUF then
local cmd_name_map = require("cmd_name_map")
if cmd_name_map then
proto_man.register_protobuf_cmd_map(cmd_name_map)
end
end
--end
-- 开启网络服务
netbus.tcp_listen(6080)
netbus.ws_listen(8001)
netbus.udp_listen(8002)
--end
print("start service success !!!!")
local cmd_name_map = {
"LoginReq",
"LoginRes",
}
return cmd_name_map
服务器链接到其他服务器
session添加字段as_client,当as_client = 1时代表session连接了其他服务器,0代表是客户端的session。另外加上字段整形udata,用于用户标识
netbus.cc新增代码
struct connect_cb {
void(*on_connected)(int err, session* s, void* udata);
void* udata;
};
static void
after_connect(uv_connect_t* handle, int status) {
uv_session* s = (uv_session*)handle->handle->data;
struct connect_cb* cb = (struct connect_cb*)handle->data;
if (status) {
if (cb->on_connected) {
cb->on_connected(1, NULL, cb->udata);
}
s->close();
free(cb);
free(handle);
return;
}
if (cb->on_connected) {
cb->on_connected(0, (session*)s, cb->udata);
}
uv_read_start((uv_stream_t*)handle->handle, uv_alloc_buf, after_read);
free(cb);
free(handle);
}
void
netbus::tcp_connect(char* server_ip, int port,
void(*on_connected)(int err, session* s, void* udata),
void* udata) {
struct sockaddr_in bind_addr;
int iret = uv_ip4_addr(server_ip, port, &bind_addr);
if (iret) {
return;
}
uv_session* s = uv_session::create();
uv_tcp_t* client = &s->tcp_handler;
memset(client, 0, sizeof(uv_tcp_t));
uv_tcp_init(uv_default_loop(), client);
client->data = (void*)s;
s->as_client = 1;
s->socket_type = TCP_SOCKET;
strcpy(s->c_address, server_ip);
s->c_port = port;
uv_connect_t* connect_req = (uv_connect_t*)malloc(sizeof(uv_connect_t));
struct connect_cb* cb = (struct connect_cb*)malloc(sizeof(struct connect_cb));
cb->on_connected = on_connected;
cb->udata = udata;
connect_req->data = (void*)cb;
iret = uv_tcp_connect(connect_req, client, (struct sockaddr*)&bind_addr, after_connect);
if (iret) {
// log_error("uv_tcp_connect error!!!");
return;
}
}
使用:
netbus::instance()->tcp_connect("127.0.0.1", 7788, NULL, NULL);
导出到lua
netbus_export_to_lua.cc
static void
on_tcp_connected(int err, session* s, void* udata) {
if (err) {
lua_pushinteger(lua_wrapper::lua_state(), err);
lua_pushnil(lua_wrapper::lua_state());
}
else {
lua_pushinteger(lua_wrapper::lua_state(), err);
tolua_pushuserdata(lua_wrapper::lua_state(), s);
}
lua_wrapper::execute_script_handler((int)udata, 2);
lua_wrapper::remove_script_handler((int)udata);
}
// ip, port, lua_func(err, session)
static int
lua_tcp_connect(lua_State* tolua_S) {
const char* ip = luaL_checkstring(tolua_S, 1);
if (ip == NULL) {
goto lua_failed;
}
int port = luaL_checkinteger(tolua_S, 2);
int handler = toluafix_ref_function(tolua_S, 3, 0);
if (handler == 0) {
goto lua_failed;
}
netbus::instance()->tcp_connect(ip, port, on_tcp_connected, (void*)handler);
lua_failed:
return 0;
}
tolua_function(tolua_S, "tcp_connect", lua_tcp_connect);
Unity的TCP网络模块
Unity网络模块分为连接模块,发送数据模块、接收数据模块,其中接收数据模块需要在独立的线程进行,并设置消息队列和消息分发,当主线程检测到消息队列中有消息时才会进行处理,通过这样的方式进行异步处理,防止主线程的等待。
Network基础代码(连接)
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class network : MonoBehaviour {
public string server_ip;
public int port;
private Socket client_socket = null;
private bool is_connect = false;
private Thread recv_thread = null;
private byte[] recv_buffer = new byte[8192];
void Awake() {
DontDestroyOnLoad(this.gameObject);
}
// Use this for initialization
void Start () {
this.connect_to_server();
// test
this.Invoke("close", 5.0f);
// end
}
void on_conntect_timeout() {
}
void on_connect_error(string err) {
}
void connect_to_server() {
try {
this.client_socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
IPAddress ipAddress = IPAddress.Parse(this.server_ip);
IPEndPoint ipEndpoint = new IPEndPoint(ipAddress, this.port);
IAsyncResult result = this.client_socket.BeginConnect(ipEndpoint, new AsyncCallback(this.on_connected), this.client_socket);
bool success = result.AsyncWaitHandle.WaitOne(5000, true);
if (!success) { // timeout;
this.on_conntect_timeout();
}
}
catch (System.Exception e) {
Debug.Log(e.ToString());
this.on_connect_error(e.ToString());
}
}
void on_recv_data() {
if (this.is_connect == false) {
return;
}
while (true) {
if (!this.client_socket.Connected) {
break;
}
try
{
int recv_len = this.client_socket.Receive(this.recv_buffer);
if (recv_len > 0) { // recv data from server
}
}
catch (System.Exception e) {
Debug.Log(e.ToString());
this.client_socket.Disconnect(true);
this.client_socket.Shutdown(SocketShutdown.Both);
this.client_socket.Close();
this.is_connect = false;
break;
}
}
}
void on_connected(IAsyncResult iar) {
try {
Socket client = (Socket)iar.AsyncState;
client.EndConnect(iar);
this.is_connect = true;
this.recv_thread = new Thread(new ThreadStart(this.on_recv_data));
this.recv_thread.Start();
Debug.Log("connect to server success" + this.server_ip + ":" + this.port + "!");
}
catch (System.Exception e) {
Debug.Log(e.ToString());
this.on_connect_error(e.ToString());
this.is_connect = false;
}
}
void close() {
if (!this.is_connect) {
return;
}
// abort recv thread
if (this.recv_thread != null) {
this.recv_thread.Abort();
}
// end
if (this.client_socket != null && this.client_socket.Connected) {
this.client_socket.Close();
}
}
}
data_viewer
用于给缓冲区传递数据
1: 读取byte[]里面的unsigned short, unsigned int 数据
小尾: 低位存在低内存地址, 高位存在高内存地址;
大尾: 低位存在高内存地址地方;
2: write_ushort_le: 写入无符号2个字节;
3: write_uint_le: 写入无符号4个字节;
4: write_bytes: 写入数组;
using System;
public class data_viewer {
public static void write_ushort_le(byte[] buf, int offset, ushort value) {
// value ---> byte[];
byte[] byte_value = BitConverter.GetBytes(value);
// 小尾,还是大尾?BitConvert 系统是小尾还是大尾;
if (!BitConverter.IsLittleEndian) {
Array.Reverse(byte_value);
}
Array.Copy(byte_value, 0, buf, offset, byte_value.Length);
}
public static void write_uint_le(byte[] buf, int offset, uint value) {
// value ---> byte[];
byte[] byte_value = BitConverter.GetBytes(value);
// 小尾,还是大尾?BitConvert 系统是小尾还是大尾;
if (!BitConverter.IsLittleEndian)
{
Array.Reverse(byte_value);
}
Array.Copy(byte_value, 0, buf, offset, byte_value.Length);
}
public static void write_bytes(byte[] dst, int offset, byte[] value)
{
Array.Copy(value, 0, dst, offset, value.Length);
}
// 从地址读出小尾值大小
public static ushort read_ushort_le(byte[] data, int offset) {
int ret = (data[offset] | (data[offset + 1] << 8));
return (ushort)ret;
}
}
发送数据
proto_man.cs
将消息用protobuf转换并发送,主要负责编码/解码数据包;
using System;
using System.IO;
using System.Text;
using ProtoBuf;
public class proto_man {
private const int HEADER_SIZE = 8; // 2 stype, 2 ctype, 4utag, msg--> body;
private static byte[] protobuf_serializer(ProtoBuf.IExtensible data)
{
using (MemoryStream m = new MemoryStream())
{
byte[] buffer = null;
Serializer.Serialize(m, data);
m.Position = 0;
int length = (int)m.Length;
buffer = new byte[length];
m.Read(buffer, 0, length);
return buffer;
}
}
public static byte[] pack_protobuf_cmd(int stype, int ctype, ProtoBuf.IExtensible msg) {
int cmd_len = HEADER_SIZE;
byte[] cmd_body = null;
if (msg != null) {
cmd_body = protobuf_serializer(msg);
cmd_len += cmd_body.Length;
}
byte[] cmd = new byte[cmd_len];
// stype, ctype, utag(4保留), cmd_body
data_viewer.write_ushort_le(cmd, 0, (ushort)stype);
data_viewer.write_ushort_le(cmd, 2, (ushort)ctype);
if (cmd_body != null) {
data_viewer.write_bytes(cmd, HEADER_SIZE, cmd_body);
}
return cmd;
}
public static byte[] pack_json_cmd(int stype, int ctype, string json_msg) {
int cmd_len = HEADER_SIZE;
byte[] cmd_body = null;
if (json_msg.Length > 0) { // utf8
cmd_body = Encoding.UTF8.GetBytes(json_msg);
cmd_len += cmd_body.Length;
}
byte[] cmd = new byte[cmd_len];
// stype, ctype, utag(4保留), cmd_body
data_viewer.write_ushort_le(cmd, 0, (ushort)stype);
data_viewer.write_ushort_le(cmd, 2, (ushort)ctype);
if (cmd_body != null)
{
data_viewer.write_bytes(cmd, HEADER_SIZE, cmd_body);
}
return cmd;
}
}
tcp_packder.cs
为了防止tcp的粘包问题,来做TCP数据包的封包/拆包协议;
using System;
public class tcp_packer {
private const int HEADER_SIZE = 2;
public static byte[] pack(byte[] cmd_data) {
int len = cmd_data.Length;
if (len > 65535 - 2) {
return null;
}
int cmd_len = len + HEADER_SIZE;
byte[] cmd = new byte[cmd_len];
data_viewer.write_ushort_le(cmd, 0, (ushort)cmd_len);
data_viewer.write_bytes(cmd, HEADER_SIZE, cmd_data);
return cmd;
}
}
network.cs新增
发送json或protobuf数据入口
private void on_send_data(IAsyncResult iar)
{
try
{
Socket client = (Socket)iar.AsyncState;
client.EndSend(iar);
}
catch (System.Exception e)
{
Debug.Log(e.ToString());
}
}
public void send_protobuf_cmd(int stype, int ctype, ProtoBuf.IExtensible body) {
byte[] cmd_data = proto_man.pack_protobuf_cmd(stype, ctype, body);
if (cmd_data == null) {
return;
}
byte[]tcp_pkg = tcp_packer.pack(cmd_data);
this.client_socket.BeginSend(tcp_pkg, 0, tcp_pkg.Length, SocketFlags.None, new AsyncCallback(this.on_send_data), this.client_socket);
}
public void send_json_cmd(int stype, int ctype, string json_body)
{
byte[] cmd_data = proto_man.pack_json_cmd(stype, ctype, json_body);
if (cmd_data == null) {
return;
}
byte[] tcp_pkg = tcp_packer.pack(cmd_data);
this.client_socket.BeginSend(tcp_pkg, 0, tcp_pkg.Length, SocketFlags.None, new AsyncCallback(this.on_send_data), this.client_socket);
}
接收数据
注册服务
session_export_to_lua.cc修改
为了迎合lua脚本中send_msg传进来的msg表包含的数据形式是数组
static int
lua_send_msg(lua_State* tolua_S) {
session* s = (session*)tolua_touserdata(tolua_S, 1, NULL);
if (s == NULL) {
goto lua_failed;
}
// stack: 1 s, 2, table,
if (!lua_istable(tolua_S, 2)) {
goto lua_failed;
}
struct cmd_msg msg;
int n = luaL_len(tolua_S, 2);
if (n != 4) {
goto lua_failed;
}
lua_pushnumber(tolua_S, 1);
lua_gettable(tolua_S, 2);
msg.stype = luaL_checkinteger(tolua_S, -1);
lua_pushnumber(tolua_S, 2);
lua_gettable(tolua_S, 2);
msg.ctype = luaL_checkinteger(tolua_S, -1);
lua_pushnumber(tolua_S, 3);
lua_gettable(tolua_S, 2);
msg.utag = luaL_checkinteger(tolua_S, -1);
lua_pushnumber(tolua_S, 4);
lua_gettable(tolua_S, 2);
if (proto_man::proto_type() == PROTO_JSON) {
msg.body = (char*)lua_tostring(tolua_S, -1);
s->send_msg(&msg);
}
else {
if (!lua_istable(tolua_S, -1)) {
msg.body = NULL;
s->send_msg(&msg);
}
else { // protobuf message table
const char* msg_name = proto_man::protobuf_cmd_name(msg.ctype);
msg.body = lua_table_to_protobuf(tolua_S, lua_gettop(tolua_S), msg_name);
s->send_msg(&msg);
proto_man::release_message((google::protobuf::Message*)(msg.body));
}
}
lua_failed:
return 0;
}
echo_server.lua
function echo_recv_cmd(s, msg)
print(msg[1]) -- stype
print(msg[2]) -- ctype
print(msg[3]) -- utag,
local body = msg[4]
print(body.name)
print(body.email)
print(body.age)
-- send to client
local to_client = {1, 2, 0, {status = 200}}
session.send_msg(s, to_client)
end
function echo_session_disconnect(s)
end
local echo_service = {
on_session_recv_cmd = echo_recv_cmd,
on_session_disconnect = echo_session_disconnect,
}
local echo_server = {
stype = 1,
service = echo_service,
}
return echo_server;
客户端拆包
1: 编写tcp_package模块,来做TCP数据包的封包/拆包协议;
2: network添加数据成员:
recved: 已经收到的数据;
long_pkg: 是否为一个大的数据包 > 默认的大小8192;
long_pkg_size: 当前大数据包的大小;
3: 对比服务器模块编写tcp的拆包;
tcp_packer: 添加接口: bool read_header(byte[] data, int data_len, out int pkg_size, out int head_size);
data_viewer: 添加接口: ushort read_ushort_le(byte[] data, offset);
4: proto_man解码 cmd:
static bool unpack_msg_cmd(byte[] data, int start, int len, out cmd_msg msg) ;
5: proto_man: 添加protobuf 解码函数:
Network.cs
thread_recv_worker是连接后启动的接收数据线程
private const int RECV_LEN = 8192;
private byte[] recv_buf = new byte[RECV_LEN];
private int recved;
private byte[] long_pkg = null;
private int long_pkg_size = 0;
// 对数据进行解包,变成message
void on_recv_tcp_cmd(byte[] data, int start, int data_len) {
cmd_msg msg;
proto_man.unpack_cmd_msg(data, start, data_len, out msg);
if (msg != null) {
// test
gprotocol.LoginRes res = proto_man.protobuf_deserialize<gprotocol.LoginRes>(msg.body);
Debug.Log("########## res = " + res.status);
// end
}
}
// 对接收的数据进行处理
void on_recv_tcp_data() {
byte[] pkg_data = (this.long_pkg != null) ? this.long_pkg : this.recv_buf;
while (this.recved > 0) {
int pkg_size = 0;
int head_size = 0;
if (!tcp_packer.read_header(pkg_data, this.recved, out pkg_size, out head_size)) {
break;
}
// 接收到的是大包,数据没收完,不处理
if (this.recved < pkg_size) {
break;
}
int raw_data_start = head_size;
int raw_data_len = pkg_size - head_size;
// 解出一个包
on_recv_tcp_cmd(pkg_data, raw_data_start, raw_data_len);
// 说明包含了多个包,把剩下没处理的数据留下
if (this.recved > pkg_size) {
this.recv_buf = new byte[RECV_LEN];
Array.Copy(pkg_data, pkg_size, this.recv_buf, 0, this.recved - pkg_size);
pkg_data = this.recv_buf;
}
this.recved -= pkg_size;
// 处理完毕
if (this.recved == 0 && this.long_pkg != null) {
this.long_pkg = null;
this.long_pkg_size = 0;
}
}
}
// 接收数据线程(入口)
void thread_recv_worker() {
if (this.is_connect == false) {
return;
}
while (true) {
if (!this.client_socket.Connected) {
break;
}
try
{
int recv_len = 0;
if (this.recved < RECV_LEN) {
recv_len = this.client_socket.Receive(this.recv_buf, this.recved, RECV_LEN - this.recved, SocketFlags.None);
}
// 收到了大包,使用long_pkg存储
else {
if (this.long_pkg == null) {
int pkg_size;
int head_size;
tcp_packer.read_header(this.recv_buf, this.recved, out pkg_size, out head_size);
this.long_pkg_size = pkg_size;
this.long_pkg = new byte[pkg_size];
Array.Copy(this.recv_buf, 0, this.long_pkg, 0, this.recved);
}
recv_len = this.client_socket.Receive(this.long_pkg, this.recved, this.long_pkg_size - this.recved, SocketFlags.None);
}
if (recv_len > 0) {
this.recved += recv_len;
this.on_recv_tcp_data();
}
}
catch (System.Exception e) {
Debug.Log(e.ToString());
this.client_socket.Disconnect(true);
this.client_socket.Shutdown(SocketShutdown.Both);
this.client_socket.Close();
this.is_connect = false;
break;
}
}
}
tcp_package.cs
public static bool read_header(byte[] data, int data_len, out int pkg_size, out int head_size) {
pkg_size = 0;
head_size = 0;
if (data_len < 2) {
return false;
}
head_size = 2;
pkg_size = (data[0] | (data[1] << 8));
return true;
}
proto_man.cs
public class cmd_msg {
public int stype;
public int ctype;
public byte[] body; // protobuf, utf8 string json byte;
}
// 将protobuf二进制数据解码成自定义类cmd_msg
public static bool unpack_cmd_msg(byte[] data, int start, int cmd_len, out cmd_msg msg) {
msg = new cmd_msg();
msg.stype = data_viewer.read_ushort_le(data, start);
msg.ctype = data_viewer.read_ushort_le(data, start + 2);
int body_len = cmd_len - HEADER_SIZE;
msg.body = new byte[body_len];
Array.Copy(data, start + HEADER_SIZE, msg.body, 0, body_len);
return true;
}
// 二进制数据解码成message
public static T protobuf_deserialize<T>(byte[] _data) {
using (MemoryStream m = new MemoryStream(_data))
{
return Serializer.Deserialize<T>(m);
}
}
补充调整
Network.cs需要变成一个单例
在Destory的时候需要停止网络线程,关闭连接
public static network _instance;
public static network instance {
get {
return _instance;
}
}
void Awake() {
_instance = this;
DontDestroyOnLoad(this.gameObject);
}
void OnDestroy() {
// Debug.Log("network onDestroy!");
this.close();
}
void OnApplicaitonQuit() {
// Debug.Log("OnApplicaitonQuit");
this.close();
}
消息队列和消息分发
原理:首先我们需要建立一个事件队列,当网络线程接收到数据后会push到事件队列的尾部,主线程每次推出一个事件进行处理,每一个事件都可以注册监听者,事件触发时调用回调函数,两个线程采用同步锁保证安全。
Network.cs修改
// event queque
private Queue<cmd_msg> net_events = new Queue<cmd_msg>();
// event listener, stype--> 监听者;
public delegate void net_message_handler(cmd_msg msg);
// 事件和监听的map
private Dictionary<int, net_message_handler> event_listeners = new Dictionary<int, net_message_handler>();
// 主线程
void Update () {
lock (this.net_events) {
while (this.net_events.Count > 0) {
cmd_msg msg = this.net_events.Dequeue();
// 收到了一个命令包;
if (this.event_listeners.ContainsKey(msg.stype)) {
this.event_listeners[msg.stype](msg);
}
// end
}
}
// 网络收数据线程
void on_recv_tcp_cmd(byte[] data, int start, int data_len) {
cmd_msg msg;
proto_man.unpack_cmd_msg(data, start, data_len, out msg);
if (msg != null) {
lock (this.net_events) { // recv thread
this.net_events.Enqueue(msg);
}
}
}
public void add_service_listener(int stype, net_message_handler handler) {
if (this.event_listeners.ContainsKey(stype)) {
this.event_listeners[stype] += handler;
}
else {
this.event_listeners.Add(stype, handler);
}
}
public void remove_service_listener(int stype, net_message_handler handler) {
if (!this.event_listeners.ContainsKey(stype)) {
return;
}
this.event_listeners[stype] -= handler;
if (this.event_listeners[stype] == null) {
this.event_listeners.Remove(stype);
}
}
使用
public class game_scene : MonoBehaviour {
void Start () {
this.Invoke("test", 5.0f);
network.instance.add_service_listener(1, this.on_service_event);
}
void on_service_event(cmd_msg msg) {
switch (msg.ctype) {
case 2:
gprotocol.LoginRes res = proto_man.protobuf_deserialize<gprotocol.LoginRes>(msg.body);
Debug.Log("########## res = " + res.status);
break;
}
}
void OnDestroy() {
if (network.instance) {
network.instance.remove_service_listener(1, this.on_service_event);
}
}
void test()
{
gprotocol.LoginReq req = new gprotocol.LoginReq();
req.name = "blake";
req.email = "blake@bycw.edu";
req.age = 34;
req.int_set = 8;
network.instance.send_protobuf_cmd(1, 1, req);
}
}
Network.cs完整代码
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class network : MonoBehaviour {
public string server_ip;
public int port;
private Socket client_socket = null;
private bool is_connect = false;
private Thread recv_thread = null;
private const int RECV_LEN = 8192;
private byte[] recv_buf = new byte[RECV_LEN];
private int recved;
private byte[] long_pkg = null;
private int long_pkg_size = 0;
// event queque
private Queue<cmd_msg> net_events = new Queue<cmd_msg>();
// event listener, stype--> 监听者;
public delegate void net_message_handler(cmd_msg msg);
// 事件和监听的map
private Dictionary<int, net_message_handler> event_listeners = new Dictionary<int, net_message_handler>();
public static network _instance;
public static network instance {
get {
return _instance;
}
}
void Awake() {
_instance = this;
DontDestroyOnLoad(this.gameObject);
}
// Use this for initialization
void Start () {
this.connect_to_server();
}
void OnDestroy() {
// Debug.Log("network onDestroy!");
this.close();
}
void OnApplicaitonQuit() {
// Debug.Log("OnApplicaitonQuit");
this.close();
}
// Update is called once per frame
void Update () {
lock (this.net_events) {
while (this.net_events.Count > 0) {
cmd_msg msg = this.net_events.Dequeue();
// 收到了一个命令包;
if (this.event_listeners.ContainsKey(msg.stype)) {
this.event_listeners[msg.stype](msg);
}
// end
}
}
}
void on_conntect_timeout() {
}
void on_connect_error(string err) {
}
void connect_to_server() {
try {
this.client_socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
IPAddress ipAddress = IPAddress.Parse(this.server_ip);
IPEndPoint ipEndpoint = new IPEndPoint(ipAddress, this.port);
IAsyncResult result = this.client_socket.BeginConnect(ipEndpoint, new AsyncCallback(this.on_connected), this.client_socket);
bool success = result.AsyncWaitHandle.WaitOne(5000, true);
if (!success) { // timeout;
this.on_conntect_timeout();
}
}
catch (System.Exception e) {
Debug.Log(e.ToString());
this.on_connect_error(e.ToString());
}
}
void on_recv_tcp_cmd(byte[] data, int start, int data_len) {
cmd_msg msg;
proto_man.unpack_cmd_msg(data, start, data_len, out msg);
if (msg != null) {
lock (this.net_events) { // recv thread
this.net_events.Enqueue(msg);
}
}
}
void on_recv_tcp_data() {
byte[] pkg_data = (this.long_pkg != null) ? this.long_pkg : this.recv_buf;
while (this.recved > 0) {
int pkg_size = 0;
int head_size = 0;
if (!tcp_packer.read_header(pkg_data, this.recved, out pkg_size, out head_size)) {
break;
}
if (this.recved < pkg_size) {
break;
}
// unsigned char* raw_data = pkg_data + head_size;
int raw_data_start = head_size;
int raw_data_len = pkg_size - head_size;
on_recv_tcp_cmd(pkg_data, raw_data_start, raw_data_len);
// end
if (this.recved > pkg_size) {
this.recv_buf = new byte[RECV_LEN];
Array.Copy(pkg_data, pkg_size, this.recv_buf, 0, this.recved - pkg_size);
pkg_data = this.recv_buf;
}
this.recved -= pkg_size;
if (this.recved == 0 && this.long_pkg != null) {
this.long_pkg = null;
this.long_pkg_size = 0;
}
}
}
void thread_recv_worker() {
if (this.is_connect == false) {
return;
}
while (true) {
if (!this.client_socket.Connected) {
break;
}
try
{
int recv_len = 0;
if (this.recved < RECV_LEN) {
recv_len = this.client_socket.Receive(this.recv_buf, this.recved, RECV_LEN - this.recved, SocketFlags.None);
}
else {
if (this.long_pkg == null) {
int pkg_size;
int head_size;
tcp_packer.read_header(this.recv_buf, this.recved, out pkg_size, out head_size);
this.long_pkg_size = pkg_size;
this.long_pkg = new byte[pkg_size];
Array.Copy(this.recv_buf, 0, this.long_pkg, 0, this.recved);
}
recv_len = this.client_socket.Receive(this.long_pkg, this.recved, this.long_pkg_size - this.recved, SocketFlags.None);
}
if (recv_len > 0) {
this.recved += recv_len;
this.on_recv_tcp_data();
}
}
catch (System.Exception e) {
Debug.Log(e.ToString());
this.client_socket.Disconnect(true);
this.client_socket.Shutdown(SocketShutdown.Both);
this.client_socket.Close();
this.is_connect = false;
break;
}
}
}
void on_connected(IAsyncResult iar) {
try {
Socket client = (Socket)iar.AsyncState;
client.EndConnect(iar);
this.is_connect = true;
this.recv_thread = new Thread(new ThreadStart(this.thread_recv_worker));
this.recv_thread.Start();
Debug.Log("connect to server success" + this.server_ip + ":" + this.port + "!");
}
catch (System.Exception e) {
Debug.Log(e.ToString());
this.on_connect_error(e.ToString());
this.is_connect = false;
}
}
void close() {
if (!this.is_connect) {
return;
}
// abort recv thread
if (this.recv_thread != null) {
this.recv_thread.Abort();
}
// end
if (this.client_socket != null && this.client_socket.Connected) {
this.client_socket.Close();
}
}
private void on_send_data(IAsyncResult iar)
{
try
{
Socket client = (Socket)iar.AsyncState;
client.EndSend(iar);
}
catch (System.Exception e)
{
Debug.Log(e.ToString());
}
}
public void send_protobuf_cmd(int stype, int ctype, ProtoBuf.IExtensible body) {
byte[] cmd_data = proto_man.pack_protobuf_cmd(stype, ctype, body);
if (cmd_data == null) {
return;
}
byte[]tcp_pkg = tcp_packer.pack(cmd_data);
// end
this.client_socket.BeginSend(tcp_pkg, 0, tcp_pkg.Length, SocketFlags.None, new AsyncCallback(this.on_send_data), this.client_socket);
// end
}
public void send_json_cmd(int stype, int ctype, string json_body)
{
byte[] cmd_data = proto_man.pack_json_cmd(stype, ctype, json_body);
if (cmd_data == null) {
return;
}
byte[] tcp_pkg = tcp_packer.pack(cmd_data);
// end
this.client_socket.BeginSend(tcp_pkg, 0, tcp_pkg.Length, SocketFlags.None, new AsyncCallback(this.on_send_data), this.client_socket);
// end
}
public void add_service_listener(int stype, net_message_handler handler) {
if (this.event_listeners.ContainsKey(stype)) {
this.event_listeners[stype] += handler;
}
else {
this.event_listeners.Add(stype, handler);
}
}
public void remove_service_listener(int stype, net_message_handler handler) {
if (!this.event_listeners.ContainsKey(stype)) {
return;
}
this.event_listeners[stype] -= handler;
if (this.event_listeners[stype] == null) {
this.event_listeners.Remove(stype);
}
}
}
Unity网络聊天室案例
1: game.proto 服务协议:
客户端发的Req, 服务器响应客户端的Res结尾;
Onxxxxx开头的协议,表示是服务器主动发送的;
例如A发了一个消息,其他所有玩家都可以收到, 服务器主动发的;
LoginReq: 登陆聊天室,他没有body数据部分,所以不需要定义
LoginRes: 登陆返回: 返回的是一个状态码, 是否登陆成功了;
ExitReq: 离开请求: 他没有body数据部分,所以不需要定义;
ExitRes: 离开返回: 返回的是一个状态码, 是否离开成功了
SendMsgReq: 发送消息请求;
SendMsgRes: 发送消息返回;
OnUserLogin: 登陆广播;
OnUserExit: 离开广播;
OnSendMsg: 发送消息广播;
生成protobuf
syntax = "proto2";
enum Cmd {
INVALID_CMD = 0;
eLoginReq = 1;
eLoginRes = 2;
eExitReq = 3;
eExitRes = 4;
eSendMsgReq = 5;
eSendMsgRes = 6;
eOnUserLogin = 7;
eOnUserExit = 8;
eOnSendMsg = 9;
}
message LoginRes {
required int32 status = 1;
}
message ExitRes {
required int32 status = 1;
}
message SendMsgReq {
required string content = 1;
}
message SendMsgRes {
required int32 status = 1;
}
message OnUserLogin {
required string ip = 1;
required int32 port = 2;
}
message OnUserExit {
required string ip = 1;
required int32 port = 2;
}
message OnSendMsg {
required string ip = 1;
required int32 port = 2;
required string content = 3;
}
服务端注册Service
cmd_name_map.lua
local cmd_name_map = {
"LoginReq",
"LoginRes",
"ExitReq",
"ExitRes",
"SendMsgReq",
"SendMsgRes",
"OnUserLogin",
"OnUserExit",
"OnSendMsg",
}
return cmd_name_map
trm_server.lua
-- {stype, ctype, utag, body}
function on_trm_recv_cmd(s, msg)
end
function on_trm_session_disconnect(s)
local ip, port = session.get_address(s)
print("trm service on recv disconnect: ".. ip .." : "..port)
end
local trm_service = {
on_session_recv_cmd = on_trm_recv_cmd,
on_session_disconnect = on_trm_session_disconnect,
}
local trm_server = {
stype = 1,
service = trm_service,
}
return trm_server;
服务器消息处理逻辑
trm_server.lua
local session_set = {} -- 保存所有客户端的集合
function broadcast_except(msg, except_session)
for i = 1, #session_set do
if except_session ~= session_set[i] then
session.send_msg(session_set[i], msg)
end
end
end
function on_recv_login_cmd(s)
-- 当前是否已经在这个集合,如果是,返回已经在这个聊天室的提示,
for i = 1, #session_set do
if s == session_set[i] then -- 返回状态-1
local msg = {1, 2, 0, {status = -1}}
session.send_msg(s, msg)
return
end
end
-- 加入到当前的集合, 发送数据给客户端
table.insert(session_set, s)
local msg = {1, 2, 0, {status = 1}} -- 返回状态1表示登陆成功了
session.send_msg(s, msg)
--end
local s_ip, s_port = session.get_address(s)
msg = {1, 7, 0, {ip = s_ip, port = s_port}}
broadcast_except(msg, s)
end
function on_recv_exit_cmd(s)
for i = 1, #session_set do
if s == session_set[i] then -- 返回状态-1
table.remove(session_set, i)
local msg = {1, 4, 0, {status = 1}} -- 返回状态1表示离开成功了
session.send_msg(s, msg)
local s_ip, s_port = session.get_address(s)
msg = {1, 8, 0, {ip = s_ip, port = s_port}}
broadcast_except(msg, s)
return
end
end
local msg = {1, 4, 0, {status = -1}} -- 离开的时候不在聊天室里面
session.send_msg(s, msg)
end
function on_recv_send_msg_cmd(s, str)
for i = 1, #session_set do
if s == session_set[i] then -- 返回状态-1
local msg = {1, 6, 0, {status = 1}} -- 返回状态1表示发送成功了
session.send_msg(s, msg)
local s_ip, s_port = session.get_address(s)
msg = {1, 9, 0, {ip = s_ip, port = s_port, content = str}}
broadcast_except(msg, s)
return
end
end
local msg = {1, 6, 0, {status = -1}} -- 返回状态-1表示发送失败了
session.send_msg(s, msg)
end
-- {stype, ctype, utag, body}
function on_trm_recv_cmd(s, msg)
local ctype = msg[2]
local body = msg[4]
if ctype == 1 then
on_recv_login_cmd(s)
elseif ctype == 3 then
on_recv_exit_cmd(s)
elseif ctype == 5 then
on_recv_send_msg_cmd(s, body.content)
end
end
function on_trm_session_disconnect(s)
local ip, port = session.get_address(s)
for i = 1, #session_set do
if s == session_set[i] then -- 返回状态-1
print("remove from talk room: ".. ip .." : "..port)
table.remove(session_set, i)
local s_ip, s_port = session.get_address(s)
local msg = {1, 8, 0, {ip = s_ip, port = s_port}}
broadcast_except(msg, s)
return
end
end
end
local trm_service = {
on_session_recv_cmd = on_trm_recv_cmd,
on_session_disconnect = on_trm_session_disconnect,
}
local trm_server = {
stype = 1,
service = trm_service,
}
return trm_server;
客户端部分代码
talkroom.cs
public class talkroom : MonoBehaviour {
void on_login_return(byte[] body) {
LoginRes res = proto_man.protobuf_deserialize<LoginRes>(body);
if (res.status == 1) {
this.add_status_option("你成功进入聊天室!");
}
else if (res.status == -1) {
this.add_status_option("你已经在聊天室了!");
}
}
void on_exit_return(byte[] body) {
ExitRes res = proto_man.protobuf_deserialize<ExitRes>(body);
if (res.status == 1)
{
this.add_status_option("你离开聊天室!");
}
else if (res.status == -1)
{
this.add_status_option("你早已不在聊天室了!");
}
}
void on_send_msg_return(byte[] body) {
SendMsgRes res = proto_man.protobuf_deserialize<SendMsgRes>(body);
if (res.status == 1) {
this.add_self_option(this.send_msg);
}
else if (res.status == -1)
{
this.add_status_option("你不在聊天室!");
}
}
void on_other_user_enter(byte[] body) {
OnUserLogin res = proto_man.protobuf_deserialize<OnUserLogin>(body);
this.add_status_option(res.ip + ":" + res.port + "进入聊天室!" );
}
void on_other_user_exit(byte[] body) {
OnUserExit res = proto_man.protobuf_deserialize<OnUserExit>(body);
this.add_status_option(res.ip + ":" + res.port + "离开聊天室!");
}
void on_other_user_send_msg(byte[] body) {
OnSendMsg res = proto_man.protobuf_deserialize<OnSendMsg>(body);
this.add_talk_option(res.ip, res.port, res.content);
}
void on_trm_server_return(cmd_msg msg) {
switch (msg.ctype) {
case (int) Cmd.eLoginRes:
this.on_login_return(msg.body);
break;
case (int) Cmd.eExitRes:
this.on_exit_return(msg.body);
break;
case (int) Cmd.eSendMsgRes:
this.on_send_msg_return(msg.body);
break;
case (int) Cmd.eOnUserLogin:
this.on_other_user_enter(msg.body);
break;
case (int) Cmd.eOnUserExit:
this.on_other_user_exit(msg.body);
break;
case (int) Cmd.eOnSendMsg:
this.on_other_user_send_msg(msg.body);
break;
}
}
void Start () {
network.instance.add_service_listener(1, this.on_trm_server_return);
}
// 按钮点击事件(入口)
public void on_enter_talkroom() {
network.instance.send_protobuf_cmd(1, (int)Cmd.eLoginReq, null);
}
// 按钮点击事件(入口)
public void on_exit_talkroom() {
network.instance.send_protobuf_cmd(1, (int)Cmd.eExitReq, null);
}
// 按钮点击事件(入口)
public void on_send_msg() {
if (this.input.text.Length <= 0) {
return;
}
SendMsgReq req = new SendMsgReq();
req.content = this.input.text;
this.send_msg = this.input.text;
//
network.instance.send_protobuf_cmd(1, (int)Cmd.eSendMsgReq, req);
// end
}
}