Skip to content

Commit

Permalink
fix for review
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielc committed Nov 13, 2024
1 parent de08854 commit 5deb304
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 60 deletions.
42 changes: 30 additions & 12 deletions api/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ pub struct Server<C, I, M> {
marker: PhantomData<C>,
authentication: bool,

pipeline: SessionContext,
pipeline: Option<SessionContext>,
}

impl<C, I, M> Server<C, I, M>
Expand All @@ -381,7 +381,7 @@ where
network: Network,
interest: I,
model: Arc<M>,
pipeline: SessionContext,
pipeline: Option<SessionContext>,
shutdown_signal: broadcast::Receiver<()>,
) -> Self {
let (tx, event_rx) = tokio::sync::mpsc::channel::<EventInsert>(1024);
Expand Down Expand Up @@ -835,12 +835,10 @@ where

async fn get_stream_state(
&self,
stream_id: String,
pipeline: &SessionContext,
stream_id: StreamId,
) -> Result<ceramic_api_server::StreamsStreamIdGetResponse, ErrorResponse> {
let stream_id = StreamId::from_str(&stream_id)
.map_err(|err| ErrorResponse::new(format!("invalid stream_id: {err}")))?;
let state_batch = self
.pipeline
let state_batch = pipeline
.table("doc_state")
.await
.map_err(|err| ErrorResponse::new(format!("doc_state table not found: {err}")))?
Expand Down Expand Up @@ -894,7 +892,7 @@ where
ErrorResponse::new(format!("failed to execute pipeline query: {err}"))
})?;

if state_batch.len() == 0 {
if state_batch.is_empty() {
return Ok(
ceramic_api_server::StreamsStreamIdGetResponse::StreamNotFound(
stream_id.to_string(),
Expand All @@ -912,7 +910,7 @@ where
let state = as_string_array(
batch
.column_by_name("state")
.ok_or_else(|| ErrorResponse::new(format!("state column should exist")))?,
.ok_or_else(|| ErrorResponse::new("state column should exist".to_string()))?,
)
.map_err(|err| ErrorResponse::new(format!("state should be a string column: {err}")))?;
let state: State = serde_json::from_str(state.value(0)).map_err(|err| {
Expand Down Expand Up @@ -1182,9 +1180,29 @@ where
stream_id: String,
_context: &C,
) -> Result<ceramic_api_server::StreamsStreamIdGetResponse, ApiError> {
self.get_stream_state(stream_id).await.or_else(|err| {
Ok(ceramic_api_server::StreamsStreamIdGetResponse::InternalServerError(err))
})
let stream_id = match StreamId::from_str(&stream_id) {
Ok(stream_id) => stream_id,
Err(err) => {
return Ok(ceramic_api_server::StreamsStreamIdGetResponse::BadRequest(
models::BadRequestResponse {
message: format!("invalid stream id: {err}"),
},
))
}
};
if let Some(pipeline) = &self.pipeline {
self.get_stream_state(pipeline, stream_id)
.await
.or_else(|err| {
Ok(ceramic_api_server::StreamsStreamIdGetResponse::InternalServerError(err))
})
} else {
Ok(ceramic_api_server::StreamsStreamIdGetResponse::BadRequest(
models::BadRequestResponse {
message: "cannot use stream state API without enabling aggregator".to_string(),
},
))
}
}

/// cors
Expand Down
56 changes: 15 additions & 41 deletions api/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use ceramic_api_server::{
};
use ceramic_core::{Cid, Interest};
use ceramic_core::{EventId, Network, NodeId, PeerId, StreamId};
use ceramic_pipeline::ConclusionFeed;
use datafusion::arrow::array::{
BinaryBuilder, BinaryDictionaryBuilder, MapBuilder, MapFieldNames, RecordBatch, StringArray,
StringBuilder, UInt64Array, UInt8Array,
Expand Down Expand Up @@ -181,7 +180,7 @@ fn create_test_server<C, I, M>(
network: Network,
interest: I,
model: Arc<M>,
pipeline: SessionContext,
pipeline: Option<SessionContext>,
) -> Server<C, I, M>
where
I: InterestService,
Expand All @@ -191,32 +190,6 @@ where
Server::new(node_id, network, interest, model, pipeline, rx)
}

async fn empty_pipeline() -> SessionContext {
let mut feed = MockFeed::new();
feed.expect_conclusion_events_since()
.returning(|_, _| Ok(vec![]));
ceramic_pipeline::session_from_config(ceramic_pipeline::Config {
conclusion_feed: feed.into(),
object_store: Arc::new(object_store::memory::InMemory::new()),
object_store_bucket_name: "test_bucket".to_string(),
})
.await
.unwrap()
}

mock! {
#[derive(Debug)]
pub Feed {}
#[async_trait]
impl ConclusionFeed for Feed {
async fn conclusion_events_since(
&self,
highwater_mark: i64,
limit: i64,
) -> anyhow::Result<Vec<ceramic_pipeline::ConclusionEvent>>;
}
}

#[test(tokio::test)]
async fn create_event() {
let node_id = NodeId::random().0;
Expand Down Expand Up @@ -251,7 +224,7 @@ async fn create_event() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let resp = server
.events_post(
Expand Down Expand Up @@ -303,7 +276,7 @@ async fn create_event_fails() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let resp = server
.events_post(
Expand Down Expand Up @@ -359,7 +332,7 @@ async fn register_interest_sort_value() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let interest = models::Interest {
sep: "model".to_string(),
Expand Down Expand Up @@ -387,7 +360,7 @@ async fn register_interest_sort_value_bad_request() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let interest = models::Interest {
sep: "model".to_string(),
Expand Down Expand Up @@ -442,7 +415,7 @@ async fn register_interest_sort_value_controller() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let resp = server
.interests_sort_key_sort_value_post(
Expand Down Expand Up @@ -500,7 +473,7 @@ async fn register_interest_value_controller_stream() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let resp = server
.interests_sort_key_sort_value_post(
Expand Down Expand Up @@ -574,7 +547,7 @@ async fn get_interests() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let resp = server
.experimental_interests_get(None, &Context)
Expand Down Expand Up @@ -668,7 +641,7 @@ async fn get_interests_for_peer() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let resp = server
.experimental_interests_get(Some(peer_id_a.to_string()), &Context)
Expand Down Expand Up @@ -735,7 +708,7 @@ async fn get_events_for_interest_range() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let resp = server
.experimental_events_sep_sep_value_get(
Expand Down Expand Up @@ -790,7 +763,7 @@ async fn events_event_id_get_by_event_id_success() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let result = server.events_event_id_get(event_id_str, &Context).await;
let EventsEventIdGetResponse::Success(event) = result.unwrap() else {
Expand Down Expand Up @@ -823,7 +796,7 @@ async fn events_event_id_get_by_cid_success() {
network,
mock_interest,
Arc::new(mock_event_store),
empty_pipeline().await,
None,
);
let result = server
.events_event_id_get(event_cid.to_string(), &Context)
Expand Down Expand Up @@ -860,7 +833,7 @@ async fn stream_state() {
network,
mock_interest,
Arc::new(mock_event_store),
pipeline,
Some(pipeline),
);
let result = server
.streams_stream_id_get(
Expand All @@ -887,7 +860,8 @@ async fn stream_state() {
},
),
)
"#]].assert_debug_eq(&result);
"#]]
.assert_debug_eq(&result);
}
// helper function to generate some stream states
fn states() -> RecordBatch {
Expand Down
9 changes: 2 additions & 7 deletions one/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,14 +567,9 @@ pub async fn run(opts: DaemonOpts) -> Result<()> {
})
.await
});
(pipeline_ctx, aggregator_handle, Some(flight_handle))
(Some(pipeline_ctx), aggregator_handle, Some(flight_handle))
} else {
// TODO
(
datafusion::execution::context::SessionContext::default(),
None,
None,
)
(None, None, None)
};

// Start anchoring if remote anchor service URL is provided
Expand Down

0 comments on commit 5deb304

Please sign in to comment.