Skip to content

Commit

Permalink
fix: bugs in scheduled task (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Sep 1, 2024
1 parent 9714bce commit 4e29475
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 16 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ features = [
"Win32_Graphics_Dwm",
"Win32_Graphics_Gdi",
"Win32_Security",
"Win32_Security_Authorization",
"Win32_System_Console",
"Win32_System_LibraryLoader",
"Win32_System_Registry",
"Win32_System_SystemInformation",
"Win32_System_Threading",
"Win32_Storage_FileSystem",
]
Expand Down
20 changes: 11 additions & 9 deletions src/utils/admin.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::HandleWrapper;

use anyhow::{anyhow, Result};
use windows::Win32::{
Foundation::{CloseHandle, HANDLE},
Security::{GetTokenInformation, TokenElevation, TOKEN_ELEVATION, TOKEN_QUERY},
System::Threading::{GetCurrentProcess, OpenProcessToken},
};
Expand All @@ -12,22 +13,23 @@ pub fn is_running_as_admin() -> Result<bool> {

fn is_running_as_admin_impl() -> Result<bool> {
let is_elevated = unsafe {
let mut token_handle: HANDLE = HANDLE(0);
let mut token_handle = HandleWrapper::default();
let mut elevation = TOKEN_ELEVATION::default();
let mut returned_length = 0;
OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &mut token_handle)?;
OpenProcessToken(
GetCurrentProcess(),
TOKEN_QUERY,
token_handle.get_handle_mut(),
)?;

let token_information = GetTokenInformation(
token_handle,
GetTokenInformation(
token_handle.get_handle(),
TokenElevation,
Some(&mut elevation as *mut _ as *mut _),
std::mem::size_of::<TOKEN_ELEVATION>() as u32,
&mut returned_length,
);

CloseHandle(token_handle)?;
)?;

token_information?;
elevation.TokenIsElevated != 0
};
Ok(is_elevated)
Expand Down
29 changes: 29 additions & 0 deletions src/utils/handle_wrapper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use windows::Win32::Foundation::{CloseHandle, HANDLE};

#[derive(Debug, Clone, Default)]
pub struct HandleWrapper {
handle: HANDLE,
}

impl HandleWrapper {
pub fn new(handle: HANDLE) -> Self {
Self { handle }
}
pub fn get_handle(&self) -> HANDLE {
self.handle
}
pub fn get_handle_mut(&mut self) -> &mut HANDLE {
&mut self.handle
}
}

impl Drop for HandleWrapper {
fn drop(&mut self) {
if self.handle.is_invalid() {
return;
}
unsafe {
let _ = CloseHandle(self.handle);
}
}
}
2 changes: 2 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod admin;
mod check_error;
mod handle_wrapper;
mod regedit;
mod scheduled_task;
mod single_instance;
Expand All @@ -8,6 +9,7 @@ mod windows_icon;

pub use admin::*;
pub use check_error::*;
pub use handle_wrapper::*;
pub use regedit::*;
pub use scheduled_task::*;
pub use single_instance::*;
Expand Down
176 changes: 169 additions & 7 deletions src/utils/scheduled_task.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
use std::{os::windows::process::CommandExt, process::Command};
use super::HandleWrapper;

use anyhow::{bail, Result};
use windows::Win32::System::Threading::CREATE_NO_WINDOW;
use anyhow::{anyhow, bail, Result};
use std::{
env,
ffi::OsString,
fs,
os::windows::{ffi::OsStringExt, process::CommandExt},
process::Command,
};
use windows::core::{Result as WindowsResult, PWSTR};
use windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER;
use windows::Win32::Security::Authorization::ConvertSidToStringSidW;
use windows::Win32::Security::{
GetTokenInformation, LookupAccountSidW, TokenUser, SID_NAME_USE, TOKEN_QUERY, TOKEN_USER,
};
use windows::Win32::System::SystemInformation::GetLocalTime;
use windows::Win32::System::Threading::{GetCurrentProcess, OpenProcessToken, CREATE_NO_WINDOW};

pub fn create_scheduled_task(name: &str, exe_path: &str) -> Result<()> {
let task_xml_path = create_task_file(name, exe_path)
.map_err(|err| anyhow!("Failed to create scheduled task, {err}"))?;
debug!("scheduled task file: {}", task_xml_path);
let output = Command::new("schtasks")
.creation_flags(CREATE_NO_WINDOW.0) // CREATE_NO_WINDOW flag
.args([
"/create", "/tn", name, "/tr", exe_path, "/sc", "onlogon", "/rl", "highest", "/it",
"/f",
])
.args(["/create", "/tn", name, "/xml", &task_xml_path, "/f"])
.output()?;
if !output.status.success() {
bail!(
Expand Down Expand Up @@ -45,3 +59,151 @@ pub fn exist_scheduled_task(name: &str) -> Result<bool> {
Ok(false)
}
}

fn create_task_file(name: &str, exe_path: &str) -> Result<String> {
let (author, user_id) = get_author_and_userid()
.map_err(|err| anyhow!("Failed to get author and user id, {err}"))?;
let current_time = get_current_time();
let command_path = if exe_path.contains(|c: char| c.is_whitespace()) {
format!("\"{}\"", exe_path)
} else {
exe_path.to_string()
};
let xml_data = format!(
r#"<?xml version="1.0" encoding="UTF-16"?>
<Task version="1.2" xmlns="http://schemas.microsoft.com/windows/2004/02/mit/task">
<RegistrationInfo>
<Date>{current_time}</Date>
<Author>{author}</Author>
<URI>\{name}</URI>
</RegistrationInfo>
<Triggers>
<LogonTrigger>
<StartBoundary>{current_time}</StartBoundary>
<Enabled>true</Enabled>
</LogonTrigger>
</Triggers>
<Principals>
<Principal id="Author">
<UserId>{user_id}</UserId>
<LogonType>InteractiveToken</LogonType>
<RunLevel>HighestAvailable</RunLevel>
</Principal>
</Principals>
<Settings>
<MultipleInstancesPolicy>IgnoreNew</MultipleInstancesPolicy>
<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>
<StopIfGoingOnBatteries>true</StopIfGoingOnBatteries>
<AllowHardTerminate>true</AllowHardTerminate>
<StartWhenAvailable>false</StartWhenAvailable>
<RunOnlyIfNetworkAvailable>false</RunOnlyIfNetworkAvailable>
<IdleSettings>
<StopOnIdleEnd>true</StopOnIdleEnd>
<RestartOnIdle>false</RestartOnIdle>
</IdleSettings>
<AllowStartOnDemand>true</AllowStartOnDemand>
<Enabled>true</Enabled>
<Hidden>false</Hidden>
<RunOnlyIfIdle>false</RunOnlyIfIdle>
<WakeToRun>false</WakeToRun>
<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>
<Priority>7</Priority>
</Settings>
<Actions Context="Author">
<Exec>
<Command>{command_path}</Command>
</Exec>
</Actions>
</Task>"#
);
let xml_path = env::temp_dir().join("window-switcher-task.xml");
let xml_path = xml_path.display().to_string();
fs::write(&xml_path, xml_data)
.map_err(|err| anyhow!("Failed to write task xml file at '{xml_path}', {err}",))?;
Ok(xml_path)
}

fn get_author_and_userid() -> WindowsResult<(String, String)> {
let mut token_handle = HandleWrapper::default();
unsafe {
OpenProcessToken(
GetCurrentProcess(),
TOKEN_QUERY,
token_handle.get_handle_mut(),
)?
};

let mut token_info_length = 0;
if let Err(err) = unsafe {
GetTokenInformation(
token_handle.get_handle(),
TokenUser,
None,
0,
&mut token_info_length,
)
} {
if err != ERROR_INSUFFICIENT_BUFFER.into() {
return Err(err);
}
}

let mut token_user = Vec::<u8>::with_capacity(token_info_length as usize);
unsafe {
GetTokenInformation(
token_handle.get_handle(),
TokenUser,
Some(token_user.as_mut_ptr() as *mut _),
token_info_length,
&mut token_info_length,
)?
};

let user_sid = unsafe { *(token_user.as_ptr() as *const TOKEN_USER) }
.User
.Sid;

let mut name = Vec::<u16>::with_capacity(256);
let mut name_len = 256;
let mut domain = Vec::<u16>::with_capacity(256);
let mut domain_len = 256;
let mut sid_name_use = SID_NAME_USE(0);

unsafe {
LookupAccountSidW(
None,
user_sid,
PWSTR(name.as_mut_ptr()),
&mut name_len,
PWSTR(domain.as_mut_ptr()),
&mut domain_len,
&mut sid_name_use,
)?
};

unsafe {
name.set_len(name_len as usize);
domain.set_len(domain_len as usize);
}

let username = OsString::from_wide(&name).to_string_lossy().into_owned();
let domainname = OsString::from_wide(&domain).to_string_lossy().into_owned();

let mut sid_string = PWSTR::null();
unsafe { ConvertSidToStringSidW(user_sid, &mut sid_string)? };

let sid_str = OsString::from_wide(unsafe { sid_string.as_wide() })
.to_string_lossy()
.into_owned();

Ok((format!("{}\\{}", domainname, username), sid_str))
}

fn get_current_time() -> String {
let st = unsafe { GetLocalTime() };

format!(
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}",
st.wYear, st.wMonth, st.wDay, st.wHour, st.wMinute, st.wSecond,
)
}

0 comments on commit 4e29475

Please sign in to comment.