Skip to content

Commit

Permalink
Remove Channel::request_cancellation.
Browse files Browse the repository at this point in the history
This trait fn returns a private type, which means it's useless for
anyone using the Channel.

Instead, add an inert (now-public) ResponseGuard to TrackedRequest that,
when taken out of the ManuallyDrop, ensures a Channel's request state is
cleaned up. It's preferable to make ResponseGuard public instead of
RequestCancellations because it's a smaller API surface (no public
methods, just a Drop fn) and harder to misuse, because it is already
associated with the correct request ID to cancel.
  • Loading branch information
tikue committed Aug 12, 2022
1 parent 453ba1c commit 68863e3
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 53 deletions.
8 changes: 8 additions & 0 deletions tarpc/src/cancellations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ pub fn cancellations() -> (RequestCancellation, CanceledRequests) {

impl RequestCancellation {
/// Cancels the request with ID `request_id`.
///
/// No validation is done of `request_id`. There is no way to know if the request id provided
/// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is
/// a one-way communication channel.
///
/// Once request data is cleaned up, a response will never be received by the client. This is
/// useful primarily when request processing ends prematurely for requests with long deadlines
/// which would otherwise continue to be tracked by the backing channel—a kind of leak.
pub fn cancel(&self, request_id: u64) {
let _ = self.0.send(request_id);
}
Expand Down
76 changes: 42 additions & 34 deletions tarpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ use futures::{
};
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin};
use std::{
convert::TryFrom,
error::Error,
fmt,
marker::PhantomData,
mem::{self, ManuallyDrop},
pin::Pin,
};
use tracing::{info_span, instrument::Instrument, Span};

mod in_flight_requests;
Expand Down Expand Up @@ -199,9 +206,13 @@ where
Ok(abort_registration) => {
drop(entered);
Ok(TrackedRequest {
request,
abort_registration,
span,
response_guard: ManuallyDrop::new(ResponseGuard {
request_id: request.id,
request_cancellation: self.request_cancellation.clone(),
}),
request,
})
}
Err(AlreadyExistsError) => {
Expand All @@ -228,6 +239,8 @@ pub struct TrackedRequest<Req> {
pub abort_registration: AbortRegistration,
/// A span representing the server processing of this request.
pub span: Span,
/// An inert response guard. Becomes active in an InFlightRequest.
pub response_guard: ManuallyDrop<ResponseGuard>,
}

/// The server end of an open connection with a client, receiving requests from, and sending
Expand All @@ -246,13 +259,15 @@ pub struct TrackedRequest<Req> {
/// [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests
/// from, and send responses into, a Channel in lieu of the previous methods. Channels stream
/// [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the
/// server [`Span`] and request lifetime [`AbortRegistration`]. Wrapping response
/// logic in an [`Abortable`] future using the abort registration will ensure that the response
/// does not execute longer than the request deadline. The `Channel` itself will clean up
/// request state once either the deadline expires, or a cancellation message is received, or a
/// response is sent. Because there is no guarantee that a cancellation message will ever be
/// received for a request, or that requests come with reasonably short deadlines, services
/// should strive to clean up Channel resources by sending a response for every request.
/// server [`Span`], request lifetime [`AbortRegistration`], and an inert [`ResponseGuard`].
/// Wrapping response logic in an [`Abortable`] future using the abort registration will ensure
/// that the response does not execute longer than the request deadline. The `Channel` itself
/// will clean up request state once either the deadline expires, or the response guard is
/// dropped, or a response is sent.
///
/// Channels must be implemented using the decorator pattern: the only way to create a
/// `TrackedRequest` is to get one from another `Channel`. Ultimately, all `TrackedRequests` are
/// created by [`BaseChannel`].
pub trait Channel
where
Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
Expand All @@ -275,12 +290,6 @@ where
/// Returns the transport underlying the channel.
fn transport(&self) -> &Self::Transport;

/// Returns a reference to the channel's request cancellation channel, which can be used to
/// clean up request data when request processing ends prematurely.
///
/// Once request data is cleaned up, a response cannot be sent back to the client.
fn request_cancellation(&self) -> &RequestCancellation;

/// Caps the number of concurrent requests to `limit`. An error will be returned for requests
/// over the concurrency limit.
///
Expand Down Expand Up @@ -525,10 +534,6 @@ where
fn transport(&self) -> &Self::Transport {
self.get_ref()
}

fn request_cancellation(&self) -> &RequestCancellation {
&self.request_cancellation
}
}

/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
Expand Down Expand Up @@ -571,19 +576,22 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
self.channel_pin_mut().poll_next(cx).map_ok(|request| {
let request_id = request.request.id;
InFlightRequest {
request: request.request,
abort_registration: request.abort_registration,
span: request.span,
response_tx: self.responses_tx.clone(),
response_guard: ResponseGuard {
request_id,
request_cancellation: self.channel.request_cancellation().clone(),
},
}
})
self.channel_pin_mut().poll_next(cx).map_ok(
|TrackedRequest {
request,
abort_registration,
span,
response_guard,
}| {
InFlightRequest {
request,
abort_registration,
span,
response_guard: ManuallyDrop::into_inner(response_guard),
response_tx: self.responses_tx.clone(),
}
},
)
}

fn pump_write(
Expand Down Expand Up @@ -660,10 +668,10 @@ where
}
}

/// A fail-safe to ensure requests are properly canceled if an InFlightRequest is dropped before
/// A fail-safe to ensure requests are properly canceled if request processing is aborted before
/// completing.
#[derive(Debug)]
struct ResponseGuard {
pub struct ResponseGuard {
request_cancellation: RequestCancellation,
request_id: u64,
}
Expand Down
5 changes: 0 additions & 5 deletions tarpc/src/server/limits/channels_per_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// https://opensource.org/licenses/MIT.

use crate::{
cancellations::RequestCancellation,
server::{self, Channel},
util::Compact,
};
Expand Down Expand Up @@ -120,10 +119,6 @@ where
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}

fn request_cancellation(&self) -> &RequestCancellation {
self.inner.request_cancellation()
}
}

impl<C, K> TrackedChannel<C, K> {
Expand Down
8 changes: 0 additions & 8 deletions tarpc/src/server/limits/requests_per_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// https://opensource.org/licenses/MIT.

use crate::{
cancellations::RequestCancellation,
server::{Channel, Config},
Response, ServerError,
};
Expand Down Expand Up @@ -132,10 +131,6 @@ where
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}

fn request_cancellation(&self) -> &RequestCancellation {
self.inner.request_cancellation()
}
}

/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
Expand Down Expand Up @@ -315,9 +310,6 @@ mod tests {
fn transport(&self) -> &() {
&()
}
fn request_cancellation(&self) -> &RequestCancellation {
unreachable!()
}
}
}

Expand Down
13 changes: 7 additions & 6 deletions tarpc/src/server/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context,
server::{Channel, Config, TrackedRequest},
server::{Channel, Config, ResponseGuard, TrackedRequest},
Request, Response,
};
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime};
use tracing::Span;

#[pin_project]
Expand Down Expand Up @@ -84,15 +84,12 @@ where
fn transport(&self) -> &() {
&()
}

fn request_cancellation(&self) -> &RequestCancellation {
&self.request_cancellation
}
}

impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
let (request_cancellation, _) = cancellations();
self.stream.push_back(Ok(TrackedRequest {
request: Request {
context: context::Context {
Expand All @@ -104,6 +101,10 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
},
abort_registration,
span: Span::none(),
response_guard: ManuallyDrop::new(ResponseGuard {
request_cancellation,
request_id: id,
}),
}));
}
}
Expand Down

0 comments on commit 68863e3

Please sign in to comment.