1 module lighttp.server.router;
2 
3 import std.algorithm : max;
4 import std.base64 : Base64;
5 import std.conv : to, ConvException;
6 import std.digest.sha : sha1Of;
7 import std.regex : Regex, isRegexFor, regex, matchAll;
8 import std.socket : Address;
9 import std.string : startsWith, join;
10 import std.traits : Parameters, hasUDA;
11 
12 import libasync : AsyncTCPConnection;
13 
14 import lighttp.server.resource;
15 import lighttp.server.server : ServerOptions, Connection, MultipartConnection, WebSocketConnection;
16 import lighttp.util;
17 
18 struct HandleResult {
19 
20 	bool success;
21 	Connection connection = null;
22 
23 }
24 
25 /**
26  * Router for handling requests.
27  */
28 class Router {
29 
30 	private static Resource indexPage;
31 	private static TemplatedResource errorPage;
32 
33 	static this() {
34 		indexPage = new Resource("text/html", import("index.html"));
35 		errorPage = new TemplatedResource("text/html", import("error.html"));
36 	}
37 
38 	private Route[][string] routes;
39 
40 	private void delegate(ServerRequest, ServerResponse) _errorHandler;
41 
42 	this() {
43 		this.add(Get(), indexPage);
44 		_errorHandler = &this.defaultErrorHandler;
45 	}
46 
47 	/*
48 	 * Handles a connection.
49 	 */
50 	void handle(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res) {
51 		if(!req.url.path.startsWith("/")) {
52 			res.status = StatusCodes.badRequest;
53 		} else {
54 			auto routes = req.method in this.routes;
55 			if(routes) {
56 				foreach_reverse(route ; *routes) {
57 					route.handle(options, result, client, req, res);
58 					if(result.success) return;
59 				}
60 			}
61 			res.status = StatusCodes.notFound;
62 		}
63 	}
64 
65 	/*
66 	 * Handles a client or server error and displays an error
67 	 * page to the client.
68 	 */
69 	void handleError(ServerRequest req, ServerResponse res) {
70 		_errorHandler(req, res);
71 	}
72 
73 	private void defaultErrorHandler(ServerRequest req, ServerResponse res) {
74 		errorPage.apply(["message": res.status.message, "error": res.status.toString(), "server": res.headers["Server"]]).apply(req, res);
75 	}
76 
77 	/**
78 	 * Registers routes from a class's methods marked with the
79 	 * @Get, @Post and @CustomMethod attributes.
80 	 */
81 	void add(T)(T routes) {
82 		foreach(member ; __traits(allMembers, T)) {
83 			static if(__traits(getProtection, __traits(getMember, T, member)) == "public") {
84 				foreach(uda ; __traits(getAttributes, __traits(getMember, T, member))) {
85 					static if(is(typeof(uda)) && isRouteInfo!(typeof(uda))) {
86 						mixin("alias M = routes." ~ member ~ ";");
87 						static if(is(typeof(__traits(getMember, T, member)) == function)) {
88 							// function
89 							this.add(uda, mixin("&routes." ~ member));
90 						} else static if(is(M == class)) {
91 							// websocket
92 							static if(__traits(isNested, M)) this.addWebSocket!M(uda, { return routes..new M(); });
93 							else this.addWebSocket!M(uda);
94 						} else {
95 							// member
96 							this.add(uda, mixin("routes." ~ member));
97 						}
98 					}
99 				}
100 			}
101 		}
102 	}
103 
104 	/**
105 	 * Adds a route.
106 	 */
107 	void add(T, E...)(RouteInfo!T info, void delegate(E) del) {
108 		if(info.hasBody) this.routes[info.method] ~= new MultipartRouteOf!(T, E)(info.path, del);
109 		else this.routes[info.method] ~= new RouteOf!(T, E)(info.path, del);
110 	}
111 
112 	void add(T)(RouteInfo!T info, Resource resource) {
113 		this.add(info, (ServerRequest req, ServerResponse res){ resource.apply(req, res); });
114 	}
115 
116 	void addWebSocket(W:WebSocketConnection, T)(RouteInfo!T info, W delegate() del) {
117 		static if(__traits(hasMember, W, "onConnect")) this.routes[info.method] ~= new WebSocketRouteOf!(W, T, Parameters!(W.onConnect))(info.path, del);
118 		else this.routes[info.method] ~= new WebSocketRouteOf!(W, T)(info.path, del);
119 	}
120 
121 	void addWebSocket(W:WebSocketConnection, T)(RouteInfo!T info) if(!__traits(isNested, W)) {
122 		this.addWebSocket!(W, T)(info, { return new W(); });
123 	}
124 
125 	void remove(T, E...)(RouteInfo!T info, void delegate(E) del) {
126 		//TODO
127 	}
128 	
129 }
130 
131 class Route {
132 
133 	abstract void handle(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res);
134 
135 }
136 
137 class RouteImpl(T, E...) if(is(T == string) || isRegexFor!(T, string)) : Route {
138 
139 	private T path;
140 	
141 	static if(E.length) {
142 		static if(is(E[0] == ServerRequest)) {
143 			enum __request = 0;
144 			static if(E.length > 1 && is(E[1] == ServerResponse)) enum __response = 1;
145 		} else static if(is(E[0] == ServerResponse)) {
146 			enum __response = 0;
147 			static if(E.length > 1 && is(E[1] == ServerRequest)) enum __request = 1;
148 		}
149 	}
150 
151 	static if(!is(typeof(__request))) enum __request = -1;
152 	static if(!is(typeof(__response))) enum __response = -1;
153 	
154 	static if(__request == -1 && __response == -1) {
155 		alias Args = E[0..0];
156 		alias Match = E[0..$];
157 	} else {
158 		enum _ = max(__request, __response) + 1;
159 		alias Args = E[0.._];
160 		alias Match = E[_..$];
161 	}
162 
163 	static assert(Match.length == 0 || !is(T : string));
164 	
165 	this(T path) {
166 		this.path = path;
167 	}
168 	
169 	void callImpl(void delegate(E) del, ServerOptions options, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) {
170 		Args args;
171 		static if(__request != -1) args[__request] = req;
172 		static if(__response != -1) args[__response] = res;
173 		del(args, match);
174 	}
175 	
176 	abstract void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match);
177 	
178 	override void handle(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res) {
179 		static if(is(T == string)) {
180 			if(req.url.path[1..$] == this.path) {
181 				this.call(options, result, client, req, res);
182 				result.success = true;
183 			}
184 		} else {
185 			auto match = req.url.path[1..$].matchAll(this.path);
186 			if(match && match.post.length == 0) {
187 				string[] matches;
188 				foreach(m ; match.front) matches ~= m;
189 				Match args;
190 				static if(E.length == 1 && is(E[0] == string[])) {
191 					args[0] = matches[1..$];
192 				} else {
193 					if(matches.length != args.length + 1) throw new Exception("Arguments count mismatch");
194 					static foreach(i ; 0..Match.length) {
195 						args[i] = to!(Match[i])(matches[i+1]);
196 					}
197 				}
198 				this.call(options, result, client, req, res, args);
199 				result.success = true;
200 			}
201 		}
202 	}
203 	
204 }
205 
206 class RouteOf(T, E...) : RouteImpl!(T, E) {
207 
208 	private void delegate(E) del;
209 	
210 	this(T path, void delegate(E) del) {
211 		super(path);
212 		this.del = del;
213 	}
214 	
215 	override void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) {
216 		this.callImpl(this.del, options, client, req, res, match);
217 	}
218 	
219 }
220 
221 class MultipartRouteOf(T, E...) : RouteOf!(T, E) {
222 
223 	this(T path, void delegate(E) del) {
224 		super(path, del);
225 	}
226 
227 	override void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) {
228 		if(auto lstr = "content-length" in req.headers) {
229 			try {
230 				size_t length = to!size_t(*lstr);
231 				if(length > options.max) {
232 					result.success = false;
233 					res.status = StatusCodes.payloadTooLarge;
234 				} else if(req.body_.length >= length) {
235 					return super.call(options, result, client, req, res, match);
236 				} else {
237 					// wait for full data
238 					result.connection = new MultipartConnection(client, length, req, res, { super.call(options, result, client, req, res, match); });
239 					res.ready = false;
240 					return;
241 				}
242 			} catch(ConvException) {
243 				result.success = false;
244 				res.status = StatusCodes.badRequest;
245 			}
246 		} else {
247 			// assuming body has no content
248 			super.call(options, result, client, req, res, match);
249 		}
250 	}
251 
252 }
253 
254 class WebSocketRouteOf(WebSocket, T, E...) : RouteImpl!(T, E) {
255 
256 	private WebSocket delegate() createWebSocket;
257 
258 	this(T path, WebSocket delegate() createWebSocket) {
259 		super(path);
260 		this.createWebSocket = createWebSocket;
261 	}
262 
263 	override void call(ServerOptions options, ref HandleResult result, AsyncTCPConnection client, ServerRequest req, ServerResponse res, Match match) {
264 		auto key = "sec-websocket-key" in req.headers;
265 		if(key) {
266 			res.status = StatusCodes.switchingProtocols;
267 			res.headers["Sec-WebSocket-Accept"] = Base64.encode(sha1Of(*key ~ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")).idup;
268 			res.headers["Connection"] = "upgrade";
269 			res.headers["Upgrade"] = "websocket";
270 			// create web socket and set callback for onConnect
271 			WebSocket webSocket = this.createWebSocket();
272 			webSocket.conn = client;
273 			result.connection = webSocket;
274 			static if(__traits(hasMember, WebSocket, "onConnect")) webSocket.onStartImpl = { this.callImpl(&webSocket.onConnect, options, client, req, res, match); };
275 		} else {
276 			res.status = StatusCodes.notFound;
277 		}
278 	}
279 
280 }
281 
282 struct RouteInfo(T) if(is(T : string) || is(T == Regex!char) || isRegexFor!(T, string)) {
283 	
284 	string method;
285 	bool hasBody;
286 	T path;
287 
288 }
289 
290 auto routeInfo(E...)(string method, bool hasBody, E path) {
291 	static if(E.length == 0) {
292 		return routeInfo(method, hasBody, "");
293 	} else static if(E.length == 1) {
294 		static if(isRegexFor!(E[0], string)) return RouteInfo!E(method, hasBody, path);
295 		else return RouteInfo!(Regex!char)(method, hasBody, regex(path));
296 	} else {
297 		string[] p;
298 		foreach(pp ; path) p ~= pp;
299 		return RouteInfo!(Regex!char)(method, hasBody, regex(p.join(`\/`)));
300 	}
301 }
302 
303 private enum isRouteInfo(T) = is(T : RouteInfo!R, R);
304 
305 auto CustomMethod(R)(string method, bool hasBody, R path){ return routeInfo!R(method, hasBody, path); }
306 
307 auto Get(R...)(R path){ return routeInfo!R("GET", false, path); }
308 
309 auto Post(R...)(R path){ return routeInfo!R("POST", true, path); }
310 
311 auto Put(R...)(R path){ return routeInfo!R("PUT", true, path); }
312 
313 auto Delete(R...)(R path){ return routeInfo!R("DELETE", false, path); }
314 
315 void registerRoutes(R)(Router register, R router) {
316 
317 	foreach(member ; __traits(allMembers, R)) {
318 		static if(__traits(getProtection, __traits(getMember, R, member)) == "public") {
319 			foreach(uda ; __traits(getAttributes, __traits(getMember, R, member))) {
320 				static if(is(typeof(uda)) && isRouteInfo!(typeof(uda))) {
321 					mixin("alias M = router." ~ member ~ ";");
322 					static if(is(typeof(__traits(getMember, R, member)) == function)) {
323 						// function
324 						static if(hasUDA!(__traits(getMember, R, member), Multipart)) register.addMultipart(uda, mixin("&router." ~ member));
325 						else register.add(uda, mixin("&router." ~ member));
326 					} else static if(is(M == class)) {
327 						// websocket
328 						static if(__traits(isNested, M)) register.addWebSocket!M(uda, { return router..new M(); });
329 						else register.addWebSocket!M(uda);
330 					} else {
331 						// member
332 						register.add(uda, mixin("router." ~ member));
333 					}
334 				}
335 			}
336 		}
337 	}
338 	
339 }