Index: impl/memory/taskqueue.go |
diff --git a/impl/memory/taskqueue.go b/impl/memory/taskqueue.go |
index 9ac3e5287d6b80ba66b140edee33f39d75a376ca..703c1057835d193adc2e263b4f25c84e56013d3c 100644 |
--- a/impl/memory/taskqueue.go |
+++ b/impl/memory/taskqueue.go |
@@ -18,20 +18,20 @@ import ( |
/////////////////////////////// public functions /////////////////////////////// |
func useTQ(c context.Context) context.Context { |
- return tq.SetRawFactory(c, func(ic context.Context) tq.RawInterface { |
- tqd := cur(ic).Get(memContextTQIdx) |
- if x, ok := tqd.(*taskQueueData); ok { |
- return &taskqueueImpl{ |
- x, |
- ic, |
- curGID(ic).namespace, |
- } |
+ return tq.SetRawFactory(c, func(ic context.Context, wantTxn bool) tq.RawInterface { |
+ ns := curGID(ic).namespace |
+ var tqd memContextObj |
+ |
+ if !wantTxn { |
+ tqd = curNoTxn(ic).Get(memContextTQIdx) |
+ } else { |
+ tqd = cur(ic).Get(memContextTQIdx) |
} |
- return &taskqueueTxnImpl{ |
- tqd.(*txnTaskQueueData), |
- ic, |
- curGID(ic).namespace, |
+ |
+ if x, ok := tqd.(*taskQueueData); ok { |
+ return &taskqueueImpl{x, ic, ns} |
} |
+ return &taskqueueTxnImpl{tqd.(*txnTaskQueueData), ic, ns} |
}) |
} |