Netty : Set-Cookie in WebSocket Handshake -
my pipeline looks below
channelpipeline pipeline = ch.pipeline(); pipeline.addlast(new httpservercodec()); pipeline.addlast(new httpobjectaggregator(65536)); pipeline.addlast(new websocketserverprotocolhandler(websocket_path, null, true));
i want add set-cookie
http header in response of handshake. part of rfc6455
the handshake server looks follows:
connection:upgrade sec-websocket-accept:t1ugq4hht3dvlnq5yi+i/gfasi8= upgrade:websocket set-cookie: ccc=22; path=/; httponly
an unordered set of header fields comes after leading line in both cases. meaning of these header fields specified in section 4 of document. additional header fields may present, such cookies [rfc6265].
i don't find way. did invoking private method via reflection.
netty 4.1.2.final
first find source code of websocketserverprotocolhandshakehandler
class. class non-public, make copy of class , modifies basing on it.
class customwebsocketserverprotocolhandshakehandler extends channelinboundhandleradapter { private final string websocketpath; private final string subprotocols; private final boolean allowextensions; private final int maxframepayloadsize; private final boolean allowmaskmismatch; static final methodhandle sethandshakermethod = getsethandshakermethod(); static final methodhandle forbiddenhttprequestrespondermethod = getforbiddenhttprequestrespondermethod(); static methodhandle getsethandshakermethod(){ try { method method = websocketserverprotocolhandler.class.getdeclaredmethod("sethandshaker" , channel.class , websocketserverhandshaker.class ); method.setaccessible(true); return methodhandles.lookup().unreflect(method); } catch (throwable e) { // should never happen e.printstacktrace(); system.exit(5); return null; } } static methodhandle getforbiddenhttprequestrespondermethod(){ try { method method = websocketserverprotocolhandler.class.getdeclaredmethod("forbiddenhttprequestresponder"); method.setaccessible(true); return methodhandles.lookup().unreflect(method); } catch (throwable e) { // should never happen e.printstacktrace(); system.exit(6); return null; } } public customwebsocketserverprotocolhandshakehandler(string websocketpath, string subprotocols, boolean allowextensions, int maxframesize, boolean allowmaskmismatch) { this.websocketpath = websocketpath; this.subprotocols = subprotocols; this.allowextensions = allowextensions; maxframepayloadsize = maxframesize; this.allowmaskmismatch = allowmaskmismatch; } @override public void channelread(final channelhandlercontext ctx, object msg) throws exception { fullhttprequest req = (fullhttprequest) msg; if (!websocketpath.equals(req.uri())) { ctx.firechannelread(msg); return; } try { if (req.method() != get) { sendhttpresponse(ctx, req, new defaultfullhttpresponse(http_1_1, forbidden)); return; } final websocketserverhandshakerfactory wsfactory = new websocketserverhandshakerfactory( getwebsocketlocation(ctx.pipeline(), req, websocketpath), subprotocols, allowextensions, maxframepayloadsize, allowmaskmismatch); final websocketserverhandshaker handshaker = wsfactory.newhandshaker(req); if (handshaker == null) { websocketserverhandshakerfactory.sendunsupportedversionresponse(ctx.channel()); } else { channel channel = ctx.channel(); final channelfuture handshakefuture = handshaker.handshake(channel, req, getresponseheaders(req), channel.newpromise()); handshakefuture.addlistener(new channelfuturelistener() { @override public void operationcomplete(channelfuture future) throws exception { if (!future.issuccess()) { ctx.fireexceptioncaught(future.cause()); } else { ctx.fireusereventtriggered( websocketserverprotocolhandler.serverhandshakestateevent.handshake_complete); } } }); try { sethandshakermethod.invokeexact(ctx.channel(), handshaker); channelhandler handler = (channelhandler)forbiddenhttprequestrespondermethod.invokeexact(); ctx.pipeline().replace(this, "ws403responder", handler); } catch (throwable e) { // should never happen e.printstacktrace(); system.exit(7); } } } { req.release(); } } private static void sendhttpresponse(channelhandlercontext ctx, httprequest req, httpresponse res) { channelfuture f = ctx.channel().writeandflush(res); if (!iskeepalive(req) || res.status().code() != 200) { f.addlistener(channelfuturelistener.close); } } private static string getwebsocketlocation(channelpipeline cp, httprequest req, string path) { string protocol = "ws"; if (cp.get(sslhandler.class) != null) { // ssl in use use secure websockets protocol = "wss"; } return protocol + "://" + req.headers().get(httpheadernames.host) + path; } private static httpheaders getresponseheaders(fullhttprequest req){ final string cookiename = "cid"; final defaulthttpheaders httpheaders = new defaulthttpheaders(); string connectionid = null; string cookiestring = req.headers().get(httpheadernames.cookie); if( cookiestring != null && cookiestring.length() > 0 ) { set<cookie> cookies = servercookiedecoder.lax.decode(cookiestring); (cookie cookie : cookies) { if( cookiename.equalsignorecase(cookie.name())){ connectionid = cookie.value(); break; } } } if( connectionid == null || connectionid.length() < 16 || connectionid.length() > 50 ){ connectionid = uuid.randomuuid().tostring().replaceall("-", ""); } defaultcookie cookie = new defaultcookie("cid", connectionid); cookie.setpath("/"); cookie.sethttponly(true); cookie.setsecure(false); httpheaders.add(httpheadernames.set_cookie, servercookieencoder.lax.encode(cookie)); return httpheaders; } }
then add new class inherits websocketserverprotocolhandler
class customwebsocketserverprotocolhandler extends websocketserverprotocolhandler { private final string websocketpath; private final string subprotocols; private final boolean allowextensions; private final int maxframepayloadlength; private final boolean allowmaskmismatch; public customwebsocketserverprotocolhandler(string websocketpath, string subprotocols, boolean allowextensions) { this(websocketpath, subprotocols, allowextensions, 65536, false); // todo auto-generated constructor stub } public customwebsocketserverprotocolhandler(string websocketpath, string subprotocols, boolean allowextensions, int maxframesize, boolean allowmaskmismatch) { super(websocketpath, subprotocols, allowextensions, maxframesize, allowmaskmismatch); this.websocketpath = websocketpath; this.subprotocols = subprotocols; this.allowextensions = allowextensions; maxframepayloadlength = maxframesize; this.allowmaskmismatch = allowmaskmismatch; } @override public void handleradded(channelhandlercontext ctx) { channelpipeline cp = ctx.pipeline(); if (cp.get(customwebsocketserverprotocolhandshakehandler.class) == null) { // add websockethandshakehandler before one. ctx.pipeline().addbefore(ctx.name(), customwebsocketserverprotocolhandshakehandler.class.getname(), new customwebsocketserverprotocolhandshakehandler(websocketpath, subprotocols, allowextensions, maxframepayloadlength, allowmaskmismatch)); } if (cp.get(utf8framevalidator.class) == null) { // add uft8 checking before one. ctx.pipeline().addbefore(ctx.name(), utf8framevalidator.class.getname(), new utf8framevalidator()); } } }
put them pipeline
pipeline.addlast(new httpservercodec()); pipeline.addlast(new httpobjectaggregator(65536)); pipeline.addlast(new websocketservercompressionhandler()); pipeline.addlast(new customwebsocketserverprotocolhandler(websocket_path, "*", true));
Comments
Post a Comment