use crate::as_dev_err;
use alloc::{sync::Arc, vec::Vec};
use driver_common::{BaseDriverOps, DevError, DevResult, DeviceType};
use driver_net::{EthernetAddress, NetBuf, NetBufBox, NetBufPool, NetBufPtr, NetDriverOps};
use virtio_drivers::{device::net::VirtIONetRaw as InnerDev, transport::Transport, Hal};
extern crate alloc;
const NET_BUF_LEN: usize = 1526;
pub struct VirtIoNetDev<H: Hal, T: Transport, const QS: usize> {
rx_buffers: [Option<NetBufBox>; QS],
tx_buffers: [Option<NetBufBox>; QS],
free_tx_bufs: Vec<NetBufBox>,
buf_pool: Arc<NetBufPool>,
inner: InnerDev<H, T, QS>,
}
unsafe impl<H: Hal, T: Transport, const QS: usize> Send for VirtIoNetDev<H, T, QS> {}
unsafe impl<H: Hal, T: Transport, const QS: usize> Sync for VirtIoNetDev<H, T, QS> {}
impl<H: Hal, T: Transport, const QS: usize> VirtIoNetDev<H, T, QS> {
pub fn try_new(transport: T) -> DevResult<Self> {
const NONE_BUF: Option<NetBufBox> = None;
let inner = InnerDev::new(transport).map_err(as_dev_err)?;
let rx_buffers = [NONE_BUF; QS];
let tx_buffers = [NONE_BUF; QS];
let buf_pool = NetBufPool::new(2 * QS, NET_BUF_LEN)?;
let free_tx_bufs = Vec::with_capacity(QS);
let mut dev = Self {
rx_buffers,
inner,
tx_buffers,
free_tx_bufs,
buf_pool,
};
for (i, rx_buf_place) in dev.rx_buffers.iter_mut().enumerate() {
let mut rx_buf = dev.buf_pool.alloc_boxed().ok_or(DevError::NoMemory)?;
let token = unsafe {
dev.inner
.receive_begin(rx_buf.raw_buf_mut())
.map_err(as_dev_err)?
};
assert_eq!(token, i as u16);
*rx_buf_place = Some(rx_buf);
}
for _ in 0..QS {
let mut tx_buf = dev.buf_pool.alloc_boxed().ok_or(DevError::NoMemory)?;
let hdr_len = dev
.inner
.fill_buffer_header(tx_buf.raw_buf_mut())
.or(Err(DevError::InvalidParam))?;
tx_buf.set_header_len(hdr_len);
dev.free_tx_bufs.push(tx_buf);
}
Ok(dev)
}
}
impl<H: Hal, T: Transport, const QS: usize> const BaseDriverOps for VirtIoNetDev<H, T, QS> {
fn device_name(&self) -> &str {
"virtio-net"
}
fn device_type(&self) -> DeviceType {
DeviceType::Net
}
}
impl<H: Hal, T: Transport, const QS: usize> NetDriverOps for VirtIoNetDev<H, T, QS> {
#[inline]
fn mac_address(&self) -> EthernetAddress {
EthernetAddress(self.inner.mac_address())
}
#[inline]
fn can_transmit(&self) -> bool {
!self.free_tx_bufs.is_empty() && self.inner.can_send()
}
#[inline]
fn can_receive(&self) -> bool {
self.inner.poll_receive().is_some()
}
#[inline]
fn rx_queue_size(&self) -> usize {
QS
}
#[inline]
fn tx_queue_size(&self) -> usize {
QS
}
fn recycle_rx_buffer(&mut self, rx_buf: NetBufPtr) -> DevResult {
let mut rx_buf = unsafe { NetBuf::from_buf_ptr(rx_buf) };
let new_token = unsafe {
self.inner
.receive_begin(rx_buf.raw_buf_mut())
.map_err(as_dev_err)?
};
if self.rx_buffers[new_token as usize].is_some() {
return Err(DevError::BadState);
}
self.rx_buffers[new_token as usize] = Some(rx_buf);
Ok(())
}
fn recycle_tx_buffers(&mut self) -> DevResult {
while let Some(token) = self.inner.poll_transmit() {
let tx_buf = self.tx_buffers[token as usize]
.take()
.ok_or(DevError::BadState)?;
unsafe {
self.inner
.transmit_complete(token, tx_buf.packet_with_header())
.map_err(as_dev_err)?;
}
self.free_tx_bufs.push(tx_buf);
}
Ok(())
}
fn transmit(&mut self, tx_buf: NetBufPtr) -> DevResult {
let tx_buf = unsafe { NetBuf::from_buf_ptr(tx_buf) };
let token = unsafe {
self.inner
.transmit_begin(tx_buf.packet_with_header())
.map_err(as_dev_err)?
};
self.tx_buffers[token as usize] = Some(tx_buf);
Ok(())
}
fn receive(&mut self) -> DevResult<NetBufPtr> {
if let Some(token) = self.inner.poll_receive() {
let mut rx_buf = self.rx_buffers[token as usize]
.take()
.ok_or(DevError::BadState)?;
let (hdr_len, pkt_len) = unsafe {
self.inner
.receive_complete(token, rx_buf.raw_buf_mut())
.map_err(as_dev_err)?
};
rx_buf.set_header_len(hdr_len);
rx_buf.set_packet_len(pkt_len);
Ok(rx_buf.into_buf_ptr())
} else {
Err(DevError::Again)
}
}
fn alloc_tx_buffer(&mut self, size: usize) -> DevResult<NetBufPtr> {
let mut net_buf = self.free_tx_bufs.pop().ok_or(DevError::NoMemory)?;
let pkt_len = size;
let hdr_len = net_buf.header_len();
if hdr_len + pkt_len > net_buf.capacity() {
return Err(DevError::InvalidParam);
}
net_buf.set_packet_len(pkt_len);
Ok(net_buf.into_buf_ptr())
}
}