From 1ceb8a833edf3225249bd4514e665d885b1f096d Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Fri, 21 Oct 2022 18:21:54 +0800 Subject: [PATCH 01/12] update to edition 2021 generate by cargo `clippy --release --fix --allow-dirty --edition --all-features` delete obsolete extern crate restirct native-tls with !wasm32 --- Cargo.toml | 1 + src/bin/mstsc-rs.rs | 428 ++++++++++++++++---------- src/codec/mod.rs | 2 +- src/codec/rle.rs | 673 ++++++++++++++++++++++------------------- src/core/capability.rs | 86 ++++-- src/core/client.rs | 85 +++--- src/core/event.rs | 66 ++-- src/core/gcc.rs | 118 ++++---- src/core/global.rs | 482 ++++++++++++++++++++--------- src/core/license.rs | 37 ++- src/core/mcs.rs | 254 ++++++++++++---- src/core/mod.rs | 14 +- src/core/per.rs | 72 +++-- src/core/sec.rs | 46 +-- src/core/tpkt.rs | 97 +++--- src/core/x224.rs | 143 ++++++--- src/lib.rs | 22 +- src/model/data.rs | 139 +++++---- src/model/error.rs | 53 ++-- src/model/link.rs | 46 +-- src/model/mod.rs | 2 +- src/model/rnd.rs | 2 +- src/model/unicode.rs | 6 +- src/nla/asn1.rs | 62 ++-- src/nla/cssp.rs | 98 ++++-- src/nla/mod.rs | 2 +- src/nla/ntlm.rs | 496 ++++++++++++++++++++++-------- src/nla/rc4.rs | 17 +- src/nla/sspi.rs | 4 +- 29 files changed, 2228 insertions(+), 1325 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 07d020c..93ea77d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ keywords = ["rdp", "security", "network", "windows"] categories = ["network"] license = "MIT" documentation = "https://docs.rs/rdp-rs" +edition = "2021" [lib] name = "rdp" diff --git a/src/bin/mstsc-rs.rs b/src/bin/mstsc-rs.rs index e9328e1..a2d4488 100644 --- a/src/bin/mstsc-rs.rs +++ b/src/bin/mstsc-rs.rs @@ -1,40 +1,30 @@ -#[cfg(target_os = "windows")] -extern crate winapi; +use clap::{App, Arg, ArgMatches}; #[cfg(any(target_os = "linux", target_os = "macos"))] -extern crate libc; -extern crate minifb; -extern crate rdp; -extern crate hex; -extern crate clap; -extern crate hmac; - -use minifb::{Key, Window, WindowOptions, MouseMode, MouseButton, KeyRepeat}; -use std::net::{SocketAddr, TcpStream}; +use libc::{fd_set, select, FD_SET}; +use minifb::{Key, KeyRepeat, MouseButton, MouseMode, Window, WindowOptions}; +use rdp::core::client::{Connector, RdpClient}; +use rdp::core::event::{BitmapEvent, KeyboardEvent, PointerButton, PointerEvent, RdpEvent}; +use rdp::core::gcc::KeyboardLayout; +use rdp::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use std::convert::TryFrom; use std::io::{Read, Write}; -use std::time::{Instant}; -use std::ptr; use std::mem; -use std::mem::{size_of, forget}; -use rdp::core::client::{RdpClient, Connector}; -#[cfg(target_os = "windows")] -use winapi::um::winsock2::{select, fd_set}; +use std::mem::{forget, size_of}; +use std::net::{SocketAddr, TcpStream}; #[cfg(any(target_os = "linux", target_os = "macos"))] -use libc::{select, fd_set, FD_SET}; +use std::os::unix::io::AsRawFd; #[cfg(target_os = "windows")] -use std::os::windows::io::{AsRawSocket}; -#[cfg(any(target_os = "linux", target_os = "macos"))] -use std::os::unix::io::{AsRawFd}; -use rdp::core::event::{RdpEvent, BitmapEvent, PointerEvent, PointerButton, KeyboardEvent}; +use std::os::windows::io::AsRawSocket; +use std::ptr; use std::ptr::copy_nonoverlapping; -use std::convert::TryFrom; -use std::thread; -use std::sync::{mpsc, Arc, Mutex}; -use std::thread::{JoinHandle}; use std::sync::atomic::{AtomicBool, Ordering}; -use rdp::model::error::{Error, RdpErrorKind, RdpError, RdpResult}; -use clap::{Arg, App, ArgMatches}; -use rdp::core::gcc::KeyboardLayout; use std::sync::mpsc::{Receiver, Sender}; +use std::sync::{mpsc, Arc, Mutex}; +use std::thread; +use std::thread::JoinHandle; +use std::time::Instant; +#[cfg(target_os = "windows")] +use winapi::um::winsock2::{fd_set, select}; const APPLICATION_NAME: &str = "mstsc-rs"; @@ -46,7 +36,13 @@ fn wait_for_fd(fd: usize) -> bool { let mut raw_fds: fd_set = mem::zeroed(); raw_fds.fd_array[0] = fd; raw_fds.fd_count = 1; - let result = select(0, &mut raw_fds, ptr::null_mut(), ptr::null_mut(), ptr::null()); + let result = select( + 0, + &mut raw_fds, + ptr::null_mut(), + ptr::null_mut(), + ptr::null(), + ); result == 1 } } @@ -57,8 +53,14 @@ fn wait_for_fd(fd: usize) -> bool { let mut raw_fds: fd_set = mem::zeroed(); FD_SET(fd as i32, &mut raw_fds); - - let result = select(fd as i32 + 1, &mut raw_fds, ptr::null_mut(), ptr::null_mut(), ptr::null_mut()); + + let result = select( + fd as i32 + 1, + &mut raw_fds, + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ); result == 1 } } @@ -76,7 +78,7 @@ pub unsafe fn transmute_vec(mut vec: Vec) -> Vec { /// Copy a bitmap event into the buffer /// This function use unsafe copy /// to accelerate data transfer -fn fast_bitmap_transfer(buffer: &mut Vec, width: usize, bitmap: BitmapEvent) -> RdpResult<()>{ +fn fast_bitmap_transfer(buffer: &mut Vec, width: usize, bitmap: BitmapEvent) -> RdpResult<()> { let bitmap_dest_left = bitmap.dest_left as usize; let bitmap_dest_right = bitmap.dest_right as usize; let bitmap_dest_bottom = bitmap.dest_bottom as usize; @@ -88,15 +90,26 @@ fn fast_bitmap_transfer(buffer: &mut Vec, width: usize, bitmap: BitmapEvent // Use some unsafe method to faster // data transfer between buffers unsafe { - let data_aligned :Vec = transmute_vec(data); + let data_aligned: Vec = transmute_vec(data); for i in 0..(bitmap_dest_bottom - bitmap_dest_top + 1) { let dest_i = (i + bitmap_dest_top) * width + bitmap_dest_left; let src_i = i * bitmap_width; let count = bitmap_dest_right - bitmap_dest_left + 1; - if dest_i > buffer.len() || dest_i + count > buffer.len() || src_i > data_aligned.len() || src_i + count > data_aligned.len() { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "Image have invalide size"))) + if dest_i > buffer.len() + || dest_i + count > buffer.len() + || src_i > data_aligned.len() + || src_i + count > data_aligned.len() + { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "Image have invalide size", + ))); } - copy_nonoverlapping(data_aligned.as_ptr().offset((src_i) as isize), buffer.as_mut_ptr().offset(dest_i as isize), count) + copy_nonoverlapping( + data_aligned.as_ptr().offset((src_i) as isize), + buffer.as_mut_ptr().offset(dest_i as isize), + count, + ) } } @@ -225,22 +238,32 @@ fn to_scancode(key: Key) -> u16 { Key::LeftSuper => 0xE05B, Key::RightSuper => 0xE05C, Key::Menu => 0xE05D, - _ => panic!("foo") + _ => panic!("foo"), } } /// Create a tcp stream from main args fn tcp_from_args(args: &ArgMatches) -> RdpResult { - let ip = args.value_of("host").expect("You need to provide a target argument"); + let ip = args + .value_of("host") + .expect("You need to provide a target argument"); let port = args.value_of("port").unwrap_or_default(); // TCP connection - let addr = format!("{}:{}", ip, port).parse::().map_err( |e| { - Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, &format!("Cannot parse the IP PORT input [{}]", e))) - })?; + let addr = format!("{}:{}", ip, port) + .parse::() + .map_err(|e| { + Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + &format!("Cannot parse the IP PORT input [{}]", e), + )) + })?; let tcp = TcpStream::connect(&addr).unwrap(); tcp.set_nodelay(true).map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, &format!("Unable to set no delay option [{}]", e))) + Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + &format!("Unable to set no delay option [{}]", e), + )) })?; Ok(tcp) @@ -248,13 +271,26 @@ fn tcp_from_args(args: &ArgMatches) -> RdpResult { /// Create rdp client from args fn rdp_from_args(args: &ArgMatches, stream: S) -> RdpResult> { - - let width = args.value_of("width").unwrap_or_default().parse().map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, &format!("Cannot parse the input width argument [{}]", e))) - })?; - let height = args.value_of("height").unwrap_or_default().parse().map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, &format!("Cannot parse the input height argument [{}]", e))) - })?; + let width = args + .value_of("width") + .unwrap_or_default() + .parse() + .map_err(|e| { + Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + &format!("Cannot parse the input width argument [{}]", e), + )) + })?; + let height = args + .value_of("height") + .unwrap_or_default() + .parse() + .map_err(|e| { + Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + &format!("Cannot parse the input height argument [{}]", e), + )) + })?; let domain = args.value_of("domain").unwrap_or_default(); let username = args.value_of("username").unwrap_or_default(); let password = args.value_of("password").unwrap_or_default(); @@ -267,9 +303,13 @@ fn rdp_from_args(args: &ArgMatches, stream: S) -> RdpResult(args: &ArgMatches, stream: S) -> RdpResult(args: &ArgMatches, stream: S) -> RdpResult RdpResult { - let width = args.value_of("width").unwrap_or_default().parse().map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, &format!("Cannot parse the input width argument [{}]", e))) - })?; - let height = args.value_of("height").unwrap_or_default().parse().map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, &format!("Cannot parse the input height argument [{}]", e))) - })?; + let width = args + .value_of("width") + .unwrap_or_default() + .parse() + .map_err(|e| { + Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + &format!("Cannot parse the input width argument [{}]", e), + )) + })?; + let height = args + .value_of("height") + .unwrap_or_default() + .parse() + .map_err(|e| { + Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + &format!("Cannot parse the input height argument [{}]", e), + )) + })?; let window = Window::new( "mstsc-rs Remote Desktop in Rust", width, height, WindowOptions::default(), - ).map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::Unknown, &format!("Unable to create window [{}]", e))) + ) + .map_err(|e| { + Error::RdpError(RdpError::new( + RdpErrorKind::Unknown, + &format!("Unable to create window [{}]", e), + )) })?; Ok(window) @@ -319,24 +380,23 @@ fn launch_rdp_thread( handle: usize, rdp_client: Arc>>, sync: Arc, - bitmap_channel: Sender) -> RdpResult> { + bitmap_channel: Sender, +) -> RdpResult> { // Create the rdp thread Ok(thread::spawn(move || { while wait_for_fd(handle as usize) && sync.load(Ordering::Relaxed) { let mut guard = rdp_client.lock().unwrap(); - if let Err(Error::RdpError(e)) = guard.read(|event| { - match event { - RdpEvent::Bitmap(bitmap) => { - bitmap_channel.send(bitmap).unwrap(); - }, - _ => println!("{}: ignore event", APPLICATION_NAME) + if let Err(Error::RdpError(e)) = guard.read(|event| match event { + RdpEvent::Bitmap(bitmap) => { + bitmap_channel.send(bitmap).unwrap(); } + _ => println!("{}: ignore event", APPLICATION_NAME), }) { match e.kind() { RdpErrorKind::Disconnect => { println!("{}: Server ask for disconnect", APPLICATION_NAME); - }, - _ => println!("{}: {:?}", APPLICATION_NAME, e) + } + _ => println!("{}: {:?}", APPLICATION_NAME, e), } break; } @@ -351,8 +411,8 @@ fn main_gui_loop( mut window: Window, rdp_client: Arc>>, sync: Arc, - bitmap_receiver: Receiver) -> RdpResult<()> { - + bitmap_receiver: Receiver, +) -> RdpResult<()> { let (width, height) = window.get_size(); // Now we continue with the graphical main thread // Limit to max ~60 fps update rate @@ -378,7 +438,7 @@ fn main_gui_loop( Err(mpsc::TryRecvError::Empty) => break, Err(mpsc::TryRecvError::Disconnected) => { sync.store(false, Ordering::Relaxed); - break + break; } }; } @@ -386,19 +446,24 @@ fn main_gui_loop( // Mouse position input if let Some((x, y)) = window.get_mouse_pos(MouseMode::Clamp) { let mut rdp_client_guard = rdp_client.lock().map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::Unknown, &format!("Thread error during access to mutex [{}]", e))) + Error::RdpError(RdpError::new( + RdpErrorKind::Unknown, + &format!("Thread error during access to mutex [{}]", e), + )) })?; // Button is down if not 0 let current_button = get_rdp_pointer_down(&window); - rdp_client_guard.try_write(RdpEvent::Pointer( - PointerEvent{ - x: x as u16, - y: y as u16, - button: if last_button == current_button { PointerButton::None } else { PointerButton::try_from(last_button as u8 | current_button as u8).unwrap() }, - down: (last_button != current_button) && last_button == PointerButton::None - }) - )?; + rdp_client_guard.try_write(RdpEvent::Pointer(PointerEvent { + x: x as u16, + y: y as u16, + button: if last_button == current_button { + PointerButton::None + } else { + PointerButton::try_from(last_button as u8 | current_button as u8).unwrap() + }, + down: (last_button != current_button) && last_button == PointerButton::None, + }))?; last_button = current_button; } @@ -409,23 +474,19 @@ fn main_gui_loop( for key in last_keys.iter() { if !keys.contains(key) { - rdp_client_guard.try_write(RdpEvent::Key( - KeyboardEvent { - code: to_scancode(*key), - down: false - }) - )? + rdp_client_guard.try_write(RdpEvent::Key(KeyboardEvent { + code: to_scancode(*key), + down: false, + }))? } } for key in keys.iter() { - if window.is_key_pressed(*key, KeyRepeat::Yes){ - rdp_client_guard.try_write(RdpEvent::Key( - KeyboardEvent { - code: to_scancode(*key), - down: true - }) - )? + if window.is_key_pressed(*key, KeyRepeat::Yes) { + rdp_client_guard.try_write(RdpEvent::Key(KeyboardEvent { + code: to_scancode(*key), + down: true, + }))? } } @@ -433,9 +494,14 @@ fn main_gui_loop( } // We unwrap here as we want this code to exit if it fails. Real applications may want to handle this in a different way - window.update_with_buffer(&buffer, width, height).map_err(|e| { - Error::RdpError(RdpError::new(RdpErrorKind::Unknown, &format!("Unable to update screen buffer [{}]", e))) - })?; + window + .update_with_buffer(&buffer, width, height) + .map_err(|e| { + Error::RdpError(RdpError::new( + RdpErrorKind::Unknown, + &format!("Unable to update screen buffer [{}]", e), + )) + })?; } sync.store(false, Ordering::Relaxed); @@ -449,68 +515,98 @@ fn main() { .version("0.1.0") .author("Sylvain Peyrefitte ") .about("Secure Remote Desktop Client in RUST") - .arg(Arg::with_name("host") - .long("host") - .takes_value(true) - .help("host IP of the target machine")) - .arg(Arg::with_name("port") - .long("port") - .takes_value(true) - .default_value("3389") - .help("Destination Port")) - .arg(Arg::with_name("width") - .long("width") - .takes_value(true) - .default_value("800") - .help("Screen width")) - .arg(Arg::with_name("height") - .long("height") - .takes_value(true) - .default_value("600") - .help("Screen height")) - .arg(Arg::with_name("domain") - .long("domain") - .takes_value(true) - .default_value("") - .help("Windows domain")) - .arg(Arg::with_name("username") - .long("user") - .takes_value(true) - .default_value("") - .help("Username")) - .arg(Arg::with_name("password") - .long("password") - .takes_value(true) - .default_value("") - .help("Password")) - .arg(Arg::with_name("hash") - .long("hash") - .takes_value(true) - .help("NTLM Hash")) - .arg(Arg::with_name("admin") - .long("admin") - .help("Restricted admin mode")) - .arg(Arg::with_name("layout") - .long("layout") - .takes_value(true) - .default_value("us") - .help("Keyboard layout: us or fr")) - .arg(Arg::with_name("auto_logon") - .long("auto") - .help("AutoLogon mode in case of SSL nego")) - .arg(Arg::with_name("blank_creds") - .long("blank") - .help("Do not send credentials at the last CredSSP payload")) - .arg(Arg::with_name("check_certificate") - .long("check") - .help("Check the target SSL certificate")) - .arg(Arg::with_name("disable_nla") - .long("ssl") - .help("Disable Network Level Authentication and only use SSL")) - .arg(Arg::with_name("name") - .long("name") - .default_value("mstsc-rs") - .help("Name of the client send to the server")) + .arg( + Arg::with_name("host") + .long("host") + .takes_value(true) + .help("host IP of the target machine"), + ) + .arg( + Arg::with_name("port") + .long("port") + .takes_value(true) + .default_value("3389") + .help("Destination Port"), + ) + .arg( + Arg::with_name("width") + .long("width") + .takes_value(true) + .default_value("800") + .help("Screen width"), + ) + .arg( + Arg::with_name("height") + .long("height") + .takes_value(true) + .default_value("600") + .help("Screen height"), + ) + .arg( + Arg::with_name("domain") + .long("domain") + .takes_value(true) + .default_value("") + .help("Windows domain"), + ) + .arg( + Arg::with_name("username") + .long("user") + .takes_value(true) + .default_value("") + .help("Username"), + ) + .arg( + Arg::with_name("password") + .long("password") + .takes_value(true) + .default_value("") + .help("Password"), + ) + .arg( + Arg::with_name("hash") + .long("hash") + .takes_value(true) + .help("NTLM Hash"), + ) + .arg( + Arg::with_name("admin") + .long("admin") + .help("Restricted admin mode"), + ) + .arg( + Arg::with_name("layout") + .long("layout") + .takes_value(true) + .default_value("us") + .help("Keyboard layout: us or fr"), + ) + .arg( + Arg::with_name("auto_logon") + .long("auto") + .help("AutoLogon mode in case of SSL nego"), + ) + .arg( + Arg::with_name("blank_creds") + .long("blank") + .help("Do not send credentials at the last CredSSP payload"), + ) + .arg( + Arg::with_name("check_certificate") + .long("check") + .help("Check the target SSL certificate"), + ) + .arg( + Arg::with_name("disable_nla") + .long("ssl") + .help("Disable Network Level Authentication and only use SSL"), + ) + .arg( + Arg::with_name("name") + .long("name") + .default_value("mstsc-rs") + .help("Name of the client send to the server"), + ) .get_matches(); // Create a tcp stream from args @@ -543,16 +639,12 @@ fn main() { handle as usize, Arc::clone(&rdp_client_mutex), Arc::clone(&sync), - bitmap_sender - ).unwrap(); + bitmap_sender, + ) + .unwrap(); // Launch the GUI - main_gui_loop( - window, - rdp_client_mutex, - sync, - bitmap_receiver - ).unwrap(); + main_gui_loop(window, rdp_client_mutex, sync, bitmap_receiver).unwrap(); rdp_thread.join().unwrap(); } diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 97d40b6..d09d240 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -1 +1 @@ -pub mod rle; \ No newline at end of file +pub mod rle; diff --git a/src/codec/rle.rs b/src/codec/rle.rs index 8e7f684..b44e908 100644 --- a/src/codec/rle.rs +++ b/src/codec/rle.rs @@ -1,346 +1,405 @@ -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use byteorder::{LittleEndian, ReadBytesExt}; use std::io::{Cursor, Read}; -use byteorder::{ReadBytesExt, LittleEndian}; /// All this uncompress code /// Are directly inspired from the source code /// of rdesktop and diretly port to rust /// Need a little bit of refactoring for rust -fn process_plane(input: &mut dyn Read, width: u32, height: u32, output: &mut [u8]) -> RdpResult<()> { +fn process_plane( + input: &mut dyn Read, + width: u32, + height: u32, + output: &mut [u8], +) -> RdpResult<()> { let mut indexw; - let mut indexh= 0; - let mut code ; - let mut collen; - let mut replen; - let mut color:i8; - let mut x; - let mut revcode; + let mut indexh = 0; + let mut code; + let mut collen; + let mut replen; + let mut color: i8; + let mut x; + let mut revcode; let mut this_line: u32; let mut last_line: u32 = 0; - while indexh < height { - let mut out = (width * height * 4) - ((indexh + 1) * width * 4); - color = 0; - this_line = out; - indexw = 0; - if last_line == 0 { - while indexw < width { - code = input.read_u8()?; - replen = code & 0xf; - collen = (code >> 4) & 0xf; - revcode = (replen << 4) | collen; - if (revcode <= 47) && (revcode >= 16) { - replen = revcode; - collen = 0; - } - while collen > 0 { - color = input.read_u8()? as i8; - output[out as usize] = color as u8; - out += 4; - indexw += 1; - collen -= 1; - } - while replen > 0 { - output[out as usize] = color as u8; - out += 4; - indexw += 1; - replen -= 1; - } - } - } - else - { - while indexw < width { - code = input.read_u8()?; - replen = code & 0xf; - collen = (code >> 4) & 0xf; - revcode = (replen << 4) | collen; - if (revcode <= 47) && (revcode >= 16) { - replen = revcode; - collen = 0; - } - while collen > 0 { - x = input.read_u8()?; - if x & 1 != 0{ - x = x >> 1; - x = x + 1; - color = -(x as i32) as i8; - } - else - { - x = x >> 1; - color = x as i8; - } - x = (output[(last_line + (indexw * 4)) as usize] as i32 + color as i32) as u8; - output[out as usize] = x; - out += 4; - indexw += 1; - collen -= 1; - } - while replen > 0 { - x = (output[(last_line + (indexw * 4)) as usize] as i32 + color as i32) as u8; - output[out as usize] = x; - out += 4; - indexw += 1; - replen -= 1; - } - } - } - indexh += 1; - last_line = this_line; - } + while indexh < height { + let mut out = (width * height * 4) - ((indexh + 1) * width * 4); + color = 0; + this_line = out; + indexw = 0; + if last_line == 0 { + while indexw < width { + code = input.read_u8()?; + replen = code & 0xf; + collen = (code >> 4) & 0xf; + revcode = (replen << 4) | collen; + if (revcode <= 47) && (revcode >= 16) { + replen = revcode; + collen = 0; + } + while collen > 0 { + color = input.read_u8()? as i8; + output[out as usize] = color as u8; + out += 4; + indexw += 1; + collen -= 1; + } + while replen > 0 { + output[out as usize] = color as u8; + out += 4; + indexw += 1; + replen -= 1; + } + } + } else { + while indexw < width { + code = input.read_u8()?; + replen = code & 0xf; + collen = (code >> 4) & 0xf; + revcode = (replen << 4) | collen; + if (revcode <= 47) && (revcode >= 16) { + replen = revcode; + collen = 0; + } + while collen > 0 { + x = input.read_u8()?; + if x & 1 != 0 { + x = x >> 1; + x = x + 1; + color = -(x as i32) as i8; + } else { + x = x >> 1; + color = x as i8; + } + x = (output[(last_line + (indexw * 4)) as usize] as i32 + color as i32) as u8; + output[out as usize] = x; + out += 4; + indexw += 1; + collen -= 1; + } + while replen > 0 { + x = (output[(last_line + (indexw * 4)) as usize] as i32 + color as i32) as u8; + output[out as usize] = x; + out += 4; + indexw += 1; + replen -= 1; + } + } + } + indexh += 1; + last_line = this_line; + } Ok(()) } /// Run length encoding decoding function for 32 bpp -pub fn rle_32_decompress(input: &[u8], width: u32, height: u32, output: &mut [u8]) -> RdpResult<()> { +pub fn rle_32_decompress( + input: &[u8], + width: u32, + height: u32, + output: &mut [u8], +) -> RdpResult<()> { let mut input_cursor = Cursor::new(input); - if input_cursor.read_u8()? != 0x10 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, "Bad header"))) - } + if input_cursor.read_u8()? != 0x10 { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + "Bad header", + ))); + } - process_plane(&mut input_cursor, width, height, &mut output[3..])?; - process_plane(&mut input_cursor, width, height, &mut output[2..])?; - process_plane(&mut input_cursor, width, height, &mut output[1..])?; - process_plane(&mut input_cursor, width, height, &mut output[0..])?; + process_plane(&mut input_cursor, width, height, &mut output[3..])?; + process_plane(&mut input_cursor, width, height, &mut output[2..])?; + process_plane(&mut input_cursor, width, height, &mut output[1..])?; + process_plane(&mut input_cursor, width, height, &mut output[0..])?; - Ok(()) + Ok(()) } macro_rules! repeat { ($expr:expr, $count:expr, $x:expr, $width:expr) => { - while (($count & !0x7) != 0) && ($x + 8) < $width { - $expr; $count -= 1; $x += 1; - $expr; $count -= 1; $x += 1; - $expr; $count -= 1; $x += 1; - $expr; $count -= 1; $x += 1; - $expr; $count -= 1; $x += 1; - $expr; $count -= 1; $x += 1; - $expr; $count -= 1; $x += 1; - $expr; $count -= 1; $x += 1; - } - while $count > 0 && $x < $width { - $expr; - $count -= 1; - $x += 1; - } + while (($count & !0x7) != 0) && ($x + 8) < $width { + $expr; + $count -= 1; + $x += 1; + $expr; + $count -= 1; + $x += 1; + $expr; + $count -= 1; + $x += 1; + $expr; + $count -= 1; + $x += 1; + $expr; + $count -= 1; + $x += 1; + $expr; + $count -= 1; + $x += 1; + $expr; + $count -= 1; + $x += 1; + $expr; + $count -= 1; + $x += 1; + } + while $count > 0 && $x < $width { + $expr; + $count -= 1; + $x += 1; + } }; } -pub fn rle_16_decompress(input: &[u8], width: usize, mut height: usize, output: &mut [u16]) -> RdpResult<()> { - let mut input_cursor = Cursor::new(input); +pub fn rle_16_decompress( + input: &[u8], + width: usize, + mut height: usize, + output: &mut [u16], +) -> RdpResult<()> { + let mut input_cursor = Cursor::new(input); - let mut code: u8; - let mut opcode: u8; - let mut lastopcode: u8 = 0xFF; - let mut count: u16; - let mut offset: u16; - let mut isfillormix; - let mut insertmix = false; - let mut x: usize = width; - let mut prevline : Option = None; - let mut line : Option = None; - let mut colour1= 0; - let mut colour2 = 0; - let mut mix = 0xffff; - let mut mask:u8 = 0; - let mut fom_mask : u8; - let mut mixmask:u8; - let mut bicolour = false; + let mut code: u8; + let mut opcode: u8; + let mut lastopcode: u8 = 0xFF; + let mut count: u16; + let mut offset: u16; + let mut isfillormix; + let mut insertmix = false; + let mut x: usize = width; + let mut prevline: Option = None; + let mut line: Option = None; + let mut colour1 = 0; + let mut colour2 = 0; + let mut mix = 0xffff; + let mut mask: u8 = 0; + let mut fom_mask: u8; + let mut mixmask: u8; + let mut bicolour = false; - while (input_cursor.position() as usize) < input.len() { - fom_mask = 0; - code = input_cursor.read_u8()?; - opcode = code >> 4; + while (input_cursor.position() as usize) < input.len() { + fom_mask = 0; + code = input_cursor.read_u8()?; + opcode = code >> 4; - match opcode { - 0xC | 0xD | 0xE => { - opcode -= 6; - count = (code & 0xf) as u16; - offset = 16; - } - 0xF => { - opcode = code & 0xf; - if opcode < 9 { - count = input_cursor.read_u16::()? - } else if opcode < 0xb { - count = 8 - } else { - count = 1 - } - offset = 0; - } - _ => { - opcode >>= 1; - count = (code & 0x1f) as u16; - offset = 32; - } - } + match opcode { + 0xC | 0xD | 0xE => { + opcode -= 6; + count = (code & 0xf) as u16; + offset = 16; + } + 0xF => { + opcode = code & 0xf; + if opcode < 9 { + count = input_cursor.read_u16::()? + } else if opcode < 0xb { + count = 8 + } else { + count = 1 + } + offset = 0; + } + _ => { + opcode >>= 1; + count = (code & 0x1f) as u16; + offset = 32; + } + } - if offset != 0 { - isfillormix = (opcode == 2) || (opcode == 7); - if count == 0 { - if isfillormix { - count = input_cursor.read_u8()? as u16 + 1; - } else { - count = input_cursor.read_u8()? as u16 + offset; - } - } else if isfillormix { - count <<= 3; - } - } + if offset != 0 { + isfillormix = (opcode == 2) || (opcode == 7); + if count == 0 { + if isfillormix { + count = input_cursor.read_u8()? as u16 + 1; + } else { + count = input_cursor.read_u8()? as u16 + offset; + } + } else if isfillormix { + count <<= 3; + } + } - match opcode { - 0 => { - if lastopcode == opcode && !(x == width && prevline == None) { - insertmix = true; - } - }, - 8 => { - colour1 = input_cursor.read_u16::()?; - colour2 = input_cursor.read_u16::()?; - }, - 3 => { - colour2 = input_cursor.read_u16::()?; - }, - 6 | 7 => { - mix = input_cursor.read_u16::()?; - opcode -= 5; - } - 9 => { - mask = 0x03; - opcode = 0x02; - fom_mask = 3; - }, - 0xa => { - mask = 0x05; - opcode = 0x02; - fom_mask = 5; - } - _ => () - } - lastopcode = opcode; - mixmask = 0; + match opcode { + 0 => { + if lastopcode == opcode && !(x == width && prevline == None) { + insertmix = true; + } + } + 8 => { + colour1 = input_cursor.read_u16::()?; + colour2 = input_cursor.read_u16::()?; + } + 3 => { + colour2 = input_cursor.read_u16::()?; + } + 6 | 7 => { + mix = input_cursor.read_u16::()?; + opcode -= 5; + } + 9 => { + mask = 0x03; + opcode = 0x02; + fom_mask = 3; + } + 0xa => { + mask = 0x05; + opcode = 0x02; + fom_mask = 5; + } + _ => (), + } + lastopcode = opcode; + mixmask = 0; - while count > 0 { - if x >= width { - if height <= 0 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "error during decompress"))) - } - x = 0; - height -= 1; - prevline = line; - line = Some(height * width); - } + while count > 0 { + if x >= width { + if height <= 0 { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "error during decompress", + ))); + } + x = 0; + height -= 1; + prevline = line; + line = Some(height * width); + } - match opcode { - 0 => { - if insertmix { - if let Some(e) = prevline { - output[line.unwrap() + x] = output[e + x] ^ mix; - } - else { - output[line.unwrap() + x] = mix; - } - insertmix = false; - count -= 1; - x += 1; - } + match opcode { + 0 => { + if insertmix { + if let Some(e) = prevline { + output[line.unwrap() + x] = output[e + x] ^ mix; + } else { + output[line.unwrap() + x] = mix; + } + insertmix = false; + count -= 1; + x += 1; + } - if let Some(e) = prevline { - repeat!(output[line.unwrap() + x] = output[e + x], count, x, width); - } - else { - repeat!(output[line.unwrap() + x] = 0, count, x, width); - } - }, - 1 => { - if let Some(e) = prevline { - repeat!(output[line.unwrap() + x] = output[e + x] ^ mix, count, x, width); - } - else { - repeat!(output[line.unwrap() + x] = mix, count, x, width); - } - }, - 2 => { - if let Some(e) = prevline { - repeat!({ - mixmask <<= 1; - if mixmask == 0 { - mask = if fom_mask != 0 { fom_mask } else { input_cursor.read_u8()? }; - mixmask = 1; - } - if (mask & mixmask) != 0 { - output[line.unwrap() + x] = output[e + x] ^ mix; - } - else { - output[line.unwrap() + x] = output[e + x]; - } - }, count, x, width); - } - else { - repeat!({ - mixmask <<= 1; - if mixmask == 0 { - mask = if fom_mask != 0 { fom_mask } else { input_cursor.read_u8()? }; - mixmask = 1; - } - if (mask & mixmask) != 0 { - output[line.unwrap() + x] = mix; - } - else { - output[line.unwrap() + x] = 0; - } - }, count, x, width); - } - }, - 3 => { - repeat!(output[line.unwrap() + x] = colour2, count, x, width); - }, - 4 => { - repeat!(output[line.unwrap() + x] = input_cursor.read_u16::()?, count, x, width); - }, - 8 => { - repeat!({ - if bicolour { - output[line.unwrap() + x] = colour2; - bicolour = false; - } else { - output[line.unwrap() + x] = colour1; - bicolour = true; - count += 1; - }; - }, count, x, width); - }, - 0xd => { - repeat!(output[line.unwrap() + x] = 0xffff, count, x, width); - }, - 0xe => { - repeat!(output[line.unwrap() + x] = 0, count, x, width); - } - _ => panic!("opcode") - } - } - } + if let Some(e) = prevline { + repeat!(output[line.unwrap() + x] = output[e + x], count, x, width); + } else { + repeat!(output[line.unwrap() + x] = 0, count, x, width); + } + } + 1 => { + if let Some(e) = prevline { + repeat!( + output[line.unwrap() + x] = output[e + x] ^ mix, + count, + x, + width + ); + } else { + repeat!(output[line.unwrap() + x] = mix, count, x, width); + } + } + 2 => { + if let Some(e) = prevline { + repeat!( + { + mixmask <<= 1; + if mixmask == 0 { + mask = if fom_mask != 0 { + fom_mask + } else { + input_cursor.read_u8()? + }; + mixmask = 1; + } + if (mask & mixmask) != 0 { + output[line.unwrap() + x] = output[e + x] ^ mix; + } else { + output[line.unwrap() + x] = output[e + x]; + } + }, + count, + x, + width + ); + } else { + repeat!( + { + mixmask <<= 1; + if mixmask == 0 { + mask = if fom_mask != 0 { + fom_mask + } else { + input_cursor.read_u8()? + }; + mixmask = 1; + } + if (mask & mixmask) != 0 { + output[line.unwrap() + x] = mix; + } else { + output[line.unwrap() + x] = 0; + } + }, + count, + x, + width + ); + } + } + 3 => { + repeat!(output[line.unwrap() + x] = colour2, count, x, width); + } + 4 => { + repeat!( + output[line.unwrap() + x] = input_cursor.read_u16::()?, + count, + x, + width + ); + } + 8 => { + repeat!( + { + if bicolour { + output[line.unwrap() + x] = colour2; + bicolour = false; + } else { + output[line.unwrap() + x] = colour1; + bicolour = true; + count += 1; + }; + }, + count, + x, + width + ); + } + 0xd => { + repeat!(output[line.unwrap() + x] = 0xffff, count, x, width); + } + 0xe => { + repeat!(output[line.unwrap() + x] = 0, count, x, width); + } + _ => panic!("opcode"), + } + } + } - Ok(()) + Ok(()) } - pub fn rgb565torgb32(input: &[u16], width: usize, height: usize) -> Vec { - let mut result_32_bpp = vec![0 as u8; width as usize * height as usize * 4]; - for i in 0..height { - for j in 0..width { - let index = (i * width + j) as usize; - let v = input[index]; - result_32_bpp[index * 4 + 3] = 0xff; - result_32_bpp[index * 4 + 2] = (((((v >> 11) & 0x1f) * 527) + 23) >> 6) as u8; - result_32_bpp[index * 4 + 1] = (((((v >> 5) & 0x3f) * 259) + 33) >> 6) as u8; - result_32_bpp[index * 4] = ((((v & 0x1f) * 527) + 23) >> 6) as u8; - } - } - result_32_bpp -} \ No newline at end of file + let mut result_32_bpp = vec![0 as u8; width as usize * height as usize * 4]; + for i in 0..height { + for j in 0..width { + let index = (i * width + j) as usize; + let v = input[index]; + result_32_bpp[index * 4 + 3] = 0xff; + result_32_bpp[index * 4 + 2] = (((((v >> 11) & 0x1f) * 527) + 23) >> 6) as u8; + result_32_bpp[index * 4 + 1] = (((((v >> 5) & 0x3f) * 259) + 33) >> 6) as u8; + result_32_bpp[index * 4] = ((((v & 0x1f) * 527) + 23) >> 6) as u8; + } + } + result_32_bpp +} diff --git a/src/core/capability.rs b/src/core/capability.rs index a3cf40e..0a166ab 100644 --- a/src/core/capability.rs +++ b/src/core/capability.rs @@ -1,9 +1,11 @@ -use model::data::{Component, U16, U32, DynOption, MessageOption, Message, DataType, Check, Trame, to_vec}; -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; -use std::io::Cursor; -use core::gcc::{KeyboardLayout, KeyboardType}; +use crate::core::gcc::{KeyboardLayout, KeyboardType}; +use crate::model::data::{ + to_vec, Check, Component, DataType, DynOption, Message, MessageOption, Trame, U16, U32, +}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use num_enum::TryFromPrimitive; use std::convert::TryFrom; +use std::io::Cursor; /// All capabilities that can be negotiated /// between client and server @@ -38,7 +40,7 @@ pub enum CapabilitySetType { CapsettypeLargePointer = 0x001B, CapsettypeSurfaceCommands = 0x001C, CapsettypeBitmapCodecs = 0x001D, - CapssettypeFrameAcknowledge = 0x001E + CapssettypeFrameAcknowledge = 0x001E, } /// A capability @@ -58,7 +60,7 @@ pub enum CapabilitySetType { /// ``` pub struct Capability { pub cap_type: CapabilitySetType, - pub message: Component + pub message: Component, } impl Capability { @@ -78,7 +80,10 @@ impl Capability { /// # } /// ``` pub fn from_capability_set(capability_set: &Component) -> RdpResult { - let cap_type = CapabilitySetType::try_from(cast!(DataType::U16, capability_set["capabilitySetType"])?)?; + let cap_type = CapabilitySetType::try_from(cast!( + DataType::U16, + capability_set["capabilitySetType"] + )?)?; let mut capability = match cap_type { CapabilitySetType::CapstypeGeneral => ts_general_capability_set(None), CapabilitySetType::CapstypeBitmap => ts_bitmap_capability_set(None, None, None), @@ -91,12 +96,20 @@ impl Capability { CapabilitySetType::CapstypeOffscreencache => ts_offscreen_capability_set(), CapabilitySetType::CapstypeVirtualchannel => ts_virtualchannel_capability_set(), CapabilitySetType::CapstypeSound => ts_sound_capability_set(), - CapabilitySetType::CapsettypeMultifragmentupdate => ts_multifragment_update_capability_ts(), + CapabilitySetType::CapsettypeMultifragmentupdate => { + ts_multifragment_update_capability_ts() + } _ => { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::Unknown, &format!("CAPABILITY: Unknown capability {:?}", cap_type)))) + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::Unknown, + &format!("CAPABILITY: Unknown capability {:?}", cap_type), + ))) } }; - capability.message.read(&mut Cursor::new(cast!(DataType::Slice, capability_set["capabilitySet"])?))?; + capability.message.read(&mut Cursor::new(cast!( + DataType::Slice, + capability_set["capabilitySet"] + )?))?; Ok(capability) } } @@ -118,7 +131,10 @@ impl Capability { /// } /// ``` pub fn capability_set(capability: Option) -> Component { - let default_capability = capability.unwrap_or(Capability{ cap_type: CapabilitySetType::CapstypeGeneral, message: component![]}); + let default_capability = capability.unwrap_or(Capability { + cap_type: CapabilitySetType::CapstypeGeneral, + message: component![], + }); component![ "capabilitySetType" => U16::LE(default_capability.cap_type as u16), "lengthCapability" => DynOption::new(U16::LE(default_capability.message.length() as u16 + 4), |length| MessageOption::Size("capabilitySet".to_string(), length.inner() as usize - 4)), @@ -136,7 +152,7 @@ enum MajorType { OsmajortypeUnix = 0x0004, OsmajortypeIos = 0x0005, OsmajortypeOsx = 0x0006, - OsmajortypeAndroid = 0x0007 + OsmajortypeAndroid = 0x0007, } #[allow(dead_code)] @@ -150,7 +166,7 @@ enum MinorType { OsminortypeMacintosh = 0x0006, OsminortypeNativeXserver = 0x0007, OsminortypePseudoXserver = 0x0008, - OsminortypeWindowsRt = 0x0009 + OsminortypeWindowsRt = 0x0009, } #[repr(u16)] @@ -159,7 +175,7 @@ pub enum GeneralExtraFlag { NoBitmapCompressionHdr = 0x0400, LongCredentialsSupported = 0x0004, AutoreconnectSupported = 0x0008, - EncSaltedChecksum = 0x0010 + EncSaltedChecksum = 0x0010, } /// General capability @@ -189,7 +205,7 @@ pub fn ts_general_capability_set(extra_flags: Option) -> Capability { "generalCompressionLevel" => Check::new(U16::LE(0)), "refreshRectSupport" => 0 as u8, "suppressOutputSupport" => 0 as u8 - ] + ], } } @@ -206,7 +222,11 @@ pub fn ts_general_capability_set(extra_flags: Option) -> Capability { /// let capability_set = capability_set(Some(ts_bitmap_capability_set(Some(24), Some(800), Some(600)))); /// assert_eq!(to_vec(&capability_set), [2, 0, 28, 0, 24, 0, 1, 0, 1, 0, 1, 0, 32, 3, 88, 2, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]) /// ``` -pub fn ts_bitmap_capability_set(preferred_bits_per_pixel: Option, desktop_width: Option, desktop_height: Option) -> Capability { +pub fn ts_bitmap_capability_set( + preferred_bits_per_pixel: Option, + desktop_width: Option, + desktop_height: Option, +) -> Capability { Capability { cap_type: CapabilitySetType::CapstypeBitmap, message: component![ @@ -223,7 +243,7 @@ pub fn ts_bitmap_capability_set(preferred_bits_per_pixel: Option, desktop_w "drawingFlags" => 0 as u8, "multipleRectangleSupport" => Check::new(U16::LE(0x0001)), "pad2octetsB" => U16::LE(0) - ] + ], } } @@ -234,7 +254,7 @@ pub enum OrderFlag { ZEROBOUNDSDELTASSUPPORT = 0x0008, COLORINDEXSUPPORT = 0x0020, SOLIDPATTERNBRUSHONLY = 0x0040, - OrderflagsExtraFlags = 0x0080 + OrderflagsExtraFlags = 0x0080, } /// Order capability @@ -270,7 +290,7 @@ pub fn ts_order_capability_set(order_flags: Option) -> Capability { "pad2octetsD" => U16::LE(0), "textANSICodePage" => U16::LE(0), "pad2octetsE" => U16::LE(0) - ] + ], } } @@ -301,7 +321,7 @@ pub fn ts_bitmap_cache_capability_set() -> Capability { "cache1MaximumCellSize" => U16::LE(0), "cache2Entries" => U16::LE(0), "cache2MaximumCellSize" => U16::LE(0) - ] + ], } } @@ -323,7 +343,7 @@ pub fn ts_pointer_capability_set() -> Capability { message: component![ "colorPointerFlag" => U16::LE(0), "colorPointerCacheSize" => U16::LE(20) - ] + ], } } @@ -350,7 +370,7 @@ pub enum InputFlags { InputFlagUnused2 = 0x0080, /// Support of the mouse wheel /// This feature is supported by rdp-rs - TsInputFlagMouseHwheel = 0x0100 + TsInputFlagMouseHwheel = 0x0100, } /// Send input capability @@ -365,7 +385,10 @@ pub enum InputFlags { /// let capability_set = capability_set(Some(ts_input_capability_set(Some(InputFlags::InputFlagScancodes as u16), Some(KeyboardLayout::French)))); /// assert_eq!(to_vec(&capability_set), vec![13, 0, 88, 0, 1, 0, 0, 0, 12, 4, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) /// ``` -pub fn ts_input_capability_set(input_flags: Option, keyboard_layout: Option) -> Capability { +pub fn ts_input_capability_set( + input_flags: Option, + keyboard_layout: Option, +) -> Capability { Capability { cap_type: CapabilitySetType::CapstypeInput, message: component![ @@ -376,7 +399,7 @@ pub fn ts_input_capability_set(input_flags: Option, keyboard_layout: Option "keyboardSubType" => U32::LE(0), "keyboardFunctionKey" => U32::LE(12), "imeFileName" => vec![0 as u8; 64] - ] + ], } } @@ -397,7 +420,7 @@ pub fn ts_brush_capability_set() -> Capability { cap_type: CapabilitySetType::CapstypeBrush, message: component![ "brushSupportLevel" => U32::LE(0) - ] + ], } } @@ -411,7 +434,6 @@ fn cache_entry() -> Component { ] } - /// Glyph capability set /// send from client to server /// @@ -436,7 +458,7 @@ pub fn ts_glyph_capability_set() -> Capability { "fragCache" => U32::LE(0), "glyphSupportLevel" => U16::LE(0), "pad2octets" => U16::LE(0) - ] + ], } } @@ -460,7 +482,7 @@ pub fn ts_offscreen_capability_set() -> Capability { "offscreenSupportLevel" => U32::LE(0), "offscreenCacheSize" => U16::LE(0), "offscreenCacheEntries" => U16::LE(0) - ] + ], } } @@ -483,7 +505,7 @@ pub fn ts_virtualchannel_capability_set() -> Capability { message: component![ "flags" => U32::LE(0), "VCChunkSize" => Some(U32::LE(0)) - ] + ], } } @@ -506,7 +528,7 @@ pub fn ts_sound_capability_set() -> Capability { message: component![ "soundFlags" => U16::LE(0), "pad2octetsA" => U16::LE(0) - ] + ], } } @@ -528,6 +550,6 @@ pub fn ts_multifragment_update_capability_ts() -> Capability { cap_type: CapabilitySetType::CapsettypeMultifragmentupdate, message: component![ "MaxRequestSize" => U32::LE(0) - ] + ], } -} \ No newline at end of file +} diff --git a/src/core/client.rs b/src/core/client.rs index 91ab8ed..72306dc 100644 --- a/src/core/client.rs +++ b/src/core/client.rs @@ -1,15 +1,15 @@ -use core::x224; -use core::gcc::KeyboardLayout; -use core::mcs; -use core::tpkt; -use core::sec; -use core::global; +use crate::core::event::{PointerButton, RdpEvent}; +use crate::core::gcc::KeyboardLayout; +use crate::core::global; +use crate::core::global::{ts_keyboard_event, ts_pointer_event, KeyboardFlag, PointerFlag}; +use crate::core::mcs; +use crate::core::sec; +use crate::core::tpkt; +use crate::core::x224; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::model::link::{Link, Stream}; +use crate::nla::ntlm::Ntlm; use std::io::{Read, Write}; -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; -use model::link::{Link, Stream}; -use core::event::{RdpEvent, PointerButton}; -use core::global::{ts_pointer_event, PointerFlag, ts_keyboard_event, KeyboardFlag}; -use nla::ntlm::Ntlm; impl From<&str> for KeyboardLayout { fn from(e: &str) -> Self { @@ -26,7 +26,7 @@ pub struct RdpClient { /// This is the main switch layer of the protocol mcs: mcs::Client, /// Global channel that implement the basic layer - global: global::Client + global: global::Client, } impl RdpClient { @@ -55,11 +55,16 @@ impl RdpClient { /// }).unwrap() /// ``` pub fn read(&mut self, callback: T) -> RdpResult<()> - where T: FnMut(RdpEvent) { + where + T: FnMut(RdpEvent), + { let (channel_name, message) = self.mcs.read()?; match channel_name.as_str() { "global" => self.global.read(message, &mut self.mcs, callback), - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, &format!("Invalid channel name {:?}", channel_name)))) + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + &format!("Invalid channel name {:?}", channel_name), + ))), } } @@ -106,17 +111,26 @@ impl RdpClient { flags |= PointerFlag::PtrflagsDown as u16; } - self.global.write_input_event(ts_pointer_event(Some(flags), Some(pointer.x), Some(pointer.y)), &mut self.mcs) - }, + self.global.write_input_event( + ts_pointer_event(Some(flags), Some(pointer.x), Some(pointer.y)), + &mut self.mcs, + ) + } // Raw keyboard input RdpEvent::Key(key) => { let mut flags: u16 = 0; if !key.down { flags |= KeyboardFlag::KbdflagsRelease as u16; } - self.global.write_input_event(ts_keyboard_event(Some(flags), Some(key.code)), &mut self.mcs) + self.global.write_input_event( + ts_keyboard_event(Some(flags), Some(key.code)), + &mut self.mcs, + ) } - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, "RDPCLIENT: This event can't be sent"))) + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + "RDPCLIENT: This event can't be sent", + ))), } } @@ -129,9 +143,9 @@ impl RdpClient { match result { Err(Error::RdpError(e)) => match e.kind() { RdpErrorKind::InvalidAutomata => Ok(()), - _ => Err(Error::RdpError(e)) + _ => Err(Error::RdpError(e)), }, - _ => result + _ => result, } } @@ -171,7 +185,7 @@ pub struct Connector { name: String, /// Use network level authentication /// default TRUE - use_nla: bool + use_nla: bool, } impl Connector { @@ -199,7 +213,7 @@ impl Connector { blank_creds: false, check_certificate: false, name: "rdp-rs".to_string(), - use_nla: true + use_nla: true, } } @@ -219,16 +233,18 @@ impl Connector { /// let mut client = connector.connect(tcp).unwrap(); /// ``` pub fn connect(&mut self, stream: S) -> RdpResult> { - // Create a wrapper around the stream - let tcp = Link::new( Stream::Raw(stream)); + let tcp = Link::new(Stream::Raw(stream)); // Compute authentication method let mut authentication = if let Some(hash) = &self.password_hash { Ntlm::from_hash(self.domain.clone(), self.username.clone(), hash) - } - else { - Ntlm::new(self.domain.clone(), self.username.clone(), self.password.clone()) + } else { + Ntlm::new( + self.domain.clone(), + self.username.clone(), + self.password.clone(), + ) }; // Create the x224 layer // With all negotiated security stuff and credentials @@ -243,7 +259,7 @@ impl Connector { self.check_certificate, Some(&mut authentication), self.restricted_admin_mode, - self.blank_creds + self.blank_creds, )?; // Create MCS layer and connect it @@ -256,7 +272,7 @@ impl Connector { &"".to_string(), &"".to_string(), &"".to_string(), - self.auto_logon + self.auto_logon, )?; } else { sec::connect( @@ -264,7 +280,7 @@ impl Connector { &self.domain, &self.username, &self.password, - self.auto_logon + self.auto_logon, )?; } @@ -275,13 +291,10 @@ impl Connector { self.width, self.height, self.layout, - &self.name + &self.name, ); - Ok(RdpClient { - mcs, - global - }) + Ok(RdpClient { mcs, global }) } /// Configure the screen size of the session @@ -348,4 +361,4 @@ impl Connector { self.use_nla = use_nla; self } -} \ No newline at end of file +} diff --git a/src/core/event.rs b/src/core/event.rs index b379171..1e9a0da 100644 --- a/src/core/event.rs +++ b/src/core/event.rs @@ -1,6 +1,6 @@ -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; +use crate::codec::rle::{rgb565torgb32, rle_16_decompress, rle_32_decompress}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use num_enum::TryFromPrimitive; -use codec::rle::{rle_32_decompress, rle_16_decompress, rgb565torgb32}; /// A bitmap event is used /// to notify client that it received @@ -28,7 +28,7 @@ pub struct BitmapEvent { /// true if bitmap buffer is compressed using RLE pub is_compress: bool, /// Bitmap data - pub data: Vec + pub data: Vec, } impl BitmapEvent { @@ -60,41 +60,59 @@ impl BitmapEvent { /// }).unwrap() /// ``` pub fn decompress(self) -> RdpResult> { - // actually only handle 32 bpp match self.bpp { 32 => { // 32 bpp is straight forward - Ok( - if self.is_compress { - let mut result = vec![0 as u8; self.width as usize * self.height as usize * 4]; - rle_32_decompress(&self.data, self.width as u32, self.height as u32, &mut result)?; - result - } else { - self.data - } - ) - }, + Ok(if self.is_compress { + let mut result = vec![0 as u8; self.width as usize * self.height as usize * 4]; + rle_32_decompress( + &self.data, + self.width as u32, + self.height as u32, + &mut result, + )?; + result + } else { + self.data + }) + } 16 => { // 16 bpp is more consumer let result_16bpp = if self.is_compress { let mut result = vec![0 as u16; self.width as usize * self.height as usize * 2]; - rle_16_decompress(&self.data, self.width as usize, self.height as usize, &mut result)?; + rle_16_decompress( + &self.data, + self.width as usize, + self.height as usize, + &mut result, + )?; result } else { let mut result = vec![0 as u16; self.width as usize * self.height as usize]; for i in 0..self.height { for j in 0..self.width { let src = (((self.height - i - 1) * self.width + j) * 2) as usize; - result[(i * self.width + j) as usize] = (self.data[src + 1] as u16) << 8 | self.data[src] as u16; + result[(i * self.width + j) as usize] = + (self.data[src + 1] as u16) << 8 | self.data[src] as u16; } } result }; - Ok(rgb565torgb32(&result_16bpp, self.width as usize, self.height as usize)) - }, - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, &format!("Decompression Algorithm not implemented for bpp {}", self.bpp)))) + Ok(rgb565torgb32( + &result_16bpp, + self.width as usize, + self.height as usize, + )) + } + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + &format!( + "Decompression Algorithm not implemented for bpp {}", + self.bpp + ), + ))), } } } @@ -109,7 +127,7 @@ pub enum PointerButton { /// Right mouse button Right = 2, /// Wheel mouse button - Middle = 3 + Middle = 3, } /// A mouse pointer event @@ -121,7 +139,7 @@ pub struct PointerEvent { /// Which button is pressed pub button: PointerButton, /// true if it's a down press action - pub down: bool + pub down: bool, } /// Keyboard event @@ -131,7 +149,7 @@ pub struct KeyboardEvent { /// Scancode of the key pub code: u16, /// State of the key - pub down: bool + pub down: bool, } /// All event handle by RDP protocol implemented by rdp-rs @@ -141,5 +159,5 @@ pub enum RdpEvent { /// Mouse event Pointer(PointerEvent), /// Keyboard event - Key(KeyboardEvent) -} \ No newline at end of file + Key(KeyboardEvent), +} diff --git a/src/core/gcc.rs b/src/core/gcc.rs index 6ca719f..309b363 100644 --- a/src/core/gcc.rs +++ b/src/core/gcc.rs @@ -1,12 +1,13 @@ -use model::data::{Component, U32, U16, Trame, to_vec, Message, DataType, DynOption, MessageOption, Check, Array}; -use model::unicode::Unicode; -use model::error::{RdpResult, RdpError, RdpErrorKind, Error}; -use core::per; -use std::io::{Cursor, Read}; +use crate::core::per; +use crate::model::data::{ + to_vec, Array, Check, Component, DataType, DynOption, Message, MessageOption, Trame, U16, U32, +}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::model::unicode::Unicode; use std::collections::HashMap; +use std::io::{Cursor, Read}; - -const T124_02_98_OID: [u8; 6] = [ 0, 0, 20, 124, 0, 1 ]; +const T124_02_98_OID: [u8; 6] = [0, 0, 20, 124, 0, 1]; const H221_CS_KEY: [u8; 4] = *b"Duca"; const H221_SC_KEY: [u8; 4] = *b"McDn"; /// RDP protocol version @@ -17,7 +18,7 @@ const H221_SC_KEY: [u8; 4] = *b"McDn"; pub enum Version { RdpVersion = 0x00080001, RdpVersion5plus = 0x00080004, - Unknown + Unknown, } impl From for Version { @@ -25,7 +26,7 @@ impl From for Version { match e { 0x00080001 => Version::RdpVersion5plus, 0x00080004 => Version::RdpVersion, - _ => Version::Unknown + _ => Version::Unknown, } } } @@ -39,12 +40,12 @@ enum ColorDepth { RnsUdColor8BPP = 0xCA01, RnsUdColor16BPP555 = 0xCA02, RnsUdColor16BPP565 = 0xCA03, - RnsUdColor24BPP = 0xCA04 + RnsUdColor24BPP = 0xCA04, } #[repr(u16)] enum Sequence { - RnsUdSasDel = 0xAA03 + RnsUdSasDel = 0xAA03, } /// Keyboard layout @@ -70,7 +71,7 @@ pub enum KeyboardLayout { Japanese = 0x00000411, Korean = 0x00000412, Dutch = 0x00000413, - Norwegian = 0x00000414 + Norwegian = 0x00000414, } /// Keyboard type @@ -78,13 +79,13 @@ pub enum KeyboardLayout { #[repr(u32)] #[allow(dead_code)] pub enum KeyboardType { - IbmPcXt83Key = 0x00000001, - Olivetti = 0x00000002, - IbmPcAt84Key = 0x00000003, - Ibm101102Keys = 0x00000004, - Nokia1050 = 0x00000005, - Nokia9140 = 0x00000006, - Japanese = 0x00000007 + IbmPcXt83Key = 0x00000001, + Olivetti = 0x00000002, + IbmPcAt84Key = 0x00000003, + Ibm101102Keys = 0x00000004, + Nokia1050 = 0x00000005, + Nokia9140 = 0x00000006, + Japanese = 0x00000007, } #[repr(u16)] @@ -94,10 +95,9 @@ enum HighColor { HighColor8BPP = 0x0008, HighColor15BPP = 0x000f, HighColor16BPP = 0x0010, - HighColor24BPP = 0x0018 + HighColor24BPP = 0x0018, } - /// Supported color depth /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/00f1da4a-ee9c-421a-852f-c19f92343d73?redirectedfrom=MSDN #[repr(u16)] @@ -106,7 +106,7 @@ enum Support { RnsUd24BPPSupport = 0x0001, RnsUd16BPPSupport = 0x0002, RnsUd15BPPSupport = 0x0004, - RnsUd32BPPSupport = 0x0008 + RnsUd32BPPSupport = 0x0008, } /// Negotiation of some capability for pdu layer @@ -117,14 +117,14 @@ enum CapabilityFlag { RnsUdCsSupportErrinfoPDU = 0x0001, RnsUdCsWant32BPPSession = 0x0002, RnsUdCsSupportStatusInfoPdu = 0x0004, - RnsUdCsStrongAsymmetricKeys = 0x0008, + RnsUdCsStrongAsymmetricKeys = 0x0008, RnsUdCsUnused = 0x0010, RnsUdCsValidConnectionType = 0x0020, RnsUdCsSupportMonitorLayoutPDU = 0x0040, RnsUdCsSupportNetcharAutodetect = 0x0080, RnsUdCsSupportDynvcGFXProtocol = 0x0100, RnsUdCsSupportDynamicTimezone = 0x0200, - RnsUdCsSupportHeartbeatPDU = 0x0400 + RnsUdCsSupportHeartbeatPDU = 0x0400, } /// Supported encryption method @@ -135,7 +135,7 @@ enum EncryptionMethod { EncryptionFlag40bit = 0x00000001, EncryptionFlag128bit = 0x00000002, EncryptionFlag56bit = 0x00000008, - FipsEncryptionFlag = 0x00000010 + FipsEncryptionFlag = 0x00000010, } /// Encryption level @@ -146,12 +146,12 @@ enum EncryptionLevel { Low = 0x00000001, ClientCompatible = 0x00000002, High = 0x00000003, - Fips = 0x00000004 + Fips = 0x00000004, } #[repr(u16)] #[derive(Eq, PartialEq, Hash)] -pub enum MessageType { +pub enum MessageType { //server -> client ScCore = 0x0C01, ScSecurity = 0x0C02, @@ -162,7 +162,7 @@ pub enum MessageType { CsNet = 0xC003, CsCluster = 0xC004, CsMonitor = 0xC005, - Unknown = 0 + Unknown = 0, } impl From for MessageType { @@ -176,7 +176,7 @@ impl From for MessageType { 0xC003 => MessageType::CsNet, 0xC004 => MessageType::CsCluster, 0xC005 => MessageType::CsMonitor, - _ => MessageType::Unknown + _ => MessageType::Unknown, } } } @@ -190,7 +190,7 @@ pub struct ClientData { pub layout: KeyboardLayout, pub server_selected_protocol: u32, pub rdp_version: Version, - pub name: String + pub name: String, } /// This is the first client specific data @@ -199,15 +199,14 @@ pub struct ClientData { /// RDP they are not use /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/00f1da4a-ee9c-421a-852f-c19f92343d73?redirectedfrom=MSDN pub fn client_core_data(parameter: Option) -> Component { - let client_parameter = parameter.unwrap_or( - ClientData { - width: 0, - height: 0, - layout: KeyboardLayout::French, - server_selected_protocol: 0, - rdp_version: Version::RdpVersion5plus, - name: "".to_string() - }); + let client_parameter = parameter.unwrap_or(ClientData { + width: 0, + height: 0, + layout: KeyboardLayout::French, + server_selected_protocol: 0, + rdp_version: Version::RdpVersion5plus, + name: "".to_string(), + }); let client_name = if client_parameter.name.len() >= 16 { (&client_parameter.name[0..16]).to_string() @@ -246,7 +245,7 @@ pub fn client_core_data(parameter: Option) -> Component { ] } -pub fn server_core_data() -> Component{ +pub fn server_core_data() -> Component { component![ "rdpVersion" => U32::LE(0), "clientRequestedProtocol" => Some(U32::LE(0)), @@ -307,7 +306,7 @@ pub fn block_header(data_type: Option, length: Option) -> Comp ] } -pub fn write_conference_create_request(user_data: &[u8]) ->RdpResult> { +pub fn write_conference_create_request(user_data: &[u8]) -> RdpResult> { let mut result = Cursor::new(vec![]); per::write_choice(0, &mut result)?; per::write_object_identifier(&T124_02_98_OID, &mut result)?; @@ -318,14 +317,14 @@ pub fn write_conference_create_request(user_data: &[u8]) ->RdpResult> { per::write_padding(1, &mut result)?; per::write_number_of_set(1, &mut result)?; per::write_choice(0xc0, &mut result)?; - per::write_octet_stream(&H221_CS_KEY, 4,&mut result)?; + per::write_octet_stream(&H221_CS_KEY, 4, &mut result)?; per::write_octet_stream(user_data, 0, &mut result)?; Ok(result.into_inner()) } pub struct ServerData { pub channel_ids: Vec, - pub rdp_version : Version + pub rdp_version: Version, } /// Read conference create response @@ -345,14 +344,17 @@ pub fn read_conference_create_response(cc_response: &mut dyn Read) -> RdpResult< let mut result = HashMap::new(); let mut sub = cc_response.take(length as u64); loop { - let mut header = block_header(None, None); // No more blocks to read if header.read(&mut sub).is_err() { break; } - let mut buffer = vec![0 as u8; (cast!(DataType::U16, header["length"])? - header.length() as u16) as usize]; + let mut buffer = vec![ + 0 as u8; + (cast!(DataType::U16, header["length"])? - header.length() as u16) + as usize + ]; sub.read_exact(&mut buffer)?; match MessageType::from(cast!(DataType::U16, header["type"])?) { @@ -360,24 +362,36 @@ pub fn read_conference_create_response(cc_response: &mut dyn Read) -> RdpResult< let mut server_core = server_core_data(); server_core.read(&mut Cursor::new(buffer))?; result.insert(MessageType::ScCore, server_core); - }, + } MessageType::ScSecurity => { let mut server_security = server_security_data(); server_security.read(&mut Cursor::new(buffer))?; result.insert(MessageType::ScSecurity, server_security); - }, + } MessageType::ScNet => { let mut server_net = server_network_data(); server_net.read(&mut Cursor::new(buffer))?; result.insert(MessageType::ScNet, server_net); } - _ => println!("GCC: Unknown server block {:?}", cast!(DataType::U16, header["type"])?) + _ => println!( + "GCC: Unknown server block {:?}", + cast!(DataType::U16, header["type"])? + ), } } // All section are important - Ok(ServerData{ - channel_ids: cast!(DataType::Trame, result[&MessageType::ScNet]["channelIdArray"])?.into_iter().map(|x| cast!(DataType::U16, x).unwrap()).collect(), - rdp_version: Version::from(cast!(DataType::U32, result[&MessageType::ScCore]["rdpVersion"])?) + Ok(ServerData { + channel_ids: cast!( + DataType::Trame, + result[&MessageType::ScNet]["channelIdArray"] + )? + .into_iter() + .map(|x| cast!(DataType::U16, x).unwrap()) + .collect(), + rdp_version: Version::from(cast!( + DataType::U32, + result[&MessageType::ScCore]["rdpVersion"] + )?), }) -} \ No newline at end of file +} diff --git a/src/core/global.rs b/src/core/global.rs index 8d9cf69..023c0dc 100644 --- a/src/core/global.rs +++ b/src/core/global.rs @@ -1,15 +1,16 @@ -use core::mcs; -use core::tpkt; -use std::io::{Read, Write, Cursor}; -use model::error::{RdpResult, Error, RdpErrorKind, RdpError}; -use model::data::{Component, MessageOption, U32, DynOption, U16, DataType, Message, Array, Trame, Check, to_vec}; -use core::event::{RdpEvent, BitmapEvent}; +use crate::core::capability; +use crate::core::capability::{capability_set, Capability}; +use crate::core::event::{BitmapEvent, RdpEvent}; +use crate::core::gcc::KeyboardLayout; +use crate::core::mcs; +use crate::core::tpkt; +use crate::model::data::{ + to_vec, Array, Check, Component, DataType, DynOption, Message, MessageOption, Trame, U16, U32, +}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use num_enum::TryFromPrimitive; use std::convert::TryFrom; -use core::capability::{Capability, capability_set}; -use core::capability; -use core::gcc::KeyboardLayout; - +use std::io::{Cursor, Read, Write}; /// Raw PDU type use by the protocol #[repr(u16)] @@ -19,7 +20,7 @@ enum PDUType { PdutypeConfirmactivepdu = 0x13, PdutypeDeactivateallpdu = 0x16, PdutypeDatapdu = 0x17, - PdutypeServerRedirPkt = 0x1A + PdutypeServerRedirPkt = 0x1A, } /// PDU type available @@ -27,7 +28,7 @@ enum PDUType { /// Then once connected only Data are send and received struct PDU { pub pdu_type: PDUType, - pub message: Component + pub message: Component, } impl PDU { @@ -46,9 +47,17 @@ impl PDU { PDUType::PdutypeDatapdu => share_data_header(None, None, None), PDUType::PdutypeConfirmactivepdu => ts_confirm_active_pdu(None, None, None), PDUType::PdutypeDeactivateallpdu => ts_deactivate_all_pdu(), - _ => return Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, "GLOBAL: PDU not implemented"))) + _ => { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + "GLOBAL: PDU not implemented", + ))) + } }; - pdu.message.read(&mut Cursor::new(cast!(DataType::Slice, control["pduMessage"])?))?; + pdu.message.read(&mut Cursor::new(cast!( + DataType::Slice, + control["pduMessage"] + )?))?; Ok(pdu) } } @@ -71,7 +80,7 @@ fn ts_demand_active_pdu() -> PDU { "pad2Octets" => U16::LE(0), "capabilitySets" => Array::new(|| capability_set(None)), "sessionId" => U32::LE(0) - ] + ], } } @@ -79,7 +88,11 @@ fn ts_demand_active_pdu() -> PDU { /// This PDU declare capabilities for the client /// /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/4e9722c3-ad83-43f5-af5a-529f73d88b48 -fn ts_confirm_active_pdu(share_id: Option, source: Option>, capabilities_set: Option>) -> PDU { +fn ts_confirm_active_pdu( + share_id: Option, + source: Option>, + capabilities_set: Option>, +) -> PDU { let default_capabilities_set = capabilities_set.unwrap_or(Array::new(|| capability_set(None))); let default_source = source.unwrap_or(vec![]); PDU { @@ -93,7 +106,7 @@ fn ts_confirm_active_pdu(share_id: Option, source: Option>, capabil "numberCapabilities" => U16::LE(default_capabilities_set.inner().len() as u16), "pad2Octets" => U16::LE(0), "capabilitySets" => default_capabilities_set - ] + ], } } @@ -107,12 +120,16 @@ fn ts_deactivate_all_pdu() -> PDU { "shareId" => U32::LE(0), "lengthSourceDescriptor" => DynOption::new(U16::LE(0), |length| MessageOption::Size("sourceDescriptor".to_string(), length.inner() as usize)), "sourceDescriptor" => Vec::::new() - ] + ], } } /// All Data PDU share the same layout -fn share_data_header(share_id: Option, pdu_type_2: Option, message: Option>) -> PDU { +fn share_data_header( + share_id: Option, + pdu_type_2: Option, + message: Option>, +) -> PDU { let default_message = message.unwrap_or(vec![]); PDU { pdu_type: PDUType::PdutypeDatapdu, @@ -125,14 +142,17 @@ fn share_data_header(share_id: Option, pdu_type_2: Option, messag "compressedType" => 0 as u8, "compressedLength" => U16::LE(0), "payload" => default_message - ] + ], } } - /// This is the main PDU payload format /// It use the share control header to dispatch between all PDU -fn share_control_header(pdu_type: Option, pdu_source: Option, message: Option>) -> Component { +fn share_control_header( + pdu_type: Option, + pdu_source: Option, + message: Option>, +) -> Component { let default_message = message.unwrap_or(vec![]); component![ "totalLength" => DynOption::new(U16::LE(default_message.length() as u16 + 6), |total| MessageOption::Size("pduMessage".to_string(), total.inner() as usize - 6)), @@ -169,13 +189,13 @@ enum PDUType2 { Pdutype2ArcStatusPdu = 0x32, Pdutype2StatusInfoPdu = 0x36, Pdutype2MonitorLayoutPdu = 0x37, - Unknown + Unknown, } /// Data PDU container struct DataPDU { pdu_type: PDUType2, - message: Component + message: Component, } impl DataPDU { @@ -190,9 +210,17 @@ impl DataPDU { PDUType2::Pdutype2Fontlist => ts_font_list_pdu(), PDUType2::Pdutype2Fontmap => ts_font_map_pdu(), PDUType2::Pdutype2SetErrorInfoPdu => ts_set_error_info_pdu(), - _ => return Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, &format!("GLOBAL: Data PDU parsing not implemented {:?}", pdu_type)))) + _ => { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + &format!("GLOBAL: Data PDU parsing not implemented {:?}", pdu_type), + ))) + } }; - result.message.read(&mut Cursor::new(cast!(DataType::Slice, data_pdu.message["payload"])?))?; + result.message.read(&mut Cursor::new(cast!( + DataType::Slice, + data_pdu.message["payload"] + )?))?; Ok(result) } } @@ -206,7 +234,7 @@ fn ts_synchronize_pdu(target_user: Option) -> DataPDU { message: component![ "messageType" => Check::new(U16::LE(1)), "targetUser" => Some(U16::LE(target_user.unwrap_or(0))) - ] + ], } } @@ -221,7 +249,7 @@ fn ts_font_list_pdu() -> DataPDU { "totalNumFonts" => U16::LE(0), "listFlags" => U16::LE(0x0003), "entrySize" => U16::LE(0x0032) - ] + ], } } @@ -233,7 +261,7 @@ fn ts_set_error_info_pdu() -> DataPDU { pdu_type: PDUType2::Pdutype2SetErrorInfoPdu, message: component![ "errorInfo" => U32::LE(0) - ] + ], } } @@ -243,7 +271,7 @@ enum Action { CtrlactionRequestControl = 0x0001, CtrlactionGrantedControl = 0x0002, CtrlactionDetach = 0x0003, - CtrlactionCooperate = 0x0004 + CtrlactionCooperate = 0x0004, } /// Control payload send during pdu handshake @@ -256,7 +284,7 @@ fn ts_control_pdu(action: Option) -> DataPDU { "action" => U16::LE(action.unwrap_or(Action::CtrlactionCooperate) as u16), "grantId" => U16::LE(0), "controlId" => U32::LE(0) - ] + ], } } @@ -271,7 +299,7 @@ fn ts_font_map_pdu() -> DataPDU { "totalNumEntries" => U16::LE(0), "mapFlags" => U16::LE(0x0003), "entrySize" => U16::LE(0x0004) - ] + ], } } @@ -284,7 +312,7 @@ fn ts_input_pdu_data(events: Option>) -> DataPDU { "numEvents" => U16::LE(default_events.inner().len() as u16), "pad2Octets" => U16::LE(0), "slowPathInputEvents" => default_events - ] + ], } } @@ -307,13 +335,13 @@ pub enum InputEventType { InputEventScancode = 0x0004, InputEventUnicode = 0x0005, InputEventMouse = 0x8001, - InputEventMousex = 0x8002 + InputEventMousex = 0x8002, } /// All Terminal Service Slow Path Input Event pub struct TSInputEvent { event_type: InputEventType, - message: Component + message: Component, } /// All supported flags for pointer event @@ -329,7 +357,7 @@ pub enum PointerFlag { PtrflagsDown = 0x8000, PtrflagsButton1 = 0x1000, PtrflagsButton2 = 0x2000, - PtrflagsButton3 = 0x4000 + PtrflagsButton3 = 0x4000, } /// A pointer event @@ -338,11 +366,11 @@ pub enum PointerFlag { pub fn ts_pointer_event(flags: Option, x: Option, y: Option) -> TSInputEvent { TSInputEvent { event_type: InputEventType::InputEventMouse, - message : component![ + message: component![ "pointerFlags" => U16::LE(flags.unwrap_or(0)), "xPos" => U16::LE(x.unwrap_or(0)), "yPos" => U16::LE(y.unwrap_or(0)) - ] + ], } } @@ -350,7 +378,7 @@ pub fn ts_pointer_event(flags: Option, x: Option, y: Option) -> T pub enum KeyboardFlag { KbdflagsExtended = 0x0100, KbdflagsDown = 0x4000, - KbdflagsRelease = 0x8000 + KbdflagsRelease = 0x8000, } /// Raw input keyboard event @@ -362,7 +390,7 @@ pub fn ts_keyboard_event(flags: Option, key_code: Option) -> TSInputEv "keyboardFlags" => U16::LE(flags.unwrap_or(0)), "keyCode" => U16::LE(key_code.unwrap_or(0)), "pad2Octets" => U16::LE(0) - ] + ], } } @@ -385,7 +413,6 @@ fn ts_fp_update() -> Component { ] } - #[repr(u8)] #[derive(Debug, TryFromPrimitive, Copy, Clone, Eq, PartialEq)] enum FastPathUpdateType { @@ -400,26 +427,38 @@ enum FastPathUpdateType { FastpathUpdatetypeColor = 0x9, FastpathUpdatetypeCached = 0xA, FastpathUpdatetypePointer = 0xB, - Unknown + Unknown, } -struct FastPathUpdate{ +struct FastPathUpdate { fp_type: FastPathUpdateType, - message: Component + message: Component, } impl FastPathUpdate { /// Parse Fast Path update order fn from_fp(fast_path: &Component) -> RdpResult { - let fp_update_type = FastPathUpdateType::try_from(cast!(DataType::U8, fast_path["updateHeader"])? & 0xf)?; + let fp_update_type = + FastPathUpdateType::try_from(cast!(DataType::U8, fast_path["updateHeader"])? & 0xf)?; let mut result = match fp_update_type { FastPathUpdateType::FastpathUpdatetypeBitmap => ts_fp_update_bitmap(), FastPathUpdateType::FastpathUpdatetypeColor => ts_colorpointerattribute(), FastPathUpdateType::FastpathUpdatetypeSynchronize => ts_fp_update_synchronize(), FastPathUpdateType::FastpathUpdatetypePtrNull => ts_fp_systempointerhiddenattribute(), - _ => return Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, &format!("GLOBAL: Fast Path parsing not implemented {:?}", fp_update_type)))) + _ => { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + &format!( + "GLOBAL: Fast Path parsing not implemented {:?}", + fp_update_type + ), + ))) + } }; - result.message.read(&mut Cursor::new(cast!(DataType::Slice, fast_path["updateData"])?))?; + result.message.read(&mut Cursor::new(cast!( + DataType::Slice, + fast_path["updateData"] + )?))?; Ok(result) } } @@ -470,7 +509,7 @@ fn ts_fp_update_bitmap() -> FastPathUpdate { "header" => Check::new(U16::LE(FastPathUpdateType::FastpathUpdatetypeBitmap as u16)), "numberRectangles" => U16::LE(0), "rectangles" => Array::new(|| ts_bitmap_data()) - ] + ], } } @@ -481,16 +520,16 @@ fn ts_colorpointerattribute() -> FastPathUpdate { FastPathUpdate { fp_type: FastPathUpdateType::FastpathUpdatetypeColor, message: component![ - "cacheIndex " => U16::LE(0), - "hotSpot " => U32::LE(0), - "width" => U16::LE(0), - "height" => U16::LE(0), - "lengthAndMask" => DynOption::new(U16::LE(0), |length| MessageOption::Size("andMaskData".to_string(), length.inner() as usize)), - "lengthXorMask" => DynOption::new(U16::LE(0), |length| MessageOption::Size("xorMaskData".to_string(), length.inner() as usize)), - "xorMaskData" => Vec::::new(), - "andMaskData" => Vec::::new(), - "pad" => Some(0 as u8) - ] + "cacheIndex " => U16::LE(0), + "hotSpot " => U32::LE(0), + "width" => U16::LE(0), + "height" => U16::LE(0), + "lengthAndMask" => DynOption::new(U16::LE(0), |length| MessageOption::Size("andMaskData".to_string(), length.inner() as usize)), + "lengthXorMask" => DynOption::new(U16::LE(0), |length| MessageOption::Size("xorMaskData".to_string(), length.inner() as usize)), + "xorMaskData" => Vec::::new(), + "andMaskData" => Vec::::new(), + "pad" => Some(0 as u8) + ], } } @@ -500,7 +539,7 @@ fn ts_colorpointerattribute() -> FastPathUpdate { fn ts_fp_update_synchronize() -> FastPathUpdate { FastPathUpdate { fp_type: FastPathUpdateType::FastpathUpdatetypeSynchronize, - message: component![] + message: component![], } } @@ -510,7 +549,7 @@ fn ts_fp_update_synchronize() -> FastPathUpdate { fn ts_fp_systempointerhiddenattribute() -> FastPathUpdate { FastPathUpdate { fp_type: FastPathUpdateType::FastpathUpdatetypePtrNull, - message: component![] + message: component![], } } @@ -527,7 +566,7 @@ enum ClientState { FontMap, /// wait for date /// either data pdu or fast path data - Data + Data, } pub struct Client { @@ -550,7 +589,7 @@ pub struct Client { /// Keep tracing of server capabilities server_capabilities: Vec, /// Name send to the server - name: String + name: String, } impl Client { @@ -571,7 +610,14 @@ impl Client { /// "mstsc-rs" /// ); /// ``` - pub fn new(user_id: u16, channel_id: u16, width: u16, height: u16, layout: KeyboardLayout, name: &str) -> Client { + pub fn new( + user_id: u16, + channel_id: u16, + width: u16, + height: u16, + layout: KeyboardLayout, + name: &str, + ) -> Client { Client { state: ClientState::DemandActivePDU, server_capabilities: Vec::new(), @@ -581,7 +627,7 @@ impl Client { width, height, layout, - name: String::from(name) + name: String::from(name), } } @@ -596,13 +642,13 @@ impl Client { for capability_set in cast!(DataType::Trame, pdu.message["capabilitySets"])?.iter() { match Capability::from_capability_set(cast!(DataType::Component, capability_set)?) { Ok(capability) => self.server_capabilities.push(capability), - Err(e) => println!("GLOBAL: {:?}", e) + Err(e) => println!("GLOBAL: {:?}", e), } } self.share_id = Some(cast!(DataType::U32, pdu.message["shareId"])?); - return Ok(true) + return Ok(true); } - return Ok(false) + return Ok(false); } /// Read server synchronize pdu @@ -612,10 +658,10 @@ impl Client { fn read_synchronize_pdu(&mut self, stream: &mut dyn Read) -> RdpResult { let pdu = PDU::from_stream(stream)?; if pdu.pdu_type != PDUType::PdutypeDatapdu { - return Ok(false) + return Ok(false); } if DataPDU::from_pdu(&pdu)?.pdu_type != PDUType2::Pdutype2Synchronize { - return Ok(false) + return Ok(false); } Ok(true) } @@ -626,16 +672,19 @@ impl Client { fn read_control_pdu(&mut self, stream: &mut dyn Read, action: Action) -> RdpResult { let pdu = PDU::from_stream(stream)?; if pdu.pdu_type != PDUType::PdutypeDatapdu { - return Ok(false) + return Ok(false); } let data_pdu = DataPDU::from_pdu(&pdu)?; if data_pdu.pdu_type != PDUType2::Pdutype2Control { - return Ok(false) + return Ok(false); } - if cast!(DataType::U16, data_pdu.message["action"])? != action as u16 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::UnexpectedType, "GLOBAL: bad message type"))) + if cast!(DataType::U16, data_pdu.message["action"])? != action as u16 { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::UnexpectedType, + "GLOBAL: bad message type", + ))); } Ok(true) @@ -644,13 +693,13 @@ impl Client { /// Read the server font data PDU /// /// This function return true if it read the expected PDU - fn read_font_map_pdu(&mut self, stream: &mut dyn Read) -> RdpResult { + fn read_font_map_pdu(&mut self, stream: &mut dyn Read) -> RdpResult { let pdu = PDU::from_stream(stream)?; if pdu.pdu_type != PDUType::PdutypeDatapdu { - return Ok(false) + return Ok(false); } if DataPDU::from_pdu(&pdu)?.pdu_type != PDUType2::Pdutype2Fontmap { - return Ok(false) + return Ok(false); } Ok(true) } @@ -678,13 +727,14 @@ impl Client { } match DataPDU::from_pdu(&pdu) { - Ok(data_pdu) => { - match data_pdu.pdu_type { - PDUType2::Pdutype2SetErrorInfoPdu => println!("GLOBAL: Receive error PDU from server {:?}", cast!(DataType::U32, data_pdu.message["errorInfo"])?), - _ => println!("GLOBAL: Data PDU not handle {:?}", data_pdu.pdu_type) - } + Ok(data_pdu) => match data_pdu.pdu_type { + PDUType2::Pdutype2SetErrorInfoPdu => println!( + "GLOBAL: Receive error PDU from server {:?}", + cast!(DataType::U32, data_pdu.message["errorInfo"])? + ), + _ => println!("GLOBAL: Data PDU not handle {:?}", data_pdu.pdu_type), }, - Err(e) => println!("GLOBAL: Parsing data PDU error {:?}", e) + Err(e) => println!("GLOBAL: Parsing data PDU error {:?}", e), }; } Ok(()) @@ -694,7 +744,9 @@ impl Client { /// Reading is processed using a callback patterm /// This is where bitmap are received fn read_fast_path(&mut self, stream: &mut dyn Read, mut callback: T) -> RdpResult<()> - where T: FnMut(RdpEvent) { + where + T: FnMut(RdpEvent), + { // it could be have one or more fast path payload let mut fp_messages = Array::new(|| ts_fp_update()); fp_messages.read(stream)?; @@ -706,27 +758,30 @@ impl Client { FastPathUpdateType::FastpathUpdatetypeBitmap => { for rectangle in cast!(DataType::Trame, order.message["rectangles"])? { let bitmap = cast!(DataType::Component, rectangle)?; - callback(RdpEvent::Bitmap( - BitmapEvent { - dest_left: cast!(DataType::U16, bitmap["destLeft"])?, - dest_top: cast!(DataType::U16, bitmap["destTop"])?, - dest_right: cast!(DataType::U16, bitmap["destRight"])?, - dest_bottom: cast!(DataType::U16, bitmap["destBottom"])?, - width: cast!(DataType::U16, bitmap["width"])?, - height: cast!(DataType::U16, bitmap["height"])?, - bpp: cast!(DataType::U16, bitmap["bitsPerPixel"])?, - is_compress: cast!(DataType::U16, bitmap["flags"])? & BitmapFlag::BitmapCompression as u16 != 0, - data: cast!(DataType::Slice, bitmap["bitmapDataStream"])?.to_vec() - } - )); + callback(RdpEvent::Bitmap(BitmapEvent { + dest_left: cast!(DataType::U16, bitmap["destLeft"])?, + dest_top: cast!(DataType::U16, bitmap["destTop"])?, + dest_right: cast!(DataType::U16, bitmap["destRight"])?, + dest_bottom: cast!(DataType::U16, bitmap["destBottom"])?, + width: cast!(DataType::U16, bitmap["width"])?, + height: cast!(DataType::U16, bitmap["height"])?, + bpp: cast!(DataType::U16, bitmap["bitsPerPixel"])?, + is_compress: cast!(DataType::U16, bitmap["flags"])? + & BitmapFlag::BitmapCompression as u16 + != 0, + data: cast!(DataType::Slice, bitmap["bitmapDataStream"])? + .to_vec(), + })); } - }, + } // do nothing - FastPathUpdateType::FastpathUpdatetypeColor | FastPathUpdateType::FastpathUpdatetypePtrNull | FastPathUpdateType::FastpathUpdatetypeSynchronize => (), - _ => println!("GLOBAL: Fast Path order not handled {:?}", order.fp_type) + FastPathUpdateType::FastpathUpdatetypeColor + | FastPathUpdateType::FastpathUpdatetypePtrNull + | FastPathUpdateType::FastpathUpdatetypeSynchronize => (), + _ => println!("GLOBAL: Fast Path order not handled {:?}", order.fp_type), } - }, - Err(e) => println!("GLOBAL: Unknown Fast Path order {:?}", e) + } + Err(e) => println!("GLOBAL: Unknown Fast Path order {:?}", e), }; } @@ -735,23 +790,47 @@ impl Client { /// Write confirm active pdu /// This PDU include all client capabilities - fn write_confirm_active_pdu(&mut self, mcs: &mut mcs::Client) -> RdpResult<()> { - let pdu = ts_confirm_active_pdu(self.share_id, Some(self.name.as_bytes().to_vec()), Some(Array::from_trame( - trame![ - capability_set(Some(capability::ts_general_capability_set(Some(capability::GeneralExtraFlag::LongCredentialsSupported as u16 | capability::GeneralExtraFlag::NoBitmapCompressionHdr as u16 | capability::GeneralExtraFlag::EncSaltedChecksum as u16 | capability::GeneralExtraFlag::FastpathOutputSupported as u16)))), - capability_set(Some(capability::ts_bitmap_capability_set(Some(0x0018), Some(self.width), Some(self.height)))), - capability_set(Some(capability::ts_order_capability_set(Some(capability::OrderFlag::NEGOTIATEORDERSUPPORT as u16 | capability::OrderFlag::ZEROBOUNDSDELTASSUPPORT as u16)))), + fn write_confirm_active_pdu( + &mut self, + mcs: &mut mcs::Client, + ) -> RdpResult<()> { + let pdu = ts_confirm_active_pdu( + self.share_id, + Some(self.name.as_bytes().to_vec()), + Some(Array::from_trame(trame![ + capability_set(Some(capability::ts_general_capability_set(Some( + capability::GeneralExtraFlag::LongCredentialsSupported as u16 + | capability::GeneralExtraFlag::NoBitmapCompressionHdr as u16 + | capability::GeneralExtraFlag::EncSaltedChecksum as u16 + | capability::GeneralExtraFlag::FastpathOutputSupported as u16 + )))), + capability_set(Some(capability::ts_bitmap_capability_set( + Some(0x0018), + Some(self.width), + Some(self.height) + ))), + capability_set(Some(capability::ts_order_capability_set(Some( + capability::OrderFlag::NEGOTIATEORDERSUPPORT as u16 + | capability::OrderFlag::ZEROBOUNDSDELTASSUPPORT as u16 + )))), capability_set(Some(capability::ts_bitmap_cache_capability_set())), capability_set(Some(capability::ts_pointer_capability_set())), capability_set(Some(capability::ts_sound_capability_set())), - capability_set(Some(capability::ts_input_capability_set(Some(capability::InputFlags::InputFlagScancodes as u16 | capability::InputFlags::InputFlagMousex as u16 | capability::InputFlags::InputFlagUnicode as u16), Some(self.layout)))), + capability_set(Some(capability::ts_input_capability_set( + Some( + capability::InputFlags::InputFlagScancodes as u16 + | capability::InputFlags::InputFlagMousex as u16 + | capability::InputFlags::InputFlagUnicode as u16 + ), + Some(self.layout) + ))), capability_set(Some(capability::ts_brush_capability_set())), capability_set(Some(capability::ts_glyph_capability_set())), capability_set(Some(capability::ts_offscreen_capability_set())), capability_set(Some(capability::ts_virtualchannel_capability_set())), capability_set(Some(capability::ts_multifragment_update_capability_ts())) - ] - ))); + ])), + ); self.write_pdu(pdu, mcs) } @@ -766,12 +845,30 @@ impl Client { /// Send a classic PDU to the global channel fn write_pdu(&self, message: PDU, mcs: &mut mcs::Client) -> RdpResult<()> { - mcs.write(&"global".to_string(), share_control_header(Some(message.pdu_type), Some(self.user_id), Some(to_vec(&message.message)))) + mcs.write( + &"global".to_string(), + share_control_header( + Some(message.pdu_type), + Some(self.user_id), + Some(to_vec(&message.message)), + ), + ) } /// Send Data pdu - fn write_data_pdu(&self, message: DataPDU, mcs: &mut mcs::Client) -> RdpResult<()> { - self.write_pdu(share_data_header(self.share_id, Some(message.pdu_type), Some(to_vec(&message.message))), mcs) + fn write_data_pdu( + &self, + message: DataPDU, + mcs: &mut mcs::Client, + ) -> RdpResult<()> { + self.write_pdu( + share_data_header( + self.share_id, + Some(message.pdu_type), + Some(to_vec(&message.message)), + ), + mcs, + ) } /// Public interface to sent input event @@ -797,10 +894,23 @@ impl Client { /// &mut self.mcs /// ) /// ``` - pub fn write_input_event(&self, event: TSInputEvent, mcs: &mut mcs::Client) -> RdpResult<()> { + pub fn write_input_event( + &self, + event: TSInputEvent, + mcs: &mut mcs::Client, + ) -> RdpResult<()> { match self.state { - ClientState::Data => Ok(self.write_data_pdu(ts_input_pdu_data(Some(Array::from_trame(trame![ts_input_event(Some(event.event_type), Some(to_vec(&event.message)))]))), mcs)?), - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidAutomata, "You cannot send data once it's not connected"))) + ClientState::Data => Ok(self.write_data_pdu( + ts_input_pdu_data(Some(Array::from_trame(trame![ts_input_event( + Some(event.event_type), + Some(to_vec(&event.message)) + )]))), + mcs, + )?), + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidAutomata, + "You cannot send data once it's not connected", + ))), } } @@ -819,8 +929,15 @@ impl Client { /// ... /// } /// ``` - pub fn read(&mut self, payload: tpkt::Payload, mcs: &mut mcs::Client, callback: T) -> RdpResult<()> - where T: FnMut(RdpEvent){ + pub fn read( + &mut self, + payload: tpkt::Payload, + mcs: &mut mcs::Client, + callback: T, + ) -> RdpResult<()> + where + T: FnMut(RdpEvent), + { match self.state { ClientState::DemandActivePDU => { if self.read_demand_active_pdu(&mut try_let!(tpkt::Payload::Raw, payload)?)? { @@ -837,33 +954,41 @@ impl Client { self.state = ClientState::ControlCooperate; } Ok(()) - }, + } ClientState::ControlCooperate => { - if self.read_control_pdu(&mut try_let!(tpkt::Payload::Raw, payload)?, Action::CtrlactionCooperate)? { + if self.read_control_pdu( + &mut try_let!(tpkt::Payload::Raw, payload)?, + Action::CtrlactionCooperate, + )? { // next state is control granted self.state = ClientState::ControlGranted; } Ok(()) - }, + } ClientState::ControlGranted => { - if self.read_control_pdu(&mut try_let!(tpkt::Payload::Raw, payload)?, Action::CtrlactionGrantedControl)? { + if self.read_control_pdu( + &mut try_let!(tpkt::Payload::Raw, payload)?, + Action::CtrlactionGrantedControl, + )? { // next state is font map pdu self.state = ClientState::FontMap; } Ok(()) - }, + } ClientState::FontMap => { if self.read_font_map_pdu(&mut try_let!(tpkt::Payload::Raw, payload)?)? { // finish handshake now wait for sdata self.state = ClientState::Data; } Ok(()) - }, + } ClientState::Data => { // Now we can receive update data match payload { tpkt::Payload::Raw(mut stream) => self.read_data_pdu(&mut stream), - tpkt::Payload::FastPath(_sec_flag, mut stream) => self.read_fast_path(&mut stream, callback) + tpkt::Payload::FastPath(_sec_flag, mut stream) => { + self.read_fast_path(&mut stream, callback) + } } } } @@ -877,53 +1002,124 @@ mod test { /// Test format message of demand active pdu #[test] fn test_demand_active_pdu() { - let mut stream = Cursor::new(vec![234, 3, 1, 0, 4, 0, 179, 1, 82, 68, 80, 0, 17, 0, 0, 0, 9, 0, 8, 0, 234, 3, 0, 0, 1, 0, 24, 0, 1, 0, 3, 0, 0, 2, 0, 0, 0, 0, 29, 4, 0, 0, 0, 0, 0, 0, 1, 1, 20, 0, 12, 0, 2, 0, 0, 0, 64, 6, 0, 0, 10, 0, 8, 0, 6, 0, 0, 0, 8, 0, 10, 0, 1, 0, 25, 0, 25, 0, 27, 0, 6, 0, 3, 0, 14, 0, 8, 0, 1, 0, 0, 0, 2, 0, 28, 0, 32, 0, 1, 0, 1, 0, 1, 0, 32, 3, 88, 2, 0, 0, 1, 0, 1, 0, 0, 30, 1, 0, 0, 0, 29, 0, 96, 0, 4, 185, 27, 141, 202, 15, 0, 79, 21, 88, 159, 174, 45, 26, 135, 226, 214, 0, 3, 0, 1, 1, 3, 18, 47, 119, 118, 114, 189, 99, 68, 175, 179, 183, 60, 156, 111, 120, 134, 0, 4, 0, 0, 0, 0, 0, 166, 81, 67, 156, 53, 53, 174, 66, 145, 12, 205, 252, 229, 118, 11, 88, 0, 4, 0, 0, 0, 0, 0, 212, 204, 68, 39, 138, 157, 116, 78, 128, 60, 14, 203, 238, 161, 156, 84, 0, 4, 0, 0, 0, 0, 0, 3, 0, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 66, 15, 0, 1, 0, 20, 0, 0, 0, 1, 0, 0, 0, 170, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 161, 6, 6, 0, 64, 66, 15, 0, 64, 66, 15, 0, 1, 0, 0, 0, 0, 0, 0, 0, 18, 0, 8, 0, 1, 0, 0, 0, 13, 0, 88, 0, 117, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 0, 8, 0, 255, 0, 0, 0, 24, 0, 11, 0, 2, 0, 0, 0, 3, 12, 0, 26, 0, 8, 0, 43, 72, 9, 0, 28, 0, 12, 0, 82, 0, 0, 0, 0, 0, 0, 0, 30, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + let mut stream = Cursor::new(vec![ + 234, 3, 1, 0, 4, 0, 179, 1, 82, 68, 80, 0, 17, 0, 0, 0, 9, 0, 8, 0, 234, 3, 0, 0, 1, 0, + 24, 0, 1, 0, 3, 0, 0, 2, 0, 0, 0, 0, 29, 4, 0, 0, 0, 0, 0, 0, 1, 1, 20, 0, 12, 0, 2, 0, + 0, 0, 64, 6, 0, 0, 10, 0, 8, 0, 6, 0, 0, 0, 8, 0, 10, 0, 1, 0, 25, 0, 25, 0, 27, 0, 6, + 0, 3, 0, 14, 0, 8, 0, 1, 0, 0, 0, 2, 0, 28, 0, 32, 0, 1, 0, 1, 0, 1, 0, 32, 3, 88, 2, + 0, 0, 1, 0, 1, 0, 0, 30, 1, 0, 0, 0, 29, 0, 96, 0, 4, 185, 27, 141, 202, 15, 0, 79, 21, + 88, 159, 174, 45, 26, 135, 226, 214, 0, 3, 0, 1, 1, 3, 18, 47, 119, 118, 114, 189, 99, + 68, 175, 179, 183, 60, 156, 111, 120, 134, 0, 4, 0, 0, 0, 0, 0, 166, 81, 67, 156, 53, + 53, 174, 66, 145, 12, 205, 252, 229, 118, 11, 88, 0, 4, 0, 0, 0, 0, 0, 212, 204, 68, + 39, 138, 157, 116, 78, 128, 60, 14, 203, 238, 161, 156, 84, 0, 4, 0, 0, 0, 0, 0, 3, 0, + 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 66, 15, 0, 1, 0, 20, 0, 0, + 0, 1, 0, 0, 0, 170, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 161, 6, 6, 0, 64, 66, 15, 0, 64, 66, 15, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 18, 0, 8, 0, 1, 0, 0, 0, 13, 0, 88, 0, 117, 3, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 0, 8, 0, 255, 0, 0, 0, 24, 0, 11, + 0, 2, 0, 0, 0, 3, 12, 0, 26, 0, 8, 0, 43, 72, 9, 0, 28, 0, 12, 0, 82, 0, 0, 0, 0, 0, 0, + 0, 30, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]); let mut pdu = ts_demand_active_pdu(); pdu.message.read(&mut stream).unwrap(); - assert_eq!(cast!(DataType::U16, pdu.message["numberCapabilities"]).unwrap(), 17) + assert_eq!( + cast!(DataType::U16, pdu.message["numberCapabilities"]).unwrap(), + 17 + ) } /// Test confirm active PDU format #[test] fn test_confirm_active_pdu() { let mut stream = Cursor::new(vec![]); - ts_confirm_active_pdu(Some(4), Some(b"rdp-rs".to_vec()), Some(Array::from_trame(trame![capability_set(Some(capability::ts_brush_capability_set()))]))).message.write(&mut stream).unwrap(); - assert_eq!(stream.into_inner(), [4, 0, 0, 0, 234, 3, 6, 0, 12, 0, 114, 100, 112, 45, 114, 115, 1, 0, 0, 0, 15, 0, 8, 0, 0, 0, 0, 0]); + ts_confirm_active_pdu( + Some(4), + Some(b"rdp-rs".to_vec()), + Some(Array::from_trame(trame![capability_set(Some( + capability::ts_brush_capability_set() + ))])), + ) + .message + .write(&mut stream) + .unwrap(); + assert_eq!( + stream.into_inner(), + [ + 4, 0, 0, 0, 234, 3, 6, 0, 12, 0, 114, 100, 112, 45, 114, 115, 1, 0, 0, 0, 15, 0, 8, + 0, 0, 0, 0, 0 + ] + ); } #[test] fn test_share_control_header() { let mut stream = Cursor::new(vec![]); - share_control_header(Some(PDUType::PdutypeConfirmactivepdu), Some(12), Some(to_vec(&ts_confirm_active_pdu(Some(4), Some(b"rdp-rs".to_vec()), Some(Array::from_trame(trame![capability_set(Some(capability::ts_brush_capability_set()))]))).message))).write(&mut stream).unwrap(); - - assert_eq!(stream.into_inner(), vec![34, 0, 19, 0, 12, 0, 4, 0, 0, 0, 234, 3, 6, 0, 12, 0, 114, 100, 112, 45, 114, 115, 1, 0, 0, 0, 15, 0, 8, 0, 0, 0, 0, 0]) + share_control_header( + Some(PDUType::PdutypeConfirmactivepdu), + Some(12), + Some(to_vec( + &ts_confirm_active_pdu( + Some(4), + Some(b"rdp-rs".to_vec()), + Some(Array::from_trame(trame![capability_set(Some( + capability::ts_brush_capability_set() + ))])), + ) + .message, + )), + ) + .write(&mut stream) + .unwrap(); + + assert_eq!( + stream.into_inner(), + vec![ + 34, 0, 19, 0, 12, 0, 4, 0, 0, 0, 234, 3, 6, 0, 12, 0, 114, 100, 112, 45, 114, 115, + 1, 0, 0, 0, 15, 0, 8, 0, 0, 0, 0, 0 + ] + ) } #[test] fn test_read_synchronize_pdu() { - let mut stream = Cursor::new(vec![22, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 22, 0, 31, 0, 0, 0, 1, 0, 0, 0]); - let mut global = Client::new(0,0, 800, 600, KeyboardLayout::US, "foo"); + let mut stream = Cursor::new(vec![ + 22, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 22, 0, 31, 0, 0, 0, 1, 0, 0, 0, + ]); + let mut global = Client::new(0, 0, 800, 600, KeyboardLayout::US, "foo"); assert!(global.read_synchronize_pdu(&mut stream).unwrap()) } #[test] fn test_read_control_cooperate_pdu() { - let mut stream = Cursor::new(vec![26, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 26, 0, 20, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0]); - let mut global = Client::new(0,0, 800, 600, KeyboardLayout::US, "foo"); - assert!(global.read_control_pdu(&mut stream, Action::CtrlactionCooperate).unwrap()) + let mut stream = Cursor::new(vec![ + 26, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 26, 0, 20, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, + ]); + let mut global = Client::new(0, 0, 800, 600, KeyboardLayout::US, "foo"); + assert!(global + .read_control_pdu(&mut stream, Action::CtrlactionCooperate) + .unwrap()) } #[test] fn test_read_control_granted_pdu() { - let mut stream = Cursor::new(vec![26, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 26, 0, 20, 0, 0, 0, 2, 0, 236, 3, 234, 3, 0, 0]); - let mut global = Client::new(0,0, 800, 600, KeyboardLayout::US, "foo"); - assert!(global.read_control_pdu(&mut stream, Action::CtrlactionGrantedControl).unwrap()) + let mut stream = Cursor::new(vec![ + 26, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 26, 0, 20, 0, 0, 0, 2, 0, 236, 3, 234, 3, 0, + 0, + ]); + let mut global = Client::new(0, 0, 800, 600, KeyboardLayout::US, "foo"); + assert!(global + .read_control_pdu(&mut stream, Action::CtrlactionGrantedControl) + .unwrap()) } #[test] fn test_read_font_map_pdu() { - let mut stream = Cursor::new(vec![26, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 26, 0, 40, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0]); - let mut global = Client::new(0,0, 800, 600, KeyboardLayout::US, "foo"); + let mut stream = Cursor::new(vec![ + 26, 0, 23, 0, 234, 3, 234, 3, 1, 0, 0, 2, 26, 0, 40, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, + ]); + let mut global = Client::new(0, 0, 800, 600, KeyboardLayout::US, "foo"); assert!(global.read_font_map_pdu(&mut stream).unwrap()) } -} \ No newline at end of file +} diff --git a/src/core/license.rs b/src/core/license.rs index 960c6b1..8a57e45 100644 --- a/src/core/license.rs +++ b/src/core/license.rs @@ -1,12 +1,12 @@ -use model::data::{Component, Check, DynOption, U16, MessageOption, U32, DataType, Message}; -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; -use std::io::{Cursor, Read}; +use crate::model::data::{Check, Component, DataType, DynOption, Message, MessageOption, U16, U32}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use num_enum::TryFromPrimitive; use std::convert::TryFrom; +use std::io::{Cursor, Read}; pub enum LicenseMessage { NewLicense, - ErrorAlert(Component) + ErrorAlert(Component), } /// License preambule @@ -16,7 +16,7 @@ pub enum LicenseMessage { enum Preambule { PreambleVersion20 = 0x2, PreambleVersion30 = 0x3, - ExtendedErrorMsgSupported = 0x80 + ExtendedErrorMsgSupported = 0x80, } /// All type of message @@ -32,7 +32,7 @@ pub enum MessageType { LicenseInfo = 0x12, NewLicenseRequest = 0x13, PlatformChallengeResponse = 0x15, - ErrorAlert = 0xFF + ErrorAlert = 0xFF, } /// Error code of the license automata @@ -48,7 +48,7 @@ pub enum ErrorCode { ErrInvalidClient = 0x00000008, ErrInvalidProductid = 0x0000000B, ErrInvalidMessageLen = 0x0000000C, - ErrInvalidMac = 0x00000003 + ErrInvalidMac = 0x00000003, } /// All valid state transition available @@ -60,7 +60,7 @@ pub enum StateTransition { StTotalAbort = 0x00000001, StNoTransition = 0x00000002, StResetPhaseToStart = 0x00000003, - StResendLastMessage = 0x00000004 + StResendLastMessage = 0x00000004, } /// This a license preamble @@ -94,7 +94,6 @@ fn licensing_error_message() -> Component { ] } - /// Parse a payload that follow an preamble /// Actualle we only accept payload with type NewLicense or ErrorAlert fn parse_payload(payload: &Component) -> RdpResult { @@ -106,7 +105,10 @@ fn parse_payload(payload: &Component) -> RdpResult { message.read(&mut stream)?; Ok(LicenseMessage::ErrorAlert(message)) } - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, "Licensing nego not implemented"))) + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + "Licensing nego not implemented", + ))), } } @@ -119,19 +121,24 @@ fn parse_payload(payload: &Component) -> RdpResult { /// ``` /// ``` pub fn client_connect(s: &mut dyn Read) -> RdpResult<()> { - let mut license_message = preamble(); license_message.read(s)?; match parse_payload(&license_message)? { LicenseMessage::NewLicense => Ok(()), LicenseMessage::ErrorAlert(blob) => { - if ErrorCode::try_from(cast!(DataType::U32, blob["dwErrorCode"])?)? == ErrorCode::StatusValidClient && - StateTransition::try_from(cast!(DataType::U32, blob["dwStateTransition"])?)? == StateTransition::StNoTransition { + if ErrorCode::try_from(cast!(DataType::U32, blob["dwErrorCode"])?)? + == ErrorCode::StatusValidClient + && StateTransition::try_from(cast!(DataType::U32, blob["dwStateTransition"])?)? + == StateTransition::StNoTransition + { Ok(()) } else { - Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidRespond, "Server reject license, Actually license nego is not implemented"))) + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidRespond, + "Server reject license, Actually license nego is not implemented", + ))) } } } -} \ No newline at end of file +} diff --git a/src/core/mcs.rs b/src/core/mcs.rs index 8fd6128..c2615e2 100644 --- a/src/core/mcs.rs +++ b/src/core/mcs.rs @@ -1,13 +1,19 @@ -use core::x224; -use core::tpkt; -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; -use core::gcc::{KeyboardLayout, client_core_data, ClientData, ServerData, client_security_data, client_network_data, block_header, write_conference_create_request, MessageType, read_conference_create_response, Version}; -use model::data::{Trame, to_vec, Message, DataType, U16}; -use nla::asn1::{Sequence, ImplicitTag, OctetString, Enumerate, ASN1Type, Integer, to_der, from_ber}; -use yasna::{Tag}; -use std::io::{Write, Read, BufRead, Cursor}; -use core::per; +use crate::core::gcc::{ + block_header, client_core_data, client_network_data, client_security_data, + read_conference_create_response, write_conference_create_request, ClientData, KeyboardLayout, + MessageType, ServerData, Version, +}; +use crate::core::per; +use crate::core::tpkt; +use crate::core::x224; +use crate::model::data::{to_vec, DataType, Message, Trame, U16}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::nla::asn1::{ + from_ber, to_der, ASN1Type, Enumerate, ImplicitTag, Integer, OctetString, Sequence, +}; use std::collections::HashMap; +use std::io::{BufRead, Cursor, Read, Write}; +use yasna::Tag; #[allow(dead_code)] #[repr(u8)] @@ -19,14 +25,21 @@ enum DomainMCSPDU { ChannelJoinRequest = 14, ChannelJoinConfirm = 15, SendDataRequest = 25, - SendDataIndication = 26 + SendDataIndication = 26, } /// ASN1 structure use by mcs layer /// to inform on conference capability -fn domain_parameters(max_channel_ids: u32, maw_user_ids: u32, max_token_ids: u32, - num_priorities: u32, min_thoughput: u32, max_height: u32, - max_mcs_pdu_size: u32, protocol_version: u32) -> Sequence { +fn domain_parameters( + max_channel_ids: u32, + maw_user_ids: u32, + max_token_ids: u32, + num_priorities: u32, + min_thoughput: u32, + max_height: u32, + max_mcs_pdu_size: u32, + protocol_version: u32, +) -> Sequence { sequence![ "maxChannelIds" => max_channel_ids, "maxUserIds" => maw_user_ids, @@ -44,26 +57,31 @@ fn domain_parameters(max_channel_ids: u32, maw_user_ids: u32, max_token_ids: u32 /// /// http://www.itu.int/rec/T-REC-T.125-199802-I/en page 25 fn connect_initial(user_data: Option) -> ImplicitTag { - ImplicitTag::new(Tag::application(101), sequence![ - "callingDomainSelector" => vec![1 as u8] as OctetString, - "calledDomainSelector" => vec![1 as u8] as OctetString, - "upwardFlag" => true, - "targetParameters" => domain_parameters(34, 2, 0, 1, 0, 1, 0xffff, 2), - "minimumParameters" => domain_parameters(1, 1, 1, 1, 0, 1, 0x420, 2), - "maximumParameters" => domain_parameters(0xffff, 0xfc17, 0xffff, 1, 0, 1, 0xffff, 2), - "userData" => user_data.unwrap_or(Vec::new()) - ]) + ImplicitTag::new( + Tag::application(101), + sequence![ + "callingDomainSelector" => vec![1 as u8] as OctetString, + "calledDomainSelector" => vec![1 as u8] as OctetString, + "upwardFlag" => true, + "targetParameters" => domain_parameters(34, 2, 0, 1, 0, 1, 0xffff, 2), + "minimumParameters" => domain_parameters(1, 1, 1, 1, 0, 1, 0x420, 2), + "maximumParameters" => domain_parameters(0xffff, 0xfc17, 0xffff, 1, 0, 1, 0xffff, 2), + "userData" => user_data.unwrap_or(Vec::new()) + ], + ) } /// Server response with channel capacity fn connect_response(user_data: Option) -> ImplicitTag { - ImplicitTag::new(Tag::application(102), -sequence![ - "result" => 0 as Enumerate, - "calledConnectId" => 0 as Integer, - "domainParameters" => domain_parameters(22, 3, 0, 1, 0, 1,0xfff8, 2), - "userData" => user_data.unwrap_or(Vec::new()) - ]) + ImplicitTag::new( + Tag::application(102), + sequence![ + "result" => 0 as Enumerate, + "calledConnectId" => 0 as Integer, + "domainParameters" => domain_parameters(22, 3, 0, 1, 0, 1,0xfff8, 2), + "userData" => user_data.unwrap_or(Vec::new()) + ], + ) } /// Create a basic MCS PDU header @@ -77,13 +95,21 @@ fn mcs_pdu_header(pdu: Option, options: Option) -> u8 { fn read_attach_user_confirm(buffer: &mut dyn Read) -> RdpResult { let mut confirm = trame![0 as u8, Vec::::new()]; confirm.read(buffer)?; - if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::AttachUserConfirm), None) >> 2 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: unexpected header on recv_attach_user_confirm"))); + if cast!(DataType::U8, confirm[0])? >> 2 + != mcs_pdu_header(Some(DomainMCSPDU::AttachUserConfirm), None) >> 2 + { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "MCS: unexpected header on recv_attach_user_confirm", + ))); } let mut request = Cursor::new(cast!(DataType::Slice, confirm[1])?); if per::read_enumerates(&mut request)? != 0 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::RejectedByServer, "MCS: recv_attach_user_confirm user rejected by server"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::RejectedByServer, + "MCS: recv_attach_user_confirm user rejected by server", + ))); } Ok(per::read_integer_16(1001, &mut request)?) } @@ -96,7 +122,6 @@ fn attach_user_request() -> u8 { mcs_pdu_header(Some(DomainMCSPDU::AttachUserRequest), None) } - /// Create a new domain for MCS layer fn erect_domain_request() -> RdpResult { let mut result = Cursor::new(vec![]); @@ -132,11 +157,20 @@ fn channel_join_request(user_id: Option, channel_id: Option) -> RdpRes /// /// Client -- channel_join_request -> Server /// Client <- channel_join_confirm -- Server -fn read_channel_join_confirm(user_id: u16, channel_id: u16, buffer: &mut dyn Read) -> RdpResult { +fn read_channel_join_confirm( + user_id: u16, + channel_id: u16, + buffer: &mut dyn Read, +) -> RdpResult { let mut confirm = trame![0 as u8, Vec::::new()]; confirm.read(buffer)?; - if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::ChannelJoinConfirm), None) >> 2 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: unexpected header on read_channel_join_confirm"))); + if cast!(DataType::U8, confirm[0])? >> 2 + != mcs_pdu_header(Some(DomainMCSPDU::ChannelJoinConfirm), None) >> 2 + { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "MCS: unexpected header on read_channel_join_confirm", + ))); } let mut request = Cursor::new(cast!(DataType::Slice, confirm[1])?); @@ -145,11 +179,17 @@ fn read_channel_join_confirm(user_id: u16, channel_id: u16, buffer: &mut dyn Rea let confirm_channel_id = per::read_integer_16(0, &mut request)?; if user_id != confirm_user_id { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: read_channel_join_confirm invalid user id"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "MCS: read_channel_join_confirm invalid user id", + ))); } if channel_id != confirm_channel_id { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: read_channel_join_confirm invalid channel_id"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "MCS: read_channel_join_confirm invalid channel_id", + ))); } Ok(confirm == 0) @@ -164,7 +204,7 @@ pub struct Client { /// User id session negotiated by the MCS user_id: Option, /// Map that translate channel name to channel id - channel_ids : HashMap + channel_ids: HashMap, } impl Client { @@ -173,28 +213,52 @@ impl Client { server_data: None, x224, user_id: None, - channel_ids: HashMap::new() + channel_ids: HashMap::new(), } } /// Write connection initial payload /// This payload include a lot of /// client specific config parameters - fn write_connect_initial(&mut self, screen_width: u16, screen_height: u16, keyboard_layout: KeyboardLayout, client_name: String) -> RdpResult<()> { + fn write_connect_initial( + &mut self, + screen_width: u16, + screen_height: u16, + keyboard_layout: KeyboardLayout, + client_name: String, + ) -> RdpResult<()> { let client_core_data = client_core_data(Some(ClientData { width: screen_width, height: screen_height, layout: keyboard_layout, server_selected_protocol: self.x224.get_selected_protocols() as u32, rdp_version: Version::RdpVersion5plus, - name: client_name + name: client_name, })); let client_security_data = client_security_data(); let client_network_data = client_network_data(trame![]); let user_data = to_vec(&trame![ - trame![block_header(Some(MessageType::CsCore), Some(client_core_data.length() as u16)), client_core_data], - trame![block_header(Some(MessageType::CsSecurity), Some(client_security_data.length() as u16)), client_security_data], - trame![block_header(Some(MessageType::CsNet), Some(client_network_data.length() as u16)), client_network_data] + trame![ + block_header( + Some(MessageType::CsCore), + Some(client_core_data.length() as u16) + ), + client_core_data + ], + trame![ + block_header( + Some(MessageType::CsSecurity), + Some(client_security_data.length() as u16) + ), + client_security_data + ], + trame![ + block_header( + Some(MessageType::CsNet), + Some(client_network_data.length() as u16) + ), + client_network_data + ] ]); let conference = write_conference_create_request(&user_data)?; self.x224.write(to_der(&connect_initial(Some(conference)))) @@ -210,7 +274,9 @@ impl Client { // Get server data // Read conference create response let cc_response = cast!(ASN1Type::OctetString, connect_response.inner["userData"])?; - self.server_data = Some(read_conference_create_response(&mut Cursor::new(cc_response))?); + self.server_data = Some(read_conference_create_response(&mut Cursor::new( + cc_response, + ))?); Ok(()) } @@ -223,23 +289,38 @@ impl Client { /// let mut mcs = mcs::Client(x224); /// mcs.connect(800, 600, KeyboardLayout::French).unwrap() /// ``` - pub fn connect(&mut self, client_name: String, screen_width: u16, screen_height: u16, keyboard_layout: KeyboardLayout) -> RdpResult<()> { + pub fn connect( + &mut self, + client_name: String, + screen_width: u16, + screen_height: u16, + keyboard_layout: KeyboardLayout, + ) -> RdpResult<()> { self.write_connect_initial(screen_width, screen_height, keyboard_layout, client_name)?; self.read_connect_response()?; self.x224.write(erect_domain_request()?)?; self.x224.write(attach_user_request())?; - self.user_id = Some(read_attach_user_confirm(&mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?)?); + self.user_id = Some(read_attach_user_confirm(&mut try_let!( + tpkt::Payload::Raw, + self.x224.read()? + )?)?); // Add static channel self.channel_ids.insert("global".to_string(), 1003); - self.channel_ids.insert("user".to_string(), self.user_id.unwrap()); + self.channel_ids + .insert("user".to_string(), self.user_id.unwrap()); // Create list of requested channels // Actually only the two static main channel are requested for channel_id in self.channel_ids.values() { - self.x224.write(channel_join_request(self.user_id, Some(*channel_id))?)?; - if !read_channel_join_confirm(self.user_id.unwrap(), *channel_id, &mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?)? { + self.x224 + .write(channel_join_request(self.user_id, Some(*channel_id))?)?; + if !read_channel_join_confirm( + self.user_id.unwrap(), + *channel_id, + &mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?, + )? { println!("Server reject channel id {:?}", channel_id); } } @@ -258,7 +339,9 @@ impl Client { /// mcs.write("global".to_string(), trame![U16::LE(0)]) /// ``` pub fn write(&mut self, channel_name: &String, message: T) -> RdpResult<()> - where T: Message { + where + T: Message, + { self.x224.write(trame![ mcs_pdu_header(Some(DomainMCSPDU::SendDataRequest), None), U16::BE(self.user_id.unwrap() - 1001), @@ -287,33 +370,48 @@ impl Client { let message = self.x224.read()?; match message { tpkt::Payload::Raw(mut payload) => { - let mut header = mcs_pdu_header(None, None); + let mut header = mcs_pdu_header(None, None); header.read(&mut payload)?; if header >> 2 == DomainMCSPDU::DisconnectProviderUltimatum as u8 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::Disconnect, "MCS: Disconnect Provider Ultimatum"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::Disconnect, + "MCS: Disconnect Provider Ultimatum", + ))); } if header >> 2 != DomainMCSPDU::SendDataIndication as u8 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: Invalid opcode"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "MCS: Invalid opcode", + ))); } // Server user id per::read_integer_16(1001, &mut payload)?; let channel_id = per::read_integer_16(0, &mut payload)?; - let channel = self.channel_ids.iter().find(|x| *x.1 == channel_id).ok_or(Error::RdpError(RdpError::new(RdpErrorKind::Unknown, "MCS: unknown channel")))?; + let channel = + self.channel_ids + .iter() + .find(|x| *x.1 == channel_id) + .ok_or(Error::RdpError(RdpError::new( + RdpErrorKind::Unknown, + "MCS: unknown channel", + )))?; per::read_enumerates(&mut payload)?; per::read_length(&mut payload)?; Ok((channel.0.clone(), tpkt::Payload::Raw(payload))) - }, + } tpkt::Payload::FastPath(sec_flag, payload) => { // fastpath packet are dedicated to global channel - Ok(("global".to_string(), tpkt::Payload::FastPath(sec_flag, payload))) + Ok(( + "global".to_string(), + tpkt::Payload::FastPath(sec_flag, payload), + )) } } - } /// Send a close event to server @@ -350,7 +448,10 @@ mod test { /// Test of read read_attach_user_confirm #[test] fn test_read_attach_user_confirm() { - assert_eq!(read_attach_user_confirm(&mut Cursor::new(vec![46, 0, 0, 3])).unwrap(), 1004) + assert_eq!( + read_attach_user_confirm(&mut Cursor::new(vec![46, 0, 0, 3])).unwrap(), + 1004 + ) } /// Attach user request payload @@ -368,27 +469,48 @@ mod test { /// Test format of the channel join request #[test] fn test_channel_join_request() { - assert_eq!(to_vec(&channel_join_request(None, None).unwrap()), [56, 0, 0, 0, 0]) + assert_eq!( + to_vec(&channel_join_request(None, None).unwrap()), + [56, 0, 0, 0, 0] + ) } /// Test domain parameters format #[test] fn test_domain_parameters() { - let result = to_der(&domain_parameters(1,2,3,4, 5, 6, 7, 8)); - assert_eq!(result, vec![48, 24, 2, 1, 1, 2, 1, 2, 2, 1, 3, 2, 1, 4, 2, 1, 5, 2, 1, 6, 2, 1, 7, 2, 1, 8]) + let result = to_der(&domain_parameters(1, 2, 3, 4, 5, 6, 7, 8)); + assert_eq!( + result, + vec![48, 24, 2, 1, 1, 2, 1, 2, 2, 1, 3, 2, 1, 4, 2, 1, 5, 2, 1, 6, 2, 1, 7, 2, 1, 8] + ) } /// Test connect initial #[test] fn test_connect_initial() { let result = to_der(&connect_initial(Some(vec![1, 2, 3]))); - assert_eq!(result, vec![127, 101, 103, 4, 1, 1, 4, 1, 1, 1, 1, 255, 48, 26, 2, 1, 34, 2, 1, 2, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, 1, 2, 48, 25, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 2, 4, 32, 2, 1, 2, 48, 32, 2, 3, 0, 255, 255, 2, 3, 0, 252, 23, 2, 3, 0, 255, 255, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, 1, 2, 4, 3, 1, 2, 3]) + assert_eq!( + result, + vec![ + 127, 101, 103, 4, 1, 1, 4, 1, 1, 1, 1, 255, 48, 26, 2, 1, 34, 2, 1, 2, 2, 1, 0, 2, + 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, 1, 2, 48, 25, 2, 1, 1, 2, 1, 1, 2, 1, + 1, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 2, 4, 32, 2, 1, 2, 48, 32, 2, 3, 0, 255, 255, 2, + 3, 0, 252, 23, 2, 3, 0, 255, 255, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, + 1, 2, 4, 3, 1, 2, 3 + ] + ) } /// Test connect response #[test] fn test_connect_response() { let result = to_der(&connect_response(Some(vec![1, 2, 3]))); - assert_eq!(result, vec![127, 102, 39, 10, 1, 0, 2, 1, 0, 48, 26, 2, 1, 22, 2, 1, 3, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 248, 2, 1, 2, 4, 3, 1, 2, 3]) + assert_eq!( + result, + vec![ + 127, 102, 39, 10, 1, 0, 2, 1, 0, 48, 26, 2, 1, 22, 2, 1, 3, 2, 1, 0, 2, 1, 1, 2, 1, + 0, 2, 1, 1, 2, 3, 0, 255, 248, 2, 1, 2, 4, 3, 1, 2, 3 + ] + ) } -} \ No newline at end of file +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 329d6d4..357171e 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,11 +1,11 @@ -pub mod tpkt; -pub mod x224; +pub mod capability; pub mod client; -pub mod mcs; +pub mod event; pub mod gcc; +pub mod global; +pub mod license; +pub mod mcs; pub mod per; pub mod sec; -pub mod license; -pub mod global; -pub mod capability; -pub mod event; \ No newline at end of file +pub mod tpkt; +pub mod x224; diff --git a/src/core/per.rs b/src/core/per.rs index 453073c..6d39b42 100644 --- a/src/core/per.rs +++ b/src/core/per.rs @@ -1,7 +1,6 @@ -use model::data::{Message, U16, Trame, U32}; +use crate::model::data::{Message, Trame, U16, U32}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use std::io::{Read, Write}; -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; - /// PER encoding length /// read length of following payload @@ -19,12 +18,11 @@ pub fn read_length(s: &mut dyn Read) -> RdpResult { byte.read(s)?; if byte & 0x80 != 0 { byte = byte & !0x80; - let mut size = (byte as u16) << 8 ; + let mut size = (byte as u16) << 8; byte.read(s)?; size += byte as u16; Ok(size) - } - else { + } else { Ok(byte as u16) } } @@ -45,8 +43,7 @@ pub fn read_length(s: &mut dyn Read) -> RdpResult { pub fn write_length(length: u16) -> RdpResult { if length > 0x7f { Ok(trame![U16::BE(length | 0x8000)]) - } - else { + } else { Ok(trame![length as u8]) } } @@ -61,7 +58,7 @@ pub fn write_length(length: u16) -> RdpResult { /// assert_eq!(read_choice(&mut s).unwrap(), 1) /// ``` pub fn read_choice(s: &mut dyn Read) -> RdpResult { - let mut result : u8 = 0; + let mut result: u8 = 0; result.read(s)?; Ok(result) } @@ -92,7 +89,7 @@ pub fn write_choice(choice: u8, s: &mut dyn Write) -> RdpResult<()> { /// assert_eq!(read_selection(&mut s).unwrap(), 1) /// ``` pub fn read_selection(s: &mut dyn Read) -> RdpResult { - let mut result : u8 = 0; + let mut result: u8 = 0; result.read(s)?; Ok(result) } @@ -123,7 +120,7 @@ pub fn write_selection(selection: u8, s: &mut dyn Write) -> RdpResult<()> { /// assert_eq!(read_number_of_set(&mut s).unwrap(), 1) /// ``` pub fn read_number_of_set(s: &mut dyn Read) -> RdpResult { - let mut result : u8 = 0; + let mut result: u8 = 0; result.read(s)?; Ok(result) } @@ -154,7 +151,7 @@ pub fn write_number_of_set(number_of_set: u8, s: &mut dyn Write) -> RdpResult<() /// assert_eq!(read_enumerates(&mut s).unwrap(), 1) /// ``` pub fn read_enumerates(s: &mut dyn Read) -> RdpResult { - let mut result : u8 = 0; + let mut result: u8 = 0; result.read(s)?; Ok(result) } @@ -195,18 +192,21 @@ pub fn read_integer(s: &mut dyn Read) -> RdpResult { let mut result: u8 = 0; result.read(s)?; Ok(result as u32) - }, + } 2 => { let mut result = U16::BE(0); result.read(s)?; Ok(result.inner() as u32) - }, + } 4 => { let mut result = U32::BE(0); result.read(s)?; Ok(result.inner() as u32) - }, - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "PER integer encoded with an invalid size"))) + } + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "PER integer encoded with an invalid size", + ))), } } @@ -240,7 +240,6 @@ pub fn write_integer(integer: u32, s: &mut dyn Write) -> RdpResult<()> { Ok(()) } - /// Read u16 integer PER encoded /// /// # Example @@ -271,7 +270,6 @@ pub fn write_integer_16(integer: u16, minimum: u16, s: &mut dyn Write) -> RdpRes Ok(()) } - /// Read an object identifier encoded in PER /// /// # Example @@ -287,16 +285,22 @@ pub fn write_integer_16(integer: u16, minimum: u16, s: &mut dyn Write) -> RdpRes /// ``` pub fn read_object_identifier(oid: &[u8], s: &mut dyn Read) -> RdpResult { if oid.len() != 6 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "Oid to check have an invalid size"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "Oid to check have an invalid size", + ))); } let length = read_length(s)?; if length != 5 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "Oid source have an invalid size"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "Oid source have an invalid size", + ))); } let mut oid_parsed = [0; 6]; - let mut tmp : u8 = 0; + let mut tmp: u8 = 0; tmp.read(s)?; oid_parsed[0] = tmp >> 4; @@ -323,9 +327,12 @@ pub fn read_object_identifier(oid: &[u8], s: &mut dyn Read) -> RdpResult { /// write_object_identifier(&[1, 2, 3, 4, 5, 6], &mut s).unwrap(); /// assert_eq!(s.into_inner(), [5, 0x12, 3, 4, 5, 6]); /// ``` -pub fn write_object_identifier(oid: &[u8], s: &mut dyn Write) ->RdpResult<()> { +pub fn write_object_identifier(oid: &[u8], s: &mut dyn Write) -> RdpResult<()> { if oid.len() != 6 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "PER: oid source don't have the correct size"))) + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "PER: oid source don't have the correct size", + ))); } trame![ @@ -335,7 +342,8 @@ pub fn write_object_identifier(oid: &[u8], s: &mut dyn Write) ->RdpResult<()> { oid[3], oid[4], oid[5] - ].write(s) + ] + .write(s) } /// Read a numeric string @@ -354,7 +362,7 @@ pub fn read_numeric_string(minimum: usize, s: &mut dyn Read) -> RdpResult RdpResult<()> { +pub fn write_numeric_string(string: &[u8], minimum: usize, s: &mut dyn Write) -> RdpResult<()> { let mut length = string.len(); if length as i64 - minimum as i64 >= 0 { length -= minimum; @@ -365,7 +373,7 @@ pub fn write_numeric_string(string: &[u8], minimum: usize, s: &mut dyn Write) - for i in 0..string.len() { let mut c1 = string[i]; let mut c2 = if i + 1 < string.len() { - string[i+1] + string[i + 1] } else { 0x30 }; @@ -404,13 +412,19 @@ pub fn write_padding(length: usize, s: &mut dyn Write) -> RdpResult<()> { pub fn read_octet_stream(octet_stream: &[u8], minimum: usize, s: &mut dyn Read) -> RdpResult<()> { let length = read_length(s)? as usize + minimum; if length != octet_stream.len() { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "PER: source octet string have an invalid size"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "PER: source octet string have an invalid size", + ))); } for i in 0..length { let mut c: u8 = 0; c.read(s)?; if c != octet_stream[i] { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "PER: source octet string have an invalid char"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "PER: source octet string have an invalid char", + ))); } } @@ -427,4 +441,4 @@ pub fn write_octet_stream(octet_string: &[u8], minimum: usize, s: &mut dyn Write octet_string.to_vec().write(s)?; Ok(()) -} \ No newline at end of file +} diff --git a/src/core/sec.rs b/src/core/sec.rs index 606cd05..8b0d0cb 100644 --- a/src/core/sec.rs +++ b/src/core/sec.rs @@ -1,10 +1,10 @@ -use core::mcs; -use core::license; -use core::tpkt; -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; -use model::data::{Message, Component, U16, U32, DynOption, MessageOption, Trame, DataType}; -use std::io::{Write, Read}; -use model::unicode::Unicode; +use crate::core::license; +use crate::core::mcs; +use crate::core::tpkt; +use crate::model::data::{Component, DataType, DynOption, Message, MessageOption, Trame, U16, U32}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::model::unicode::Unicode; +use std::io::{Read, Write}; /// Security flag send as header flage in core ptotocol /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e13405c5-668b-4716-94b2-1c2654ca1ad4?redirectedfrom=MSDN @@ -25,7 +25,7 @@ enum SecurityFlag { SecAutodetectReq = 0x1000, SecAutodetectRsp = 0x2000, SecHeartbeat = 0x4000, - SecFlagshiValid = 0x8000 + SecFlagshiValid = 0x8000, } /// RDP option someone links to capabilities @@ -50,13 +50,13 @@ enum InfoFlag { InfoUsingSavedCreds = 0x00100000, InfoAudiocapture = 0x00200000, InfoVideoDisable = 0x00400000, - InfoCompressionTypeMask = 0x00001E00 + InfoCompressionTypeMask = 0x00001E00, } #[allow(dead_code)] enum AfInet { AfInet = 0x00002, - AfInet6 = 0x0017 + AfInet6 = 0x0017, } /// On RDP version > 5 @@ -77,7 +77,13 @@ fn rdp_extended_infos() -> Component { /// When CSSP is not used /// interactive logon used credentials /// present in this payload -fn rdp_infos(is_extended_info: bool, domain: &String, username: &String, password: &String, auto_logon: bool) -> Component { +fn rdp_infos( + is_extended_info: bool, + domain: &String, + username: &String, + password: &String, + auto_logon: bool, +) -> Component { let mut domain_format = domain.to_unicode(); domain_format.push(0); domain_format.push(0); @@ -123,7 +129,6 @@ fn security_header() -> Component { ] } - /// Security layer need mcs layer and send all message through /// the global channel /// @@ -136,7 +141,13 @@ fn security_header() -> Component { /// let mut mcs = mcs::Client(...).unwrap(); /// sec::connect(&mut mcs).unwrap(); /// ``` -pub fn connect(mcs: &mut mcs::Client, domain: &String, username: &String, password: &String, auto_logon: bool) -> RdpResult<()> { +pub fn connect( + mcs: &mut mcs::Client, + domain: &String, + username: &String, + password: &String, + auto_logon: bool, +) -> RdpResult<()> { mcs.write( &"global".to_string(), trame![ @@ -149,7 +160,7 @@ pub fn connect(mcs: &mut mcs::Client, domain: &String, usern password, auto_logon ) - ] + ], )?; let (_channel_name, payload) = mcs.read()?; @@ -157,11 +168,12 @@ pub fn connect(mcs: &mut mcs::Client, domain: &String, usern let mut header = security_header(); header.read(&mut stream)?; if cast!(DataType::U16, header["securityFlag"])? & SecurityFlag::SecLicensePkt as u16 == 0 { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "SEC: Invalid Licence packet"))); + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "SEC: Invalid Licence packet", + ))); } license::client_connect(&mut stream)?; Ok(()) } - - diff --git a/src/core/tpkt.rs b/src/core/tpkt.rs index cb06e37..49534d7 100644 --- a/src/core/tpkt.rs +++ b/src/core/tpkt.rs @@ -1,14 +1,14 @@ -use model::link::{Link}; -use model::data::{Message, U16, Component, Trame}; -use model::error::{RdpResult, RdpError, RdpErrorKind, Error}; -use std::io::{Cursor, Write, Read}; -use nla::cssp::cssp_connect; -use nla::sspi::AuthenticationProtocol; +use crate::model::data::{Component, Message, Trame, U16}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::model::link::Link; +use crate::nla::cssp::cssp_connect; +use crate::nla::sspi::AuthenticationProtocol; +use std::io::{Cursor, Read, Write}; /// TPKT must implement this two kind of payload pub enum Payload { Raw(Cursor>), - FastPath(u8, Cursor>) + FastPath(u8, Cursor>), } /// TPKT action header @@ -17,7 +17,7 @@ pub enum Payload { #[derive(Copy, Clone)] pub enum Action { FastPathActionFastPath = 0x0, - FastPathActionX224 = 0x3 + FastPathActionX224 = 0x3, } /// TPKT layer header @@ -42,15 +42,13 @@ fn tpkt_header(size: u16) -> Component { /// let tpkt_client = Client::new(Link::new(Stream::Raw(stream))); /// ``` pub struct Client { - transport: Link + transport: Link, } impl Client { /// Ctor of TPKT client layer - pub fn new (transport: Link) -> Self { - Client { - transport - } + pub fn new(transport: Link) -> Self { + Client { transport } } /// Send a message to the link layer @@ -79,13 +77,11 @@ impl Client { /// } /// ``` pub fn write(&mut self, message: T) -> RdpResult<()> - where T: Message { - self.transport.write( - &trame![ - tpkt_header(message.length() as u16), - message - ] - ) + where + T: Message, + { + self.transport + .write(&trame![tpkt_header(message.length() as u16), message]) } /// Read a payload from the underlying layer @@ -126,7 +122,6 @@ impl Client { let mut action: u8 = 0; action.read(&mut buffer)?; if action == Action::FastPathActionX224 as u8 { - // read padding let mut padding: u8 = 0; padding.read(&mut buffer)?; @@ -139,11 +134,15 @@ impl Client { // Minimal size must be 7 // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/18a27ef9-6f9a-4501-b000-94b1fe3c2c10 if size.inner() < 4 { - Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "Invalid minimal size for TPKT"))) - } - else { + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "Invalid minimal size for TPKT", + ))) + } else { // now wait for body - Ok(Payload::Raw(Cursor::new(self.transport.read(size.inner() as usize - 4)?))) + Ok(Payload::Raw(Cursor::new( + self.transport.read(size.inner() as usize - 4)?, + ))) } } else { // fast path @@ -156,19 +155,30 @@ impl Client { let length: u16 = ((short_length & !0x80) as u16) << 8; let length = length | hi_length as u16; if length < 3 { - Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "Invalid minimal size for TPKT"))) + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "Invalid minimal size for TPKT", + ))) } else { - Ok(Payload::FastPath(sec_flag, Cursor::new(self.transport.read(length as usize - 3)?))) + Ok(Payload::FastPath( + sec_flag, + Cursor::new(self.transport.read(length as usize - 3)?), + )) } - } - else { + } else { if short_length < 2 { - Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidSize, "Invalid minimal size for TPKT"))) + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "Invalid minimal size for TPKT", + ))) } else { - Ok(Payload::FastPath(sec_flag, Cursor::new(self.transport.read(short_length as usize - 2)?))) + Ok(Payload::FastPath( + sec_flag, + Cursor::new(self.transport.read(short_length as usize - 2)?), + )) } } - } + } } /// This function transform the link layer with @@ -202,7 +212,12 @@ impl Client { /// let mut tpkt = tpkt::Client::new(link::Link::new(link::Stream::Raw(tcp))); /// let mut tpkt_nla = tpkt.start_nla(false, &mut Ntlm::new("domain".to_string(), "username".to_string(), "password".to_string()), false); /// ``` - pub fn start_nla(self, check_certificate: bool, authentication_protocol: &mut dyn AuthenticationProtocol, restricted_admin_mode: bool) -> RdpResult> { + pub fn start_nla( + self, + check_certificate: bool, + authentication_protocol: &mut dyn AuthenticationProtocol, + restricted_admin_mode: bool, + ) -> RdpResult> { let mut link = self.transport.start_ssl(check_certificate)?; cssp_connect(&mut link, authentication_protocol, restricted_admin_mode)?; Ok(Client::new(link)) @@ -222,18 +237,15 @@ impl Client { #[cfg(test)] mod test { use super::*; + use crate::model::data::{DataType, U32}; + use crate::model::link::Stream; use std::io::Cursor; - use model::data::{U32, DataType}; - use model::link::Stream; /// Test the tpkt header type in write context #[test] fn test_write_tpkt_header() { let x = U32::BE(1); - let message = trame![ - tpkt_header(x.length() as u16), - x - ]; + let message = trame![tpkt_header(x.length() as u16), x]; let mut buffer = Cursor::new(Vec::new()); message.write(&mut buffer).unwrap(); assert_eq!(buffer.get_ref().as_slice(), [3, 0, 0, 8, 0, 0, 0, 1]); @@ -242,11 +254,14 @@ mod test { /// Test read of TPKT header #[test] fn test_read_tpkt_header() { - let mut message = tpkt_header(0); + let mut message = tpkt_header(0); let mut buffer = Cursor::new([3, 0, 0, 8, 0, 0, 0, 1]); message.read(&mut buffer).unwrap(); assert_eq!(cast!(DataType::U16, message["size"]).unwrap(), 8); - assert_eq!(cast!(DataType::U8, message["action"]).unwrap(), Action::FastPathActionX224 as u8); + assert_eq!( + cast!(DataType::U8, message["action"]).unwrap(), + Action::FastPathActionX224 as u8 + ); } fn process(data: &[u8]) { diff --git a/src/core/x224.rs b/src/core/x224.rs index 9a19062..f4db4a8 100644 --- a/src/core/x224.rs +++ b/src/core/x224.rs @@ -1,11 +1,11 @@ -use core::tpkt; -use model::data::{Message, Check, U16, U32, Component, DataType, Trame}; -use model::error::{Error, RdpError, RdpResult, RdpErrorKind}; -use std::io::{Read, Write}; -use std::option::{Option}; -use nla::sspi::AuthenticationProtocol; +use crate::core::tpkt; +use crate::model::data::{Check, Component, DataType, Message, Trame, U16, U32}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::nla::sspi::AuthenticationProtocol; use num_enum::TryFromPrimitive; use std::convert::TryFrom; +use std::io::{Read, Write}; +use std::option::Option; #[repr(u8)] #[derive(Copy, Clone, TryFromPrimitive)] @@ -19,7 +19,7 @@ pub enum NegotiationType { /// Negotiation failure /// Send when security level are not expected /// Server ask for NLA and client doesn't support it - TypeRDPNegFailure = 0x03 + TypeRDPNegFailure = 0x03, } #[repr(u32)] @@ -33,7 +33,7 @@ pub enum Protocols { /// Network Level Authentication over SSL ProtocolHybrid = 0x02, /// NLA + SSL + Quick respond - ProtocolHybridEx = 0x08 + ProtocolHybridEx = 0x08, } #[derive(Copy, Clone)] @@ -42,7 +42,7 @@ pub enum MessageType { X224TPDUConnectionConfirm = 0xD0, X224TPDUDisconnectRequest = 0x80, X224TPDUData = 0xF0, - X224TPDUError = 0x70 + X224TPDUError = 0x70, } /// Credential mode @@ -55,7 +55,7 @@ pub enum RequestMode { /// New feature present in lastest windows 10 /// Can't support acctually RedirectedAuthenticationModeRequired = 0x02, - CorrelationInfoPresent = 0x08 + CorrelationInfoPresent = 0x08, } /// RDP Negotiation Request @@ -63,7 +63,11 @@ pub enum RequestMode { /// Security protocol /// /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/902b090b-9cb3-4efc-92bf-ee13373371e3 -fn rdp_neg_req(neg_type: Option, result: Option, flag: Option) -> Component { +fn rdp_neg_req( + neg_type: Option, + result: Option, + flag: Option, +) -> Component { component! [ "type" => neg_type.unwrap_or(NegotiationType::TypeRDPNegReq) as u8, "flag" => flag.unwrap_or(0), @@ -87,12 +91,9 @@ fn x224_crq(len: u8, code: MessageType) -> Component { fn x224_connection_pdu( neg_type: Option, mode: Option, - protocols: Option) -> Component { - let negotiation = rdp_neg_req( - neg_type, - protocols, - mode - ); + protocols: Option, +) -> Component { + let negotiation = rdp_neg_req(neg_type, protocols, mode); component![ "header" => x224_crq(negotiation.length() as u8, MessageType::X224TPDUConnectionRequest), @@ -114,15 +115,15 @@ pub struct Client { /// Transport layer, x224 use a tpkt transport: tpkt::Client, /// Security selected protocol by the connector - selected_protocol: Protocols + selected_protocol: Protocols, } impl Client { /// Constructor use by the connector - fn new (transport: tpkt::Client, selected_protocol: Protocols) -> Self { + fn new(transport: tpkt::Client, selected_protocol: Protocols) -> Self { Client { transport, - selected_protocol + selected_protocol, } } @@ -141,7 +142,9 @@ impl Client { /// x224.write(trame![U16::LE(0)]).unwrap() /// ``` pub fn write(&mut self, message: T) -> RdpResult<()> - where T: Message { + where + T: Message, + { self.transport.write(trame![x224_header(), message]) } @@ -167,13 +170,12 @@ impl Client { let mut x224_header = x224_header(); x224_header.read(&mut payload)?; Ok(tpkt::Payload::Raw(payload)) - }, + } tpkt::Payload::FastPath(flag, payload) => { // nothing to do Ok(tpkt::Payload::FastPath(flag, payload)) } } - } /// Launch the connection sequence of the x224 stack @@ -205,43 +207,77 @@ impl Client { /// false /// ).unwrap() /// ``` - pub fn connect(mut tpkt: tpkt::Client, security_protocols: u32, check_certificate: bool, authentication_protocol: Option<&mut dyn AuthenticationProtocol>, restricted_admin_mode: bool, blank_creds: bool) -> RdpResult> { - Self::write_connection_request(&mut tpkt, security_protocols, Some(if restricted_admin_mode { RequestMode::RestrictedAdminModeRequired as u8} else { 0 }))?; + pub fn connect( + mut tpkt: tpkt::Client, + security_protocols: u32, + check_certificate: bool, + authentication_protocol: Option<&mut dyn AuthenticationProtocol>, + restricted_admin_mode: bool, + blank_creds: bool, + ) -> RdpResult> { + Self::write_connection_request( + &mut tpkt, + security_protocols, + Some(if restricted_admin_mode { + RequestMode::RestrictedAdminModeRequired as u8 + } else { + 0 + }), + )?; match Self::read_connection_confirm(&mut tpkt)? { - Protocols::ProtocolHybrid => Ok(Client::new(tpkt.start_nla(check_certificate, authentication_protocol.unwrap(), restricted_admin_mode || blank_creds)?,Protocols::ProtocolHybrid)), - Protocols::ProtocolSSL => Ok(Client::new(tpkt.start_ssl(check_certificate)?, Protocols::ProtocolSSL)), + Protocols::ProtocolHybrid => Ok(Client::new( + tpkt.start_nla( + check_certificate, + authentication_protocol.unwrap(), + restricted_admin_mode || blank_creds, + )?, + Protocols::ProtocolHybrid, + )), + Protocols::ProtocolSSL => Ok(Client::new( + tpkt.start_ssl(check_certificate)?, + Protocols::ProtocolSSL, + )), Protocols::ProtocolRDP => Ok(Client::new(tpkt, Protocols::ProtocolRDP)), - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidProtocol, "Security protocol not handled"))) + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidProtocol, + "Security protocol not handled", + ))), } } /// Send connection request - fn write_connection_request(tpkt: &mut tpkt::Client, security_protocols: u32, mode: Option) -> RdpResult<()> { - tpkt.write( - x224_connection_pdu( - Some(NegotiationType::TypeRDPNegReq), - mode, - Some(security_protocols) - ) - ) + fn write_connection_request( + tpkt: &mut tpkt::Client, + security_protocols: u32, + mode: Option, + ) -> RdpResult<()> { + tpkt.write(x224_connection_pdu( + Some(NegotiationType::TypeRDPNegReq), + mode, + Some(security_protocols), + )) } /// Expect a connection confirm payload fn read_connection_confirm(tpkt: &mut tpkt::Client) -> RdpResult { let mut buffer = try_let!(tpkt::Payload::Raw, tpkt.read()?)?; - let mut confirm = x224_connection_pdu( - None, - None, - None - ); + let mut confirm = x224_connection_pdu(None, None, None); confirm.read(&mut buffer)?; let nego = cast!(DataType::Component, confirm["negotiation"]).unwrap(); match NegotiationType::try_from(cast!(DataType::U8, nego["type"])?)? { - NegotiationType::TypeRDPNegFailure => Err(Error::RdpError(RdpError::new(RdpErrorKind::ProtocolNegFailure, "Error during negotiation step"))), - NegotiationType::TypeRDPNegReq => Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidAutomata, "Server reject security protocols"))), - NegotiationType::TypeRDPNegRsp => Ok(Protocols::try_from(cast!(DataType::U32, nego["result"])?)?) + NegotiationType::TypeRDPNegFailure => Err(Error::RdpError(RdpError::new( + RdpErrorKind::ProtocolNegFailure, + "Error during negotiation step", + ))), + NegotiationType::TypeRDPNegReq => Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidAutomata, + "Server reject security protocols", + ))), + NegotiationType::TypeRDPNegRsp => { + Ok(Protocols::try_from(cast!(DataType::U32, nego["result"])?)?) + } } } @@ -264,7 +300,9 @@ mod test { #[test] fn test_rdp_neg_req() { let mut s = Cursor::new(vec![]); - rdp_neg_req(Some(NegotiationType::TypeRDPNegRsp), Some(1), Some(0)).write(&mut s).unwrap(); + rdp_neg_req(Some(NegotiationType::TypeRDPNegRsp), Some(1), Some(0)) + .write(&mut s) + .unwrap(); assert_eq!(s.into_inner(), vec![2, 0, 8, 0, 1, 0, 0, 0]) } @@ -272,7 +310,9 @@ mod test { #[test] fn test_x224_crq() { let mut s = Cursor::new(vec![]); - x224_crq(20, MessageType::X224TPDUData).write(&mut s).unwrap(); + x224_crq(20, MessageType::X224TPDUData) + .write(&mut s) + .unwrap(); assert_eq!(s.into_inner(), vec![26, 240, 0, 0, 0, 0, 0]) } @@ -288,7 +328,12 @@ mod test { #[test] fn test_x224_connection_pdu() { let mut s = Cursor::new(vec![]); - x224_connection_pdu(Some(NegotiationType::TypeRDPNegReq), Some(0), Some(3)).write(&mut s).unwrap(); - assert_eq!(s.into_inner(), vec![14, 224, 0, 0, 0, 0, 0, 1, 0, 8, 0, 3, 0, 0, 0]) + x224_connection_pdu(Some(NegotiationType::TypeRDPNegReq), Some(0), Some(3)) + .write(&mut s) + .unwrap(); + assert_eq!( + s.into_inner(), + vec![14, 224, 0, 0, 0, 0, 0, 1, 0, 8, 0, 3, 0, 0, 0] + ) } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index d636315..b80a3b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,26 +1,6 @@ -extern crate byteorder; -extern crate indexmap; -extern crate yasna; -extern crate native_tls; -extern crate md4; -extern crate hmac; -extern crate md5; -extern crate rand; -extern crate num_bigint; -extern crate x509_parser; -extern crate num_enum; -#[cfg(feature = "mstsc-rs")] -extern crate minifb; -#[cfg(feature = "mstsc-rs")] -extern crate winapi; -#[cfg(feature = "mstsc-rs")] -extern crate hex; -#[cfg(feature = "mstsc-rs")] -extern crate clap; - #[macro_use] pub mod model; #[macro_use] pub mod nla; -pub mod core; pub mod codec; +pub mod core; diff --git a/src/model/data.rs b/src/model/data.rs index 98ed3c8..472916e 100644 --- a/src/model/data.rs +++ b/src/model/data.rs @@ -1,9 +1,8 @@ -use std::io::{Write, Read, Cursor}; -use model::error::{RdpResult, RdpErrorKind, RdpError, Error}; -use byteorder::{WriteBytesExt, ReadBytesExt, LittleEndian, BigEndian}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use byteorder::{BigEndian, LittleEndian, ReadBytesExt, WriteBytesExt}; use indexmap::IndexMap; -use std::collections::{HashSet, HashMap}; - +use std::collections::{HashMap, HashSet}; +use std::io::{Cursor, Read, Write}; /// All data type used /// @@ -42,10 +41,9 @@ pub enum DataType<'a> { /// A slice is just a raw u8 of vector Slice(&'a [u8]), /// Optional value can be absent - None + None, } - /// Retrieve leaf value into a type tree /// /// This is a facilitate macro use to visit a type tree @@ -67,10 +65,15 @@ pub enum DataType<'a> { /// ``` #[macro_export] macro_rules! cast { - ($ident:path, $expr:expr) => (match $expr.visit() { - $ident(e) => Ok(e), - _ => Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidCast, "Invalid Cast"))) - }) + ($ident:path, $expr:expr) => { + match $expr.visit() { + $ident(e) => Ok(e), + _ => Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidCast, + "Invalid Cast", + ))), + } + }; } /// Allow to a son to inform parent of something special @@ -87,14 +90,14 @@ pub enum MessageOption { /// for a particular field Size(String, usize), /// Non option - None + None, } /// All is a message /// /// A message can be Read or Write from a Stream /// -pub trait Message : Send { +pub trait Message: Send { /// Write node to the Stream /// /// Write current element into a writable stream @@ -126,7 +129,6 @@ pub trait Message : Send { /// /// Implement Message trait for basic type u8 impl Message for u8 { - /// Write u8 value into stream /// # Example /// @@ -141,7 +143,7 @@ impl Message for u8 { /// assert_eq!(*s.get_ref(), vec![8 as u8]); /// # } /// ``` - fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { + fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { Ok(writer.write_u8(*self)?) } @@ -267,9 +269,9 @@ impl Message for Trame { /// assert_eq!(s.into_inner(), [0, 2, 0, 0, 0]) /// # } /// ``` - fn write(&self, writer: &mut dyn Write) -> RdpResult<()>{ + fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { for v in self { - v.write(writer)?; + v.write(writer)?; } Ok(()) } @@ -294,9 +296,9 @@ impl Message for Trame { /// assert_eq!(cast!(DataType::U32, x[1]).unwrap(), 3); /// # } /// ``` - fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()>{ + fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()> { for v in self { - v.read(reader)?; + v.read(reader)?; } Ok(()) } @@ -317,7 +319,7 @@ impl Message for Trame { /// # } /// ``` fn length(&self) -> u64 { - let mut sum : u64 = 0; + let mut sum: u64 = 0; for v in self { sum += v.length(); } @@ -385,7 +387,7 @@ impl Message for Component { /// assert_eq!(s.into_inner(), [3, 6, 0, 0, 0]) /// # } /// ``` - fn write(&self, writer: &mut dyn Write) -> RdpResult<()>{ + fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { let mut filtering_key = HashSet::new(); for (name, value) in self.iter() { // ignore filtering keys @@ -420,7 +422,7 @@ impl Message for Component { /// assert_eq!(cast!(DataType::U32, x["field2"]).unwrap(), 6) /// # } /// ``` - fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()>{ + fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()> { let mut filtering_key = HashSet::new(); let mut dynamic_size = HashMap::new(); for (name, value) in self.into_iter() { @@ -430,20 +432,22 @@ impl Message for Component { } if dynamic_size.contains_key(name) { - let mut local =vec![0; dynamic_size[name]]; + let mut local = vec![0; dynamic_size[name]]; reader.read_exact(&mut local)?; value.read(&mut Cursor::new(local))?; - } - else { + } else { value.read(reader)?; } match value.options() { - MessageOption::SkipField(field) => { filtering_key.insert(field); }, - MessageOption::Size(field, size) => { dynamic_size.insert(field, size); }, - MessageOption::None => () + MessageOption::SkipField(field) => { + filtering_key.insert(field); + } + MessageOption::Size(field, size) => { + dynamic_size.insert(field, size); + } + MessageOption::None => (), } - } Ok(()) } @@ -468,7 +472,7 @@ impl Message for Component { /// # } /// ``` fn length(&self) -> u64 { - let mut sum : u64 = 0; + let mut sum: u64 = 0; let mut filtering_key = HashSet::new(); for (name, value) in self.iter() { // ignore filtering keys @@ -521,7 +525,7 @@ pub enum Value { /// Big Endianness BE(Type), /// Little Endianness - LE(Type) + LE(Type), } impl Value { @@ -535,7 +539,7 @@ impl Value { /// ``` pub fn inner(&self) -> Type { match self { - Value::::BE(e) | Value::::LE(e) => *e + Value::::BE(e) | Value::::LE(e) => *e, } } } @@ -543,7 +547,7 @@ impl Value { impl PartialEq for Value { /// Equality between all type fn eq(&self, other: &Self) -> bool { - return self.inner() == other.inner() + return self.inner() == other.inner(); } } @@ -551,7 +555,6 @@ impl PartialEq for Value { pub type U16 = Value; impl Message for U16 { - /// Write an unsigned 16 bits value /// /// # Example @@ -565,10 +568,10 @@ impl Message for U16 { /// U16::BE(4).write(&mut s2).unwrap(); /// assert_eq!(s2.into_inner(), [0, 4]); /// ``` - fn write(&self, writer: &mut dyn Write) -> RdpResult<()>{ + fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { match self { U16::BE(value) => Ok(writer.write_u16::(*value)?), - U16::LE(value) => Ok(writer.write_u16::(*value)?) + U16::LE(value) => Ok(writer.write_u16::(*value)?), } } @@ -588,10 +591,10 @@ impl Message for U16 { /// v2.read(&mut s2).unwrap(); /// assert_eq!(v2.inner(), 4); /// ``` - fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()>{ + fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()> { match self { U16::BE(value) => *value = reader.read_u16::()?, - U16::LE(value) => *value = reader.read_u16::()? + U16::LE(value) => *value = reader.read_u16::()?, } Ok(()) } @@ -634,7 +637,6 @@ impl Message for U16 { pub type U32 = Value; impl Message for U32 { - /// Write an unsigned 32 bits value /// /// # Example @@ -651,7 +653,7 @@ impl Message for U32 { fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { match self { U32::BE(value) => Ok(writer.write_u32::(*value)?), - U32::LE(value) => Ok(writer.write_u32::(*value)?) + U32::LE(value) => Ok(writer.write_u32::(*value)?), } } @@ -674,7 +676,7 @@ impl Message for U32 { fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()> { match self { U32::BE(value) => *value = reader.read_u32::()?, - U32::LE(value) => *value = reader.read_u32::()? + U32::LE(value) => *value = reader.read_u32::()?, } Ok(()) } @@ -716,7 +718,7 @@ impl Message for U32 { /// This is a wrapper around /// a copyable message to check constness pub struct Check { - value: T + value: T, } impl Check { @@ -733,15 +735,12 @@ impl Check { /// let mut s2 = Cursor::new(vec![5, 0]); /// assert!(x.read(&mut s2).is_err()); /// ``` - pub fn new(value: T) -> Self{ - Check { - value - } + pub fn new(value: T) -> Self { + Check { value } } } impl Message for Check { - /// Check values doesn't happen during write steps fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { self.value.write(writer) @@ -764,7 +763,10 @@ impl Message for Check { let old = self.value.clone(); self.value.read(reader)?; if old != self.value { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidConst, "Invalid constness of data"))) + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidConst, + "Invalid constness of data", + ))); } Ok(()) } @@ -812,8 +814,7 @@ impl Message for Vec { fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()> { if self.len() == 0 { reader.read_to_end(self)?; - } - else { + } else { reader.read_exact(self)?; } Ok(()) @@ -866,7 +867,7 @@ impl Message for Vec { pub type DynOptionFnSend = dyn Fn(&T) -> MessageOption + Send; pub struct DynOption { inner: T, - filter: Box> + filter: Box>, } /// The filter impl @@ -920,10 +921,13 @@ impl DynOption { /// # } /// ``` pub fn new(current: T, filter: F) -> Self - where F: Fn(&T) -> MessageOption, F: Send { + where + F: Fn(&T) -> MessageOption, + F: Send, + { DynOption { inner: current, - filter : Box::new(filter) + filter: Box::new(filter), } } } @@ -964,16 +968,16 @@ pub fn to_vec(message: &dyn Message) -> Vec { stream.into_inner() } - #[macro_export] macro_rules! is_none { - ($expr:expr) => (match $expr.visit() { - DataType::None => true, - _ => false - }) + ($expr:expr) => { + match $expr.visit() { + DataType::None => true, + _ => false, + } + }; } - /// This is an optional fields /// Actually always write but read if and only if the reader /// buffer could read the size of inner Message @@ -1048,8 +1052,7 @@ impl Message for Option { fn length(&self) -> u64 { if let Some(value) = self { value.length() - } - else { + } else { 0 } } @@ -1069,8 +1072,7 @@ impl Message for Option { fn visit(&self) -> DataType { if let Some(value) = self { value.visit() - } - else { + } else { DataType::None } } @@ -1088,7 +1090,7 @@ pub struct Array { /// This is the inner trame inner: Trame, /// function call to build each element of the array - factory: Box> + factory: Box>, } impl Array { @@ -1112,10 +1114,13 @@ impl Array { /// # } /// ``` pub fn new(factory: F) -> Self - where F: Fn() -> T, F: Send { + where + F: Fn() -> T, + F: Send, + { Array { inner: trame![], - factory: Box::new(factory) + factory: Box::new(factory), } } @@ -1125,7 +1130,7 @@ impl Array { pub fn from_trame(inner: Trame) -> Self { Array { inner, - factory: Box::new(|| panic!("Try reading a non empty array")) + factory: Box::new(|| panic!("Try reading a non empty array")), } } diff --git a/src/model/error.rs b/src/model/error.rs index 89ef68b..4b96c3b 100644 --- a/src/model/error.rs +++ b/src/model/error.rs @@ -1,12 +1,10 @@ -extern crate native_tls; - -use std::io::{Read, Write}; +use native_tls::Error as SslError; +use native_tls::HandshakeError; +use num_enum::{TryFromPrimitive, TryFromPrimitiveError}; use std::io::Error as IoError; +use std::io::{Read, Write}; use std::string::String; -use self::native_tls::HandshakeError; -use self::native_tls::Error as SslError; use yasna::ASN1Error; -use num_enum::{TryFromPrimitive, TryFromPrimitiveError}; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum RdpErrorKind { @@ -48,15 +46,16 @@ pub enum RdpErrorKind { Disconnect, /// Indicate an unknown field Unknown, - UnexpectedType + UnexpectedType, } #[derive(Debug)] +#[allow(dead_code)] pub struct RdpError { /// Kind of error kind: RdpErrorKind, /// Associated message of the context - message: String + message: String, } impl RdpError { @@ -66,10 +65,10 @@ impl RdpError { /// use rdp::model::error::{RdpError, RdpErrorKind}; /// let error = RdpError::new(RdpErrorKind::Disconnect, "disconnected"); /// ``` - pub fn new (kind: RdpErrorKind, message: &str) -> Self { + pub fn new(kind: RdpErrorKind, message: &str) -> Self { RdpError { kind, - message: String::from(message) + message: String::from(message), } } @@ -99,7 +98,7 @@ pub enum Error { /// ASN1 parser error ASN1Error(ASN1Error), /// try error - TryError(String) + TryError(String), } /// From IO Error @@ -129,7 +128,10 @@ impl From for Error { impl From> for Error { fn from(_: TryFromPrimitiveError) -> Self { - Error::RdpError(RdpError::new(RdpErrorKind::InvalidCast, "Invalid enum conversion")) + Error::RdpError(RdpError::new( + RdpErrorKind::InvalidCast, + "Invalid enum conversion", + )) } } @@ -139,22 +141,27 @@ pub type RdpResult = Result; #[macro_export] macro_rules! try_option { ($val: expr, $expr: expr) => { - if let Some(x) = $val { + if let Some(x) = $val { Ok(x) - } else { - Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidOptionalField, $expr))) - } - } + } else { + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidOptionalField, + $expr, + ))) + } + }; } #[macro_export] macro_rules! try_let { ($ident: path, $val: expr) => { - if let $ident(x) = $val { + if let $ident(x) = $val { Ok(x) - } else { - Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidCast, "Invalid Cast"))) - } - } + } else { + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidCast, + "Invalid Cast", + ))) + } + }; } - diff --git a/src/model/link.rs b/src/model/link.rs index 424a1c2..ba27fd5 100644 --- a/src/model/link.rs +++ b/src/model/link.rs @@ -1,9 +1,7 @@ -extern crate native_tls; - -use model::error::{RdpResult, Error, RdpError, RdpErrorKind}; +use crate::model::data::Message; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use native_tls::{Certificate, TlsConnector, TlsStream}; use std::io::{Cursor, Read, Write}; -use self::native_tls::{TlsConnector, TlsStream, Certificate}; -use model::data::{Message}; /// This a wrapper to work equals /// for a stream and a TLS stream @@ -11,7 +9,7 @@ pub enum Stream { /// Raw stream that implement Read + Write Raw(S), /// TLS Stream - Ssl(TlsStream) + Ssl(TlsStream), } impl Stream { @@ -26,10 +24,10 @@ impl Stream { /// s.read_exact(&mut result).unwrap(); /// assert_eq!(result, [1, 2]) /// ``` - pub fn read_exact(&mut self, buf: &mut[u8]) -> RdpResult<()> { + pub fn read_exact(&mut self, buf: &mut [u8]) -> RdpResult<()> { match self { Stream::Raw(e) => e.read_exact(buf)?, - Stream::Ssl(e) => e.read_exact(buf)? + Stream::Ssl(e) => e.read_exact(buf)?, }; Ok(()) } @@ -45,10 +43,10 @@ impl Stream { /// s.read(&mut result).unwrap(); /// assert_eq!(result, [1, 2, 3, 0]) /// ``` - pub fn read(&mut self, buf: &mut[u8]) -> RdpResult { + pub fn read(&mut self, buf: &mut [u8]) -> RdpResult { match self { Stream::Raw(e) => Ok(e.read(buf)?), - Stream::Ssl(e) => Ok(e.read(buf)?) + Stream::Ssl(e) => Ok(e.read(buf)?), } } @@ -71,7 +69,7 @@ impl Stream { pub fn write(&mut self, buffer: &[u8]) -> RdpResult { Ok(match self { Stream::Raw(e) => e.write(buffer)?, - Stream::Ssl(e) => e.write(buffer)? + Stream::Ssl(e) => e.write(buffer)?, }) } @@ -80,7 +78,7 @@ impl Stream { pub fn shutdown(&mut self) -> RdpResult<()> { Ok(match self { Stream::Ssl(e) => e.shutdown()?, - _ => () + _ => (), }) } } @@ -88,7 +86,7 @@ impl Stream { /// Link layer is a wrapper around TCP or SSL stream /// It can swicth from TCP to SSL pub struct Link { - stream: Stream + stream: Stream, } impl Link { @@ -104,9 +102,7 @@ impl Link { /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap())); /// ``` pub fn new(stream: Stream) -> Self { - Link { - stream - } + Link { stream } } /// This method is designed to write a Message @@ -155,8 +151,7 @@ impl Link { let size = self.stream.read(&mut buffer)?; buffer.resize(size, 0); Ok(buffer) - } - else { + } else { let mut buffer = vec![0; expected_size]; self.stream.read_exact(&mut buffer)?; Ok(buffer) @@ -181,9 +176,12 @@ impl Link { let connector = builder.build()?; if let Stream::Raw(stream) = self.stream { - return Ok(Link::new(Stream::Ssl(connector.connect("", stream)?))) + return Ok(Link::new(Stream::Ssl(connector.connect("", stream)?))); } - Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, "start_ssl on ssl stream is forbidden"))) + Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + "start_ssl on ssl stream is forbidden", + ))) } /// Retrive the peer certificate @@ -201,9 +199,11 @@ impl Link { pub fn get_peer_certificate(&self) -> RdpResult> { if let Stream::Ssl(stream) = &self.stream { Ok(stream.peer_certificate()?) - } - else { - Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "get peer certificate on non ssl link is impossible"))) + } else { + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidData, + "get peer certificate on non ssl link is impossible", + ))) } } diff --git a/src/model/mod.rs b/src/model/mod.rs index 1f96306..f2f00fe 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -4,4 +4,4 @@ pub mod link; #[macro_use] pub mod error; pub mod rnd; -pub mod unicode; \ No newline at end of file +pub mod unicode; diff --git a/src/model/rnd.rs b/src/model/rnd.rs index a3d2bb5..5112f13 100644 --- a/src/model/rnd.rs +++ b/src/model/rnd.rs @@ -11,4 +11,4 @@ use rand::Rng; pub fn random(size: usize) -> Vec { let mut rng = rand::thread_rng(); (0..size).map(|_| rng.gen()).collect() -} \ No newline at end of file +} diff --git a/src/model/unicode.rs b/src/model/unicode.rs index 5246a71..45f959e 100644 --- a/src/model/unicode.rs +++ b/src/model/unicode.rs @@ -1,4 +1,4 @@ -use model::data::{Message, U16}; +use crate::model::data::{Message, U16}; use std::io::Cursor; /// Use to to_unicode function for String @@ -21,6 +21,6 @@ impl Unicode for String { let encode_char = U16::LE(c); encode_char.write(&mut result).unwrap(); } - return result.into_inner() + return result.into_inner(); } -} \ No newline at end of file +} diff --git a/src/nla/asn1.rs b/src/nla/asn1.rs index 5b09bae..c0d2108 100644 --- a/src/nla/asn1.rs +++ b/src/nla/asn1.rs @@ -1,6 +1,6 @@ -use yasna::{Tag, DERWriter, BERReader}; -use model::error::{RdpResult, Error}; +use crate::model::error::{Error, RdpResult}; use indexmap::map::IndexMap; +use yasna::{BERReader, DERWriter, Tag}; /// Enum all possible value /// In an ASN 1 tree @@ -16,7 +16,7 @@ pub enum ASN1Type<'a> { /// Boolean Bool(bool), /// Enumerate - Enumerate(i64) + Enumerate(i64), } /// This trait is a wrapper around @@ -38,7 +38,7 @@ pub struct SequenceOf { /// The inner vector of ASN1 node pub inner: Vec>, /// Callback use as Factory - factory: Option Box>> + factory: Option Box>>, } impl SequenceOf { @@ -49,10 +49,10 @@ impl SequenceOf { /// use rdp::nla::asn1::SequenceOf; /// let so = SequenceOf::new(); /// ``` - pub fn new() -> Self{ + pub fn new() -> Self { SequenceOf { inner: Vec::new(), - factory : None + factory: None, } } @@ -65,10 +65,12 @@ impl SequenceOf { /// let so = SequenceOf::reader(|| Box::new(OctetString::new())); /// ``` pub fn reader(factory: F) -> Self - where F: Fn() -> Box { + where + F: Fn() -> Box, + { SequenceOf { inner: Vec::new(), - factory : Some(Box::new(factory)) + factory: Some(Box::new(factory)), } } } @@ -121,7 +123,7 @@ impl ASN1 for SequenceOf { if let Some(callback) = &self.factory { let mut element = (callback)(); if let Err(Error::ASN1Error(e)) = element.read_asn1(sequence_reader) { - return Err(e) + return Err(e); } self.inner.push(element); } @@ -227,7 +229,7 @@ pub struct ExplicitTag { /// Associate explicit Tag tag: Tag, /// The inner object - inner: T + inner: T, } impl ExplicitTag { @@ -241,10 +243,7 @@ impl ExplicitTag { /// let s = ExplicitTag::new(Tag::context(0), 2 as Integer); /// ``` pub fn new(tag: Tag, inner: T) -> Self { - ExplicitTag { - tag, - inner - } + ExplicitTag { tag, inner } } /// return the inner object @@ -296,8 +295,8 @@ impl ASN1 for ExplicitTag { /// ``` fn read_asn1(&mut self, reader: BERReader) -> RdpResult<()> { reader.read_tagged(self.tag, |tag_reader| { - if let Err(Error::ASN1Error(e)) = self.inner.read_asn1(tag_reader) { - return Err(e) + if let Err(Error::ASN1Error(e)) = self.inner.read_asn1(tag_reader) { + return Err(e); } Ok(()) })?; @@ -330,7 +329,7 @@ pub struct ImplicitTag { /// This implicit tag tag: Tag, /// The inner node - pub inner: T + pub inner: T, } impl ImplicitTag { @@ -344,10 +343,7 @@ impl ImplicitTag { /// let s = ImplicitTag::new(Tag::context(0), 1 as Integer); /// ``` pub fn new(tag: Tag, inner: T) -> Self { - ImplicitTag { - tag, - inner - } + ImplicitTag { tag, inner } } } @@ -394,8 +390,8 @@ impl ASN1 for ImplicitTag { /// ``` fn read_asn1(&mut self, reader: BERReader) -> RdpResult<()> { reader.read_tagged_implicit(self.tag, |tag_reader| { - if let Err(Error::ASN1Error(e)) = self.inner.read_asn1(tag_reader) { - return Err(e) + if let Err(Error::ASN1Error(e)) = self.inner.read_asn1(tag_reader) { + return Err(e); } Ok(()) })?; @@ -427,7 +423,6 @@ impl ASN1 for ImplicitTag { pub type Integer = u32; impl ASN1 for Integer { - /// Write an ASN1 Integer Node /// using a DERWriter /// @@ -494,7 +489,6 @@ impl ASN1 for Integer { /// ASN1 for boolean impl ASN1 for bool { - /// Write an ASN1 boolean Node /// using a DERWriter /// @@ -564,7 +558,6 @@ impl ASN1 for bool { pub type Sequence = IndexMap>; impl ASN1 for Sequence { - /// Write an ASN1 sequence Node /// using a DERWriter /// @@ -588,7 +581,7 @@ impl ASN1 for Sequence { writer.write_sequence(|sequence| { for (_name, child) in self.iter() { child.write_asn1(sequence.next()).unwrap(); - }; + } }); Ok(()) } @@ -618,9 +611,9 @@ impl ASN1 for Sequence { reader.read_sequence(|sequence_reader| { for (_name, child) in self.into_iter() { if let Err(Error::ASN1Error(e)) = child.read_asn1(sequence_reader.next()) { - return Err(e) + return Err(e); } - }; + } Ok(()) })?; Ok(()) @@ -655,7 +648,6 @@ impl ASN1 for Sequence { pub type Enumerate = i64; impl ASN1 for Enumerate { - /// Write an ASN1 Enumerate Node /// using a DERWriter /// @@ -729,20 +721,20 @@ pub fn to_der(message: &dyn ASN1) -> Vec { } /// Deserialize an ASN1 message from a stream -pub fn from_der(message: &mut dyn ASN1, stream: &[u8]) ->RdpResult<()> { +pub fn from_der(message: &mut dyn ASN1, stream: &[u8]) -> RdpResult<()> { Ok(yasna::parse_der(stream, |reader| { if let Err(Error::ASN1Error(e)) = message.read_asn1(reader) { - return Err(e) + return Err(e); } Ok(()) })?) } /// Deserialize an ASN1 message from a stream using BER -pub fn from_ber(message: &mut dyn ASN1, stream: &[u8]) ->RdpResult<()> { +pub fn from_ber(message: &mut dyn ASN1, stream: &[u8]) -> RdpResult<()> { Ok(yasna::parse_ber(stream, |reader| { if let Err(Error::ASN1Error(e)) = message.read_asn1(reader) { - return Err(e) + return Err(e); } Ok(()) })?) @@ -755,4 +747,4 @@ macro_rules! sequence { $( map.insert($key.to_string(), Box::new($val)); )* map }} -} \ No newline at end of file +} diff --git a/src/nla/cssp.rs b/src/nla/cssp.rs index d327106..309ef64 100644 --- a/src/nla/cssp.rs +++ b/src/nla/cssp.rs @@ -1,11 +1,13 @@ -use nla::asn1::{ASN1, Sequence, ExplicitTag, SequenceOf, ASN1Type, OctetString, Integer, to_der}; -use model::error::{RdpError, RdpErrorKind, Error, RdpResult}; -use num_bigint::{BigUint}; -use yasna::Tag; -use x509_parser::{parse_x509_der, X509Certificate}; -use nla::sspi::AuthenticationProtocol; -use model::link::Link; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::model::link::Link; +use crate::nla::asn1::{ + to_der, ASN1Type, ExplicitTag, Integer, OctetString, Sequence, SequenceOf, ASN1, +}; +use crate::nla::sspi::AuthenticationProtocol; +use num_bigint::BigUint; use std::io::{Read, Write}; +use x509_parser::{parse_x509_der, X509Certificate}; +use yasna::Tag; /// Create a ts request as expected by the specification /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-cssp/6aac4dea-08ef-47a6-8747-22ea7f6d8685?redirectedfrom=MSDN @@ -39,7 +41,7 @@ pub fn create_ts_request(nego: Vec) -> Vec { /// /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-cssp/6aac4dea-08ef-47a6-8747-22ea7f6d8685?redirectedfrom=MSDN /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-cssp/9664994d-0784-4659-b85b-83b8d54c2336 -/// +/// /// # Example /// ``` /// use rdp::nla::cssp::read_ts_server_challenge; @@ -61,7 +63,7 @@ pub fn read_ts_server_challenge(stream: &[u8]) -> RdpResult> { yasna::parse_der(stream, |reader| { if let Err(Error::ASN1Error(e)) = ts_request.read_asn1(reader) { - return Err(e) + return Err(e); } Ok(()) })?; @@ -124,7 +126,7 @@ pub fn read_ts_validate(request: &[u8]) -> RdpResult> { yasna::parse_der(request, |reader| { if let Err(Error::ASN1Error(e)) = ts_challenge.read_asn1(reader) { - return Err(e) + return Err(e); } Ok(()) })?; @@ -163,7 +165,11 @@ fn create_ts_authinfo(auth_info: Vec) -> Vec { /// This the main function for CSSP protocol /// It will use the raw link layer and the selected authenticate protocol /// to perform the NLA authenticate -pub fn cssp_connect(link: &mut Link, authentication_protocol: &mut dyn AuthenticationProtocol, restricted_admin_mode: bool) -> RdpResult<()> { +pub fn cssp_connect( + link: &mut Link, + authentication_protocol: &mut dyn AuthenticationProtocol, + restricted_admin_mode: bool, +) -> RdpResult<()> { // first step is to send the negotiate message from authentication protocol let negotiate_message = create_ts_request(authentication_protocol.create_negotiate_message()?); link.write(&negotiate_message)?; @@ -178,28 +184,66 @@ pub fn cssp_connect(link: &mut Link, authentication_protocol let mut security_interface = authentication_protocol.build_security_interface(); // Get the peer public certificate - let certificate_der = try_option!(link.get_peer_certificate()?, "No public certificate available")?.to_der()?; + let certificate_der = try_option!( + link.get_peer_certificate()?, + "No public certificate available" + )? + .to_der()?; let certificate = read_public_certificate(&certificate_der)?; // Now we can send back our challenge payload wit the public key encoded - let challenge = create_ts_authenticate(client_challenge, security_interface.gss_wrapex(certificate.tbs_certificate.subject_pki.subject_public_key.data)?); + let challenge = create_ts_authenticate( + client_challenge, + security_interface.gss_wrapex( + certificate + .tbs_certificate + .subject_pki + .subject_public_key + .data, + )?, + ); link.write(&challenge)?; // now server respond normally with the original public key incremented by one let inc_pub_key = security_interface.gss_unwrapex(&(read_ts_validate(&(link.read(0)?))?))?; // Check possible man in the middle using cssp - if BigUint::from_bytes_le(&inc_pub_key) != BigUint::from_bytes_le(certificate.tbs_certificate.subject_pki.subject_public_key.data) + BigUint::new(vec![1]) { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::PossibleMITM, "Man in the middle detected"))) + if BigUint::from_bytes_le(&inc_pub_key) + != BigUint::from_bytes_le( + certificate + .tbs_certificate + .subject_pki + .subject_public_key + .data, + ) + BigUint::new(vec![1]) + { + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::PossibleMITM, + "Man in the middle detected", + ))); } // compute the last message with encoded credentials - let domain = if restricted_admin_mode { vec![] } else { authentication_protocol.get_domain_name()}; - let user = if restricted_admin_mode { vec![] } else { authentication_protocol.get_user_name() }; - let password = if restricted_admin_mode { vec![] } else { authentication_protocol.get_password() }; + let domain = if restricted_admin_mode { + vec![] + } else { + authentication_protocol.get_domain_name() + }; + let user = if restricted_admin_mode { + vec![] + } else { + authentication_protocol.get_user_name() + }; + let password = if restricted_admin_mode { + vec![] + } else { + authentication_protocol.get_password() + }; - let credentials = create_ts_authinfo(security_interface.gss_wrapex(&create_ts_credentials(domain, user, password))?); + let credentials = create_ts_authinfo( + security_interface.gss_wrapex(&create_ts_credentials(domain, user, password))?, + ); link.write(&credentials)?; Ok(()) @@ -211,14 +255,22 @@ mod test { #[test] fn test_create_ts_credentials() { - let credentials = create_ts_credentials(b"domain".to_vec(), b"user".to_vec(), b"password".to_vec()); - let result = [48, 41, 160, 3, 2, 1, 1, 161, 34, 4, 32, 48, 30, 160, 8, 4, 6, 100, 111, 109, 97, 105, 110, 161, 6, 4, 4, 117, 115, 101, 114, 162, 10, 4, 8, 112, 97, 115, 115, 119, 111, 114, 100]; + let credentials = + create_ts_credentials(b"domain".to_vec(), b"user".to_vec(), b"password".to_vec()); + let result = [ + 48, 41, 160, 3, 2, 1, 1, 161, 34, 4, 32, 48, 30, 160, 8, 4, 6, 100, 111, 109, 97, 105, + 110, 161, 6, 4, 4, 117, 115, 101, 114, 162, 10, 4, 8, 112, 97, 115, 115, 119, 111, 114, + 100, + ]; assert_eq!(credentials[0..32], result[0..32]); assert_eq!(credentials[33..43], result[33..43]); } #[test] fn test_create_ts_authinfo() { - assert_eq!(create_ts_authinfo(b"foo".to_vec()), [48, 12, 160, 3, 2, 1, 2, 162, 5, 4, 3, 102, 111, 111]) + assert_eq!( + create_ts_authinfo(b"foo".to_vec()), + [48, 12, 160, 3, 2, 1, 2, 162, 5, 4, 3, 102, 111, 111] + ) } -} \ No newline at end of file +} diff --git a/src/nla/mod.rs b/src/nla/mod.rs index 6b45edd..af8d885 100644 --- a/src/nla/mod.rs +++ b/src/nla/mod.rs @@ -2,5 +2,5 @@ pub mod asn1; pub mod cssp; pub mod ntlm; -pub mod sspi; pub mod rc4; +pub mod sspi; diff --git a/src/nla/ntlm.rs b/src/nla/ntlm.rs index 3680c69..8d2bd06 100644 --- a/src/nla/ntlm.rs +++ b/src/nla/ntlm.rs @@ -1,15 +1,17 @@ -use nla::sspi::{AuthenticationProtocol, GenericSecurityService}; -use model::data::{Message, Component, U16, U32, Trame, DynOption, Check, DataType, MessageOption, to_vec}; -use std::io::{Cursor}; -use model::error::{RdpResult, RdpError, RdpErrorKind, Error}; -use std::collections::HashMap; -use md4::{Md4, Digest}; +use crate::model::data::{ + to_vec, Check, Component, DataType, DynOption, Message, MessageOption, Trame, U16, U32, +}; +use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +use crate::model::rnd::random; +use crate::nla::rc4::Rc4; +use crate::nla::sspi::{AuthenticationProtocol, GenericSecurityService}; use hmac::{Hmac, Mac}; -use md5::{Md5}; -use model::rnd::{random}; -use nla::rc4::{Rc4}; +use md4::{Digest, Md4}; +use md5::Md5; use num_enum::TryFromPrimitive; +use std::collections::HashMap; use std::convert::TryFrom; +use std::io::Cursor; #[repr(u32)] #[allow(dead_code)] @@ -34,14 +36,14 @@ enum Negotiate { NtlmsspNegociateSign = 0x00000010, NtlmsspRequestTarget = 0x00000004, NtlmNegotiateOEM = 0x00000002, - NtlmsspNegociateUnicode = 0x00000001 + NtlmsspNegociateUnicode = 0x00000001, } #[repr(u8)] #[allow(dead_code)] enum MajorVersion { WindowsMajorVersion5 = 0x05, - WindowsMajorVersion6 = 0x06 + WindowsMajorVersion6 = 0x06, } #[repr(u8)] @@ -50,12 +52,12 @@ enum MinorVersion { WindowsMinorVersion0 = 0x00, WindowsMinorVersion1 = 0x01, WindowsMinorVersion2 = 0x02, - WindowsMinorVersion3 = 0x03 + WindowsMinorVersion3 = 0x03, } #[repr(u8)] enum NTLMRevision { - NtlmSspRevisionW2K3 = 0x0F + NtlmSspRevisionW2K3 = 0x0F, } fn version() -> Component { @@ -120,43 +122,62 @@ fn challenge_message() -> Component { /// /// Due to Microsoft spec if you have to compute MIC you need /// separatly the packet and the payload -fn authenticate_message(lm_challenge_response: &[u8], nt_challenge_response:&[u8], domain: &[u8], user: &[u8], workstation: &[u8], encrypted_random_session_key: &[u8], flags: u32) -> (Component, Vec) { - let payload = [lm_challenge_response.to_vec(), nt_challenge_response.to_vec(), domain.to_vec(), user.to_vec(), workstation.to_vec(), encrypted_random_session_key.to_vec()].concat(); +fn authenticate_message( + lm_challenge_response: &[u8], + nt_challenge_response: &[u8], + domain: &[u8], + user: &[u8], + workstation: &[u8], + encrypted_random_session_key: &[u8], + flags: u32, +) -> (Component, Vec) { + let payload = [ + lm_challenge_response.to_vec(), + nt_challenge_response.to_vec(), + domain.to_vec(), + user.to_vec(), + workstation.to_vec(), + encrypted_random_session_key.to_vec(), + ] + .concat(); let offset = if flags & (Negotiate::NtlmsspNegociateVersion as u32) == 0 { 80 } else { 88 }; - (component![ - "Signature" => Check::new(b"NTLMSSP\x00".to_vec()), - "MessageType" => Check::new(U32::LE(3)), - "LmChallengeResponseLen" => U16::LE(lm_challenge_response.len() as u16), - "LmChallengeResponseMaxLen" => U16::LE(lm_challenge_response.len() as u16), - "LmChallengeResponseBufferOffset" => U32::LE(offset), - "NtChallengeResponseLen" => U16::LE(nt_challenge_response.len() as u16), - "NtChallengeResponseMaxLen" => U16::LE(nt_challenge_response.len() as u16), - "NtChallengeResponseBufferOffset" => U32::LE(offset + lm_challenge_response.len() as u32), - "DomainNameLen" => U16::LE(domain.len() as u16), - "DomainNameMaxLen" => U16::LE(domain.len() as u16), - "DomainNameBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len()) as u32), - "UserNameLen" => U16::LE(user.len() as u16), - "UserNameMaxLen" => U16::LE(user.len() as u16), - "UserNameBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len() + domain.len()) as u32), - "WorkstationLen" => U16::LE(workstation.len() as u16), - "WorkstationMaxLen" => U16::LE(workstation.len() as u16), - "WorkstationBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len() + domain.len() + user.len()) as u32), - "EncryptedRandomSessionLen" => U16::LE(encrypted_random_session_key.len() as u16), - "EncryptedRandomSessionMaxLen" => U16::LE(encrypted_random_session_key.len() as u16), - "EncryptedRandomSessionBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len() + domain.len() + user.len() + workstation.len()) as u32), - "NegotiateFlags" => DynOption::new(U32::LE(flags), |node| { - if node.inner() & (Negotiate::NtlmsspNegociateVersion as u32) == 0 { - return MessageOption::SkipField("Version".to_string()) - } - return MessageOption::None - }), - "Version" => version() - ] , payload) + ( + component![ + "Signature" => Check::new(b"NTLMSSP\x00".to_vec()), + "MessageType" => Check::new(U32::LE(3)), + "LmChallengeResponseLen" => U16::LE(lm_challenge_response.len() as u16), + "LmChallengeResponseMaxLen" => U16::LE(lm_challenge_response.len() as u16), + "LmChallengeResponseBufferOffset" => U32::LE(offset), + "NtChallengeResponseLen" => U16::LE(nt_challenge_response.len() as u16), + "NtChallengeResponseMaxLen" => U16::LE(nt_challenge_response.len() as u16), + "NtChallengeResponseBufferOffset" => U32::LE(offset + lm_challenge_response.len() as u32), + "DomainNameLen" => U16::LE(domain.len() as u16), + "DomainNameMaxLen" => U16::LE(domain.len() as u16), + "DomainNameBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len()) as u32), + "UserNameLen" => U16::LE(user.len() as u16), + "UserNameMaxLen" => U16::LE(user.len() as u16), + "UserNameBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len() + domain.len()) as u32), + "WorkstationLen" => U16::LE(workstation.len() as u16), + "WorkstationMaxLen" => U16::LE(workstation.len() as u16), + "WorkstationBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len() + domain.len() + user.len()) as u32), + "EncryptedRandomSessionLen" => U16::LE(encrypted_random_session_key.len() as u16), + "EncryptedRandomSessionMaxLen" => U16::LE(encrypted_random_session_key.len() as u16), + "EncryptedRandomSessionBufferOffset" => U32::LE(offset + (lm_challenge_response.len() + nt_challenge_response.len() + domain.len() + user.len() + workstation.len()) as u32), + "NegotiateFlags" => DynOption::new(U32::LE(flags), |node| { + if node.inner() & (Negotiate::NtlmsspNegociateVersion as u32) == 0 { + return MessageOption::SkipField("Version".to_string()) + } + return MessageOption::None + }), + "Version" => version() + ], + payload, + ) } /// This function is a shortcut to get a particular field from the payload field @@ -168,7 +189,6 @@ fn get_payload_field(message: &Component, length: u16, buffer_offset: u32) -> Rd Ok(&payload[start..end]) } - #[repr(u16)] #[derive(Eq, PartialEq, Hash, Debug, TryFromPrimitive)] enum AvId { @@ -182,7 +202,7 @@ enum AvId { MsvAvTimestamp = 0x0007, MsvAvSingleHost = 0x0008, MsvAvTargetName = 0x0009, - MsvChannelBindings = 0x000A + MsvChannelBindings = 0x000A, } /// Av Pair is a Key Value pair structure @@ -298,7 +318,7 @@ fn unicode(data: &String) -> Vec { let encode_char = U16::LE(c); encode_char.write(&mut result).unwrap(); } - return result.into_inner() + return result.into_inner(); } /// Compute HMAC with MD5 hash algorithm @@ -325,7 +345,10 @@ fn hmac_md5(key: &[u8], data: &[u8]) -> Vec { /// let key = ntowfv2("hello123".to_string(), "user".to_string(), "domain".to_string()) /// ``` fn ntowfv2(password: &String, user: &String, domain: &String) -> Vec { - hmac_md5(&md4(&unicode(password)), &unicode(&(user.to_uppercase() + &domain))) + hmac_md5( + &md4(&unicode(password)), + &unicode(&(user.to_uppercase() + &domain)), + ) } /// This function is used to compute init key of another hmac_md5 @@ -365,27 +388,57 @@ fn lmowfv2(password: &String, user: &String, domain: &String) -> Vec { /// let session_base_key = response.2; /// ``` fn compute_response_v2( - response_key_nt: &[u8], response_key_lm: &[u8], - server_challenge: &[u8], client_challenge: &[u8], time: &[u8], - server_name: &[u8] + response_key_nt: &[u8], + response_key_lm: &[u8], + server_challenge: &[u8], + client_challenge: &[u8], + time: &[u8], + server_name: &[u8], ) -> (Vec, Vec, Vec) { let response_version = b"\x01"; let hi_response_version = b"\x01"; - let temp = [response_version.to_vec(), hi_response_version.to_vec(), z(6), time.to_vec(), client_challenge.to_vec(), z(4), server_name.to_vec()].concat(); - let nt_proof_str = hmac_md5(response_key_nt, &[server_challenge.to_vec(), temp.clone()].concat()); + let temp = [ + response_version.to_vec(), + hi_response_version.to_vec(), + z(6), + time.to_vec(), + client_challenge.to_vec(), + z(4), + server_name.to_vec(), + ] + .concat(); + let nt_proof_str = hmac_md5( + response_key_nt, + &[server_challenge.to_vec(), temp.clone()].concat(), + ); let nt_challenge_response = [nt_proof_str.clone(), temp.clone()].concat(); - let lm_challenge_response = [hmac_md5(response_key_lm, &[server_challenge.to_vec(), client_challenge.to_vec()].concat()), client_challenge.to_vec()].concat(); + let lm_challenge_response = [ + hmac_md5( + response_key_lm, + &[server_challenge.to_vec(), client_challenge.to_vec()].concat(), + ), + client_challenge.to_vec(), + ] + .concat(); let session_base_key = hmac_md5(response_key_nt, &nt_proof_str); - (nt_challenge_response, lm_challenge_response, session_base_key) + ( + nt_challenge_response, + lm_challenge_response, + session_base_key, + ) } /// This is a function described in specification /// /// This is just ton follow specification -fn kx_key_v2(session_base_key: &[u8], _lm_challenge_response: &[u8], _server_challenge: &[u8]) -> Vec { +fn kx_key_v2( + session_base_key: &[u8], + _lm_challenge_response: &[u8], + _server_challenge: &[u8], +) -> Vec { session_base_key.to_vec() } @@ -400,17 +453,38 @@ fn rc4k(key: &[u8], plaintext: &[u8]) -> Vec { } /// Compute a signature of all data exchange during NTLMv2 handshake -fn mic(exported_session_key: &[u8], negotiate_message: &[u8], challenge_message: &[u8], authenticate_message: &[u8]) -> Vec{ - hmac_md5(exported_session_key, &[negotiate_message.to_vec(), challenge_message.to_vec(), authenticate_message.to_vec()].concat()) +fn mic( + exported_session_key: &[u8], + negotiate_message: &[u8], + challenge_message: &[u8], + authenticate_message: &[u8], +) -> Vec { + hmac_md5( + exported_session_key, + &[ + negotiate_message.to_vec(), + challenge_message.to_vec(), + authenticate_message.to_vec(), + ] + .concat(), + ) } /// NTLMv2 security interface generate a sign key /// By using MD5 of the session key + a static member (sentense) fn sign_key(exported_session_key: &[u8], is_client: bool) -> Vec { if is_client { - md5(&[exported_session_key, b"session key to client-to-server signing key magic constant\0"].concat()) + md5(&[ + exported_session_key, + b"session key to client-to-server signing key magic constant\0", + ] + .concat()) } else { - md5(&[exported_session_key, b"session key to server-to-client signing key magic constant\0"].concat()) + md5(&[ + exported_session_key, + b"session key to server-to-client signing key magic constant\0", + ] + .concat()) } } @@ -418,9 +492,17 @@ fn sign_key(exported_session_key: &[u8], is_client: bool) -> Vec { /// By using MD5 of the session key + a static member (sentense) fn seal_key(exported_session_key: &[u8], is_client: bool) -> Vec { if is_client { - md5(&[exported_session_key, b"session key to client-to-server sealing key magic constant\0"].concat()) + md5(&[ + exported_session_key, + b"session key to client-to-server sealing key magic constant\0", + ] + .concat()) } else { - md5(&[exported_session_key, b"session key to server-to-client sealing key magic constant\0"].concat()) + md5(&[ + exported_session_key, + b"session key to server-to-client sealing key magic constant\0", + ] + .concat()) } } @@ -431,13 +513,18 @@ fn seal_key(exported_session_key: &[u8], is_client: bool) -> Vec { /// let signature = mac(&mut Rc4::new(b"foo"), b"bar", 0, b"data"); /// ``` fn mac(rc4_handle: &mut Rc4, signing_key: &[u8], seq_num: u32, data: &[u8]) -> Vec { - - let signature = hmac_md5(signing_key, &[to_vec(&U32::LE(seq_num)).as_slice(), data].concat()); + let signature = hmac_md5( + signing_key, + &[to_vec(&U32::LE(seq_num)).as_slice(), data].concat(), + ); let mut encryped_signature = vec![0; 8]; rc4_handle.process(&signature[0..8], &mut encryped_signature); - to_vec(&message_signature_ex(Some(&encryped_signature), Some(seq_num))) + to_vec(&message_signature_ex( + Some(&encryped_signature), + Some(seq_num), + )) } pub struct Ntlm { @@ -456,7 +543,7 @@ pub struct Ntlm { /// Key use to ciphering messages exported_session_key: Option>, /// True if session use unicode - is_unicode: bool + is_unicode: bool, } impl Ntlm { @@ -479,7 +566,7 @@ impl Ntlm { password, negotiate_message: None, exported_session_key: None, - is_unicode: false + is_unicode: false, } } @@ -500,35 +587,34 @@ impl Ntlm { password: "".to_string(), negotiate_message: None, exported_session_key: None, - is_unicode: false + is_unicode: false, } } } -impl AuthenticationProtocol for Ntlm { +impl AuthenticationProtocol for Ntlm { /// Create Negotiate message for our NTLMv2 implementation /// This message is used to inform server /// about the capabilities of the client fn create_negotiate_message(&mut self) -> RdpResult> { let buffer = to_vec(&negotiate_message( - Negotiate::NtlmsspNegociateKeyExch as u32 | - Negotiate::NtlmsspNegociate128 as u32 | - Negotiate::NtlmsspNegociateExtendedSessionSecurity as u32 | - Negotiate::NtlmsspNegociateAlwaysSign as u32 | - Negotiate::NtlmsspNegociateNTLM as u32 | - Negotiate::NtlmsspNegociateSeal as u32 | - Negotiate::NtlmsspNegociateSign as u32 | - Negotiate::NtlmsspRequestTarget as u32 | - Negotiate::NtlmsspNegociateUnicode as u32 + Negotiate::NtlmsspNegociateKeyExch as u32 + | Negotiate::NtlmsspNegociate128 as u32 + | Negotiate::NtlmsspNegociateExtendedSessionSecurity as u32 + | Negotiate::NtlmsspNegociateAlwaysSign as u32 + | Negotiate::NtlmsspNegociateNTLM as u32 + | Negotiate::NtlmsspNegociateSeal as u32 + | Negotiate::NtlmsspNegociateSign as u32 + | Negotiate::NtlmsspRequestTarget as u32 + | Negotiate::NtlmsspNegociateUnicode as u32, )); self.negotiate_message = Some(buffer.clone()); - return Ok(buffer) + return Ok(buffer); } /// Read the server challenge /// This is the second payload in cssp connection fn read_challenge_message(&mut self, request: &[u8]) -> RdpResult> { - let mut stream = Cursor::new(request); let mut result = challenge_message(); result.read(&mut stream)?; @@ -538,48 +624,79 @@ impl AuthenticationProtocol for Ntlm { let target_name = get_payload_field( &result, cast!(DataType::U16, result["TargetInfoLen"])?, - cast!(DataType::U32, result["TargetInfoBufferOffset"])? + cast!(DataType::U32, result["TargetInfoBufferOffset"])?, )?; - let target_info = read_target_info( - get_payload_field( - &result, - cast!(DataType::U16, result["TargetInfoLen"])?, - cast!(DataType::U32, result["TargetInfoBufferOffset"])? - )? - )?; + let target_info = read_target_info(get_payload_field( + &result, + cast!(DataType::U16, result["TargetInfoLen"])?, + cast!(DataType::U32, result["TargetInfoBufferOffset"])?, + )?)?; let timestamp = if target_info.contains_key(&AvId::MsvAvTimestamp) { target_info[&AvId::MsvAvTimestamp].clone() - } - else { + } else { panic!("no timestamp available") }; // generate client challenge let client_challenge = random(8); - let response = compute_response_v2(&self.response_key_nt, &self.response_key_lm, &server_challenge, &client_challenge, ×tamp, &target_name); + let response = compute_response_v2( + &self.response_key_nt, + &self.response_key_lm, + &server_challenge, + &client_challenge, + ×tamp, + &target_name, + ); let nt_challenge_response = response.0; let lm_challenge_response = response.1; let session_base_key = response.2; - let key_exchange_key = kx_key_v2(&session_base_key, &lm_challenge_response, &server_challenge); + let key_exchange_key = + kx_key_v2(&session_base_key, &lm_challenge_response, &server_challenge); self.exported_session_key = Some(random(16)); - let encrypted_random_session_key = rc4k(&key_exchange_key, self.exported_session_key.as_ref().unwrap()); - - self.is_unicode = cast!(DataType::U32, result["NegotiateFlags"])? & Negotiate::NtlmsspNegociateUnicode as u32 == 1; + let encrypted_random_session_key = rc4k( + &key_exchange_key, + self.exported_session_key.as_ref().unwrap(), + ); + + self.is_unicode = cast!(DataType::U32, result["NegotiateFlags"])? + & Negotiate::NtlmsspNegociateUnicode as u32 + == 1; let domain = self.get_domain_name(); let user = self.get_user_name(); - let auth_message_compute = authenticate_message(&lm_challenge_response, &nt_challenge_response, &domain, &user, b"", &encrypted_random_session_key, cast!(DataType::U32, result["NegotiateFlags"])?); + let auth_message_compute = authenticate_message( + &lm_challenge_response, + &nt_challenge_response, + &domain, + &user, + b"", + &encrypted_random_session_key, + cast!(DataType::U32, result["NegotiateFlags"])?, + ); // need to write a tmp message to compute MIC and then include it into final message - let tmp_final_auth_message = to_vec(&trame![to_vec(&auth_message_compute.0), vec![0; 16], auth_message_compute.1.clone()]); - - let signature = mic(self.exported_session_key.as_ref().unwrap(), self.negotiate_message.as_ref().unwrap(), request, &tmp_final_auth_message); - Ok(to_vec(&trame![auth_message_compute.0, signature, auth_message_compute.1])) + let tmp_final_auth_message = to_vec(&trame![ + to_vec(&auth_message_compute.0), + vec![0; 16], + auth_message_compute.1.clone() + ]); + + let signature = mic( + self.exported_session_key.as_ref().unwrap(), + self.negotiate_message.as_ref().unwrap(), + request, + &tmp_final_auth_message, + ); + Ok(to_vec(&trame![ + auth_message_compute.0, + signature, + auth_message_compute.1 + ])) } /// We are now able to build a security interface @@ -591,14 +708,12 @@ impl AuthenticationProtocol for Ntlm { let client_sealing_key = seal_key(self.exported_session_key.as_ref().unwrap(), true); let server_sealing_key = seal_key(self.exported_session_key.as_ref().unwrap(), false); - Box::new( - NTLMv2SecurityInterface::new( - Rc4::new(&client_sealing_key), - Rc4::new(&server_sealing_key), - client_signing_key, - server_signing_key - ) - ) + Box::new(NTLMv2SecurityInterface::new( + Rc4::new(&client_sealing_key), + Rc4::new(&server_sealing_key), + client_signing_key, + server_signing_key, + )) } /// Retrieve the domain name encoded as expected during negotiate payload @@ -642,7 +757,7 @@ pub struct NTLMv2SecurityInterface { /// Key use message integrity that come from server verify_key: Vec, /// Payload number - seq_num: u32 + seq_num: u32, } impl NTLMv2SecurityInterface { @@ -660,7 +775,7 @@ impl NTLMv2SecurityInterface { decrypt, signing_key, verify_key, - seq_num: 0 + seq_num: 0, } } } @@ -716,10 +831,16 @@ impl GenericSecurityService for NTLMv2SecurityInterface { // compute signature let seq_num = to_vec(&U32::LE(cast!(DataType::U32, signature["SeqNum"])?)); - let computed_checksum = hmac_md5(&self.verify_key, &[seq_num, plaintext_payload.clone()].concat()); + let computed_checksum = hmac_md5( + &self.verify_key, + &[seq_num, plaintext_payload.clone()].concat(), + ); if plaintext_checksum.as_slice() != &(computed_checksum[0..8]) { - return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidChecksum, "Invalid checksum on NTLMv2"))) + return Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidChecksum, + "Invalid checksum on NTLMv2", + ))); } Ok(plaintext_payload) } @@ -734,47 +855,112 @@ mod test { #[test] fn test_ntlmv2_negotiate_message() { let mut buffer = Cursor::new(Vec::new()); - Ntlm::new("".to_string(), "".to_string(), "".to_string()).create_negotiate_message().unwrap().write(&mut buffer).unwrap(); - assert_eq!(buffer.get_ref().as_slice(), [78, 84, 76, 77, 83, 83, 80, 0, 1, 0, 0, 0, 53, 130, 8, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + Ntlm::new("".to_string(), "".to_string(), "".to_string()) + .create_negotiate_message() + .unwrap() + .write(&mut buffer) + .unwrap(); + assert_eq!( + buffer.get_ref().as_slice(), + [ + 78, 84, 76, 77, 83, 83, 80, 0, 1, 0, 0, 0, 53, 130, 8, 96, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0 + ] + ); } /// Test of md4 hash function #[test] fn test_md4() { - assert_eq!(md4(b"foo"), [0x0a, 0xc6, 0x70, 0x0c, 0x49, 0x1d, 0x70, 0xfb, 0x86, 0x50, 0x94, 0x0b, 0x1c, 0xa1, 0xe4, 0xb2]) + assert_eq!( + md4(b"foo"), + [ + 0x0a, 0xc6, 0x70, 0x0c, 0x49, 0x1d, 0x70, 0xfb, 0x86, 0x50, 0x94, 0x0b, 0x1c, 0xa1, + 0xe4, 0xb2 + ] + ) } /// Test of the unicode function #[test] fn test_unicode() { - assert_eq!(unicode(&"foo".to_string()), [0x66, 0x00, 0x6f, 0x00, 0x6f, 0x00]) + assert_eq!( + unicode(&"foo".to_string()), + [0x66, 0x00, 0x6f, 0x00, 0x6f, 0x00] + ) } /// Test HMAC_MD5 function #[test] fn test_hmacmd5() { - assert_eq!(hmac_md5(b"foo", b"bar"), [0x0c, 0x7a, 0x25, 0x02, 0x81, 0x31, 0x5a, 0xb8, 0x63, 0x54, 0x9f, 0x66, 0xcd, 0x8a, 0x3a, 0x53]) + assert_eq!( + hmac_md5(b"foo", b"bar"), + [ + 0x0c, 0x7a, 0x25, 0x02, 0x81, 0x31, 0x5a, 0xb8, 0x63, 0x54, 0x9f, 0x66, 0xcd, 0x8a, + 0x3a, 0x53 + ] + ) } /// Test NTOWFv2 function #[test] fn test_ntowfv2() { - assert_eq!(ntowfv2(&"foo".to_string(), &"user".to_string(), &"domain".to_string()), [0x6e, 0x53, 0xb9, 0x0, 0x97, 0x8c, 0x87, 0x1f, 0x91, 0xde, 0x6, 0x44, 0x9d, 0x8b, 0x8b, 0x81]) + assert_eq!( + ntowfv2( + &"foo".to_string(), + &"user".to_string(), + &"domain".to_string() + ), + [ + 0x6e, 0x53, 0xb9, 0x0, 0x97, 0x8c, 0x87, 0x1f, 0x91, 0xde, 0x6, 0x44, 0x9d, 0x8b, + 0x8b, 0x81 + ] + ) } /// Test LMOWFv2 function #[test] fn test_lmowfv2() { - assert_eq!(lmowfv2(&"foo".to_string(), &"user".to_string(), &"domain".to_string()), ntowfv2(&"foo".to_string(), &"user".to_string(), &"domain".to_string())) + assert_eq!( + lmowfv2( + &"foo".to_string(), + &"user".to_string(), + &"domain".to_string() + ), + ntowfv2( + &"foo".to_string(), + &"user".to_string(), + &"domain".to_string() + ) + ) } /// Test compute response v2 function #[test] fn test_compute_response_v2() { let response = compute_response_v2(b"a", b"b", b"c", b"d", b"e", b"f"); - assert_eq!(response.0, [0xb4, 0x23, 0x84, 0xf, 0x6e, 0x83, 0xc1, 0x5a, 0x45, 0x4f, 0x4c, 0x92, 0x7a, 0xf2, 0xc3, 0x3e, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x65, 0x64, 0x0, 0x0, 0x0, 0x0, 0x66]); - assert_eq!(response.1, [0x56, 0xba, 0xff, 0x2d, 0x98, 0xbe, 0xcd, 0xa5, 0x6d, 0xe6, 0x17, 0x89, 0xe1, 0xed, 0xca, 0xae, 0x64]); - assert_eq!(response.2, [0x40, 0x3b, 0x33, 0xe5, 0x24, 0x34, 0x3c, 0xc3, 0x24, 0xa0, 0x4d, 0x77, 0x75, 0x34, 0xa4, 0xd0]); + assert_eq!( + response.0, + [ + 0xb4, 0x23, 0x84, 0xf, 0x6e, 0x83, 0xc1, 0x5a, 0x45, 0x4f, 0x4c, 0x92, 0x7a, 0xf2, + 0xc3, 0x3e, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x65, 0x64, 0x0, 0x0, 0x0, 0x0, + 0x66 + ] + ); + assert_eq!( + response.1, + [ + 0x56, 0xba, 0xff, 0x2d, 0x98, 0xbe, 0xcd, 0xa5, 0x6d, 0xe6, 0x17, 0x89, 0xe1, 0xed, + 0xca, 0xae, 0x64 + ] + ); + assert_eq!( + response.2, + [ + 0x40, 0x3b, 0x33, 0xe5, 0x24, 0x34, 0x3c, 0xc3, 0x24, 0xa0, 0x4d, 0x77, 0x75, 0x34, + 0xa4, 0xd0 + ] + ); } /// Test of rc4k function @@ -786,32 +972,76 @@ mod test { /// Test of sign_key function #[test] fn test_sign_key() { - assert_eq!(sign_key(b"foo", true), [253, 238, 149, 155, 221, 78, 43, 179, 82, 61, 111, 132, 168, 68, 222, 15]); - assert_eq!(sign_key(b"foo", false), [90, 201, 12, 225, 140, 156, 151, 61, 156, 56, 31, 254, 10, 223, 252, 74]) + assert_eq!( + sign_key(b"foo", true), + [253, 238, 149, 155, 221, 78, 43, 179, 82, 61, 111, 132, 168, 68, 222, 15] + ); + assert_eq!( + sign_key(b"foo", false), + [90, 201, 12, 225, 140, 156, 151, 61, 156, 56, 31, 254, 10, 223, 252, 74] + ) } /// Test of seal_key function #[test] fn test_seal_key() { - assert_eq!(seal_key(b"foo", true), [20, 213, 185, 176, 168, 142, 134, 244, 36, 249, 89, 247, 180, 36, 162, 101]); - assert_eq!(seal_key(b"foo", false), [64, 125, 160, 17, 144, 165, 62, 226, 22, 125, 128, 31, 103, 141, 55, 40]); + assert_eq!( + seal_key(b"foo", true), + [20, 213, 185, 176, 168, 142, 134, 244, 36, 249, 89, 247, 180, 36, 162, 101] + ); + assert_eq!( + seal_key(b"foo", false), + [64, 125, 160, 17, 144, 165, 62, 226, 22, 125, 128, 31, 103, 141, 55, 40] + ); } /// Test signature function #[test] fn test_mac() { - assert_eq!(mac(&mut Rc4::new(b"foo"), b"bar", 0, b"data"), [1, 0, 0, 0, 77, 211, 144, 84, 51, 242, 202, 176, 0, 0, 0, 0]) + assert_eq!( + mac(&mut Rc4::new(b"foo"), b"bar", 0, b"data"), + [1, 0, 0, 0, 77, 211, 144, 84, 51, 242, 202, 176, 0, 0, 0, 0] + ) } /// Test challenge message #[test] fn test_auth_message() { - let result = authenticate_message(b"foo", b"foo", b"domain", b"user", b"workstation", b"foo", 0); + let result = authenticate_message( + b"foo", + b"foo", + b"domain", + b"user", + b"workstation", + b"foo", + 0, + ); let compare_result = [to_vec(&result.0), vec![0; 16], result.1].concat(); - assert_eq!(compare_result[0..32], [78, 84, 76, 77, 83, 83, 80, 0, 3, 0, 0, 0, 3, 0, 3, 0, 80, 0, 0, 0, 3, 0, 3, 0, 83, 0, 0, 0, 6, 0, 6, 0]); - assert_eq!(compare_result[32..64], [86, 0, 0, 0, 4, 0, 4, 0, 92, 0, 0, 0, 11, 0, 11, 0, 96, 0, 0, 0, 3, 0, 3, 0, 107, 0, 0, 0, 0, 0, 0, 0]); - assert_eq!(compare_result[64..96], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 102, 111, 111, 102, 111, 111, 100, 111, 109, 97, 105, 110, 117, 115, 101, 114]); - assert_eq!(compare_result[96..110], [119, 111, 114, 107, 115, 116, 97, 116, 105, 111, 110, 102, 111, 111]); + assert_eq!( + compare_result[0..32], + [ + 78, 84, 76, 77, 83, 83, 80, 0, 3, 0, 0, 0, 3, 0, 3, 0, 80, 0, 0, 0, 3, 0, 3, 0, 83, + 0, 0, 0, 6, 0, 6, 0 + ] + ); + assert_eq!( + compare_result[32..64], + [ + 86, 0, 0, 0, 4, 0, 4, 0, 92, 0, 0, 0, 11, 0, 11, 0, 96, 0, 0, 0, 3, 0, 3, 0, 107, + 0, 0, 0, 0, 0, 0, 0 + ] + ); + assert_eq!( + compare_result[64..96], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 102, 111, 111, 102, 111, 111, 100, + 111, 109, 97, 105, 110, 117, 115, 101, 114 + ] + ); + assert_eq!( + compare_result[96..110], + [119, 111, 114, 107, 115, 116, 97, 116, 105, 111, 110, 102, 111, 111] + ); } #[test] @@ -826,4 +1056,4 @@ mod test { key.process(plaintext2, &mut cipher2); assert_eq!(cipher2, [75, 169, 19]); } -} \ No newline at end of file +} diff --git a/src/nla/rc4.rs b/src/nla/rc4.rs index 1197a2e..97e0310 100644 --- a/src/nla/rc4.rs +++ b/src/nla/rc4.rs @@ -1,19 +1,25 @@ pub struct Rc4 { i: u8, j: u8, - state: [u8; 256] + state: [u8; 256], } impl Rc4 { pub fn new(key: &[u8]) -> Rc4 { assert!(key.len() >= 1 && key.len() <= 256); - let mut rc4 = Rc4 { i: 0, j: 0, state: [0; 256] }; + let mut rc4 = Rc4 { + i: 0, + j: 0, + state: [0; 256], + }; for (i, x) in rc4.state.iter_mut().enumerate() { *x = i as u8; } let mut j: u8 = 0; for i in 0..256 { - j = j.wrapping_add(rc4.state[i]).wrapping_add(key[i % key.len()]); + j = j + .wrapping_add(rc4.state[i]) + .wrapping_add(key[i % key.len()]); rc4.state.swap(i, j as usize); } rc4 @@ -22,7 +28,8 @@ impl Rc4 { self.i = self.i.wrapping_add(1); self.j = self.j.wrapping_add(self.state[self.i as usize]); self.state.swap(self.i as usize, self.j as usize); - let k = self.state[(self.state[self.i as usize].wrapping_add(self.state[self.j as usize])) as usize]; + let k = self.state + [(self.state[self.i as usize].wrapping_add(self.state[self.j as usize])) as usize]; k } @@ -32,4 +39,4 @@ impl Rc4 { *y = *x ^ self.next(); } } -} \ No newline at end of file +} diff --git a/src/nla/sspi.rs b/src/nla/sspi.rs index 4c8f848..7173b61 100644 --- a/src/nla/sspi.rs +++ b/src/nla/sspi.rs @@ -1,4 +1,4 @@ -use model::error::RdpResult; +use crate::model::error::RdpResult; /// This is a trait use by authentication /// protocol to provide a context @@ -36,4 +36,4 @@ pub trait AuthenticationProtocol { /// Get password encoded as expected in the negotiated payload fn get_password(&self) -> Vec; -} \ No newline at end of file +} From d49a122ff082d2950c92c49aaccef9528bbe74a7 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Fri, 21 Oct 2022 18:21:56 +0800 Subject: [PATCH 02/12] Add default feature which use openssl to start ssl connections --- Cargo.toml | 4 +++- src/model/error.rs | 6 ++++++ src/model/link.rs | 24 +++++++++++++++++------- src/nla/cssp.rs | 5 ++--- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 93ea77d..56d456d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,13 +21,15 @@ path = "src/bin/mstsc-rs.rs" required-features = ["mstsc-rs"] [features] +default = ["openssl"] # The reason we do this is because doctests don't get cfg(test) # See: https://github.com/rust-lang/cargo/issues/4669 integration = [] mstsc-rs = ["hex", "winapi", "minifb", "clap", "libc"] +openssl = ["native-tls"] [dependencies] -native-tls = "^0.2" +native-tls = { version = "^0.2", optional = true } byteorder = "^1.3" bufstream = "0.1" indexmap = "^1.3" diff --git a/src/model/error.rs b/src/model/error.rs index 4b96c3b..38adc4a 100644 --- a/src/model/error.rs +++ b/src/model/error.rs @@ -1,7 +1,10 @@ +#[cfg(feature = "openssl")] use native_tls::Error as SslError; +#[cfg(feature = "openssl")] use native_tls::HandshakeError; use num_enum::{TryFromPrimitive, TryFromPrimitiveError}; use std::io::Error as IoError; +#[cfg(feature = "openssl")] use std::io::{Read, Write}; use std::string::String; use yasna::ASN1Error; @@ -94,6 +97,7 @@ pub enum Error { /// SSL handshake error SslHandshakeError, /// SSL error + #[cfg(feature = "openssl")] SslError(SslError), /// ASN1 parser error ASN1Error(ASN1Error), @@ -108,12 +112,14 @@ impl From for Error { } } +#[cfg(feature = "openssl")] impl From> for Error { fn from(_: HandshakeError) -> Error { Error::SslHandshakeError } } +#[cfg(feature = "openssl")] impl From for Error { fn from(e: SslError) -> Error { Error::SslError(e) diff --git a/src/model/link.rs b/src/model/link.rs index ba27fd5..6572acc 100644 --- a/src/model/link.rs +++ b/src/model/link.rs @@ -1,6 +1,7 @@ use crate::model::data::Message; use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; -use native_tls::{Certificate, TlsConnector, TlsStream}; +#[cfg(feature = "openssl")] +use native_tls::{TlsConnector, TlsStream}; use std::io::{Cursor, Read, Write}; /// This a wrapper to work equals @@ -9,6 +10,7 @@ pub enum Stream { /// Raw stream that implement Read + Write Raw(S), /// TLS Stream + #[cfg(feature = "openssl")] Ssl(TlsStream), } @@ -27,6 +29,7 @@ impl Stream { pub fn read_exact(&mut self, buf: &mut [u8]) -> RdpResult<()> { match self { Stream::Raw(e) => e.read_exact(buf)?, + #[cfg(feature = "openssl")] Stream::Ssl(e) => e.read_exact(buf)?, }; Ok(()) @@ -46,6 +49,7 @@ impl Stream { pub fn read(&mut self, buf: &mut [u8]) -> RdpResult { match self { Stream::Raw(e) => Ok(e.read(buf)?), + #[cfg(feature = "openssl")] Stream::Ssl(e) => Ok(e.read(buf)?), } } @@ -69,6 +73,7 @@ impl Stream { pub fn write(&mut self, buffer: &[u8]) -> RdpResult { Ok(match self { Stream::Raw(e) => e.write(buffer)?, + #[cfg(feature = "openssl")] Stream::Ssl(e) => e.write(buffer)?, }) } @@ -77,6 +82,7 @@ impl Stream { /// Only works when stream is a SSL stream pub fn shutdown(&mut self) -> RdpResult<()> { Ok(match self { + #[cfg(feature = "openssl")] Stream::Ssl(e) => e.shutdown()?, _ => (), }) @@ -168,6 +174,7 @@ impl Link { /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap())); /// let link_ssl = link_tcp.start_ssl(false).unwrap(); /// ``` + #[cfg(feature = "openssl")] pub fn start_ssl(self, check_certificate: bool) -> RdpResult> { let mut builder = TlsConnector::builder(); builder.danger_accept_invalid_certs(!check_certificate); @@ -196,14 +203,17 @@ impl Link { /// let link_ssl = link_tcp.start_ssl(false).unwrap(); /// let certificate = link_ssl.get_peer_certificate().unwrap().unwrap(); /// ``` - pub fn get_peer_certificate(&self) -> RdpResult> { - if let Stream::Ssl(stream) = &self.stream { - Ok(stream.peer_certificate()?) - } else { - Err(Error::RdpError(RdpError::new( + pub fn get_peer_certificate_der(&self) -> RdpResult>> { + match &self.stream { + #[cfg(feature = "openssl")] + Stream::Ssl(stream) => Ok(match stream.peer_certificate()? { + Some(cert) => Some(cert.to_der()?), + None => None, + }), + _ => Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidData, "get peer certificate on non ssl link is impossible", - ))) + ))), } } diff --git a/src/nla/cssp.rs b/src/nla/cssp.rs index 309ef64..7b5c44d 100644 --- a/src/nla/cssp.rs +++ b/src/nla/cssp.rs @@ -185,10 +185,9 @@ pub fn cssp_connect( // Get the peer public certificate let certificate_der = try_option!( - link.get_peer_certificate()?, + link.get_peer_certificate_der()?, "No public certificate available" - )? - .to_der()?; + )?; let certificate = read_public_certificate(&certificate_der)?; // Now we can send back our challenge payload wit the public key encoded From 0c30c03acac261fddb45838a5074c3f7143cd4e8 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Fri, 21 Oct 2022 18:21:56 +0800 Subject: [PATCH 03/12] Allow bio r/w --- src/core/client.rs | 15 +++++++++++++++ src/model/link.rs | 28 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/core/client.rs b/src/core/client.rs index 72306dc..339b48e 100644 --- a/src/core/client.rs +++ b/src/core/client.rs @@ -7,6 +7,8 @@ use crate::core::sec; use crate::core::tpkt; use crate::core::x224; use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; +#[cfg(not(feature = "openssl"))] +use crate::model::link::SecureBio; use crate::model::link::{Link, Stream}; use crate::nla::ntlm::Ntlm; use std::io::{Read, Write}; @@ -232,10 +234,23 @@ impl Connector { /// .credentials("domain".to_string(), "username".to_string(), "password".to_string()); /// let mut client = connector.connect(tcp).unwrap(); /// ``` + #[cfg(feature = "openssl")] pub fn connect(&mut self, stream: S) -> RdpResult> { // Create a wrapper around the stream let tcp = Link::new(Stream::Raw(stream)); + self.connect_further(tcp) + } + #[cfg(not(feature = "openssl"))] + pub fn connect + 'static>( + &mut self, + stream: Box, + ) -> RdpResult> { + // Create a wrapper around the stream + let tcp = Link::new(Stream::Bio(stream)); + self.connect_further(tcp) + } + fn connect_further(&self, tcp: Link) -> RdpResult> { // Compute authentication method let mut authentication = if let Some(hash) = &self.password_hash { Ntlm::from_hash(self.domain.clone(), self.username.clone(), hash) diff --git a/src/model/link.rs b/src/model/link.rs index 6572acc..5d7b43a 100644 --- a/src/model/link.rs +++ b/src/model/link.rs @@ -4,6 +4,13 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use native_tls::{TlsConnector, TlsStream}; use std::io::{Cursor, Read, Write}; +#[cfg(not(feature = "openssl"))] +pub trait SecureBio: Read + Write { + fn start_ssl(&self, check_certificate: bool) -> RdpResult>; + fn get_peer_certificate_der(&self) -> RdpResult>>; + fn shutdown(&mut self) -> std::io::Result<()>; +} + /// This a wrapper to work equals /// for a stream and a TLS stream pub enum Stream { @@ -12,6 +19,8 @@ pub enum Stream { /// TLS Stream #[cfg(feature = "openssl")] Ssl(TlsStream), + #[cfg(not(feature = "openssl"))] + Bio(Box>), } impl Stream { @@ -31,6 +40,8 @@ impl Stream { Stream::Raw(e) => e.read_exact(buf)?, #[cfg(feature = "openssl")] Stream::Ssl(e) => e.read_exact(buf)?, + #[cfg(not(feature = "openssl"))] + Stream::Bio(bio) => bio.read_exact(buf)?, }; Ok(()) } @@ -51,6 +62,8 @@ impl Stream { Stream::Raw(e) => Ok(e.read(buf)?), #[cfg(feature = "openssl")] Stream::Ssl(e) => Ok(e.read(buf)?), + #[cfg(not(feature = "openssl"))] + Stream::Bio(e) => Ok(e.read(buf)?), } } @@ -75,6 +88,8 @@ impl Stream { Stream::Raw(e) => e.write(buffer)?, #[cfg(feature = "openssl")] Stream::Ssl(e) => e.write(buffer)?, + #[cfg(not(feature = "openssl"))] + Stream::Bio(e) => e.write(buffer)?, }) } @@ -84,6 +99,8 @@ impl Stream { Ok(match self { #[cfg(feature = "openssl")] Stream::Ssl(e) => e.shutdown()?, + #[cfg(not(feature = "openssl"))] + Stream::Bio(e) => e.shutdown()?, _ => (), }) } @@ -191,6 +208,17 @@ impl Link { ))) } + #[cfg(not(feature = "openssl"))] + pub fn start_ssl(self, check_certificate: bool) -> RdpResult> { + if let Stream::Bio(ref stream) = self.stream { + stream.start_ssl(check_certificate)?; + return Ok(self); + } + Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + "start_ssl on ssl stream is forbidden", + ))) + } /// Retrive the peer certificate /// Use by the NLA authentication protocol /// to avoid MITM attack From 43fe7fcb3518b5991087d4e259cb088eaae20925 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Sat, 22 Oct 2022 21:51:05 +0800 Subject: [PATCH 04/12] use a wrapper on bio_stream --- src/model/link.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/model/link.rs b/src/model/link.rs index 5d7b43a..ef53930 100644 --- a/src/model/link.rs +++ b/src/model/link.rs @@ -5,10 +5,14 @@ use native_tls::{TlsConnector, TlsStream}; use std::io::{Cursor, Read, Write}; #[cfg(not(feature = "openssl"))] -pub trait SecureBio: Read + Write { - fn start_ssl(&self, check_certificate: bool) -> RdpResult>; +pub trait SecureBio +where + S: Read + Write, +{ + fn start_ssl(&mut self, check_certificate: bool) -> RdpResult<()>; fn get_peer_certificate_der(&self) -> RdpResult>>; fn shutdown(&mut self) -> std::io::Result<()>; + fn get_io(&mut self) -> &mut S; } /// This a wrapper to work equals @@ -41,7 +45,7 @@ impl Stream { #[cfg(feature = "openssl")] Stream::Ssl(e) => e.read_exact(buf)?, #[cfg(not(feature = "openssl"))] - Stream::Bio(bio) => bio.read_exact(buf)?, + Stream::Bio(bio) => bio.get_io().read_exact(buf)?, }; Ok(()) } @@ -63,7 +67,7 @@ impl Stream { #[cfg(feature = "openssl")] Stream::Ssl(e) => Ok(e.read(buf)?), #[cfg(not(feature = "openssl"))] - Stream::Bio(e) => Ok(e.read(buf)?), + Stream::Bio(e) => Ok(e.get_io().read(buf)?), } } @@ -89,7 +93,7 @@ impl Stream { #[cfg(feature = "openssl")] Stream::Ssl(e) => e.write(buffer)?, #[cfg(not(feature = "openssl"))] - Stream::Bio(e) => e.write(buffer)?, + Stream::Bio(e) => e.get_io().write(buffer)?, }) } @@ -210,9 +214,9 @@ impl Link { #[cfg(not(feature = "openssl"))] pub fn start_ssl(self, check_certificate: bool) -> RdpResult> { - if let Stream::Bio(ref stream) = self.stream { + if let Stream::Bio(mut stream) = self.stream { stream.start_ssl(check_certificate)?; - return Ok(self); + return Ok(Link::new(Stream::Bio(stream))); } Err(Error::RdpError(RdpError::new( RdpErrorKind::NotImplemented, From 46971d29cab6a600811ecc25413536c7bf1da9b0 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Fri, 21 Oct 2022 23:41:29 +0800 Subject: [PATCH 05/12] swith to async-io --- Cargo.toml | 5 +- src/codec/rle.rs | 14 +-- src/core/capability.rs | 14 +-- src/core/client.rs | 67 ++++++++------ src/core/event.rs | 6 +- src/core/gcc.rs | 50 +++++------ src/core/global.rs | 198 ++++++++++++++++++++++------------------- src/core/license.rs | 2 +- src/core/mcs.rs | 104 ++++++++++++---------- src/core/per.rs | 14 +-- src/core/sec.rs | 61 ++++++------- src/core/tpkt.rs | 58 ++++++------ src/core/x224.rs | 41 +++++---- src/model/data.rs | 15 ++-- src/model/error.rs | 14 +-- src/model/link.rs | 70 ++++++++------- src/model/unicode.rs | 2 +- src/nla/asn1.rs | 6 +- src/nla/cssp.rs | 15 ++-- src/nla/ntlm.rs | 61 +++++-------- src/nla/rc4.rs | 7 +- 21 files changed, 424 insertions(+), 400 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 56d456d..3e24dd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,10 +26,11 @@ default = ["openssl"] # See: https://github.com/rust-lang/cargo/issues/4669 integration = [] mstsc-rs = ["hex", "winapi", "minifb", "clap", "libc"] -openssl = ["native-tls"] +openssl = ["async-native-tls"] [dependencies] -native-tls = { version = "^0.2", optional = true } +tokio = { version = "1.21.2", features = ["full"] } +async-native-tls = { version = "^0.4", optional = true, default-features = false, features = ["runtime-tokio"] } byteorder = "^1.3" bufstream = "0.1" indexmap = "^1.3" diff --git a/src/codec/rle.rs b/src/codec/rle.rs index b44e908..061b400 100644 --- a/src/codec/rle.rs +++ b/src/codec/rle.rs @@ -36,7 +36,7 @@ fn process_plane( replen = code & 0xf; collen = (code >> 4) & 0xf; revcode = (replen << 4) | collen; - if (revcode <= 47) && (revcode >= 16) { + if (16..=47).contains(&revcode) { replen = revcode; collen = 0; } @@ -60,18 +60,18 @@ fn process_plane( replen = code & 0xf; collen = (code >> 4) & 0xf; revcode = (replen << 4) | collen; - if (revcode <= 47) && (revcode >= 16) { + if (16..=47).contains(&revcode) { replen = revcode; collen = 0; } while collen > 0 { x = input.read_u8()?; if x & 1 != 0 { - x = x >> 1; - x = x + 1; + x >>= 1; + x += 1; color = -(x as i32) as i8; } else { - x = x >> 1; + x >>= 1; color = x as i8; } x = (output[(last_line + (indexw * 4)) as usize] as i32 + color as i32) as u8; @@ -257,7 +257,7 @@ pub fn rle_16_decompress( while count > 0 { if x >= width { - if height <= 0 { + if height == 0 { return Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidData, "error during decompress", @@ -390,7 +390,7 @@ pub fn rle_16_decompress( } pub fn rgb565torgb32(input: &[u16], width: usize, height: usize) -> Vec { - let mut result_32_bpp = vec![0 as u8; width as usize * height as usize * 4]; + let mut result_32_bpp = vec![0_u8; width as usize * height as usize * 4]; for i in 0..height { for j in 0..width { let index = (i * width + j) as usize; diff --git a/src/core/capability.rs b/src/core/capability.rs index 0a166ab..8208734 100644 --- a/src/core/capability.rs +++ b/src/core/capability.rs @@ -203,8 +203,8 @@ pub fn ts_general_capability_set(extra_flags: Option) -> Capability { "updateCapabilityFlag" => Check::new(U16::LE(0)), "remoteUnshareFlag" => Check::new(U16::LE(0)), "generalCompressionLevel" => Check::new(U16::LE(0)), - "refreshRectSupport" => 0 as u8, - "suppressOutputSupport" => 0 as u8 + "refreshRectSupport" => 0_u8, + "suppressOutputSupport" => 0_u8 ], } } @@ -239,8 +239,8 @@ pub fn ts_bitmap_capability_set( "pad2octets" => U16::LE(0), "desktopResizeFlag" => U16::LE(0), "bitmapCompressionFlag" => Check::new(U16::LE(0x0001)), - "highColorFlags" => Check::new(0 as u8), - "drawingFlags" => 0 as u8, + "highColorFlags" => Check::new(0_u8), + "drawingFlags" => 0_u8, "multipleRectangleSupport" => Check::new(U16::LE(0x0001)), "pad2octetsB" => U16::LE(0) ], @@ -273,7 +273,7 @@ pub fn ts_order_capability_set(order_flags: Option) -> Capability { Capability { cap_type: CapabilitySetType::CapstypeOrder, message: component![ - "terminalDescriptor" => vec![0 as u8; 16], + "terminalDescriptor" => vec![0_u8; 16], "pad4octetsA" => U32::LE(0), "desktopSaveXGranularity" => U16::LE(1), "desktopSaveYGranularity" => U16::LE(20), @@ -281,7 +281,7 @@ pub fn ts_order_capability_set(order_flags: Option) -> Capability { "maximumOrderLevel" => U16::LE(1), "numberFonts" => U16::LE(0), "orderFlags" => U16::LE(order_flags.unwrap_or(OrderFlag::NEGOTIATEORDERSUPPORT as u16)), - "orderSupport" => vec![0 as u8; 32], + "orderSupport" => vec![0_u8; 32], "textFlags" => U16::LE(0), "orderSupportExFlags" => U16::LE(0), "pad4octetsB" => U32::LE(0), @@ -398,7 +398,7 @@ pub fn ts_input_capability_set( "keyboardType" => U32::LE(KeyboardType::Ibm101102Keys as u32), "keyboardSubType" => U32::LE(0), "keyboardFunctionKey" => U32::LE(12), - "imeFileName" => vec![0 as u8; 64] + "imeFileName" => vec![0_u8; 64] ], } } diff --git a/src/core/client.rs b/src/core/client.rs index 339b48e..815602d 100644 --- a/src/core/client.rs +++ b/src/core/client.rs @@ -11,7 +11,7 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use crate::model::link::SecureBio; use crate::model::link::{Link, Stream}; use crate::nla::ntlm::Ntlm; -use std::io::{Read, Write}; +use tokio::io::*; impl From<&str> for KeyboardLayout { fn from(e: &str) -> Self { @@ -31,7 +31,7 @@ pub struct RdpClient { global: global::Client, } -impl RdpClient { +impl RdpClient { /// Read a payload from the server /// RDpClient use a callback pattern that can be called more than once /// during a read call @@ -56,13 +56,13 @@ impl RdpClient { /// } /// }).unwrap() /// ``` - pub fn read(&mut self, callback: T) -> RdpResult<()> + pub async fn read(&mut self, callback: T) -> RdpResult<()> where T: FnMut(RdpEvent), { - let (channel_name, message) = self.mcs.read()?; + let (channel_name, message) = self.mcs.read().await?; match channel_name.as_str() { - "global" => self.global.read(message, &mut self.mcs, callback), + "global" => self.global.read(message, &mut self.mcs, callback).await, _ => Err(Error::RdpError(RdpError::new( RdpErrorKind::UnexpectedType, &format!("Invalid channel name {:?}", channel_name), @@ -94,7 +94,7 @@ impl RdpClient { /// } /// )).unwrap() /// ``` - pub fn write(&mut self, event: RdpEvent) -> RdpResult<()> { + pub async fn write(&mut self, event: RdpEvent) -> RdpResult<()> { match event { // Pointer event // Mouse position an d button position @@ -113,10 +113,12 @@ impl RdpClient { flags |= PointerFlag::PtrflagsDown as u16; } - self.global.write_input_event( - ts_pointer_event(Some(flags), Some(pointer.x), Some(pointer.y)), - &mut self.mcs, - ) + self.global + .write_input_event( + ts_pointer_event(Some(flags), Some(pointer.x), Some(pointer.y)), + &mut self.mcs, + ) + .await } // Raw keyboard input RdpEvent::Key(key) => { @@ -124,10 +126,12 @@ impl RdpClient { if !key.down { flags |= KeyboardFlag::KbdflagsRelease as u16; } - self.global.write_input_event( - ts_keyboard_event(Some(flags), Some(key.code)), - &mut self.mcs, - ) + self.global + .write_input_event( + ts_keyboard_event(Some(flags), Some(key.code)), + &mut self.mcs, + ) + .await } _ => Err(Error::RdpError(RdpError::new( RdpErrorKind::UnexpectedType, @@ -140,8 +144,8 @@ impl RdpClient { /// once the global channel is not connected /// This will disable InvalidAutomata error in case /// if you sent input before end of the sync process - pub fn try_write(&mut self, event: RdpEvent) -> RdpResult<()> { - let result = self.write(event); + pub async fn try_write(&mut self, event: RdpEvent) -> RdpResult<()> { + let result = self.write(event).await; match result { Err(Error::RdpError(e)) => match e.kind() { RdpErrorKind::InvalidAutomata => Ok(()), @@ -152,11 +156,12 @@ impl RdpClient { } /// Close client is indeed close the switch layer - pub fn shutdown(&mut self) -> RdpResult<()> { - self.mcs.shutdown() + pub async fn shutdown(&mut self) -> RdpResult<()> { + self.mcs.shutdown().await } } +#[derive(Default)] pub struct Connector { /// Screen width width: u16, @@ -235,13 +240,16 @@ impl Connector { /// let mut client = connector.connect(tcp).unwrap(); /// ``` #[cfg(feature = "openssl")] - pub fn connect(&mut self, stream: S) -> RdpResult> { + pub async fn connect( + &mut self, + stream: S, + ) -> RdpResult> { // Create a wrapper around the stream let tcp = Link::new(Stream::Raw(stream)); - self.connect_further(tcp) + self.connect_further(tcp).await } #[cfg(not(feature = "openssl"))] - pub fn connect + 'static>( + pub fn connect + 'static>( &mut self, stream: Box, ) -> RdpResult> { @@ -250,7 +258,10 @@ impl Connector { self.connect_further(tcp) } - fn connect_further(&self, tcp: Link) -> RdpResult> { + async fn connect_further( + &self, + tcp: Link, + ) -> RdpResult> { // Compute authentication method let mut authentication = if let Some(hash) = &self.password_hash { Ntlm::from_hash(self.domain.clone(), self.username.clone(), hash) @@ -275,11 +286,13 @@ impl Connector { Some(&mut authentication), self.restricted_admin_mode, self.blank_creds, - )?; + ) + .await?; // Create MCS layer and connect it let mut mcs = mcs::Client::new(x224); - mcs.connect(self.name.clone(), self.width, self.height, self.layout)?; + mcs.connect(self.name.clone(), self.width, self.height, self.layout) + .await?; // state less connection for old secure layer if self.restricted_admin_mode { sec::connect( @@ -288,7 +301,8 @@ impl Connector { &"".to_string(), &"".to_string(), self.auto_logon, - )?; + ) + .await?; } else { sec::connect( &mut mcs, @@ -296,7 +310,8 @@ impl Connector { &self.username, &self.password, self.auto_logon, - )?; + ) + .await?; } // Now the global channel diff --git a/src/core/event.rs b/src/core/event.rs index 1e9a0da..7560033 100644 --- a/src/core/event.rs +++ b/src/core/event.rs @@ -65,7 +65,7 @@ impl BitmapEvent { 32 => { // 32 bpp is straight forward Ok(if self.is_compress { - let mut result = vec![0 as u8; self.width as usize * self.height as usize * 4]; + let mut result = vec![0_u8; self.width as usize * self.height as usize * 4]; rle_32_decompress( &self.data, self.width as u32, @@ -80,7 +80,7 @@ impl BitmapEvent { 16 => { // 16 bpp is more consumer let result_16bpp = if self.is_compress { - let mut result = vec![0 as u16; self.width as usize * self.height as usize * 2]; + let mut result = vec![0_u16; self.width as usize * self.height as usize * 2]; rle_16_decompress( &self.data, self.width as usize, @@ -89,7 +89,7 @@ impl BitmapEvent { )?; result } else { - let mut result = vec![0 as u16; self.width as usize * self.height as usize]; + let mut result = vec![0_u16; self.width as usize * self.height as usize]; for i in 0..self.height { for j in 0..self.width { let src = (((self.height - i - 1) * self.width + j) * 2) as usize; diff --git a/src/core/gcc.rs b/src/core/gcc.rs index 309b363..5c73ccc 100644 --- a/src/core/gcc.rs +++ b/src/core/gcc.rs @@ -51,7 +51,7 @@ enum Sequence { /// Keyboard layout /// https://docs.microsoft.com/en-us/previous-versions/windows/it-pro/windows-vista/cc766503(v=ws.10)?redirectedfrom=MSDN #[repr(u32)] -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Default)] pub enum KeyboardLayout { Arabic = 0x00000401, Bulgarian = 0x00000402, @@ -60,6 +60,7 @@ pub enum KeyboardLayout { Danish = 0x00000406, German = 0x00000407, Greek = 0x00000408, + #[default] US = 0x00000409, Spanish = 0x0000040a, Finnish = 0x0000040b, @@ -90,23 +91,23 @@ pub enum KeyboardType { #[repr(u16)] #[allow(dead_code)] -enum HighColor { - HighColor4BPP = 0x0004, - HighColor8BPP = 0x0008, - HighColor15BPP = 0x000f, - HighColor16BPP = 0x0010, - HighColor24BPP = 0x0018, +enum HighColorBpp { + Four = 0x0004, + Eight = 0x0008, + Fifteen = 0x000f, + Sixteen = 0x0010, + TwentyFour = 0x0018, } /// Supported color depth /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/00f1da4a-ee9c-421a-852f-c19f92343d73?redirectedfrom=MSDN #[repr(u16)] #[allow(dead_code)] -enum Support { - RnsUd24BPPSupport = 0x0001, - RnsUd16BPPSupport = 0x0002, - RnsUd15BPPSupport = 0x0004, - RnsUd32BPPSupport = 0x0008, +enum RnsUdBppSupport { + TwentyFour = 0x0001, + Sixteen = 0x0002, + Fifteen = 0x0004, + ThirtyTwo = 0x0008, } /// Negotiation of some capability for pdu layer @@ -209,7 +210,7 @@ pub fn client_core_data(parameter: Option) -> Component { }); let client_name = if client_parameter.name.len() >= 16 { - (&client_parameter.name[0..16]).to_string() + client_parameter.name[0..16].to_string() } else { client_parameter.name.clone() + &"\x00".repeat(16 - client_parameter.name.len()) }; @@ -222,25 +223,25 @@ pub fn client_core_data(parameter: Option) -> Component { "sasSequence" => U16::LE(Sequence::RnsUdSasDel as u16), "kbdLayout" => U32::LE(client_parameter.layout as u32), "clientBuild" => U32::LE(3790), - "clientName" => client_name.to_string().to_unicode(), + "clientName" => client_name.to_unicode(), "keyboardType" => U32::LE(KeyboardType::Ibm101102Keys as u32), "keyboardSubType" => U32::LE(0), "keyboardFnKeys" => U32::LE(12), - "imeFileName" => vec![0 as u8; 64], + "imeFileName" => vec![0_u8; 64], "postBeta2ColorDepth" => U16::LE(ColorDepth::RnsUdColor8BPP as u16), "clientProductId" => U16::LE(1), "serialNumber" => U32::LE(0), - "highColorDepth" => U16::LE(HighColor::HighColor24BPP as u16), + "highColorDepth" => U16::LE(HighColorBpp::TwentyFour as u16), "supportedColorDepths" => U16::LE( //Support::RnsUd15BPPSupport as u16 | - Support::RnsUd16BPPSupport as u16 | + RnsUdBppSupport::Sixteen as u16 | //Support::RnsUd24BPPSupport as u16 | - Support::RnsUd32BPPSupport as u16 + RnsUdBppSupport::ThirtyTwo as u16 ), "earlyCapabilityFlags" => U16::LE(CapabilityFlag::RnsUdCsSupportErrinfoPDU as u16), "clientDigProductId" => vec![0; 64], - "connectionType" => 0 as u8, - "pad1octet" => 0 as u8, + "connectionType" => 0_u8, + "pad1octet" => 0_u8, "serverSelectedProtocol" => U32::LE(client_parameter.server_selected_protocol) ] } @@ -350,11 +351,8 @@ pub fn read_conference_create_response(cc_response: &mut dyn Read) -> RdpResult< break; } - let mut buffer = vec![ - 0 as u8; - (cast!(DataType::U16, header["length"])? - header.length() as u16) - as usize - ]; + let mut buffer = + vec![0_u8; (cast!(DataType::U16, header["length"])? - header.length() as u16) as usize]; sub.read_exact(&mut buffer)?; match MessageType::from(cast!(DataType::U16, header["type"])?) { @@ -386,7 +384,7 @@ pub fn read_conference_create_response(cc_response: &mut dyn Read) -> RdpResult< DataType::Trame, result[&MessageType::ScNet]["channelIdArray"] )? - .into_iter() + .iter() .map(|x| cast!(DataType::U16, x).unwrap()) .collect(), rdp_version: Version::from(cast!( diff --git a/src/core/global.rs b/src/core/global.rs index 023c0dc..2126b99 100644 --- a/src/core/global.rs +++ b/src/core/global.rs @@ -10,43 +10,44 @@ use crate::model::data::{ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use num_enum::TryFromPrimitive; use std::convert::TryFrom; -use std::io::{Cursor, Read, Write}; +use std::io::{Cursor, Read}; +use tokio::io::*; /// Raw PDU type use by the protocol #[repr(u16)] #[derive(Copy, Clone, Eq, PartialEq, Debug, TryFromPrimitive)] -enum PDUType { - PdutypeDemandactivepdu = 0x11, - PdutypeConfirmactivepdu = 0x13, - PdutypeDeactivateallpdu = 0x16, - PdutypeDatapdu = 0x17, - PdutypeServerRedirPkt = 0x1A, +enum PduType { + Demandactivepdu = 0x11, + Confirmactivepdu = 0x13, + Deactivateallpdu = 0x16, + Datapdu = 0x17, + ServerRedirPkt = 0x1A, } /// PDU type available /// Most of them are used for initial handshake /// Then once connected only Data are send and received -struct PDU { - pub pdu_type: PDUType, +struct Pdu { + pub pdu_type: PduType, pub message: Component, } -impl PDU { +impl Pdu { /// Build a PDU structure from reading stream pub fn from_stream(stream: &mut dyn Read) -> RdpResult { let mut header = share_control_header(None, None, None); header.read(stream)?; - PDU::from_control(&header) + Pdu::from_control(&header) } /// Build a PDU data directly from a control message pub fn from_control(control: &Component) -> RdpResult { let pdu_type = cast!(DataType::U16, control["pduType"])?; - let mut pdu = match PDUType::try_from(pdu_type)? { - PDUType::PdutypeDemandactivepdu => ts_demand_active_pdu(), - PDUType::PdutypeDatapdu => share_data_header(None, None, None), - PDUType::PdutypeConfirmactivepdu => ts_confirm_active_pdu(None, None, None), - PDUType::PdutypeDeactivateallpdu => ts_deactivate_all_pdu(), + let mut pdu = match PduType::try_from(pdu_type)? { + PduType::Demandactivepdu => ts_demand_active_pdu(), + PduType::Datapdu => share_data_header(None, None, None), + PduType::Confirmactivepdu => ts_confirm_active_pdu(None, None, None), + PduType::Deactivateallpdu => ts_deactivate_all_pdu(), _ => { return Err(Error::RdpError(RdpError::new( RdpErrorKind::NotImplemented, @@ -68,9 +69,9 @@ impl PDU { /// of the target server /// /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/bd612af5-cb54-43a2-9646-438bc3ecf5db -fn ts_demand_active_pdu() -> PDU { - PDU { - pdu_type: PDUType::PdutypeDemandactivepdu, +fn ts_demand_active_pdu() -> Pdu { + Pdu { + pdu_type: PduType::Demandactivepdu, message: component![ "shareId" => U32::LE(0), "lengthSourceDescriptor" => DynOption::new(U16::LE(0), |length| MessageOption::Size("sourceDescriptor".to_string(), length.inner() as usize)), @@ -92,11 +93,12 @@ fn ts_confirm_active_pdu( share_id: Option, source: Option>, capabilities_set: Option>, -) -> PDU { - let default_capabilities_set = capabilities_set.unwrap_or(Array::new(|| capability_set(None))); - let default_source = source.unwrap_or(vec![]); - PDU { - pdu_type: PDUType::PdutypeConfirmactivepdu, +) -> Pdu { + let default_capabilities_set = + capabilities_set.unwrap_or_else(|| Array::new(|| capability_set(None))); + let default_source = source.unwrap_or_default(); + Pdu { + pdu_type: PduType::Confirmactivepdu, message: component![ "shareId" => U32::LE(share_id.unwrap_or(0)), "originatorId" => Check::new(U16::LE(0x03EA)), @@ -113,9 +115,9 @@ fn ts_confirm_active_pdu( /// Use to inform user that a session already exist /// /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/fc191c40-e688-4d5a-a550-6609cd5b8b59 -fn ts_deactivate_all_pdu() -> PDU { - PDU { - pdu_type: PDUType::PdutypeDeactivateallpdu, +fn ts_deactivate_all_pdu() -> Pdu { + Pdu { + pdu_type: PduType::Deactivateallpdu, message: component![ "shareId" => U32::LE(0), "lengthSourceDescriptor" => DynOption::new(U16::LE(0), |length| MessageOption::Size("sourceDescriptor".to_string(), length.inner() as usize)), @@ -129,17 +131,17 @@ fn share_data_header( share_id: Option, pdu_type_2: Option, message: Option>, -) -> PDU { - let default_message = message.unwrap_or(vec![]); - PDU { - pdu_type: PDUType::PdutypeDatapdu, +) -> Pdu { + let default_message = message.unwrap_or_default(); + Pdu { + pdu_type: PduType::Datapdu, message: component![ "shareId" => U32::LE(share_id.unwrap_or(0)), - "pad1" => 0 as u8, - "streamId" => 1 as u8, + "pad1" => 0_u8, + "streamId" => 1_u8, "uncompressedLength" => DynOption::new(U16::LE(default_message.length() as u16 + 18), | size | MessageOption::Size("payload".to_string(), size.inner() as usize - 18)), "pduType2" => pdu_type_2.unwrap_or(PDUType2::Pdutype2ArcStatusPdu) as u8, - "compressedType" => 0 as u8, + "compressedType" => 0_u8, "compressedLength" => U16::LE(0), "payload" => default_message ], @@ -149,14 +151,14 @@ fn share_data_header( /// This is the main PDU payload format /// It use the share control header to dispatch between all PDU fn share_control_header( - pdu_type: Option, + pdu_type: Option, pdu_source: Option, message: Option>, ) -> Component { - let default_message = message.unwrap_or(vec![]); + let default_message = message.unwrap_or_default(); component![ "totalLength" => DynOption::new(U16::LE(default_message.length() as u16 + 6), |total| MessageOption::Size("pduMessage".to_string(), total.inner() as usize - 6)), - "pduType" => U16::LE(pdu_type.unwrap_or(PDUType::PdutypeDemandactivepdu) as u16), + "pduType" => U16::LE(pdu_type.unwrap_or(PduType::Demandactivepdu) as u16), "PDUSource" => Some(U16::LE(pdu_source.unwrap_or(0))), "pduMessage" => default_message ] @@ -202,7 +204,7 @@ impl DataPDU { /// Build a DATA PDU from a PDU container /// User must check that the PDU is a DATA PDU /// If not this function will panic - pub fn from_pdu(data_pdu: &PDU) -> RdpResult { + pub fn from_pdu(data_pdu: &Pdu) -> RdpResult { let pdu_type = PDUType2::try_from(cast!(DataType::U8, data_pdu.message["pduType2"])?)?; let mut result = match pdu_type { PDUType2::Pdutype2Synchronize => ts_synchronize_pdu(None), @@ -267,21 +269,21 @@ fn ts_set_error_info_pdu() -> DataPDU { #[repr(u16)] #[allow(dead_code)] -enum Action { - CtrlactionRequestControl = 0x0001, - CtrlactionGrantedControl = 0x0002, - CtrlactionDetach = 0x0003, - CtrlactionCooperate = 0x0004, +enum CtrlAction { + RequestControl = 0x0001, + GrantedControl = 0x0002, + Detach = 0x0003, + Cooperate = 0x0004, } /// Control payload send during pdu handshake /// /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/0448f397-aa11-455d-81b1-f1265085239d -fn ts_control_pdu(action: Option) -> DataPDU { +fn ts_control_pdu(action: Option) -> DataPDU { DataPDU { pdu_type: PDUType2::Pdutype2Control, message: component![ - "action" => U16::LE(action.unwrap_or(Action::CtrlactionCooperate) as u16), + "action" => U16::LE(action.unwrap_or(CtrlAction::Cooperate) as u16), "grantId" => U16::LE(0), "controlId" => U32::LE(0) ], @@ -305,7 +307,7 @@ fn ts_font_map_pdu() -> DataPDU { /// Send input event as slow path fn ts_input_pdu_data(events: Option>) -> DataPDU { - let default_events = events.unwrap_or(Array::new(|| ts_input_event(None, None))); + let default_events = events.unwrap_or_else(|| Array::new(|| ts_input_event(None, None))); DataPDU { pdu_type: PDUType2::Pdutype2Input, message: component![ @@ -321,7 +323,7 @@ fn ts_input_event(message_type: Option, data: Option>) - component![ "eventTime" => U32::LE(0), "messageType" => U16::LE(message_type.unwrap_or(InputEventType::InputEventMouse) as u16), - "slowPathInputData" => data.unwrap_or(vec![]) + "slowPathInputData" => data.unwrap_or_default() ] } @@ -399,15 +401,15 @@ pub fn ts_keyboard_event(flags: Option, key_code: Option) -> TSInputEv /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/a1c4caa8-00ed-45bb-a06e-5177473766d3 fn ts_fp_update() -> Component { component![ - "updateHeader" => DynOption::new(0 as u8, |header| { - if (header >> 4) & 0x2 as u8 == 0 as u8 { + "updateHeader" => DynOption::new(0_u8, |header| { + if (header >> 4) & 0x2_u8 == 0_u8 { MessageOption::SkipField("compressionFlags".to_string()) } else { MessageOption::None } }), - "compressionFlags" => 0 as u8, + "compressionFlags" => 0_u8, "size" => DynOption::new(U16::LE(0), | size | MessageOption::Size("updateData".to_string(), size.inner() as usize)), "updateData" => Vec::::new() ] @@ -508,7 +510,7 @@ fn ts_fp_update_bitmap() -> FastPathUpdate { message: component![ "header" => Check::new(U16::LE(FastPathUpdateType::FastpathUpdatetypeBitmap as u16)), "numberRectangles" => U16::LE(0), - "rectangles" => Array::new(|| ts_bitmap_data()) + "rectangles" => Array::new(ts_bitmap_data) ], } } @@ -528,7 +530,7 @@ fn ts_colorpointerattribute() -> FastPathUpdate { "lengthXorMask" => DynOption::new(U16::LE(0), |length| MessageOption::Size("xorMaskData".to_string(), length.inner() as usize)), "xorMaskData" => Vec::::new(), "andMaskData" => Vec::::new(), - "pad" => Some(0 as u8) + "pad" => Some(0_u8) ], } } @@ -637,8 +639,8 @@ impl Client { /// /// This function return true if it read the expected PDU fn read_demand_active_pdu(&mut self, stream: &mut dyn Read) -> RdpResult { - let pdu = PDU::from_stream(stream)?; - if pdu.pdu_type == PDUType::PdutypeDemandactivepdu { + let pdu = Pdu::from_stream(stream)?; + if pdu.pdu_type == PduType::Demandactivepdu { for capability_set in cast!(DataType::Trame, pdu.message["capabilitySets"])?.iter() { match Capability::from_capability_set(cast!(DataType::Component, capability_set)?) { Ok(capability) => self.server_capabilities.push(capability), @@ -648,7 +650,7 @@ impl Client { self.share_id = Some(cast!(DataType::U32, pdu.message["shareId"])?); return Ok(true); } - return Ok(false); + Ok(false) } /// Read server synchronize pdu @@ -656,8 +658,8 @@ impl Client { /// /// This function return true if it read the expected PDU fn read_synchronize_pdu(&mut self, stream: &mut dyn Read) -> RdpResult { - let pdu = PDU::from_stream(stream)?; - if pdu.pdu_type != PDUType::PdutypeDatapdu { + let pdu = Pdu::from_stream(stream)?; + if pdu.pdu_type != PduType::Datapdu { return Ok(false); } if DataPDU::from_pdu(&pdu)?.pdu_type != PDUType2::Pdutype2Synchronize { @@ -669,9 +671,9 @@ impl Client { /// Read the server control PDU with the expected action /// /// This function return true if it read the expected PDU with the expected action - fn read_control_pdu(&mut self, stream: &mut dyn Read, action: Action) -> RdpResult { - let pdu = PDU::from_stream(stream)?; - if pdu.pdu_type != PDUType::PdutypeDatapdu { + fn read_control_pdu(&mut self, stream: &mut dyn Read, action: CtrlAction) -> RdpResult { + let pdu = Pdu::from_stream(stream)?; + if pdu.pdu_type != PduType::Datapdu { return Ok(false); } @@ -694,8 +696,8 @@ impl Client { /// /// This function return true if it read the expected PDU fn read_font_map_pdu(&mut self, stream: &mut dyn Read) -> RdpResult { - let pdu = PDU::from_stream(stream)?; - if pdu.pdu_type != PDUType::PdutypeDatapdu { + let pdu = Pdu::from_stream(stream)?; + if pdu.pdu_type != PduType::Datapdu { return Ok(false); } if DataPDU::from_pdu(&pdu)?.pdu_type != PDUType2::Pdutype2Fontmap { @@ -713,15 +715,15 @@ impl Client { message.read(stream)?; for pdu in message.inner() { - let pdu = PDU::from_control(cast!(DataType::Component, pdu)?)?; + let pdu = Pdu::from_control(cast!(DataType::Component, pdu)?)?; // Ask for a new handshake - if pdu.pdu_type == PDUType::PdutypeDeactivateallpdu { + if pdu.pdu_type == PduType::Deactivateallpdu { println!("GLOBAL: deactive/reactive sequence initiated"); self.state = ClientState::DemandActivePDU; continue; } - if pdu.pdu_type != PDUType::PdutypeDatapdu { + if pdu.pdu_type != PduType::Datapdu { println!("GLOBAL: Ignore PDU {:?}", pdu.pdu_type); continue; } @@ -748,7 +750,7 @@ impl Client { T: FnMut(RdpEvent), { // it could be have one or more fast path payload - let mut fp_messages = Array::new(|| ts_fp_update()); + let mut fp_messages = Array::new(ts_fp_update); fp_messages.read(stream)?; for fp_message in fp_messages.inner().iter() { @@ -790,7 +792,7 @@ impl Client { /// Write confirm active pdu /// This PDU include all client capabilities - fn write_confirm_active_pdu( + async fn write_confirm_active_pdu( &mut self, mcs: &mut mcs::Client, ) -> RdpResult<()> { @@ -831,20 +833,30 @@ impl Client { capability_set(Some(capability::ts_multifragment_update_capability_ts())) ])), ); - self.write_pdu(pdu, mcs) + self.write_pdu(pdu, mcs).await } /// This is the finalize connection sequence /// sent from client to server - fn write_client_finalize(&self, mcs: &mut mcs::Client) -> RdpResult<()> { - self.write_data_pdu(ts_synchronize_pdu(Some(self.channel_id)), mcs)?; - self.write_data_pdu(ts_control_pdu(Some(Action::CtrlactionCooperate)), mcs)?; - self.write_data_pdu(ts_control_pdu(Some(Action::CtrlactionRequestControl)), mcs)?; - self.write_data_pdu(ts_font_list_pdu(), mcs) + async fn write_client_finalize( + &self, + mcs: &mut mcs::Client, + ) -> RdpResult<()> { + self.write_data_pdu(ts_synchronize_pdu(Some(self.channel_id)), mcs) + .await?; + self.write_data_pdu(ts_control_pdu(Some(CtrlAction::Cooperate)), mcs) + .await?; + self.write_data_pdu(ts_control_pdu(Some(CtrlAction::RequestControl)), mcs) + .await?; + self.write_data_pdu(ts_font_list_pdu(), mcs).await } /// Send a classic PDU to the global channel - fn write_pdu(&self, message: PDU, mcs: &mut mcs::Client) -> RdpResult<()> { + async fn write_pdu( + &self, + message: Pdu, + mcs: &mut mcs::Client, + ) -> RdpResult<()> { mcs.write( &"global".to_string(), share_control_header( @@ -853,10 +865,11 @@ impl Client { Some(to_vec(&message.message)), ), ) + .await } /// Send Data pdu - fn write_data_pdu( + async fn write_data_pdu( &self, message: DataPDU, mcs: &mut mcs::Client, @@ -869,6 +882,7 @@ impl Client { ), mcs, ) + .await } /// Public interface to sent input event @@ -894,19 +908,21 @@ impl Client { /// &mut self.mcs /// ) /// ``` - pub fn write_input_event( + pub async fn write_input_event( &self, event: TSInputEvent, mcs: &mut mcs::Client, ) -> RdpResult<()> { match self.state { - ClientState::Data => Ok(self.write_data_pdu( - ts_input_pdu_data(Some(Array::from_trame(trame![ts_input_event( - Some(event.event_type), - Some(to_vec(&event.message)) - )]))), - mcs, - )?), + ClientState::Data => Ok(self + .write_data_pdu( + ts_input_pdu_data(Some(Array::from_trame(trame![ts_input_event( + Some(event.event_type), + Some(to_vec(&event.message)) + )]))), + mcs, + ) + .await?), _ => Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidAutomata, "You cannot send data once it's not connected", @@ -929,7 +945,7 @@ impl Client { /// ... /// } /// ``` - pub fn read( + pub async fn read( &mut self, payload: tpkt::Payload, mcs: &mut mcs::Client, @@ -941,8 +957,8 @@ impl Client { match self.state { ClientState::DemandActivePDU => { if self.read_demand_active_pdu(&mut try_let!(tpkt::Payload::Raw, payload)?)? { - self.write_confirm_active_pdu(mcs)?; - self.write_client_finalize(mcs)?; + self.write_confirm_active_pdu(mcs).await?; + self.write_client_finalize(mcs).await?; // now wait for server synchronize self.state = ClientState::SynchronizePDU; } @@ -958,7 +974,7 @@ impl Client { ClientState::ControlCooperate => { if self.read_control_pdu( &mut try_let!(tpkt::Payload::Raw, payload)?, - Action::CtrlactionCooperate, + CtrlAction::Cooperate, )? { // next state is control granted self.state = ClientState::ControlGranted; @@ -968,7 +984,7 @@ impl Client { ClientState::ControlGranted => { if self.read_control_pdu( &mut try_let!(tpkt::Payload::Raw, payload)?, - Action::CtrlactionGrantedControl, + CtrlAction::GrantedControl, )? { // next state is font map pdu self.state = ClientState::FontMap; @@ -1057,7 +1073,7 @@ mod test { fn test_share_control_header() { let mut stream = Cursor::new(vec![]); share_control_header( - Some(PDUType::PdutypeConfirmactivepdu), + Some(PduType::Confirmactivepdu), Some(12), Some(to_vec( &ts_confirm_active_pdu( @@ -1098,7 +1114,7 @@ mod test { ]); let mut global = Client::new(0, 0, 800, 600, KeyboardLayout::US, "foo"); assert!(global - .read_control_pdu(&mut stream, Action::CtrlactionCooperate) + .read_control_pdu(&mut stream, CtrlAction::Cooperate) .unwrap()) } @@ -1110,7 +1126,7 @@ mod test { ]); let mut global = Client::new(0, 0, 800, 600, KeyboardLayout::US, "foo"); assert!(global - .read_control_pdu(&mut stream, Action::CtrlactionGrantedControl) + .read_control_pdu(&mut stream, CtrlAction::GrantedControl) .unwrap()) } diff --git a/src/core/license.rs b/src/core/license.rs index 8a57e45..c9a224f 100644 --- a/src/core/license.rs +++ b/src/core/license.rs @@ -68,7 +68,7 @@ pub enum StateTransition { /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/73170ca2-5f82-4a2d-9d1b-b439f3d8dadc fn preamble() -> Component { component![ - "bMsgtype" => 0 as u8, + "bMsgtype" => 0_u8, "flag" => Check::new(Preambule::PreambleVersion30 as u8), "wMsgSize" => DynOption::new(U16::LE(0), |size| MessageOption::Size("message".to_string(), size.inner() as usize - 4)), "message" => Vec::::new() diff --git a/src/core/mcs.rs b/src/core/mcs.rs index c2615e2..07969a9 100644 --- a/src/core/mcs.rs +++ b/src/core/mcs.rs @@ -12,7 +12,8 @@ use crate::nla::asn1::{ from_ber, to_der, ASN1Type, Enumerate, ImplicitTag, Integer, OctetString, Sequence, }; use std::collections::HashMap; -use std::io::{BufRead, Cursor, Read, Write}; +use std::io::{BufRead, Cursor, Read}; +use tokio::io::*; use yasna::Tag; #[allow(dead_code)] @@ -30,6 +31,7 @@ enum DomainMCSPDU { /// ASN1 structure use by mcs layer /// to inform on conference capability +#[allow(clippy::too_many_arguments)] fn domain_parameters( max_channel_ids: u32, maw_user_ids: u32, @@ -60,13 +62,13 @@ fn connect_initial(user_data: Option) -> ImplicitTag { ImplicitTag::new( Tag::application(101), sequence![ - "callingDomainSelector" => vec![1 as u8] as OctetString, - "calledDomainSelector" => vec![1 as u8] as OctetString, + "callingDomainSelector" => vec![1_u8] as OctetString, + "calledDomainSelector" => vec![1_u8] as OctetString, "upwardFlag" => true, "targetParameters" => domain_parameters(34, 2, 0, 1, 0, 1, 0xffff, 2), "minimumParameters" => domain_parameters(1, 1, 1, 1, 0, 1, 0x420, 2), "maximumParameters" => domain_parameters(0xffff, 0xfc17, 0xffff, 1, 0, 1, 0xffff, 2), - "userData" => user_data.unwrap_or(Vec::new()) + "userData" => user_data.unwrap_or_default() ], ) } @@ -79,7 +81,7 @@ fn connect_response(user_data: Option) -> ImplicitTag { "result" => 0 as Enumerate, "calledConnectId" => 0 as Integer, "domainParameters" => domain_parameters(22, 3, 0, 1, 0, 1,0xfff8, 2), - "userData" => user_data.unwrap_or(Vec::new()) + "userData" => user_data.unwrap_or_default() ], ) } @@ -93,7 +95,7 @@ fn mcs_pdu_header(pdu: Option, options: Option) -> u8 { /// Client -- attach_user_request -> Server /// Client <- attach_user_confirm -- Server fn read_attach_user_confirm(buffer: &mut dyn Read) -> RdpResult { - let mut confirm = trame![0 as u8, Vec::::new()]; + let mut confirm = trame![0_u8, Vec::::new()]; confirm.read(buffer)?; if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::AttachUserConfirm), None) >> 2 @@ -111,7 +113,7 @@ fn read_attach_user_confirm(buffer: &mut dyn Read) -> RdpResult { "MCS: recv_attach_user_confirm user rejected by server", ))); } - Ok(per::read_integer_16(1001, &mut request)?) + per::read_integer_16(1001, &mut request) } /// Create a session for the current user @@ -162,7 +164,7 @@ fn read_channel_join_confirm( channel_id: u16, buffer: &mut dyn Read, ) -> RdpResult { - let mut confirm = trame![0 as u8, Vec::::new()]; + let mut confirm = trame![0_u8, Vec::::new()]; confirm.read(buffer)?; if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::ChannelJoinConfirm), None) >> 2 @@ -207,7 +209,7 @@ pub struct Client { channel_ids: HashMap, } -impl Client { +impl Client { pub fn new(x224: x224::Client) -> Self { Client { server_data: None, @@ -220,7 +222,7 @@ impl Client { /// Write connection initial payload /// This payload include a lot of /// client specific config parameters - fn write_connect_initial( + async fn write_connect_initial( &mut self, screen_width: u16, screen_height: u16, @@ -261,15 +263,17 @@ impl Client { ] ]); let conference = write_conference_create_request(&user_data)?; - self.x224.write(to_der(&connect_initial(Some(conference)))) + self.x224 + .write(to_der(&connect_initial(Some(conference)))) + .await } /// Read a connect response comming from server to client - fn read_connect_response(&mut self) -> RdpResult<()> { + async fn read_connect_response(&mut self) -> RdpResult<()> { // Now read response from the server let mut connect_response = connect_response(None); - let mut payload = try_let!(tpkt::Payload::Raw, self.x224.read()?)?; - from_ber(&mut connect_response, payload.fill_buf()?)?; + let mut payload = try_let!(tpkt::Payload::Raw, self.x224.read().await?)?; + from_ber(&mut connect_response, BufRead::fill_buf(&mut payload)?)?; // Get server data // Read conference create response @@ -289,21 +293,22 @@ impl Client { /// let mut mcs = mcs::Client(x224); /// mcs.connect(800, 600, KeyboardLayout::French).unwrap() /// ``` - pub fn connect( + pub async fn connect( &mut self, client_name: String, screen_width: u16, screen_height: u16, keyboard_layout: KeyboardLayout, ) -> RdpResult<()> { - self.write_connect_initial(screen_width, screen_height, keyboard_layout, client_name)?; - self.read_connect_response()?; - self.x224.write(erect_domain_request()?)?; - self.x224.write(attach_user_request())?; + self.write_connect_initial(screen_width, screen_height, keyboard_layout, client_name) + .await?; + self.read_connect_response().await?; + self.x224.write(erect_domain_request()?).await?; + self.x224.write(attach_user_request()).await?; self.user_id = Some(read_attach_user_confirm(&mut try_let!( tpkt::Payload::Raw, - self.x224.read()? + self.x224.read().await? )?)?); // Add static channel @@ -315,11 +320,12 @@ impl Client { // Actually only the two static main channel are requested for channel_id in self.channel_ids.values() { self.x224 - .write(channel_join_request(self.user_id, Some(*channel_id))?)?; + .write(channel_join_request(self.user_id, Some(*channel_id))?) + .await?; if !read_channel_join_confirm( self.user_id.unwrap(), *channel_id, - &mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?, + &mut try_let!(tpkt::Payload::Raw, self.x224.read().await?)?, )? { println!("Server reject channel id {:?}", channel_id); } @@ -338,18 +344,20 @@ impl Client { /// mcs.connect(800, 600, KeyboardLayout::French).unwrap(); /// mcs.write("global".to_string(), trame![U16::LE(0)]) /// ``` - pub fn write(&mut self, channel_name: &String, message: T) -> RdpResult<()> + pub async fn write(&mut self, channel_name: &String, message: T) -> RdpResult<()> where T: Message, { - self.x224.write(trame![ - mcs_pdu_header(Some(DomainMCSPDU::SendDataRequest), None), - U16::BE(self.user_id.unwrap() - 1001), - U16::BE(self.channel_ids[channel_name]), - 0x70 as u8, - per::write_length(message.length() as u16)?, - message - ]) + self.x224 + .write(trame![ + mcs_pdu_header(Some(DomainMCSPDU::SendDataRequest), None), + U16::BE(self.user_id.unwrap() - 1001), + U16::BE(self.channel_ids[channel_name]), + 0x70_u8, + per::write_length(message.length() as u16)?, + message + ]) + .await } /// Receive a message for a specific channel @@ -366,8 +374,8 @@ impl Client { /// ... /// } /// ``` - pub fn read(&mut self) -> RdpResult<(String, tpkt::Payload)> { - let message = self.x224.read()?; + pub async fn read(&mut self) -> RdpResult<(String, tpkt::Payload)> { + let message = self.x224.read().await?; match message { tpkt::Payload::Raw(mut payload) => { let mut header = mcs_pdu_header(None, None); @@ -390,14 +398,16 @@ impl Client { per::read_integer_16(1001, &mut payload)?; let channel_id = per::read_integer_16(0, &mut payload)?; - let channel = - self.channel_ids - .iter() - .find(|x| *x.1 == channel_id) - .ok_or(Error::RdpError(RdpError::new( + let channel = self + .channel_ids + .iter() + .find(|x| *x.1 == channel_id) + .ok_or_else(|| { + Error::RdpError(RdpError::new( RdpErrorKind::Unknown, "MCS: unknown channel", - )))?; + )) + })?; per::read_enumerates(&mut payload)?; per::read_length(&mut payload)?; @@ -415,13 +425,15 @@ impl Client { } /// Send a close event to server - pub fn shutdown(&mut self) -> RdpResult<()> { - self.x224.write(trame![ - mcs_pdu_header(Some(DomainMCSPDU::DisconnectProviderUltimatum), Some(1)), - per::write_enumerates(0x80)?, - b"\x00\x00\x00\x00\x00\x00".to_vec() - ])?; - self.x224.shutdown() + pub async fn shutdown(&mut self) -> RdpResult<()> { + self.x224 + .write(trame![ + mcs_pdu_header(Some(DomainMCSPDU::DisconnectProviderUltimatum), Some(1)), + per::write_enumerates(0x80)?, + b"\x00\x00\x00\x00\x00\x00".to_vec() + ]) + .await?; + self.x224.shutdown().await } /// This function check if the client diff --git a/src/core/per.rs b/src/core/per.rs index 6d39b42..8d44b95 100644 --- a/src/core/per.rs +++ b/src/core/per.rs @@ -17,7 +17,7 @@ pub fn read_length(s: &mut dyn Read) -> RdpResult { let mut byte: u8 = 0; byte.read(s)?; if byte & 0x80 != 0 { - byte = byte & !0x80; + byte &= !0x80; let mut size = (byte as u16) << 8; byte.read(s)?; size += byte as u16; @@ -336,7 +336,7 @@ pub fn write_object_identifier(oid: &[u8], s: &mut dyn Write) -> RdpResult<()> { } trame![ - 5 as u8, + 5_u8, oid[0] << 4 | oid[1] & 0xF, oid[2], oid[3], @@ -357,7 +357,7 @@ pub fn write_object_identifier(oid: &[u8], s: &mut dyn Write) -> RdpResult<()> { /// ``` pub fn read_numeric_string(minimum: usize, s: &mut dyn Read) -> RdpResult> { let length = read_length(s)?; - let mut result = vec![0 as u8; length as usize + minimum + 1]; + let mut result = vec![0_u8; length as usize + minimum + 1]; result.read(s)?; Ok(result) } @@ -388,13 +388,13 @@ pub fn write_numeric_string(string: &[u8], minimum: usize, s: &mut dyn Write) -> /// Read exactly a number of bytes pub fn read_padding(length: usize, s: &mut dyn Read) -> RdpResult<()> { let mut padding = vec![0; length]; - s.read(&mut padding)?; + s.read_exact(&mut padding)?; Ok(()) } /// Write length zero bytes pub fn write_padding(length: usize, s: &mut dyn Write) -> RdpResult<()> { - vec![0 as u8; length].write(s)?; + vec![0_u8; length].write(s)?; Ok(()) } @@ -417,10 +417,10 @@ pub fn read_octet_stream(octet_stream: &[u8], minimum: usize, s: &mut dyn Read) "PER: source octet string have an invalid size", ))); } - for i in 0..length { + for oc in octet_stream.iter().take(length) { let mut c: u8 = 0; c.read(s)?; - if c != octet_stream[i] { + if &c != oc { return Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidData, "PER: source octet string have an invalid char", diff --git a/src/core/sec.rs b/src/core/sec.rs index 8b0d0cb..54f79d9 100644 --- a/src/core/sec.rs +++ b/src/core/sec.rs @@ -4,7 +4,7 @@ use crate::core::tpkt; use crate::model::data::{Component, DataType, DynOption, Message, MessageOption, Trame, U16, U32}; use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use crate::model::unicode::Unicode; -use std::io::{Read, Write}; +use tokio::io::*; /// Security flag send as header flage in core ptotocol /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e13405c5-668b-4716-94b2-1c2654ca1ad4?redirectedfrom=MSDN @@ -32,25 +32,25 @@ enum SecurityFlag { /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/732394f5-e2b5-4ac5-8a0a-35345386b0d1?redirectedfrom=MSDN #[allow(dead_code)] enum InfoFlag { - InfoMouse = 0x00000001, - InfoDisablectrlaltdel = 0x00000002, - InfoAutologon = 0x00000008, - InfoUnicode = 0x00000010, - InfoMaximizeshell = 0x00000020, - InfoLogonnotify = 0x00000040, - InfoCompression = 0x00000080, - InfoEnablewindowskey = 0x00000100, - InfoRemoteconsoleaudio = 0x00002000, - InfoForceEncryptedCsPdu = 0x00004000, - InfoRail = 0x00008000, - InfoLogonerrors = 0x00010000, - InfoMouseHasWheel = 0x00020000, - InfoPasswordIsScPin = 0x00040000, - InfoNoaudioplayback = 0x00080000, - InfoUsingSavedCreds = 0x00100000, - InfoAudiocapture = 0x00200000, - InfoVideoDisable = 0x00400000, - InfoCompressionTypeMask = 0x00001E00, + Mouse = 0x00000001, + Disablectrlaltdel = 0x00000002, + Autologon = 0x00000008, + Unicode = 0x00000010, + Maximizeshell = 0x00000020, + Logonnotify = 0x00000040, + Compression = 0x00000080, + Enablewindowskey = 0x00000100, + Remoteconsoleaudio = 0x00002000, + ForceEncryptedCsPdu = 0x00004000, + Rail = 0x00008000, + Logonerrors = 0x00010000, + MouseHasWheel = 0x00020000, + PasswordIsScPin = 0x00040000, + Noaudioplayback = 0x00080000, + UsingSavedCreds = 0x00100000, + Audiocapture = 0x00200000, + VideoDisable = 0x00400000, + CompressionTypeMask = 0x00001E00, } #[allow(dead_code)] @@ -99,13 +99,13 @@ fn rdp_infos( component![ "codePage" => U32::LE(0), "flag" => U32::LE( - InfoFlag::InfoMouse as u32 | - InfoFlag::InfoUnicode as u32 | - InfoFlag::InfoLogonnotify as u32 | - InfoFlag::InfoLogonerrors as u32 | - InfoFlag::InfoDisablectrlaltdel as u32 | - InfoFlag::InfoEnablewindowskey as u32 | - if auto_logon { InfoFlag::InfoAutologon as u32 } else { 0 } + InfoFlag::Mouse as u32 | + InfoFlag::Unicode as u32 | + InfoFlag::Logonnotify as u32 | + InfoFlag::Logonerrors as u32 | + InfoFlag::Disablectrlaltdel as u32 | + InfoFlag::Enablewindowskey as u32 | + if auto_logon { InfoFlag::Autologon as u32 } else { 0 } ), "cbDomain" => U16::LE((domain_format.len() - 2) as u16), "cbUserName" => U16::LE((username_format.len() - 2) as u16), @@ -141,7 +141,7 @@ fn security_header() -> Component { /// let mut mcs = mcs::Client(...).unwrap(); /// sec::connect(&mut mcs).unwrap(); /// ``` -pub fn connect( +pub async fn connect( mcs: &mut mcs::Client, domain: &String, username: &String, @@ -161,9 +161,10 @@ pub fn connect( auto_logon ) ], - )?; + ) + .await?; - let (_channel_name, payload) = mcs.read()?; + let (_channel_name, payload) = mcs.read().await?; let mut stream = try_let!(tpkt::Payload::Raw, payload)?; let mut header = security_header(); header.read(&mut stream)?; diff --git a/src/core/tpkt.rs b/src/core/tpkt.rs index 49534d7..14fe33e 100644 --- a/src/core/tpkt.rs +++ b/src/core/tpkt.rs @@ -3,7 +3,8 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use crate::model::link::Link; use crate::nla::cssp::cssp_connect; use crate::nla::sspi::AuthenticationProtocol; -use std::io::{Cursor, Read, Write}; +use std::io::Cursor; +use tokio::io::*; /// TPKT must implement this two kind of payload pub enum Payload { @@ -26,7 +27,7 @@ pub enum Action { fn tpkt_header(size: u16) -> Component { component![ "action" => Action::FastPathActionX224 as u8, - "flag" => 0 as u8, + "flag" => 0_u8, "size" => U16::BE(size + 4) ] } @@ -45,7 +46,7 @@ pub struct Client { transport: Link, } -impl Client { +impl Client { /// Ctor of TPKT client layer pub fn new(transport: Link) -> Self { Client { transport } @@ -76,12 +77,13 @@ impl Client { /// } /// } /// ``` - pub fn write(&mut self, message: T) -> RdpResult<()> + pub async fn write(&mut self, message: T) -> RdpResult<()> where T: Message, { self.transport .write(&trame![tpkt_header(message.length() as u16), message]) + .await } /// Read a payload from the underlying layer @@ -117,8 +119,8 @@ impl Client { /// panic!("unexpected result") /// } /// ``` - pub fn read(&mut self) -> RdpResult { - let mut buffer = Cursor::new(self.transport.read(2)?); + pub async fn read(&mut self) -> RdpResult { + let mut buffer = Cursor::new(self.transport.read(2).await?); let mut action: u8 = 0; action.read(&mut buffer)?; if action == Action::FastPathActionX224 as u8 { @@ -126,7 +128,7 @@ impl Client { let mut padding: u8 = 0; padding.read(&mut buffer)?; // now wait extended header - buffer = Cursor::new(self.transport.read(2)?); + buffer = Cursor::new(self.transport.read(2).await?); let mut size = U16::BE(0); size.read(&mut buffer)?; @@ -141,7 +143,7 @@ impl Client { } else { // now wait for body Ok(Payload::Raw(Cursor::new( - self.transport.read(size.inner() as usize - 4)?, + self.transport.read(size.inner() as usize - 4).await?, ))) } } else { @@ -151,7 +153,7 @@ impl Client { short_length.read(&mut buffer)?; if short_length & 0x80 != 0 { let mut hi_length: u8 = 0; - hi_length.read(&mut Cursor::new(self.transport.read(1)?))?; + hi_length.read(&mut Cursor::new(self.transport.read(1).await?))?; let length: u16 = ((short_length & !0x80) as u16) << 8; let length = length | hi_length as u16; if length < 3 { @@ -162,21 +164,19 @@ impl Client { } else { Ok(Payload::FastPath( sec_flag, - Cursor::new(self.transport.read(length as usize - 3)?), + Cursor::new(self.transport.read(length as usize - 3).await?), )) } + } else if short_length < 2 { + Err(Error::RdpError(RdpError::new( + RdpErrorKind::InvalidSize, + "Invalid minimal size for TPKT", + ))) } else { - if short_length < 2 { - Err(Error::RdpError(RdpError::new( - RdpErrorKind::InvalidSize, - "Invalid minimal size for TPKT", - ))) - } else { - Ok(Payload::FastPath( - sec_flag, - Cursor::new(self.transport.read(short_length as usize - 2)?), - )) - } + Ok(Payload::FastPath( + sec_flag, + Cursor::new(self.transport.read(short_length as usize - 2).await?), + )) } } } @@ -194,8 +194,10 @@ impl Client { /// let mut tpkt = tpkt::Client::new(link::Link::new(link::Stream::Raw(tcp))); /// let mut tpkt_ssl = tpkt.start_ssl(false).unwrap(); /// ``` - pub fn start_ssl(self, check_certificate: bool) -> RdpResult> { - Ok(Client::new(self.transport.start_ssl(check_certificate)?)) + pub async fn start_ssl(self, check_certificate: bool) -> RdpResult> { + Ok(Client::new( + self.transport.start_ssl(check_certificate).await?, + )) } /// This function is used when NLA (Network Level Authentication) @@ -212,20 +214,20 @@ impl Client { /// let mut tpkt = tpkt::Client::new(link::Link::new(link::Stream::Raw(tcp))); /// let mut tpkt_nla = tpkt.start_nla(false, &mut Ntlm::new("domain".to_string(), "username".to_string(), "password".to_string()), false); /// ``` - pub fn start_nla( + pub async fn start_nla( self, check_certificate: bool, authentication_protocol: &mut dyn AuthenticationProtocol, restricted_admin_mode: bool, ) -> RdpResult> { - let mut link = self.transport.start_ssl(check_certificate)?; - cssp_connect(&mut link, authentication_protocol, restricted_admin_mode)?; + let mut link = self.transport.start_ssl(check_certificate).await?; + cssp_connect(&mut link, authentication_protocol, restricted_admin_mode).await?; Ok(Client::new(link)) } /// Shutdown current connection - pub fn shutdown(&mut self) -> RdpResult<()> { - self.transport.shutdown() + pub async fn shutdown(&mut self) -> RdpResult<()> { + self.transport.shutdown().await } #[cfg(feature = "integration")] diff --git a/src/core/x224.rs b/src/core/x224.rs index f4db4a8..f800d50 100644 --- a/src/core/x224.rs +++ b/src/core/x224.rs @@ -4,8 +4,8 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use crate::nla::sspi::AuthenticationProtocol; use num_enum::TryFromPrimitive; use std::convert::TryFrom; -use std::io::{Read, Write}; use std::option::Option; +use tokio::io::*; #[repr(u8)] #[derive(Copy, Clone, TryFromPrimitive)] @@ -81,7 +81,7 @@ fn x224_crq(len: u8, code: MessageType) -> Component { component! [ "len" => (len + 6) as u8, "code" => code as u8, - "padding" => trame! [U16::LE(0), U16::LE(0), 0 as u8] + "padding" => trame! [U16::LE(0), U16::LE(0), 0_u8] ] } @@ -104,9 +104,9 @@ fn x224_connection_pdu( /// X224 header fn x224_header() -> Component { component![ - "header" => 2 as u8, + "header" => 2_u8, "messageType" => MessageType::X224TPDUData as u8, - "separator" => Check::new(0x80 as u8) + "separator" => Check::new(0x80_u8) ] } @@ -118,7 +118,7 @@ pub struct Client { selected_protocol: Protocols, } -impl Client { +impl Client { /// Constructor use by the connector fn new(transport: tpkt::Client, selected_protocol: Protocols) -> Self { Client { @@ -141,11 +141,11 @@ impl Client { /// ).unwrap(); /// x224.write(trame![U16::LE(0)]).unwrap() /// ``` - pub fn write(&mut self, message: T) -> RdpResult<()> + pub async fn write(&mut self, message: T) -> RdpResult<()> where T: Message, { - self.transport.write(trame![x224_header(), message]) + self.transport.write(trame![x224_header(), message]).await } /// Start reading an entire X224 paylaod @@ -163,8 +163,8 @@ impl Client { /// ).unwrap(); /// let payload = x224.read().unwrap(); // you have to check the type /// ``` - pub fn read(&mut self) -> RdpResult { - let s = self.transport.read()?; + pub async fn read(&mut self) -> RdpResult { + let s = self.transport.read().await?; match s { tpkt::Payload::Raw(mut payload) => { let mut x224_header = x224_header(); @@ -207,7 +207,7 @@ impl Client { /// false /// ).unwrap() /// ``` - pub fn connect( + pub async fn connect( mut tpkt: tpkt::Client, security_protocols: u32, check_certificate: bool, @@ -223,18 +223,20 @@ impl Client { } else { 0 }), - )?; - match Self::read_connection_confirm(&mut tpkt)? { + ) + .await?; + match Self::read_connection_confirm(&mut tpkt).await? { Protocols::ProtocolHybrid => Ok(Client::new( tpkt.start_nla( check_certificate, authentication_protocol.unwrap(), restricted_admin_mode || blank_creds, - )?, + ) + .await?, Protocols::ProtocolHybrid, )), Protocols::ProtocolSSL => Ok(Client::new( - tpkt.start_ssl(check_certificate)?, + tpkt.start_ssl(check_certificate).await?, Protocols::ProtocolSSL, )), Protocols::ProtocolRDP => Ok(Client::new(tpkt, Protocols::ProtocolRDP)), @@ -246,7 +248,7 @@ impl Client { } /// Send connection request - fn write_connection_request( + async fn write_connection_request( tpkt: &mut tpkt::Client, security_protocols: u32, mode: Option, @@ -256,11 +258,12 @@ impl Client { mode, Some(security_protocols), )) + .await } /// Expect a connection confirm payload - fn read_connection_confirm(tpkt: &mut tpkt::Client) -> RdpResult { - let mut buffer = try_let!(tpkt::Payload::Raw, tpkt.read()?)?; + async fn read_connection_confirm(tpkt: &mut tpkt::Client) -> RdpResult { + let mut buffer = try_let!(tpkt::Payload::Raw, tpkt.read().await?)?; let mut confirm = x224_connection_pdu(None, None, None); confirm.read(&mut buffer)?; @@ -286,8 +289,8 @@ impl Client { self.selected_protocol } - pub fn shutdown(&mut self) -> RdpResult<()> { - self.transport.shutdown() + pub async fn shutdown(&mut self) -> RdpResult<()> { + self.transport.shutdown().await } } diff --git a/src/model/data.rs b/src/model/data.rs index 472916e..beba00d 100644 --- a/src/model/data.rs +++ b/src/model/data.rs @@ -243,7 +243,7 @@ pub type Trame = Vec>; macro_rules! trame { () => { Trame::new() }; ($( $val: expr ),*) => {{ - let mut vec = Trame::new(); + let mut vec = vec![] as Trame; $( vec.push(Box::new($val)); )* vec }} @@ -547,7 +547,7 @@ impl Value { impl PartialEq for Value { /// Equality between all type fn eq(&self, other: &Self) -> bool { - return self.inner() == other.inner(); + self.inner() == other.inner() } } @@ -807,12 +807,12 @@ impl Message for Check { impl Message for Vec { fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { - writer.write(self)?; + writer.write_all(self)?; Ok(()) } fn read(&mut self, reader: &mut dyn Read) -> RdpResult<()> { - if self.len() == 0 { + if self.is_empty() { reader.read_to_end(self)?; } else { reader.read_exact(self)?; @@ -997,9 +997,10 @@ impl Message for Option { /// assert_eq!(s2.into_inner(), []) /// ``` fn write(&self, writer: &mut dyn Write) -> RdpResult<()> { - Ok(if let Some(value) = self { + if let Some(value) = self { value.write(writer)? - }) + }; + Ok(()) } /// Read an optional field @@ -1215,7 +1216,7 @@ mod test { #[test] fn test_data_u8_write() { let mut stream = Cursor::new(Vec::::new()); - let x = 1 as u8; + let x = 1_u8; x.write(&mut stream).unwrap(); assert_eq!(stream.get_ref().as_slice(), [1]) } diff --git a/src/model/error.rs b/src/model/error.rs index 38adc4a..37c0bb6 100644 --- a/src/model/error.rs +++ b/src/model/error.rs @@ -1,11 +1,8 @@ #[cfg(feature = "openssl")] -use native_tls::Error as SslError; -#[cfg(feature = "openssl")] -use native_tls::HandshakeError; +use async_native_tls::Error as SslError; use num_enum::{TryFromPrimitive, TryFromPrimitiveError}; use std::io::Error as IoError; #[cfg(feature = "openssl")] -use std::io::{Read, Write}; use std::string::String; use yasna::ASN1Error; @@ -94,8 +91,6 @@ pub enum Error { RdpError(RdpError), /// All kind of IO error Io(IoError), - /// SSL handshake error - SslHandshakeError, /// SSL error #[cfg(feature = "openssl")] SslError(SslError), @@ -112,13 +107,6 @@ impl From for Error { } } -#[cfg(feature = "openssl")] -impl From> for Error { - fn from(_: HandshakeError) -> Error { - Error::SslHandshakeError - } -} - #[cfg(feature = "openssl")] impl From for Error { fn from(e: SslError) -> Error { diff --git a/src/model/link.rs b/src/model/link.rs index ef53930..1c40a6b 100644 --- a/src/model/link.rs +++ b/src/model/link.rs @@ -1,13 +1,14 @@ use crate::model::data::Message; use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; #[cfg(feature = "openssl")] -use native_tls::{TlsConnector, TlsStream}; -use std::io::{Cursor, Read, Write}; +use async_native_tls::{TlsConnector, TlsStream}; +use std::io::Cursor; +use tokio::io::*; #[cfg(not(feature = "openssl"))] pub trait SecureBio where - S: Read + Write, + S: AsyncRead + AsyncWrite, { fn start_ssl(&mut self, check_certificate: bool) -> RdpResult<()>; fn get_peer_certificate_der(&self) -> RdpResult>>; @@ -18,7 +19,7 @@ where /// This a wrapper to work equals /// for a stream and a TLS stream pub enum Stream { - /// Raw stream that implement Read + Write + /// Raw stream that implement AsyncRead + AsyncWrite Raw(S), /// TLS Stream #[cfg(feature = "openssl")] @@ -27,7 +28,7 @@ pub enum Stream { Bio(Box>), } -impl Stream { +impl Stream { /// Read exactly the number of bytes present in buffer /// /// # Example @@ -39,11 +40,11 @@ impl Stream { /// s.read_exact(&mut result).unwrap(); /// assert_eq!(result, [1, 2]) /// ``` - pub fn read_exact(&mut self, buf: &mut [u8]) -> RdpResult<()> { + pub async fn read_exact(&mut self, buf: &mut [u8]) -> RdpResult<()> { match self { - Stream::Raw(e) => e.read_exact(buf)?, + Stream::Raw(e) => e.read_exact(buf).await?, #[cfg(feature = "openssl")] - Stream::Ssl(e) => e.read_exact(buf)?, + Stream::Ssl(e) => e.read_exact(buf).await?, #[cfg(not(feature = "openssl"))] Stream::Bio(bio) => bio.get_io().read_exact(buf)?, }; @@ -61,11 +62,11 @@ impl Stream { /// s.read(&mut result).unwrap(); /// assert_eq!(result, [1, 2, 3, 0]) /// ``` - pub fn read(&mut self, buf: &mut [u8]) -> RdpResult { + pub async fn read(&mut self, buf: &mut [u8]) -> RdpResult { match self { - Stream::Raw(e) => Ok(e.read(buf)?), + Stream::Raw(e) => Ok(e.read(buf).await?), #[cfg(feature = "openssl")] - Stream::Ssl(e) => Ok(e.read(buf)?), + Stream::Ssl(e) => Ok(e.read(buf).await?), #[cfg(not(feature = "openssl"))] Stream::Bio(e) => Ok(e.get_io().read(buf)?), } @@ -87,11 +88,11 @@ impl Stream { /// panic!("invalid") /// } /// ``` - pub fn write(&mut self, buffer: &[u8]) -> RdpResult { + pub async fn write(&mut self, buffer: &[u8]) -> RdpResult { Ok(match self { - Stream::Raw(e) => e.write(buffer)?, + Stream::Raw(e) => e.write(buffer).await?, #[cfg(feature = "openssl")] - Stream::Ssl(e) => e.write(buffer)?, + Stream::Ssl(e) => e.write(buffer).await?, #[cfg(not(feature = "openssl"))] Stream::Bio(e) => e.get_io().write(buffer)?, }) @@ -99,14 +100,15 @@ impl Stream { /// Shutdown the stream /// Only works when stream is a SSL stream - pub fn shutdown(&mut self) -> RdpResult<()> { - Ok(match self { + pub async fn shutdown(&mut self) -> RdpResult<()> { + match self { #[cfg(feature = "openssl")] - Stream::Ssl(e) => e.shutdown()?, + Stream::Ssl(e) => e.shutdown().await?, #[cfg(not(feature = "openssl"))] Stream::Bio(e) => e.shutdown()?, _ => (), - }) + }; + Ok(()) } } @@ -116,7 +118,7 @@ pub struct Link { stream: Stream, } -impl Link { +impl Link { /// Create a new link layer from a Stream /// /// # Example @@ -156,10 +158,10 @@ impl Link { /// } /// # } /// ``` - pub fn write(&mut self, message: &dyn Message) -> RdpResult<()> { + pub async fn write(&mut self, message: &dyn Message) -> RdpResult<()> { let mut buffer = Cursor::new(Vec::new()); message.write(&mut buffer)?; - self.stream.write(buffer.into_inner().as_slice())?; + self.stream.write(buffer.into_inner().as_slice()).await?; Ok(()) } @@ -172,15 +174,15 @@ impl Link { /// let mut link = Link::new(Stream::Raw(Cursor::new(vec![0, 1, 2]))); /// assert_eq!(link.read(2).unwrap(), [0, 1]) /// ``` - pub fn read(&mut self, expected_size: usize) -> RdpResult> { + pub async fn read(&mut self, expected_size: usize) -> RdpResult> { if expected_size == 0 { let mut buffer = vec![0; 1500]; - let size = self.stream.read(&mut buffer)?; + let size = self.stream.read(&mut buffer).await?; buffer.resize(size, 0); Ok(buffer) } else { let mut buffer = vec![0; expected_size]; - self.stream.read_exact(&mut buffer)?; + self.stream.read_exact(&mut buffer).await?; Ok(buffer) } } @@ -196,15 +198,15 @@ impl Link { /// let link_ssl = link_tcp.start_ssl(false).unwrap(); /// ``` #[cfg(feature = "openssl")] - pub fn start_ssl(self, check_certificate: bool) -> RdpResult> { - let mut builder = TlsConnector::builder(); - builder.danger_accept_invalid_certs(!check_certificate); - builder.use_sni(false); - - let connector = builder.build()?; - + pub async fn start_ssl(self, check_certificate: bool) -> RdpResult> { if let Stream::Raw(stream) = self.stream { - return Ok(Link::new(Stream::Ssl(connector.connect("", stream)?))); + return Ok(Link::new(Stream::Ssl( + TlsConnector::new() + .danger_accept_invalid_certs(!check_certificate) + .use_sni(false) + .connect("", stream) + .await?, + ))); } Err(Error::RdpError(RdpError::new( RdpErrorKind::NotImplemented, @@ -251,8 +253,8 @@ impl Link { /// Close the stream /// Only works on SSL Stream - pub fn shutdown(&mut self) -> RdpResult<()> { - self.stream.shutdown() + pub async fn shutdown(&mut self) -> RdpResult<()> { + self.stream.shutdown().await } #[cfg(feature = "integration")] diff --git a/src/model/unicode.rs b/src/model/unicode.rs index 45f959e..d734a9e 100644 --- a/src/model/unicode.rs +++ b/src/model/unicode.rs @@ -21,6 +21,6 @@ impl Unicode for String { let encode_char = U16::LE(c); encode_char.write(&mut result).unwrap(); } - return result.into_inner(); + result.into_inner() } } diff --git a/src/nla/asn1.rs b/src/nla/asn1.rs index c0d2108..97de25b 100644 --- a/src/nla/asn1.rs +++ b/src/nla/asn1.rs @@ -34,6 +34,7 @@ pub trait ASN1 { /// A sequence of is dynamically build /// using a callback factory +#[derive(Default)] pub struct SequenceOf { /// The inner vector of ASN1 node pub inner: Vec>, @@ -50,9 +51,8 @@ impl SequenceOf { /// let so = SequenceOf::new(); /// ``` pub fn new() -> Self { - SequenceOf { - inner: Vec::new(), - factory: None, + Self { + ..Default::default() } } diff --git a/src/nla/cssp.rs b/src/nla/cssp.rs index 7b5c44d..082bba2 100644 --- a/src/nla/cssp.rs +++ b/src/nla/cssp.rs @@ -5,7 +5,7 @@ use crate::nla::asn1::{ }; use crate::nla::sspi::AuthenticationProtocol; use num_bigint::BigUint; -use std::io::{Read, Write}; +use tokio::io::*; use x509_parser::{parse_x509_der, X509Certificate}; use yasna::Tag; @@ -165,17 +165,17 @@ fn create_ts_authinfo(auth_info: Vec) -> Vec { /// This the main function for CSSP protocol /// It will use the raw link layer and the selected authenticate protocol /// to perform the NLA authenticate -pub fn cssp_connect( +pub async fn cssp_connect( link: &mut Link, authentication_protocol: &mut dyn AuthenticationProtocol, restricted_admin_mode: bool, ) -> RdpResult<()> { // first step is to send the negotiate message from authentication protocol let negotiate_message = create_ts_request(authentication_protocol.create_negotiate_message()?); - link.write(&negotiate_message)?; + link.write(&negotiate_message).await?; // now receive server challenge - let server_challenge = read_ts_server_challenge(&(link.read(0)?))?; + let server_challenge = read_ts_server_challenge(&(link.read(0).await?))?; // now ask for to authenticate protocol let client_challenge = authentication_protocol.read_challenge_message(&server_challenge)?; @@ -201,10 +201,11 @@ pub fn cssp_connect( .data, )?, ); - link.write(&challenge)?; + link.write(&challenge).await?; // now server respond normally with the original public key incremented by one - let inc_pub_key = security_interface.gss_unwrapex(&(read_ts_validate(&(link.read(0)?))?))?; + let inc_pub_key = + security_interface.gss_unwrapex(&(read_ts_validate(&(link.read(0).await?))?))?; // Check possible man in the middle using cssp if BigUint::from_bytes_le(&inc_pub_key) @@ -243,7 +244,7 @@ pub fn cssp_connect( let credentials = create_ts_authinfo( security_interface.gss_wrapex(&create_ts_credentials(domain, user, password))?, ); - link.write(&credentials)?; + link.write(&credentials).await?; Ok(()) } diff --git a/src/nla/ntlm.rs b/src/nla/ntlm.rs index 8d2bd06..3e0a96a 100644 --- a/src/nla/ntlm.rs +++ b/src/nla/ntlm.rs @@ -65,7 +65,7 @@ fn version() -> Component { "ProductMajorVersion" => MajorVersion::WindowsMajorVersion6 as u8, "ProductMinorVersion" => MinorVersion::WindowsMinorVersion0 as u8, "ProductBuild" => U16::LE(6002), - "Reserved" => trame![U16::LE(0), 0 as u8], + "Reserved" => trame![U16::LE(0), 0_u8], "NTLMRevisionCurrent" => NTLMRevision::NtlmSspRevisionW2K3 as u8 ) } @@ -80,7 +80,7 @@ fn negotiate_message(flags: u32) -> Component { if node.inner() & (Negotiate::NtlmsspNegociateVersion as u32) == 0 { return MessageOption::SkipField("Version".to_string()) } - return MessageOption::None + MessageOption::None }), "DomainNameLen" => U16::LE(0), "DomainNameMaxLen" => U16::LE(0), @@ -106,7 +106,7 @@ fn challenge_message() -> Component { if node.inner() & (Negotiate::NtlmsspNegociateVersion as u32) == 0 { return MessageOption::SkipField("Version".to_string()) } - return MessageOption::None + MessageOption::None }), "ServerChallenge" => vec![0; 8], "Reserved" => vec![0; 8], @@ -172,7 +172,7 @@ fn authenticate_message( if node.inner() & (Negotiate::NtlmsspNegociateVersion as u32) == 0 { return MessageOption::SkipField("Version".to_string()) } - return MessageOption::None + MessageOption::None }), "Version" => version() ], @@ -257,7 +257,7 @@ fn read_target_info(data: &[u8]) -> RdpResult>> { result.insert(av_id, cast!(DataType::Slice, element["Value"])?.to_vec()); } - return Ok(result); + Ok(result) } /// Zero filled array @@ -312,13 +312,13 @@ fn md5(data: &[u8]) -> Vec { /// ```rust, ignore /// let encoded_string = unicode("foo".to_string()); /// ``` -fn unicode(data: &String) -> Vec { +fn unicode(data: &str) -> Vec { let mut result = Cursor::new(Vec::new()); for c in data.encode_utf16() { let encode_char = U16::LE(c); encode_char.write(&mut result).unwrap(); } - return result.into_inner(); + result.into_inner() } /// Compute HMAC with MD5 hash algorithm @@ -344,10 +344,10 @@ fn hmac_md5(key: &[u8], data: &[u8]) -> Vec { /// ```rust, ignore /// let key = ntowfv2("hello123".to_string(), "user".to_string(), "domain".to_string()) /// ``` -fn ntowfv2(password: &String, user: &String, domain: &String) -> Vec { +fn ntowfv2(password: &str, user: &str, domain: &str) -> Vec { hmac_md5( &md4(&unicode(password)), - &unicode(&(user.to_uppercase() + &domain)), + &unicode(&(user.to_uppercase() + domain)), ) } @@ -361,8 +361,8 @@ fn ntowfv2(password: &String, user: &String, domain: &String) -> Vec { /// ```rust, ignore /// let key = ntowfv2("hello123".to_string(), "user".to_string(), "domain".to_string()) /// ``` -fn ntowfv2_hash(hash: &[u8], user: &String, domain: &String) -> Vec { - hmac_md5(hash, &unicode(&(user.to_uppercase() + &domain))) +fn ntowfv2_hash(hash: &[u8], user: &str, domain: &str) -> Vec { + hmac_md5(hash, &unicode(&(user.to_uppercase() + domain))) } /// This function is used to compute init key of another hmac_md5 @@ -372,7 +372,7 @@ fn ntowfv2_hash(hash: &[u8], user: &String, domain: &String) -> Vec { /// ```rust, ignore /// let key = lmowfv2("hello123".to_string(), "user".to_string(), "domain".to_string()) /// ``` -fn lmowfv2(password: &String, user: &String, domain: &String) -> Vec { +fn lmowfv2(password: &str, user: &str, domain: &str) -> Vec { ntowfv2(password, user, domain) } @@ -412,7 +412,7 @@ fn compute_response_v2( response_key_nt, &[server_challenge.to_vec(), temp.clone()].concat(), ); - let nt_challenge_response = [nt_proof_str.clone(), temp.clone()].concat(); + let nt_challenge_response = [nt_proof_str.clone(), temp].concat(); let lm_challenge_response = [ hmac_md5( response_key_lm, @@ -448,7 +448,7 @@ fn kx_key_v2( fn rc4k(key: &[u8], plaintext: &[u8]) -> Vec { let mut result = vec![0; plaintext.len()]; let mut rc4_handle = Rc4::new(key); - rc4_handle.process(&plaintext, &mut result); + rc4_handle.process(plaintext, &mut result); result } @@ -609,7 +609,7 @@ impl AuthenticationProtocol for Ntlm { | Negotiate::NtlmsspNegociateUnicode as u32, )); self.negotiate_message = Some(buffer.clone()); - return Ok(buffer); + Ok(buffer) } /// Read the server challenge @@ -645,16 +645,16 @@ impl AuthenticationProtocol for Ntlm { let response = compute_response_v2( &self.response_key_nt, &self.response_key_lm, - &server_challenge, + server_challenge, &client_challenge, ×tamp, - &target_name, + target_name, ); let nt_challenge_response = response.0; let lm_challenge_response = response.1; let session_base_key = response.2; let key_exchange_key = - kx_key_v2(&session_base_key, &lm_challenge_response, &server_challenge); + kx_key_v2(&session_base_key, &lm_challenge_response, server_challenge); self.exported_session_key = Some(random(16)); let encrypted_random_session_key = rc4k( @@ -797,7 +797,7 @@ impl GenericSecurityService for NTLMv2SecurityInterface { let mut encrypted_data = vec![0; data.len()]; self.encrypt.process(data, &mut encrypted_data); let signature = mac(&mut self.encrypt, &self.signing_key, self.seq_num, data); - self.seq_num = self.seq_num + 1; + self.seq_num += 1; Ok(to_vec(&trame![signature, encrypted_data])) } @@ -884,10 +884,7 @@ mod test { /// Test of the unicode function #[test] fn test_unicode() { - assert_eq!( - unicode(&"foo".to_string()), - [0x66, 0x00, 0x6f, 0x00, 0x6f, 0x00] - ) + assert_eq!(unicode("foo"), [0x66, 0x00, 0x6f, 0x00, 0x6f, 0x00]) } /// Test HMAC_MD5 function @@ -906,11 +903,7 @@ mod test { #[test] fn test_ntowfv2() { assert_eq!( - ntowfv2( - &"foo".to_string(), - &"user".to_string(), - &"domain".to_string() - ), + ntowfv2("foo", "user", "domain"), [ 0x6e, 0x53, 0xb9, 0x0, 0x97, 0x8c, 0x87, 0x1f, 0x91, 0xde, 0x6, 0x44, 0x9d, 0x8b, 0x8b, 0x81 @@ -922,16 +915,8 @@ mod test { #[test] fn test_lmowfv2() { assert_eq!( - lmowfv2( - &"foo".to_string(), - &"user".to_string(), - &"domain".to_string() - ), - ntowfv2( - &"foo".to_string(), - &"user".to_string(), - &"domain".to_string() - ) + lmowfv2("foo", "user", "domain"), + ntowfv2("foo", "user", "domain") ) } diff --git a/src/nla/rc4.rs b/src/nla/rc4.rs index 97e0310..12a37c3 100644 --- a/src/nla/rc4.rs +++ b/src/nla/rc4.rs @@ -6,7 +6,7 @@ pub struct Rc4 { impl Rc4 { pub fn new(key: &[u8]) -> Rc4 { - assert!(key.len() >= 1 && key.len() <= 256); + assert!(!key.is_empty() && key.len() <= 256); let mut rc4 = Rc4 { i: 0, j: 0, @@ -28,9 +28,8 @@ impl Rc4 { self.i = self.i.wrapping_add(1); self.j = self.j.wrapping_add(self.state[self.i as usize]); self.state.swap(self.i as usize, self.j as usize); - let k = self.state - [(self.state[self.i as usize].wrapping_add(self.state[self.j as usize])) as usize]; - k + + self.state[(self.state[self.i as usize].wrapping_add(self.state[self.j as usize])) as usize] } pub fn process(&mut self, input: &[u8], output: &mut [u8]) { From d5f00c6069721b41150e51c372cdc9e32ca2798f Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Sat, 22 Oct 2022 00:25:02 +0800 Subject: [PATCH 06/12] mstsc-rs use async-io too --- Cargo.toml | 5 +- src/bin/mstsc-rs.rs | 132 ++++++++++++++++++++++++-------------------- 2 files changed, 76 insertions(+), 61 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3e24dd0..f0c9b67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ default = ["openssl"] # The reason we do this is because doctests don't get cfg(test) # See: https://github.com/rust-lang/cargo/issues/4669 integration = [] -mstsc-rs = ["hex", "winapi", "minifb", "clap", "libc"] +mstsc-rs = ["hex", "winapi", "minifb", "clap", "libc", "openssl", "futures"] openssl = ["async-native-tls"] [dependencies] @@ -48,4 +48,5 @@ hex = { version = "^0.4", optional = true } winapi = { version = "^0.3", features = ["winsock2"], optional = true } minifb = { version = "^0.15", optional = true } clap = { version = "^2.33", optional = true} -libc = { version = "^0.2", optional = true} \ No newline at end of file +libc = { version = "^0.2", optional = true} +futures = { version = "0.3", optional = true } \ No newline at end of file diff --git a/src/bin/mstsc-rs.rs b/src/bin/mstsc-rs.rs index a2d4488..ccc3283 100644 --- a/src/bin/mstsc-rs.rs +++ b/src/bin/mstsc-rs.rs @@ -7,10 +7,9 @@ use rdp::core::event::{BitmapEvent, KeyboardEvent, PointerButton, PointerEvent, use rdp::core::gcc::KeyboardLayout; use rdp::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use std::convert::TryFrom; -use std::io::{Read, Write}; use std::mem; use std::mem::{forget, size_of}; -use std::net::{SocketAddr, TcpStream}; +use std::net::SocketAddr; #[cfg(any(target_os = "linux", target_os = "macos"))] use std::os::unix::io::AsRawFd; #[cfg(target_os = "windows")] @@ -19,10 +18,13 @@ use std::ptr; use std::ptr::copy_nonoverlapping; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::{Receiver, Sender}; -use std::sync::{mpsc, Arc, Mutex}; -use std::thread; -use std::thread::JoinHandle; +use std::sync::{mpsc, Arc}; use std::time::Instant; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, + sync::Mutex, +}; #[cfg(target_os = "windows")] use winapi::um::winsock2::{fd_set, select}; @@ -106,8 +108,8 @@ fn fast_bitmap_transfer(buffer: &mut Vec, width: usize, bitmap: BitmapEvent ))); } copy_nonoverlapping( - data_aligned.as_ptr().offset((src_i) as isize), - buffer.as_mut_ptr().offset(dest_i as isize), + data_aligned.as_ptr().add(src_i), + buffer.as_mut_ptr().add(dest_i), count, ) } @@ -243,7 +245,7 @@ fn to_scancode(key: Key) -> u16 { } /// Create a tcp stream from main args -fn tcp_from_args(args: &ArgMatches) -> RdpResult { +async fn tcp_from_args(args: &ArgMatches<'_>) -> RdpResult { let ip = args .value_of("host") .expect("You need to provide a target argument"); @@ -258,7 +260,7 @@ fn tcp_from_args(args: &ArgMatches) -> RdpResult { &format!("Cannot parse the IP PORT input [{}]", e), )) })?; - let tcp = TcpStream::connect(&addr).unwrap(); + let tcp = TcpStream::connect(&addr).await.unwrap(); tcp.set_nodelay(true).map_err(|e| { Error::RdpError(RdpError::new( RdpErrorKind::InvalidData, @@ -270,7 +272,10 @@ fn tcp_from_args(args: &ArgMatches) -> RdpResult { } /// Create rdp client from args -fn rdp_from_args(args: &ArgMatches, stream: S) -> RdpResult> { +async fn rdp_from_args( + args: &ArgMatches<'_>, + stream: S, +) -> RdpResult> { let width = args .value_of("width") .unwrap_or_default() @@ -327,7 +332,7 @@ fn rdp_from_args(args: &ArgMatches, stream: S) -> RdpResult RdpResult { /// This will launch the thread in charge /// of receiving event (mostly bitmap event) /// And send back to the gui thread -fn launch_rdp_thread( +fn launch_rdp_thread( handle: usize, rdp_client: Arc>>, sync: Arc, bitmap_channel: Sender, -) -> RdpResult> { +) -> RdpResult> { // Create the rdp thread - Ok(thread::spawn(move || { - while wait_for_fd(handle as usize) && sync.load(Ordering::Relaxed) { - let mut guard = rdp_client.lock().unwrap(); - if let Err(Error::RdpError(e)) = guard.read(|event| match event { - RdpEvent::Bitmap(bitmap) => { - bitmap_channel.send(bitmap).unwrap(); - } - _ => println!("{}: ignore event", APPLICATION_NAME), - }) { - match e.kind() { - RdpErrorKind::Disconnect => { - println!("{}: Server ask for disconnect", APPLICATION_NAME); + Ok(std::thread::spawn(move || { + futures::executor::block_on(async { + while wait_for_fd(handle as usize) && sync.load(Ordering::Relaxed) { + let mut guard = rdp_client.lock().await; + if let Err(Error::RdpError(e)) = guard + .read(|event| match event { + RdpEvent::Bitmap(bitmap) => { + bitmap_channel.send(bitmap).unwrap(); + } + _ => println!("{}: ignore event", APPLICATION_NAME), + }) + .await + { + match e.kind() { + RdpErrorKind::Disconnect => { + println!("{}: Server ask for disconnect", APPLICATION_NAME); + } + _ => println!("{}: {:?}", APPLICATION_NAME, e), } - _ => println!("{}: {:?}", APPLICATION_NAME, e), + break; } - break; } - } + }) })) } /// This is the main loop /// Print Window and handle all input (mous + keyboard) /// to RDP -fn main_gui_loop( +async fn main_gui_loop( mut window: Window, rdp_client: Arc>>, sync: Arc, @@ -445,48 +455,49 @@ fn main_gui_loop( // Mouse position input if let Some((x, y)) = window.get_mouse_pos(MouseMode::Clamp) { - let mut rdp_client_guard = rdp_client.lock().map_err(|e| { - Error::RdpError(RdpError::new( - RdpErrorKind::Unknown, - &format!("Thread error during access to mutex [{}]", e), - )) - })?; + let mut rdp_client_guard = rdp_client.lock().await; // Button is down if not 0 let current_button = get_rdp_pointer_down(&window); - rdp_client_guard.try_write(RdpEvent::Pointer(PointerEvent { - x: x as u16, - y: y as u16, - button: if last_button == current_button { - PointerButton::None - } else { - PointerButton::try_from(last_button as u8 | current_button as u8).unwrap() - }, - down: (last_button != current_button) && last_button == PointerButton::None, - }))?; + rdp_client_guard + .try_write(RdpEvent::Pointer(PointerEvent { + x: x as u16, + y: y as u16, + button: if last_button == current_button { + PointerButton::None + } else { + PointerButton::try_from(last_button as u8 | current_button as u8).unwrap() + }, + down: (last_button != current_button) && last_button == PointerButton::None, + })) + .await?; last_button = current_button; } // Keyboard inputs if let Some(keys) = window.get_keys() { - let mut rdp_client_guard = rdp_client.lock().unwrap(); + let mut rdp_client_guard = rdp_client.lock().await; for key in last_keys.iter() { if !keys.contains(key) { - rdp_client_guard.try_write(RdpEvent::Key(KeyboardEvent { - code: to_scancode(*key), - down: false, - }))? + rdp_client_guard + .try_write(RdpEvent::Key(KeyboardEvent { + code: to_scancode(*key), + down: false, + })) + .await? } } for key in keys.iter() { if window.is_key_pressed(*key, KeyRepeat::Yes) { - rdp_client_guard.try_write(RdpEvent::Key(KeyboardEvent { - code: to_scancode(*key), - down: true, - }))? + rdp_client_guard + .try_write(RdpEvent::Key(KeyboardEvent { + code: to_scancode(*key), + down: true, + })) + .await? } } @@ -505,11 +516,12 @@ fn main_gui_loop( } sync.store(false, Ordering::Relaxed); - rdp_client.lock().unwrap().shutdown()?; + rdp_client.lock().await.shutdown().await?; Ok(()) } -fn main() { +#[tokio::main] +async fn main() { // Parsing argument let matches = App::new(APPLICATION_NAME) .version("0.1.0") @@ -610,7 +622,7 @@ fn main() { .get_matches(); // Create a tcp stream from args - let tcp = tcp_from_args(&matches).unwrap(); + let tcp = tcp_from_args(&matches).await.unwrap(); // Keep trace of the handle #[cfg(target_os = "windows")] @@ -620,7 +632,7 @@ fn main() { let handle = tcp.as_raw_fd(); // Create rdp client - let rdp_client = rdp_from_args(&matches, tcp).unwrap(); + let rdp_client = rdp_from_args(&matches, tcp).await.unwrap(); let window = window_from_args(&matches).unwrap(); @@ -644,7 +656,9 @@ fn main() { .unwrap(); // Launch the GUI - main_gui_loop(window, rdp_client_mutex, sync, bitmap_receiver).unwrap(); + main_gui_loop(window, rdp_client_mutex, sync, bitmap_receiver) + .await + .unwrap(); rdp_thread.join().unwrap(); } From b08c27094da0b4c375aee83689ab9b420a09b529 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Sat, 22 Oct 2022 01:27:47 +0800 Subject: [PATCH 07/12] fix crashes --- src/model/link.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/link.rs b/src/model/link.rs index 1c40a6b..c56630a 100644 --- a/src/model/link.rs +++ b/src/model/link.rs @@ -204,7 +204,7 @@ impl Link { TlsConnector::new() .danger_accept_invalid_certs(!check_certificate) .use_sni(false) - .connect("", stream) + .connect("not_in_use.com", stream) .await?, ))); } From fab65c817069be508fe412265fa819b784bba993 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Sat, 22 Oct 2022 09:46:27 +0800 Subject: [PATCH 08/12] Bump dependencies to the latest --- Cargo.toml | 18 +++++++++--------- src/bin/mstsc-rs.rs | 1 + src/nla/cssp.rs | 4 ++-- src/nla/ntlm.rs | 16 ++++++++-------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f0c9b67..17b2106 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,19 +29,19 @@ mstsc-rs = ["hex", "winapi", "minifb", "clap", "libc", "openssl", "futures"] openssl = ["async-native-tls"] [dependencies] -tokio = { version = "1.21.2", features = ["full"] } +tokio = { version = "^1", features = ["full"] } async-native-tls = { version = "^0.4", optional = true, default-features = false, features = ["runtime-tokio"] } byteorder = "^1.3" bufstream = "0.1" indexmap = "^1.3" -yasna = { version = "^0.3" } -md4 = "^0.8" -hmac = "^0.7" -md-5 = "^0.8" -rand = "^0.7" -num-bigint = "^0.2" -x509-parser = "0.6.5" -num_enum = "0.4.3" +yasna = { version = "^0.4" } +md4 = "^0.9" +hmac = "^0.11" +md-5 = "^0.9" +rand = "^0.8" +num-bigint = "^0.4" +x509-parser = "^0.12" +num_enum = "^0.5" # for mtsc-rs hex = { version = "^0.4", optional = true } diff --git a/src/bin/mstsc-rs.rs b/src/bin/mstsc-rs.rs index ccc3283..b560f12 100644 --- a/src/bin/mstsc-rs.rs +++ b/src/bin/mstsc-rs.rs @@ -69,6 +69,7 @@ fn wait_for_fd(fd: usize) -> bool { /// Transmute is use to convert Vec -> Vec /// To accelerate data convert +#[allow(clippy::missing_safety_doc)] pub unsafe fn transmute_vec(mut vec: Vec) -> Vec { let ptr = vec.as_mut_ptr(); let capacity = vec.capacity() * size_of::() / size_of::(); diff --git a/src/nla/cssp.rs b/src/nla/cssp.rs index 082bba2..584cbec 100644 --- a/src/nla/cssp.rs +++ b/src/nla/cssp.rs @@ -6,7 +6,7 @@ use crate::nla::asn1::{ use crate::nla::sspi::AuthenticationProtocol; use num_bigint::BigUint; use tokio::io::*; -use x509_parser::{parse_x509_der, X509Certificate}; +use x509_parser::prelude::*; use yasna::Tag; /// Create a ts request as expected by the specification @@ -102,7 +102,7 @@ pub fn create_ts_authenticate(nego: Vec, pub_key_auth: Vec) -> Vec { } pub fn read_public_certificate(stream: &[u8]) -> RdpResult { - let res = parse_x509_der(stream).unwrap(); + let res = X509Certificate::from_der(stream).unwrap(); Ok(res.1) } diff --git a/src/nla/ntlm.rs b/src/nla/ntlm.rs index 3e0a96a..7ae5e48 100644 --- a/src/nla/ntlm.rs +++ b/src/nla/ntlm.rs @@ -5,7 +5,7 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use crate::model::rnd::random; use crate::nla::rc4::Rc4; use crate::nla::sspi::{AuthenticationProtocol, GenericSecurityService}; -use hmac::{Hmac, Mac}; +use hmac::{Hmac, Mac, NewMac}; use md4::{Digest, Md4}; use md5::Md5; use num_enum::TryFromPrimitive; @@ -284,8 +284,8 @@ fn z(m: usize) -> Vec { /// ``` fn md4(data: &[u8]) -> Vec { let mut hasher = Md4::new(); - hasher.input(data); - hasher.result().to_vec() + hasher.update(data); + hasher.finalize().to_vec() } /// Compute the MD5 Hash of input vector @@ -299,8 +299,8 @@ fn md4(data: &[u8]) -> Vec { /// ``` fn md5(data: &[u8]) -> Vec { let mut hasher = Md5::new(); - hasher.input(data); - hasher.result().to_vec() + hasher.update(data); + hasher.finalize().to_vec() } /// Encode a string into utf-16le @@ -330,9 +330,9 @@ fn unicode(data: &str) -> Vec { /// let signature = hmac_md5(b"foo", b"bar"); /// ``` fn hmac_md5(key: &[u8], data: &[u8]) -> Vec { - let mut stream = Hmac::::new_varkey(key).unwrap(); - stream.input(data); - stream.result().code().to_vec() + let mut stream = Hmac::::new_from_slice(key).unwrap(); + stream.update(data); + stream.finalize().into_bytes().to_vec() } /// This function is used to compute init key of another hmac_md5 From b8f884557dced3b5770d187730982714fcfe0e54 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Sat, 22 Oct 2022 18:10:01 +0800 Subject: [PATCH 09/12] Async Bio Support --- Cargo.toml | 1 + src/codec/rle.rs | 6 +++--- src/core/client.rs | 6 +++--- src/core/gcc.rs | 2 +- src/core/per.rs | 2 +- src/core/x224.rs | 2 +- src/model/link.rs | 52 ++++++++++++++++++++++++++-------------------- 7 files changed, 40 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 17b2106..6e98020 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ openssl = ["async-native-tls"] [dependencies] tokio = { version = "^1", features = ["full"] } async-native-tls = { version = "^0.4", optional = true, default-features = false, features = ["runtime-tokio"] } +async-trait = { version = "^0.1" } byteorder = "^1.3" bufstream = "0.1" indexmap = "^1.3" diff --git a/src/codec/rle.rs b/src/codec/rle.rs index 061b400..ce05d44 100644 --- a/src/codec/rle.rs +++ b/src/codec/rle.rs @@ -225,7 +225,7 @@ pub fn rle_16_decompress( match opcode { 0 => { - if lastopcode == opcode && !(x == width && prevline == None) { + if lastopcode == opcode && !(x == width && prevline.is_none()) { insertmix = true; } } @@ -390,10 +390,10 @@ pub fn rle_16_decompress( } pub fn rgb565torgb32(input: &[u16], width: usize, height: usize) -> Vec { - let mut result_32_bpp = vec![0_u8; width as usize * height as usize * 4]; + let mut result_32_bpp = vec![0_u8; width * height * 4]; for i in 0..height { for j in 0..width { - let index = (i * width + j) as usize; + let index = i * width + j; let v = input[index]; result_32_bpp[index * 4 + 3] = 0xff; result_32_bpp[index * 4 + 2] = (((((v >> 11) & 0x1f) * 527) + 23) >> 6) as u8; diff --git a/src/core/client.rs b/src/core/client.rs index 815602d..db29073 100644 --- a/src/core/client.rs +++ b/src/core/client.rs @@ -8,7 +8,7 @@ use crate::core::tpkt; use crate::core::x224; use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; #[cfg(not(feature = "openssl"))] -use crate::model::link::SecureBio; +use crate::model::link::AsyncSecureBio; use crate::model::link::{Link, Stream}; use crate::nla::ntlm::Ntlm; use tokio::io::*; @@ -249,13 +249,13 @@ impl Connector { self.connect_further(tcp).await } #[cfg(not(feature = "openssl"))] - pub fn connect + 'static>( + pub async fn connect + 'static>( &mut self, stream: Box, ) -> RdpResult> { // Create a wrapper around the stream let tcp = Link::new(Stream::Bio(stream)); - self.connect_further(tcp) + self.connect_further(tcp).await } async fn connect_further( diff --git a/src/core/gcc.rs b/src/core/gcc.rs index 5c73ccc..7f32d77 100644 --- a/src/core/gcc.rs +++ b/src/core/gcc.rs @@ -303,7 +303,7 @@ pub fn server_network_data() -> Component { pub fn block_header(data_type: Option, length: Option) -> Component { component![ "type" => U16::LE(data_type.unwrap_or(MessageType::CsCore) as u16), - "length" => U16::LE(length.unwrap_or(0) as u16 + 4) + "length" => U16::LE(length.unwrap_or(0) + 4) ] } diff --git a/src/core/per.rs b/src/core/per.rs index 8d44b95..5bfe2a8 100644 --- a/src/core/per.rs +++ b/src/core/per.rs @@ -201,7 +201,7 @@ pub fn read_integer(s: &mut dyn Read) -> RdpResult { 4 => { let mut result = U32::BE(0); result.read(s)?; - Ok(result.inner() as u32) + Ok(result.inner()) } _ => Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidSize, diff --git a/src/core/x224.rs b/src/core/x224.rs index f800d50..22c949c 100644 --- a/src/core/x224.rs +++ b/src/core/x224.rs @@ -79,7 +79,7 @@ fn rdp_neg_req( /// X224 request header fn x224_crq(len: u8, code: MessageType) -> Component { component! [ - "len" => (len + 6) as u8, + "len" => len + 6, "code" => code as u8, "padding" => trame! [U16::LE(0), U16::LE(0), 0_u8] ] diff --git a/src/model/link.rs b/src/model/link.rs index c56630a..2e4eed2 100644 --- a/src/model/link.rs +++ b/src/model/link.rs @@ -6,13 +6,17 @@ use std::io::Cursor; use tokio::io::*; #[cfg(not(feature = "openssl"))] -pub trait SecureBio +use async_trait::async_trait; + +#[cfg(not(feature = "openssl"))] +#[async_trait] +pub trait AsyncSecureBio where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin, { - fn start_ssl(&mut self, check_certificate: bool) -> RdpResult<()>; + async fn start_ssl(&mut self, check_certificate: bool) -> RdpResult<()>; fn get_peer_certificate_der(&self) -> RdpResult>>; - fn shutdown(&mut self) -> std::io::Result<()>; + async fn shutdown(&mut self) -> std::io::Result<()>; fn get_io(&mut self) -> &mut S; } @@ -25,7 +29,7 @@ pub enum Stream { #[cfg(feature = "openssl")] Ssl(TlsStream), #[cfg(not(feature = "openssl"))] - Bio(Box>), + Bio(Box>), } impl Stream { @@ -46,7 +50,7 @@ impl Stream { #[cfg(feature = "openssl")] Stream::Ssl(e) => e.read_exact(buf).await?, #[cfg(not(feature = "openssl"))] - Stream::Bio(bio) => bio.get_io().read_exact(buf)?, + Stream::Bio(bio) => bio.get_io().read_exact(buf).await?, }; Ok(()) } @@ -68,7 +72,7 @@ impl Stream { #[cfg(feature = "openssl")] Stream::Ssl(e) => Ok(e.read(buf).await?), #[cfg(not(feature = "openssl"))] - Stream::Bio(e) => Ok(e.get_io().read(buf)?), + Stream::Bio(e) => Ok(e.get_io().read(buf).await?), } } @@ -94,20 +98,21 @@ impl Stream { #[cfg(feature = "openssl")] Stream::Ssl(e) => e.write(buffer).await?, #[cfg(not(feature = "openssl"))] - Stream::Bio(e) => e.get_io().write(buffer)?, + Stream::Bio(e) => e.get_io().write(buffer).await?, }) } /// Shutdown the stream /// Only works when stream is a SSL stream pub async fn shutdown(&mut self) -> RdpResult<()> { - match self { - #[cfg(feature = "openssl")] - Stream::Ssl(e) => e.shutdown().await?, - #[cfg(not(feature = "openssl"))] - Stream::Bio(e) => e.shutdown()?, - _ => (), - }; + #[cfg(feature = "openssl")] + if let Stream::Ssl(e) = self { + e.shutdown().await? + } + #[cfg(not(feature = "openssl"))] + if let Stream::Bio(e) = self { + e.shutdown().await?; + } Ok(()) } } @@ -215,15 +220,16 @@ impl Link { } #[cfg(not(feature = "openssl"))] - pub fn start_ssl(self, check_certificate: bool) -> RdpResult> { + pub async fn start_ssl(self, check_certificate: bool) -> RdpResult> { if let Stream::Bio(mut stream) = self.stream { - stream.start_ssl(check_certificate)?; - return Ok(Link::new(Stream::Bio(stream))); + stream.start_ssl(check_certificate).await?; + Ok(Link::new(Stream::Bio(stream))) + } else { + Err(Error::RdpError(RdpError::new( + RdpErrorKind::NotImplemented, + "start_ssl on ssl stream is forbidden", + ))) } - Err(Error::RdpError(RdpError::new( - RdpErrorKind::NotImplemented, - "start_ssl on ssl stream is forbidden", - ))) } /// Retrive the peer certificate /// Use by the NLA authentication protocol @@ -244,6 +250,8 @@ impl Link { Some(cert) => Some(cert.to_der()?), None => None, }), + #[cfg(not(feature = "openssl"))] + Stream::Bio(stream) => stream.get_peer_certificate_der(), _ => Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidData, "get peer certificate on non ssl link is impossible", From c340a749868f70fde328bf15c751f843da6e119e Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Sat, 22 Oct 2022 22:41:39 +0800 Subject: [PATCH 10/12] Support build for wasm32 --- Cargo.toml | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6e98020..9699e7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,6 @@ mstsc-rs = ["hex", "winapi", "minifb", "clap", "libc", "openssl", "futures"] openssl = ["async-native-tls"] [dependencies] -tokio = { version = "^1", features = ["full"] } async-native-tls = { version = "^0.4", optional = true, default-features = false, features = ["runtime-tokio"] } async-trait = { version = "^0.1" } byteorder = "^1.3" @@ -50,4 +49,18 @@ winapi = { version = "^0.3", features = ["winsock2"], optional = true } minifb = { version = "^0.15", optional = true } clap = { version = "^2.33", optional = true} libc = { version = "^0.2", optional = true} -futures = { version = "0.3", optional = true } \ No newline at end of file +futures = { version = "0.3", optional = true } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { version = "^1", features = ["full"] } + + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2", features = ["js"] } +tokio = { version = "^1", features = [ + "sync", + "macros", + "io-util", + "rt", + "time" + ]} From 21a40e7f32e4769d74f428176da34539d7eb8539 Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Mon, 31 Oct 2022 08:39:07 +0000 Subject: [PATCH 11/12] update the lib version as api changed --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9699e7a..d2825dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rdp-rs" -version = "0.1.0" +version = "0.2.0" authors = ["Sylvain Peyrefitte "] description = "A Pure RUST imlementation of Remote Desktop Protocol" repository = "https://github.com/citronneur/rdp-rs" From ffdbc22db8888f8a8465b27d19f14824accab81c Mon Sep 17 00:00:00 2001 From: Jovi Hsu Date: Mon, 31 Oct 2022 08:39:16 +0000 Subject: [PATCH 12/12] Using tracing crate to generate more compatible log --- Cargo.toml | 12 +++++++++++- src/bin/mstsc-rs.rs | 20 +++++++++++++++++--- src/core/gcc.rs | 4 +++- src/core/global.rs | 26 ++++++++++++++++++-------- src/core/mcs.rs | 3 ++- src/core/tpkt.rs | 13 ++++++++++--- 6 files changed, 61 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d2825dd..2a0f212 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,15 @@ default = ["openssl"] # The reason we do this is because doctests don't get cfg(test) # See: https://github.com/rust-lang/cargo/issues/4669 integration = [] -mstsc-rs = ["hex", "winapi", "minifb", "clap", "libc", "openssl", "futures"] +mstsc-rs = ["hex", + "winapi", + "minifb", + "clap", + "libc", + "openssl", + "futures", + "tracing-subscriber" + ] openssl = ["async-native-tls"] [dependencies] @@ -42,6 +50,7 @@ rand = "^0.8" num-bigint = "^0.4" x509-parser = "^0.12" num_enum = "^0.5" +tracing = { version = "^0.1", features = ["log"] } # for mtsc-rs hex = { version = "^0.4", optional = true } @@ -50,6 +59,7 @@ minifb = { version = "^0.15", optional = true } clap = { version = "^2.33", optional = true} libc = { version = "^0.2", optional = true} futures = { version = "0.3", optional = true } +tracing-subscriber = { version = "^0.3", optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio = { version = "^1", features = ["full"] } diff --git a/src/bin/mstsc-rs.rs b/src/bin/mstsc-rs.rs index b560f12..064dded 100644 --- a/src/bin/mstsc-rs.rs +++ b/src/bin/mstsc-rs.rs @@ -25,6 +25,7 @@ use tokio::{ net::TcpStream, sync::Mutex, }; +use tracing::{event, Level}; #[cfg(target_os = "windows")] use winapi::um::winsock2::{fd_set, select}; @@ -398,15 +399,19 @@ fn launch_rdp_thread( RdpEvent::Bitmap(bitmap) => { bitmap_channel.send(bitmap).unwrap(); } - _ => println!("{}: ignore event", APPLICATION_NAME), + _ => event!(Level::WARN, "{}: ignore event", APPLICATION_NAME), }) .await { match e.kind() { RdpErrorKind::Disconnect => { - println!("{}: Server ask for disconnect", APPLICATION_NAME); + event!( + Level::WARN, + "{}: Server ask for disconnect", + APPLICATION_NAME + ); } - _ => println!("{}: {:?}", APPLICATION_NAME, e), + _ => event!(Level::WARN, "{}: {:?}", APPLICATION_NAME, e), } break; } @@ -622,6 +627,15 @@ async fn main() { ) .get_matches(); + // Create tracing subscriber + let subscriber = tracing_subscriber::FmtSubscriber::builder() + // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) + // will be written to stdout. + .with_max_level(Level::INFO) + // completes the builder. + .finish(); + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + // Create a tcp stream from args let tcp = tcp_from_args(&matches).await.unwrap(); diff --git a/src/core/gcc.rs b/src/core/gcc.rs index 7f32d77..8e3b771 100644 --- a/src/core/gcc.rs +++ b/src/core/gcc.rs @@ -6,6 +6,7 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult}; use crate::model::unicode::Unicode; use std::collections::HashMap; use std::io::{Cursor, Read}; +use tracing::{event, Level}; const T124_02_98_OID: [u8; 6] = [0, 0, 20, 124, 0, 1]; const H221_CS_KEY: [u8; 4] = *b"Duca"; @@ -371,7 +372,8 @@ pub fn read_conference_create_response(cc_response: &mut dyn Read) -> RdpResult< server_net.read(&mut Cursor::new(buffer))?; result.insert(MessageType::ScNet, server_net); } - _ => println!( + _ => event!( + Level::WARN, "GCC: Unknown server block {:?}", cast!(DataType::U16, header["type"])? ), diff --git a/src/core/global.rs b/src/core/global.rs index 2126b99..c05a6c6 100644 --- a/src/core/global.rs +++ b/src/core/global.rs @@ -12,6 +12,7 @@ use num_enum::TryFromPrimitive; use std::convert::TryFrom; use std::io::{Cursor, Read}; use tokio::io::*; +use tracing::{event, Level}; /// Raw PDU type use by the protocol #[repr(u16)] @@ -644,7 +645,7 @@ impl Client { for capability_set in cast!(DataType::Trame, pdu.message["capabilitySets"])?.iter() { match Capability::from_capability_set(cast!(DataType::Component, capability_set)?) { Ok(capability) => self.server_capabilities.push(capability), - Err(e) => println!("GLOBAL: {:?}", e), + Err(e) => event!(Level::WARN, "GLOBAL: {:?}", e), } } self.share_id = Some(cast!(DataType::U32, pdu.message["shareId"])?); @@ -719,24 +720,29 @@ impl Client { // Ask for a new handshake if pdu.pdu_type == PduType::Deactivateallpdu { - println!("GLOBAL: deactive/reactive sequence initiated"); + event!(Level::WARN, "GLOBAL: deactive/reactive sequence initiated"); self.state = ClientState::DemandActivePDU; continue; } if pdu.pdu_type != PduType::Datapdu { - println!("GLOBAL: Ignore PDU {:?}", pdu.pdu_type); + event!(Level::WARN, "GLOBAL: Ignore PDU {:?}", pdu.pdu_type); continue; } match DataPDU::from_pdu(&pdu) { Ok(data_pdu) => match data_pdu.pdu_type { - PDUType2::Pdutype2SetErrorInfoPdu => println!( + PDUType2::Pdutype2SetErrorInfoPdu => event!( + Level::WARN, "GLOBAL: Receive error PDU from server {:?}", cast!(DataType::U32, data_pdu.message["errorInfo"])? ), - _ => println!("GLOBAL: Data PDU not handle {:?}", data_pdu.pdu_type), + _ => event!( + Level::WARN, + "GLOBAL: Data PDU not handle {:?}", + data_pdu.pdu_type + ), }, - Err(e) => println!("GLOBAL: Parsing data PDU error {:?}", e), + Err(e) => event!(Level::WARN, "GLOBAL: Parsing data PDU error {:?}", e), }; } Ok(()) @@ -780,10 +786,14 @@ impl Client { FastPathUpdateType::FastpathUpdatetypeColor | FastPathUpdateType::FastpathUpdatetypePtrNull | FastPathUpdateType::FastpathUpdatetypeSynchronize => (), - _ => println!("GLOBAL: Fast Path order not handled {:?}", order.fp_type), + _ => event!( + Level::WARN, + "GLOBAL: Fast Path order not handled {:?}", + order.fp_type + ), } } - Err(e) => println!("GLOBAL: Unknown Fast Path order {:?}", e), + Err(e) => event!(Level::WARN, "GLOBAL: Unknown Fast Path order {:?}", e), }; } diff --git a/src/core/mcs.rs b/src/core/mcs.rs index 07969a9..1e536d9 100644 --- a/src/core/mcs.rs +++ b/src/core/mcs.rs @@ -14,6 +14,7 @@ use crate::nla::asn1::{ use std::collections::HashMap; use std::io::{BufRead, Cursor, Read}; use tokio::io::*; +use tracing::{event, Level}; use yasna::Tag; #[allow(dead_code)] @@ -327,7 +328,7 @@ impl Client { *channel_id, &mut try_let!(tpkt::Payload::Raw, self.x224.read().await?)?, )? { - println!("Server reject channel id {:?}", channel_id); + event!(Level::WARN, "Server reject channel id {:?}", channel_id); } } diff --git a/src/core/tpkt.rs b/src/core/tpkt.rs index 14fe33e..d8d83c0 100644 --- a/src/core/tpkt.rs +++ b/src/core/tpkt.rs @@ -138,7 +138,11 @@ impl Client { if size.inner() < 4 { Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidSize, - "Invalid minimal size for TPKT", + &format!( + "Invalid minimal size for TPKT #1 ({}, {})", + action, + size.inner() + ), ))) } else { // now wait for body @@ -159,7 +163,7 @@ impl Client { if length < 3 { Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidSize, - "Invalid minimal size for TPKT", + &format!("Invalid minimal size for TPKT #2 ({}, {})", action, length), ))) } else { Ok(Payload::FastPath( @@ -170,7 +174,10 @@ impl Client { } else if short_length < 2 { Err(Error::RdpError(RdpError::new( RdpErrorKind::InvalidSize, - "Invalid minimal size for TPKT", + &format!( + "Invalid minimal size for TPKT #3 ({}, {})", + action, short_length + ), ))) } else { Ok(Payload::FastPath(