Index: impl/prod/raw_datastore.go |
diff --git a/impl/prod/raw_datastore.go b/impl/prod/raw_datastore.go |
index bc835c03bcca53d58e87991bbe4e8056804a983a..e70b31adb06670de9cbcdab4a238626571dbcb8b 100644 |
--- a/impl/prod/raw_datastore.go |
+++ b/impl/prod/raw_datastore.go |
@@ -15,8 +15,18 @@ import ( |
// useRDS adds a gae.RawDatastore implementation to context, accessible |
// by gae.GetDS(c) |
func useRDS(c context.Context) context.Context { |
- return ds.SetRawFactory(c, func(ci context.Context) ds.RawInterface { |
- return rdsImpl{ci, AEContext(ci), info.Get(ci).GetNamespace()} |
+ return ds.SetRawFactory(c, func(ci context.Context, wantTxn bool) ds.RawInterface { |
+ ns := info.Get(ci).GetNamespace() |
+ maybeTxnCtx := AEContext(ci) |
+ |
+ if wantTxn { |
+ return rdsImpl{ci, maybeTxnCtx, ns} |
+ } |
+ aeCtx := AEContextNoTxn(ci) |
+ if maybeTxnCtx != aeCtx { |
+ ci = context.WithValue(ci, prodContextKey, aeCtx) |
+ } |
+ return rdsImpl{ci, aeCtx, ns} |
}) |
} |
@@ -223,8 +233,8 @@ func (d rdsImpl) Count(fq *ds.FinalizedQuery) (int64, error) { |
func (d rdsImpl) RunInTransaction(f func(c context.Context) error, opts *ds.TransactionOptions) error { |
ropts := (*datastore.TransactionOptions)(opts) |
- return datastore.RunInTransaction(d.aeCtx, func(aeCtx context.Context) error { |
- return f(context.WithValue(d.userCtx, prodContextKey, aeCtx)) |
+ return datastore.RunInTransaction(d.aeCtx, func(c context.Context) error { |
+ return f(context.WithValue(d.userCtx, prodContextKey, c)) |
}, ropts) |
} |