| Index: device/u2f/u2f_hid_device.cc
|
| diff --git a/device/u2f/u2f_hid_device.cc b/device/u2f/u2f_hid_device.cc
|
| index c313b96e4c165aab270e68e616c552fa4401d62c..62eda924f55c390d157d72e9a2d1e3b143001880 100644
|
| --- a/device/u2f/u2f_hid_device.cc
|
| +++ b/device/u2f/u2f_hid_device.cc
|
| @@ -7,6 +7,7 @@
|
| #include "base/bind.h"
|
| #include "base/bind_helpers.h"
|
| #include "base/command_line.h"
|
| +#include "base/threading/thread_task_runner_handle.h"
|
| #include "crypto/random.h"
|
| #include "device/base/device_client.h"
|
| #include "device/hid/hid_connection.h"
|
| @@ -29,7 +30,7 @@ U2fHidDevice::U2fHidDevice(scoped_refptr<HidDeviceInfo> device_info)
|
|
|
| U2fHidDevice::~U2fHidDevice() {
|
| // Cleanup connection
|
| - if (connection_)
|
| + if (connection_ && !connection_->closed())
|
| connection_->Close();
|
| }
|
|
|
| @@ -43,17 +44,22 @@ void U2fHidDevice::Transition(std::unique_ptr<U2fApduCommand> command,
|
| switch (state_) {
|
| case State::INIT:
|
| state_ = State::BUSY;
|
| + ArmTimeout(callback);
|
| Connect(base::Bind(&U2fHidDevice::OnConnect, weak_factory_.GetWeakPtr(),
|
| base::Passed(&command), callback));
|
| break;
|
| case State::CONNECTED:
|
| state_ = State::BUSY;
|
| + ArmTimeout(callback);
|
| AllocateChannel(std::move(command), callback);
|
| break;
|
| case State::IDLE: {
|
| state_ = State::BUSY;
|
| std::unique_ptr<U2fMessage> msg = U2fMessage::Create(
|
| channel_id_, U2fMessage::Type::CMD_MSG, command->GetEncodedCommand());
|
| +
|
| + ArmTimeout(callback);
|
| + // Write message to the device
|
| WriteMessage(std::move(msg), true,
|
| base::Bind(&U2fHidDevice::MessageReceived,
|
| weak_factory_.GetWeakPtr(), callback));
|
| @@ -86,6 +92,10 @@ void U2fHidDevice::Connect(const HidService::ConnectCallback& callback) {
|
| void U2fHidDevice::OnConnect(std::unique_ptr<U2fApduCommand> command,
|
| const DeviceCallback& callback,
|
| scoped_refptr<HidConnection> connection) {
|
| + if (state_ == State::DEVICE_ERROR)
|
| + return;
|
| + timeout_callback_.Cancel();
|
| +
|
| if (connection) {
|
| connection_ = connection;
|
| state_ = State::CONNECTED;
|
| @@ -114,6 +124,10 @@ void U2fHidDevice::OnAllocateChannel(std::vector<uint8_t> nonce,
|
| const DeviceCallback& callback,
|
| bool success,
|
| std::unique_ptr<U2fMessage> message) {
|
| + if (state_ == State::DEVICE_ERROR)
|
| + return;
|
| + timeout_callback_.Cancel();
|
| +
|
| if (!success || !message) {
|
| state_ = State::DEVICE_ERROR;
|
| Transition(nullptr, callback);
|
| @@ -254,11 +268,16 @@ void U2fHidDevice::OnReadContinuation(std::unique_ptr<U2fMessage> message,
|
| void U2fHidDevice::MessageReceived(const DeviceCallback& callback,
|
| bool success,
|
| std::unique_ptr<U2fMessage> message) {
|
| + if (state_ == State::DEVICE_ERROR)
|
| + return;
|
| + timeout_callback_.Cancel();
|
| +
|
| if (!success) {
|
| state_ = State::DEVICE_ERROR;
|
| Transition(nullptr, callback);
|
| return;
|
| }
|
| +
|
| std::unique_ptr<U2fApduResponse> response = nullptr;
|
| if (message)
|
| response = U2fApduResponse::CreateFromMessage(message->GetMessagePayload());
|
| @@ -297,8 +316,23 @@ void U2fHidDevice::OnWink(const WinkCallback& callback,
|
| callback.Run();
|
| }
|
|
|
| +void U2fHidDevice::ArmTimeout(const DeviceCallback& callback) {
|
| + DCHECK(timeout_callback_.IsCancelled());
|
| + timeout_callback_.Reset(base::Bind(&U2fHidDevice::OnTimeout,
|
| + weak_factory_.GetWeakPtr(), callback));
|
| + // Setup timeout task for 3 seconds
|
| + base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
|
| + FROM_HERE, timeout_callback_.callback(),
|
| + base::TimeDelta::FromMilliseconds(3000));
|
| +}
|
| +
|
| +void U2fHidDevice::OnTimeout(const DeviceCallback& callback) {
|
| + state_ = State::DEVICE_ERROR;
|
| + Transition(nullptr, callback);
|
| +}
|
| +
|
| std::string U2fHidDevice::GetId() {
|
| - std::ostringstream id("hid:");
|
| + std::ostringstream id("hid:", std::ios::ate);
|
| id << device_info_->device_id();
|
| return id.str();
|
| }
|
|
|