Skip to content

Commit

Permalink
private key
Browse files Browse the repository at this point in the history
  • Loading branch information
sfackler committed Jun 24, 2017
1 parent 1d2a529 commit c7bc56e
Show file tree
Hide file tree
Showing 11 changed files with 340 additions and 79 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ schannel = "0.1.7"

[target.'cfg(not(any(target_os = "windows", target_os = "macos")))'.dependencies]
openssl = "0.9.2"

[replace]
"schannel:0.1.7" = { git = "https://github.com/steffengy/schannel-rs" }
"openssl:0.9.14" = { git = "https://github.com/sfackler/rust-openssl", branch = "pkey-private-key-from-der" }
"security-framework:0.1.14" = { git = "https://github.com/sfackler/rust-security-framework" }
24 changes: 24 additions & 0 deletions src/imp/openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::io;
use std::fmt;
use std::error;
use self::openssl::pkcs12;
use self::openssl::pkey::PKey;
use self::openssl::error::ErrorStack;
use self::openssl::ssl::{self, SslMethod, SslConnectorBuilder, SslConnector, SslAcceptorBuilder,
SslAcceptor, MidHandshakeSslStream, SslContextBuilder};
Expand Down Expand Up @@ -87,6 +88,15 @@ impl Certificate {
}
}

pub struct PrivateKey(PKey);

impl PrivateKey {
pub fn from_der(buf: &[u8]) -> Result<PrivateKey, Error> {
let key = try!(PKey::private_key_from_der(buf));
Ok(PrivateKey(key))
}
}

pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);

impl<S> fmt::Debug for MidHandshakeTlsStream<S>
Expand Down Expand Up @@ -254,6 +264,20 @@ impl TlsAcceptor {
Ok(TlsAcceptorBuilder(builder))
}

pub fn builder2(
key: PrivateKey,
cert: Certificate,
chain: Vec<Certificate>,
) -> Result<TlsAcceptorBuilder, Error> {
let builder = try!(SslAcceptorBuilder::mozilla_intermediate(
SslMethod::tls(),
&key.0,
&cert.0,
chain.iter().map(|c| &c.0),
));
Ok(TlsAcceptorBuilder(builder))
}

pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
Expand Down
107 changes: 75 additions & 32 deletions src/imp/schannel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ use std::io;
use std::fmt;
use std::error;
use std::sync::Arc;
use self::schannel::crypt_prov::{AcquireOptions, CryptProv, ProviderType};
use self::schannel::cert_store::{PfxImportOptions, Memory, CertStore, CertAdd};
use self::schannel::cert_context::CertContext;
use self::schannel::cert_context::{CertContext, KeySpec};
use self::schannel::schannel_cred::{Direction, SchannelCred, Protocol};
use self::schannel::tls_stream;

const CONTAINER_NAME: &'static str = "native-tls";

This comment has been minimized.

Copy link
@jethrogb

jethrogb Jun 24, 2017

How does namespacing work? Will using a constant work if you have multiple servers with different keys active at the same time?

This comment has been minimized.

Copy link
@sfackler

sfackler Jun 24, 2017

Author Owner

I'll pull in another key to confirm, but I believe containers can hold multiple keys at the same time.

This comment has been minimized.

Copy link
@sfackler

sfackler Jun 24, 2017

Author Owner

Looks like things do go bad actually - I'll switch to different containers per key.


fn convert_protocols(protocols: &[::Protocol]) -> Vec<Protocol> {
protocols
.iter()
protocols.iter()
.map(|p| match *p {
::Protocol::Sslv3 => Protocol::Ssl3,
::Protocol::Tlsv10 => Protocol::Tls10,
Expand Down Expand Up @@ -66,21 +68,17 @@ impl Pkcs12 {
.silent(true)
.compare_key(true)
.acquire()
.is_ok()
{
.is_ok() {
identity = Some(cert);
}
}

let identity = match identity {
Some(identity) => identity,
None => {
return Err(
io::Error::new(
io::ErrorKind::InvalidInput,
"No identity found in PKCS #12 archive",
).into(),
);
return Err(io::Error::new(io::ErrorKind::InvalidInput,
"No identity found in PKCS #12 archive")
.into());
}
};

Expand All @@ -97,20 +95,49 @@ impl Certificate {
}
}

pub struct PrivateKey(CryptProv);

impl PrivateKey {
pub fn from_der(buf: &[u8]) -> Result<PrivateKey, Error> {
let mut options = AcquireOptions::new();
options.container(CONTAINER_NAME)
.new_keyset(true);
let type_ = ProviderType::rsa_full();

// this is kind of a mess - we have to tell WinAPI to either open an
// existing container or create a new one, but there's no "open or
// create" option. If you try to create it and it exists it'll error
// and if you try to open it and it doesn't exist it'll error. We first
// try to open an existing one, then try to create it, then finally try
// to open it in case a parallel caller created it concurrently.
let mut container = match options.acquire(type_) {
Ok(container) => container,
Err(_) => {
match options.new_keyset(true).acquire(type_) {
Ok(container) => container,
Err(_) => options.new_keyset(false).acquire(type_)?,
}
}
};

container.import().import(buf)?;

Ok(PrivateKey(container))
}
}

pub struct MidHandshakeTlsStream<S>(tls_stream::MidHandshakeTlsStream<S>);

impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
where S: fmt::Debug
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}

impl<S> MidHandshakeTlsStream<S>
where
S: io::Read + io::Write,
where S: io::Read + io::Write
{
pub fn get_ref(&self) -> &S {
self.0.get_ref()
Expand Down Expand Up @@ -192,26 +219,22 @@ impl TlsConnector {
}

pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
where S: io::Read + io::Write
{
self._connect(Some(domain), stream)
}

pub fn connect_no_domain<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
where S: io::Read + io::Write
{
self._connect(None, stream)
}

fn _connect<S>(
&self,
domain: Option<&str>,
stream: S,
) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
fn _connect<S>(&self,
domain: Option<&str>,
stream: S)
-> Result<TlsStream<S>, HandshakeError<S>>
where S: io::Read + io::Write
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(&self.protocols);
Expand Down Expand Up @@ -262,9 +285,31 @@ impl TlsAcceptor {
}))
}

pub fn builder2(_key: PrivateKey,
cert: Certificate,
chain: Vec<Certificate>)
-> Result<TlsAcceptorBuilder, Error> {
let mut store = try!(Memory::new()).into_store();
for cert in chain {
try!(store.add_cert(&cert.0, CertAdd::ReplaceExisting));
}
let cert = try!(store.add_cert(&cert.0, CertAdd::ReplaceExisting));

try!(cert.set_key_prov_info()
.container(CONTAINER_NAME)
.type_(ProviderType::rsa_full())
.keep_open(true)
.key_spec(KeySpec::key_exchange())
.set());

Ok(TlsAcceptorBuilder(TlsAcceptor {
cert: cert,
protocols: vec![Protocol::Tls10, Protocol::Tls11, Protocol::Tls12],
}))
}

pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
where S: io::Read + io::Write
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(&self.protocols);
Expand Down Expand Up @@ -326,14 +371,12 @@ pub trait TlsConnectorBuilderExt {
/// Sets a callback function which decides if the server's certificate chain
/// is to be trusted.
fn verify_callback<F>(&mut self, callback: F)
where
F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync;
where F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync;
}

impl TlsConnectorBuilderExt for ::TlsConnectorBuilder {
fn verify_callback<F>(&mut self, callback: F)
where
F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync,
where F: Fn(tls_stream::CertValidationResult) -> io::Result<()> + 'static + Send + Sync
{
(self.0).0.callback = Some(Arc::new(callback));
}
Expand Down
72 changes: 57 additions & 15 deletions src/imp/security_framework.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ extern crate tempdir;
use self::security_framework::base;
use self::security_framework::certificate::SecCertificate;
use self::security_framework::identity::SecIdentity;
use self::security_framework::import_export::Pkcs12ImportOptions;
use self::security_framework::keychain::SecKeychain;
use self::security_framework::import_export::{Pkcs12ImportOptions};
use self::security_framework::secure_transport::{self, SslContext, ProtocolSide, ConnectionType,
SslProtocol, ClientBuilder};
use self::security_framework::os::macos::identity::SecIdentityExt;
use self::security_framework::os::macos::import_export::ImportOptions;
use self::security_framework::os::macos::keychain::{self, KeychainSettings};
use self::security_framework_sys::base::errSecIO;
use self::tempdir::TempDir;
Expand Down Expand Up @@ -71,6 +74,21 @@ impl From<base::Error> for Error {
}
}

fn temp_keychain(pass: &str) -> Result<(SecKeychain, TempDir), Error> {
let dir = match TempDir::new("native-tls") {
Ok(dir) => dir,
Err(_) => return Err(Error(base::Error::from(errSecIO))),
};

let mut keychain = try!(keychain::CreateOptions::new().password(pass).create(
dir.path().join("tmp.keychain"),
));
// disable lock on sleep and timeouts
try!(keychain.set_settings(&KeychainSettings::new()));

Ok((keychain, dir))
}

#[derive(Clone)]
pub struct Pkcs12 {
identity: SecIdentity,
Expand All @@ -79,16 +97,7 @@ pub struct Pkcs12 {

impl Pkcs12 {
pub fn from_der(buf: &[u8], pass: &str) -> Result<Pkcs12, Error> {
let dir = match TempDir::new("native-tls") {
Ok(dir) => dir,
Err(_) => return Err(Error(base::Error::from(errSecIO))),
};

let mut keychain = try!(keychain::CreateOptions::new().password(pass).create(
dir.path().join("tmp.keychain"),
));
// disable lock on sleep and timeouts
try!(keychain.set_settings(&KeychainSettings::new()));
let (keychain, _dir) = try!(temp_keychain(pass));

let mut imports = try!(
Pkcs12ImportOptions::new()
Expand Down Expand Up @@ -121,6 +130,23 @@ impl Certificate {
}
}

pub struct PrivateKey(SecKeychain);

impl PrivateKey {
pub fn from_der(buf: &[u8]) -> Result<PrivateKey, Error> {
let (mut keychain, _dir) = try!(temp_keychain(""));

try!(
ImportOptions::new()
.filename(".der")
.keychain(&mut keychain)
.import(buf)
);

Ok(PrivateKey(keychain))
}
}

pub enum HandshakeError<S> {
Interrupted(MidHandshakeTlsStream<S>),
Failure(Error),
Expand Down Expand Up @@ -303,14 +329,30 @@ impl TlsAcceptorBuilder {

#[derive(Clone)]
pub struct TlsAcceptor {
pkcs12: Pkcs12,
identity: SecIdentity,
chain: Vec<SecCertificate>,
protocols: Vec<Protocol>,
}

impl TlsAcceptor {
pub fn builder(pkcs12: Pkcs12) -> Result<TlsAcceptorBuilder, Error> {
Ok(TlsAcceptorBuilder(TlsAcceptor {
pkcs12: pkcs12,
identity: pkcs12.identity,
chain: pkcs12.chain,
protocols: vec![Protocol::Tlsv10, Protocol::Tlsv11, Protocol::Tlsv12],
}))
}

pub fn builder2(
key: PrivateKey,
cert: Certificate,
chain: Vec<Certificate>,
) -> Result<TlsAcceptorBuilder, Error> {
let identity = try!(SecIdentity::with_certificate(&[key.0], &cert.0));
let chain = chain.into_iter().map(|c| c.0).collect();
Ok(TlsAcceptorBuilder(TlsAcceptor {
identity: identity,
chain: chain,
protocols: vec![Protocol::Tlsv10, Protocol::Tlsv11, Protocol::Tlsv12],
}))
}
Expand All @@ -328,8 +370,8 @@ impl TlsAcceptor {
try!(ctx.set_protocol_version_min(min));
try!(ctx.set_protocol_version_max(max));
try!(ctx.set_certificate(
&self.pkcs12.identity,
&self.pkcs12.chain,
&self.identity,
&self.chain,
));
match ctx.handshake(stream) {
Ok(s) => Ok(TlsStream(s)),
Expand Down
Loading

0 comments on commit c7bc56e

Please sign in to comment.