Index: go/src/infra/gae/libs/gae/memcache.go |
diff --git a/go/src/infra/gae/libs/gae/memcache.go b/go/src/infra/gae/libs/gae/memcache.go |
index 63d4f1fd58c92289ce48f1ec90908fb3b6d73568..e2de7cb0f42e3f4029f3c38ab32ee9acf7f2968b 100644 |
--- a/go/src/infra/gae/libs/gae/memcache.go |
+++ b/go/src/infra/gae/libs/gae/memcache.go |
@@ -57,14 +57,32 @@ type Memcache interface { |
// SetMCFactory. |
type MCFactory func(context.Context) Memcache |
-// GetMC gets the current memcache implementation from the context. |
-func GetMC(c context.Context) Memcache { |
+// MCFilter is the function signature for a filter MC implementation. It |
+// gets the current MC implementation, and returns a new MC implementation |
+// backed by the one passed in. |
+type MCFilter func(context.Context, Memcache) Memcache |
+ |
+// GetMCUnfiltered gets gets the Memcache implementation from context without |
+// any of the filters applied. |
+func GetMCUnfiltered(c context.Context) Memcache { |
if f, ok := c.Value(memcacheKey).(MCFactory); ok && f != nil { |
return f(c) |
} |
return nil |
} |
+// GetMC gets the current memcache implementation from the context. |
+func GetMC(c context.Context) Memcache { |
+ ret := GetMCUnfiltered(c) |
+ if ret == nil { |
+ return nil |
+ } |
+ for _, f := range getCurMCFilters(c) { |
+ ret = f(c, ret) |
+ } |
+ return ret |
+} |
+ |
// SetMCFactory sets the function to produce Memcache instances, as returned by |
// the GetMC method. |
func SetMCFactory(c context.Context, mcf MCFactory) context.Context { |
@@ -77,3 +95,23 @@ func SetMCFactory(c context.Context, mcf MCFactory) context.Context { |
func SetMC(c context.Context, mc Memcache) context.Context { |
return SetMCFactory(c, func(context.Context) Memcache { return mc }) |
} |
+ |
+func getCurMCFilters(c context.Context) []MCFilter { |
+ curFiltsI := c.Value(memcacheFilterKey) |
+ if curFiltsI != nil { |
+ return curFiltsI.([]MCFilter) |
+ } |
+ return nil |
+} |
+ |
+// AddMCFilters adds Memcache filters to the context. |
+func AddMCFilters(c context.Context, filts ...MCFilter) context.Context { |
+ if len(filts) == 0 { |
+ return c |
+ } |
+ cur := getCurMCFilters(c) |
+ newFilts := make([]MCFilter, 0, len(cur)+len(filts)) |
+ newFilts = append(newFilts, getCurMCFilters(c)...) |
+ newFilts = append(newFilts, filts...) |
+ return context.WithValue(c, memcacheFilterKey, newFilts) |
+} |