OLD | NEW |
1 // Copyright 2016 The Chromium Authors. All rights reserved. | 1 // Copyright 2016 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 package crimsondb | 5 package crimsondb |
6 | 6 |
7 import ( | 7 import ( |
8 "bytes" | 8 "bytes" |
9 "database/sql" | 9 "database/sql" |
10 "encoding/hex" | 10 "encoding/hex" |
(...skipping 257 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
268 | 268 |
269 var ipRanges []IPRange | 269 var ipRanges []IPRange |
270 ipRanges, err = scanIPRanges(ctx, rows) | 270 ipRanges, err = scanIPRanges(ctx, rows) |
271 if err != nil { | 271 if err != nil { |
272 logging.Errorf(ctx, "%s", err) | 272 logging.Errorf(ctx, "%s", err) |
273 return nil, err | 273 return nil, err |
274 } | 274 } |
275 return ipRanges, nil | 275 return ipRanges, nil |
276 } | 276 } |
277 | 277 |
| 278 // scanHosts is a low-level function to scan sql results. |
| 279 // Rows must contain site, hostname, mac_addr, ip, boot_class in that order. |
| 280 func scanHosts(ctx context.Context, rows *sql.Rows) (*crimson.HostList, error) { |
| 281 hostList := crimson.HostList{} |
| 282 |
| 283 for rows.Next() { |
| 284 var ipString, macString string |
| 285 var hw net.HardwareAddr |
| 286 var ip net.IP |
| 287 var bootClass sql.NullString |
| 288 |
| 289 host := crimson.Host{} |
| 290 err := rows.Scan(&host.Site, &host.Hostname, &macString, &ipStri
ng, |
| 291 &bootClass) |
| 292 if bootClass.Valid { |
| 293 host.BootClass = bootClass.String |
| 294 } |
| 295 if err != nil { // Users can't trigger that. |
| 296 logging.Errorf(ctx, "%s", err) |
| 297 return nil, err |
| 298 } |
| 299 if macString != "" { |
| 300 hw, err = HexStringToHardwareAddr(macString) |
| 301 if err != nil { |
| 302 return nil, err |
| 303 } |
| 304 host.MacAddr = hw.String() |
| 305 } |
| 306 |
| 307 if ipString != "" { |
| 308 ip, err = HexStringToIP(ipString) |
| 309 if err != nil { |
| 310 return nil, err |
| 311 } |
| 312 host.Ip = ip.String() |
| 313 } |
| 314 hostList.Hosts = append(hostList.Hosts, &host) |
| 315 } |
| 316 err := rows.Err() |
| 317 if err != nil { |
| 318 logging.Errorf(ctx, "%s", err) |
| 319 return nil, err |
| 320 } |
| 321 return &hostList, nil |
| 322 } |
| 323 |
| 324 // InsertHost adds new hosts in the corresponding table. |
| 325 func InsertHost(ctx context.Context, req *crimson.HostList) (err error) { |
| 326 db := DB(ctx) |
| 327 |
| 328 if len(req.Hosts) == 0 { |
| 329 logging.Errorf(ctx, "Received empty list of hosts to create.") |
| 330 return UserErrorf(InvalidArgument, |
| 331 "Received empty list of hosts to create.") |
| 332 } |
| 333 |
| 334 statement := bytes.Buffer{} |
| 335 params := []interface{}{} |
| 336 |
| 337 statement.WriteString("INSERT INTO host " + |
| 338 "(site, hostname, mac_addr, ip, boot_class) VALUES ") |
| 339 delimiter := "" |
| 340 |
| 341 // Check that all required fields have been provided. |
| 342 // TODO(pgervais): autogenerate missing values instead. |
| 343 for i, host := range req.Hosts { |
| 344 if host.Site == "" { |
| 345 err = UserErrorf(InvalidArgument, |
| 346 "Received empty host in entry #%s", i+1) |
| 347 return |
| 348 } |
| 349 if host.MacAddr == "" { |
| 350 err = UserErrorf(InvalidArgument, |
| 351 "Received empty MAC address in entry #%s", i+1) |
| 352 return |
| 353 } |
| 354 if host.Ip == "" { |
| 355 err = UserErrorf(InvalidArgument, |
| 356 "Received empty IP address in entry #%s", i+1) |
| 357 return |
| 358 } |
| 359 if host.Hostname == "" { |
| 360 err = UserErrorf(InvalidArgument, |
| 361 "Received empty hostname in entry #%s", i+1) |
| 362 return |
| 363 } |
| 364 |
| 365 // Compose query |
| 366 var ip, macAddr string |
| 367 statement.WriteString(delimiter) |
| 368 delimiter = ", \n" |
| 369 statement.WriteString("(?, ?, ?, ?, ?)") |
| 370 |
| 371 ip, err = IPStringToHexString(host.Ip) |
| 372 if err != nil { |
| 373 return |
| 374 } |
| 375 |
| 376 macAddr, err = MacAddrStringToHexString(host.MacAddr) |
| 377 if err != nil { |
| 378 return |
| 379 } |
| 380 |
| 381 if host.BootClass == "" { |
| 382 params = append( |
| 383 params, |
| 384 host.Site, host.Hostname, macAddr, ip, nil) |
| 385 } else { |
| 386 params = append( |
| 387 params, |
| 388 host.Site, host.Hostname, macAddr, ip, host.Boot
Class) |
| 389 } |
| 390 } |
| 391 |
| 392 _, err = db.Exec(statement.String(), params...) |
| 393 if err != nil { |
| 394 logging.Errorf(ctx, "Insertion of new hosts failed. %s", err) |
| 395 return |
| 396 } |
| 397 |
| 398 return |
| 399 } |
| 400 |
| 401 func SelectHost(ctx context.Context, req *crimson.HostQuery) (*crimson.HostList,
error) { |
| 402 var err error |
| 403 |
| 404 db := DB(ctx) |
| 405 delimiter := "" |
| 406 |
| 407 statement := bytes.Buffer{} |
| 408 params := []interface{}{} |
| 409 |
| 410 statement.WriteString("SELECT site, hostname, mac_addr, ip, boot_class F
ROM host") |
| 411 delimiter = "\nWHERE " |
| 412 |
| 413 if req.Site != "" { |
| 414 statement.WriteString(delimiter) |
| 415 delimiter = "\nAND " |
| 416 statement.WriteString("site=?") |
| 417 params = append(params, req.Site) |
| 418 } |
| 419 |
| 420 if req.Hostname != "" { |
| 421 statement.WriteString(delimiter) |
| 422 delimiter = "\nAND " |
| 423 statement.WriteString("hostname=?") |
| 424 params = append(params, req.Hostname) |
| 425 } |
| 426 |
| 427 if req.MacAddr != "" { |
| 428 statement.WriteString(delimiter) |
| 429 delimiter = "\nAND " |
| 430 hw, err := MacAddrStringToHexString(req.MacAddr) |
| 431 if err != nil { |
| 432 return nil, UserErrorf( |
| 433 InvalidArgument, |
| 434 "parsing of Mac address failed: %s", req.MacAddr
) |
| 435 } |
| 436 statement.WriteString("mac_addr=?") |
| 437 params = append(params, hw) |
| 438 } |
| 439 |
| 440 if req.Ip != "" { |
| 441 statement.WriteString(delimiter) |
| 442 delimiter = "\nAND " |
| 443 ip, err := IPStringToHexString(req.Ip) |
| 444 if err != nil { |
| 445 return nil, UserErrorf( |
| 446 InvalidArgument, |
| 447 "parsing of IP address failed: %s", req.Ip) |
| 448 } |
| 449 statement.WriteString("ip=?") |
| 450 params = append(params, ip) |
| 451 } |
| 452 |
| 453 if req.BootClass != "" { |
| 454 statement.WriteString(delimiter) |
| 455 delimiter = "\nAND " |
| 456 statement.WriteString("boot_class=?") |
| 457 params = append(params, req.BootClass) |
| 458 } |
| 459 |
| 460 if req.Limit > 0 { |
| 461 statement.WriteString("\nLIMIT ?") |
| 462 params = append(params, req.Limit) |
| 463 } |
| 464 |
| 465 sqlRows, err := db.Query(statement.String(), params...) |
| 466 |
| 467 if err != nil { |
| 468 logging.Errorf(ctx, "%s", err) |
| 469 return nil, err |
| 470 } |
| 471 defer sqlRows.Close() |
| 472 |
| 473 var rows *crimson.HostList |
| 474 rows, err = scanHosts(ctx, sqlRows) |
| 475 if err != nil { |
| 476 logging.Errorf(ctx, "%s", err) |
| 477 return nil, err |
| 478 } |
| 479 return rows, nil |
| 480 } |
| 481 |
278 // UseDB stores a db handle into a context. | 482 // UseDB stores a db handle into a context. |
279 func UseDB(ctx context.Context, db *sql.DB) context.Context { | 483 func UseDB(ctx context.Context, db *sql.DB) context.Context { |
280 return context.WithValue(ctx, "dbHandle", db) | 484 return context.WithValue(ctx, "dbHandle", db) |
281 } | 485 } |
282 | 486 |
283 // DB gets the current db handle from the context. | 487 // DB gets the current db handle from the context. |
284 func DB(ctx context.Context) *sql.DB { | 488 func DB(ctx context.Context) *sql.DB { |
285 return ctx.Value("dbHandle").(*sql.DB) | 489 return ctx.Value("dbHandle").(*sql.DB) |
286 } | 490 } |
287 | 491 |
288 // GetDBHandle returns a handle to the Cloud SQL instance used by this deploymen
t. | 492 // GetDBHandle returns a handle to the Cloud SQL instance used by this deploymen
t. |
289 func GetDBHandle() (*sql.DB, error) { | 493 func GetDBHandle() (*sql.DB, error) { |
290 // TODO(pgervais): do not hard-code the name of the database. | 494 // TODO(pgervais): do not hard-code the name of the database. |
291 return sql.Open("mysql", "root@cloudsql(crimson-staging:crimson-staging)
/crimson") | 495 return sql.Open("mysql", "root@cloudsql(crimson-staging:crimson-staging)
/crimson") |
292 } | 496 } |
OLD | NEW |