diff --git a/src/connector.cpp b/src/connector.cpp index 577a500..afcc6d1 100644 --- a/src/connector.cpp +++ b/src/connector.cpp @@ -4,11 +4,14 @@ #include #include +#include "struct/message.h" #include "helper/protocol_const.h" #include "helper/functions.h" #include "settings.h" + Connector::Connector() + : _usbBuffer(Settings::usbQueue * 2, Settings::usbBufferSize) { try { @@ -25,28 +28,18 @@ Connector::Connector() throw std::runtime_error(std::string("Can't initialise USB: ") + libusb_error_name(result)); _usbTransfers = Settings::usbQueue; - if(_usbTransfers > MAX_USB_REQUESTS) + if (_usbTransfers > MAX_USB_REQUESTS) _usbTransfers = MAX_USB_REQUESTS; for (int i = 0; i < _usbTransfers; i++) { - _usbTransfer[i] = libusb_alloc_transfer(0); - uint8_t *buf = static_cast(malloc(Settings::usbBufferSize)); - libusb_fill_bulk_transfer(_usbTransfer[i], _device, _endpoint_in, buf, Settings::usbBufferSize, Connector::onUsbRead, this, READ_TIMEOUT); + _usbContext[i].transfer = libusb_alloc_transfer(0); } } Connector::~Connector() { - _active = false; - _queue.notify(); - libusb_interrupt_event_handler(_context); - - if (_write_thread.joinable()) - _write_thread.join(); - - if (_read_thread.joinable()) - _read_thread.join(); + stop(); if (_cipher) { @@ -56,24 +49,22 @@ Connector::~Connector() for (int i = 0; i < _usbTransfers; i++) { - if (_usbTransfer[i]) + if (_usbContext[i].transfer) { - timeval timeout{0, 100000}; - libusb_cancel_transfer(_usbTransfer[i]); - libusb_handle_events_timeout_completed(_context, &timeout, nullptr); + timeval timeout{0, 100000}; + libusb_cancel_transfer(_usbContext[i].transfer); + libusb_handle_events_timeout_completed(_context, &timeout, nullptr); } } for (int i = 0; i < _usbTransfers; i++) { - if (_usbTransfer[i]) + if (_usbContext[i].transfer) { - if (_usbTransfer[i]->buffer) - free(_usbTransfer[i]->buffer); - libusb_free_transfer(_usbTransfer[i]); - _usbTransfer[i] = nullptr; + libusb_free_transfer(_usbContext[i].transfer); + _usbContext[i].transfer = nullptr; } - } + } if (_device) { @@ -102,18 +93,12 @@ void Connector::stop() { if (!_active) return; - _active = false; _queue.notify(); - libusb_interrupt_event_handler(_context); - state(PROTOCOL_STATUS_INITIALISING); if (_write_thread.joinable()) _write_thread.join(); - - if (_read_thread.joinable()) - _read_thread.join(); } bool Connector::connect(uint16_t vendor_id, uint16_t product_id) @@ -355,24 +340,27 @@ void Connector::printMessage(uint32_t cmd, uint32_t length, uint8_t *data, bool void Connector::onUsbRead(libusb_transfer *transfer) { - Connector *c = static_cast(transfer->user_data); + UsbContext *c = static_cast(transfer->user_data); - if (!c->_active) + if (!c->owner->_active) return; - if (transfer->status == LIBUSB_TRANSFER_NO_DEVICE) + c->slot->commit(transfer->actual_length); + try { - c->onDisconnect(); + c->slot = c->owner->_usbBuffer.get(); + } + catch (const std::exception &e) + { + std::cout << "[Connection] USB buffer unavailable: " << e.what() << std::endl; + c->owner->onDisconnect(); return; } - - if (c->_active && transfer->status == LIBUSB_TRANSFER_COMPLETED && transfer->actual_length > 0) - c->onData(transfer->buffer, transfer->actual_length); - - if (c->_active && (libusb_submit_transfer(transfer) != LIBUSB_SUCCESS)) + libusb_fill_bulk_transfer(c->transfer, c->owner->_device, c->owner->_endpoint_in, c->slot->data, c->slot->size, Connector::onUsbRead, c, 0); + if (c->owner->_active && (libusb_submit_transfer(c->transfer) != LIBUSB_SUCCESS)) { std::cout << "[Connection] USB transfer re-submit failed" << std::endl; - c->onDisconnect(); + c->owner->onDisconnect(); } } @@ -383,9 +371,10 @@ void Connector::readLoop() for (int i = 0; i < _usbTransfers; i++) { - _usbTransfer[i]->dev_handle = _device; - _usbTransfer[i]->endpoint = _endpoint_in; - int status = libusb_submit_transfer(_usbTransfer[i]); + _usbContext[i].slot = _usbBuffer.get(); + _usbContext[i].owner = this; + libusb_fill_bulk_transfer(_usbContext[i].transfer, _device, _endpoint_in, _usbContext[i].slot->data, _usbContext[i].slot->size, Connector::onUsbRead, &_usbContext[i], 0); + int status = libusb_submit_transfer(_usbContext[i].transfer); if (status != LIBUSB_SUCCESS) { std::cout << "[Connection] USB transfer submit " << i << " failed: " << status << std::endl; @@ -400,6 +389,63 @@ void Connector::readLoop() } } +void Connector::bufferReadLoop() +{ + setThreadName("protocol-log"); + + while (_active && _connected) + { + Header header{0, 0, 0, 0}; + uint8_t *data = nullptr; + + if (!_usbBuffer.read(reinterpret_cast(&header), sizeof(Header))) + break; + + int32_t payloadLength = static_cast(header.length); + int32_t padding = header.type == CMD_VIDEO_DATA ? AV_INPUT_BUFFER_PADDING_SIZE : 0; + + //std::cout << "[Connection] Chunk: cmd " << header.type << " len " << header.length << " magic " << header.magic << " queue state " << _usbBuffer.count() << std::endl; + + if (payloadLength > 0) + { + data = static_cast(malloc(payloadLength + padding)); + + if (!_usbBuffer.read(data, payloadLength)) + { + free(data); + break; + } + } + + if (header.magic == MAGIC_ENC && payloadLength > 0) + { + if (!_cipher) + { + std::cout << "[Connection] Received encrypted buffered command " << header.type + << " but cipher is not initialised" << std::endl; + free(data); + continue; + } + + if (!_cipher->Decrypt(data, payloadLength)) + { + std::cout << "[Connection] Can't decrypt buffered command " << header.type << std::endl; + free(data); + continue; + } + } + +#ifdef PROTOCOL_DEBUG + printMessage(header.type, payloadLength, data, header.magic == MAGIC_ENC, false); +#endif + + if (padding > 0 && data) + std::fill(data + payloadLength, data + payloadLength + padding, 0); + + onData(header.type, payloadLength, data); + } +} + void Connector::writeLoop() { // Set thread name @@ -413,10 +459,13 @@ void Connector::writeLoop() { std::cout << "[Connection] Device connected" << std::endl; + _usbBuffer.start(); _read_thread = std::thread(&Connector::readLoop, this); - onDevice(true); + _buffer_thread = std::thread(&Connector::bufferReadLoop, this); + onDevice(true); state(PROTOCOL_STATUS_ONLINE); + while (_connected && _active) { std::unique_ptr message = _queue.pop(); @@ -440,10 +489,14 @@ void Connector::writeLoop() } _queue.clear(); + _usbBuffer.stop(); if (_read_thread.joinable()) _read_thread.join(); + + if (_buffer_thread.joinable()) + _buffer_thread.join(); } - _queue.waitFor(_active, 1000); + _queue.waitFor(_active, 100); } } diff --git a/src/connector.h b/src/connector.h index 0fee361..414c727 100644 --- a/src/connector.h +++ b/src/connector.h @@ -9,12 +9,12 @@ #include #include "helper/isender.h" -#include "aes_cipher.h" #include "struct/atomic_queue.h" #include "struct/command.h" +#include "struct/usb_buffer.h" +#include "aes_cipher.h" -#define READ_TIMEOUT 10000 -#define MAX_USB_REQUESTS 64 +#define MAX_USB_REQUESTS 128 #define COMMAND_QUEUE_SIZE 256 #define ENCRYPTION_BASE "SkBRDy3gmrw1ieH0" @@ -24,6 +24,13 @@ #define PROTOCOL_DEBUG_OUT 3 #define PROTOCOL_DEBUG_ALL 4 +class Connector; + +struct UsbContext { + Connector* owner; + DataSlot* slot; + libusb_transfer* transfer; +}; class Connector : public ISender { @@ -37,7 +44,7 @@ public: bool send(std::unique_ptr packet) override; protected: - virtual void onData(uint8_t *data, uint32_t length) = 0; + virtual void onData(uint32_t cmd, uint32_t length, uint8_t *data) = 0; virtual void onStatus(u_int8_t status) = 0; virtual void onDevice(bool connected) = 0; @@ -48,12 +55,14 @@ protected: static void printBytes(uint8_t *data, uint32_t length, uint16_t max); static const char *cmdString(int cmd); - AESCipher *_cipher = nullptr; + AESCipher *_cipher = nullptr; + UsbBuffer _usbBuffer; private: static void onUsbRead(libusb_transfer *transfer); void readLoop(); + void bufferReadLoop(); void writeLoop(); void onDisconnect(); bool connect(uint16_t vendor_id, uint16_t product_id); @@ -76,11 +85,12 @@ private: uint8_t _nodeviceCount; std::thread _read_thread; + std::thread _buffer_thread; std::thread _write_thread; std::mutex _write_mutex; std::atomic _active = false; AtomicQueue _queue{COMMAND_QUEUE_SIZE}; - libusb_transfer *_usbTransfer[MAX_USB_REQUESTS] = {}; + UsbContext _usbContext[MAX_USB_REQUESTS] = {}; }; #endif /* SRC_CONNECTOR */ diff --git a/src/protocol.cpp b/src/protocol.cpp index e2ecb9c..af772ec 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -18,7 +18,6 @@ Protocol::Protocol(uint16_t width, uint16_t height, uint16_t fps) _fps(fps), _phoneConnected(false) { - } Protocol::~Protocol() @@ -163,54 +162,64 @@ void Protocol::onControl(int cmd) } } -void Protocol::onData(uint8_t *data, uint32_t length) +void Protocol::onData(uint32_t cmd, uint32_t length, uint8_t *data) { - uint32_t offset = 0; - while (offset < length) + bool dispose = true; + + switch (cmd) { - if (_message == nullptr) - _message = std::make_unique(); - offset += _message->parse(data + offset, length - offset); - - if (!_message->ready()) - continue; - - if (!_message->valid()) + case CMD_CONTROL: + if (length == 4) { - std::cout << "[Connection] Invalid message received" << std::endl; - _message = nullptr; + int value = 0; + memcpy(&value, data, sizeof(int)); + onControl(value); + } + break; - while(true) + case CMD_PLUGGED: + onPhone(true); + break; + + case CMD_UNPLUGGED: + onPhone(false); + break; + + case CMD_VIDEO_DATA: + if (length > 20) + { + videoData.pushDiscard(std::make_unique(data, length, 20)); + dispose = false; + } + break; + + case CMD_AUDIO_DATA: + if (length > 16) + { + int channel = 0; + memcpy(&channel, data + 8, sizeof(int)); + + if (channel == 1) { - if (length - offset < sizeof(uint32_t)) - return; - uint32_t magic = 0; - memcpy(&magic, data + offset, sizeof(uint32_t)); - if (magic == MAGIC || magic == MAGIC_ENC) - break; - offset++; + audioStreamMain.pushDiscard(std::make_unique(data, length, 12)); + dispose = false; + } + else if (channel == 2) + { + audioStreamAux.pushDiscard(std::make_unique(data, length, 12)); + dispose = false; } - - continue; } + break; - if (_message->encrypted() && !_message->decrypt(_cipher)) - { - if (!_cipher) - std::cout << "[Connection] Received encrypted command " << _message->type() << " but cipher is not initialised" << std::endl; - else - std::cout << "[Connection] Can't decrypt command " << _message->type() << std::endl; - _message = nullptr; - continue; - } - -#ifdef PROTOCOL_DEBUG - printMessage(_message->type(), _message->length(), _message->data(), _message->encrypted(), false); -#endif - - dispatch(std::move(_message)); - _message = nullptr; + case CMD_ENCRYPTION: + if (length == 0) + setEncryption(true); + break; } + + if (dispose && data && length > 0) + free(data); } void Protocol::dispatch(std::unique_ptr msg) @@ -233,7 +242,7 @@ void Protocol::dispatch(std::unique_ptr msg) case CMD_VIDEO_DATA: { - if(msg->setOffset(20)) + if (msg->setOffset(20)) videoData.pushDiscard(std::move(msg)); break; } diff --git a/src/protocol.h b/src/protocol.h index 056e323..87f134c 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -28,7 +28,7 @@ private: void onStatus(uint8_t status) override; void onDevice(bool connected) override; - void onData(uint8_t *data, uint32_t length) override; + void onData(uint32_t cmd, uint32_t length, uint8_t *data) override; void dispatch(std::unique_ptr msg); void onControl(int cmd); @@ -42,8 +42,6 @@ private: uint32_t _evtStatusId = (uint32_t)-1; uint32_t _evtPhoneId = (uint32_t)-1; - - std::unique_ptr _message; }; #endif /* SRC_PROTOCOL */ diff --git a/src/struct/message.h b/src/struct/message.h index aa3f136..34fca75 100644 --- a/src/struct/message.h +++ b/src/struct/message.h @@ -28,6 +28,17 @@ public: { } + Message(uint8_t *data, uint32_t length, uint32_t offset) + : _header({0, static_cast(length), 0, 0}), + _data(data), + _offset(offset <= length ? offset : length), + _headerLegth(sizeof(Header)), + _dataLength(length), + _valid(true), + _ready(true) + { + } + ~Message() { if (_data) diff --git a/src/struct/usb_buffer.h b/src/struct/usb_buffer.h new file mode 100644 index 0000000..3e7c421 --- /dev/null +++ b/src/struct/usb_buffer.h @@ -0,0 +1,197 @@ +#ifndef SRC_STRUCT_USB_BUFFER +#define SRC_STRUCT_USB_BUFFER + +#include +#include +#include +#include +#include +#include +#include +#include + +class DataSlot +{ +public: + DataSlot() + : ready(false), offset(0), length(0), size(0), data(nullptr), _cv(nullptr) + { + } + + ~DataSlot() + { + size = 0; + if (data) + { + free(data); + data = nullptr; + } + } + + void init(uint32_t slotSize, std::condition_variable *condition) + { + ready = false; + offset = 0; + length = 0; + size = slotSize; + data = static_cast(malloc(size)); + _cv = condition; + } + + void reset() + { + ready = false; + offset = 0; + length = 0; + } + + void commit(size_t dataSize) + { + length = dataSize; + offset = 0; + ready = true; + + if (_cv) + _cv->notify_one(); + } + + bool consume(size_t dataSize) + { + offset += dataSize; + if (offset < length) + return false; + ready = false; + return true; + } + + size_t remain() const { return length > offset ? length - offset : 0; } + + bool ready; + size_t offset; + size_t length; + size_t size; + uint8_t *data; + +private: + std::condition_variable *_cv; +}; + +class UsbBuffer +{ +public: + UsbBuffer(uint16_t slotCount, size_t slotSize) + : _active(true), _size(slotCount) + { + if (slotCount == 0 || slotSize == 0) + throw std::invalid_argument("[UsbBuffer] Number of slots and slot size must be greater than 0"); + + _slots = new DataSlot[_size]; + + for (uint16_t i = 0; i < _size; i++) + { + _slots[i].init(slotSize, &_cvReady); + } + } + + UsbBuffer(const UsbBuffer &) = delete; + UsbBuffer &operator=(const UsbBuffer &) = delete; + + ~UsbBuffer() + { + stop(); + if (_slots) + { + delete[] _slots; + } + } + + void start() + { + _readSlot = 0; + _writeSlot = 0; + for (uint16_t i = 0; i < _size; i++) + { + _slots[i].reset(); + } + _active = true; + } + + void stop() + { + _active = false; + std::lock_guard lock(_mutex); + _cvReady.notify_all(); + } + + DataSlot *get() + { + if (!_active || _slots[_writeSlot].ready) + throw std::runtime_error("[UsbBuffer] No free slots available"); + DataSlot *slot = &(_slots[_writeSlot]); + _writeSlot++; + if (_writeSlot >= _size) + _writeSlot = 0; + return slot; + } + + bool read(uint8_t *dst, size_t length) + { + if (length == 0) + return true; + + if (dst == nullptr) + throw std::invalid_argument("[UsbBuffer] Read destination is null"); + + size_t done = 0; + while (length > 0) + { + if (!_active) + return false; + + while (!_slots[_readSlot].ready) + { + std::unique_lock lock(_mutex); + _cvReady.wait(lock, [&]() + { return !_active || _slots[_readSlot].ready; }); + if (!_active) + return false; + } + + size_t copy = _slots[_readSlot].remain(); + if (copy > length) + copy = length; + std::memcpy(dst + done, _slots[_readSlot].data + _slots[_readSlot].offset, copy); + if (_slots[_readSlot].consume(copy)) + { + _readSlot++; + if (_readSlot >= _size) + _readSlot = 0; + } + done += copy; + length -= copy; + } + + return true; + } + + int count() const { + int result = _writeSlot - _readSlot; + if(result<0) + result += _size; + return result; + } + +private: + mutable std::mutex _mutex; + std::condition_variable _cvReady; + + std::atomic _active; + + uint16_t _writeSlot = 0; + uint16_t _readSlot = 0; + + DataSlot *_slots = nullptr; + uint16_t _size = 0; +}; + +#endif /* SRC_STRUCT_USB_BUFFER */