Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqlalchemy_commithooks/commit_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sqlalchemy import event
from sqlalchemy.orm import object_session
from sqlalchemy.orm.session import SessionTransaction
from sqlalchemy.orm.session import SessionTransaction, SessionTransactionOrigin


def _build_add_func(time, action):
Expand Down Expand Up @@ -245,6 +245,6 @@ def _tmp_transaction(session: SessionMixin):
Fix it by providing a temporary transaction.
"""
current_transaction = session.transaction
session.transaction = SessionTransaction(session)
session.transaction = SessionTransaction(session, SessionTransactionOrigin.BEGIN)
yield session
session.transaction = current_transaction
20 changes: 11 additions & 9 deletions sqlalchemy_commithooks/commit_mixin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from mock import Mock
from sqlalchemy import Column, Integer
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import declarative_base, sessionmaker

from . import commit_mixin
from .commit_mixin import _build_add_func, Session
Expand Down Expand Up @@ -102,7 +101,6 @@ def test_multiple_inheritance(self):
assert 'before_commit_from_delete' in self.Multiple._overridden_hooks()



@contextmanager
def _tmp_transaction_patch(session: Session):
"""needed for _do_after_commit"""
Expand Down Expand Up @@ -213,9 +211,9 @@ def after_commit_from_insert(self):
assert data.after_commit_counter == 1


Base = declarative_base()

class TestQueriesAtCommit:
Base = declarative_base()

class Foo(Base):
__tablename__ = "foo"
id = Column(Integer, primary_key=True)
Expand Down Expand Up @@ -258,9 +256,13 @@ def test_after_commit(self):

class TestCommitMixinHooks:
"""verify correct behavior with the hooks we have selected"""

Base = declarative_base()

class Data(Base, commit_mixin.CommitMixin):
__tablename__ = "data"
id = Column(Integer, primary_key=True)
value = Column(Integer, unique=True)

def __init__(self, *args, **kwargs):
self.before_commit_counter = 0
Expand Down Expand Up @@ -322,12 +324,12 @@ def get_session(self):

def test_nested_bad_flush(self):
session = self.get_session()
outer_data = self.Data(id=1)
outer_data = self.Data(value=1)
session.add(outer_data)

session.begin_nested()
with pytest.raises(Exception):
bad_flush_data = self.Data(id=1)
bad_flush_data = self.Data(value=1)
session.add(bad_flush_data)
session.commit()
# except
Expand Down Expand Up @@ -381,7 +383,7 @@ def test_multiple_bad_commits(self, monkeypatch):

data1 = self.Data()
session.add(data1)
with pytest.raises(AttributeError):
with pytest.raises(NotImplementedError):
monkeypatch.delattr('sqlalchemy.engine.base.Transaction.commit')
session.commit()
monkeypatch.undo()
Expand All @@ -390,7 +392,7 @@ def test_multiple_bad_commits(self, monkeypatch):

data2 = self.Data()
session.add(data2)
with pytest.raises(AttributeError):
with pytest.raises(NotImplementedError):
monkeypatch.delattr('sqlalchemy.engine.base.Transaction.commit')
session.commit()
monkeypatch.undo()
Expand Down