Index: go/src/infra/gae/libs/gae/taskqueue.go |
diff --git a/go/src/infra/gae/libs/gae/taskqueue.go b/go/src/infra/gae/libs/gae/taskqueue.go |
index ab25eefcc4f342d2bef3554533dd1b0f55285a27..4af6d22e332a288d6b2b1d7fb5f69cbca146c3d8 100644 |
--- a/go/src/infra/gae/libs/gae/taskqueue.go |
+++ b/go/src/infra/gae/libs/gae/taskqueue.go |
@@ -29,14 +29,32 @@ type TaskQueue interface { |
// SetTQFactory. |
type TQFactory func(context.Context) TaskQueue |
-// GetTQ gets the TaskQueue implementation from context. |
-func GetTQ(c context.Context) TaskQueue { |
+// TQFilter is the function signature for a filter TQ implementation. It |
+// gets the current TQ implementation, and returns a new TQ implementation |
+// backed by the one passed in. |
+type TQFilter func(context.Context, TaskQueue) TaskQueue |
+ |
+// GetTQUnfiltered gets gets the TaskQueue implementation from context without |
+// any of the filters applied. |
+func GetTQUnfiltered(c context.Context) TaskQueue { |
if f, ok := c.Value(taskQueueKey).(TQFactory); ok && f != nil { |
return f(c) |
} |
return nil |
} |
+// GetTQ gets the TaskQueue implementation from context. |
+func GetTQ(c context.Context) TaskQueue { |
+ ret := GetTQUnfiltered(c) |
+ if ret == nil { |
+ return nil |
+ } |
+ for _, f := range getCurTQFilters(c) { |
+ ret = f(c, ret) |
+ } |
+ return ret |
+} |
+ |
// SetTQFactory sets the function to produce TaskQueue instances, as returned by |
// the GetTQ method. |
func SetTQFactory(c context.Context, tqf TQFactory) context.Context { |
@@ -49,3 +67,23 @@ func SetTQFactory(c context.Context, tqf TQFactory) context.Context { |
func SetTQ(c context.Context, tq TaskQueue) context.Context { |
return SetTQFactory(c, func(context.Context) TaskQueue { return tq }) |
} |
+ |
+func getCurTQFilters(c context.Context) []TQFilter { |
+ curFiltsI := c.Value(taskQueueFilterKey) |
+ if curFiltsI != nil { |
+ return curFiltsI.([]TQFilter) |
+ } |
+ return nil |
+} |
+ |
+// AddTQFilters adds TaskQueue filters to the context. |
+func AddTQFilters(c context.Context, filts ...TQFilter) context.Context { |
+ if len(filts) == 0 { |
+ return c |
+ } |
+ cur := getCurTQFilters(c) |
+ newFilts := make([]TQFilter, 0, len(cur)+len(filts)) |
+ newFilts = append(newFilts, getCurTQFilters(c)...) |
+ newFilts = append(newFilts, filts...) |
+ return context.WithValue(c, taskQueueFilterKey, newFilts) |
+} |