1 module lighttp.router;
2 
3 import std.algorithm : max;
4 import std.base64 : Base64;
5 import std.conv : to;
6 import std.digest.sha : sha1Of;
7 import std.regex : Regex, isRegexFor, matchAll;
8 import std.string : startsWith;
9 import std.traits : Parameters;
10 
11 import libasync : NetworkAddress, AsyncTCPConnection;
12 
13 import lighttp.server : WebSocketClient;
14 import lighttp.util;
15 
16 struct HandleResult {
17 
18 	bool success;
19 	WebSocketClient webSocket = null;
20 	void delegate() callOnConnect;
21 
22 }
23 
24 pragma(msg, HandleResult.sizeof);
25 
26 class Router {
27 
28 	private Route[][string] routes;
29 
30 	void handle(ref HandleResult result, AsyncTCPConnection conn, Request req, Response res) {
31 		if(!req.path.startsWith("/")) {
32 			res.status = StatusCodes.badRequest;
33 		} else {
34 			auto routes = req.method in this.routes;
35 			if(routes) {
36 				foreach(route ; *routes) {
37 					route.handle(result, conn, req, res);
38 					if(result.success) return;
39 				}
40 			}
41 			res.status = StatusCodes.notFound;
42 		}
43 	}
44 	
45 	void error(Request req, Response res) {
46 		res.body_ = "<!DOCTYPE html><html><head><title>" ~ res.status.toString() ~ "</title></head><body><center><h1>" ~ res.status.toString() ~ "</h1></center><hr><center>" ~ res.headers.get("Server", "") ~ "</center></body></html>";
47 	}
48 	
49 	void add(T, E...)(RouteInfo!T info, void delegate(E) del) {
50 		this.routes[info.method] ~= new RouteOf!(T, E)(info.path, del);
51 	}
52 
53 	void add(T)(RouteInfo!T info, Resource resource) {
54 		this.add(info, (Request req, Response res){ resource.apply(req, res); });
55 	}
56 
57 	void add(T, E...)(string method, T path, void delegate(E) del) {
58 		this.add(RouteInfo!T(method, path), del);
59 	}
60 
61 	void add(T)(string method, T path, Resource resource) {
62 		this.add(RouteInfo!T(method, path), resource);
63 	}
64 
65 	void addWebSocket(W:WebSocketClient, T)(RouteInfo!T info, W delegate() del) {
66 		static if(__traits(hasMember, W, "onConnect")) this.routes[info.method] ~= new WebSocketRouteOf!(W, T, Parameters!(W.onConnect))(info.path, del);
67 		else this.routes[info.method] ~= new WebSocketRouteOf!(W, T)(info.path, del);
68 	}
69 
70 	void addWebSocket(W:WebSocketClient, T)(RouteInfo!T info) if(!__traits(isNested, W)) {
71 		this.addWebSocket!(W, T)(info, { return new W(); });
72 	}
73 
74 	void remove(T, E...)(RouteInfo!T info, void delegate(E) del) {
75 		//TODO
76 	}
77 	
78 }
79 
80 class Route {
81 
82 	abstract void handle(ref HandleResult result, AsyncTCPConnection conn, Request req, Response res);
83 
84 }
85 
86 class RouteImpl(T, E...) if(is(T == string) || isRegexFor!(T, string)) : Route {
87 
88 	private T path;
89 	
90 	static if(E.length) {
91 		static if(is(E[0] == NetworkAddress)) {
92 			enum __address = 0;
93 			static if(E.length > 1) {
94 				static if(is(E[1] == Request)) {
95 					static if(E.length > 2 && is(E[2] == Response)) {
96 						enum __response = 2;
97 					}
98 				} else static if(is(E[1] == Response)) {
99 					enum __response = 1;
100 				}
101 			}
102 		} else static if(is(E[0] == Request)) {
103 			enum __request = 0;
104 			static if(E.length > 1 && is(E[1] == Response)) enum __response = 1;
105 		} else static if(is(E[0] == Response)) {
106 			enum __response = 0;
107 		}
108 	}
109 	
110 	static if(!is(typeof(__address))) enum __address = -1;
111 	static if(!is(typeof(__request))) enum __request = -1;
112 	static if(!is(typeof(__response))) enum __response = -1;
113 	
114 	static if(__address == -1 && __request == -1 && __response == -1) {
115 		alias Args = E[0..0];
116 		alias Match = E[0..$];
117 	} else {
118 		enum _ = max(__address, __request, __response) + 1;
119 		alias Args = E[0.._];
120 		alias Match = E[_..$];
121 	}
122 	
123 	static assert(Match.length == 0 || !is(T == string));
124 	
125 	this(T path) {
126 		this.path = path;
127 	}
128 	
129 	void callImpl(void delegate(E) del, AsyncTCPConnection conn, Request req, Response res, Match match) {
130 		Args args;
131 		static if(__address != -1) args[__address] = conn.local;
132 		static if(__request != -1) args[__request] = req;
133 		static if(__response != -1) args[__response] = res;
134 		del(args, match);
135 	}
136 	
137 	abstract void call(ref HandleResult result, AsyncTCPConnection conn, Request req, Response res, Match match);
138 	
139 	override void handle(ref HandleResult result, AsyncTCPConnection conn, Request req, Response res) {
140 		static if(is(T == string)) {
141 			if(req.path == this.path) {
142 				this.call(result, conn, req, res);
143 				result.success = true;
144 			}
145 		} else {
146 			auto match = req.path.matchAll(this.path);
147 			if(match && match.post.length == 0) {
148 				string[] matches;
149 				foreach(m ; match.front) matches ~= m;
150 				Match args;
151 				static if(E.length == 1 && is(E[0] == string[])) {
152 					args[0] = matches[1..$];
153 				} else {
154 					if(matches.length != args.length + 1) throw new Exception("Arguments count mismatch"); //TODO do this check at compile time if possible
155 					static foreach(i ; 0..Match.length) {
156 						args[i] = to!(Match[i])(matches[i+1]);
157 					}
158 				}
159 				this.call(result, conn, req, res, args);
160 				result.success = true;
161 			}
162 		}
163 	}
164 	
165 }
166 
167 class RouteOf(T, E...) : RouteImpl!(T, E) {
168 
169 	private void delegate(E) del;
170 	
171 	this(T path, void delegate(E) del) {
172 		super(path);
173 		this.del = del;
174 	}
175 	
176 	override void call(ref HandleResult result, AsyncTCPConnection conn, Request req, Response res, Match match) {
177 		this.callImpl(this.del, conn, req, res, match);
178 	}
179 	
180 }
181 
182 class WebSocketRouteOf(WebSocket, T, E...) : RouteImpl!(T, E) {
183 
184 	private WebSocket delegate() createWebSocket;
185 
186 	this(T path, WebSocket delegate() createWebSocket) {
187 		super(path);
188 		this.createWebSocket = createWebSocket;
189 	}
190 
191 	override void call(ref HandleResult result, AsyncTCPConnection conn, Request req, Response res, Match match) {
192 		auto key = "sec-websocket-key" in req.headers;
193 		if(key) {
194 			res.status = StatusCodes.switchingProtocols;
195 			res.headers["Sec-WebSocket-Accept"] = Base64.encode(sha1Of(*key ~ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")).idup;
196 			res.headers["Connection"] = "upgrade";
197 			res.headers["Upgrade"] = "websocket";
198 			// create web socket and set callback for onConnect
199 			WebSocket webSocket = this.createWebSocket();
200 			webSocket.conn = conn;
201 			result.webSocket = webSocket;
202 			static if(__traits(hasMember, WebSocket, "onConnect")) result.callOnConnect = { this.callImpl(&webSocket.onConnect, conn, req, res, match); };
203 			else result.callOnConnect = {};
204 		} else {
205 			res.status = StatusCodes.notFound;
206 		}
207 	}
208 
209 }
210 
211 struct RouteInfo(T) if(is(T == string) || is(T == Regex!char) || isRegexFor!(T, string)) { 
212 
213 	string method;
214 	T path;
215 
216 }
217 
218 private enum isRouteInfo(T) = is(T : RouteInfo!R, R);
219 
220 auto CustomMethod(R)(string method, R path){ return RouteInfo!R(method, path); }
221 
222 auto Get(R)(R path){ return RouteInfo!R("GET", path); }
223 
224 auto Post(R)(R path){ return RouteInfo!R("POST", path); }
225 
226 void registerRoutes(R:Router)(R router) {
227 
228 	foreach(member ; __traits(allMembers, R)) {
229 		static if(__traits(getProtection, __traits(getMember, R, member)) == "public") {
230 			foreach(uda ; __traits(getAttributes, __traits(getMember, R, member))) {
231 				static if(isRouteInfo!(typeof(uda))) {
232 					mixin("alias M = router." ~ member ~ ";");
233 					static if(is(typeof(__traits(getMember, R, member)) == function)) {
234 						router.add(uda, mixin("&router." ~ member));
235 					} else static if(is(M == class)) {
236 						static if(__traits(isNested, M)) router.addWebSocket!M(uda, { return router..new M(); });
237 						else router.addWebSocket!M(uda);
238 					} else {
239 						router.add(uda, mixin("router." ~ member));
240 					}
241 				}
242 			}
243 		}
244 	}
245 	
246 }