1 module lighttp.server.router; 2 3 import std.algorithm : max; 4 import std.base64 : Base64; 5 import std.conv : to, ConvException; 6 import std.digest.sha : sha1Of; 7 import std.regex : Regex, isRegexFor, regex, matchAll; 8 import std.socket : Address; 9 import std.string : startsWith, join; 10 import std.traits : Parameters, hasUDA; 11 12 import libasync : AsyncTCPConnection; 13 14 import lighttp.server.resource; 15 import lighttp.server.server : ServerOptions, Connection, MultipartConnection, WebSocketConnection; 16 import lighttp.util; 17 18 struct HandleResult { 19 20 bool success; 21 Connection connection = null; 22 23 } 24 25 /** 26 * Router for handling requests. 27 */ 28 class Router { 29 30 private static Resource indexPage; 31 private static TemplatedResource errorPage; 32 33 static this() { 34 indexPage = new Resource("text/html", import("index.html")); 35 errorPage = new TemplatedResource("text/html", import("error.html")); 36 } 37 38 private Route[][string] routes; 39 40 private void delegate(ServerRequest, ServerResponse) _errorHandler; 41 42 this() { 43 this.add(Get(), indexPage); 44 _errorHandler = &this.defaultErrorHandler; 45 } 46 47 /* 48 * Handles a connection. 49 */ 50 void handle(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res) { 51 if(!req.url.path.startsWith("/")) { 52 res.status = StatusCodes.badRequest; 53 } else { 54 auto routes = req.method in this.routes; 55 if(routes) { 56 foreach_reverse(route ; *routes) { 57 route.handle(options, result, client, req, res); 58 if(result.success) return; 59 } 60 } 61 res.status = StatusCodes.notFound; 62 } 63 } 64 65 /* 66 * Handles a client or server error and displays an error 67 * page to the client. 68 */ 69 void handleError(ServerRequest req, ServerResponse res) { 70 _errorHandler(req, res); 71 } 72 73 private void defaultErrorHandler(ServerRequest req, ServerResponse res) { 74 errorPage.apply(["message": res.status.message, "error": res.status.toString(), "server": res.headers["Server"]]).apply(req, res); 75 } 76 77 /** 78 * Registers routes from a class's methods marked with the 79 * @Get, @Post and @CustomMethod attributes. 80 */ 81 void add(T)(T routes) { 82 foreach(member ; __traits(allMembers, T)) { 83 static if(__traits(getProtection, __traits(getMember, T, member)) == "public") { 84 foreach(uda ; __traits(getAttributes, __traits(getMember, T, member))) { 85 static if(is(typeof(uda)) && isRouteInfo!(typeof(uda))) { 86 mixin("alias M = routes." ~ member ~ ";"); 87 static if(is(typeof(__traits(getMember, T, member)) == function)) { 88 // function 89 this.add(uda, mixin("&routes." ~ member)); 90 } else static if(is(M == class)) { 91 // websocket 92 static if(__traits(isNested, M)) this.addWebSocket!M(uda, { return routes..new M(); }); 93 else this.addWebSocket!M(uda); 94 } else { 95 // member 96 this.add(uda, mixin("routes." ~ member)); 97 } 98 } 99 } 100 } 101 } 102 } 103 104 /** 105 * Adds a route. 106 */ 107 void add(T, E...)(RouteInfo!T info, void delegate(E) del) { 108 if(info.hasBody) this.routes[info.method] ~= new MultipartRouteOf!(T, E)(info.path, del); 109 else this.routes[info.method] ~= new RouteOf!(T, E)(info.path, del); 110 } 111 112 void add(T)(RouteInfo!T info, Resource resource) { 113 this.add(info, (ServerRequest req, ServerResponse res){ resource.apply(req, res); }); 114 } 115 116 void addWebSocket(W:WebSocketConnection, T)(RouteInfo!T info, W delegate() del) { 117 static if(__traits(hasMember, W, "onConnect")) this.routes[info.method] ~= new WebSocketRouteOf!(W, T, Parameters!(W.onConnect))(info.path, del); 118 else this.routes[info.method] ~= new WebSocketRouteOf!(W, T)(info.path, del); 119 } 120 121 void addWebSocket(W:WebSocketConnection, T)(RouteInfo!T info) if(!__traits(isNested, W)) { 122 this.addWebSocket!(W, T)(info, { return new W(); }); 123 } 124 125 void remove(T, E...)(RouteInfo!T info, void delegate(E) del) { 126 //TODO 127 } 128 129 } 130 131 class Route { 132 133 abstract void handle(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res); 134 135 } 136 137 class RouteImpl(T, E...) if(is(T == string) || isRegexFor!(T, string)) : Route { 138 139 private T path; 140 141 static if(E.length) { 142 static if(is(E[0] == ServerRequest)) { 143 enum __request = 0; 144 static if(E.length > 1 && is(E[1] == ServerResponse)) enum __response = 1; 145 } else static if(is(E[0] == ServerResponse)) { 146 enum __response = 0; 147 static if(E.length > 1 && is(E[1] == ServerRequest)) enum __request = 1; 148 } 149 } 150 151 static if(!is(typeof(__request))) enum __request = -1; 152 static if(!is(typeof(__response))) enum __response = -1; 153 154 static if(__request == -1 && __response == -1) { 155 alias Args = E[0..0]; 156 alias Match = E[0..$]; 157 } else { 158 enum _ = max(__request, __response) + 1; 159 alias Args = E[0.._]; 160 alias Match = E[_..$]; 161 } 162 163 static assert(Match.length == 0 || !is(T : string)); 164 165 this(T path) { 166 this.path = path; 167 } 168 169 void callImpl(void delegate(E) del, ServerOptions options, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) { 170 Args args; 171 static if(__request != -1) args[__request] = req; 172 static if(__response != -1) args[__response] = res; 173 del(args, match); 174 } 175 176 abstract void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match); 177 178 override void handle(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res) { 179 static if(is(T == string)) { 180 if(req.url.path[1..$] == this.path) { 181 this.call(options, result, client, req, res); 182 result.success = true; 183 } 184 } else { 185 auto match = req.url.path[1..$].matchAll(this.path); 186 if(match && match.post.length == 0) { 187 string[] matches; 188 foreach(m ; match.front) matches ~= m; 189 Match args; 190 static if(E.length == 1 && is(E[0] == string[])) { 191 args[0] = matches[1..$]; 192 } else { 193 if(matches.length != args.length + 1) throw new Exception("Arguments count mismatch"); 194 static foreach(i ; 0..Match.length) { 195 args[i] = to!(Match[i])(matches[i+1]); 196 } 197 } 198 this.call(options, result, client, req, res, args); 199 result.success = true; 200 } 201 } 202 } 203 204 } 205 206 class RouteOf(T, E...) : RouteImpl!(T, E) { 207 208 private void delegate(E) del; 209 210 this(T path, void delegate(E) del) { 211 super(path); 212 this.del = del; 213 } 214 215 override void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) { 216 this.callImpl(this.del, options, client, req, res, match); 217 } 218 219 } 220 221 class MultipartRouteOf(T, E...) : RouteOf!(T, E) { 222 223 this(T path, void delegate(E) del) { 224 super(path, del); 225 } 226 227 override void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) { 228 if(auto lstr = "content-length" in req.headers) { 229 try { 230 size_t length = to!size_t(*lstr); 231 if(length > options.max) { 232 result.success = false; 233 res.status = StatusCodes.payloadTooLarge; 234 } else if(req.body_.length >= length) { 235 return super.call(options, result, client, req, res, match); 236 } else { 237 // wait for full data 238 result.connection = new MultipartConnection(client, length, req, res, { super.call(options, result, client, req, res, match); }); 239 res.ready = false; 240 return; 241 } 242 } catch(ConvException) { 243 result.success = false; 244 res.status = StatusCodes.badRequest; 245 } 246 } else { 247 // assuming body has no content 248 super.call(options, result, client, req, res, match); 249 } 250 } 251 252 } 253 254 class WebSocketRouteOf(WebSocket, T, E...) : RouteImpl!(T, E) { 255 256 private WebSocket delegate() createWebSocket; 257 258 this(T path, WebSocket delegate() createWebSocket) { 259 super(path); 260 this.createWebSocket = createWebSocket; 261 } 262 263 override void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) { 264 auto key = "sec-websocket-key" in req.headers; 265 if(key) { 266 res.status = StatusCodes.switchingProtocols; 267 res.headers["Sec-WebSocket-Accept"] = Base64.encode(sha1Of(*key ~ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")).idup; 268 res.headers["Connection"] = "upgrade"; 269 res.headers["Upgrade"] = "websocket"; 270 // create web socket and set callback for onConnect 271 WebSocket webSocket = this.createWebSocket(); 272 webSocket.conn = client; 273 result.connection = webSocket; 274 static if(__traits(hasMember, WebSocket, "onConnect")) webSocket.onStartImpl = { this.callImpl(&webSocket.onConnect, options, client, req, res, match); }; 275 } else { 276 res.status = StatusCodes.notFound; 277 } 278 } 279 280 } 281 282 struct RouteInfo(T) if(is(T : string) || is(T == Regex!char) || isRegexFor!(T, string)) { 283 284 string method; 285 bool hasBody; 286 T path; 287 288 } 289 290 auto routeInfo(E...)(string method, bool hasBody, E path) { 291 static if(E.length == 0) { 292 return routeInfo(method, hasBody, ""); 293 } else static if(E.length == 1) { 294 static if(isRegexFor!(E[0], string)) return RouteInfo!E(method, hasBody, path); 295 else return RouteInfo!(Regex!char)(method, hasBody, regex(path)); 296 } else { 297 string[] p; 298 foreach(pp ; path) p ~= pp; 299 return RouteInfo!(Regex!char)(method, hasBody, regex(p.join(`\/`))); 300 } 301 } 302 303 private enum isRouteInfo(T) = is(T : RouteInfo!R, R); 304 305 auto CustomMethod(R)(string method, bool hasBody, R path){ return routeInfo!R(method, hasBody, path); } 306 307 auto Get(R...)(R path){ return routeInfo!R("GET", false, path); } 308 309 auto Post(R...)(R path){ return routeInfo!R("POST", true, path); } 310 311 auto Put(R...)(R path){ return routeInfo!R("PUT", true, path); } 312 313 auto Delete(R...)(R path){ return routeInfo!R("DELETE", false, path); } 314 315 void registerRoutes(R)(Router register, R router) { 316 317 foreach(member ; __traits(allMembers, R)) { 318 static if(__traits(getProtection, __traits(getMember, R, member)) == "public") { 319 foreach(uda ; __traits(getAttributes, __traits(getMember, R, member))) { 320 static if(is(typeof(uda)) && isRouteInfo!(typeof(uda))) { 321 mixin("alias M = router." ~ member ~ ";"); 322 static if(is(typeof(__traits(getMember, R, member)) == function)) { 323 // function 324 static if(hasUDA!(__traits(getMember, R, member), Multipart)) register.addMultipart(uda, mixin("&router." ~ member)); 325 else register.add(uda, mixin("&router." ~ member)); 326 } else static if(is(M == class)) { 327 // websocket 328 static if(__traits(isNested, M)) register.addWebSocket!M(uda, { return router..new M(); }); 329 else register.addWebSocket!M(uda); 330 } else { 331 // member 332 register.add(uda, mixin("router." ~ member)); 333 } 334 } 335 } 336 } 337 } 338 339 }