1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use alloc::vec::Vec;
use axerrno::{ax_err_type, AxError, AxResult};
use core::net::IpAddr;

use smoltcp::iface::SocketHandle;
use smoltcp::socket::dns::{self, GetQueryResultError, StartQueryError};
use smoltcp::wire::DnsQueryType;

use super::addr::into_core_ipaddr;
use super::{SocketSetWrapper, SOCKET_SET};

/// A DNS socket.
struct DnsSocket {
    handle: Option<SocketHandle>,
}

impl DnsSocket {
    #[allow(clippy::new_without_default)]
    /// Creates a new DNS socket.
    pub fn new() -> Self {
        let socket = SocketSetWrapper::new_dns_socket();
        let handle = Some(SOCKET_SET.add(socket));
        Self { handle }
    }

    #[allow(dead_code)]
    /// Update the list of DNS servers, will replace all existing servers.
    pub fn update_servers(self, servers: &[smoltcp::wire::IpAddress]) {
        SOCKET_SET.with_socket_mut::<dns::Socket, _, _>(self.handle.unwrap(), |socket| {
            socket.update_servers(servers)
        });
    }

    /// Query a address with given DNS query type.
    pub fn query(&self, name: &str, query_type: DnsQueryType) -> AxResult<Vec<IpAddr>> {
        // let local_addr = self.local_addr.unwrap_or_else(f);
        let handle = self.handle.ok_or_else(|| ax_err_type!(InvalidInput))?;
        #[cfg(not(feature = "ip"))]
        let iface = &super::ETH0.iface;
        #[cfg(feature = "ip")]
        let iface = super::LOOPBACK.try_get().unwrap();
        let query_handle = SOCKET_SET
            .with_socket_mut::<dns::Socket, _, _>(handle, |socket| {
                socket.start_query(iface.lock().context(), name, query_type)
            })
            .map_err(|e| match e {
                StartQueryError::NoFreeSlot => {
                    ax_err_type!(ResourceBusy, "socket query() failed: no free slot")
                }
                StartQueryError::InvalidName => {
                    ax_err_type!(InvalidInput, "socket query() failed: invalid name")
                }
                StartQueryError::NameTooLong => {
                    ax_err_type!(InvalidInput, "socket query() failed: too long name")
                }
            })?;
        loop {
            SOCKET_SET.poll_interfaces();
            match SOCKET_SET.with_socket_mut::<dns::Socket, _, _>(handle, |socket| {
                socket.get_query_result(query_handle).map_err(|e| match e {
                    GetQueryResultError::Pending => AxError::WouldBlock,
                    GetQueryResultError::Failed => {
                        ax_err_type!(ConnectionRefused, "socket query() failed")
                    }
                })
            }) {
                Ok(n) => {
                    let mut res = Vec::with_capacity(n.capacity());
                    for ip in n {
                        res.push(into_core_ipaddr(ip))
                    }
                    return Ok(res);
                }
                Err(AxError::WouldBlock) => axtask::yield_now(),
                Err(e) => return Err(e),
            }
        }
    }
}

impl Drop for DnsSocket {
    fn drop(&mut self) {
        if let Some(handle) = self.handle {
            SOCKET_SET.remove(handle);
        }
    }
}

/// Public function for DNS query.
pub fn dns_query(name: &str) -> AxResult<alloc::vec::Vec<IpAddr>> {
    let socket = DnsSocket::new();
    socket.query(name, DnsQueryType::A)
}