Skip to content

Instantly share code, notes, and snippets.

@vadv
Created October 3, 2023 20:09
Show Gist options
  • Save vadv/617ba6b74fd494307546d84c83607c52 to your computer and use it in GitHub Desktop.
Save vadv/617ba6b74fd494307546d84c83607c52 to your computer and use it in GitHub Desktop.
diff --git a/.gitignore b/.gitignore
index 94d2001..2078a91 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
/target
.idea
/examples
+/vendor
diff --git a/Cargo.lock b/Cargo.lock
index 9a82980..f105049 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -605,6 +605,9 @@ name = "ipnet"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6"
+dependencies = [
+ "serde",
+]
[[package]]
name = "itoa"
diff --git a/Cargo.toml b/Cargo.toml
index 1b71a16..87e2ceb 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@ log = "0.4.20"
clap = { version = "4.3.1", features = ["derive", "env"] }
serde = { version = "1", features = ["derive"] }
serde_derive = "1"
-ipnet = "2.8.0"
+ipnet = { version = "2.8.0", features = ["serde"] }
once_cell = "1"
arc-swap = "1"
toml = "0.7"
diff --git a/pg_doorman.toml b/pg_doorman.toml
index 1d3e79a..0b1440a 100644
--- a/pg_doorman.toml
+++ b/pg_doorman.toml
@@ -27,6 +27,8 @@ admin_password = "doorman_admin_password"
prometheus_exporter_port = 9075
+hba = ["10.0.0.0/8", "192.168.0.0/16"]
+
[pools]
[pools.example_db]
diff --git a/src/client.rs b/src/client.rs
index 1b130be..7c98876 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -12,9 +12,7 @@ use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
-use crate::config::{
- get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
-};
+use crate::config::{get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, addr_in_hba};
use crate::constants::*;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
@@ -463,6 +461,16 @@ where
return Err(Error::ShuttingDown);
}
+ if !addr_in_hba(addr.ip()) {
+ error_response_terminal(
+ &mut write,
+ "hba forbidden for this ip address",
+ ).await?;
+ return Err(Error::HbaForbiddenError(format!(
+ "hba forbidden client: {} from address: {:?}", client_identifier, addr
+ )));
+ }
+
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
diff --git a/src/config.rs b/src/config.rs
index c4d6bad..4c36865 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -7,6 +7,8 @@ use log::{error, info};
use std::collections::{BTreeMap, HashMap};
use std::path::Path;
use std::collections::hash_map::DefaultHasher;
+use std::net::IpAddr;
+use ipnet::IpNet;
use once_cell::sync::Lazy;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
@@ -259,6 +261,8 @@ pub struct General {
#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
+
+ pub hba: Vec<IpNet>,
}
impl General {
@@ -375,6 +379,7 @@ impl Default for General {
validate_config: true,
prepared_statements: false,
prepared_statements_cache_size: 500,
+ hba: vec![],
}
}
}
@@ -859,8 +864,17 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
}
}
+pub fn addr_in_hba(addr: IpAddr) -> bool {
+ let config = get_config();
+ if config.general.hba.is_empty() {
+ return true
+ }
+ return config.general.hba .iter() .find(|net| net.contains(&addr)) .is_some();
+}
+
#[cfg(test)]
mod test {
+ use std::net::Ipv4Addr;
use super::*;
#[tokio::test]
@@ -883,6 +897,9 @@ mod test {
assert_eq!(get_config().pools["example_db"].users["1"].username, "example_user_2");
assert_eq!(get_config().pools["example_db"].users["0"].pool_size, 40);
assert_eq!(get_config().pools["example_db"].users["0"].pool_mode, Some(PoolMode::Session));
+ assert_eq!(addr_in_hba(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))), true);
+ assert_eq!(addr_in_hba(IpAddr::V4(Ipv4Addr::new(172, 0, 0, 1))), false);
+ assert_eq!(addr_in_hba(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))), true);
}
#[tokio::test]
diff --git a/src/errors.rs b/src/errors.rs
index 49553d8..6a73402 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -26,6 +26,7 @@ pub enum Error {
QueryError(String),
ScramClientError(String),
ScramServerError(String),
+ HbaForbiddenError(String),
}
#[derive(Clone, PartialEq, Debug)]
@@ -51,7 +52,7 @@ impl std::fmt::Display for ClientIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
- "{{ {}@{}{}/?application_name={} }}",
+ "{{ {}@{}/{}?application_name={} }}",
self.username, self.addr, self.pool_name, self.application_name
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment