1use std::collections::hash_map::Entry;
2use std::collections::{HashMap, VecDeque};
3use std::io::{self, Read};
4use std::sync::Mutex;
5
6use crate::agent::AgentState;
7use crate::stream::Stream;
8use crate::{Agent, Proxy};
9
10use log::debug;
11use url::Url;
12
13pub(crate) struct ConnectionPool {
37 inner: Mutex<Inner>,
38 max_idle_connections: usize,
39 max_idle_connections_per_host: usize,
40}
41
42struct Inner {
43 recycle: HashMap<PoolKey, VecDeque<Stream>>,
45 lru: VecDeque<PoolKey>,
50}
51
52impl fmt::Debug for ConnectionPool {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.debug_struct("ConnectionPool")
55 .field("max_idle", &self.max_idle_connections)
56 .field("max_idle_per_host", &self.max_idle_connections_per_host)
57 .field("connections", &self.inner.lock().unwrap().lru.len())
58 .finish()
59 }
60}
61fn remove_first_match(list: &mut VecDeque<PoolKey>, key: &PoolKey) -> Option<PoolKey> {
62 match list.iter().position(|x| x == key) {
63 Some(i) => list.remove(i),
64 None => None,
65 }
66}
67
68fn remove_last_match(list: &mut VecDeque<PoolKey>, key: &PoolKey) -> Option<PoolKey> {
69 match list.iter().rposition(|x| x == key) {
70 Some(i) => list.remove(i),
71 None => None,
72 }
73}
74
75impl ConnectionPool {
76 pub(crate) fn new_with_limits(
77 max_idle_connections: usize,
78 max_idle_connections_per_host: usize,
79 ) -> Self {
80 ConnectionPool {
81 inner: Mutex::new(Inner {
82 recycle: HashMap::new(),
83 lru: VecDeque::new(),
84 }),
85 max_idle_connections,
86 max_idle_connections_per_host,
87 }
88 }
89
90 fn noop(&self) -> bool {
92 self.max_idle_connections == 0 || self.max_idle_connections_per_host == 0
93 }
94
95 pub fn try_get_connection(&self, url: &Url, proxy: Option<Proxy>) -> Option<Stream> {
97 let key = PoolKey::new(url, proxy);
98 self.remove(&key)
99 }
100
101 fn remove(&self, key: &PoolKey) -> Option<Stream> {
102 let mut inner = self.inner.lock().unwrap();
103 match inner.recycle.entry(key.clone()) {
104 Entry::Occupied(mut occupied_entry) => {
105 let streams = occupied_entry.get_mut();
106 let stream = streams.pop_back();
108 let stream = stream.expect("invariant failed: empty VecDeque in `recycle`");
109
110 if streams.is_empty() {
111 occupied_entry.remove();
112 }
113
114 remove_last_match(&mut inner.lru, key)
117 .expect("invariant failed: key in recycle but not in lru");
118
119 debug!("pulling stream from pool: {:?} -> {:?}", key, stream);
120 Some(stream)
121 }
122 Entry::Vacant(_) => None,
123 }
124 }
125
126 pub(crate) fn add(&self, key: &PoolKey, stream: Stream) {
127 if self.noop() {
128 return;
129 }
130 debug!("adding stream to pool: {:?} -> {:?}", key, stream);
131
132 let mut inner = self.inner.lock().unwrap();
133 match inner.recycle.entry(key.clone()) {
134 Entry::Occupied(mut occupied_entry) => {
135 let streams = occupied_entry.get_mut();
136 streams.push_back(stream);
137 if streams.len() > self.max_idle_connections_per_host {
138 let stream = streams.pop_front().expect("empty streams list");
140 debug!(
141 "host {:?} has {} conns, dropping oldest: {:?}",
142 key,
143 streams.len(),
144 stream
145 );
146 remove_first_match(&mut inner.lru, key)
147 .expect("invariant failed: key in recycle but not in lru");
148 }
149 }
150 Entry::Vacant(vacant_entry) => {
151 vacant_entry.insert(vec![stream].into());
152 }
153 }
154 inner.lru.push_back(key.clone());
155 if inner.lru.len() > self.max_idle_connections {
156 drop(inner);
157 self.remove_oldest()
158 }
159 }
160
161 fn remove_oldest(&self) {
164 assert!(!self.noop(), "remove_oldest called on Pool with max of 0");
165 let mut inner = self.inner.lock().unwrap();
166 let key = inner.lru.pop_front();
167 let key = key.expect("tried to remove oldest but no entries found!");
168 match inner.recycle.entry(key) {
169 Entry::Occupied(mut occupied_entry) => {
170 let streams = occupied_entry.get_mut();
171 let stream = streams
172 .pop_front()
173 .expect("invariant failed: key existed in recycle but no streams available");
174 debug!("dropping oldest stream in pool: {:?}", stream);
175 if streams.is_empty() {
176 occupied_entry.remove();
177 }
178 }
179 Entry::Vacant(_) => panic!("invariant failed: key existed in lru but not in recycle"),
180 }
181 }
182
183 #[cfg(test)]
184 pub fn len(&self) -> usize {
185 self.inner.lock().unwrap().lru.len()
186 }
187}
188
189#[derive(PartialEq, Clone, Eq, Hash)]
190pub(crate) struct PoolKey {
191 scheme: String,
192 hostname: String,
193 port: Option<u16>,
194 proxy: Option<Proxy>,
195}
196
197use std::fmt;
198
199impl fmt::Debug for PoolKey {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.write_fmt(format_args!(
202 "{}|{}|{}",
203 self.scheme,
204 self.hostname,
205 self.port.unwrap_or(0)
206 ))
207 }
208}
209
210impl PoolKey {
211 fn new(url: &Url, proxy: Option<Proxy>) -> Self {
212 let port = url.port_or_known_default();
213 PoolKey {
214 scheme: url.scheme().to_string(),
215 hostname: url.host_str().unwrap_or("").to_string(),
216 port,
217 proxy,
218 }
219 }
220
221 pub(crate) fn from_parts(scheme: &str, hostname: &str, port: u16) -> Self {
222 PoolKey {
223 scheme: scheme.to_string(),
224 hostname: hostname.to_string(),
225 port: Some(port),
226 proxy: None,
227 }
228 }
229}
230
231#[derive(Clone, Debug)]
232pub(crate) struct PoolReturner {
233 inner: Option<(std::sync::Weak<AgentState>, PoolKey)>,
237}
238
239impl PoolReturner {
240 pub(crate) fn new(agent: &Agent, pool_key: PoolKey) -> Self {
242 Self {
243 inner: Some((agent.weak_state(), pool_key)),
244 }
245 }
246
247 pub(crate) fn none() -> Self {
249 Self { inner: None }
250 }
251
252 pub(crate) fn return_to_pool(&self, stream: Stream) {
253 if let Some((weak_state, pool_key)) = &self.inner {
254 if let Some(state) = weak_state.upgrade() {
255 state.pool.add(pool_key, stream);
256 }
257 }
258 }
259}
260
261pub(crate) struct PoolReturnRead<R: Read + Sized + Into<Stream>> {
266 reader: Option<R>,
269}
270
271impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
272 pub fn new(reader: R) -> Self {
273 PoolReturnRead {
274 reader: Some(reader),
275 }
276 }
277
278 fn return_connection(&mut self) -> io::Result<()> {
279 if let Some(reader) = self.reader.take() {
281 let stream: Stream = reader.into();
283 stream.return_to_pool()?;
284 }
285
286 Ok(())
287 }
288
289 fn do_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
290 match self.reader.as_mut() {
291 None => Ok(0),
292 Some(reader) => reader.read(buf),
293 }
294 }
295}
296
297impl<R: Read + Sized + Into<Stream>> Read for PoolReturnRead<R> {
298 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
299 let amount = self.do_read(buf)?;
300 if amount == 0 {
303 self.return_connection()?;
304 }
305 Ok(amount)
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use std::io;
312
313 use crate::stream::{remote_addr_for_test, Stream};
314 use crate::ReadWrite;
315
316 use super::*;
317
318 #[derive(Debug)]
319 struct NoopStream;
320
321 impl NoopStream {
322 fn stream(pool_returner: PoolReturner) -> Stream {
323 Stream::new(NoopStream, remote_addr_for_test(), pool_returner)
324 }
325 }
326
327 impl Read for NoopStream {
328 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
329 Ok(buf.len())
330 }
331 }
332
333 impl std::io::Write for NoopStream {
334 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
335 Ok(buf.len())
336 }
337
338 fn flush(&mut self) -> io::Result<()> {
339 Ok(())
340 }
341 }
342
343 impl ReadWrite for NoopStream {
344 fn socket(&self) -> Option<&std::net::TcpStream> {
345 None
346 }
347 }
348
349 #[test]
350 fn poolkey_new() {
351 PoolKey::new(&Url::parse("zzz:///example.com").unwrap(), None);
353 }
354
355 #[test]
356 fn pool_connections_limit() {
357 let pool = ConnectionPool::new_with_limits(10, 1);
361 let hostnames = (0..pool.max_idle_connections * 2).map(|i| format!("{}.example", i));
362 let poolkeys = hostnames.map(|hostname| PoolKey {
363 scheme: "https".to_string(),
364 hostname,
365 port: Some(999),
366 proxy: None,
367 });
368 for key in poolkeys.clone() {
369 pool.add(&key, NoopStream::stream(PoolReturner::none()));
370 }
371 assert_eq!(pool.len(), pool.max_idle_connections);
372
373 for key in poolkeys.skip(pool.max_idle_connections) {
374 let result = pool.remove(&key);
375 assert!(result.is_some(), "expected key was not in pool");
376 }
377 assert_eq!(pool.len(), 0)
378 }
379
380 #[test]
381 fn pool_per_host_connections_limit() {
382 let pool = ConnectionPool::new_with_limits(10, 2);
386 let poolkey = PoolKey {
387 scheme: "https".to_string(),
388 hostname: "example.com".to_string(),
389 port: Some(999),
390 proxy: None,
391 };
392
393 for _ in 0..pool.max_idle_connections_per_host * 2 {
394 pool.add(&poolkey, NoopStream::stream(PoolReturner::none()))
395 }
396 assert_eq!(pool.len(), pool.max_idle_connections_per_host);
397
398 for _ in 0..pool.max_idle_connections_per_host {
399 let result = pool.remove(&poolkey);
400 assert!(result.is_some(), "expected key was not in pool");
401 }
402 assert_eq!(pool.len(), 0);
403 }
404
405 #[test]
406 fn pool_checks_proxy() {
407 let pool = ConnectionPool::new_with_limits(10, 1);
410 let url = Url::parse("zzz:///example.com").unwrap();
411 let pool_key = PoolKey::new(&url, None);
412
413 pool.add(&pool_key, NoopStream::stream(PoolReturner::none()));
414 assert_eq!(pool.len(), 1);
415
416 let pool_key = PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap()));
417
418 pool.add(&pool_key, NoopStream::stream(PoolReturner::none()));
419 assert_eq!(pool.len(), 2);
420
421 let pool_key = PoolKey::new(
422 &url,
423 Some(Proxy::new("user:password@localhost:9999").unwrap()),
424 );
425
426 pool.add(&pool_key, NoopStream::stream(PoolReturner::none()));
427 assert_eq!(pool.len(), 3);
428 }
429
430 #[test]
433 fn read_exact() {
434 use crate::response::LimitedRead;
435
436 let url = Url::parse("https:///example.com").unwrap();
437
438 let mut out_buf = [0u8; 500];
439
440 let agent = Agent::new();
441 let pool_key = PoolKey::new(&url, None);
442 let stream = NoopStream::stream(PoolReturner::new(&agent, pool_key));
443 let mut limited_read = LimitedRead::new(stream, std::num::NonZeroUsize::new(500).unwrap());
444
445 limited_read.read_exact(&mut out_buf).unwrap();
446
447 assert_eq!(agent.state.pool.len(), 1);
448 }
449
450 #[test]
454 #[cfg(feature = "gzip")]
455 fn read_exact_chunked_gzip() {
456 use crate::response::Compression;
457 use std::io::Cursor;
458
459 let gz_body = vec![
460 b'E', b'\r', b'\n', 0x1F, 0x8B, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x03, 0xCB, 0x48, 0xCD, 0xC9,
462 b'\r', b'\n', b'E', b'\r', b'\n', 0xC9, 0x57, 0x28, 0xCF, 0x2F, 0xCA, 0x49, 0x51, 0xC8, 0x18, 0xBC, 0x6C, 0x00, 0xA5,
465 b'\r', b'\n', b'7', b'\r', b'\n', 0x5C, 0x7C, 0xEF, 0xA7, 0x00, 0x00, 0x00, b'\r', b'\n', b'0', b'\r', b'\n', b'\r', b'\n', ];
473
474 let agent = Agent::new();
475 assert_eq!(agent.state.pool.len(), 0);
476
477 let ro = crate::test::TestStream::new(Cursor::new(gz_body), std::io::sink());
478 let stream = Stream::new(
479 ro,
480 "1.1.1.1:4343".parse().unwrap(),
481 PoolReturner::new(&agent, PoolKey::from_parts("http", "1.1.1.1", 8080)),
482 );
483
484 let chunked = crate::chunked::Decoder::new(stream);
485 let pool_return_read: Box<(dyn Read + Send + Sync + 'static)> =
486 Box::new(PoolReturnRead::new(chunked));
487
488 let compression = Compression::Gzip;
489 let mut stream = compression.wrap_reader(pool_return_read);
490
491 io::copy(&mut stream, &mut io::sink()).unwrap();
492
493 assert_eq!(agent.state.pool.len(), 1);
494 }
495}