Skip to content

Commit

Permalink
ARC-1370: Add content negotiation for JSON (#44)
Browse files Browse the repository at this point in the history
Co-authored-by: Eduard Thamm <[email protected]>
  • Loading branch information
thomasrichner-oviva and eduardOrthopy authored Mar 6, 2024
1 parent a0e3c2b commit 81f22d3
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,34 @@ public Response auth(
.build();
}

@GET
@Produces(MediaType.APPLICATION_JSON)
public Response authJson(
@QueryParam("scope") String scope,
@QueryParam("state") String state,
@QueryParam("response_type") String responseType,
@QueryParam("client_id") String clientId,
@QueryParam("redirect_uri") String redirectUri,
@QueryParam("nonce") String nonce) {

var uri = mustParse(redirectUri);

var res =
authService.auth(
new AuthorizationRequest(scope, state, responseType, clientId, uri, nonce));

var availableIdentityProviders =
res.identityProviders().stream()
.map(idp -> new IdpEntry(idp.iss(), idp.name(), idp.logoUrl()))
.toList();

var body = new AuthResponse(availableIdentityProviders);

return Response.ok(body, MediaType.APPLICATION_JSON_TYPE)
.cookie(createSessionCookie(res.sessionId()))
.build();
}

@NonNull
private URI mustParse(@Nullable String uri) {
if (uri == null || uri.isBlank()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.oviva.ehealthid.relyingparty.ws;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.oviva.ehealthid.relyingparty.svc.AuthenticationException;
import com.oviva.ehealthid.relyingparty.svc.ValidationException;
import com.oviva.ehealthid.relyingparty.ws.ui.Pages;
Expand All @@ -14,9 +15,12 @@
import jakarta.ws.rs.core.Response.StatusType;
import jakarta.ws.rs.core.UriInfo;
import jakarta.ws.rs.ext.ExceptionMapper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.jboss.resteasy.util.MediaTypeHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.spi.LoggingEventBuilder;
Expand All @@ -30,7 +34,9 @@ public class ThrowableExceptionMapper implements ExceptionMapper<Throwable> {
@Context Request request;
@Context HttpHeaders headers;

// Note: MUST be non-final for mocking
// Note: below fields MUST be non-final for mocking
private MediaTypeNegotiator mediaTypeNegotiator = new ResteasyMediaTypeNegotiator();

private Logger logger = LoggerFactory.getLogger(ThrowableExceptionMapper.class);

@Override
Expand Down Expand Up @@ -65,18 +71,22 @@ public Response toResponse(Throwable exception) {

private Response buildContentNegotiatedErrorResponse(String message, StatusType status) {

if (acceptsTextHtml()) {
var mediaType =
mediaTypeNegotiator.bestMatch(
headers.getAcceptableMediaTypes(),
List.of(MediaType.TEXT_HTML_TYPE, MediaType.APPLICATION_JSON_TYPE));

if (MediaType.TEXT_HTML_TYPE.equals(mediaType)) {
var body = pages.error(message);
return Response.status(status).entity(body).type(MediaType.TEXT_HTML_TYPE).build();
}

return Response.status(status).build();
}
if (MediaType.APPLICATION_JSON_TYPE.equals(mediaType)) {
var body = new Problem("/server_error", message);
return Response.status(status).entity(body).type(MediaType.APPLICATION_JSON_TYPE).build();
}

private boolean acceptsTextHtml() {
var acceptable = headers.getAcceptableMediaTypes();
return acceptable.contains(MediaType.WILDCARD_TYPE)
|| acceptable.contains(MediaType.TEXT_HTML_TYPE);
return Response.status(status).build();
}

private StatusType determineStatus(Throwable exception) {
Expand Down Expand Up @@ -137,6 +147,22 @@ private LoggingEventBuilder addTraceInfo(LoggingEventBuilder log) {
return log.addKeyValue("traceId", parsed.traceId()).addKeyValue("spanId", parsed.spanId());
}

interface MediaTypeNegotiator {
MediaType bestMatch(List<MediaType> desiredMediaType, List<MediaType> supportedMediaTypes);
}

private static class ResteasyMediaTypeNegotiator implements MediaTypeNegotiator {

@Override
public MediaType bestMatch(
List<MediaType> desiredMediaType, List<MediaType> supportedMediaTypes) {

// note: resteasy needs mutable lists
return MediaTypeHelper.getBestMatch(
new ArrayList<>(desiredMediaType), new ArrayList<>(supportedMediaTypes));
}
}

private record Traceparent(String spanId, String traceId) {

// https://www.w3.org/TR/trace-context/#traceparent-header-field-values
Expand All @@ -160,4 +186,6 @@ static Traceparent parse(String s) {
return new Traceparent(spanId, traceId);
}
}

public record Problem(@JsonProperty("type") String type, @JsonProperty("title") String title) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.oviva.ehealthid.fedclient.IdpEntry;
import com.oviva.ehealthid.relyingparty.svc.AuthService;
import com.oviva.ehealthid.relyingparty.svc.AuthService.AuthorizationRequest;
import com.oviva.ehealthid.relyingparty.svc.AuthService.AuthorizationResponse;
import com.oviva.ehealthid.relyingparty.svc.AuthService.CallbackRequest;
import com.oviva.ehealthid.relyingparty.svc.AuthService.SelectedIdpRequest;
import com.oviva.ehealthid.relyingparty.svc.ValidationException;
import com.oviva.ehealthid.relyingparty.util.IdGenerator;
import com.oviva.ehealthid.relyingparty.ws.AuthEndpoint.AuthResponse;
import jakarta.ws.rs.core.Response.Status;
import java.net.URI;
import java.util.List;
Expand Down Expand Up @@ -101,6 +103,44 @@ void auth_success() {
}
}

@Test
void authJson_success() {
var identityProviders = List.of(new IdpEntry("a", "A", null), new IdpEntry("b", "B", null));

var sessionId = IdGenerator.generateID();
var authService = mock(AuthService.class);
when(authService.auth(any()))
.thenReturn(new AuthorizationResponse(identityProviders, sessionId));
var sut = new AuthEndpoint(authService);

var scope = "openid";
var state = UUID.randomUUID().toString();
var nonce = UUID.randomUUID().toString();
var responseType = "code";
var clientId = "myapp";

// when
try (var res = sut.authJson(scope, state, responseType, clientId, REDIRECT_URI, nonce)) {

// then
assertEquals(Status.OK.getStatusCode(), res.getStatus());

var authResponse = res.readEntity(AuthResponse.class);
var actualIdentityProviders = authResponse.identityProviders();
assertEquals(identityProviders.size(), actualIdentityProviders.size());
for (int i = 0; i < identityProviders.size(); i++) {
var expected = identityProviders.get(i);
var actual = actualIdentityProviders.get(i);
assertEquals(expected.iss(), actual.iss());
assertEquals(expected.name(), actual.name());
assertEquals(expected.logoUrl(), actual.logoUrl());
}

var sessionCookie = res.getCookies().get("session_id");
assertEquals(sessionId, sessionCookie.getValue());
}
}

@Test
void callback_success() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.oviva.ehealthid.relyingparty.svc.AuthenticationException;
import com.oviva.ehealthid.relyingparty.svc.ValidationException;
import com.oviva.ehealthid.relyingparty.ws.ThrowableExceptionMapper.Problem;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.ServerErrorException;
import jakarta.ws.rs.core.HttpHeaders;
Expand Down Expand Up @@ -91,6 +94,16 @@ void toResponse_isLogged() {
verify(logger).atError();
}

@Test
void toResponse_authentication() {

// when
var res = mapper.toResponse(new AuthenticationException(null));

// then
assertEquals(401, res.getStatus());
}

@Test
void toResponse_withBody() {

Expand All @@ -106,4 +119,25 @@ void toResponse_withBody() {
assertEquals(MediaType.TEXT_HTML_TYPE, res.getMediaType());
assertNotNull(res.getEntity());
}

@Test
void toResponse_withJson() {

when(headers.getAcceptableMediaTypes())
.thenReturn(
List.of(
MediaType.APPLICATION_JSON_TYPE,
MediaType.TEXT_HTML_TYPE,
MediaType.WILDCARD_TYPE));

var msg = "Ooops! An error :/";

// when
var res = mapper.toResponse(new ValidationException(msg));

// then
assertEquals(400, res.getStatus());
assertEquals(MediaType.APPLICATION_JSON_TYPE, res.getMediaType());
assertEquals(new Problem("/server_error", msg), res.getEntity());
}
}

0 comments on commit 81f22d3

Please sign in to comment.