Skip to content

Commit

Permalink
add support for more mssql connection options
Browse files Browse the repository at this point in the history
  • Loading branch information
lovasoa committed Sep 13, 2023
1 parent 8e995fb commit 49e2779
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 7 deletions.
15 changes: 8 additions & 7 deletions sqlx-core/src/mssql/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ impl MssqlConnection {
PreLogin {
version: Version::default(),
encryption: Encrypt::NOT_SUPPORTED,
instance: options.instance.as_deref(),

..Default::default()
},
Expand All @@ -40,16 +41,16 @@ impl MssqlConnection {
Login7 {
// FIXME: use a version constant
version: 0x74000004, // SQL Server 2012 - SQL Server 2019
client_program_version: 0,
client_pid: 0,
client_program_version: options.client_program_version,
client_pid: options.client_pid,
packet_size: options.requested_packet_size, // max allowed size of TDS packet
hostname: "",
hostname: &options.hostname,
username: &options.username,
password: options.password.as_deref().unwrap_or_default(),
app_name: "",
server_name: "",
client_interface_name: "",
language: "",
app_name: &options.app_name,
server_name: &options.server_name,
client_interface_name: &options.client_interface_name,
language: &options.language,
database: &*options.database,
client_id: [0; 6],
},
Expand Down
62 changes: 62 additions & 0 deletions sqlx-core/src/mssql/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,28 @@ use crate::connection::LogSettings;
mod connect;
mod parse;

/// Options and flags which can be used to configure a Microsoft SQL Server connection.
///
/// Connection strings should be in the form:
/// ```text
/// mssql://[username[:password]@]host/database[?instance=instance_name&packet_size=packet_size&client_program_version=client_program_version&client_pid=client_pid&hostname=hostname&app_name=app_name&server_name=server_name&client_interface_name=client_interface_name&language=language]
/// ```
#[derive(Debug, Clone)]
pub struct MssqlConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) username: String,
pub(crate) database: String,
pub(crate) password: Option<String>,
pub(crate) instance: Option<String>,
pub(crate) log_settings: LogSettings,
pub(crate) client_program_version: u32,
pub(crate) client_pid: u32,
pub(crate) hostname: String,
pub(crate) app_name: String,
pub(crate) server_name: String,
pub(crate) client_interface_name: String,
pub(crate) language: String,
/// Size in bytes of TDS packets to exchange with the server
pub(crate) requested_packet_size: u32,
}
Expand All @@ -29,8 +43,16 @@ impl MssqlConnectOptions {
database: String::from("master"),
username: String::from("sa"),
password: None,
instance: None,
log_settings: Default::default(),
requested_packet_size: 4096,
client_program_version: 0,
client_pid: 0,
hostname: "".to_string(),
app_name: "".to_string(),
server_name: "".to_string(),
client_interface_name: "".to_string(),
language: "".to_string(),
}
}

Expand Down Expand Up @@ -59,6 +81,46 @@ impl MssqlConnectOptions {
self
}

pub fn instance(mut self, instance: &str) -> Self {
self.instance = Some(instance.to_owned());
self
}

pub fn client_program_version(mut self, client_program_version: u32) -> Self {
self.client_program_version = client_program_version.to_owned();
self
}

pub fn client_pid(mut self, client_pid: u32) -> Self {
self.client_pid = client_pid.to_owned();
self
}

pub fn hostname(mut self, hostname: &str) -> Self {
self.hostname = hostname.to_owned();
self
}

pub fn app_name(mut self, app_name: &str) -> Self {
self.app_name = app_name.to_owned();
self
}

pub fn server_name(mut self, server_name: &str) -> Self {
self.server_name = server_name.to_owned();
self
}

pub fn client_interface_name(mut self, client_interface_name: &str) -> Self {
self.client_interface_name = client_interface_name.to_owned();
self
}

pub fn language(mut self, language: &str) -> Self {
self.language = language.to_owned();
self
}

/// Size in bytes of TDS packets to exchange with the server.
/// Returns an error if the size is smaller than 512 bytes
pub fn requested_packet_size(mut self, size: u32) -> Result<Self, Self> {
Expand Down
42 changes: 42 additions & 0 deletions sqlx-core/src/mssql/options/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ use url::Url;
impl FromStr for MssqlConnectOptions {
type Err = Error;

/// Parse a connection string into a set of connection options.
///
/// The connection string is expected to be a valid URL with the following format:
/// ```text
/// mssql://[username[:password]@]host/database[?instance=instance_name&packet_size=packet_size&client_program_version=client_program_version&client_pid=client_pid&hostname=hostname&app_name=app_name&server_name=server_name&client_interface_name=client_interface_name&language=language]
/// ```
fn from_str(s: &str) -> Result<Self, Self::Err> {
let url: Url = s.parse().map_err(Error::config)?;
let mut options = Self::new();
Expand Down Expand Up @@ -41,10 +47,46 @@ impl FromStr for MssqlConnectOptions {
options = options.database(path);
}

for (key, value) in url.query_pairs() {
match key.as_ref() {
"instance" => {
options = options.instance(&*value);
}
"packet_size" => {
let size = value.parse().map_err(Error::config)?;
options = options
.requested_packet_size(size)
.map_err(|_| Error::config(MssqlInvalidOption(format!("packet_size={}", size))))?;
}
"client_program_version" => {
options = options.client_program_version(value.parse().map_err(Error::config)?)
}
"client_pid" => options = options.client_pid(value.parse().map_err(Error::config)?),
"hostname" => options = options.hostname(&*value),
"app_name" => options = options.app_name(&*value),
"server_name" => options = options.server_name(&*value),
"client_interface_name" => options = options.client_interface_name(&*value),
"language" => options = options.language(&*value),
_ => {
return Err(Error::config(MssqlInvalidOption(key.into())));
}
}
}
Ok(options)
}
}

#[derive(Debug)]
struct MssqlInvalidOption(String);

impl std::fmt::Display for MssqlInvalidOption {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "`{}` is not a valid mssql connection option", self.0)
}
}

impl std::error::Error for MssqlInvalidOption {}

#[test]
fn it_parses_username_with_at_sign_correctly() {
let url = "mysql://user@hostname:password@hostname:5432/database";
Expand Down

0 comments on commit 49e2779

Please sign in to comment.