use core::net::SocketAddr;
use core::sync::atomic::{AtomicBool, Ordering};
use axerrno::{ax_err, ax_err_type, AxError, AxResult};
use axhal::time::current_ticks;
use axio::{PollState, Read, Write};
use axsync::Mutex;
use spin::RwLock;
use smoltcp::iface::SocketHandle;
use smoltcp::socket::udp::{self, BindError, SendError};
use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
use super::addr::{from_core_sockaddr, into_core_sockaddr, is_unspecified, UNSPECIFIED_ENDPOINT};
use super::{SocketSetWrapper, SOCKET_SET};
pub struct UdpSocket {
handle: SocketHandle,
local_addr: RwLock<Option<IpEndpoint>>,
peer_addr: RwLock<Option<IpEndpoint>>,
nonblock: AtomicBool,
}
impl UdpSocket {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let socket = SocketSetWrapper::new_udp_socket();
let handle = SOCKET_SET.add(socket);
Self {
handle,
local_addr: RwLock::new(None),
peer_addr: RwLock::new(None),
nonblock: AtomicBool::new(false),
}
}
pub fn local_addr(&self) -> AxResult<SocketAddr> {
match self.local_addr.try_read() {
Some(addr) => addr.map(into_core_sockaddr).ok_or(AxError::NotConnected),
None => Err(AxError::NotConnected),
}
}
pub fn peer_addr(&self) -> AxResult<SocketAddr> {
self.remote_endpoint().map(into_core_sockaddr)
}
#[inline]
pub fn is_nonblocking(&self) -> bool {
self.nonblock.load(Ordering::Acquire)
}
#[inline]
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblock.store(nonblocking, Ordering::Release);
}
pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
let mut self_local_addr = self.local_addr.write();
if local_addr.port() == 0 {
local_addr.set_port(get_ephemeral_port()?);
}
if self_local_addr.is_some() {
return ax_err!(InvalidInput, "socket bind() failed: already bound");
}
let local_endpoint = from_core_sockaddr(local_addr);
let endpoint = IpListenEndpoint {
addr: (!is_unspecified(local_endpoint.addr)).then_some(local_endpoint.addr),
port: local_endpoint.port,
};
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
socket.bind(endpoint).or_else(|e| match e {
BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"),
BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
})
})?;
*self_local_addr = Some(local_endpoint);
debug!("UDP socket {}: bound on {}", self.handle, endpoint);
Ok(())
}
pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() {
return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
}
self.send_impl(buf, from_core_sockaddr(remote_addr))
}
pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
self.recv_impl(|socket| match socket.recv_slice(buf) {
Ok((len, meta)) => Ok((len, into_core_sockaddr(meta.endpoint))),
Err(_) => ax_err!(BadState, "socket recv_from() failed"),
})
}
pub fn recv_from_timeout(&self, buf: &mut [u8], ticks: u64) -> AxResult<(usize, SocketAddr)> {
let expire_at = current_ticks() + ticks;
self.recv_impl(|socket| match socket.recv_slice(buf) {
Ok((len, meta)) => Ok((len, into_core_sockaddr(meta.endpoint))),
Err(_) => {
if current_ticks() > expire_at {
Err(AxError::Timeout)
} else {
Err(AxError::WouldBlock)
}
}
})
}
pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
self.recv_impl(|socket| match socket.peek_slice(buf) {
Ok((len, meta)) => Ok((len, into_core_sockaddr(meta.endpoint))),
Err(_) => ax_err!(BadState, "socket recv_from() failed"),
})
}
pub fn connect(&self, addr: SocketAddr) -> AxResult {
let mut self_peer_addr = self.peer_addr.write();
if self.local_addr.read().is_none() {
self.bind(into_core_sockaddr(UNSPECIFIED_ENDPOINT))?;
}
*self_peer_addr = Some(from_core_sockaddr(addr));
debug!("UDP socket {}: connected to {}", self.handle, addr);
Ok(())
}
pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
let remote_endpoint = self.remote_endpoint()?;
self.send_impl(buf, remote_endpoint)
}
pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
let remote_endpoint = self.remote_endpoint()?;
self.recv_impl(|socket| {
let (len, meta) = socket
.recv_slice(buf)
.map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
if !is_unspecified(remote_endpoint.addr) && remote_endpoint.addr != meta.endpoint.addr {
return Err(AxError::WouldBlock);
}
if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port {
return Err(AxError::WouldBlock);
}
Ok(len)
})
}
pub fn shutdown(&self) -> AxResult {
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
debug!("UDP socket {}: shutting down", self.handle);
socket.close();
});
SOCKET_SET.poll_interfaces();
Ok(())
}
pub fn poll(&self) -> AxResult<PollState> {
if self.local_addr.read().is_none() {
return Ok(PollState {
readable: false,
writable: false,
});
}
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
Ok(PollState {
readable: socket.can_recv(),
writable: socket.can_send(),
})
})
}
}
impl UdpSocket {
fn remote_endpoint(&self) -> AxResult<IpEndpoint> {
match self.peer_addr.try_read() {
Some(addr) => addr.ok_or(AxError::NotConnected),
None => Err(AxError::NotConnected),
}
}
fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult<usize> {
if self.local_addr.read().is_none() {
return ax_err!(NotConnected, "socket send() failed");
}
self.block_on(|| {
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
if !socket.is_open() {
ax_err!(NotConnected, "socket send() failed")
} else if socket.can_send() {
socket
.send_slice(buf, remote_endpoint)
.map_err(|e| match e {
SendError::BufferFull => AxError::WouldBlock,
SendError::Unaddressable => {
ax_err_type!(ConnectionRefused, "socket send() failed")
}
})?;
Ok(buf.len())
} else {
Err(AxError::WouldBlock)
}
})
})
}
fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
where
F: FnMut(&mut udp::Socket) -> AxResult<T>,
{
if self.local_addr.read().is_none() {
return ax_err!(NotConnected, "socket send() failed");
}
self.block_on(|| {
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
if !socket.is_open() {
ax_err!(NotConnected, "socket recv() failed")
} else if socket.can_recv() {
op(socket)
} else {
Err(AxError::WouldBlock)
}
})
})
}
fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
where
F: FnMut() -> AxResult<T>,
{
if self.is_nonblocking() {
f()
} else {
loop {
SOCKET_SET.poll_interfaces();
match f() {
Ok(t) => return Ok(t),
Err(AxError::WouldBlock) => axtask::yield_now(),
Err(e) => return Err(e),
}
}
}
}
pub fn with_socket<R>(&self, f: impl FnOnce(&udp::Socket) -> R) -> R {
SOCKET_SET.with_socket(self.handle, |s| f(s))
}
}
impl Read for UdpSocket {
fn read(&mut self, buf: &mut [u8]) -> AxResult<usize> {
self.recv(buf)
}
}
impl Write for UdpSocket {
fn write(&mut self, buf: &[u8]) -> AxResult<usize> {
self.send(buf)
}
fn flush(&mut self) -> AxResult {
Err(AxError::Unsupported)
}
}
impl Drop for UdpSocket {
fn drop(&mut self) {
self.shutdown().ok();
SOCKET_SET.remove(self.handle);
}
}
fn get_ephemeral_port() -> AxResult<u16> {
const PORT_START: u16 = 0xc000;
const PORT_END: u16 = 0xffff;
static CURR: Mutex<u16> = Mutex::new(PORT_START);
let mut curr = CURR.lock();
let port = *curr;
if *curr == PORT_END {
*curr = PORT_START;
} else {
*curr += 1;
}
Ok(port)
}