Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 14 additions & 24 deletions ags/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,17 @@ def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback):
if exc_value and len(exc_value.args) == 1:
(arg,) = exc_value.args
if not isinstance(arg, ErrorContext):
arg = ErrorContext(arg)
exc_value.args = (arg,)
arg.prepend(self.context)


class ErrorContext:
def __init__(self, message):
self.message = message
self.context = ""

def prepend(self, context):
self.context = context + self.context

def __str__(self):
return f"in {self.context}: {self.message}"

def __repr__(self):
return f"in {self.context}: {self.message!r}"
if not exc_value:
return
if not hasattr(exc_value, '__notes__'):
notes = []
exc_value.__notes__ = notes
else:
notes = exc_value.__notes__
note = "In: " + self.context
if notes and notes[-1].startswith("In: "):
note += notes.pop()[4:]
notes.append(note)


def mismatch(expect, got):
Expand Down Expand Up @@ -150,12 +140,12 @@ def mapping_for(T) -> Mapping:
with context(f".{param.name}"):
if param.kind not in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY):
raise TypeError("positional-only arguments are not supported")
if param.annotation != param.empty:
if param.annotation is not param.empty:
mapping = mapping_for(param.annotation)
if param.default != param.empty:
if param.default is not param.empty:
with context("(default)"):
mapping.lower(param.default, inject_none)
elif param.default != param.empty:
elif param.default is not param.empty:
mapping = mapping_for(type(param.default))
else:
raise TypeError(f"cannot establish type for parameter {param.name}")
Expand Down
34 changes: 25 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import date, time, datetime
from doctest import DocFileSuite
import sys
import traceback

from ags import _mapping

Expand Down Expand Up @@ -86,10 +87,15 @@ class A:
i: int = 10
s: str = 20

with self.assertRaisesRegex(
ValueError, r"in .s\(default\): expects str, got int"
):
with self.assertRaises(ValueError) as cm:
_mapping.mapping_for(A)
s = traceback.format_exception(cm.exception)
self.assertEqual(s, [
'ValueError: expects str, got int\n',
'In: .s(default)\n',
] if sys.version_info >= (3, 11) else [
'ValueError: expects str, got int\n',
])

def test_boundargs(self):
def f(i: int, s: str):
Expand All @@ -104,10 +110,15 @@ def f(i: int = 10, s: str = 20):
pass

sig = signature(f)
with self.assertRaisesRegex(
ValueError, r"in .s\(default\): expects str, got int"
):
with self.assertRaises(ValueError) as cm:
_mapping.mapping_for(sig)
s = traceback.format_exception(cm.exception)
self.assertEqual(s, [
'ValueError: expects str, got int\n',
'In: .s(default)\n',
] if sys.version_info >= (3, 11) else [
'ValueError: expects str, got int\n',
])

def test_union(self):
for modern in False, True:
Expand Down Expand Up @@ -191,10 +202,15 @@ def __eq__(self, other):
def test_exception(self):
T = dict[str, list[int]]
m = _mapping.mapping_for(T)
with self.assertRaisesRegex(
AssertionError, r"in \[b\]\[1\]: <class 'str'> is not <class 'int'>"
):
with self.assertRaises(AssertionError) as cm:
m.unlower({"a": [10, 20], "b": [30, "40", 50]}, self.mysurject)
s = traceback.format_exception(cm.exception)
self.assertEqual(s, [
"AssertionError: <class 'str'> is not <class 'int'>\n",
'In: [b][1]\n',
] if sys.version_info >= (3, 11) else [
"AssertionError: <class 'str'> is not <class 'int'>\n",
])


class Demo:
Expand Down
Loading