1 module grpc.service;
2 import core.atomic;
3 import interop.headers;
4 import grpc.logger;
5 import core.thread;
6 import grpc.server : Server;
7 import grpc.common.call;
8 import google.rpc.status;
9 import grpc.core.tag;
10 import grpc.common.cq;
11 import grpc.core.utils;
12 import std.experimental.allocator : theAllocator, make, dispose;
13 
14 // Every Service template class is guaranteed to at least implement these functions
15 interface ServiceHandlerInterface {
16     bool register(Server server);
17     void stop();
18     void kickstart();
19     ulong runners();
20 }
21 
22 // Since we dynamically generate the function handler through UDAs,
23 // we need some way to get the type that it is expecting (be it, through a ServerReader/ServerWriter interface, or a POD type)
24 // These two templates resolve that, and get us the type.
25 
26 mixin template Reader(T) {
27     import std.traits : TemplateArgsOf, TemplateOf;
28     static if(is(TemplateOf!T == void)) {
29         alias input = T;
30     }
31     else {
32         alias input = TemplateArgsOf!(T)[0];
33     }
34 
35 }
36 
37 mixin template Writer(T) {
38     import std.traits : TemplateArgsOf, TemplateOf;
39     static if(is(TemplateOf!T == void)) {
40         alias output = T;
41     }
42     else {
43         alias output = TemplateArgsOf!(T)[0];
44     }
45 }
46 class ServicerThread(T) : Thread {
47     this() {
48         super(&run);
49     }
50     
51     ~this() {
52     }
53     
54     /* Set by the main thread */
55     shared bool threadReady;
56     shared bool threadStart;
57     ulong workerIndex;
58     void*[string] registeredMethods;
59     shared(Server) server;
60 
61 private:
62     /* Thread local things */
63     shared(CompletionQueue!"Next") notificationCq;
64     CompletionQueue!"Next" callCq;
65     Tag*[] tags;
66     T instance;
67     shared bool _run;
68 
69     void handleTag(Tag* tag) {
70         import std.traits : Parameters, BaseTypeTuple, getSymbolsByUDA, hasUDA;
71         import core.time : MonoTime;
72 
73         alias parent = BaseTypeTuple!T[1];
74         if (!tag) return;
75         if (tag.metadata[0] != 0xDE) return;
76         if (tag.metadata[2] != workerIndex) return;
77 
78         tag.ctx.mutex.lock;
79         scope(exit) tag.ctx.mutex.unlock;
80         tag.ctx.timestamp = MonoTime.currTime;
81 
82         sw: switch (tag.metadata[1]) {
83             static foreach(i, val; getSymbolsByUDA!(parent, RPC)) {{
84                 import grpc.stream.server.reader : ServerReader;
85                 import grpc.stream.server.writer : ServerWriter;
86                 mixin Reader!(Parameters!val[0]);
87                 mixin Writer!(Parameters!val[1]);
88                 alias SR = ServerReader!input;
89                 alias SW = ServerWriter!output;
90                 case i: {
91                     enum ServerStream = hasUDA!(val, ServerStreaming);
92                     enum ClientStream = hasUDA!(val, ClientStreaming);
93                     SR reader = SR(tag, callCq);
94                     SW writer = SW(tag, callCq);
95 
96                     Status stat;
97                     input funcIn;
98                     output funcOut;
99                     try {
100                         static if(!ServerStream && !ClientStream) {
101                             // unary call
102                             funcIn = reader.readOne();
103                             stat = __traits(child, instance, val)(funcIn, funcOut);
104                             writer.start();
105                             writer.write(funcOut);
106                         } else static if (ServerStream && ClientStream) {
107                             // bidi call
108                             stat = __traits(child, instance, val)(reader, writer);
109                         } else static if (!ServerStream && ClientStream) {
110                             // client streaming call
111                             stat = __traits(child, instance, val)(reader, funcOut);
112                             writer.start();
113                             writer.write(funcOut);
114                         } else static if (ServerStream && !ClientStream) {
115                             // server streaming call
116                             funcIn = reader.readOne();
117                             writer.start();
118                             stat = __traits(child, instance, val)(funcIn, writer);
119                         }
120                     } catch (Exception e) {
121                         grpc_call_cancel(*tag.ctx.call, null);
122                         stat.code = GRPC_STATUS_INTERNAL;
123                         stat.message = e.msg;
124                     }
125 
126                     writer.finish(stat);
127                     reader.finish();
128 
129                     tag.ctx.metadata.cleanup;
130                     if (tag.ctx.data.valid) {
131                         tag.ctx.data.cleanup;
132                     }
133 
134                     grpc_call_unref(*tag.ctx.call);
135                     *tag.ctx.call = null;
136                     break sw;
137                 }
138             }}
139 
140             default:
141                 assert(0, "Received tag with function index out of bounds");
142         }
143     }
144 
145     void run() {
146         import std.traits : getUDAs, getSymbolsByUDA, BaseTypeTuple;
147         import std.experimental.allocator.mallocator: Mallocator;
148         import std.experimental.allocator : theAllocator, allocatorObject;
149         
150         theAllocator = allocatorObject(Mallocator.instance);
151 
152         instance = theAllocator.make!T();
153         notificationCq = cast(shared)theAllocator.make!(CompletionQueue!"Next")();
154         callCq = theAllocator.make!(CompletionQueue!"Next")();
155 
156         DEBUG!"registering (thread: %d)"(workerIndex);
157         server.registerQueue(notificationCq);
158         DEBUG!"registered (thread: %d)"(workerIndex);
159 
160         // Block while the rest of the threads spool up, and wait for the server to signal
161         // that it has started, and it is safe to request calls on our CQ
162         atomicStore(threadReady, true);
163         while (!atomicLoad(threadStart)) {
164              Thread.sleep(1.msecs);
165         }
166 
167         DEBUG!"beginning phase 2 (thread: %d)"(workerIndex);
168         alias parent = BaseTypeTuple!T[1];
169         static foreach(i, val; getSymbolsByUDA!(parent, RPC)) {{
170             static if (i > ubyte.max) {
171                 static assert(0, "Too many RPC functions!");
172             }
173 
174             enum remoteName = getUDAs!(val, RPC)[0].methodName;
175             Tag* tag = Tag();
176             tags ~= tag;
177             // magic number
178             tag.metadata[0] = 0xDE;
179             tag.metadata[1] = cast(ubyte)i;
180             tag.metadata[2] = cast(ubyte)workerIndex;
181             tag.method = registeredMethods[remoteName];
182             tag.methodName = remoteName;
183             callCq.requestCall(tag.method, tag, server, notificationCq);
184         }}
185 
186         /*
187             PHASE 3:
188             Here, we begin the main loop to service requests. This continues,
189             until the queue is shutdown, or _run is set to false.
190         */
191 
192         atomicStore(_run, true);
193         while (atomicLoad(_run)) {
194             auto item = notificationCq.next(10.seconds);
195             notificationCq.lock();
196             scope(exit) notificationCq.unlock();
197             if (item.type == GRPC_OP_COMPLETE) {
198                 DEBUG!"hello from task %d"(workerIndex);
199                 DEBUG!"hit something";
200             } else if (item.type == GRPC_QUEUE_SHUTDOWN) {
201                 DEBUG!"shutdown";
202                 _run = false;
203                 continue;
204             } else if(item.type == GRPC_QUEUE_TIMEOUT) {
205                 DEBUG!"timeout";
206                 continue;
207             }
208 
209             if (notificationCq.inShutdownPath) {
210                 DEBUG!"we are in shutdown path";
211                 _run = false;
212                 break;
213             }
214 
215             DEBUG!"grabbing tag";
216             Tag* tag = cast(Tag*)item.tag;
217             DEBUG!"got tag: %x"(tag);
218 
219             // xxx: should never happen
220             if (tag == null) {
221                 ERROR!"got null tag?";
222                 continue;
223             }
224 
225             handleTag(tag);
226 
227             grpc_call_error error = callCq.requestCall(tag.method, tag, server, notificationCq);
228             if (error != GRPC_CALL_OK) {
229                 ERROR!"could not request call %s"(error);
230             }
231         }
232 
233         theAllocator.dispose(callCq);
234         theAllocator.dispose(notificationCq);
235         theAllocator.dispose(instance);
236     }
237 }
238 class Service(T) : ServiceHandlerInterface 
239 if (is(T == class)) {
240     private {
241         void*[string] registeredMethods;
242         Server _server;
243         ThreadGroup threads;
244         ServicerThread!T[] _threads; // do not ever use
245         immutable ulong workingThreads;
246         immutable ulong _serviceId;
247     }
248 
249     // function may be called by another thread other then the main() thread
250     // make sure that doesnt muck up
251     
252     ulong runners() {
253         ulong r = 0;
254         foreach(thread; threads) {
255             if (thread.isRunning()) r += 1;
256         }
257         
258         return r;
259     }
260 
261     void kickstart() {
262         foreach(thread; _threads) {
263             atomicStore(thread.threadStart, true);
264         }
265     }
266 
267     void stop() {
268         // block while every thread terminates
269         threads.joinAll();
270     }
271 
272     // this function will always be called by main()
273 
274     bool register(Server server) {
275         import std.traits : BaseTypeTuple;
276         alias parent = BaseTypeTuple!T[1];
277 
278         /* Fork and spawn new threads for each worker */
279         for (ulong i = 0; i < workingThreads; i++) {
280             auto t = new ServicerThread!T();
281             //t.handlers = _handlers;
282             t.workerIndex = i;
283             // We *should* block for this thread to stop execution at the end
284             t.isDaemon = false;
285             t.registeredMethods = registeredMethods;
286             t.server = cast(shared)server;
287             t.start();
288 
289             threads.add(t);
290             _threads ~= t;
291         }
292 
293         // avoid race condition while we wait for all threads to fully spool
294 
295         loop: while (true) {
296             foreach(thread; _threads) {
297                 if (!atomicLoad(thread.threadReady)) {
298                     Thread.sleep(1.msecs);
299                     continue loop;
300                 }
301             }
302             break;
303         }
304 
305         return true;
306     }
307 
308     this(ulong serviceId, void*[string] methodTable) {
309         debug import std.stdio;
310         debug writefln("passed method table: %s", methodTable);
311         registeredMethods = methodTable.dup;
312         _serviceId = serviceId;
313         threads = new ThreadGroup();
314 
315         // TODO: make this user-specifiable
316         workingThreads = 1;
317     }
318         
319 }
320