diff -r a40139603cde -r 92225a708bda rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Sat Apr 13 00:37:35 2019 +0300 +++ b/rust/hedgewars-server/src/server/network.rs Mon Apr 15 21:22:51 2019 +0300 @@ -13,6 +13,7 @@ net::{TcpListener, TcpStream}, Poll, PollOpt, Ready, Token, }; +use mio_extras::timer; use netbuf; use slab::Slab; @@ -34,8 +35,11 @@ SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode, }, }; +use std::time::Duration; const MAX_BYTES_PER_READ: usize = 2048; +const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30); +const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(10); #[derive(Hash, Eq, PartialEq, Copy, Clone)] pub enum NetworkClientState { @@ -80,16 +84,23 @@ peer_addr: SocketAddr, decoder: ProtocolDecoder, buf_out: netbuf::Buf, + timeout: timer::Timeout, } impl NetworkClient { - pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { + pub fn new( + id: ClientId, + socket: ClientSocket, + peer_addr: SocketAddr, + timeout: timer::Timeout, + ) -> NetworkClient { NetworkClient { id, socket, peer_addr, decoder: ProtocolDecoder::new(), buf_out: netbuf::Buf::new(), + timeout, } } @@ -231,6 +242,10 @@ pub fn send_string(&mut self, msg: &str) { self.send_raw_msg(&msg.as_bytes()); } + + pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout { + replace(&mut self.timeout, timeout) + } } #[cfg(feature = "tls-connections")] @@ -288,6 +303,13 @@ } } +enum TimeoutEvent { + SendPing, + DropClient, +} + +struct TimerData(TimeoutEvent, ClientId); + pub struct NetworkLayer { listener: TcpListener, server: HWServer, @@ -298,6 +320,21 @@ ssl: ServerSsl, #[cfg(feature = "official-server")] io: IoLayer, + timer: timer::Timer, +} + +fn create_ping_timeout(timer: &mut timer::Timer, client_id: ClientId) -> timer::Timeout { + timer.set_timeout( + SEND_PING_TIMEOUT, + TimerData(TimeoutEvent::SendPing, client_id), + ) +} + +fn create_drop_timeout(timer: &mut timer::Timer, client_id: ClientId) -> timer::Timeout { + timer.set_timeout( + DROP_CLIENT_TIMEOUT, + TimerData(TimeoutEvent::DropClient, client_id), + ) } impl NetworkLayer { @@ -306,6 +343,7 @@ let clients = Slab::with_capacity(clients_limit); let pending = HashSet::with_capacity(2 * clients_limit); let pending_cache = Vec::with_capacity(2 * clients_limit); + let timer = timer::Builder::default().build(); NetworkLayer { listener, @@ -317,6 +355,7 @@ ssl: NetworkLayer::create_ssl_context(), #[cfg(feature = "official-server")] io: IoLayer::new(), + timer, } } @@ -346,6 +385,13 @@ PollOpt::edge(), )?; + poll.register( + &self.timer, + utils::TIMER_TOKEN, + Ready::readable(), + PollOpt::edge(), + )?; + #[cfg(feature = "official-server")] self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?; @@ -384,7 +430,12 @@ ) .expect("could not register socket with event loop"); - let client = NetworkClient::new(client_id, client_socket, addr); + let client = NetworkClient::new( + client_id, + client_socket, + addr, + create_ping_timeout(&mut self.timer, client_id), + ); info!("client {} ({}) added", client.id, client.peer_addr); entry.insert(client); @@ -419,6 +470,29 @@ } } + pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> { + while let Some(TimerData(event, client_id)) = self.timer.poll() { + match event { + TimeoutEvent::SendPing => { + if let Some(ref mut client) = self.clients.get_mut(client_id) { + client.send_string(&HWServerMessage::Ping.to_raw_protocol()); + client.write()?; + client.replace_timeout(create_drop_timeout(&mut self.timer, client_id)); + } + } + TimeoutEvent::DropClient => { + self.operation_failed( + poll, + client_id, + &ErrorKind::TimedOut.into(), + "No ping response", + )?; + } + } + } + Ok(()) + } + #[cfg(feature = "official-server")] pub fn handle_io_result(&mut self) { if let Some((client_id, result)) = self.io.try_recv() { @@ -486,6 +560,8 @@ pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) { + let timeout = client.replace_timeout(create_ping_timeout(&mut self.timer, client_id)); + self.timer.cancel_timeout(&timeout); client.read() } else { warn!("invalid readable client: {}", client_id);