use crate::{ http::{ requests::{Request, RequestStatus, URL}, responses::{Body, Response, UnitOrBoxedError}, }, threads::ThreadPool, }; use std::{ collections::HashMap, net::{Shutdown, TcpListener}, sync::Arc, }; type DynHandlerFn = dyn Fn(Request) -> Response + Send + Sync; type BoxedHandlerFn = Box; type HandlersMap = HashMap>; // Collection of handlers for requests pub struct Handlers { matchers: HandlersMap, } impl Handlers { /// Create a new Handlers struct to handle requests /// /// !! This struct does nothing until you add handlers to it (add_handler) and then bind it to a TcpListener !! pub fn new() -> Self { Handlers { matchers: HashMap::new(), } } /// Add a request handler /// /// !! Be sure to bind this Handlers struct to a TcpListener to actually handle requests !! /// /// path: Path to match (no trailing /) /// handler: Function to handle the request. Must return a Response pub fn add_handler( &mut self, path: &str, handler: impl Fn(Request) -> Response + Send + Sync + 'static, ) { self.matchers .insert(path.to_string(), Arc::new(Box::from(handler))); } /// Bind these handlers to a listener in order to handle incoming requests. /// You will need to pass in a TcpListener. /// This method creates a ThreadPool to handle incoming requests so no multithreading is needed on the part of the caller. /// /// !! Call this *after* adding all your handlers with add_handler !! /// /// listener: TcpListener to bind to pub fn bind(&self, listener: TcpListener) -> UnitOrBoxedError { let pool = match ThreadPool::build(4) { Some(pool) => pool, None => return Err(Box::from("Failed to create ThreadPool")), }; for stream in listener.incoming() { match stream { Ok(stream) => { let request = match Request::parse_stream(stream) { RequestStatus::Ok(req) => req, RequestStatus::MalformedHTTP(stream) => { stream.shutdown(Shutdown::Both).unwrap_or_else(|_| { eprintln!("Failed to close malformed HTTP stream") }); continue; } }; let handler = match Handlers::match_handler(&self.matchers, &request.url) { Some(handler) => handler.clone(), None => { Response::new(request, 404, Body::Static("Not Found")).send()?; continue; } }; pool.execute(move || { handler.as_ref()(request) .send() .unwrap_or_else(|_| eprintln!("Failed to send response")); }) .unwrap_or_else(|_| { eprintln!("Failed to send job to ThreadPool"); }); } Err(e) => { eprintln!("Failed to establish connection: {}", e); } }; } return Ok(()); } fn match_handler<'a>(matchers: &'a HandlersMap, url: &URL) -> Option<&'a Arc> { 'matching_loop: for (path, handler) in matchers.iter() { // Exact match if path == &url.path { return Some(handler); }; // Segment matching let url_segments = url.path.split('/').collect::>(); let path_segments = path.split('/').collect::>(); // If the URL has more segments than the path, it can't match // or if the path has no segments, it can't match if (url_segments.len() != path_segments.len()) || (path_segments.len() == 0) { continue; } // Check each segment of the url for (url_segment, path_segment) in url_segments.iter().zip(path_segments.iter()) { if path_segment.starts_with('[') { // e.g. /path/[id] if path_segment.ends_with(']') { continue; } // e.g. /path/prefix[id].suffix let prefix = path_segment.split('[').collect::>()[0]; let suffix = path_segment.split(']').collect::>()[1]; // begins_with and starts_with always return true on empty strings if url_segment.starts_with(prefix) && url_segment.ends_with(suffix) { continue; } continue 'matching_loop; } if url_segment == path_segment { continue; } continue 'matching_loop; } return Some(handler); } None } }