Skip to content

Commit

Permalink
Don't defer WebSocket opening, close #1348
Browse files Browse the repository at this point in the history
Motivation:

We currently buffer WebSocket opening until first LastHttpContent
reception with the UpgradeCallback.
This doesn't make sense, and forces us to buffer any frame that might
be sent along with the upgrade response.

Modifications:

* Drop UpgradeHandler that's never used as an abstraction
* Perform upgrade/abort as soon as response is received
* Ignore LastHttpContent
* No need to buffer any frame

Result:

More simple code
  • Loading branch information
slandelle committed Feb 7, 2017
1 parent bff695e commit 821a0f7
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
Expand All @@ -36,7 +37,6 @@
import org.asynchttpclient.HttpResponseStatus;
import org.asynchttpclient.netty.NettyResponseFuture;
import org.asynchttpclient.netty.NettyResponseStatus;
import org.asynchttpclient.netty.OnLastHttpContentCallback;
import org.asynchttpclient.netty.channel.ChannelManager;
import org.asynchttpclient.netty.channel.Channels;
import org.asynchttpclient.netty.request.NettyRequestSender;
Expand All @@ -52,71 +52,45 @@ public WebSocketHandler(AsyncHttpClientConfig config,//
super(config, channelManager, requestSender);
}

private class UpgradeCallback extends OnLastHttpContentCallback {

private final Channel channel;
private final HttpResponse response;
private final WebSocketUpgradeHandler handler;
private final HttpResponseStatus status;
private final HttpResponseHeaders responseHeaders;

public UpgradeCallback(NettyResponseFuture<?> future, Channel channel, HttpResponse response, WebSocketUpgradeHandler handler, HttpResponseStatus status,
HttpResponseHeaders responseHeaders) {
super(future);
this.channel = channel;
this.response = response;
this.handler = handler;
this.status = status;
this.responseHeaders = responseHeaders;
private void upgrade(Channel channel, NettyResponseFuture<?> future, WebSocketUpgradeHandler handler, HttpResponse response, HttpResponseHeaders responseHeaders)
throws Exception {
boolean validStatus = response.status().equals(SWITCHING_PROTOCOLS);
boolean validUpgrade = response.headers().get(UPGRADE) != null;
String connection = response.headers().get(CONNECTION);
boolean validConnection = HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(connection);
final boolean headerOK = handler.onHeadersReceived(responseHeaders) == State.CONTINUE;
if (!headerOK || !validStatus || !validUpgrade || !validConnection) {
requestSender.abort(channel, future, new IOException("Invalid handshake response"));
return;
}

// We don't need to synchronize as replacing the "ws-decoder" will
// process using the same thread.
private void invokeOnSucces(Channel channel, WebSocketUpgradeHandler h) {
try {
h.onSuccess(new NettyWebSocket(channel, responseHeaders.getHeaders()));
} catch (Exception ex) {
logger.warn("onSuccess unexpected exception", ex);
}
String accept = response.headers().get(SEC_WEBSOCKET_ACCEPT);
String key = getAcceptKey(future.getNettyRequest().getHttpRequest().headers().get(SEC_WEBSOCKET_KEY));
if (accept == null || !accept.equals(key)) {
requestSender.abort(channel, future, new IOException("Invalid challenge. Actual: " + accept + ". Expected: " + key));
}

@Override
public void call() throws Exception {
boolean validStatus = response.status().equals(SWITCHING_PROTOCOLS);
boolean validUpgrade = response.headers().get(UPGRADE) != null;
String connection = response.headers().get(CONNECTION);
boolean validConnection = HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(connection);
boolean statusReceived = handler.onStatusReceived(status) == State.CONTINUE;

if (!statusReceived) {
try {
handler.onCompleted();
} finally {
future.done();
}
return;
}

final boolean headerOK = handler.onHeadersReceived(responseHeaders) == State.CONTINUE;
if (!headerOK || !validStatus || !validUpgrade || !validConnection) {
requestSender.abort(channel, future, new IOException("Invalid handshake response"));
return;
}
// set back the future so the protocol gets notified of frames
// removing the HttpClientCodec from the pipeline might trigger a read with a WebSocket message
// if it comes in the same frame as the HTTP Upgrade response
Channels.setAttribute(channel, future);

String accept = response.headers().get(SEC_WEBSOCKET_ACCEPT);
String key = getAcceptKey(future.getNettyRequest().getHttpRequest().headers().get(SEC_WEBSOCKET_KEY));
if (accept == null || !accept.equals(key)) {
requestSender.abort(channel, future, new IOException(String.format("Invalid challenge. Actual: %s. Expected: %s", accept, key)));
}
channelManager.upgradePipelineForWebSockets(channel.pipeline());

// set back the future so the protocol gets notified of frames
// removing the HttpClientCodec from the pipeline might trigger a read with a WebSocket message
// if it comes in the same frame as the HTTP Upgrade response
Channels.setAttribute(channel, future);

channelManager.upgradePipelineForWebSockets(channel.pipeline());
// We don't need to synchronize as replacing the "ws-decoder" will
// process using the same thread.
try {
handler.openWebSocket(new NettyWebSocket(channel, responseHeaders.getHeaders()));
} catch (Exception ex) {
logger.warn("onSuccess unexpected exception", ex);
}
future.done();
}

invokeOnSucces(channel, handler);
private void abort(NettyResponseFuture<?> future, WebSocketUpgradeHandler handler, HttpResponseStatus status) throws Exception {
try {
handler.onThrowable(new IOException("Invalid Status code=" + status.getStatusCode() + " text=" + status.getStatusText()));
} finally {
future.done();
}
}
Expand All @@ -136,36 +110,23 @@ public void handleRead(Channel channel, NettyResponseFuture<?> future, Object e)
HttpResponseHeaders responseHeaders = new HttpResponseHeaders(response.headers());

if (!interceptors.exitAfterIntercept(channel, future, handler, response, status, responseHeaders)) {
Channels.setAttribute(channel, new UpgradeCallback(future, channel, response, handler, status, responseHeaders));
switch (handler.onStatusReceived(status)) {
case CONTINUE:
upgrade(channel, future, handler, response, responseHeaders);
break;
default:
abort(future, handler, status);
}
}

} else if (e instanceof WebSocketFrame) {
final WebSocketFrame frame = (WebSocketFrame) e;
WebSocketUpgradeHandler handler = (WebSocketUpgradeHandler) future.getAsyncHandler();
NettyWebSocket webSocket = (NettyWebSocket) handler.onCompleted();
handleFrame(channel, frame, handler, webSocket);

if (webSocket != null) {
handleFrame(channel, frame, handler, webSocket);
} else {
logger.debug("Frame received but WebSocket is not available yet, buffering frame");
frame.retain();
Runnable bufferedFrame = new Runnable() {
public void run() {
try {
// WebSocket is now not null
NettyWebSocket webSocket = (NettyWebSocket) handler.onCompleted();
handleFrame(channel, frame, handler, webSocket);
} catch (Exception e) {
logger.debug("Failure while handling buffered frame", e);
handler.onFailure(e);
} finally {
frame.release();
}
}
};
handler.bufferFrame(bufferedFrame);
}
} else {
} else if (!(e instanceof LastHttpContent)) {
// ignore, end of handshake response
logger.error("Invalid message {}", e);
}
}
Expand Down Expand Up @@ -197,7 +158,6 @@ public void handleException(NettyResponseFuture<?> future, Throwable e) {

try {
WebSocketUpgradeHandler h = (WebSocketUpgradeHandler) future.getAsyncHandler();

NettyWebSocket webSocket = NettyWebSocket.class.cast(h.onCompleted());
if (webSocket != null) {
webSocket.onError(e.getCause());
Expand Down
37 changes: 0 additions & 37 deletions client/src/main/java/org/asynchttpclient/ws/UpgradeHandler.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
/*
* Copyright (c) 2010-2012 Sonatype, Inc. All rights reserved.
* Copyright (c) 2017 AsyncHttpClient Project. All rights reserved.
*
* This program is licensed to you under the Apache License Version 2.0,
* and you may not use this file except in compliance with the Apache License Version 2.0.
* You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0.
* You may obtain a copy of the Apache License Version 2.0 at
* http://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the Apache License Version 2.0 is distributed on an
Expand All @@ -12,11 +13,8 @@
*/
package org.asynchttpclient.ws;

import static org.asynchttpclient.util.MiscUtils.isNonEmpty;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import org.asynchttpclient.AsyncHandler;
import org.asynchttpclient.HttpResponseBodyPart;
Expand All @@ -26,85 +24,52 @@
/**
* An {@link AsyncHandler} which is able to execute WebSocket upgrade. Use the Builder for configuring WebSocket options.
*/
public class WebSocketUpgradeHandler implements UpgradeHandler<WebSocket>, AsyncHandler<WebSocket> {
public class WebSocketUpgradeHandler implements AsyncHandler<WebSocket> {

private static final int SWITCHING_PROTOCOLS = io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS.code();

private WebSocket webSocket;
private final List<WebSocketListener> listeners;
private final AtomicBoolean ok = new AtomicBoolean(false);
private int status;
private List<Runnable> bufferedFrames;

public WebSocketUpgradeHandler(List<WebSocketListener> listeners) {
this.listeners = listeners;
}

public void bufferFrame(Runnable bufferedFrame) {
if (bufferedFrames == null) {
bufferedFrames = new ArrayList<>(1);
}
bufferedFrames.add(bufferedFrame);
}

@Override
public final void onThrowable(Throwable t) {
onFailure(t);
public final State onStatusReceived(HttpResponseStatus responseStatus) throws Exception {
return responseStatus.getStatusCode() == SWITCHING_PROTOCOLS ? State.CONTINUE : State.ABORT;
}


@Override
public final State onBodyPartReceived(HttpResponseBodyPart bodyPart) throws Exception {
public final State onHeadersReceived(HttpResponseHeaders headers) throws Exception {
return State.CONTINUE;
}

@Override
public final State onStatusReceived(HttpResponseStatus responseStatus) throws Exception {
status = responseStatus.getStatusCode();
return status == SWITCHING_PROTOCOLS ? State.CONTINUE : State.ABORT;
}

@Override
public final State onHeadersReceived(HttpResponseHeaders headers) throws Exception {
public final State onBodyPartReceived(HttpResponseBodyPart bodyPart) throws Exception {
return State.CONTINUE;
}

@Override
public final WebSocket onCompleted() throws Exception {
if (status != SWITCHING_PROTOCOLS) {
IllegalStateException e = new IllegalStateException("Invalid Status Code " + status);
for (WebSocketListener listener : listeners) {
listener.onError(e);
}
throw e;
}

return webSocket;
}

@Override
public final void onSuccess(WebSocket webSocket) {
this.webSocket = webSocket;
public final void onThrowable(Throwable t) {
for (WebSocketListener listener : listeners) {
webSocket.addWebSocketListener(listener);
listener.onOpen(webSocket);
}
if (isNonEmpty(bufferedFrames)) {
for (Runnable bufferedFrame : bufferedFrames) {
bufferedFrame.run();
if (webSocket != null) {
webSocket.addWebSocketListener(listener);
}
bufferedFrames = null;
listener.onError(t);
}
ok.set(true);
}

@Override
public final void onFailure(Throwable t) {
public final void openWebSocket(WebSocket webSocket) {
this.webSocket = webSocket;
for (WebSocketListener listener : listeners) {
if (!ok.get() && webSocket != null) {
webSocket.addWebSocketListener(listener);
}
listener.onError(t);
webSocket.addWebSocketListener(listener);
listener.onOpen(webSocket);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
*/
package org.asynchttpclient.ws;

import static org.asynchttpclient.Dsl.*;
import static org.asynchttpclient.Dsl.asyncHttpClient;
import static org.testng.Assert.*;

import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -156,7 +157,7 @@ public void onError(Throwable t) {
}
}

@Test(groups = "online", timeOut = 60000, expectedExceptions = IllegalStateException.class)
@Test(groups = "online", timeOut = 60000, expectedExceptions = IOException.class)
public void wrongProtocolCode() throws Throwable {
try (AsyncHttpClient c = asyncHttpClient()) {
final CountDownLatch latch = new CountDownLatch(1);
Expand Down

0 comments on commit 821a0f7

Please sign in to comment.