]> Pileus Git - ~andy/linux/blob - drivers/staging/hv/Connection.c
Staging: hv: rework use of workqueues in osd
[~andy/linux] / drivers / staging / hv / Connection.c
1 /*
2  *
3  * Copyright (c) 2009, Microsoft Corporation.
4  *
5  * This program is free software; you can redistribute it and/or modify it
6  * under the terms and conditions of the GNU General Public License,
7  * version 2, as published by the Free Software Foundation.
8  *
9  * This program is distributed in the hope it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
12  * more details.
13  *
14  * You should have received a copy of the GNU General Public License along with
15  * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
16  * Place - Suite 330, Boston, MA 02111-1307 USA.
17  *
18  * Authors:
19  *   Haiyang Zhang <haiyangz@microsoft.com>
20  *   Hank Janssen  <hjanssen@microsoft.com>
21  *
22  */
23
24
25 #include "include/logging.h"
26
27 #include "VmbusPrivate.h"
28
29 /* Globals */
30
31
32 struct VMBUS_CONNECTION gVmbusConnection = {
33         .ConnectState           = Disconnected,
34         .NextGpadlHandle        = 0xE1E10,
35 };
36
37
38 /*++
39
40 Name:
41         VmbusConnect()
42
43 Description:
44         Sends a connect request on the partition service connection
45
46 --*/
47 static int
48 VmbusConnect(void)
49 {
50         int ret=0;
51         VMBUS_CHANNEL_MSGINFO *msgInfo=NULL;
52         VMBUS_CHANNEL_INITIATE_CONTACT *msg;
53         unsigned long flags;
54
55         DPRINT_ENTER(VMBUS);
56
57         /* Make sure we are not connecting or connected */
58         if (gVmbusConnection.ConnectState != Disconnected)
59                 return -1;
60
61         /* Initialize the vmbus connection */
62         gVmbusConnection.ConnectState = Connecting;
63         gVmbusConnection.WorkQueue = create_workqueue("hv_vmbus_con");
64         if (!gVmbusConnection.WorkQueue)
65         {
66                 ret = -1;
67                 goto Cleanup;
68         }
69
70         INITIALIZE_LIST_HEAD(&gVmbusConnection.ChannelMsgList);
71         spin_lock_init(&gVmbusConnection.channelmsg_lock);
72
73         INITIALIZE_LIST_HEAD(&gVmbusConnection.ChannelList);
74         spin_lock_init(&gVmbusConnection.channel_lock);
75
76         /*
77          * Setup the vmbus event connection for channel interrupt
78          * abstraction stuff
79          */
80         gVmbusConnection.InterruptPage = PageAlloc(1);
81         if (gVmbusConnection.InterruptPage == NULL)
82         {
83                 ret = -1;
84                 goto Cleanup;
85         }
86
87         gVmbusConnection.RecvInterruptPage = gVmbusConnection.InterruptPage;
88         gVmbusConnection.SendInterruptPage = (void*)((unsigned long)gVmbusConnection.InterruptPage + (PAGE_SIZE >> 1));
89
90         /* Setup the monitor
91          * notification facility. The 1st page for parent->child and
92          * the 2nd page for child->parent
93          */
94         gVmbusConnection.MonitorPages = PageAlloc(2);
95         if (gVmbusConnection.MonitorPages == NULL)
96         {
97                 ret = -1;
98                 goto Cleanup;
99         }
100
101         msgInfo = kzalloc(sizeof(VMBUS_CHANNEL_MSGINFO) + sizeof(VMBUS_CHANNEL_INITIATE_CONTACT), GFP_KERNEL);
102         if (msgInfo == NULL)
103         {
104                 ret = -1;
105                 goto Cleanup;
106         }
107
108         msgInfo->WaitEvent = WaitEventCreate();
109         msg = (VMBUS_CHANNEL_INITIATE_CONTACT*)msgInfo->Msg;
110
111         msg->Header.MessageType = ChannelMessageInitiateContact;
112         msg->VMBusVersionRequested = VMBUS_REVISION_NUMBER;
113         msg->InterruptPage = GetPhysicalAddress(gVmbusConnection.InterruptPage);
114         msg->MonitorPage1 = GetPhysicalAddress(gVmbusConnection.MonitorPages);
115         msg->MonitorPage2 = GetPhysicalAddress((void *)((unsigned long)gVmbusConnection.MonitorPages + PAGE_SIZE));
116
117         /*
118          * Add to list before we send the request since we may
119          * receive the response before returning from this routine
120          */
121         spin_lock_irqsave(&gVmbusConnection.channelmsg_lock, flags);
122         INSERT_TAIL_LIST(&gVmbusConnection.ChannelMsgList, &msgInfo->MsgListEntry);
123         spin_unlock_irqrestore(&gVmbusConnection.channelmsg_lock, flags);
124
125         DPRINT_DBG(VMBUS, "Vmbus connection -  interrupt pfn %llx, monitor1 pfn %llx,, monitor2 pfn %llx",
126                 msg->InterruptPage, msg->MonitorPage1, msg->MonitorPage2);
127
128         DPRINT_DBG(VMBUS, "Sending channel initiate msg...");
129
130         ret = VmbusPostMessage(msg, sizeof(VMBUS_CHANNEL_INITIATE_CONTACT));
131         if (ret != 0)
132         {
133                 REMOVE_ENTRY_LIST(&msgInfo->MsgListEntry);
134                 goto Cleanup;
135         }
136
137         /* Wait for the connection response */
138         WaitEventWait(msgInfo->WaitEvent);
139
140         REMOVE_ENTRY_LIST(&msgInfo->MsgListEntry);
141
142         /* Check if successful */
143         if (msgInfo->Response.VersionResponse.VersionSupported)
144         {
145                 DPRINT_INFO(VMBUS, "Vmbus connected!!");
146                 gVmbusConnection.ConnectState = Connected;
147
148         }
149         else
150         {
151                 DPRINT_ERR(VMBUS, "Vmbus connection failed!!...current version (%d) not supported", VMBUS_REVISION_NUMBER);
152                 ret = -1;
153
154                 goto Cleanup;
155         }
156
157
158         WaitEventClose(msgInfo->WaitEvent);
159         kfree(msgInfo);
160         DPRINT_EXIT(VMBUS);
161
162         return 0;
163
164 Cleanup:
165
166         gVmbusConnection.ConnectState = Disconnected;
167
168         if (gVmbusConnection.WorkQueue)
169                 destroy_workqueue(gVmbusConnection.WorkQueue);
170
171         if (gVmbusConnection.InterruptPage)
172         {
173                 PageFree(gVmbusConnection.InterruptPage, 1);
174                 gVmbusConnection.InterruptPage = NULL;
175         }
176
177         if (gVmbusConnection.MonitorPages)
178         {
179                 PageFree(gVmbusConnection.MonitorPages, 2);
180                 gVmbusConnection.MonitorPages = NULL;
181         }
182
183         if (msgInfo)
184         {
185                 if (msgInfo->WaitEvent)
186                         WaitEventClose(msgInfo->WaitEvent);
187
188                 kfree(msgInfo);
189         }
190
191         DPRINT_EXIT(VMBUS);
192
193         return ret;
194 }
195
196
197 /*++
198
199 Name:
200         VmbusDisconnect()
201
202 Description:
203         Sends a disconnect request on the partition service connection
204
205 --*/
206 static int
207 VmbusDisconnect(
208         void
209         )
210 {
211         int ret=0;
212         VMBUS_CHANNEL_UNLOAD *msg;
213
214         DPRINT_ENTER(VMBUS);
215
216         /* Make sure we are connected */
217         if (gVmbusConnection.ConnectState != Connected)
218                 return -1;
219
220         msg = kzalloc(sizeof(VMBUS_CHANNEL_UNLOAD), GFP_KERNEL);
221
222         msg->MessageType = ChannelMessageUnload;
223
224         ret = VmbusPostMessage(msg, sizeof(VMBUS_CHANNEL_UNLOAD));
225
226         if (ret != 0)
227         {
228                 goto Cleanup;
229         }
230
231         PageFree(gVmbusConnection.InterruptPage, 1);
232
233         /* TODO: iterate thru the msg list and free up */
234
235         destroy_workqueue(gVmbusConnection.WorkQueue);
236
237         gVmbusConnection.ConnectState = Disconnected;
238
239         DPRINT_INFO(VMBUS, "Vmbus disconnected!!");
240
241 Cleanup:
242         if (msg)
243         {
244                 kfree(msg);
245         }
246
247         DPRINT_EXIT(VMBUS);
248
249         return ret;
250 }
251
252
253 /*++
254
255 Name:
256         GetChannelFromRelId()
257
258 Description:
259         Get the channel object given its child relative id (ie channel id)
260
261 --*/
262 static VMBUS_CHANNEL*
263 GetChannelFromRelId(
264         u32 relId
265         )
266 {
267         VMBUS_CHANNEL* channel;
268         VMBUS_CHANNEL* foundChannel=NULL;
269         LIST_ENTRY* anchor;
270         LIST_ENTRY* curr;
271         unsigned long flags;
272
273         spin_lock_irqsave(&gVmbusConnection.channel_lock, flags);
274         ITERATE_LIST_ENTRIES(anchor, curr, &gVmbusConnection.ChannelList)
275         {
276                 channel = CONTAINING_RECORD(curr, VMBUS_CHANNEL, ListEntry);
277
278                 if (channel->OfferMsg.ChildRelId == relId)
279                 {
280                         foundChannel = channel;
281                         break;
282                 }
283         }
284         spin_unlock_irqrestore(&gVmbusConnection.channel_lock, flags);
285
286         return foundChannel;
287 }
288
289
290
291 /*++
292
293 Name:
294         VmbusProcessChannelEvent()
295
296 Description:
297         Process a channel event notification
298
299 --*/
300 static void
301 VmbusProcessChannelEvent(
302         void * context
303         )
304 {
305         VMBUS_CHANNEL* channel;
306         u32 relId = (u32)(unsigned long)context;
307
308         ASSERT(relId > 0);
309
310         /*
311          * Find the channel based on this relid and invokes the
312          * channel callback to process the event
313          */
314         channel = GetChannelFromRelId(relId);
315
316         if (channel)
317         {
318                 VmbusChannelOnChannelEvent(channel);
319                 /* WorkQueueQueueWorkItem(channel->dataWorkQueue, VmbusChannelOnChannelEvent, (void*)channel); */
320         }
321         else
322         {
323         DPRINT_ERR(VMBUS, "channel not found for relid - %d.", relId);
324         }
325 }
326
327
328 /*++
329
330 Name:
331         VmbusOnEvents()
332
333 Description:
334         Handler for events
335
336 --*/
337 static void
338 VmbusOnEvents(
339   void
340         )
341 {
342         int dword;
343         /* int maxdword = PAGE_SIZE >> 3; // receive size is 1/2 page and divide that by 4 bytes */
344         int maxdword = MAX_NUM_CHANNELS_SUPPORTED >> 5;
345         int bit;
346         int relid;
347         u32* recvInterruptPage = gVmbusConnection.RecvInterruptPage;
348         /* VMBUS_CHANNEL_MESSAGE* receiveMsg; */
349
350         DPRINT_ENTER(VMBUS);
351
352         /* Check events */
353         if (recvInterruptPage)
354         {
355                 for (dword = 0; dword < maxdword; dword++)
356                 {
357                         if (recvInterruptPage[dword])
358                         {
359                                 for (bit = 0; bit < 32; bit++)
360                                 {
361                                         if (BitTestAndClear(&recvInterruptPage[dword], bit))
362                                         {
363                                                 relid = (dword << 5) + bit;
364
365                                                 DPRINT_DBG(VMBUS, "event detected for relid - %d", relid);
366
367                                                 if (relid == 0) /* special case - vmbus channel protocol msg */
368                                                 {
369                                                         DPRINT_DBG(VMBUS, "invalid relid - %d", relid);
370
371                                                         continue;                                               }
372                                                 else
373                                                 {
374                                                         /* QueueWorkItem(VmbusProcessEvent, (void*)relid); */
375                                                         /* ret = WorkQueueQueueWorkItem(gVmbusConnection.workQueue, VmbusProcessChannelEvent, (void*)relid); */
376                                                         VmbusProcessChannelEvent((void*)(unsigned long)relid);
377                                                 }
378                                         }
379                                 }
380                         }
381                  }
382         }
383         DPRINT_EXIT(VMBUS);
384
385         return;
386 }
387
388 /*++
389
390 Name:
391         VmbusPostMessage()
392
393 Description:
394         Send a msg on the vmbus's message connection
395
396 --*/
397 static int
398 VmbusPostMessage(
399         void *                  buffer,
400         size_t                  bufferLen
401         )
402 {
403         int ret=0;
404         HV_CONNECTION_ID connId;
405
406
407         connId.Asu32 =0;
408         connId.u.Id = VMBUS_MESSAGE_CONNECTION_ID;
409         ret = HvPostMessage(
410                         connId,
411                         1,
412                         buffer,
413                         bufferLen);
414
415         return  ret;
416 }
417
418 /*++
419
420 Name:
421         VmbusSetEvent()
422
423 Description:
424         Send an event notification to the parent
425
426 --*/
427 static int
428 VmbusSetEvent(u32 childRelId)
429 {
430         int ret=0;
431
432         DPRINT_ENTER(VMBUS);
433
434         /* Each u32 represents 32 channels */
435         BitSet((u32*)gVmbusConnection.SendInterruptPage + (childRelId >> 5), childRelId & 31);
436         ret = HvSignalEvent();
437
438         DPRINT_EXIT(VMBUS);
439
440         return ret;
441 }
442
443 /* EOF */