8#ifndef Sawyer_Database_H
9#define Sawyer_Database_H
11#if __cplusplus >= 201103L
13#include <boost/iterator/iterator_facade.hpp>
14#include <boost/lexical_cast.hpp>
15#include <boost/numeric/conversion/cast.hpp>
17#include <Sawyer/Assert.h>
18#include <Sawyer/Map.h>
19#include <Sawyer/Optional.h>
176 class ConnectionBase;
180class Exception:
public std::runtime_error {
182 Exception(
const std::string &what)
183 : std::runtime_error(what) {}
185 ~Exception() noexcept {}
198 friend class ::Sawyer::Database::Statement;
199 friend class ::Sawyer::Database::Detail::ConnectionBase;
201 std::shared_ptr<Detail::ConnectionBase> pimpl_;
208 explicit Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl);
214 ~Connection() =
default;
217 static Connection fromUri(
const std::string &uri);
222 static std::string uriDocString();
248 Statement stmt(
const std::string &sql);
251 Connection& run(
const std::string &sql);
259 Optional<T>
get(
const std::string &sql);
264 std::string driverName()
const;
271 size_t lastInsert()
const;
274 void pimpl(
const std::shared_ptr<Detail::ConnectionBase> &p) {
288 friend class ::Sawyer::Database::Detail::ConnectionBase;
290 std::shared_ptr<Detail::StatementBase> pimpl_;
307 explicit Statement(
const std::shared_ptr<Detail::StatementBase> &stmt)
312 Connection connection()
const;
324 Statement& bind(
const std::string &name,
const T &value);
331 Statement& rebind(
const std::string &name,
const T &value);
367 friend class ::Sawyer::Database::Iterator;
369 std::shared_ptr<Detail::StatementBase> stmt_;
376 explicit Row(
const std::shared_ptr<Detail::StatementBase> &stmt);
381 Optional<T>
get(
size_t columnIdx)
const;
386 size_t rowNumber()
const;
399class Iterator:
public boost::iterator_facade<Iterator, const Row, boost::forward_traversal_tag> {
400 friend class ::Sawyer::Database::Detail::StatementBase;
409 explicit Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt);
418 explicit operator bool()
const {
423 friend class boost::iterator_core_access;
424 const Row& dereference()
const;
425 bool equal(
const Iterator&)
const;
450class ConnectionBase:
public std::enable_shared_from_this<ConnectionBase> {
451 friend class ::Sawyer::Database::Connection;
457 virtual ~ConnectionBase() {}
461 virtual void close() = 0;
465 virtual Statement prepareStatement(
const std::string &sql) = 0;
469 virtual size_t lastInsert()
const = 0;
471 Statement makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail);
473 virtual std::string driverName()
const = 0;
480 friend class ::Sawyer::Database::Detail::StatementBase;
482 std::vector<size_t> indexes;
483 bool isBound =
false;
485 void append(
size_t idx) {
486 indexes.push_back(idx);
492 friend class ::Sawyer::Database::Detail::StatementBase;
493 Optional<T> operator()(StatementBase *stmt,
size_t idx);
504class StatementBase:
public std::enable_shared_from_this<StatementBase> {
505 friend class ::Sawyer::Database::Iterator;
506 friend class ::Sawyer::Database::Row;
507 friend class ::Sawyer::Database::Statement;
508 template<
class T>
friend class ::Sawyer::Database::Detail::ColumnReader;
510 using Parameters = Container::Map<std::string, Parameter>;
512 std::shared_ptr<ConnectionBase> connection_;
513 std::weak_ptr<ConnectionBase> weakConnection_;
515 Statement::State state_ = Statement::DEAD;
516 size_t sequence_ = 0;
517 size_t rowNumber_ = 0;
520 virtual ~StatementBase() {}
523 explicit StatementBase(
const std::shared_ptr<ConnectionBase> &connection)
524 : weakConnection_(connection) {
525 ASSERT_not_null(connection);
532 std::pair<std::string, size_t> parseParameters(
const std::string &highSql) {
535 bool inString =
false;
536 size_t nLowParams = 0;
537 state(Statement::READY);
538 for (
size_t i = 0; i < highSql.size(); ++i) {
539 if (
'\'' == highSql[i]) {
540 inString = !inString;
541 lowSql += highSql[i];
542 }
else if (
'?' == highSql[i] && !inString) {
544 std::string paramName;
545 while (i+1 < highSql.size() && (::isalnum(highSql[i+1]) ||
'_' == highSql[i+1]))
546 paramName += highSql[++i];
547 if (paramName.empty())
548 throw Exception(
"invalid parameter name at character position " + boost::lexical_cast<std::string>(i));
549 Parameter ¶m = params_.insertMaybeDefault(paramName);
550 param.append(nLowParams++);
551 state(Statement::UNBOUND);
553 lowSql += highSql[i];
557 state(Statement::DEAD);
558 throw Exception(
"mismatched quotes in SQL statement");
560 return std::make_pair(lowSql, nLowParams);
564 void invalidateIteratorsAndRows() {
569 size_t sequence()
const {
575 bool lockConnection() {
576 return (connection_ = weakConnection_.lock()) !=
nullptr;
581 void unlockConnection() {
587 bool isConnectionLocked()
const {
588 return connection_ !=
nullptr;
592 std::shared_ptr<ConnectionBase> connection()
const {
593 return weakConnection_.lock();
597 Statement::State state()
const {
604 void state(Statement::State newState) {
606 case Statement::DEAD:
607 case Statement::FINISHED:
608 case Statement::UNBOUND:
609 case Statement::READY:
610 invalidateIteratorsAndRows();
613 case Statement::EXECUTING:
614 ASSERT_require(isConnectionLocked());
621 bool hasUnboundParameters()
const {
622 ASSERT_forbid(state() == Statement::DEAD);
623 for (
const Parameter ¶m: params_.values()) {
632 virtual void unbindAllParams() {
633 ASSERT_forbid(state() == Statement::DEAD);
634 for (Parameter ¶m: params_.values())
635 param.isBound = false;
636 state(params_.isEmpty() ? Statement::READY : Statement::UNBOUND);
641 virtual void reset(
bool doUnbind) {
642 ASSERT_forbid(state() == Statement::DEAD);
643 invalidateIteratorsAndRows();
647 state(hasUnboundParameters() ? Statement::UNBOUND : Statement::READY);
654 void bind(
const std::string &name,
const T &value,
bool isRebind) {
656 throw Exception(
"connection is closed");
658 case Statement::DEAD:
659 throw Exception(
"statement is dead");
660 case Statement::FINISHED:
661 case Statement::EXECUTING:
664 case Statement::READY:
665 case Statement::UNBOUND: {
666 if (!params_.exists(name))
667 throw Exception(
"no such parameter \"" + name +
"\" in statement");
668 Parameter ¶m = params_[name];
669 bool wasUnbound = !param.isBound;
670 for (
size_t idx: param.indexes) {
673 }
catch (
const Exception &e) {
674 if (param.indexes.size() > 1)
675 state(Statement::DEAD);
679 param.isBound =
true;
681 if (wasUnbound && !hasUnboundParameters())
682 state(Statement::READY);
692 bind(name, *value, isRebind);
694 bind(name, Nothing(), isRebind);
699 virtual void bindLow(
size_t idx,
int value) = 0;
700 virtual void bindLow(
size_t idx, int64_t value) = 0;
701 virtual void bindLow(
size_t idx,
size_t value) = 0;
702 virtual void bindLow(
size_t idx,
double value) = 0;
703 virtual void bindLow(
size_t idx,
const std::string &value) = 0;
704 virtual void bindLow(
size_t idx,
const char *cstring) = 0;
705 virtual void bindLow(
size_t idx, Nothing) = 0;
706 virtual void bindLow(
size_t idx,
const std::vector<uint8_t> &data) = 0;
708 Iterator makeIterator() {
709 return Iterator(shared_from_this());
716 throw Exception(
"connection is closed");
718 case Statement::DEAD:
719 throw Exception(
"statement is dead");
720 case Statement::UNBOUND: {
722 for (Parameters::Node ¶m: params_.nodes()) {
723 if (!param.value().isBound)
724 s += (s.empty() ?
"" :
", ") + param.key();
726 ASSERT_forbid(s.empty());
727 throw Exception(
"unbound parameters: " + s);
729 case Statement::FINISHED:
730 case Statement::EXECUTING:
733 case Statement::READY: {
734 if (!lockConnection())
735 throw Exception(
"connection has been closed");
736 state(Statement::EXECUTING);
738 Iterator iter = beginLow();
743 ASSERT_not_reachable(
"invalid state");
748 virtual Iterator beginLow() = 0;
753 throw Exception(
"connection is closed");
754 ASSERT_require(state() == Statement::EXECUTING);
755 invalidateIteratorsAndRows();
761 size_t rowNumber()
const {
767 virtual Iterator nextLow() = 0;
771 Optional<T>
get(
size_t columnIdx) {
773 throw Exception(
"connection is closed");
774 ASSERT_require(state() == Statement::EXECUTING);
775 if (columnIdx >= nColumns())
776 throw Exception(
"column index " + boost::lexical_cast<std::string>(columnIdx) +
" is out of range");
777 return ColumnReader<T>()(
this, columnIdx);
781 virtual size_t nColumns()
const = 0;
784 virtual Optional<std::string> getString(
size_t idx) = 0;
785 virtual Optional<std::vector<std::uint8_t>> getBlob(
size_t idx) = 0;
790ColumnReader<T>::operator()(StatementBase *stmt,
size_t idx) {
792 if (!stmt->getString(idx).assignTo(str))
794 return boost::lexical_cast<T>(str);
798inline Optional<std::vector<uint8_t>>
799ColumnReader<std::vector<uint8_t>>::operator()(StatementBase *stmt,
size_t idx) {
800 return stmt->getBlob(idx);
804ConnectionBase::makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail) {
805 return Statement(detail);
815inline Connection::Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl)
819Connection::isOpen()
const {
820 return pimpl_ !=
nullptr;
830Connection::driverName()
const {
832 return pimpl_->driverName();
839Connection::stmt(
const std::string &sql) {
841 return pimpl_->prepareStatement(sql);
843 throw Exception(
"no active database connection");
848Connection::run(
const std::string &sql) {
855Connection::get(
const std::string &sql) {
856 for (
auto row: stmt(sql))
857 return row.
get<T>(0);
862Connection::lastInsert()
const {
864 return pimpl_->lastInsert();
866 throw Exception(
"no active database connection");
875Statement::connection()
const {
877 return Connection(pimpl_->connection());
885Statement::bind(
const std::string &name,
const T &value) {
887 pimpl_->bind(name, value,
false);
889 throw Exception(
"no active database connection");
896Statement::rebind(
const std::string &name,
const T &value) {
898 pimpl_->bind(name, value,
true);
900 throw Exception(
"no active database connection");
908 return pimpl_->begin();
910 throw Exception(
"no active database connection");
928 Iterator row = begin();
930 throw Exception(
"query did not return a row");
931 return row->get<T>(0);
939Iterator::Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt)
943Iterator::dereference()
const {
945 throw Exception(
"dereferencing the end iterator");
946 if (row_.sequence_ != row_.stmt_->sequence())
947 throw Exception(
"iterator has been invalidated");
952Iterator::equal(
const Iterator &other)
const {
953 return row_.stmt_ == other.row_.stmt_ && row_.sequence_ == other.row_.sequence_;
957Iterator::increment() {
959 throw Exception(
"incrementing the end iterator");
960 *
this = row_.stmt_->next();
968Row::Row(
const std::shared_ptr<Detail::StatementBase> &stmt)
969 : stmt_(stmt), sequence_(stmt ? stmt->sequence() : 0) {}
973Row::get(
size_t columnIdx)
const {
974 ASSERT_not_null(stmt_);
975 if (sequence_ != stmt_->sequence())
976 throw Exception(
"row has been invalidated");
977 return stmt_->get<T>(columnIdx);
981Row::rowNumber()
const {
982 ASSERT_not_null(stmt_);
983 if (sequence_ != stmt_->sequence())
984 throw Exception(
"row has been invalidated");
985 return stmt_->rowNumber();
Holds a value or nothing.
bool get(const Word *words, size_t idx)
Return a single bit.
bool increment(Word *vec1, const BitRange &range1)
Increment.