diff --git a/src/app_state.rs b/src/app_state.rs new file mode 100644 index 0000000..9d308d2 --- /dev/null +++ b/src/app_state.rs @@ -0,0 +1,26 @@ +use std::sync::{Arc, Mutex}; + +use tracing::error; + +use crate::lru_cache::WeightedLRUCache; + +#[derive(Clone)] +pub struct AppState { + pub lru_cache: Arc>, + pub config: crate::config::AppConfig, +} + +impl AppState { + /// little helper to simplify access to the LRU cache + pub fn with_lru_cache(&self, f: impl FnOnce(&mut WeightedLRUCache)) { + let cache = self.lru_cache.lock(); + match cache { + Ok(mut cache) => { + f(&mut cache); + } + Err(e) => { + error!("Could not promote cache item: {}", e); + } + } + } +} diff --git a/src/config.rs b/src/config.rs index 278ed93..a056ab6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,6 +10,7 @@ pub struct AppConfig { pub read_tokens: Option>, pub read_write_tokens: HashSet, pub bind_addresses: Option>, + pub max_cache_size: Option, } pub fn load_config() -> AppConfig { diff --git a/src/disk_cache.rs b/src/disk_cache.rs new file mode 100644 index 0000000..5d8cba9 --- /dev/null +++ b/src/disk_cache.rs @@ -0,0 +1,76 @@ +use std::{path::PathBuf, sync::mpsc::Receiver}; + +use tracing::{debug, warn}; + +use crate::{config::AppConfig, lru_cache::WeightedLRUCache}; + +pub fn get_cache_entry_path(config: &AppConfig, hash: &str) -> PathBuf { + PathBuf::from(&config.cache_dir).join(hash) +} + +pub fn cache_cleanup(config: AppConfig, receiver: Receiver>) { + debug!("Starting cache cleanup loop"); + loop { + let keys = receiver.recv(); + match keys { + Ok(keys) => { + debug!("Expunging {} keys from the cache", keys.len()); + for key in keys { + debug!("Expunging key {}", key); + if let Err(e) = std::fs::remove_file(get_cache_entry_path(&config, &key)) { + warn!("Could not remove cache entry: {}", e); + } + } + } + Err(_) => { + debug!("Expunge channel closed. Stopping cleanup loop."); + break; + } + } + } +} + +struct SortableCacheEntry { + key: String, + weight: u64, + creation_time: std::time::SystemTime, +} + +/// load the existing files from the cache directory into an LRU cache +pub fn fill_lru_cache_from_disk( + config: &AppConfig, + lru_cache: &mut WeightedLRUCache, +) -> std::io::Result<()> { + let entries = std::fs::read_dir(&config.cache_dir)?; + let mut cache_entries = vec![]; + for entry in entries { + match entry { + Ok(entry) => match entry.metadata() { + Ok(metadata) => { + let weight = metadata.len(); + let key = entry.file_name().to_string_lossy().to_string(); + match metadata.created() { + Ok(creation_time) => cache_entries.push(SortableCacheEntry { + key, + weight, + creation_time, + }), + Err(e) => warn!( + "Could not read creation time for cache entry: {}. Ignoring.", + e + ), + } + } + Err(e) => warn!("Could not read metadata for cache entry: {}. Ignoring.", e), + }, + Err(e) => { + warn!("Could not read cache entry: {}. Ignoring.", e); + } + } + } + cache_entries.sort_by(|a, b| b.creation_time.cmp(&a.creation_time)); + for entry in cache_entries { + lru_cache.put(entry.key, entry.weight); + } + Ok(()) +} diff --git a/src/handlers.rs b/src/handlers.rs index fd9660b..828a5ab 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,4 +1,4 @@ -use std::{io::Write, path::PathBuf}; +use std::io::Write; use actix_web::{ HttpResponse, Responder, get, put, @@ -10,21 +10,18 @@ use tracing::{debug, error, trace}; use crate::{ access::{AccessLevel, has_access_level}, - config::AppConfig, + app_state::AppState, + disk_cache::get_cache_entry_path, }; -fn get_cache_entry_path(config: &AppConfig, hash: &str) -> PathBuf { - PathBuf::from(&config.cache_dir).join(hash) -} - #[get("/v1/cache/{hash}")] pub async fn get_cache_item( - app_data: Data, + app_data: Data, auth: BearerAuth, hash: Path, ) -> impl Responder { trace!("Requested cache item {}", hash.as_str()); - if !has_access_level(&app_data, auth.token(), AccessLevel::Read) { + if !has_access_level(&app_data.config, auth.token(), AccessLevel::Read) { debug!( "Tried to read cache item {} without valid read access token.", hash.as_str() @@ -33,25 +30,26 @@ pub async fn get_cache_item( .content_type("text/plain") .body("Please provide a valid access token with at least read-level access."); } - let path = get_cache_entry_path(&app_data, &hash); + let path = get_cache_entry_path(&app_data.config, &hash); if !path.exists() { trace!("Cache item not found: {}", hash.as_str()); HttpResponse::NotFound().body("The record was not found.") } else { trace!("Returning cache item {}", hash.as_str()); + app_data.with_lru_cache(|cache| cache.promote(hash.to_string())); HttpResponse::Ok().body(std::fs::read(path).unwrap()) } } #[put("/v1/cache/{hash}")] pub async fn put_cache_item( - app_data: Data, + app_data: Data, auth: BearerAuth, hash: Path, mut body: Payload, ) -> impl Responder { trace!("Received cache item {}", hash.as_str()); - if !has_access_level(&app_data, auth.token(), AccessLevel::ReadWrite) { + if !has_access_level(&app_data.config, auth.token(), AccessLevel::ReadWrite) { debug!( "Tried to write cache item {} without valid read-write access token.", hash.as_str() @@ -60,21 +58,29 @@ pub async fn put_cache_item( .content_type("text/plain") .body("Please provide a valid access token with read-write access."); } - let path = get_cache_entry_path(&app_data, &hash); + let path = get_cache_entry_path(&app_data.config, &hash); let file = std::fs::File::create_new(&path); match file { Ok(mut file) => { + let mut complete_size = 0u64; while let Some(chunk) = body.next().await { match chunk { - Ok(chunk) => file.write_all(&chunk).expect("This should actually work"), + Ok(chunk) => { + complete_size += chunk.len() as u64; + file.write_all(&chunk).expect("This should actually work") + } Err(e) => { error!("Could not write cache item chunk: {}", e); drop(file); std::fs::remove_file(path).unwrap(); // Clean up to make sure the block doesn't get half-written with the wrong content + complete_size = 0; break; } } } + if complete_size > 0 { + app_data.with_lru_cache(|cache| cache.put(hash.to_string(), complete_size)); + } debug!("Created cache item {}", hash.as_str()); HttpResponse::Accepted().finish() } diff --git a/src/lru_cache.rs b/src/lru_cache.rs new file mode 100644 index 0000000..ab294ef --- /dev/null +++ b/src/lru_cache.rs @@ -0,0 +1,162 @@ +use std::{ + collections::{HashMap, VecDeque}, + sync::mpsc::Sender, +}; + +use tracing::error; + +struct CacheEntry { + weight: u64, +} + +pub struct WeightedLRUCache { + current_weight: u64, + pub max_weight: u64, + entries: HashMap, + lru_list: VecDeque, + eviction_notifier: Option>>, +} + +impl WeightedLRUCache { + pub fn new(max_weight: u64) -> WeightedLRUCache { + WeightedLRUCache { + current_weight: 0, + max_weight, + entries: HashMap::new(), + lru_list: VecDeque::new(), + eviction_notifier: None, + } + } + + pub fn on_eviction(&mut self, sender: Sender>) { + self.eviction_notifier = Some(sender); + } + + pub fn put(&mut self, key: String, weight: u64) { + self.entries + .insert(key.clone(), CacheEntry { weight: weight }); + self.current_weight += weight; + self.promote(key); + if self.current_weight > self.max_weight { + let removed_keys = self.evict(); + if let Some(expunge_notifier) = &mut self.eviction_notifier { + if let Err(e) = expunge_notifier.send(removed_keys) { + error!("Could not send expunge notification: {}", e); + } + } + } + } + + pub fn promote(&mut self, key: String) { + self.lru_list.retain(|entry| entry != &key); + self.lru_list.push_back(key); + } + + pub fn has_entry(&self, key: &str) -> bool { + self.entries.contains_key(key) + } + + pub fn len(&self) -> usize { + self.entries.len() + } + + fn evict(&mut self) -> Vec { + let mut removed_keys = Vec::new(); + while self.current_weight > self.max_weight { + let key = self + .lru_list + .pop_front() + .expect("The LRU list of a filled cache should never be empty"); + let entry = self.entries.remove(&key); + match entry { + Some(entry) => { + self.current_weight -= entry.weight; + removed_keys.push(key); + } + None => error!("Could not find cache entry for evicted key {}", key), + } + } + removed_keys + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn fill_lru_cache() { + let mut lru_cache = WeightedLRUCache::new(100); + + lru_cache.put("key1".to_string(), 10); + lru_cache.put("key2".to_string(), 20); + lru_cache.put("key3".to_string(), 30); + + assert!(lru_cache.has_entry("key1")); + assert!(lru_cache.has_entry("key2")); + assert!(lru_cache.has_entry("key3")); + } + + #[test] + fn overflow_lru_cache() { + let mut lru_cache = WeightedLRUCache::new(100); + + lru_cache.put("key1".to_string(), 10); + lru_cache.put("key2".to_string(), 45); + lru_cache.put("key3".to_string(), 50); + + assert!(!lru_cache.has_entry("key1")); + assert!(lru_cache.has_entry("key2")); + assert!(lru_cache.has_entry("key3")); + } + + #[test] + fn calls_expunge_cb() { + let mut lru_cache = WeightedLRUCache::new(100); + let (sender, receiver) = std::sync::mpsc::channel(); + { + lru_cache.on_eviction(sender); + } + lru_cache.put("key1".to_string(), 10); + lru_cache.put("key2".to_string(), 45); + lru_cache.put("key3".to_string(), 50); + + assert_eq!(receiver.recv().unwrap(), vec!["key1".to_string()]); + } + + #[test] + fn double_insert_pushes_entries_to_the_back() { + let mut lru_cache = WeightedLRUCache::new(100); + + lru_cache.put("key1".to_string(), 10); + lru_cache.put("key2".to_string(), 45); + lru_cache.put("key1".to_string(), 20); + lru_cache.put("key3".to_string(), 50); + + assert!(lru_cache.has_entry("key1")); + assert!(!lru_cache.has_entry("key2")); + assert!(lru_cache.has_entry("key3")); + } + + #[test] + fn promote_pushes_entries_to_the_back() { + let mut lru_cache = WeightedLRUCache::new(100); + + lru_cache.put("key1".to_string(), 10); + lru_cache.put("key2".to_string(), 45); + lru_cache.promote("key1".to_string()); + lru_cache.put("key3".to_string(), 50); + + assert!(lru_cache.has_entry("key1")); + assert!(!lru_cache.has_entry("key2")); + assert!(lru_cache.has_entry("key3")); + } + + #[test] + fn gracefully_fails_with_zero_size() { + let mut lru_cache = WeightedLRUCache::new(0); + lru_cache.put("key1".to_string(), 10); + + assert!(!lru_cache.has_entry("key1")); + } +} diff --git a/src/main.rs b/src/main.rs index 94952e2..f7ea13a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,22 @@ +use std::{ + sync::{Arc, Mutex, mpsc::Receiver}, + u64, +}; + use actix_web::{App, HttpServer, web::Data}; +use app_state::AppState; +use disk_cache::{cache_cleanup, fill_lru_cache_from_disk}; use dotenvy::dotenv; -use tracing::{Level, debug}; +use lru_cache::WeightedLRUCache; +use tracing::{Level, debug, warn}; use tracing_subscriber::FmtSubscriber; mod access; +mod app_state; mod config; +mod disk_cache; mod handlers; +mod lru_cache; #[actix_web::main] async fn main() -> std::io::Result<()> { @@ -21,12 +32,27 @@ async fn main() -> std::io::Result<()> { } let config = config::load_config(); + let mut lru_cache = WeightedLRUCache::new(config.max_cache_size.unwrap_or(u64::MAX)); + if lru_cache.max_weight == u64::MAX { + warn!("Cache size is unlimited. This is probably not what you want"); + } else { + debug!("Cache size is limited to {} bytes", lru_cache.max_weight); + } + let (tx, rx) = std::sync::mpsc::channel(); + lru_cache.on_eviction(tx); + // Load the existing cache entries from disk + fill_lru_cache_from_disk(&config, &mut lru_cache)?; + debug!("Loaded cache with {} entries", lru_cache.len()); + + let app_state = AppState { + lru_cache: Arc::new(Mutex::new(lru_cache)), + config: config.clone(), + }; let mut server = HttpServer::new({ - let config = config.clone(); move || { App::new() - .app_data(Data::new(config.clone())) + .app_data(Data::new(app_state.clone())) .wrap(actix_web::middleware::Logger::default()) .service(handlers::get_cache_item) .service(handlers::put_cache_item) @@ -34,10 +60,18 @@ async fn main() -> std::io::Result<()> { }) .keep_alive(None); // disable HTTP keep-alive because it seems to break NX (at least in version 20.8) + let cleanup_config = config.clone(); + let cleanup_thread = std::thread::spawn(move || cache_cleanup(cleanup_config, rx)); for bind_address in config.bind_addresses.unwrap_or(vec!["::0".to_string()]) { server = server .bind((bind_address.clone(), 8080)) .expect(format!("Should have been able to bind to address {}", bind_address).as_str()); } - server.run().await + server.run().await?; + + if let Err(e) = cleanup_thread.join() { + debug!("Cache cleanup thread failed: {:?}", e); + } + + Ok(()) }