Index: chrome_elf/nt_registry/nt_registry.cc |
diff --git a/chrome_elf/nt_registry/nt_registry.cc b/chrome_elf/nt_registry/nt_registry.cc |
index 4c65af8daff7c76ea5a9e4a0255da0225aeb58a4..aca2a074da99e9f98952a286a7fe9aa8eb627274 100644 |
--- a/chrome_elf/nt_registry/nt_registry.cc |
+++ b/chrome_elf/nt_registry/nt_registry.cc |
@@ -7,6 +7,8 @@ |
#include <assert.h> |
#include <stdlib.h> |
+#include <memory> |
+ |
namespace { |
// Function pointers used for registry access. |
@@ -542,6 +544,45 @@ bool ParseFullRegPath(const std::wstring& converted_root, |
return true; |
} |
+// String safety. |
+// - NOTE: only working with wchar_t here. |
+// - Also ensures the content of |value_bytes| is at least a terminator. |
+// - Pass "true" for |multi| for MULTISZ. |
+void EnsureTerminatedSZ(std::vector<BYTE>* value_bytes, bool multi) { |
+ DWORD terminator_size = sizeof(wchar_t); |
+ |
+ if (multi) |
+ terminator_size = 2 * sizeof(wchar_t); |
+ |
+ // Ensure content is at least the size of a terminator. |
+ if (value_bytes->size() < terminator_size) { |
+ value_bytes->insert(value_bytes->end(), |
+ terminator_size - value_bytes->size(), 0); |
+ } |
+ |
+ // Sanity check content size based on character size. |
+ DWORD modulo = value_bytes->size() % sizeof(wchar_t); |
+ value_bytes->insert(value_bytes->end(), modulo, 0); |
+ |
+ // Now finally check for trailing terminator. |
+ bool terminated = true; |
+ size_t last_element = value_bytes->size() - 1; |
+ for (size_t i = 0; i < terminator_size; i++) { |
+ if ((*value_bytes)[last_element - i] != 0) { |
+ terminated = false; |
+ break; |
+ } |
+ } |
+ |
+ if (terminated) |
+ return; |
+ |
+ // Append a full terminator to be safe. |
+ value_bytes->insert(value_bytes->end(), terminator_size, 0); |
+ |
+ return; |
+} |
+ |
//------------------------------------------------------------------------------ |
// Misc wrapper functions - LOCAL |
//------------------------------------------------------------------------------ |
@@ -749,8 +790,7 @@ void CloseRegKey(HANDLE key) { |
bool QueryRegKeyValue(HANDLE key, |
const wchar_t* value_name, |
ULONG* out_type, |
- BYTE** out_buffer, |
- DWORD* out_size) { |
+ std::vector<BYTE>* out_buffer) { |
if (!g_initialized) |
InitNativeRegApi(); |
@@ -758,7 +798,6 @@ bool QueryRegKeyValue(HANDLE key, |
UNICODE_STRING value_uni = {}; |
g_rtl_init_unicode_string(&value_uni, value_name); |
DWORD size_needed = 0; |
- bool success = false; |
// First call to find out how much room we need for the value! |
ntstatus = g_nt_query_value_key(key, &value_uni, KeyValueFullInformation, |
@@ -766,24 +805,28 @@ bool QueryRegKeyValue(HANDLE key, |
if (ntstatus != STATUS_BUFFER_TOO_SMALL) |
return false; |
+ std::unique_ptr<BYTE[]> buffer(new BYTE[size_needed]); |
KEY_VALUE_FULL_INFORMATION* value_info = |
- reinterpret_cast<KEY_VALUE_FULL_INFORMATION*>(new BYTE[size_needed]); |
+ reinterpret_cast<KEY_VALUE_FULL_INFORMATION*>(buffer.get()); |
// Second call to get the value. |
ntstatus = g_nt_query_value_key(key, &value_uni, KeyValueFullInformation, |
value_info, size_needed, &size_needed); |
- if (NT_SUCCESS(ntstatus)) { |
- *out_type = value_info->Type; |
- *out_size = value_info->DataLength; |
- *out_buffer = new BYTE[*out_size]; |
- ::memcpy(*out_buffer, |
- (reinterpret_cast<BYTE*>(value_info) + value_info->DataOffset), |
- *out_size); |
- success = true; |
+ if (!NT_SUCCESS(ntstatus)) |
+ return false; |
+ |
+ *out_type = value_info->Type; |
+ DWORD data_size = value_info->DataLength; |
+ |
+ if (data_size) { |
+ // Move the data into |out_buffer| vector. |
+ BYTE* data = reinterpret_cast<BYTE*>(value_info) + value_info->DataOffset; |
+ out_buffer->assign(data, data + data_size); |
+ } else { |
+ out_buffer->clear(); |
} |
- delete[] value_info; |
- return success; |
+ return true; |
} |
// wrapper function |
@@ -791,16 +834,18 @@ bool QueryRegValueDWORD(HANDLE key, |
const wchar_t* value_name, |
DWORD* out_dword) { |
ULONG type = REG_NONE; |
- BYTE* value_bytes = nullptr; |
- DWORD ret_size = 0; |
+ std::vector<BYTE> value_bytes; |
+ |
+ if (!QueryRegKeyValue(key, value_name, &type, &value_bytes) || |
+ type != REG_DWORD) { |
+ return false; |
+ } |
- if (!QueryRegKeyValue(key, value_name, &type, &value_bytes, &ret_size) || |
- type != REG_DWORD) |
+ if (value_bytes.size() < sizeof(*out_dword)) |
return false; |
- *out_dword = *(reinterpret_cast<DWORD*>(value_bytes)); |
+ *out_dword = *(reinterpret_cast<DWORD*>(value_bytes.data())); |
- delete[] value_bytes; |
return true; |
} |
@@ -828,17 +873,18 @@ bool QueryRegValueDWORD(ROOT_KEY root, |
bool QueryRegValueSZ(HANDLE key, |
const wchar_t* value_name, |
std::wstring* out_sz) { |
- BYTE* value_bytes = nullptr; |
- DWORD ret_size = 0; |
+ std::vector<BYTE> value_bytes; |
ULONG type = REG_NONE; |
- if (!QueryRegKeyValue(key, value_name, &type, &value_bytes, &ret_size) || |
- type != REG_SZ) |
+ if (!QueryRegKeyValue(key, value_name, &type, &value_bytes) || |
+ (type != REG_SZ && type != REG_EXPAND_SZ)) { |
return false; |
+ } |
- *out_sz = reinterpret_cast<wchar_t*>(value_bytes); |
+ EnsureTerminatedSZ(&value_bytes, false); |
+ |
+ *out_sz = reinterpret_cast<wchar_t*>(value_bytes.data()); |
- delete[] value_bytes; |
return true; |
} |
@@ -866,33 +912,29 @@ bool QueryRegValueSZ(ROOT_KEY root, |
bool QueryRegValueMULTISZ(HANDLE key, |
const wchar_t* value_name, |
std::vector<std::wstring>* out_multi_sz) { |
- BYTE* value_bytes = nullptr; |
- DWORD ret_size = 0; |
+ std::vector<BYTE> value_bytes; |
ULONG type = REG_NONE; |
- if (!QueryRegKeyValue(key, value_name, &type, &value_bytes, &ret_size) || |
- type != REG_MULTI_SZ) |
+ if (!QueryRegKeyValue(key, value_name, &type, &value_bytes) || |
+ type != REG_MULTI_SZ) { |
return false; |
+ } |
- // Make sure the vector is empty to start. |
- (*out_multi_sz).resize(0); |
+ EnsureTerminatedSZ(&value_bytes, true); |
- wchar_t* pointer = reinterpret_cast<wchar_t*>(value_bytes); |
+ // Make sure the out vector is empty to start. |
+ out_multi_sz->clear(); |
+ |
+ wchar_t* pointer = reinterpret_cast<wchar_t*>(value_bytes.data()); |
std::wstring temp = pointer; |
// Loop. Each string is separated by '\0'. Another '\0' at very end (so 2 in |
// a row). |
- while (temp.length() != 0) { |
- (*out_multi_sz).push_back(temp); |
- |
+ while (!temp.empty()) { |
pointer += temp.length() + 1; |
+ out_multi_sz->push_back(std::move(temp)); |
temp = pointer; |
} |
- // Handle the case of "empty multi_sz". |
- if (out_multi_sz->size() == 0) |
- out_multi_sz->push_back(L""); |
- |
- delete[] value_bytes; |
return true; |
} |