Index: go/src/infra/gae/libs/gae/raw_datastore.go |
diff --git a/go/src/infra/gae/libs/gae/raw_datastore.go b/go/src/infra/gae/libs/gae/raw_datastore.go |
index 17bf44577ac106e7d02f36ab91233143138bf6b4..809d958c57c3efc3fec217dcf8ea8878591b4e44 100644 |
--- a/go/src/infra/gae/libs/gae/raw_datastore.go |
+++ b/go/src/infra/gae/libs/gae/raw_datastore.go |
@@ -95,14 +95,32 @@ type RawDatastore interface { |
// SetRDSFactory. |
type RDSFactory func(context.Context) RawDatastore |
-// GetRDS gets the RawDatastore implementation from context. |
-func GetRDS(c context.Context) RawDatastore { |
+// RDSFilter is the function signature for a filter RDS implementation. It |
+// gets the current RDS implementation, and returns a new RDS implementation |
+// backed by the one passed in. |
+type RDSFilter func(context.Context, RawDatastore) RawDatastore |
+ |
+// GetRDSUnfiltered gets gets the RawDatastore implementation from context without |
+// any of the filters applied. |
+func GetRDSUnfiltered(c context.Context) RawDatastore { |
if f, ok := c.Value(rawDatastoreKey).(RDSFactory); ok && f != nil { |
return f(c) |
} |
return nil |
} |
+// GetRDS gets the RawDatastore implementation from context. |
+func GetRDS(c context.Context) RawDatastore { |
+ ret := GetRDSUnfiltered(c) |
+ if ret == nil { |
+ return nil |
+ } |
+ for _, f := range getCurRDSFilters(c) { |
+ ret = f(c, ret) |
+ } |
+ return ret |
+} |
+ |
// SetRDSFactory sets the function to produce Datastore instances, as returned by |
// the GetRDS method. |
func SetRDSFactory(c context.Context, rdsf RDSFactory) context.Context { |
@@ -115,3 +133,23 @@ func SetRDSFactory(c context.Context, rdsf RDSFactory) context.Context { |
func SetRDS(c context.Context, rds RawDatastore) context.Context { |
return SetRDSFactory(c, func(context.Context) RawDatastore { return rds }) |
} |
+ |
+func getCurRDSFilters(c context.Context) []RDSFilter { |
+ curFiltsI := c.Value(rawDatastoreFilterKey) |
+ if curFiltsI != nil { |
+ return curFiltsI.([]RDSFilter) |
+ } |
+ return nil |
+} |
+ |
+// AddRDSFilters adds RawDatastore filters to the context. |
+func AddRDSFilters(c context.Context, filts ...RDSFilter) context.Context { |
+ if len(filts) == 0 { |
+ return c |
+ } |
+ cur := getCurRDSFilters(c) |
+ newFilts := make([]RDSFilter, 0, len(cur)+len(filts)) |
+ newFilts = append(newFilts, getCurRDSFilters(c)...) |
+ newFilts = append(newFilts, filts...) |
+ return context.WithValue(c, rawDatastoreFilterKey, newFilts) |
+} |