#include "lwip/sys.h"
+#include "lwip/netif.h"
#include "lwip/tcpip.h"
#include "rosip.h"
#include <debug.h>
static const char * const tcp_state_str[] = {
- "CLOSED",
- "LISTEN",
- "SYN_SENT",
- "SYN_RCVD",
- "ESTABLISHED",
- "FIN_WAIT_1",
- "FIN_WAIT_2",
- "CLOSE_WAIT",
- "CLOSING",
- "LAST_ACK",
- "TIME_WAIT"
+ "CLOSED",
+ "LISTEN",
+ "SYN_SENT",
+ "SYN_RCVD",
+ "ESTABLISHED",
+ "FIN_WAIT_1",
+ "FIN_WAIT_2",
+ "CLOSE_WAIT",
+ "CLOSING",
+ "LAST_ACK",
+ "TIME_WAIT"
};
/* The way that lwIP does multi-threading is really not ideal for our purposes but
extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
+/* Required for ERR_T to NTSTATUS translation in receive error handling */
+NTSTATUS TCPTranslateError(const err_t err);
+
+void
+LibTCPDumpPcb(PVOID SocketContext)
+{
+ struct tcp_pcb *pcb = (struct tcp_pcb*)SocketContext;
+ unsigned int addr = ntohl(pcb->remote_ip.addr);
+
+ DbgPrint("\tState: %s\n", tcp_state_str[pcb->state]);
+ DbgPrint("\tRemote: (%d.%d.%d.%d, %d)\n",
+ (addr >> 24) & 0xFF,
+ (addr >> 16) & 0xFF,
+ (addr >> 8) & 0xFF,
+ addr & 0xFF,
+ pcb->remote_port);
+}
+
static
void
LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
PQUEUE_ENTRY qp = NULL;
ReferenceObject(Connection);
-
+
while (!IsListEmpty(&Connection->PacketQueue))
{
Entry = RemoveHeadList(&Connection->PacketQueue);
qp = (PQUEUE_ENTRY)ExAllocateFromNPagedLookasideList(&QueueEntryLookasideList);
qp->p = p;
+ qp->Offset = 0;
ExInterlockedInsertTailList(&Connection->PacketQueue, &qp->ListEntry, &Connection->Lock);
}
if (IsListEmpty(&Connection->PacketQueue)) return NULL;
Entry = RemoveHeadList(&Connection->PacketQueue);
-
+
qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
return qp;
{
PQUEUE_ENTRY qp;
struct pbuf* p;
- NTSTATUS Status = STATUS_PENDING;
- UINT ReadLength, ExistingDataLength, SpaceLeft;
+ NTSTATUS Status;
+ UINT ReadLength, PayloadLength, Offset, Copied;
KIRQL OldIrql;
(*Received) = 0;
- SpaceLeft = RecvLen;
LockObject(Connection, &OldIrql);
while ((qp = LibTCPDequeuePacket(Connection)) != NULL)
{
p = qp->p;
- ExistingDataLength = (*Received);
- Status = STATUS_SUCCESS;
+ /* Calculate the payload length first */
+ PayloadLength = p->tot_len;
+ PayloadLength -= qp->Offset;
+ Offset = qp->Offset;
- ReadLength = MIN(p->tot_len, SpaceLeft);
- if (ReadLength != p->tot_len)
+ /* Check if we're reading the whole buffer */
+ ReadLength = MIN(PayloadLength, RecvLen);
+ ASSERT(ReadLength != 0);
+ if (ReadLength != PayloadLength)
{
- if (ExistingDataLength)
- {
- /* The packet was too big but we used some data already so give it another shot later */
- InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
- break;
- }
- else
- {
- /* The packet is just too big to fit fully in our buffer, even when empty so
- * return an informative status but still copy all the data we can fit.
- */
- Status = STATUS_BUFFER_OVERFLOW;
- }
+ /* Save this one for later */
+ qp->Offset += ReadLength;
+ InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
+ qp = NULL;
}
UnlockObject(Connection, OldIrql);
- /* Return to a lower IRQL because the receive buffer may be pageable memory */
- for (; (*Received) < ReadLength + ExistingDataLength; (*Received) += p->len, p = p->next)
- {
- RtlCopyMemory(RecvBuffer + (*Received), p->payload, p->len);
- }
+ Copied = pbuf_copy_partial(p, RecvBuffer, ReadLength, Offset);
+ ASSERT(Copied == ReadLength);
LockObject(Connection, &OldIrql);
- SpaceLeft -= ReadLength;
+ /* Update trackers */
+ RecvLen -= ReadLength;
+ RecvBuffer += ReadLength;
+ (*Received) += ReadLength;
- /* Use this special pbuf free callback function because we're outside tcpip thread */
- pbuf_free_callback(qp->p);
+ if (qp != NULL)
+ {
+ /* Use this special pbuf free callback function because we're outside tcpip thread */
+ pbuf_free_callback(qp->p);
- ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
+ ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
+ }
+ else
+ {
+ /* If we get here, it means we've filled the buffer */
+ ASSERT(RecvLen == 0);
+ }
- if (!RecvLen)
- break;
+ ASSERT((*Received) != 0);
+ Status = STATUS_SUCCESS;
- if (Status != STATUS_SUCCESS)
+ if (!RecvLen)
break;
}
}
else
{
if (Connection->ReceiveShutdown)
- Status = STATUS_SUCCESS;
+ Status = Connection->ReceiveShutdownStatus;
else
Status = STATUS_PENDING;
}
WaitForEventSafely(PRKEVENT Event)
{
PVOID WaitObjects[] = {Event, &TerminationEvent};
-
+
if (KeWaitForMultipleObjects(2,
WaitObjects,
WaitAny,
{
/* Make sure the socket didn't get closed */
if (!arg) return ERR_OK;
-
+
TCPSendEventHandler(arg, space);
-
+
return ERR_OK;
}
InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
{
PCONNECTION_ENDPOINT Connection = arg;
- u32_t len;
/* Make sure the socket didn't get closed */
if (!arg)
return ERR_OK;
}
-
+
if (p)
{
- len = TCPRecvEventHandler(arg, p);
- if (len == p->tot_len)
- {
- tcp_recved(pcb, len);
+ LibTCPEnqueuePacket(Connection, p);
- pbuf_free(p);
+ tcp_recved(pcb, p->tot_len);
- return ERR_OK;
- }
- else if (len != 0)
- {
- DbgPrint("UNTESTED CASE: NOT ALL DATA TAKEN! EXTRA DATA MAY BE LOST!\n");
-
- tcp_recved(pcb, len);
-
- /* Possible memory leak of pbuf here? */
-
- return ERR_OK;
- }
- else
- {
- LibTCPEnqueuePacket(Connection, p);
-
- tcp_recved(pcb, p->tot_len);
-
- return ERR_OK;
- }
+ TCPRecvEventHandler(arg);
}
else if (err == ERR_OK)
{
* whole socket here (by calling tcp_close()) as that would violate TCP specs
*/
Connection->ReceiveShutdown = TRUE;
- TCPFinEventHandler(arg, ERR_OK);
+ Connection->ReceiveShutdownStatus = STATUS_SUCCESS;
+
+ /* If we already did a send shutdown, we're in TIME_WAIT so we can't use this PCB anymore */
+ if (Connection->SendShutdown)
+ {
+ Connection->SocketContext = NULL;
+ tcp_arg(pcb, NULL);
+ }
+
+ /* Indicate the graceful close event */
+ TCPRecvEventHandler(arg);
+
+ /* If the PCB is gone, clean up the connection */
+ if (Connection->SendShutdown)
+ {
+ TCPFinEventHandler(Connection, ERR_CLSD);
+ }
}
return ERR_OK;
}
+/* This function MUST return an error value that is not ERR_ABRT or ERR_OK if the connection
+ * is not accepted to avoid leaking the new PCB */
static
err_t
InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
{
/* Make sure the socket didn't get closed */
if (!arg)
- return ERR_ABRT;
+ return ERR_CLSD;
TCPAcceptEventHandler(arg, newpcb);
if (newpcb->callback_arg)
return ERR_OK;
else
- return ERR_ABRT;
+ return ERR_CLSD;
}
static
void
InternalErrorEventHandler(void *arg, const err_t err)
{
+ PCONNECTION_ENDPOINT Connection = arg;
+
/* Make sure the socket didn't get closed */
- if (!arg) return;
+ if (!arg || Connection->SocketContext == NULL) return;
- TCPFinEventHandler(arg, err);
+ /* The PCB is dead now */
+ Connection->SocketContext = NULL;
+
+ /* Give them one shot to receive the remaining data */
+ Connection->ReceiveShutdown = TRUE;
+ Connection->ReceiveShutdownStatus = TCPTranslateError(err);
+ TCPRecvEventHandler(Connection);
+
+ /* Terminate the connection */
+ TCPFinEventHandler(Connection, err);
}
static
{
struct lwip_callback_msg *msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
struct tcp_pcb *ret;
-
+
if (msg)
{
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
ret = msg->Output.Socket.NewPcb;
else
ret = NULL;
-
+
ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-
+
return ret;
}
-
+
return NULL;
}
LibTCPBindCallback(void *arg)
{
struct lwip_callback_msg *msg = arg;
+ PTCP_PCB pcb = msg->Input.Bind.Connection->SocketContext;
ASSERT(msg);
goto done;
}
- msg->Output.Bind.Error = tcp_bind((PTCP_PCB)msg->Input.Bind.Connection->SocketContext,
+ /* We're guaranteed that the local address is valid to bind at this point */
+ pcb->so_options |= SOF_REUSEADDR;
+
+ msg->Output.Bind.Error = tcp_bind(pcb,
msg->Input.Bind.IpAddress,
ntohs(msg->Input.Bind.Port));
-
+
done:
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
}
{
struct lwip_callback_msg *msg;
err_t ret;
-
+
msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
if (msg)
{
msg->Input.Bind.Connection = Connection;
msg->Input.Bind.IpAddress = ipaddr;
msg->Input.Bind.Port = port;
-
+
tcpip_callback_with_block(LibTCPBindCallback, msg, 1);
-
+
if (WaitForEventSafely(&msg->Event))
ret = msg->Output.Bind.Error;
else
ret = ERR_CLSD;
-
+
ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-
+
return ret;
}
-
+
return ERR_MEM;
}
LibTCPListenCallback(void *arg)
{
struct lwip_callback_msg *msg = arg;
-
+
ASSERT(msg);
-
+
if (!msg->Input.Listen.Connection->SocketContext)
{
msg->Output.Listen.NewPcb = NULL;
}
msg->Output.Listen.NewPcb = tcp_listen_with_backlog((PTCP_PCB)msg->Input.Listen.Connection->SocketContext, msg->Input.Listen.Backlog);
-
+
if (msg->Output.Listen.NewPcb)
{
tcp_accept(msg->Output.Listen.NewPcb, InternalAcceptEventHandler);
}
-
+
done:
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
}
{
struct lwip_callback_msg *msg;
PTCP_PCB ret;
-
+
msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
if (msg)
{
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
msg->Input.Listen.Connection = Connection;
msg->Input.Listen.Backlog = backlog;
-
+
tcpip_callback_with_block(LibTCPListenCallback, msg, 1);
-
+
if (WaitForEventSafely(&msg->Event))
ret = msg->Output.Listen.NewPcb;
else
ret = NULL;
-
+
ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-
+
return ret;
}
LibTCPSendCallback(void *arg)
{
struct lwip_callback_msg *msg = arg;
-
+ PTCP_PCB pcb = msg->Input.Send.Connection->SocketContext;
+ ULONG SendLength;
+ UCHAR SendFlags;
+
ASSERT(msg);
-
+
if (!msg->Input.Send.Connection->SocketContext)
{
msg->Output.Send.Error = ERR_CLSD;
goto done;
}
-
+
if (msg->Input.Send.Connection->SendShutdown)
{
msg->Output.Send.Error = ERR_CLSD;
goto done;
}
-
- msg->Output.Send.Error = tcp_write((PTCP_PCB)msg->Input.Send.Connection->SocketContext,
- msg->Input.Send.Data,
- msg->Input.Send.DataLength,
- TCP_WRITE_FLAG_COPY);
- if (msg->Output.Send.Error == ERR_MEM)
+
+ SendFlags = TCP_WRITE_FLAG_COPY;
+ SendLength = msg->Input.Send.DataLength;
+ if (tcp_sndbuf(pcb) == 0)
{
/* No buffer space so return pending */
msg->Output.Send.Error = ERR_INPROGRESS;
+ goto done;
}
- else if (msg->Output.Send.Error == ERR_OK)
+ else if (tcp_sndbuf(pcb) < SendLength)
{
- /* Queued successfully so try to send it */
+ /* We've got some room so let's send what we can */
+ SendLength = tcp_sndbuf(pcb);
+
+ /* Don't set the push flag */
+ SendFlags |= TCP_WRITE_FLAG_MORE;
+ }
+
+ msg->Output.Send.Error = tcp_write(pcb,
+ msg->Input.Send.Data,
+ SendLength,
+ SendFlags);
+ if (msg->Output.Send.Error == ERR_OK)
+ {
+ /* Queued successfully so try to send it */
tcp_output((PTCP_PCB)msg->Input.Send.Connection->SocketContext);
+ msg->Output.Send.Information = SendLength;
}
-
+ else if (msg->Output.Send.Error == ERR_MEM)
+ {
+ /* The queue is too long */
+ msg->Output.Send.Error = ERR_INPROGRESS;
+ }
+
done:
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
}
err_t
-LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, const int safe)
+LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, u32_t *sent, const int safe)
{
err_t ret;
struct lwip_callback_msg *msg;
-
+
msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
if (msg)
{
msg->Input.Send.Connection = Connection;
msg->Input.Send.Data = dataptr;
msg->Input.Send.DataLength = len;
-
+
if (safe)
LibTCPSendCallback(msg);
else
tcpip_callback_with_block(LibTCPSendCallback, msg, 1);
-
+
if (WaitForEventSafely(&msg->Event))
ret = msg->Output.Send.Error;
else
ret = ERR_CLSD;
-
+
+ if (ret == ERR_OK)
+ *sent = msg->Output.Send.Information;
+ else
+ *sent = 0;
+
ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-
+
return ret;
}
LibTCPConnectCallback(void *arg)
{
struct lwip_callback_msg *msg = arg;
-
+ err_t Error;
+
ASSERT(arg);
-
+
if (!msg->Input.Connect.Connection->SocketContext)
{
msg->Output.Connect.Error = ERR_CLSD;
goto done;
}
-
+
tcp_recv((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalRecvEventHandler);
tcp_sent((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalSendEventHandler);
- err_t Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
- msg->Input.Connect.IpAddress, ntohs(msg->Input.Connect.Port),
- InternalConnectEventHandler);
+ Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
+ msg->Input.Connect.IpAddress, ntohs(msg->Input.Connect.Port),
+ InternalConnectEventHandler);
msg->Output.Connect.Error = Error == ERR_OK ? ERR_INPROGRESS : Error;
-
+
done:
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
}
{
struct lwip_callback_msg *msg;
err_t ret;
-
+
msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
if (msg)
{
msg->Input.Connect.Connection = Connection;
msg->Input.Connect.IpAddress = ipaddr;
msg->Input.Connect.Port = port;
-
+
tcpip_callback_with_block(LibTCPConnectCallback, msg, 1);
-
+
if (WaitForEventSafely(&msg->Event))
{
ret = msg->Output.Connect.Error;
}
else
ret = ERR_CLSD;
-
+
ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-
+
return ret;
}
-
+
return ERR_MEM;
}
goto done;
}
- if (pcb->state == CLOSE_WAIT)
- {
- /* This case actually results in a socket closure later (lwIP bug?) */
- msg->Input.Shutdown.Connection->SocketContext = NULL;
+ /* LwIP makes the (questionable) assumption that SHUTDOWN_RDWR is equivalent to tcp_close().
+ * This assumption holds even if the shutdown calls are done separately (even through multiple
+ * WinSock shutdown() calls). This assumption means that lwIP has the right to deallocate our
+ * PCB without telling us if we shutdown TX and RX. To avoid these problems, we'll clear the
+ * socket context if we have called shutdown for TX and RX.
+ */
+ if (msg->Input.Shutdown.shut_rx) {
+ msg->Output.Shutdown.Error = tcp_shutdown(pcb, TRUE, FALSE);
}
-
- msg->Output.Shutdown.Error = tcp_shutdown(pcb, msg->Input.Shutdown.shut_rx, msg->Input.Shutdown.shut_tx);
- if (msg->Output.Shutdown.Error)
- {
- msg->Input.Shutdown.Connection->SocketContext = pcb;
+ if (msg->Input.Shutdown.shut_tx) {
+ msg->Output.Shutdown.Error = tcp_shutdown(pcb, FALSE, TRUE);
}
- else
+
+ if (!msg->Output.Shutdown.Error)
{
if (msg->Input.Shutdown.shut_rx)
+ {
msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
-
+ msg->Input.Shutdown.Connection->ReceiveShutdownStatus = STATUS_FILE_CLOSED;
+ }
+
if (msg->Input.Shutdown.shut_tx)
msg->Input.Shutdown.Connection->SendShutdown = TRUE;
+
+ if (msg->Input.Shutdown.Connection->ReceiveShutdown &&
+ msg->Input.Shutdown.Connection->SendShutdown)
+ {
+ /* The PCB is not ours anymore */
+ msg->Input.Shutdown.Connection->SocketContext = NULL;
+ tcp_arg(pcb, NULL);
+ TCPFinEventHandler(msg->Input.Shutdown.Connection, ERR_CLSD);
+ }
}
-
+
done:
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
}
if (msg)
{
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
-
+
msg->Input.Shutdown.Connection = Connection;
msg->Input.Shutdown.shut_rx = shut_rx;
msg->Input.Shutdown.shut_tx = shut_tx;
-
+
tcpip_callback_with_block(LibTCPShutdownCallback, msg, 1);
-
+
if (WaitForEventSafely(&msg->Event))
ret = msg->Output.Shutdown.Error;
else
ret = ERR_CLSD;
-
+
ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-
+
return ret;
}
-
+
return ERR_MEM;
}
{
struct lwip_callback_msg *msg = arg;
PTCP_PCB pcb = msg->Input.Close.Connection->SocketContext;
- int state;
/* Empty the queue even if we're already "closed" */
LibTCPEmptyQueue(msg->Input.Close.Connection);
+ /* Check if we've already been closed */
+ if (msg->Input.Close.Connection->Closing)
+ {
+ msg->Output.Close.Error = ERR_OK;
+ goto done;
+ }
+
+ /* Enter "closing" mode if we're doing a normal close */
+ if (msg->Input.Close.Callback)
+ msg->Input.Close.Connection->Closing = TRUE;
+
+ /* Check if the PCB was already "closed" but the client doesn't know it yet */
if (!msg->Input.Close.Connection->SocketContext)
{
msg->Output.Close.Error = ERR_OK;
goto done;
}
- /* Clear the PCB pointer */
+ /* Clear the PCB pointer and stop callbacks */
msg->Input.Close.Connection->SocketContext = NULL;
+ tcp_arg(pcb, NULL);
- /* Save the old PCB state */
- state = pcb->state;
-
+ /* This may generate additional callbacks but we don't care,
+ * because they're too inconsistent to rely on */
msg->Output.Close.Error = tcp_close(pcb);
- if (!msg->Output.Close.Error)
- {
- if (msg->Input.Close.Callback)
- {
- /* Call the FIN handler in the cases where it will not be called by lwIP */
- switch (state)
- {
- case CLOSED:
- case LISTEN:
- case SYN_SENT:
- TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
- break;
-
- default:
- break;
- }
- }
- }
- else
+
+ if (msg->Output.Close.Error)
{
/* Restore the PCB pointer */
msg->Input.Close.Connection->SocketContext = pcb;
+ msg->Input.Close.Connection->Closing = FALSE;
+ }
+ else if (msg->Input.Close.Callback)
+ {
+ TCPFinEventHandler(msg->Input.Close.Connection, ERR_CLSD);
}
done:
{
err_t ret;
struct lwip_callback_msg *msg;
-
+
msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
if (msg)
{
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
-
+
msg->Input.Close.Connection = Connection;
msg->Input.Close.Callback = callback;
-
+
if (safe)
LibTCPCloseCallback(msg);
else
tcpip_callback_with_block(LibTCPCloseCallback, msg, 1);
-
+
if (WaitForEventSafely(&msg->Event))
ret = msg->Output.Close.Error;
else
ret = ERR_CLSD;
-
+
ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-
+
return ret;
}
LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg)
{
ASSERT(arg);
-
+
tcp_arg(pcb, NULL);
tcp_recv(pcb, InternalRecvEventHandler);
tcp_sent(pcb, InternalSendEventHandler);
tcp_err(pcb, InternalErrorEventHandler);
tcp_arg(pcb, arg);
-
+
tcp_accepted(listen_pcb);
}
{
if (!pcb)
return ERR_CLSD;
-
+
*ipaddr = pcb->local_ip;
*port = pcb->local_port;
-
+
return ERR_OK;
}
err_t
LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr * const ipaddr, u16_t * const port)
-{
+{
if (!pcb)
return ERR_CLSD;
-
+
*ipaddr = pcb->remote_ip;
*port = pcb->remote_port;
-
+
return ERR_OK;
}
+
+void
+LibTCPSetNoDelay(
+ PTCP_PCB pcb,
+ BOOLEAN Set)
+{
+ if (Set)
+ pcb->flags |= TF_NODELAY;
+ else
+ pcb->flags &= ~TF_NODELAY;
+}