Index: scm.py |
diff --git a/scm.py b/scm.py |
index eb5c52403c3e2595d14b133a7d0d00e289811d10..9bc96bc8055832707a749d3b9d256766b5a43956 100644 |
--- a/scm.py |
+++ b/scm.py |
@@ -353,11 +353,34 @@ class GIT(object): |
return remote, upstream_branch |
@staticmethod |
+ def RefToRemoteRef(ref, remote=None): |
+ """Convert a checkout ref to the equivalent remote ref. |
+ |
+ Returns: |
+ A tuple of the remote ref's (common prefix, unique suffix), or None if it |
+ doesn't appear to refer to a remote ref (e.g. it's a commit hash). |
+ """ |
+ # TODO(mmoss): This is just a brute-force mapping based of the expected git |
+ # config. It's a bit better than the even more brute-force replace('heads', |
+ # ...), but could still be smarter (like maybe actually using values gleaned |
+ # from the git config). |
+ m = re.match('^(refs/(remotes/)?)?branch-heads/', ref or '') |
+ if m: |
+ return ('refs/remotes/branch-heads/', ref.replace(m.group(0), '')) |
+ if remote: |
+ m = re.match('^((refs/)?remotes/)?%s/|(refs/)?heads/' % remote, ref or '') |
+ if m: |
+ return ('refs/remotes/%s/' % remote, ref.replace(m.group(0), '')) |
+ return None |
+ |
+ @staticmethod |
def GetUpstreamBranch(cwd): |
"""Gets the current branch's upstream branch.""" |
remote, upstream_branch = GIT.FetchUpstreamTuple(cwd) |
if remote != '.' and upstream_branch: |
- upstream_branch = upstream_branch.replace('heads', 'remotes/' + remote) |
+ remote_ref = GIT.RefToRemoteRef(upstream_branch, remote) |
+ if remote_ref: |
+ upstream_branch = ''.join(remote_ref) |
return upstream_branch |
@staticmethod |