8 #ifndef Sawyer_Database_H
9 #define Sawyer_Database_H
11 #if __cplusplus >= 201103L
13 #include <boost/iterator/iterator_facade.hpp>
14 #include <boost/lexical_cast.hpp>
15 #include <boost/numeric/conversion/cast.hpp>
17 #include <Sawyer/Assert.h>
18 #include <Sawyer/Map.h>
19 #include <Sawyer/Optional.h>
175 class ConnectionBase;
179 class Exception:
public std::runtime_error {
181 Exception(
const std::string &what)
182 :
std::runtime_error(what) {}
184 ~Exception() noexcept {}
197 friend class ::Sawyer::Database::Statement;
198 friend class ::Sawyer::Database::Detail::ConnectionBase;
200 std::shared_ptr<Detail::ConnectionBase> pimpl_;
207 explicit Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl);
213 ~Connection() =
default;
239 Statement stmt(
const std::string &sql);
242 Connection& run(
const std::string &sql);
250 Optional<T>
get(
const std::string &sql);
255 std::string driverName()
const;
262 size_t lastInsert()
const;
265 void pimpl(
const std::shared_ptr<Detail::ConnectionBase> &p) {
279 friend class ::Sawyer::Database::Detail::ConnectionBase;
281 std::shared_ptr<Detail::StatementBase> pimpl_;
298 explicit Statement(
const std::shared_ptr<Detail::StatementBase> &stmt)
303 Connection connection()
const;
315 Statement& bind(
const std::string &name,
const T &value);
322 Statement& rebind(
const std::string &name,
const T &value);
358 friend class ::Sawyer::Database::Iterator;
360 std::shared_ptr<Detail::StatementBase> stmt_;
367 explicit Row(
const std::shared_ptr<Detail::StatementBase> &stmt);
372 Optional<T>
get(
size_t columnIdx)
const;
377 size_t rowNumber()
const;
390 class Iterator:
public boost::iterator_facade<Iterator, const Row, boost::forward_traversal_tag> {
391 friend class ::Sawyer::Database::Detail::StatementBase;
400 explicit Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt);
409 explicit operator bool()
const {
414 friend class boost::iterator_core_access;
415 const Row& dereference()
const;
416 bool equal(
const Iterator&)
const;
441 class ConnectionBase:
public std::enable_shared_from_this<ConnectionBase> {
442 friend class ::Sawyer::Database::Connection;
448 virtual ~ConnectionBase() {}
452 virtual void close() = 0;
456 virtual Statement prepareStatement(
const std::string &sql) = 0;
460 virtual size_t lastInsert()
const = 0;
462 Statement makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail);
464 virtual std::string driverName()
const = 0;
471 friend class ::Sawyer::Database::Detail::StatementBase;
473 std::vector<size_t> indexes;
474 bool isBound =
false;
476 void append(
size_t idx) {
477 indexes.push_back(idx);
483 friend class ::Sawyer::Database::Detail::StatementBase;
484 Optional<T> operator()(StatementBase *stmt,
size_t idx);
495 class StatementBase:
public std::enable_shared_from_this<StatementBase> {
496 friend class ::Sawyer::Database::Iterator;
497 friend class ::Sawyer::Database::Row;
498 friend class ::Sawyer::Database::Statement;
499 template<
class T>
friend class ::Sawyer::Database::Detail::ColumnReader;
501 using Parameters = Container::Map<std::string, Parameter>;
503 std::shared_ptr<ConnectionBase> connection_;
504 std::weak_ptr<ConnectionBase> weakConnection_;
506 Statement::State state_ = Statement::DEAD;
507 size_t sequence_ = 0;
508 size_t rowNumber_ = 0;
511 virtual ~StatementBase() {}
514 explicit StatementBase(
const std::shared_ptr<ConnectionBase> &connection)
515 : weakConnection_(connection) {
516 ASSERT_not_null(connection);
523 std::pair<std::string, size_t> parseParameters(
const std::string &highSql) {
526 bool inString =
false;
527 size_t nLowParams = 0;
528 state(Statement::READY);
529 for (
size_t i = 0; i < highSql.size(); ++i) {
530 if (
'\'' == highSql[i]) {
531 inString = !inString;
532 lowSql += highSql[i];
533 }
else if (
'?' == highSql[i] && !inString) {
535 std::string paramName;
536 while (i+1 < highSql.size() && (::isalnum(highSql[i+1]) ||
'_' == highSql[i+1]))
537 paramName += highSql[++i];
538 if (paramName.empty())
539 throw Exception(
"invalid parameter name at character position " + boost::lexical_cast<std::string>(i));
540 Parameter ¶m = params_.insertMaybeDefault(paramName);
541 param.append(nLowParams++);
542 state(Statement::UNBOUND);
544 lowSql += highSql[i];
548 state(Statement::DEAD);
549 throw Exception(
"mismatched quotes in SQL statement");
551 return std::make_pair(lowSql, nLowParams);
555 void invalidateIteratorsAndRows() {
560 size_t sequence()
const {
566 bool lockConnection() {
567 return (connection_ = weakConnection_.lock()) !=
nullptr;
572 void unlockConnection() {
578 bool isConnectionLocked()
const {
579 return connection_ !=
nullptr;
583 std::shared_ptr<ConnectionBase> connection()
const {
584 return weakConnection_.lock();
588 Statement::State state()
const {
595 void state(Statement::State newState) {
597 case Statement::DEAD:
598 case Statement::FINISHED:
599 case Statement::UNBOUND:
600 case Statement::READY:
601 invalidateIteratorsAndRows();
604 case Statement::EXECUTING:
605 ASSERT_require(isConnectionLocked());
612 bool hasUnboundParameters()
const {
613 ASSERT_forbid(state() == Statement::DEAD);
614 for (
const Parameter ¶m: params_.values()) {
623 virtual void unbindAllParams() {
624 ASSERT_forbid(state() == Statement::DEAD);
625 for (Parameter ¶m: params_.values())
626 param.isBound =
false;
627 state(params_.isEmpty() ? Statement::READY : Statement::UNBOUND);
632 virtual void reset(
bool doUnbind) {
633 ASSERT_forbid(state() == Statement::DEAD);
634 invalidateIteratorsAndRows();
638 state(hasUnboundParameters() ? Statement::UNBOUND : Statement::READY);
645 void bind(
const std::string &name,
const T &value,
bool isRebind) {
647 throw Exception(
"connection is closed");
649 case Statement::DEAD:
650 throw Exception(
"statement is dead");
651 case Statement::FINISHED:
652 case Statement::EXECUTING:
655 case Statement::READY:
656 case Statement::UNBOUND: {
657 if (!params_.exists(name))
658 throw Exception(
"no such parameter \"" + name +
"\" in statement");
659 Parameter ¶m = params_[name];
660 bool wasUnbound = !param.isBound;
661 for (
size_t idx: param.indexes) {
664 }
catch (
const Exception &e) {
665 if (param.indexes.size() > 1)
666 state(Statement::DEAD);
670 param.isBound =
true;
672 if (wasUnbound && !hasUnboundParameters())
673 state(Statement::READY);
683 bind(name, *value, isRebind);
685 bind(name, Nothing(), isRebind);
690 virtual void bindLow(
size_t idx,
int value) = 0;
691 virtual void bindLow(
size_t idx, int64_t value) = 0;
692 virtual void bindLow(
size_t idx,
size_t value) = 0;
693 virtual void bindLow(
size_t idx,
double value) = 0;
694 virtual void bindLow(
size_t idx,
const std::string &value) = 0;
695 virtual void bindLow(
size_t idx,
const char *cstring) = 0;
696 virtual void bindLow(
size_t idx, Nothing) = 0;
697 virtual void bindLow(
size_t idx,
const std::vector<uint8_t> &data) = 0;
699 Iterator makeIterator() {
700 return Iterator(shared_from_this());
707 throw Exception(
"connection is closed");
709 case Statement::DEAD:
710 throw Exception(
"statement is dead");
711 case Statement::UNBOUND: {
713 for (Parameters::Node ¶m: params_.nodes()) {
714 if (!param.value().isBound)
715 s += (s.empty() ?
"" :
", ") + param.key();
717 ASSERT_forbid(s.empty());
718 throw Exception(
"unbound parameters: " + s);
720 case Statement::FINISHED:
721 case Statement::EXECUTING:
724 case Statement::READY: {
725 if (!lockConnection())
726 throw Exception(
"connection has been closed");
727 state(Statement::EXECUTING);
729 Iterator iter = beginLow();
734 ASSERT_not_reachable(
"invalid state");
739 virtual Iterator beginLow() = 0;
744 throw Exception(
"connection is closed");
745 ASSERT_require(state() == Statement::EXECUTING);
746 invalidateIteratorsAndRows();
752 size_t rowNumber()
const {
758 virtual Iterator nextLow() = 0;
762 Optional<T>
get(
size_t columnIdx) {
764 throw Exception(
"connection is closed");
765 ASSERT_require(state() == Statement::EXECUTING);
766 if (columnIdx >= nColumns())
767 throw Exception(
"column index " + boost::lexical_cast<std::string>(columnIdx) +
" is out of range");
768 return ColumnReader<T>()(
this, columnIdx);
772 virtual size_t nColumns()
const = 0;
775 virtual Optional<std::string> getString(
size_t idx) = 0;
776 virtual Optional<std::vector<std::uint8_t>> getBlob(
size_t idx) = 0;
781 ColumnReader<T>::operator()(StatementBase *stmt,
size_t idx) {
783 if (!stmt->getString(idx).assignTo(str))
785 return boost::lexical_cast<T>(str);
789 inline Optional<std::vector<uint8_t>>
790 ColumnReader<std::vector<uint8_t>>::operator()(StatementBase *stmt,
size_t idx) {
791 return stmt->getBlob(idx);
795 ConnectionBase::makeStatement(
const std::shared_ptr<Detail::StatementBase> &detail) {
796 return Statement(detail);
806 inline Connection::Connection(
const std::shared_ptr<Detail::ConnectionBase> &pimpl)
810 Connection::isOpen()
const {
811 return pimpl_ !=
nullptr;
815 Connection::close() {
821 Connection::driverName()
const {
823 return pimpl_->driverName();
830 Connection::stmt(
const std::string &sql) {
832 return pimpl_->prepareStatement(sql);
834 throw Exception(
"no active database connection");
839 Connection::run(
const std::string &sql) {
846 Connection::get(
const std::string &sql) {
847 for (
auto row: stmt(sql))
848 return row.get<T>(0);
853 Connection::lastInsert()
const {
855 return pimpl_->lastInsert();
857 throw Exception(
"no active database connection");
866 Statement::connection()
const {
868 return Connection(pimpl_->connection());
876 Statement::bind(
const std::string &name,
const T &value) {
878 pimpl_->bind(name, value,
false);
880 throw Exception(
"no active database connection");
887 Statement::rebind(
const std::string &name,
const T &value) {
889 pimpl_->bind(name, value,
true);
891 throw Exception(
"no active database connection");
899 return pimpl_->begin();
901 throw Exception(
"no active database connection");
919 Iterator row = begin();
921 throw Exception(
"query did not return a row");
922 return row->get<T>(0);
930 Iterator::Iterator(
const std::shared_ptr<Detail::StatementBase> &stmt)
934 Iterator::dereference()
const {
936 throw Exception(
"dereferencing the end iterator");
937 if (row_.sequence_ != row_.stmt_->sequence())
938 throw Exception(
"iterator has been invalidated");
943 Iterator::equal(
const Iterator &other)
const {
944 return row_.stmt_ == other.row_.stmt_ && row_.sequence_ == other.row_.sequence_;
948 Iterator::increment() {
950 throw Exception(
"incrementing the end iterator");
951 *
this = row_.stmt_->next();
959 Row::Row(
const std::shared_ptr<Detail::StatementBase> &stmt)
960 : stmt_(stmt), sequence_(stmt ? stmt->sequence() : 0) {}
964 Row::get(
size_t columnIdx)
const {
965 ASSERT_not_null(stmt_);
966 if (sequence_ != stmt_->sequence())
967 throw Exception(
"row has been invalidated");
968 return stmt_->get<T>(columnIdx);
972 Row::rowNumber()
const {
973 ASSERT_not_null(stmt_);
974 if (sequence_ != stmt_->sequence())
975 throw Exception(
"row has been invalidated");
976 return stmt_->rowNumber();
Holds a value or nothing.
bool increment(Word *vec1, const BitRange &range1)
Increment.
Name space for the entire library.