Skip to content

Commit

Permalink
Chain request headers in WSGI protocol (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro authored May 21, 2024
1 parent 14f2583 commit 1c118e1
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ futures = "0.3"
http-body-util = { version = "=0.1" }
hyper = { version = "=1.3", features = ["http1", "http2", "server"] }
hyper-util = { version = "=0.1", features = ["server-auto", "tokio"] }
itertools = "0.13"
log = "0.4"
percent-encoding = "=2.3"
pin-project = "1.1"
Expand Down
12 changes: 9 additions & 3 deletions src/wsgi/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use hyper::{
http::{request, uri::Authority},
Version,
};
use itertools::Itertools;
use percent_encoding::percent_decode_str;
use pyo3::{
prelude::*,
Expand Down Expand Up @@ -51,15 +52,20 @@ fn run_callback(
let content_type = parts.headers.remove(header::CONTENT_TYPE);
let content_len = parts.headers.remove(header::CONTENT_LENGTH);
let mut headers = Vec::with_capacity(parts.headers.len());
for (key, val) in &parts.headers {
for key in parts.headers.keys() {
headers.push((
format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()),
val.to_str().unwrap_or_default(),
parts
.headers
.get_all(key)
.iter()
.map(|v| v.to_str().unwrap_or_default())
.join(","),
));
}
if !parts.headers.contains_key(header::HOST) {
let host = parts.uri.authority().map_or("", Authority::as_str);
headers.push(("HTTP_HOST".to_string(), host));
headers.push(("HTTP_HOST".to_string(), host.to_string()));
}

Python::with_gil(|py| {
Expand Down
5 changes: 4 additions & 1 deletion tests/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
async def test_scope(wsgi_server, threading_mode):
payload = 'body_payload'
async with wsgi_server(threading_mode) as port:
res = httpx.post(f'http://localhost:{port}/info?test=true', content=payload)
res = httpx.post(
f'http://localhost:{port}/info?test=true', content=payload, headers=[('test', 'val1'), ('test', 'val2')]
)

assert res.status_code == 200
assert res.headers['content-type'] == 'application/json'
Expand All @@ -22,6 +24,7 @@ async def test_scope(wsgi_server, threading_mode):
assert data['query_string'] == 'test=true'
assert data['headers']['HTTP_HOST'] == f'localhost:{port}'
assert data['content_length'] == str(len(payload))
assert data['headers']['HTTP_TEST'] == 'val1,val2'


@pytest.mark.asyncio
Expand Down

0 comments on commit 1c118e1

Please sign in to comment.