From b5aa6068eaa67cfdfd7b5d8a41c532760c07c5a7 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Wed, 23 Jun 2021 12:31:44 +0500 Subject: [PATCH 001/125] Update TODO.md --- TODO.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TODO.md b/TODO.md index 40aaea2..71d555c 100644 --- a/TODO.md +++ b/TODO.md @@ -1,3 +1,4 @@ - Better __repr__, now they are too long and unreadable, - Test all __repr__ and __str__, -- Add testcoverage. \ No newline at end of file +- Add testcoverage, +- Add option to output DBML code for all classes. From ec94cd011b1d06b220f7593884caf38ceaf54a60 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 12 Jul 2021 13:15:49 +0300 Subject: [PATCH 002/125] remove todo, update changelog --- TODO.md | 4 ---- changelog.md | 4 ++++ 2 files changed, 4 insertions(+), 4 deletions(-) delete mode 100644 TODO.md diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 71d555c..0000000 --- a/TODO.md +++ /dev/null @@ -1,4 +0,0 @@ -- Better __repr__, now they are too long and unreadable, -- Test all __repr__ and __str__, -- Add testcoverage, -- Add option to output DBML code for all classes. diff --git a/changelog.md b/changelog.md index d297cf8..dc87af4 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,7 @@ +# 0.4.1 + +- Reworked `__repr__` and `__str__` methods on all classes. + # 0.4.0 - New: Support composite references. **Breaks backward compatibility!** `col1`, `col2` attributes on `Reference` and `col`, `ref_col` attributes on `TableReference` are now lists of `Column` instead of `Column`. From 254570d4a73a42d385f94377964ed9088bd37995 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 12 Jul 2021 18:33:02 +0300 Subject: [PATCH 003/125] add dbml property to all classes --- pydbml/classes.py | 219 +++++++++++++++++++++++++++-- pydbml/tools.py | 19 +++ test/test_classes.py | 328 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 557 insertions(+), 9 deletions(-) create mode 100644 pydbml/tools.py diff --git a/pydbml/classes.py b/pydbml/classes.py index 2b1294a..902c597 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -1,9 +1,9 @@ from __future__ import annotations from typing import Any +from typing import Collection from typing import Dict from typing import List -from typing import Collection from typing import Optional from typing import Tuple from typing import Union @@ -11,6 +11,10 @@ from .exceptions import AttributeMissingError from .exceptions import ColumnNotFoundError from .exceptions import DuplicateReferenceError +from .tools import comment_to_dbml +from .tools import indent +from .tools import note_option_to_dbml + class SQLOjbect: @@ -228,6 +232,40 @@ def sql(self): result += f' ON DELETE {self.on_delete.upper()}' return result + ';' + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += 'Ref' + if self.name: + result += f' {self.name}' + + if len(self.col1) == 1: + col1 = self.col1[0].name + else: + col1 = f'({", ".join(c.name for c in self.col1)})' + + if len(self.col2) == 1: + col2 = self.col2[0].name + else: + col2 = f'({", ".join(c.name for c in self.col2)})' + + options = [] + if self.on_update: + options.append(f'update: {self.on_update}') + if self.on_delete: + options.append(f'delete: {self.on_delete}') + + options_str = f' [{", ".join(options)}]' if options else '' + result += ( + ' {\n ' + f'{self.table1.name}.{col1} ' + f'{self.type} ' + f'{self.table2.name}.{col2}' + f'{options_str}' + '\n}' + ) + return result + class TableReference(SQLOjbect): ''' @@ -312,8 +350,8 @@ def sql(self): class Note: - def __init__(self, text: str): - self.text = text + def __init__(self, text: Any): + self.text = str(text) if text else '' def __str__(self): ''' @@ -341,6 +379,36 @@ def sql(self): else: return '' + @property + def dbml(self): + lines = [] + line = '' + for word in self.text.split(' '): + if len(line) > 80: + lines.append(line) + line = '' + if '\n' in word: + sublines = word.split('\n') + for sl in sublines[:-1]: + line += sl + lines.append(line) + line = '' + line = sublines[-1] + ' ' + else: + line += f'{word} ' + if line: + lines.append(line) + result = 'Note {\n ' + + if len(lines) > 1: + lines_str = '\n '.join(lines)[:-1] + '\n' + result += f"'''\n {lines_str} '''" + else: + result += f"'{lines[0][:-1]}'" + + result += '\n}' + return result + class Column(SQLOjbect): '''Class representing table column.''' @@ -368,7 +436,7 @@ def __init__(self, self.default = default - self.note = note or Note('') + self.note = Note(note) self.ref_blueprints = ref_blueprints or [] for ref in self.ref_blueprints: ref.col1 = self.name @@ -409,6 +477,43 @@ def sql(self): components.append(self.note.sql) return ' '.join(components) + @property + def dbml(self): + def default_to_str(val: str) -> str: + if val.lower() in ('null', 'true', 'false'): + return val.lower() + if val.isdigit(): + return val + try: + float(val) + return val + except ValueError: + pass + if val.startswith('(') and val.endswith(')'): + return f'`{val[1:-1]}`' + return f"'{val}'" + + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'{self.name} {self.type}' + + options = [] + if self.pk: + options.append('pk') + if self.autoinc: + options.append('increment') + if self.default: + options.append(f'default: {default_to_str(self.default)}') + if self.unique: + options.append('unique') + if self.not_null: + options.append('not null') + if self.note: + options.append(note_option_to_dbml(self.note)) + + if options: + result += f' [{", ".join(options)}]' + return result + def __repr__(self): ''' >>> Column('name', 'VARCHAR2') @@ -447,7 +552,7 @@ def __init__(self, self.unique = unique self.type = type_ self.pk = pk - self.note = note or Note('') + self.note = Note(note) self.comment = comment def __repr__(self): @@ -510,6 +615,37 @@ def sql(self): result += f' {self.note.sql}' return result + @property + def dbml(self): + def subject_to_str(val: str) -> str: + if val.startswith('(') and val.endswith(')'): + return f'`{val[1:-1]}`' + else: + return val + + result = comment_to_dbml(self.comment) if self.comment else '' + + if len(self.subject_names) > 1: + result += f'({", ".join(subject_to_str(sn) for sn in self.subject_names)})' + else: + result += self.subject_names[0] + + options = [] + if self.name: + options.append(f"name: '{self.name}'") + if self.pk: + options.append('pk') + if self.unique: + options.append('unique') + if self.type: + options.append(f'type: {self.type}') + if self.note: + options.append(note_option_to_dbml(self.note)) + + if options: + result += f' [{", ".join(options)}]' + return result + class Table(SQLOjbect): '''Class representing table.''' @@ -528,7 +664,7 @@ def __init__(self, self.indexes: List[Index] = [] self.column_dict: Dict[str, Column] = {} self.alias = alias if alias else None - self.note = note or Note('') + self.note = Note(note) self.header_color = header_color self.refs = refs or [] self.comment = comment @@ -630,16 +766,30 @@ def sql(self): components.extend(i.sql + '\n' for i in self.indexes if not i.pk) return '\n'.join(components) + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'Table {self.name} ' + if self.alias: + result += f'as {self.alias} ' + result += '{\n' + columns_str = '\n'.join(c.dbml for c in self.columns) + result += indent(columns_str) + '\n' + if self.note: + result += indent(self.note.dbml) + '\n' + result += '}' + return result + class EnumItem: - '''Single enum item. Does not translate into SQL''' + '''Single enum item''' def __init__(self, name: str, note: Optional[Note] = None, comment: Optional[str] = None): self.name = name - self.note = note or Note('') + self.note = Note(note) self.comment = comment def __repr__(self): @@ -665,6 +815,15 @@ def sql(self): components.append(self.note.sql) return ' '.join(components) + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += self.name + if self.note: + result += f' [{note_option_to_dbml(self.note)}]' + return result + + class Enum(SQLOjbect): required_attributes = ('name', 'items') @@ -726,6 +885,16 @@ def sql(self): '\n'.join(f' {i.sql}' for i in self.items) +\ '\n);' + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'enum {self.name} {{\n' + items_str = '\n'.join(i.dbml for i in self.items) + result += indent(items_str) + result += '\n}' + return result + + class EnumType(Enum): ''' @@ -778,6 +947,21 @@ def __getitem__(self, key) -> str: def __iter__(self): return iter(self.items) + @property + def dbml(self): + def item_to_str(val: Union[str, Table]) -> str: + if isinstance(val, Table): + return val.name + else: + return val + + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'TableGroup {self.name} {{\n' + for i in self.items: + result += f' {item_to_str(i)}\n' + result += '}' + return result + class Project: def __init__(self, @@ -787,7 +971,7 @@ def __init__(self, comment: Optional[str] = None): self.name = name self.items = items - self.note = note or Note('') + self.note = Note(note) self.comment = comment def __repr__(self): @@ -797,3 +981,20 @@ def __repr__(self): """ return f'' + + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'Project {self.name} {{\n' + if self.items: + items_str = '' + for k, v in self.items.items(): + if '\n' in v: + items_str += f"{k}: '''{v}'''\n" + else: + items_str += f"{k}: '{v}'\n" + result += indent(items_str[:-1]) + '\n' + if self.note: + result += indent(self.note.dbml) + '\n' + result += '}' + return result diff --git a/pydbml/tools.py b/pydbml/tools.py new file mode 100644 index 0000000..640ab12 --- /dev/null +++ b/pydbml/tools.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .classes import Note + +def comment_to_dbml(val: str) -> str: + return '\n'.join(f'// {cl}' for cl in val.split('\n')) + '\n' + + +def note_option_to_dbml(val: 'Note') -> str: + if '\n' in val.text: + return f"note: '''{val.text}'''" + else: + return f"note: '{val.text}'" + + +def indent(val: str, spaces=4) -> str: + if val == '': + return val + return ' ' * spaces + val.replace('\n', '\n' +' ' * spaces) \ No newline at end of file diff --git a/test/test_classes.py b/test/test_classes.py index 43e9fb0..863aa00 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -5,9 +5,12 @@ from pydbml.classes import EnumItem from pydbml.classes import Index from pydbml.classes import Note +from pydbml.classes import Project +from pydbml.classes import Reference from pydbml.classes import ReferenceBlueprint from pydbml.classes import SQLOjbect from pydbml.classes import Table +from pydbml.classes import TableGroup from pydbml.classes import TableReference from pydbml.exceptions import AttributeMissingError from pydbml.exceptions import ColumnNotFoundError @@ -166,6 +169,48 @@ def test_table_setter(self) -> None: self.assertEqual(c.ref_blueprints[0].table1, t.name) self.assertEqual(c.ref_blueprints[1].table1, t.name) + def test_dbml_simple(self): + c = Column( + name='order', + type_='integer' + ) + expected = 'order integer' + + self.assertEqual(c.dbml, expected) + + def test_dbml_full(self): + c = Column( + name='order', + type_='integer', + unique=True, + not_null=True, + pk=True, + autoinc=True, + default='Def_value', + note='Note on the column', + comment='Comment on the column' + ) + expected = \ +'''// Comment on the column +order integer [pk, increment, default: 'Def_value', unique, not null, note: 'Note on the column']''' + + self.assertEqual(c.dbml, expected) + + def test_dbml_multiline_note(self): + c = Column( + name='order', + type_='integer', + not_null=True, + note='Note on the column\nmultiline', + comment='Comment on the column' + ) + expected = \ +"""// Comment on the column +order integer [not null, note: '''Note on the column +multiline''']""" + + self.assertEqual(c.dbml, expected) + class TestIndex(TestCase): def test_basic_sql(self) -> None: @@ -221,6 +266,36 @@ def test_composite_with_expression(self) -> None: expected = 'CREATE INDEX ON "products" ("id", (id*3));' self.assertEqual(r.sql, expected) + def test_dbml_simple(self): + i = Index( + ['id'] + ) + + expected = 'id' + self.assertEqual(i.dbml, expected) + + def test_dbml_composite(self): + i = Index( + ['id', '(id*3)'] + ) + + expected = '(id, `id*3`)' + self.assertEqual(i.dbml, expected) + + def test_dbml_full(self): + i = Index( + ['id', '(getdate())'], + name='Dated id', + unique=True, + type_='hash', + pk=True, + note='Note on the column', + comment='Comment on the index' + ) + expected = \ +'''// Comment on the index +(id, `getdate()`) [name: 'Dated id', pk, unique, type: hash, note: 'Note on the column']''' + class TestTable(TestCase): def test_one_column(self) -> None: @@ -347,6 +422,62 @@ def test_add_bad_index(self) -> None: with self.assertRaises(ColumnNotFoundError): t.add_index(i) + def test_dbml_simple(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + expected = \ +'''Table products { + id integer + name varchar2 +}''' + self.assertEqual(t.dbml, expected) + + def test_dbml_full(self): + t = Table( + 'products', + alias='pd', + note='My multiline\nnote', + comment='My multiline\ncomment' + ) + c0 = Column('zero', 'number') + c1 = Column('id', 'integer', unique=True, note='Multiline\ncomment note') + c2 = Column('name', 'varchar2') + t.add_column(c0) + t.add_column(c1) + t.add_column(c2) + expected = \ +"""// My multiline +// comment +Table products as pd { + zero number + id integer [unique, note: '''Multiline + comment note'''] + name varchar2 + Note { + ''' + My multiline + note + ''' + } +}""" + self.assertEqual(t.dbml, expected) + +class TestEnumItem(TestCase): + def test_dbml_simple(self): + ei = EnumItem('en-US') + expected = 'en-US' + self.assertEqual(ei.dbml, expected) + + def test_dbml_full(self): + ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') + expected = \ +'''// EnumItem comment +en-US [note: 'preferred']''' + self.assertEqual(ei.dbml, expected) + class TestEnum(TestCase): def test_simple_enum(self) -> None: @@ -383,3 +514,200 @@ def test_notes(self) -> None: 'failure', );''' self.assertEqual(e.sql, expected) + + def test_dbml_simple(self): + items = [EnumItem('en-US'), EnumItem('ru-RU'), EnumItem('en-GB')] + e = Enum('lang', items) + expected = \ +'''enum lang { + en-US + ru-RU + en-GB +}''' + self.assertEqual(e.dbml, expected) + + def test_dbml_full(self): + items = [ + EnumItem('en-US', note='preferred'), + EnumItem('ru-RU', comment='Multiline\ncomment'), + EnumItem('en-GB')] + e = Enum('lang', items, comment="Enum comment") + expected = \ +'''// Enum comment +enum lang { + en-US [note: 'preferred'] + // Multiline + // comment + ru-RU + en-GB +}''' + self.assertEqual(e.dbml, expected) + + +class TestReference(TestCase): + def test_dbml_simple(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference('>', t, c2, t2, c21) + + expected = \ +'''Ref { + products.name > names.name_val +}''' + self.assertEqual(ref.dbml, expected) + + def test_dbml_full(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + c3 = Column('country', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference( + '<', + t, + [c2, c3], + t2, + (c21, c22), + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL' + ) + + expected = \ +'''// Reference comment +// multiline +Ref nameref { + products.(name, country) < names.(name_val, country) [update: CASCADE, delete: SET NULL] +}''' + self.assertEqual(ref.dbml, expected) + + +class TestNote(TestCase): + def test_init_types(self): + n1 = Note('My note text') + n2 = Note(3) + n3 = Note([1, 2, 3]) + n4 = Note(None) + n5 = Note(n1) + + self.assertEqual(n1.text, 'My note text') + self.assertEqual(n2.text, '3') + self.assertEqual(n3.text, '[1, 2, 3]') + self.assertEqual(n4.text, '') + self.assertEqual(n5.text, 'My note text') + + def test_oneline(self): + note = Note('One line of note text') + expected = \ +'''Note { + 'One line of note text' +}''' + self.assertEqual(note.dbml, expected) + + def test_multiline(self): + note = Note('The number of spaces you use to indent a block string will be the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') + expected = \ +"""Note { + ''' + The number of spaces you use to indent a block string will be the minimum number + of leading spaces among all lines. The parser will automatically remove the number + of indentation spaces in the final output. + ''' +}""" + self.assertEqual(note.dbml, expected) + + + def test_forced_multiline(self): + note = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') + expected = \ +"""Note { + ''' + The number of spaces you use to indent a block string + will + be the minimum number of leading spaces among all lines. The parser will automatically + remove the number of indentation spaces in the final output. + ''' +}""" + self.assertEqual(note.dbml, expected) + + +class TestTableGroup(TestCase): + def test_dbml(self): + tg = TableGroup('mytg', ['merchants', 'countries', 'customers']) + expected = \ +'''TableGroup mytg { + merchants + countries + customers +}''' + self.assertEqual(tg.dbml, expected) + + def test_dbml_with_comment_and_real_tables(self): + merchants = Table('merchants') + countries = Table('countries') + customers = Table('customers') + tg = TableGroup( + 'mytg', + [merchants, countries, customers], + comment='My table group\nmultiline comment' + ) + expected = \ +'''// My table group +// multiline comment +TableGroup mytg { + merchants + countries + customers +}''' + self.assertEqual(tg.dbml, expected) + +class TestProject(TestCase): + def test_dbml_note(self): + p = Project('myproject', note='Project note') + expected = \ +'''Project myproject { + Note { + 'Project note' + } +}''' + self.assertEqual(p.dbml, expected) + + def test_dbml_full(self): + p = Project( + 'myproject', + items={ + 'database_type': 'PostgreSQL', + 'story': "One day I was eating my cantaloupe and\nI thought, why shouldn't I?\nWhy shouldn't I create a database?" + }, + comment='Multiline\nProject comment', + note='Multiline\nProject note') + expected = \ +"""// Multiline +// Project comment +Project myproject { + database_type: 'PostgreSQL' + story: '''One day I was eating my cantaloupe and + I thought, why shouldn't I? + Why shouldn't I create a database?''' + Note { + ''' + Multiline + Project note + ''' + } +}""" + self.assertEqual(p.dbml, expected) From acfe86428892c2b3802c9eaa283186378cc54722 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 12 Jul 2021 18:33:11 +0300 Subject: [PATCH 004/125] udpate todo --- TODO.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 TODO.md diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..bf12adb --- /dev/null +++ b/TODO.md @@ -0,0 +1,3 @@ +- Notes should be converted to COMMENT ON +- comments should be converted to comments +- Docs for creating dbml schema in python From b4bd87a154dbfe50fb611a9ecab7a994a7b89727 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 12 Jul 2021 19:03:29 +0300 Subject: [PATCH 005/125] add indexes to table dbml, fix error in index dbml --- pydbml/classes.py | 8 +++++++- pydbml/parser.py | 10 ++++++++++ test/test_classes.py | 10 ++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/pydbml/classes.py b/pydbml/classes.py index 902c597..8f2f8cb 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -628,7 +628,7 @@ def subject_to_str(val: str) -> str: if len(self.subject_names) > 1: result += f'({", ".join(subject_to_str(sn) for sn in self.subject_names)})' else: - result += self.subject_names[0] + result += subject_to_str(self.subject_names[0]) options = [] if self.name: @@ -777,6 +777,12 @@ def dbml(self): result += indent(columns_str) + '\n' if self.note: result += indent(self.note.dbml) + '\n' + if self.indexes: + result += '\n indexes {\n' + indexes_str = '\n'.join(i.dbml for i in self.indexes) + result += indent(indexes_str, 8) + '\n' + result += ' }\n' + result += '}' return result diff --git a/pydbml/parser.py b/pydbml/parser.py index 3b969ee..2cf431d 100644 --- a/pydbml/parser.py +++ b/pydbml/parser.py @@ -268,3 +268,13 @@ def sql(self): components = (i.sql for i in (*self.enums, *self.tables)) return '\n'.join(components) + + @property + def dbml(self): + '''Generates DBML code out of parsed results''' + items = [self.project] if self.project else [] + items.update((*self.tables, *self.refs, *self.enums, *self.table_groups)) + components = ( + i.dbml for i in () + ) + return '\n\n'.join(components) diff --git a/test/test_classes.py b/test/test_classes.py index 863aa00..118bbc2 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -448,6 +448,10 @@ def test_dbml_full(self): t.add_column(c0) t.add_column(c1) t.add_column(c2) + i1 = Index(['zero', 'id'], unique=True) + i2 = Index(['(capitalize(name))'], comment="index comment") + t.add_index(i1) + t.add_index(i2) expected = \ """// My multiline // comment @@ -462,6 +466,12 @@ def test_dbml_full(self): note ''' } + + indexes { + (zero, id) [unique] + // index comment + `capitalize(name)` + } }""" self.assertEqual(t.dbml, expected) From 7f58a25ef63c04f2b6161c73f1e05a6b8a545ae4 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 12 Jul 2021 19:14:37 +0300 Subject: [PATCH 006/125] fix column default dbml --- pydbml/classes.py | 19 ++++++++----------- test/test_classes.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/pydbml/classes.py b/pydbml/classes.py index 8f2f8cb..b3f7d9c 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -480,18 +480,15 @@ def sql(self): @property def dbml(self): def default_to_str(val: str) -> str: - if val.lower() in ('null', 'true', 'false'): - return val.lower() - if val.isdigit(): - return val - try: - float(val) + if isinstance(val, str): + if val.lower() in ('null', 'true', 'false'): + return val.lower() + elif val.startswith('(') and val.endswith(')'): + return f'`{val[1:-1]}`' + else: + return f"'{val}'" + else: # int or float or bool return val - except ValueError: - pass - if val.startswith('(') and val.endswith(')'): - return f'`{val[1:-1]}`' - return f"'{val}'" result = comment_to_dbml(self.comment) if self.comment else '' result += f'{self.name} {self.type}' diff --git a/test/test_classes.py b/test/test_classes.py index 118bbc2..fc8e232 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -211,6 +211,38 @@ def test_dbml_multiline_note(self): self.assertEqual(c.dbml, expected) + def test_dbml_default(self): + c = Column( + name='order', + type_='integer', + default='String value' + ) + expected = "order integer [default: 'String value']" + self.assertEqual(c.dbml, expected) + + c.default = 3 + expected = 'order integer [default: 3]' + self.assertEqual(c.dbml, expected) + + c.default = 3.33 + expected = 'order integer [default: 3.33]' + self.assertEqual(c.dbml, expected) + + c.default = "(now() - interval '5 days')" + expected = "order integer [default: `now() - interval '5 days'`]" + self.assertEqual(c.dbml, expected) + + c.default = 'NULL' + expected = 'order integer [default: null]' + self.assertEqual(c.dbml, expected) + + c.default = 'TRue' + expected = 'order integer [default: true]' + self.assertEqual(c.dbml, expected) + + c.default = 'false' + expected = 'order integer [default: false]' + self.assertEqual(c.dbml, expected) class TestIndex(TestCase): def test_basic_sql(self) -> None: From f8f47fc5b5e7c4e3a39f69e9648336966f99db26 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 12 Jul 2021 20:00:18 +0300 Subject: [PATCH 007/125] add double quotes for names --- pydbml/classes.py | 24 +++++++++--------- test/test_classes.py | 58 ++++++++++++++++++++++---------------------- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/pydbml/classes.py b/pydbml/classes.py index b3f7d9c..bab94e5 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -240,14 +240,16 @@ def dbml(self): result += f' {self.name}' if len(self.col1) == 1: - col1 = self.col1[0].name + col1 = f'"{self.col1[0].name}"' else: - col1 = f'({", ".join(c.name for c in self.col1)})' + names = (f'"{c.name}"' for c in self.col1) + col1 = f'({", ".join(names)})' if len(self.col2) == 1: - col2 = self.col2[0].name + col2 = f'"{self.col2[0].name}"' else: - col2 = f'({", ".join(c.name for c in self.col2)})' + names = (f'"{c.name}"' for c in self.col2) + col2 = f'({", ".join(names)})' options = [] if self.on_update: @@ -258,9 +260,9 @@ def dbml(self): options_str = f' [{", ".join(options)}]' if options else '' result += ( ' {\n ' - f'{self.table1.name}.{col1} ' + f'"{self.table1.name}".{col1} ' f'{self.type} ' - f'{self.table2.name}.{col2}' + f'"{self.table2.name}".{col2}' f'{options_str}' '\n}' ) @@ -491,7 +493,7 @@ def default_to_str(val: str) -> str: return val result = comment_to_dbml(self.comment) if self.comment else '' - result += f'{self.name} {self.type}' + result += f'"{self.name}" {self.type}' options = [] if self.pk: @@ -766,9 +768,9 @@ def sql(self): @property def dbml(self): result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Table {self.name} ' + result += f'Table "{self.name}" ' if self.alias: - result += f'as {self.alias} ' + result += f'as "{self.alias}" ' result += '{\n' columns_str = '\n'.join(c.dbml for c in self.columns) result += indent(columns_str) + '\n' @@ -821,7 +823,7 @@ def sql(self): @property def dbml(self): result = comment_to_dbml(self.comment) if self.comment else '' - result += self.name + result += f'"{self.name}"' if self.note: result += f' [{note_option_to_dbml(self.note)}]' return result @@ -891,7 +893,7 @@ def sql(self): @property def dbml(self): result = comment_to_dbml(self.comment) if self.comment else '' - result += f'enum {self.name} {{\n' + result += f'Enum "{self.name}" {{\n' items_str = '\n'.join(i.dbml for i in self.items) result += indent(items_str) result += '\n}' diff --git a/test/test_classes.py b/test/test_classes.py index fc8e232..522ee47 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -174,7 +174,7 @@ def test_dbml_simple(self): name='order', type_='integer' ) - expected = 'order integer' + expected = '"order" integer' self.assertEqual(c.dbml, expected) @@ -192,7 +192,7 @@ def test_dbml_full(self): ) expected = \ '''// Comment on the column -order integer [pk, increment, default: 'Def_value', unique, not null, note: 'Note on the column']''' +"order" integer [pk, increment, default: 'Def_value', unique, not null, note: 'Note on the column']''' self.assertEqual(c.dbml, expected) @@ -206,7 +206,7 @@ def test_dbml_multiline_note(self): ) expected = \ """// Comment on the column -order integer [not null, note: '''Note on the column +"order" integer [not null, note: '''Note on the column multiline''']""" self.assertEqual(c.dbml, expected) @@ -217,31 +217,31 @@ def test_dbml_default(self): type_='integer', default='String value' ) - expected = "order integer [default: 'String value']" + expected = "\"order\" integer [default: 'String value']" self.assertEqual(c.dbml, expected) c.default = 3 - expected = 'order integer [default: 3]' + expected = '"order" integer [default: 3]' self.assertEqual(c.dbml, expected) c.default = 3.33 - expected = 'order integer [default: 3.33]' + expected = '"order" integer [default: 3.33]' self.assertEqual(c.dbml, expected) c.default = "(now() - interval '5 days')" - expected = "order integer [default: `now() - interval '5 days'`]" + expected = "\"order\" integer [default: `now() - interval '5 days'`]" self.assertEqual(c.dbml, expected) c.default = 'NULL' - expected = 'order integer [default: null]' + expected = '"order" integer [default: null]' self.assertEqual(c.dbml, expected) c.default = 'TRue' - expected = 'order integer [default: true]' + expected = '"order" integer [default: true]' self.assertEqual(c.dbml, expected) c.default = 'false' - expected = 'order integer [default: false]' + expected = '"order" integer [default: false]' self.assertEqual(c.dbml, expected) class TestIndex(TestCase): @@ -461,9 +461,9 @@ def test_dbml_simple(self): t.add_column(c1) t.add_column(c2) expected = \ -'''Table products { - id integer - name varchar2 +'''Table "products" { + "id" integer + "name" varchar2 }''' self.assertEqual(t.dbml, expected) @@ -487,11 +487,11 @@ def test_dbml_full(self): expected = \ """// My multiline // comment -Table products as pd { - zero number - id integer [unique, note: '''Multiline +Table "products" as "pd" { + "zero" number + "id" integer [unique, note: '''Multiline comment note'''] - name varchar2 + "name" varchar2 Note { ''' My multiline @@ -510,14 +510,14 @@ def test_dbml_full(self): class TestEnumItem(TestCase): def test_dbml_simple(self): ei = EnumItem('en-US') - expected = 'en-US' + expected = '"en-US"' self.assertEqual(ei.dbml, expected) def test_dbml_full(self): ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') expected = \ '''// EnumItem comment -en-US [note: 'preferred']''' +"en-US" [note: 'preferred']''' self.assertEqual(ei.dbml, expected) @@ -561,10 +561,10 @@ def test_dbml_simple(self): items = [EnumItem('en-US'), EnumItem('ru-RU'), EnumItem('en-GB')] e = Enum('lang', items) expected = \ -'''enum lang { - en-US - ru-RU - en-GB +'''Enum "lang" { + "en-US" + "ru-RU" + "en-GB" }''' self.assertEqual(e.dbml, expected) @@ -576,12 +576,12 @@ def test_dbml_full(self): e = Enum('lang', items, comment="Enum comment") expected = \ '''// Enum comment -enum lang { - en-US [note: 'preferred'] +Enum "lang" { + "en-US" [note: 'preferred'] // Multiline // comment - ru-RU - en-GB + "ru-RU" + "en-GB" }''' self.assertEqual(e.dbml, expected) @@ -600,7 +600,7 @@ def test_dbml_simple(self): expected = \ '''Ref { - products.name > names.name_val + "products"."name" > "names"."name_val" }''' self.assertEqual(ref.dbml, expected) @@ -633,7 +633,7 @@ def test_dbml_full(self): '''// Reference comment // multiline Ref nameref { - products.(name, country) < names.(name_val, country) [update: CASCADE, delete: SET NULL] + "products".("name", "country") < "names".("name_val", "country") [update: CASCADE, delete: SET NULL] }''' self.assertEqual(ref.dbml, expected) From a4f7713187b764893fb71a858d35621ebe511cd9 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 12 Jul 2021 20:04:42 +0300 Subject: [PATCH 008/125] add dbml property to parser --- pydbml/parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydbml/parser.py b/pydbml/parser.py index 2cf431d..9310055 100644 --- a/pydbml/parser.py +++ b/pydbml/parser.py @@ -273,8 +273,8 @@ def sql(self): def dbml(self): '''Generates DBML code out of parsed results''' items = [self.project] if self.project else [] - items.update((*self.tables, *self.refs, *self.enums, *self.table_groups)) + items.extend((*self.tables, *self.refs, *self.enums, *self.table_groups)) components = ( - i.dbml for i in () + i.dbml for i in items ) return '\n\n'.join(components) From f467adc858a3bba70e4ef5af8f18d54d29bc9d25 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 09:46:37 +0300 Subject: [PATCH 009/125] update todo --- TODO.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/TODO.md b/TODO.md index bf12adb..a53be08 100644 --- a/TODO.md +++ b/TODO.md @@ -1,3 +1,5 @@ +- Docs for converting to dbml - Notes should be converted to COMMENT ON - comments should be converted to comments +- Creating dbml schema in python - Docs for creating dbml schema in python From 41f00a10f2a6fe0cfbb5e715cb829021be491a5c Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 10:12:11 +0300 Subject: [PATCH 010/125] fix and test Reference sql. Add comment into Reference sql --- pydbml/classes.py | 16 ++++++------ pydbml/tools.py | 4 +++ test/test_classes.py | 58 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/pydbml/classes.py b/pydbml/classes.py index bab94e5..ee4d19e 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -12,6 +12,7 @@ from .exceptions import ColumnNotFoundError from .exceptions import DuplicateReferenceError from .tools import comment_to_dbml +from .tools import comment_to_sql from .tools import indent from .tools import note_option_to_dbml @@ -213,18 +214,19 @@ def sql(self): if self.type in (self.MANY_TO_ONE, self.ONE_TO_ONE): t1 = self.table1 - c1 = ', '.join(self.col1) + c1 = ', '.join(f'"{c.name}"' for c in self.col1) t2 = self.table2 - c2 = ', '.join(self.col2) + c2 = ', '.join(f'"{c.name}"' for c in self.col2) else: t1 = self.table2 - c1 = ', '.join(self.col2) + c1 = ', '.join(f'"{c.name}"' for c in self.col2) t2 = self.table1 - c2 = ', '.join(self.col1) + c2 = ', '.join(f'"{c.name}"' for c in self.col1) - result = ( - f'ALTER TABLE "{t1.name}" ADD {c}FOREIGN KEY ("{c1.name}") ' - f'REFERENCES "{t2.name} ("{c2.name}")' + result = comment_to_sql(self.comment) if self.comment else '' + result += ( + f'ALTER TABLE "{t1.name}" ADD {c}FOREIGN KEY ({c1}) ' + f'REFERENCES "{t2.name}" ({c2})' ) if self.on_update: result += f' ON UPDATE {self.on_update.upper()}' diff --git a/pydbml/tools.py b/pydbml/tools.py index 640ab12..f5e3579 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -6,6 +6,10 @@ def comment_to_dbml(val: str) -> str: return '\n'.join(f'// {cl}' for cl in val.split('\n')) + '\n' +def comment_to_sql(val: str) -> str: + return '\n'.join(f'-- {cl}' for cl in val.split('\n')) + '\n' + + def note_option_to_dbml(val: 'Note') -> str: if '\n' in val.text: return f"note: '''{val.text}'''" diff --git a/test/test_classes.py b/test/test_classes.py index 522ee47..3ebd3f9 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -587,6 +587,64 @@ def test_dbml_full(self): class TestReference(TestCase): + def test_sql_single(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('>', t, c1, t2, c2) + + expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' + self.assertEqual(ref.sql, expected) + + def test_sql_multiple(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('>', t, [c11, c12], t2, (c21, c22)) + + expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val");' + self.assertEqual(ref.sql, expected) + + def test_sql_full(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference( + '>', + t, + [c11, c12], + t2, + (c21, c22), + name="country_name", + comment="Multiline\ncomment for the constraint", + on_update="CASCADE", + on_delete="SET NULL" + ) + + expected = \ +'''-- Multiline +-- comment for the constraint +ALTER TABLE "products" ADD CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL;''' + + self.assertEqual(ref.sql, expected) + def test_dbml_simple(self): t = Table('products') c1 = Column('id', 'integer') From e4fb80f4b03c4fccc5977eac0d46c569a2d0ebf7 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 11:28:03 +0300 Subject: [PATCH 011/125] fix and test TableReference sql, update changelog --- changelog.md | 1 + pydbml/classes.py | 2 +- test/test_classes.py | 75 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/changelog.md b/changelog.md index dc87af4..2e5bd43 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,7 @@ # 0.4.1 - Reworked `__repr__` and `__str__` methods on all classes. +- Fix: sql for Reference and TableReference. # 0.4.0 diff --git a/pydbml/classes.py b/pydbml/classes.py index ee4d19e..aa2e189 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -344,7 +344,7 @@ def sql(self): ref_cols = '", "'.join(c.name for c in self.ref_col) result = ( f'{c}FOREIGN KEY ("{cols}") ' - f'REFERENCES "{self.ref_table.name} ("{ref_cols}")' + f'REFERENCES "{self.ref_table.name}" ("{ref_cols}")' ) if self.on_update: result += f' ON UPDATE {self.on_update.upper()}' diff --git a/test/test_classes.py b/test/test_classes.py index 3ebd3f9..53a7fc3 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -7,6 +7,7 @@ from pydbml.classes import Note from pydbml.classes import Project from pydbml.classes import Reference +from pydbml.classes import TableReference from pydbml.classes import ReferenceBlueprint from pydbml.classes import SQLOjbect from pydbml.classes import Table @@ -352,7 +353,7 @@ def test_ref(self) -> None: '''CREATE TABLE "products" ( "id" integer, "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names ("name_val") + FOREIGN KEY ("name") REFERENCES "names" ("name_val") ); ''' self.assertEqual(t.sql, expected) @@ -398,7 +399,7 @@ def test_ref_index(self) -> None: '''CREATE TABLE "products" ( "id" integer, "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names ("name_val") + FOREIGN KEY ("name") REFERENCES "names" ("name_val") ); CREATE INDEX ON "products" ("id", "name"); @@ -586,6 +587,64 @@ def test_dbml_full(self): self.assertEqual(e.dbml, expected) + +class TestTableReference(TestCase): + def test_sql_single(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = TableReference( + c1, t2, c2) + + expected = 'FOREIGN KEY ("name") REFERENCES "names" ("name_val")' + self.assertEqual(ref.sql, expected) + + def test_sql_multiple(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = TableReference( + [c11, c12], + t2, + (c21, c22) + ) + + expected = 'FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val")' + self.assertEqual(ref.sql, expected) + + def test_sql_full(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = TableReference( + [c11, c12], + t2, + (c21, c22), + name="country_name", + on_delete='SET NULL', + on_update='CASCADE' + ) + + expected = 'CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL' + self.assertEqual(ref.sql, expected) + class TestReference(TestCase): def test_sql_single(self): t = Table('products') @@ -599,6 +658,18 @@ def test_sql_single(self): expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' self.assertEqual(ref.sql, expected) + def test_sql_reverse(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('<', t, c1, t2, c2) + + expected = 'ALTER TABLE "names" ADD FOREIGN KEY ("name_val") REFERENCES "products" ("name");' + self.assertEqual(ref.sql, expected) + def test_sql_multiple(self): t = Table('products') c11 = Column('name', 'varchar2') From 8e26ba4f58dee50b9f2745a62268ced4a22e3913 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 12:10:54 +0300 Subject: [PATCH 012/125] add notes and comments to the rest of the classes sql, add tests --- pydbml/classes.py | 55 ++++++++++++++--------- test/test_classes.py | 101 +++++++++++++++++++++++++++++++++---------- 2 files changed, 112 insertions(+), 44 deletions(-) diff --git a/pydbml/classes.py b/pydbml/classes.py index aa2e189..04e3fcc 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -477,9 +477,10 @@ def sql(self): components.append('NOT NULL') if self.default is not None: components.append('DEFAULT ' + str(self.default)) - if self.note: - components.append(self.note.sql) - return ' '.join(components) + + result = comment_to_sql(self.comment) if self.comment else '' + result += ' '.join(components) + return result @property def dbml(self): @@ -599,7 +600,9 @@ def sql(self): self.check_attributes_for_sql() keys = ', '.join(f'"{key.name}"' if isinstance(key, Column) else key for key in self.subjects) if self.pk: - return f'PRIMARY KEY ({keys})' + result = comment_to_sql(self.comment) if self.comment else '' + result += f'PRIMARY KEY ({keys})' + return result components = ['CREATE'] if self.unique: @@ -611,9 +614,8 @@ def sql(self): if self.type: components.append(f'USING {self.type.upper()}') components.append(f'({keys})') - result = ' '.join(components) + ';' - if self.note: - result += f' {self.note.sql}' + result = comment_to_sql(self.comment) if self.comment else '' + result += ' '.join(components) + ';' return result @property @@ -756,16 +758,28 @@ def sql(self): ''' self.check_attributes_for_sql() components = [f'CREATE TABLE "{self.name}" ('] - if self.note: - components.append(f' {self.note.sql}') body = [] - body.extend(' ' + c.sql for c in self.columns) - body.extend(' ' + i.sql for i in self.indexes if i.pk) - body.extend(' ' + r.sql for r in self.refs) + body.extend(indent(c.sql, 2) for c in self.columns) + body.extend(indent(i.sql, 2) for i in self.indexes if i.pk) + body.extend(indent(r.sql, 2) for r in self.refs) components.append(',\n'.join(body)) components.append(');\n') components.extend(i.sql + '\n' for i in self.indexes if not i.pk) - return '\n'.join(components) + + result = comment_to_sql(self.comment) if self.comment else '' + result += '\n'.join(components) + + if self.note: + quoted_note = f"'{self.note.text}'" + note_sql = f'COMMENT ON TABLE "{self.name}" IS {quoted_note};' + result += f'\n\n{note_sql}' + + for col in self.columns: + if col.note: + quoted_note = f"'{col.note.text}'" + note_sql = f'COMMENT ON COLUMN "{self.name}"."{col.name}" IS {quoted_note};' + result += f'\n\n{note_sql}' + return result @property def dbml(self): @@ -817,10 +831,9 @@ def __str__(self): @property def sql(self): - components = [f"'{self.name}',"] - if self.note: - components.append(self.note.sql) - return ' '.join(components) + result = comment_to_sql(self.comment) if self.comment else '' + result += f"'{self.name}'," + return result @property def dbml(self): @@ -888,9 +901,11 @@ def sql(self): ''' self.check_attributes_for_sql() - return f'CREATE TYPE "{self.name}" AS ENUM (\n' +\ - '\n'.join(f' {i.sql}' for i in self.items) +\ - '\n);' + result = comment_to_sql(self.comment) if self.comment else '' + result += f'CREATE TYPE "{self.name}" AS ENUM (\n' + result +='\n'.join(f'{indent(i.sql, 2)}' for i in self.items) + result += '\n);' + return result @property def dbml(self): diff --git a/test/test_classes.py b/test/test_classes.py index 53a7fc3..3d554b3 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -110,14 +110,6 @@ def test_basic_sql(self) -> None: expected = '"id" integer' self.assertEqual(r.sql, expected) - def test_note(self) -> None: - n = Note('Column note') - r = Column(name='id', - type_='integer', - note=n) - expected = '"id" integer -- Column note' - self.assertEqual(r.sql, expected) - def test_pk_autoinc(self) -> None: r = Column(name='id', type_='integer', @@ -141,6 +133,17 @@ def test_default(self) -> None: expected = '"order" integer DEFAULT 0' self.assertEqual(r.sql, expected) + def test_comment(self) -> None: + r = Column(name='id', + type_='integer', + unique=True, + not_null=True, + comment="Column comment") + expected = \ +'''-- Column comment +"id" integer UNIQUE NOT NULL''' + self.assertEqual(r.sql, expected) + def test_table_setter(self) -> None: r1 = ReferenceBlueprint( ReferenceBlueprint.MANY_TO_ONE, @@ -255,15 +258,17 @@ def test_basic_sql(self) -> None: expected = 'CREATE INDEX ON "products" ("id");' self.assertEqual(r.sql, expected) - def test_note(self) -> None: + def test_comment(self) -> None: t = Table('products') t.add_column(Column('id', 'integer')) - n = Note('Index note') r = Index(subject_names=['id'], table=t, - note=n) + comment='Index comment') t.add_index(r) - expected = 'CREATE INDEX ON "products" ("id"); -- Index note' + expected = \ +'''-- Index comment +CREATE INDEX ON "products" ("id");''' + self.assertEqual(r.sql, expected) def test_unique_type_composite(self) -> None: @@ -374,12 +379,31 @@ def test_duplicate_ref(self) -> None: with self.assertRaises(DuplicateReferenceError): t.add_ref(r2) - def test_note(self) -> None: + def test_notes(self) -> None: n = Note('Table note') + nc1 = Note('First column note') + nc2 = Note('Another column\nmultiline note') t = Table('products', note=n) - c = Column('id', 'integer') - t.add_column(c) - expected = 'CREATE TABLE "products" (\n -- Table note\n "id" integer\n);\n' + c1 = Column('id', 'integer', note=nc1) + c2 = Column('name', 'varchar') + c3 = Column('country', 'varchar', note=nc2) + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar, + "country" varchar +); + + +COMMENT ON TABLE "products" IS 'Table note'; + +COMMENT ON COLUMN "products"."id" IS 'First column note'; + +COMMENT ON COLUMN "products"."country" IS 'Another column +multiline note';''' self.assertEqual(t.sql, expected) def test_ref_index(self) -> None: @@ -420,6 +444,27 @@ def test_index_inline(self) -> None: "name" varchar2, PRIMARY KEY ("id", "name") ); +''' + self.assertEqual(t.sql, expected) + + def test_index_inline_and_comments(self) -> None: + t = Table('products', comment='Multiline\ntable comment') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + i = Index(['id', 'name'], pk=True, comment='Multiline\nindex comment') + t.add_index(i) + expected = \ +'''-- Multiline +-- table comment +CREATE TABLE "products" ( + "id" integer, + "name" varchar2, + -- Multiline + -- index comment + PRIMARY KEY ("id", "name") +); ''' self.assertEqual(t.sql, expected) @@ -514,6 +559,11 @@ def test_dbml_simple(self): expected = '"en-US"' self.assertEqual(ei.dbml, expected) + def test_sql(self): + ei = EnumItem('en-US') + expected = "'en-US'," + self.assertEqual(ei.sql, expected) + def test_dbml_full(self): ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') expected = \ @@ -540,20 +590,23 @@ def test_simple_enum(self) -> None: );''' self.assertEqual(e.sql, expected) - def test_notes(self) -> None: - n = Note('EnumItem note') + def test_comments(self) -> None: items = [ - EnumItem('created', note=n), + EnumItem('created', comment='EnumItem comment'), EnumItem('running'), - EnumItem('donef', note=n), + EnumItem('donef', comment='EnumItem\nmultiline comment'), EnumItem('failure'), ] - e = Enum('job_status', items) + e = Enum('job_status', items, comment='Enum comment') expected = \ -'''CREATE TYPE "job_status" AS ENUM ( - 'created', -- EnumItem note +'''-- Enum comment +CREATE TYPE "job_status" AS ENUM ( + -- EnumItem comment + 'created', 'running', - 'donef', -- EnumItem note + -- EnumItem + -- multiline comment + 'donef', 'failure', );''' self.assertEqual(e.sql, expected) From 37a2b9d164ddc8298febfdb8c79865bc68747612 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 12:18:39 +0300 Subject: [PATCH 013/125] consistent line breaks in SQL --- README.md | 2 ++ pydbml/classes.py | 4 ++-- pydbml/parser.py | 2 +- test/test_classes.py | 15 +++++---------- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 7c384b4..31e62d7 100644 --- a/README.md +++ b/README.md @@ -82,10 +82,12 @@ CREATE TYPE "orders_status" AS ENUM ( 'done', 'failure', ); + CREATE TYPE "product status" AS ENUM ( 'Out of Stock', 'In Stock', ); + CREATE TABLE "orders" ( "id" int PRIMARY KEY AUTOINCREMENT, "user_id" int UNIQUE NOT NULL, diff --git a/pydbml/classes.py b/pydbml/classes.py index 04e3fcc..e596a71 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -763,8 +763,8 @@ def sql(self): body.extend(indent(i.sql, 2) for i in self.indexes if i.pk) body.extend(indent(r.sql, 2) for r in self.refs) components.append(',\n'.join(body)) - components.append(');\n') - components.extend(i.sql + '\n' for i in self.indexes if not i.pk) + components.append(');') + components.extend('\n' + i.sql for i in self.indexes if not i.pk) result = comment_to_sql(self.comment) if self.comment else '' result += '\n'.join(components) diff --git a/pydbml/parser.py b/pydbml/parser.py index 9310055..d4c64f1 100644 --- a/pydbml/parser.py +++ b/pydbml/parser.py @@ -267,7 +267,7 @@ def sql(self): '''Returs SQL of the parsed results''' components = (i.sql for i in (*self.enums, *self.tables)) - return '\n'.join(components) + return '\n\n'.join(components) @property def dbml(self): diff --git a/test/test_classes.py b/test/test_classes.py index 3d554b3..818dd1f 100644 --- a/test/test_classes.py +++ b/test/test_classes.py @@ -340,7 +340,7 @@ def test_one_column(self) -> None: t = Table('products') c = Column('id', 'integer') t.add_column(c) - expected = 'CREATE TABLE "products" (\n "id" integer\n);\n' + expected = 'CREATE TABLE "products" (\n "id" integer\n);' self.assertEqual(t.sql, expected) def test_ref(self) -> None: @@ -359,8 +359,7 @@ def test_ref(self) -> None: "id" integer, "name" varchar2, FOREIGN KEY ("name") REFERENCES "names" ("name_val") -); -''' +);''' self.assertEqual(t.sql, expected) def test_duplicate_ref(self) -> None: @@ -397,7 +396,6 @@ def test_notes(self) -> None: "country" varchar ); - COMMENT ON TABLE "products" IS 'Table note'; COMMENT ON COLUMN "products"."id" IS 'First column note'; @@ -426,8 +424,7 @@ def test_ref_index(self) -> None: FOREIGN KEY ("name") REFERENCES "names" ("name_val") ); -CREATE INDEX ON "products" ("id", "name"); -''' +CREATE INDEX ON "products" ("id", "name");''' self.assertEqual(t.sql, expected) def test_index_inline(self) -> None: @@ -443,8 +440,7 @@ def test_index_inline(self) -> None: "id" integer, "name" varchar2, PRIMARY KEY ("id", "name") -); -''' +);''' self.assertEqual(t.sql, expected) def test_index_inline_and_comments(self) -> None: @@ -464,8 +460,7 @@ def test_index_inline_and_comments(self) -> None: -- Multiline -- index comment PRIMARY KEY ("id", "name") -); -''' +);''' self.assertEqual(t.sql, expected) def test_add_column(self) -> None: From 2589c92a98c6db5897484ba6feae948a87125e92 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 12:23:49 +0300 Subject: [PATCH 014/125] update readme --- README.md | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 31e62d7..13d573a 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ Other meaningful attributes are: * **table_groups** — list of all table groups, * **project** — the Project object, if was defined. -Finally, you can get the SQL for your DBML schema by accessing `sql` property: +You can get the SQL for your DBML schema by accessing `sql` property: ```python >>> print(parsed.sql) # doctest:+ELLIPSIS @@ -98,6 +98,34 @@ CREATE TABLE "orders" ( ``` +Finally, you can generate the DBML source from your schema with updated values from the classes: + +```python +>>> parsed.project.items['author'] = 'John Doe' +>>> print(parsed.dbml) # doctest:+ELLIPSIS +Project test_schema { + author: 'John Doe' + Note { + 'This schema is used for PyDBML doctest' + } +} + +Table "orders" { + "id" int [pk, increment] + "user_id" int [unique, not null] + "status" orders_status + "created_at" varchar +} + +Table "order_items" { + "order_id" int + "product_id" int + "quantity" int [default: 1] +} +... + +``` + # Docs ## Table class From 0832dfd19755c019b3846fab21242d4511b8e298 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 12:27:20 +0300 Subject: [PATCH 015/125] update changelog and bump version --- changelog.md | 5 ++++- setup.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/changelog.md b/changelog.md index 2e5bd43..13b3f6e 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,9 @@ # 0.4.1 -- Reworked `__repr__` and `__str__` methods on all classes. +- Reworked `__repr__` and `__str__` methods on all classes. They are now much simplier and more readable. +- Comments on classes are now rendered as SQL comments in `sql` property (previously notes were rendered as comments on some classes). +- Notes on `Table` and `Column` classes are rendered as SQL comments in `sql` property: `COMMENT ON TABLE "x" is 'y'`. +- New: `dbml` property on most classes and on parsed results which returns the DBML code. - Fix: sql for Reference and TableReference. # 0.4.0 diff --git a/setup.py b/setup.py index b5b9989..72635f1 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='0.4.0', + version='0.4.1', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 4e2bec69232a6d751d616225e717e90901cd6ca9 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 13 Jul 2021 12:28:46 +0300 Subject: [PATCH 016/125] update readme and todo --- README.md | 2 +- TODO.md | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 13d573a..2c745e5 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ CREATE TABLE "orders" ( ``` -Finally, you can generate the DBML source from your schema with updated values from the classes: +Finally, you can generate the DBML source from your schema with updated values from the classes (added in **0.4.1**): ```python >>> parsed.project.items['author'] = 'John Doe' diff --git a/TODO.md b/TODO.md index a53be08..f99e053 100644 --- a/TODO.md +++ b/TODO.md @@ -1,5 +1 @@ -- Docs for converting to dbml -- Notes should be converted to COMMENT ON -- comments should be converted to comments - Creating dbml schema in python -- Docs for creating dbml schema in python From 1db1af4d600378e434fc7bc04ac60a949c9a564b Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 30 Jan 2022 11:37:56 +0100 Subject: [PATCH 017/125] fix subject names editing, fix table getitem after editing --- pydbml/classes.py | 40 +++++++++++------- test/test_data/editing.dbml | 35 ++++++++++++++++ test/test_editing.py | 83 +++++++++++++++++++++++++++++++++++++ test/test_index.py | 18 ++++---- 4 files changed, 153 insertions(+), 23 deletions(-) create mode 100644 test/test_data/editing.dbml create mode 100644 test/test_editing.py diff --git a/pydbml/classes.py b/pydbml/classes.py index e596a71..6bdb6a6 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -17,7 +17,6 @@ from .tools import note_option_to_dbml - class SQLOjbect: ''' Base class for all SQL objects. @@ -546,7 +545,7 @@ def __init__(self, pk: bool = False, note: Optional[Note] = None, comment: Optional[str] = None): - self.subject_names = subject_names + self._subject_names = subject_names self.subjects: List[Union[Column, str]] = [] self.name = name if name else None @@ -557,6 +556,16 @@ def __init__(self, self.note = Note(note) self.comment = comment + @property + def subject_names(self): + ''' + For backward compatibility. Returns updated list of subject names. + ''' + if self.subjects: + return [s.name if isinstance(s, Column) else s for s in self.subjects] + else: + return self._subject_names + def __repr__(self): ''' >>> Index(['name', 'type']) @@ -569,7 +578,6 @@ def __repr__(self): table_name = self.table.name if self.table else None return f"" - def __str__(self): ''' >>> print(Index(['name', 'type'])) @@ -580,7 +588,7 @@ def __str__(self): ''' table_name = self.table.name if self.table else '' - subjects = ', '.join(s for s in self.subject_names) + subjects = ', '.join(self.subject_names) return f"Index({table_name}[{subjects}])" @property @@ -628,10 +636,12 @@ def subject_to_str(val: str) -> str: result = comment_to_dbml(self.comment) if self.comment else '' - if len(self.subject_names) > 1: - result += f'({", ".join(subject_to_str(sn) for sn in self.subject_names)})' + subject_names = self.subject_names + + if len(subject_names) > 1: + result += f'({", ".join(subject_to_str(sn) for sn in subject_names)})' else: - result += subject_to_str(self.subject_names[0]) + result += subject_to_str(subject_names[0]) options = [] if self.name: @@ -665,7 +675,6 @@ def __init__(self, self.name = name self.columns: List[Column] = [] self.indexes: List[Index] = [] - self.column_dict: Dict[str, Column] = {} self.alias = alias if alias else None self.note = Note(note) self.header_color = header_color @@ -679,14 +688,13 @@ def add_column(self, c: Column) -> None: ''' c.table = self self.columns.append(c) - self.column_dict[c.name] = c def add_index(self, i: Index) -> None: ''' Adds index to self.indexes attribute and sets in this index the `table` attribute. ''' - for subj in i.subject_names: + for subj in i._subject_names: if subj.startswith('(') and subj.endswith(')'): # subject is an expression, add it as string i.subjects.append(subj) @@ -713,10 +721,16 @@ def __getitem__(self, k: Union[int, str]) -> Column: if isinstance(k, int): return self.columns[k] else: - return self.column_dict[k] + for c in self.columns: + if c.name == k: + return c + raise KeyError(k) def get(self, k, default=None): - return self.column_dict.get(k, default) + try: + return self.__getitem__(k) + except KeyError: + return default def __iter__(self): return iter(self.columns) @@ -844,7 +858,6 @@ def dbml(self): return result - class Enum(SQLOjbect): required_attributes = ('name', 'items') @@ -917,7 +930,6 @@ def dbml(self): return result - class EnumType(Enum): ''' Enum object, intended to be put in the `type` attribute of a column. diff --git a/test/test_data/editing.dbml b/test/test_data/editing.dbml new file mode 100644 index 0000000..2aa1ed9 --- /dev/null +++ b/test/test_data/editing.dbml @@ -0,0 +1,35 @@ +Table "products" { + "id" int [pk] + "name" varchar + "merchant_id" int [not null] + "price" int + "status" "product status" + "created_at" datetime [default: `now()`] + + + Indexes { + (merchant_id, status) [name: "product_status"] + id [type: hash, unique] + } +} + +Enum "product status" { + "Out of Stock" + "In Stock" +} + +Ref:"merchants"."id" < "products"."merchant_id" + + +Table "merchants" { + "id" int [pk] + "merchant_name" varchar + "country_code" int + "created_at" varchar + "admin_id" int +} + +TableGroup g1 { + products + merchants +} diff --git a/test/test_editing.py b/test/test_editing.py new file mode 100644 index 0000000..451039e --- /dev/null +++ b/test/test_editing.py @@ -0,0 +1,83 @@ +import os + +from pathlib import Path +from unittest import TestCase + +from pyparsing import ParseException +from pyparsing import ParseSyntaxException +from pyparsing import ParserElement + +from pydbml import PyDBML +from pydbml.definitions.table import alias +from pydbml.definitions.table import header_color +from pydbml.definitions.table import table +from pydbml.definitions.table import table_body +from pydbml.definitions.table import table_settings + + +ParserElement.setDefaultWhitespaceChars(' \t\r') + + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class EditingTestCase(TestCase): + def setUp(self): + self.dbml = PyDBML(TEST_DATA_PATH / 'editing.dbml') + + +class TestEditTable(EditingTestCase): + def test_name(self) -> None: + products = self.dbml['products'] + products.name = 'changed_products' + self.assertIn('CREATE TABLE "changed_products"', products.sql) + self.assertIn('CREATE INDEX "product_status" ON "changed_products"', products.sql) + self.assertIn('Table "changed_products"', products.dbml) + + ref = self.dbml.refs[0] + self.assertIn('ALTER TABLE "changed_products"', ref.sql) + self.assertIn('"changed_products"."merchant_id"', ref.dbml) + + index = products.indexes[0] + self.assertIn('ON "changed_products"', index.sql) + + def test_alias(self) -> None: + products = self.dbml['products'] + products.alias = 'new_alias' + + self.assertIn('as "new_alias"', products.dbml) + + +class TestColumn(EditingTestCase): + def test_name(self) -> None: + products = self.dbml['products'] + col = products['name'] + col.name = 'new_name' + self.assertEqual(col.sql, '"new_name" varchar') + self.assertEqual(col.dbml, '"new_name" varchar') + self.assertIn('"new_name" varchar', products.sql) + self.assertIn('"new_name" varchar', products.dbml) + + self.assertEqual(col, products[col.name]) + + def test_name_index(self) -> None: + products = self.dbml['products'] + col = products['status'] + col.name = 'changed_status' + self.assertIn('"changed_status"', products.indexes[0].sql) + self.assertIn('changed_status', products.indexes[0].dbml) + self.assertIn( + 'CREATE INDEX "product_status" ON "products" ("merchant_id", "changed_status");', + products.sql + ) + self.assertIn( + "(merchant_id, changed_status) [name: 'product_status']", + products.dbml + ) + + def test_name_ref(self) -> None: + products = self.dbml['products'] + col = products['merchant_id'] + col.name = 'changed_merchant_id' + table_ref = products.refs[0] + self.assertIn('FOREIGN KEY ("changed_merchant_id")', table_ref.sql) diff --git a/test/test_index.py b/test/test_index.py index f60dda9..2067261 100644 --- a/test/test_index.py +++ b/test/test_index.py @@ -160,51 +160,51 @@ class TestIndex(TestCase): def test_single(self) -> None: val = 'my_column' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['my_column']) + self.assertEqual(res[0]._subject_names, ['my_column']) def test_expression(self) -> None: val = '(`id*3`)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['(id*3)']) + self.assertEqual(res[0]._subject_names, ['(id*3)']) def test_composite(self) -> None: val = '(my_column, my_another_column)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) + self.assertEqual(res[0]._subject_names, ['my_column', 'my_another_column']) def test_composite_with_expression(self) -> None: val = '(`id*3`, fieldname)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['(id*3)', 'fieldname']) + self.assertEqual(res[0]._subject_names, ['(id*3)', 'fieldname']) def test_with_settings(self) -> None: val = '(my_column, my_another_column) [unique]' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) + self.assertEqual(res[0]._subject_names, ['my_column', 'my_another_column']) self.assertTrue(res[0].unique) def test_comment_above(self) -> None: val = '//comment above\nmy_column [unique]' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['my_column']) + self.assertEqual(res[0]._subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment above') def test_comment_after(self) -> None: val = 'my_column [unique] //comment after' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['my_column']) + self.assertEqual(res[0]._subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') val = 'my_column //comment after' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['my_column']) + self.assertEqual(res[0]._subject_names, ['my_column']) self.assertEqual(res[0].comment, 'comment after') def test_both_comments(self) -> None: val = '//comment before\nmy_column [unique] //comment after' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['my_column']) + self.assertEqual(res[0]._subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') From 39850f5584a53c763b1dd4895c9b0b090751eed4 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 30 Jan 2022 17:47:11 +0100 Subject: [PATCH 018/125] fix enum --- changelog.md | 8 ++++++++ pydbml/classes.py | 3 --- pydbml/parser.py | 5 +++-- setup.py | 2 +- test/test_editing.py | 13 +++++++++++++ 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/changelog.md b/changelog.md index 13b3f6e..3f16e76 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,11 @@ +# 0.4.2 + +- Fix: after editing column name index dbml was not updated. +- Fix: enums with spaces in name were not applied. +- Fix: after editing column name table dict was not updated. +- Fix: after editing enum column type was not updated. +- Removed EnumType class. Only Enum is used now. + # 0.4.1 - Reworked `__repr__` and `__str__` methods on all classes. They are now much simplier and more readable. diff --git a/pydbml/classes.py b/pydbml/classes.py index 6bdb6a6..efcf62c 100644 --- a/pydbml/classes.py +++ b/pydbml/classes.py @@ -869,9 +869,6 @@ def __init__(self, self.items = items self.comment = comment - def get_type(self): - return EnumType(self.name, self.items) - def __getitem__(self, key) -> EnumItem: return self.items[key] diff --git a/pydbml/parser.py b/pydbml/parser.py index d4c64f1..412edc1 100644 --- a/pydbml/parser.py +++ b/pydbml/parser.py @@ -240,8 +240,9 @@ def _set_enum_types(self): enum_dict = {enum.name: enum for enum in self.enums} for table_ in self.tables: for col in table_: - if str(col.type) in enum_dict: - col.type = enum_dict[str(col.type)].get_type() + col_type = str(col.type).strip('"') + if col_type in enum_dict: + col.type = enum_dict[col_type] def _validate(self): self._validate_table_groups() diff --git a/setup.py b/setup.py index 72635f1..2497ce6 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='0.4.1', + version='0.4.2', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', diff --git a/test/test_editing.py b/test/test_editing.py index 451039e..a53a509 100644 --- a/test/test_editing.py +++ b/test/test_editing.py @@ -81,3 +81,16 @@ def test_name_ref(self) -> None: col.name = 'changed_merchant_id' table_ref = products.refs[0] self.assertIn('FOREIGN KEY ("changed_merchant_id")', table_ref.sql) + + +class TestEnum(EditingTestCase): + def test_enum_name(self): + products = self.dbml['products'] + enum = self.dbml.enums[0] + enum.name = 'changed product status' + self.assertIn('CREATE TYPE "changed product status"', enum.sql) + self.assertIn('Enum "changed product status"', enum.dbml) + + col = products['status'] + self.assertEqual(col.sql, '"status" changed product status') + self.assertEqual(col.dbml, '"status" changed product status') From 503fbe4f25b1fd8fb0eda9cf4575b81cee63dc58 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 8 May 2022 17:53:24 +0200 Subject: [PATCH 019/125] PyDBML 1.0 WIP: restructure, blueprints for everything, Schema draft, new parser class, tests rework WIP --- TODO.md | 2 + pydbml/__init__.py | 10 +- pydbml/classes.py | 1031 ----------------- pydbml/classes/__init__.py | 9 + pydbml/classes/_classes.py | 161 +++ pydbml/classes/base.py | 37 + pydbml/classes/column.py | 125 ++ pydbml/classes/enum.py | 122 ++ pydbml/classes/index.py | 134 +++ pydbml/classes/note.py | 62 + pydbml/classes/project.py | 45 + pydbml/classes/reference.py | 188 +++ pydbml/classes/table.py | 241 ++++ pydbml/classes/table_group.py | 59 + pydbml/definitions/column.py | 8 +- pydbml/definitions/common.py | 6 +- pydbml/definitions/enum.py | 8 +- pydbml/definitions/index.py | 6 +- pydbml/definitions/project.py | 8 +- pydbml/definitions/reference.py | 8 +- pydbml/definitions/table.py | 12 +- pydbml/definitions/table_group.py | 4 +- pydbml/exceptions.py | 11 + pydbml/parser.py | 281 ----- pydbml/parser/__init__.py | 1 + pydbml/parser/blueprints.py | 247 ++++ pydbml/parser/parser.py | 248 ++++ pydbml/schema.py | 211 ++++ pydbml/tools.py | 3 +- setup.py | 2 +- test/{test_classes.py => _test_classes.py} | 84 -- test/_test_parser.py | 96 ++ test/test_blueprints/test_column.py | 43 + test/test_blueprints/test_enum.py | 50 + test/test_blueprints/test_index.py | 37 + test/test_blueprints/test_note.py | 12 + test/test_blueprints/test_project.py | 36 + test/test_blueprints/test_reference.py | 68 ++ test/test_blueprints/test_table.py | 125 ++ test/test_blueprints/test_table_group.py | 26 + test/test_classes/test_base.py | 33 + test/test_classes/test_column.py | 120 ++ test/test_classes/test_enum.py | 90 ++ test/test_classes/test_index.py | 105 ++ test/test_classes/test_table.py | 220 ++++ test/test_create_schema.py | 88 ++ test/test_definitions/__init__.py | 0 test/{ => test_definitions}/test_column.py | 0 test/{ => test_definitions}/test_common.py | 0 test/{ => test_definitions}/test_enum.py | 0 test/{ => test_definitions}/test_index.py | 22 +- test/{ => test_definitions}/test_project.py | 0 test/{ => test_definitions}/test_reference.py | 0 test/{ => test_definitions}/test_table.py | 0 .../test_table_group.py | 0 test/test_parser.py | 88 +- 56 files changed, 3109 insertions(+), 1524 deletions(-) delete mode 100644 pydbml/classes.py create mode 100644 pydbml/classes/__init__.py create mode 100644 pydbml/classes/_classes.py create mode 100644 pydbml/classes/base.py create mode 100644 pydbml/classes/column.py create mode 100644 pydbml/classes/enum.py create mode 100644 pydbml/classes/index.py create mode 100644 pydbml/classes/note.py create mode 100644 pydbml/classes/project.py create mode 100644 pydbml/classes/reference.py create mode 100644 pydbml/classes/table.py create mode 100644 pydbml/classes/table_group.py delete mode 100644 pydbml/parser.py create mode 100644 pydbml/parser/__init__.py create mode 100644 pydbml/parser/blueprints.py create mode 100644 pydbml/parser/parser.py create mode 100644 pydbml/schema.py rename test/{test_classes.py => _test_classes.py} (92%) create mode 100644 test/_test_parser.py create mode 100644 test/test_blueprints/test_column.py create mode 100644 test/test_blueprints/test_enum.py create mode 100644 test/test_blueprints/test_index.py create mode 100644 test/test_blueprints/test_note.py create mode 100644 test/test_blueprints/test_project.py create mode 100644 test/test_blueprints/test_reference.py create mode 100644 test/test_blueprints/test_table.py create mode 100644 test/test_blueprints/test_table_group.py create mode 100644 test/test_classes/test_base.py create mode 100644 test/test_classes/test_column.py create mode 100644 test/test_classes/test_enum.py create mode 100644 test/test_classes/test_index.py create mode 100644 test/test_classes/test_table.py create mode 100644 test/test_create_schema.py create mode 100644 test/test_definitions/__init__.py rename test/{ => test_definitions}/test_column.py (100%) rename test/{ => test_definitions}/test_common.py (100%) rename test/{ => test_definitions}/test_enum.py (100%) rename test/{ => test_definitions}/test_index.py (91%) rename test/{ => test_definitions}/test_project.py (100%) rename test/{ => test_definitions}/test_reference.py (100%) rename test/{ => test_definitions}/test_table.py (100%) rename test/{ => test_definitions}/test_table_group.py (100%) diff --git a/TODO.md b/TODO.md index f99e053..d95bd48 100644 --- a/TODO.md +++ b/TODO.md @@ -1 +1,3 @@ - Creating dbml schema in python +- pyparsing new var names (+possibly new features) +- enum type diff --git a/pydbml/__init__.py b/pydbml/__init__.py index 879ea59..2be8095 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -1,9 +1,13 @@ -from pydbml.parser import PyDBML, PyDBMLParseResults -import unittest import doctest +import unittest + from . import classes +from pydbml.parser import PyDBML +from pydbml.parser.blueprints import MANY_TO_ONE +from pydbml.parser.blueprints import ONE_TO_MANY +from pydbml.parser.blueprints import ONE_TO_ONE def load_tests(loader, tests, ignore): tests.addTests(doctest.DocTestSuite(classes)) - return tests \ No newline at end of file + return tests diff --git a/pydbml/classes.py b/pydbml/classes.py deleted file mode 100644 index efcf62c..0000000 --- a/pydbml/classes.py +++ /dev/null @@ -1,1031 +0,0 @@ -from __future__ import annotations - -from typing import Any -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -from .exceptions import AttributeMissingError -from .exceptions import ColumnNotFoundError -from .exceptions import DuplicateReferenceError -from .tools import comment_to_dbml -from .tools import comment_to_sql -from .tools import indent -from .tools import note_option_to_dbml - - -class SQLOjbect: - ''' - Base class for all SQL objects. - ''' - required_attributes: Tuple[str, ...] = () - - def check_attributes_for_sql(self): - ''' - Check if all attributes, required for rendering SQL are set in the - instance. If some attribute is missing, raise AttributeMissingError - ''' - for attr in self.required_attributes: - if getattr(self, attr) is None: - raise AttributeMissingError( - f'Cannot render SQL. Missing required attribute "{attr}".' - ) - - def __setattr__(self, name: str, value: Any): - """ - Required for type testing with MyPy. - """ - super().__setattr__(name, value) - - def __eq__(self, other: object) -> bool: - """ - Two instances of the same SQLObject subclass are equal if all their - attributes are equal. - """ - - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - return False - - -class ReferenceBlueprint: - ''' - Intermediate class for references during parsing. Table and columns are just - strings at this point, as we can't check their validity until all schema - is parsed. - - Note: `table2` and `col2` params are technically required (left optional for aesthetics). - ''' - ONE_TO_MANY = '<' - MANY_TO_ONE = '>' - ONE_TO_ONE = '-' - - def __init__(self, - type_: str, - name: Optional[str] = None, - table1: Optional[str] = None, - col1: Optional[Union[str, Collection[str]]] = None, - table2: Optional[str] = None, - col2: Optional[Union[str, Collection[str]]] = None, - comment: Optional[str] = None, - on_update: Optional[str] = None, - on_delete: Optional[str] = None): - self.type = type_ - self.name = name if name else None - self.table1 = table1 if table1 else None - self.col1 = col1 if col1 else None - self.table2 = table2 if table2 else None - self.col2 = col2 if col2 else None - self.comment = comment - self.on_update = on_update - self.on_delete = on_delete - - def __repr__(self): - ''' - >>> ReferenceBlueprint('>', table1='t1', col1='c1', table2='t2', col2='c2') - ', 't1'.'c1', 't2'.'c2'> - >>> ReferenceBlueprint('<', table2='t2', col2='c2') - - >>> ReferenceBlueprint('>', table1='t1', col1=('c11', 'c12'), table2='t2', col2=['c21', 'c22']) - ', 't1'.('c11', 'c12'), 't2'.['c21', 'c22']> - ''' - - components = [f"' - - def __str__(self): - ''' - >>> r1 = ReferenceBlueprint('>', table1='t1', col1='c1', table2='t2', col2='c2') - >>> r2 = ReferenceBlueprint('<', table2='t2', col2='c2') - >>> r3 = ReferenceBlueprint('>', table1='t1', col1=('c11', 'c12'), table2='t2', col2=('c21', 'c22')) - >>> print(r1, r2) - ReferenceBlueprint(t1.c1 > t2.c2) ReferenceBlueprint(< t2.c2) - >>> print(r3) - ReferenceBlueprint(t1[c11, c12] > t2[c21, c22]) - ''' - - components = [f"ReferenceBlueprint("] - if self.table1: - components.append(self.table1) - if self.col1: - if isinstance(self.col1, str): - components.append(f'.{self.col1} ') - else: # list or tuple - components.append(f'[{", ".join(self.col1)}] ') - components.append(f'{self.type} ') - components.append(self.table2) - if isinstance(self.col2, str): - components.append(f'.{self.col2}') - else: # list or tuple - components.append(f'[{", ".join(self.col2)}]') - return ''.join(components) + ')' - - -class Reference(SQLOjbect): - ''' - Class, representing a foreign key constraint. - It is a separate object, which is not connected to Table or Column objects - and its `sql` property contains the ALTER TABLE clause. - ''' - required_attributes = ('type', 'table1', 'col1', 'table2', 'col2') - - ONE_TO_MANY = '<' - MANY_TO_ONE = '>' - ONE_TO_ONE = '-' - - def __init__(self, - type_: str, - table1: Table, - col1: Union[Column, Collection[Column]], - table2: Table, - col2: Union[Column, Collection[Column]], - name: Optional[str] = None, - comment: Optional[str] = None, - on_update: Optional[str] = None, - on_delete: Optional[str] = None): - self.type = type_ - self.table1 = table1 - self.col1 = [col1] if isinstance(col1, Column) else list(col1) - self.table2 = table2 - self.col2 = [col2] if isinstance(col2, Column) else list(col2) - self.name = name if name else None - self.comment = comment - self.on_update = on_update - self.on_delete = on_delete - - def __repr__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> Reference('>', table1=t1, col1=c1, table2=t2, col2=c2) - ', 't1'.['c1'], 't2'.['c2']> - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22)) - - ''' - - components = [f"' - - def __str__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> print(Reference('>', table1=t1, col1=c1, table2=t2, col2=c2)) - Reference(t1[c1] > t2[c2]) - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> print(Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22))) - Reference(t1[c1, c12] < t2[c2, c22]) - ''' - - components = [f"Reference("] - components.append(self.table1.name) - components.append(f'[{", ".join(c.name for c in self.col1)}]') - components.append(f' {self.type} ') - components.append(self.table2.name) - components.append(f'[{", ".join(c.name for c in self.col2)}]') - return ''.join(components) + ')' - - @property - def sql(self): - ''' - Returns SQL of the reference: - - ALTER TABLE "orders" ADD FOREIGN KEY ("customer_id") REFERENCES "customers ("id"); - - ''' - self.check_attributes_for_sql() - c = f'CONSTRAINT "{self.name}" ' if self.name else '' - - if self.type in (self.MANY_TO_ONE, self.ONE_TO_ONE): - t1 = self.table1 - c1 = ', '.join(f'"{c.name}"' for c in self.col1) - t2 = self.table2 - c2 = ', '.join(f'"{c.name}"' for c in self.col2) - else: - t1 = self.table2 - c1 = ', '.join(f'"{c.name}"' for c in self.col2) - t2 = self.table1 - c2 = ', '.join(f'"{c.name}"' for c in self.col1) - - result = comment_to_sql(self.comment) if self.comment else '' - result += ( - f'ALTER TABLE "{t1.name}" ADD {c}FOREIGN KEY ({c1}) ' - f'REFERENCES "{t2.name}" ({c2})' - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result + ';' - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += 'Ref' - if self.name: - result += f' {self.name}' - - if len(self.col1) == 1: - col1 = f'"{self.col1[0].name}"' - else: - names = (f'"{c.name}"' for c in self.col1) - col1 = f'({", ".join(names)})' - - if len(self.col2) == 1: - col2 = f'"{self.col2[0].name}"' - else: - names = (f'"{c.name}"' for c in self.col2) - col2 = f'({", ".join(names)})' - - options = [] - if self.on_update: - options.append(f'update: {self.on_update}') - if self.on_delete: - options.append(f'delete: {self.on_delete}') - - options_str = f' [{", ".join(options)}]' if options else '' - result += ( - ' {\n ' - f'"{self.table1.name}".{col1} ' - f'{self.type} ' - f'"{self.table2.name}".{col2}' - f'{options_str}' - '\n}' - ) - return result - - -class TableReference(SQLOjbect): - ''' - Class, representing a foreign key constraint. - This object should be assigned to the `refs` attribute of a Table object. - Its `sql` property contains the inline definition of the FOREIGN KEY clause. - ''' - required_attributes = ('col', 'ref_table', 'ref_col') - - def __init__(self, - col: Union[Column, List[Column]], - ref_table: Table, - ref_col: Union[Column, List[Column]], - name: Optional[str] = None, - on_delete: Optional[str] = None, - on_update: Optional[str] = None): - self.col = [col] if isinstance(col, Column) else list(col) - self.ref_table = ref_table - self.ref_col = [ref_col] if isinstance(ref_col, Column) else list(ref_col) - self.name = name - self.on_update = on_update - self.on_delete = on_delete - - def __repr__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t2 = Table('t2') - >>> TableReference(col=c1, ref_table=t2, ref_col=c2) - - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22)) - - ''' - - col_names = [c.name for c in self.col] - ref_col_names = [c.name for c in self.ref_col] - return f"" - - def __str__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t2 = Table('t2') - >>> print(TableReference(col=c1, ref_table=t2, ref_col=c2)) - TableReference([c1] > t2[c2]) - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> print(TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22))) - TableReference([c1, c12] > t2[c2, c22]) - ''' - - components = [f"TableReference("] - components.append(f'[{", ".join(c.name for c in self.col)}]') - components.append(' > ') - components.append(self.ref_table.name) - components.append(f'[{", ".join(c.name for c in self.ref_col)}]') - return ''.join(components) + ')' - - @property - def sql(self): - ''' - Returns inline SQL of the reference, which should be a part of table definition: - - FOREIGN KEY ("order_id") REFERENCES "orders ("id") - - ''' - self.check_attributes_for_sql() - c = f'CONSTRAINT "{self.name}" ' if self.name else '' - cols = '", "'.join(c.name for c in self.col) - ref_cols = '", "'.join(c.name for c in self.ref_col) - result = ( - f'{c}FOREIGN KEY ("{cols}") ' - f'REFERENCES "{self.ref_table.name}" ("{ref_cols}")' - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result - - -class Note: - def __init__(self, text: Any): - self.text = str(text) if text else '' - - def __str__(self): - ''' - >>> print(Note('Note text')) - Note text - ''' - - return self.text - - def __bool__(self): - return bool(self.text) - - def __repr__(self): - ''' - >>> Note('Note text') - Note('Note text') - ''' - - return f'Note({repr(self.text)})' - - @property - def sql(self): - if self.text: - return '\n'.join(f'-- {line}' for line in self.text.split('\n')) - else: - return '' - - @property - def dbml(self): - lines = [] - line = '' - for word in self.text.split(' '): - if len(line) > 80: - lines.append(line) - line = '' - if '\n' in word: - sublines = word.split('\n') - for sl in sublines[:-1]: - line += sl - lines.append(line) - line = '' - line = sublines[-1] + ' ' - else: - line += f'{word} ' - if line: - lines.append(line) - result = 'Note {\n ' - - if len(lines) > 1: - lines_str = '\n '.join(lines)[:-1] + '\n' - result += f"'''\n {lines_str} '''" - else: - result += f"'{lines[0][:-1]}'" - - result += '\n}' - return result - - -class Column(SQLOjbect): - '''Class representing table column.''' - - required_attributes = ('name', 'type') - - def __init__(self, - name: str, - type_: str, - unique: bool = False, - not_null: bool = False, - pk: bool = False, - autoinc: bool = False, - default: Optional[Union[str, int, bool, float]] = None, - note: Optional[Note] = None, - ref_blueprints: Optional[List[ReferenceBlueprint]] = None, - comment: Optional[str] = None): - self.name = name - self.type = type_ - self.unique = unique - self.not_null = not_null - self.pk = pk - self.autoinc = autoinc - self.comment = comment - - self.default = default - - self.note = Note(note) - self.ref_blueprints = ref_blueprints or [] - for ref in self.ref_blueprints: - ref.col1 = self.name - - self._table: Optional[Table] = None - - @property - def table(self) -> Optional[Table]: - return self._table - - @table.setter - def table(self, v: Table): - self._table = v - for ref in self.ref_blueprints: - ref.table1 = v.name - - @property - def sql(self): - ''' - Returns inline SQL of the column, which should be a part of table definition: - - "id" integer PRIMARY KEY AUTOINCREMENT - ''' - - self.check_attributes_for_sql() - components = [f'"{self.name}"', str(self.type)] - if self.pk: - components.append('PRIMARY KEY') - if self.autoinc: - components.append('AUTOINCREMENT') - if self.unique: - components.append('UNIQUE') - if self.not_null: - components.append('NOT NULL') - if self.default is not None: - components.append('DEFAULT ' + str(self.default)) - - result = comment_to_sql(self.comment) if self.comment else '' - result += ' '.join(components) - return result - - @property - def dbml(self): - def default_to_str(val: str) -> str: - if isinstance(val, str): - if val.lower() in ('null', 'true', 'false'): - return val.lower() - elif val.startswith('(') and val.endswith(')'): - return f'`{val[1:-1]}`' - else: - return f"'{val}'" - else: # int or float or bool - return val - - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'"{self.name}" {self.type}' - - options = [] - if self.pk: - options.append('pk') - if self.autoinc: - options.append('increment') - if self.default: - options.append(f'default: {default_to_str(self.default)}') - if self.unique: - options.append('unique') - if self.not_null: - options.append('not null') - if self.note: - options.append(note_option_to_dbml(self.note)) - - if options: - result += f' [{", ".join(options)}]' - return result - - def __repr__(self): - ''' - >>> Column('name', 'VARCHAR2') - - ''' - type_name = self.type if isinstance(self.type, str) else self.type.name - return f'' - - def __str__(self): - ''' - >>> print(Column('name', 'VARCHAR2')) - name[VARCHAR2] - ''' - - return f'{self.name}[{self.type}]' - - -class Index(SQLOjbect): - '''Class representing index.''' - required_attributes = ('subjects', 'table') - - def __init__(self, - subject_names: List[str], - name: Optional[str] = None, - table: Optional[Table] = None, - unique: bool = False, - type_: Optional[str] = None, - pk: bool = False, - note: Optional[Note] = None, - comment: Optional[str] = None): - self._subject_names = subject_names - self.subjects: List[Union[Column, str]] = [] - - self.name = name if name else None - self.table = table - self.unique = unique - self.type = type_ - self.pk = pk - self.note = Note(note) - self.comment = comment - - @property - def subject_names(self): - ''' - For backward compatibility. Returns updated list of subject names. - ''' - if self.subjects: - return [s.name if isinstance(s, Column) else s for s in self.subjects] - else: - return self._subject_names - - def __repr__(self): - ''' - >>> Index(['name', 'type']) - - >>> t = Table('t') - >>> Index(['name', 'type'], table=t) - - ''' - - table_name = self.table.name if self.table else None - return f"" - - def __str__(self): - ''' - >>> print(Index(['name', 'type'])) - Index([name, type]) - >>> t = Table('t') - >>> print(Index(['name', 'type'], table=t)) - Index(t[name, type]) - ''' - - table_name = self.table.name if self.table else '' - subjects = ', '.join(self.subject_names) - return f"Index({table_name}[{subjects}])" - - @property - def sql(self): - ''' - Returns inline SQL of the index to be created separately from table - definition: - - CREATE UNIQUE INDEX ON "products" USING HASH ("id"); - - But if it's a (composite) primary key index, returns an inline SQL for - composite primary key to be used inside table definition: - - PRIMARY KEY ("id", "name") - - ''' - self.check_attributes_for_sql() - keys = ', '.join(f'"{key.name}"' if isinstance(key, Column) else key for key in self.subjects) - if self.pk: - result = comment_to_sql(self.comment) if self.comment else '' - result += f'PRIMARY KEY ({keys})' - return result - - components = ['CREATE'] - if self.unique: - components.append('UNIQUE') - components.append('INDEX') - if self.name: - components.append(f'"{self.name}"') - components.append(f'ON "{self.table.name}"') - if self.type: - components.append(f'USING {self.type.upper()}') - components.append(f'({keys})') - result = comment_to_sql(self.comment) if self.comment else '' - result += ' '.join(components) + ';' - return result - - @property - def dbml(self): - def subject_to_str(val: str) -> str: - if val.startswith('(') and val.endswith(')'): - return f'`{val[1:-1]}`' - else: - return val - - result = comment_to_dbml(self.comment) if self.comment else '' - - subject_names = self.subject_names - - if len(subject_names) > 1: - result += f'({", ".join(subject_to_str(sn) for sn in subject_names)})' - else: - result += subject_to_str(subject_names[0]) - - options = [] - if self.name: - options.append(f"name: '{self.name}'") - if self.pk: - options.append('pk') - if self.unique: - options.append('unique') - if self.type: - options.append(f'type: {self.type}') - if self.note: - options.append(note_option_to_dbml(self.note)) - - if options: - result += f' [{", ".join(options)}]' - return result - - -class Table(SQLOjbect): - '''Class representing table.''' - - required_attributes = ('name',) - - def __init__(self, - name: str, - alias: Optional[str] = None, - note: Optional[Note] = None, - header_color: Optional[str] = None, - refs: Optional[List[TableReference]] = None, - comment: Optional[str] = None): - self.name = name - self.columns: List[Column] = [] - self.indexes: List[Index] = [] - self.alias = alias if alias else None - self.note = Note(note) - self.header_color = header_color - self.refs = refs or [] - self.comment = comment - - def add_column(self, c: Column) -> None: - ''' - Adds column to self.columns attribute and sets in this column the - `table` attribute. - ''' - c.table = self - self.columns.append(c) - - def add_index(self, i: Index) -> None: - ''' - Adds index to self.indexes attribute and sets in this index the - `table` attribute. - ''' - for subj in i._subject_names: - if subj.startswith('(') and subj.endswith(')'): - # subject is an expression, add it as string - i.subjects.append(subj) - else: - try: - col = self[subj] - i.subjects.append(col) - except KeyError: - raise ColumnNotFoundError(f'Cannot add index, column "{subj}" not defined in table "{self.name}".') - - i.table = self - self.indexes.append(i) - - def add_ref(self, r: TableReference) -> None: - ''' - Adds a reference to the table. If reference already present in the table, - raises DuplicateReferenceError. - ''' - if r in self.refs: - raise DuplicateReferenceError(f'Reference with same endpoints {r} already present in the table.') - self.refs.append(r) - - def __getitem__(self, k: Union[int, str]) -> Column: - if isinstance(k, int): - return self.columns[k] - else: - for c in self.columns: - if c.name == k: - return c - raise KeyError(k) - - def get(self, k, default=None): - try: - return self.__getitem__(k) - except KeyError: - return default - - def __iter__(self): - return iter(self.columns) - - def __repr__(self): - ''' - >>> table = Table('customers') - >>> table - - ''' - - return f'
' - - def __str__(self): - ''' - >>> table = Table('customers') - >>> table.add_column(Column('id', 'INTEGER')) - >>> table.add_column(Column('name', 'VARCHAR2')) - >>> print(table) - customers(id, name) - ''' - - return f'{self.name}({", ".join(c.name for c in self.columns)})' - - @property - def sql(self): - ''' - Returns full SQL for table definition: - - CREATE TABLE "countries" ( - "code" int PRIMARY KEY, - "name" varchar, - "continent_name" varchar - ); - - Also returns indexes if they were defined: - - CREATE INDEX ON "products" ("id", "name"); - ''' - self.check_attributes_for_sql() - components = [f'CREATE TABLE "{self.name}" ('] - body = [] - body.extend(indent(c.sql, 2) for c in self.columns) - body.extend(indent(i.sql, 2) for i in self.indexes if i.pk) - body.extend(indent(r.sql, 2) for r in self.refs) - components.append(',\n'.join(body)) - components.append(');') - components.extend('\n' + i.sql for i in self.indexes if not i.pk) - - result = comment_to_sql(self.comment) if self.comment else '' - result += '\n'.join(components) - - if self.note: - quoted_note = f"'{self.note.text}'" - note_sql = f'COMMENT ON TABLE "{self.name}" IS {quoted_note};' - result += f'\n\n{note_sql}' - - for col in self.columns: - if col.note: - quoted_note = f"'{col.note.text}'" - note_sql = f'COMMENT ON COLUMN "{self.name}"."{col.name}" IS {quoted_note};' - result += f'\n\n{note_sql}' - return result - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Table "{self.name}" ' - if self.alias: - result += f'as "{self.alias}" ' - result += '{\n' - columns_str = '\n'.join(c.dbml for c in self.columns) - result += indent(columns_str) + '\n' - if self.note: - result += indent(self.note.dbml) + '\n' - if self.indexes: - result += '\n indexes {\n' - indexes_str = '\n'.join(i.dbml for i in self.indexes) - result += indent(indexes_str, 8) + '\n' - result += ' }\n' - - result += '}' - return result - - -class EnumItem: - '''Single enum item''' - - def __init__(self, - name: str, - note: Optional[Note] = None, - comment: Optional[str] = None): - self.name = name - self.note = Note(note) - self.comment = comment - - def __repr__(self): - ''' - >>> EnumItem('en-US') - - ''' - - return f'' - - def __str__(self): - ''' - >>> print(EnumItem('en-US')) - en-US - ''' - - return self.name - - @property - def sql(self): - result = comment_to_sql(self.comment) if self.comment else '' - result += f"'{self.name}'," - return result - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'"{self.name}"' - if self.note: - result += f' [{note_option_to_dbml(self.note)}]' - return result - - -class Enum(SQLOjbect): - required_attributes = ('name', 'items') - - def __init__(self, - name: str, - items: List[EnumItem], - comment: Optional[str] = None): - self.name = name - self.items = items - self.comment = comment - - def __getitem__(self, key) -> EnumItem: - return self.items[key] - - def __iter__(self): - return iter(self.items) - - def __repr__(self): - ''' - >>> en = EnumItem('en-US') - >>> ru = EnumItem('ru-RU') - >>> Enum('languages', [en, ru]) - - ''' - - item_names = [i.name for i in self.items] - classname = self.__class__.__name__ - return f'<{classname} {self.name!r}, {item_names!r}>' - - def __str__(self): - ''' - >>> en = EnumItem('en-US') - >>> ru = EnumItem('ru-RU') - >>> print(Enum('languages', [en, ru])) - languages - ''' - - return self.name - - @property - def sql(self): - ''' - Returns SQL for enum type: - - CREATE TYPE "job_status" AS ENUM ( - 'created', - 'running', - 'donef', - 'failure', - ); - - ''' - self.check_attributes_for_sql() - result = comment_to_sql(self.comment) if self.comment else '' - result += f'CREATE TYPE "{self.name}" AS ENUM (\n' - result +='\n'.join(f'{indent(i.sql, 2)}' for i in self.items) - result += '\n);' - return result - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Enum "{self.name}" {{\n' - items_str = '\n'.join(i.dbml for i in self.items) - result += indent(items_str) - result += '\n}' - return result - - -class EnumType(Enum): - ''' - Enum object, intended to be put in the `type` attribute of a column. - - >>> en = EnumItem('en-US') - >>> ru = EnumItem('ru-RU') - >>> EnumType('languages', [en, ru]) - - >>> print(_) - languages - ''' - - pass - - -class TableGroup: - ''' - TableGroup `items` parameter initially holds just the names of the tables, - but after parsing the whole document, PyDBMLParseResults class replaces - them with references to actual tables. - ''' - - def __init__(self, - name: str, - items: Union[List[str], List[Table]], - comment: Optional[str] = None): - self.name = name - self.items = items - self.comment = comment - - def __repr__(self): - """ - >>> tg = TableGroup('mygroup', ['t1', 't2']) - >>> tg - - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> tg.items = [t1, t2] - >>> tg - - """ - - items = [i if isinstance(i, str) else i.name for i in self.items] - return f'' - - def __getitem__(self, key) -> str: - return self.items[key] - - def __iter__(self): - return iter(self.items) - - @property - def dbml(self): - def item_to_str(val: Union[str, Table]) -> str: - if isinstance(val, Table): - return val.name - else: - return val - - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'TableGroup {self.name} {{\n' - for i in self.items: - result += f' {item_to_str(i)}\n' - result += '}' - return result - - -class Project: - def __init__(self, - name: str, - items: Optional[Dict[str, str]] = None, - note: Optional[Note] = None, - comment: Optional[str] = None): - self.name = name - self.items = items - self.note = Note(note) - self.comment = comment - - def __repr__(self): - """ - >>> Project('myproject') - - """ - - return f'' - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Project {self.name} {{\n' - if self.items: - items_str = '' - for k, v in self.items.items(): - if '\n' in v: - items_str += f"{k}: '''{v}'''\n" - else: - items_str += f"{k}: '{v}'\n" - result += indent(items_str[:-1]) + '\n' - if self.note: - result += indent(self.note.dbml) + '\n' - result += '}' - return result diff --git a/pydbml/classes/__init__.py b/pydbml/classes/__init__.py new file mode 100644 index 0000000..4ff08f2 --- /dev/null +++ b/pydbml/classes/__init__.py @@ -0,0 +1,9 @@ +from .column import Column +from .table import Table +from .enum import Enum +from .enum import EnumItem +from .index import Index +from .note import Note +from .project import Project +from .reference import Reference +from .table_group import TableGroup diff --git a/pydbml/classes/_classes.py b/pydbml/classes/_classes.py new file mode 100644 index 0000000..dee3485 --- /dev/null +++ b/pydbml/classes/_classes.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import Any +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from pydbml.parser.blueprints import IndexBlueprint +from pydbml.parser.blueprints import ReferenceBlueprint +from pydbml.exceptions import AttributeMissingError +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import IndexNotFoundError +from pydbml.exceptions import UnknownSchemaError +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql +from pydbml.tools import indent +from pydbml.tools import note_option_to_dbml + + +class SQLOjbect: + ''' + Base class for all SQL objects. + ''' + required_attributes: Tuple[str, ...] = () + + def check_attributes_for_sql(self): + ''' + Check if all attributes, required for rendering SQL are set in the + instance. If some attribute is missing, raise AttributeMissingError + ''' + for attr in self.required_attributes: + if getattr(self, attr) is None: + raise AttributeMissingError( + f'Cannot render SQL. Missing required attribute "{attr}".' + ) + + def __setattr__(self, name: str, value: Any): + """ + Required for type testing with MyPy. + """ + super().__setattr__(name, value) + + def __eq__(self, other: object) -> bool: + """ + Two instances of the same SQLObject subclass are equal if all their + attributes are equal. + """ + + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + +class TableReference(SQLOjbect): + ''' + Class, representing a foreign key constraint. + This object should be assigned to the `refs` attribute of a Table object. + Its `sql` property contains the inline definition of the FOREIGN KEY clause. + ''' + required_attributes = ('col', 'ref_table', 'ref_col') + + def __init__(self, + col: Union[Column, List[Column]], + ref_table: Table, + ref_col: Union[Column, List[Column]], + name: Optional[str] = None, + on_delete: Optional[str] = None, + on_update: Optional[str] = None): + self.col = [col] if isinstance(col, Column) else list(col) + self.ref_table = ref_table + self.ref_col = [ref_col] if isinstance(ref_col, Column) else list(ref_col) + self.name = name + self.on_update = on_update + self.on_delete = on_delete + + def __repr__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> t2 = Table('t2') + >>> TableReference(col=c1, ref_table=t2, ref_col=c2) + + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22)) + + ''' + + col_names = [c.name for c in self.col] + ref_col_names = [c.name for c in self.ref_col] + return f"" + + def __str__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> t2 = Table('t2') + >>> print(TableReference(col=c1, ref_table=t2, ref_col=c2)) + TableReference([c1] > t2[c2]) + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> print(TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22))) + TableReference([c1, c12] > t2[c2, c22]) + ''' + + components = [f"TableReference("] + components.append(f'[{", ".join(c.name for c in self.col)}]') + components.append(' > ') + components.append(self.ref_table.name) + components.append(f'[{", ".join(c.name for c in self.ref_col)}]') + return ''.join(components) + ')' + + @property + def sql(self): + ''' + Returns inline SQL of the reference, which should be a part of table definition: + + FOREIGN KEY ("order_id") REFERENCES "orders ("id") + + ''' + self.check_attributes_for_sql() + c = f'CONSTRAINT "{self.name}" ' if self.name else '' + cols = '", "'.join(c.name for c in self.col) + ref_cols = '", "'.join(c.name for c in self.ref_col) + result = ( + f'{c}FOREIGN KEY ("{cols}") ' + f'REFERENCES "{self.ref_table.name}" ("{ref_cols}")' + ) + if self.on_update: + result += f' ON UPDATE {self.on_update.upper()}' + if self.on_delete: + result += f' ON DELETE {self.on_delete.upper()}' + return result + + + + + + + + +class EnumType(Enum): + ''' + Enum object, intended to be put in the `type` attribute of a column. + + >>> en = EnumItem('en-US') + >>> ru = EnumItem('ru-RU') + >>> EnumType('languages', [en, ru]) + + >>> print(_) + languages + ''' + + pass + + + + diff --git a/pydbml/classes/base.py b/pydbml/classes/base.py new file mode 100644 index 0000000..419adca --- /dev/null +++ b/pydbml/classes/base.py @@ -0,0 +1,37 @@ +from typing import Tuple +from typing import Any +from pydbml.exceptions import AttributeMissingError + + +class SQLOjbect: + ''' + Base class for all SQL objects. + ''' + required_attributes: Tuple[str, ...] = () + + def check_attributes_for_sql(self): + ''' + Check if all attributes, required for rendering SQL are set in the + instance. If some attribute is missing, raise AttributeMissingError + ''' + for attr in self.required_attributes: + if getattr(self, attr) is None: + raise AttributeMissingError( + f'Cannot render SQL. Missing required attribute "{attr}".' + ) + + def __setattr__(self, name: str, value: Any): + """ + Required for type testing with MyPy. + """ + super().__setattr__(name, value) + + def __eq__(self, other: object) -> bool: + """ + Two instances of the same SQLObject subclass are equal if all their + attributes are equal. + """ + + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py new file mode 100644 index 0000000..b6bd1d9 --- /dev/null +++ b/pydbml/classes/column.py @@ -0,0 +1,125 @@ +from typing import Optional +from typing import Union +from typing import List +from typing import TYPE_CHECKING + +from .base import SQLOjbect +from .note import Note +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql +from pydbml.tools import note_option_to_dbml +from pydbml.exceptions import TableNotFoundError + +if TYPE_CHECKING: + from .table import Table + from .reference import Reference + + +class Column(SQLOjbect): + '''Class representing table column.''' + + required_attributes = ('name', 'type') + + def __init__(self, + name: str, + type_: str, + unique: bool = False, + not_null: bool = False, + pk: bool = False, + autoinc: bool = False, + default: Optional[Union[str, int, bool, float]] = None, + note: Optional[Union['Note', str]] = None, + # ref_blueprints: Optional[List[ReferenceBlueprint]] = None, + comment: Optional[str] = None): + self.schema = None + self.name = name + self.type = type_ + self.unique = unique + self.not_null = not_null + self.pk = pk + self.autoinc = autoinc + self.comment = comment + self.note = Note(note) + + self.default = default + self.table: Optional['Table'] = None + + def get_refs(self) -> List['Reference']: + if not self.table: + raise TableNotFoundError('Table for the column is not set') + return [ref for ref in self.table.get_refs() if ref.col1 == self] + + @property + def sql(self): + ''' + Returns inline SQL of the column, which should be a part of table definition: + + "id" integer PRIMARY KEY AUTOINCREMENT + ''' + + self.check_attributes_for_sql() + components = [f'"{self.name}"', str(self.type)] + if self.pk: + components.append('PRIMARY KEY') + if self.autoinc: + components.append('AUTOINCREMENT') + if self.unique: + components.append('UNIQUE') + if self.not_null: + components.append('NOT NULL') + if self.default is not None: + components.append('DEFAULT ' + str(self.default)) + + result = comment_to_sql(self.comment) if self.comment else '' + result += ' '.join(components) + return result + + @property + def dbml(self): + def default_to_str(val: str) -> str: + if isinstance(val, str): + if val.lower() in ('null', 'true', 'false'): + return val.lower() + elif val.startswith('(') and val.endswith(')'): + return f'`{val[1:-1]}`' + else: + return f"'{val}'" + else: # int or float or bool + return val + + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'"{self.name}" {self.type}' + + options = [ref.dbml() for ref in self.get_refs() if ref.inline] + if self.pk: + options.append('pk') + if self.autoinc: + options.append('increment') + if self.default: + options.append(f'default: {default_to_str(self.default)}') + if self.unique: + options.append('unique') + if self.not_null: + options.append('not null') + if self.note: + options.append(note_option_to_dbml(self.note)) + + if options: + result += f' [{", ".join(options)}]' + return result + + def __repr__(self): + ''' + >>> Column('name', 'VARCHAR2') + + ''' + type_name = self.type if isinstance(self.type, str) else self.type.name + return f'' + + def __str__(self): + ''' + >>> print(Column('name', 'VARCHAR2')) + name[VARCHAR2] + ''' + + return f'{self.name}[{self.type}]' diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py new file mode 100644 index 0000000..7114819 --- /dev/null +++ b/pydbml/classes/enum.py @@ -0,0 +1,122 @@ +from typing import List +from typing import Optional +from typing import Union + +from .base import SQLOjbect +from .note import Note +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql +from pydbml.tools import indent +from pydbml.tools import note_option_to_dbml + + +class EnumItem: + '''Single enum item''' + + def __init__(self, + name: str, + note: Optional[Union['Note', str]] = None, + comment: Optional[str] = None): + self.name = name + self.note = Note(note) + self.comment = comment + + def __repr__(self): + ''' + >>> EnumItem('en-US') + + ''' + + return f'' + + def __str__(self): + ''' + >>> print(EnumItem('en-US')) + en-US + ''' + + return self.name + + @property + def sql(self): + result = comment_to_sql(self.comment) if self.comment else '' + result += f"'{self.name}'," + return result + + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'"{self.name}"' + if self.note: + result += f' [{note_option_to_dbml(self.note)}]' + return result + + +class Enum(SQLOjbect): + required_attributes = ('name', 'items') + + def __init__(self, + name: str, + items: List['EnumItem'], + comment: Optional[str] = None): + self.schema = None + self.name = name + self.items = items + self.comment = comment + + def __getitem__(self, key) -> EnumItem: + return self.items[key] + + def __iter__(self): + return iter(self.items) + + def __repr__(self): + ''' + >>> en = EnumItem('en-US') + >>> ru = EnumItem('ru-RU') + >>> Enum('languages', [en, ru]) + + ''' + + item_names = [i.name for i in self.items] + classname = self.__class__.__name__ + return f'<{classname} {self.name!r}, {item_names!r}>' + + def __str__(self): + ''' + >>> en = EnumItem('en-US') + >>> ru = EnumItem('ru-RU') + >>> print(Enum('languages', [en, ru])) + languages + ''' + + return self.name + + @property + def sql(self): + ''' + Returns SQL for enum type: + + CREATE TYPE "job_status" AS ENUM ( + 'created', + 'running', + 'donef', + 'failure', + ); + + ''' + self.check_attributes_for_sql() + result = comment_to_sql(self.comment) if self.comment else '' + result += f'CREATE TYPE "{self.name}" AS ENUM (\n' + result += '\n'.join(f'{indent(i.sql, 2)}' for i in self.items) + result += '\n);' + return result + + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'Enum "{self.name}" {{\n' + items_str = '\n'.join(i.dbml for i in self.items) + result += indent(items_str) + result += '\n}' + return result diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py new file mode 100644 index 0000000..a09d63e --- /dev/null +++ b/pydbml/classes/index.py @@ -0,0 +1,134 @@ +from typing import Optional +from typing import Union +from typing import List + +from .base import SQLOjbect +from .note import Note +from .column import Column +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql +from pydbml.tools import note_option_to_dbml + + +class Index(SQLOjbect): + '''Class representing index.''' + required_attributes = ('subjects', 'table') + + def __init__(self, + subjects: List[Union[str, 'Column']], + name: Optional[str] = None, + unique: bool = False, + type_: Optional[str] = None, + pk: bool = False, + note: Optional[Union['Note', str]] = None, + comment: Optional[str] = None): + self.schema = None + self.subjects = subjects + self.table = None + + self.name = name if name else None + self.unique = unique + self.type = type_ + self.pk = pk + self.note = Note(note) + self.comment = comment + + @property + def subject_names(self): + ''' + For backward compatibility. Returns updated list of subject names. + ''' + return [s.name if isinstance(s, Column) else s for s in self.subjects] + + def __repr__(self): + ''' + >>> Index(['name', 'type']) + + >>> t = Table('t') + >>> Index(['name', 'type'], table=t) + + ''' + + table_name = self.table.name if self.table else None + return f"" + + def __str__(self): + ''' + >>> print(Index(['name', 'type'])) + Index([name, type]) + >>> t = Table('t') + >>> print(Index(['name', 'type'], table=t)) + Index(t[name, type]) + ''' + + table_name = self.table.name if self.table else '' + subjects = ', '.join(self.subject_names) + return f"Index({table_name}[{subjects}])" + + @property + def sql(self): + ''' + Returns inline SQL of the index to be created separately from table + definition: + + CREATE UNIQUE INDEX ON "products" USING HASH ("id"); + + But if it's a (composite) primary key index, returns an inline SQL for + composite primary key to be used inside table definition: + + PRIMARY KEY ("id", "name") + + ''' + self.check_attributes_for_sql() + keys = ', '.join(f'"{key.name}"' if isinstance(key, Column) else key for key in self.subjects) + if self.pk: + result = comment_to_sql(self.comment) if self.comment else '' + result += f'PRIMARY KEY ({keys})' + return result + + components = ['CREATE'] + if self.unique: + components.append('UNIQUE') + components.append('INDEX') + if self.name: + components.append(f'"{self.name}"') + components.append(f'ON "{self.table.name}"') + if self.type: + components.append(f'USING {self.type.upper()}') + components.append(f'({keys})') + result = comment_to_sql(self.comment) if self.comment else '' + result += ' '.join(components) + ';' + return result + + @property + def dbml(self): + def subject_to_str(val: str) -> str: + if val.startswith('(') and val.endswith(')'): + return f'`{val[1:-1]}`' + else: + return val + + result = comment_to_dbml(self.comment) if self.comment else '' + + subject_names = self.subject_names + + if len(subject_names) > 1: + result += f'({", ".join(subject_to_str(sn) for sn in subject_names)})' + else: + result += subject_to_str(subject_names[0]) + + options = [] + if self.name: + options.append(f"name: '{self.name}'") + if self.pk: + options.append('pk') + if self.unique: + options.append('unique') + if self.type: + options.append(f'type: {self.type}') + if self.note: + options.append(note_option_to_dbml(self.note)) + + if options: + result += f' [{", ".join(options)}]' + return result diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py new file mode 100644 index 0000000..32bbae3 --- /dev/null +++ b/pydbml/classes/note.py @@ -0,0 +1,62 @@ +from typing import Any + + +class Note: + def __init__(self, text: Any): + self.text = str(text) if text else '' + + def __str__(self): + ''' + >>> print(Note('Note text')) + Note text + ''' + + return self.text + + def __bool__(self): + return bool(self.text) + + def __repr__(self): + ''' + >>> Note('Note text') + Note('Note text') + ''' + + return f'Note({repr(self.text)})' + + @property + def sql(self): + if self.text: + return '\n'.join(f'-- {line}' for line in self.text.split('\n')) + else: + return '' + + @property + def dbml(self): + lines = [] + line = '' + for word in self.text.split(' '): + if len(line) > 80: + lines.append(line) + line = '' + if '\n' in word: + sublines = word.split('\n') + for sl in sublines[:-1]: + line += sl + lines.append(line) + line = '' + line = sublines[-1] + ' ' + else: + line += f'{word} ' + if line: + lines.append(line) + result = 'Note {\n ' + + if len(lines) > 1: + lines_str = '\n '.join(lines)[:-1] + '\n' + result += f"'''\n {lines_str} '''" + else: + result += f"'{lines[0][:-1]}'" + + result += '\n}' + return result diff --git a/pydbml/classes/project.py b/pydbml/classes/project.py new file mode 100644 index 0000000..5f70020 --- /dev/null +++ b/pydbml/classes/project.py @@ -0,0 +1,45 @@ +from typing import Dict +from typing import Optional +from typing import Union + +from .note import Note +from pydbml.tools import comment_to_dbml +from pydbml.tools import indent + + +class Project: + def __init__(self, + name: str, + items: Optional[Dict[str, str]] = None, + note: Optional[Union['Note', str]] = None, + comment: Optional[str] = None): + self.schema = None + self.name = name + self.items = items + self.note = Note(note) + self.comment = comment + + def __repr__(self): + """ + >>> Project('myproject') + + """ + + return f'' + + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'Project {self.name} {{\n' + if self.items: + items_str = '' + for k, v in self.items.items(): + if '\n' in v: + items_str += f"{k}: '''{v}'''\n" + else: + items_str += f"{k}: '{v}'\n" + result += indent(items_str[:-1]) + '\n' + if self.note: + result += indent(self.note.dbml) + '\n' + result += '}' + return result diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py new file mode 100644 index 0000000..3f08764 --- /dev/null +++ b/pydbml/classes/reference.py @@ -0,0 +1,188 @@ +from typing import Collection +from typing import Literal +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from .base import SQLOjbect +from .column import Column +from pydbml import MANY_TO_ONE +from pydbml import ONE_TO_MANY +from pydbml import ONE_TO_ONE +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql +from pydbml.exceptions import DBMLError + + +if TYPE_CHECKING: + from .table import Table + + +class Reference(SQLOjbect): + ''' + Class, representing a foreign key constraint. + It is a separate object, which is not connected to Table or Column objects + and its `sql` property contains the ALTER TABLE clause. + ''' + required_attributes = ('type', 'table1', 'col1', 'table2', 'col2') + + def __init__(self, + type_: Literal[MANY_TO_ONE, ONE_TO_MANY, ONE_TO_ONE], + table1: 'Table', + col1: Union[Column, Collection[Column]], + table2: 'Table', + col2: Union[Column, Collection[Column]], + name: Optional[str] = None, + comment: Optional[str] = None, + on_update: Optional[str] = None, + on_delete: Optional[str] = None, + inline: bool = False): + self.schema = None + self.type = type_ + self.table1 = table1 + self.col1 = [col1] if isinstance(col1, Column) else list(col1) + self.table2 = table2 + self.col2 = [col2] if isinstance(col2, Column) else list(col2) + self.name = name if name else None + self.comment = comment + self.on_update = on_update + self.on_delete = on_delete + self.inline = inline + + def __repr__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> t1 = Table('t1') + >>> t2 = Table('t2') + >>> Reference('>', table1=t1, col1=c1, table2=t2, col2=c2) + ', 't1'.['c1'], 't2'.['c2']> + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22)) + + ''' + + components = [f"' + + def __str__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> t1 = Table('t1') + >>> t2 = Table('t2') + >>> print(Reference('>', table1=t1, col1=c1, table2=t2, col2=c2)) + Reference(t1[c1] > t2[c2]) + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> print(Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22))) + Reference(t1[c1, c12] < t2[c2, c22]) + ''' + + components = [f"Reference("] + components.append(self.table1.name) + components.append(f'[{", ".join(c.name for c in self.col1)}]') + components.append(f' {self.type} ') + components.append(self.table2.name) + components.append(f'[{", ".join(c.name for c in self.col2)}]') + return ''.join(components) + ')' + + @property + def sql(self): + ''' + Returns SQL of the reference: + + ALTER TABLE "orders" ADD FOREIGN KEY ("customer_id") REFERENCES "customers ("id"); + + ''' + self.check_attributes_for_sql() + c = f'CONSTRAINT "{self.name}" ' if self.name else '' + + if self.inline: + if self.type in (MANY_TO_ONE, ONE_TO_ONE): + source_col = self.col1 + ref_table = self.table2 + ref_col = self.col2 + else: + source_col = self.col2 + ref_table = self.table1 + ref_col = self.col1 + + cols = '", "'.join(c.name for c in source_col) + ref_cols = '", "'.join(c.name for c in ref_col) + result = comment_to_sql(self.comment) if self.comment else '' + result += ( + f'{c}FOREIGN KEY ("{cols}") ' + f'REFERENCES "{ref_table.name}" ("{ref_cols}")' + ) + if self.on_update: + result += f' ON UPDATE {self.on_update.upper()}' + if self.on_delete: + result += f' ON DELETE {self.on_delete.upper()}' + return result + else: + if self.type in (MANY_TO_ONE, ONE_TO_ONE): + t1 = self.table1 + c1 = ', '.join(f'"{c.name}"' for c in self.col1) + t2 = self.table2 + c2 = ', '.join(f'"{c.name}"' for c in self.col2) + else: + t1 = self.table2 + c1 = ', '.join(f'"{c.name}"' for c in self.col2) + t2 = self.table1 + c2 = ', '.join(f'"{c.name}"' for c in self.col1) + + result = comment_to_sql(self.comment) if self.comment else '' + result += ( + f'ALTER TABLE "{t1.name}" ADD {c}FOREIGN KEY ({c1}) ' + f'REFERENCES "{t2.name}" ({c2})' + ) + if self.on_update: + result += f' ON UPDATE {self.on_update.upper()}' + if self.on_delete: + result += f' ON DELETE {self.on_delete.upper()}' + return result + ';' + + @property + def dbml(self): + if self.inline: + if len(self.col2) > 1: + raise DBMLError('Cannot render DBML: composite ref cannot be inline') + return f'ref: {self.type} {self.table2.name}.{self.col2[0].name}' + else: + result = comment_to_dbml(self.comment) if self.comment else '' + result += 'Ref' + if self.name: + result += f' {self.name}' + + if len(self.col1) == 1: + col1 = f'"{self.col1[0].name}"' + else: + names = (f'"{c.name}"' for c in self.col1) + col1 = f'({", ".join(names)})' + + if len(self.col2) == 1: + col2 = f'"{self.col2[0].name}"' + else: + names = (f'"{c.name}"' for c in self.col2) + col2 = f'({", ".join(names)})' + + options = [] + if self.on_update: + options.append(f'update: {self.on_update}') + if self.on_delete: + options.append(f'delete: {self.on_delete}') + + options_str = f' [{", ".join(options)}]' if options else '' + result += ( + ' {\n ' + f'"{self.table1.name}".{col1} ' + f'{self.type} ' + f'"{self.table2.name}".{col2}' + f'{options_str}' + '\n}' + ) + return result diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py new file mode 100644 index 0000000..c4c72ec --- /dev/null +++ b/pydbml/classes/table.py @@ -0,0 +1,241 @@ +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from .base import SQLOjbect +from .column import Column +from .index import Index +from .note import Note +from .reference import Reference +from pydbml import MANY_TO_ONE +from pydbml import ONE_TO_MANY +from pydbml import ONE_TO_ONE +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import IndexNotFoundError +from pydbml.exceptions import UnknownSchemaError +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql +from pydbml.tools import indent + + +if TYPE_CHECKING: + from pydbml.schema import Schema + + +class Table(SQLOjbect): + '''Class representing table.''' + + required_attributes = ('name',) + + def __init__(self, + name: str, + alias: Optional[str] = None, + note: Optional[Union['Note', str]] = None, + header_color: Optional[str] = None, + # refs: Optional[List[TableReference]] = None, + comment: Optional[str] = None): + self.schema = None + self.name = name + self.columns: List[Column] = [] + self.indexes: List[Index] = [] + self.alias = alias if alias else None + self.note = Note(note) + self.header_color = header_color + # self.refs = refs or [] + self.comment = comment + + # def _build_index_from_blueprint(self, blueprint: IndexBlueprint) -> None: + # subjects = [] + # for subj in blueprint.subject_names: + # if subj.startswith('(') and subj.endswith(')'): + # # subject is an expression, add it as string + # subjects.append(subj) + # else: + # try: + # col = self[subj] + # subjects.append(col) + # except KeyError: + # raise ColumnNotFoundError(f'Cannot add index, column "{subj}" not defined in table "{self.name}".') + # index = Index( + # subjects, + # name=blueprint.name, + # unique=blueprint.unique, + # type_=blueprint.type, + # pk=blueprint.pk, + # note=blueprint.note, + # comment=blueprint.comment + # ) + + # index.table = self + # self.indexes.append(index) + + def add_column(self, c: Column) -> None: + ''' + Adds column to self.columns attribute and sets in this column the + `table` attribute. + ''' + c.table = self + self.columns.append(c) + + def delete_column(self, c: Union[Column, int]) -> Column: + if isinstance(c, Column): + if c in self.columns: + c.table = None + return self.columns.pop(self.columns.index(c)) + else: + raise ColumnNotFoundError(f'Column {c} if missing in the table') + elif isinstance(c, int): + self.columns[c].table = None + return self.columns.pop(c) + + def add_index(self, i: Index) -> None: + ''' + Adds index to self.indexes attribute and sets in this index the + `table` attribute. + ''' + # for subj in i.subjects: + # if isinstance(subj, Column) and (subj not in self.columns): + # raise ColumnNotFoundError(f'Cannot add index, column "{subj}" not defined in table "{self.name}".') + + i.table = self + self.indexes.append(i) + + def delete_index(self, i: Union[Index, int]) -> Index: + if isinstance(i, Index): + if i in self.indexes: + i.table = None + return self.indexes.pop(self.indexes.index(i)) + else: + raise IndexNotFoundError(f'Index {i} if missing in the table') + elif isinstance(i, int): + self.indexes[i].table = None + return self.indexes.pop(i) + + def get_refs(self) -> List[Reference]: + if not self.schema: + raise UnknownSchemaError('Schema for the table is not set') + return [ref for ref in self.schema.refs if ref.table1 == self] + + def _get_references_for_sql(self) -> List[Reference]: + ''' + return inline references for this table sql definition + ''' + if not self.schema: + raise UnknownSchemaError('Schema for the table is not set') + result = [] + for ref in self.schema.refs: + if ref.inline: + if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ + (ref.table1 == self): + result.append(ref) + elif (ref.type == ONE_TO_MANY) and (ref.table2 == self): + result.append(ref) + return result + + + # def add_ref(self, r: TableReference) -> None: + # ''' + # Adds a reference to the table. If reference already present in the table, + # raises DuplicateReferenceError. + # ''' + # if r in self.refs: + # raise DuplicateReferenceError(f'Reference with same endpoints {r} already present in the table.') + # self.refs.append(r) + def __getitem__(self, k: Union[int, str]) -> Column: + if isinstance(k, int): + return self.columns[k] + else: + for c in self.columns: + if c.name == k: + return c + raise ColumnNotFoundError(f'Column {k} not present in table {self.name}') + + def get(self, k, default=None): + try: + return self.__getitem__(k) + except KeyError: + return default + + def __iter__(self): + return iter(self.columns) + + def __repr__(self): + ''' + >>> table = Table('customers') + >>> table +
+ ''' + + return f'
' + + def __str__(self): + ''' + >>> table = Table('customers') + >>> table.add_column(Column('id', 'INTEGER')) + >>> table.add_column(Column('name', 'VARCHAR2')) + >>> print(table) + customers(id, name) + ''' + + return f'{self.name}({", ".join(c.name for c in self.columns)})' + + @property + def sql(self): + ''' + Returns full SQL for table definition: + + CREATE TABLE "countries" ( + "code" int PRIMARY KEY, + "name" varchar, + "continent_name" varchar + ); + + Also returns indexes if they were defined: + + CREATE INDEX ON "products" ("id", "name"); + ''' + self.check_attributes_for_sql() + components = [f'CREATE TABLE "{self.name}" ('] + body = [] + body.extend(indent(c.sql, 2) for c in self.columns) + body.extend(indent(i.sql, 2) for i in self.indexes if i.pk) + body.extend(indent(r.sql, 2) for r in self._get_references_for_sql()) + components.append(',\n'.join(body)) + components.append(');') + components.extend('\n' + i.sql for i in self.indexes if not i.pk) + + result = comment_to_sql(self.comment) if self.comment else '' + result += '\n'.join(components) + + if self.note: + quoted_note = f"'{self.note.text}'" + note_sql = f'COMMENT ON TABLE "{self.name}" IS {quoted_note};' + result += f'\n\n{note_sql}' + + for col in self.columns: + if col.note: + quoted_note = f"'{col.note.text}'" + note_sql = f'COMMENT ON COLUMN "{self.name}"."{col.name}" IS {quoted_note};' + result += f'\n\n{note_sql}' + return result + + @property + def dbml(self): + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'Table "{self.name}" ' + if self.alias: + result += f'as "{self.alias}" ' + result += '{\n' + columns_str = '\n'.join(c.dbml for c in self.columns) + result += indent(columns_str) + '\n' + if self.note: + result += indent(self.note.dbml) + '\n' + if self.indexes: + result += '\n indexes {\n' + indexes_str = '\n'.join(i.dbml for i in self.indexes) + result += indent(indexes_str, 8) + '\n' + result += ' }\n' + + result += '}' + return result diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py new file mode 100644 index 0000000..ba17b74 --- /dev/null +++ b/pydbml/classes/table_group.py @@ -0,0 +1,59 @@ +from typing import List +from typing import Optional +from typing import Union + +from .table import Table +from pydbml.tools import comment_to_dbml + + +class TableGroup: + ''' + TableGroup `items` parameter initially holds just the names of the tables, + but after parsing the whole document, PyDBMLParseResults class replaces + them with references to actual tables. + ''' + + def __init__(self, + name: str, + items: Union[List[str], List[Table]], + comment: Optional[str] = None): + self.schema = None + self.name = name + self.items = items + self.comment = comment + + def __repr__(self): + """ + >>> tg = TableGroup('mygroup', ['t1', 't2']) + >>> tg + + >>> t1 = Table('t1') + >>> t2 = Table('t2') + >>> tg.items = [t1, t2] + >>> tg + + """ + + items = [i if isinstance(i, str) else i.name for i in self.items] + return f'' + + def __getitem__(self, key) -> str: + return self.items[key] + + def __iter__(self): + return iter(self.items) + + @property + def dbml(self): + def item_to_str(val: Union[str, Table]) -> str: + if isinstance(val, Table): + return val.name + else: + return val + + result = comment_to_dbml(self.comment) if self.comment else '' + result += f'TableGroup {self.name} {{\n' + for i in self.items: + result += f' {item_to_str(i)}\n' + result += '}' + return result diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index 8dec3be..40ecfb4 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -1,6 +1,6 @@ import pyparsing as pp -from pydbml.classes import Column +from pydbml.parser.blueprints import ColumnBlueprint from .common import _ from .common import _c @@ -25,7 +25,7 @@ column_type = (type_name + type_args[0, 1]) -def parse_column_type(s, l, t): +def parse_column_type(s, l, t) -> str: ''' int or "mytype" or varchar(255) ''' @@ -115,7 +115,7 @@ def parse_column(s, l, t): ''' init_dict = { 'name': t['name'], - 'type_': t['type'], + 'type': t['type'], } # deprecated for constraint in t.get('constraints', []): @@ -134,7 +134,7 @@ def parse_column(s, l, t): comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment - return Column(**init_dict) + return ColumnBlueprint(**init_dict) table_column.setParseAction(parse_column) diff --git a/pydbml/definitions/common.py b/pydbml/definitions/common.py index 23d874f..620f6db 100644 --- a/pydbml/definitions/common.py +++ b/pydbml/definitions/common.py @@ -1,6 +1,6 @@ import pyparsing as pp -from pydbml.classes import Note +from pydbml.parser.blueprints import NoteBlueprint from .generic import string_literal @@ -23,10 +23,10 @@ # n = pp.Suppress('\n')[1, ...] note = pp.CaselessLiteral("note:") + _ - string_literal('text') -note.setParseAction(lambda s, l, t: Note(t['text'])) +note.setParseAction(lambda s, l, t: NoteBlueprint(t['text'])) note_object = pp.CaselessLiteral('note') + _ - '{' + _ - string_literal('text') + _ - '}' -note_object.setParseAction(lambda s, l, t: Note(t['text'])) +note_object.setParseAction(lambda s, l, t: NoteBlueprint(t['text'])) pk = pp.CaselessLiteral("pk") unique = pp.CaselessLiteral("unique") diff --git a/pydbml/definitions/enum.py b/pydbml/definitions/enum.py index d8bc323..3ca9751 100644 --- a/pydbml/definitions/enum.py +++ b/pydbml/definitions/enum.py @@ -1,7 +1,7 @@ import pyparsing as pp -from pydbml.classes import Enum -from pydbml.classes import EnumItem +from pydbml.parser.blueprints import EnumBlueprint +from pydbml.parser.blueprints import EnumItemBlueprint from .common import _ from .common import _c @@ -48,7 +48,7 @@ def parse_enum_item(s, l, t): comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment - return EnumItem(**init_dict) + return EnumItemBlueprint(**init_dict) enum_item.setParseAction(parse_enum_item) @@ -82,7 +82,7 @@ def parse_enum(s, l, t): comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment - return Enum(**init_dict) + return EnumBlueprint(**init_dict) enum.setParseAction(parse_enum) diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index 69f704d..8ab13ce 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -1,6 +1,6 @@ import pyparsing as pp -from pydbml.classes import Index +from pydbml.parser.blueprints import IndexBlueprint from .common import _ from .common import _c @@ -41,7 +41,7 @@ def parse_index_settings(s, l, t): if 'pk' in t: result['pk'] = True if 'type' in t: - result['type_'] = t['type'] + result['type'] = t['type'] if 'note' in t: result['note'] = t['note'] if 'comment' in t: @@ -99,7 +99,7 @@ def parse_index(s, l, t): if 'comment' not in init_dict and 'comment_before' in t: comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment - return Index(**init_dict) + return IndexBlueprint(**init_dict) index.setParseAction(parse_index) diff --git a/pydbml/definitions/project.py b/pydbml/definitions/project.py index 2f05b56..11e8991 100644 --- a/pydbml/definitions/project.py +++ b/pydbml/definitions/project.py @@ -1,7 +1,7 @@ import pyparsing as pp -from pydbml.classes import Note -from pydbml.classes import Project +from pydbml.parser.blueprints import NoteBlueprint +from pydbml.parser.blueprints import ProjectBlueprint from .common import _ from .common import _c @@ -37,7 +37,7 @@ def parse_project(s, l, t): init_dict = {'name': t['name']} items = {} for item in t.get('items', []): - if isinstance(item, Note): + if isinstance(item, NoteBlueprint): init_dict['note'] = item else: k, v = item @@ -47,7 +47,7 @@ def parse_project(s, l, t): if 'comment_before' in t: comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment - return Project(**init_dict) + return ProjectBlueprint(**init_dict) project.setParseAction(parse_project) diff --git a/pydbml/definitions/reference.py b/pydbml/definitions/reference.py index 450c528..9a511c1 100644 --- a/pydbml/definitions/reference.py +++ b/pydbml/definitions/reference.py @@ -1,6 +1,6 @@ import pyparsing as pp -from pydbml.classes import ReferenceBlueprint +from pydbml.parser.blueprints import ReferenceBlueprint from .common import _ from .common import _c @@ -18,7 +18,8 @@ def parse_inline_relation(s, l, t): ''' ref: < table.column ''' - return ReferenceBlueprint(type_=t['type'], + return ReferenceBlueprint(type=t['type'], + inline=True, table2=t['table'], col2=t['field']) @@ -107,7 +108,8 @@ def parse_ref(s, l, t): } ''' init_dict = { - 'type_': t['type'], + 'type': t['type'], + 'inline': False, 'table1': t['table1'], 'col1': t['field1'], 'table2': t['table2'], diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index adcea9e..59622f8 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -1,6 +1,6 @@ import pyparsing as pp -from pydbml.classes import Table +from pydbml.parser.blueprints import TableBlueprint from .column import table_column from .common import _ @@ -79,14 +79,14 @@ def parse_table(s, l, t): if 'note' in t: # will override one from settings init_dict['note'] = t['note'][0] + if 'indexes' in t: + init_dict['indexes'] = t['indexes'] + if 'columns' in t: + init_dict['columns'] = t['columns'] if'comment_before' in t: comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment - result = Table(**init_dict) - for column in t['columns']: - result.add_column(column) - for index_ in t.get('indexes', []): - result.add_index(index_) + result = TableBlueprint(**init_dict) return result diff --git a/pydbml/definitions/table_group.py b/pydbml/definitions/table_group.py index 63f6710..3929ab5 100644 --- a/pydbml/definitions/table_group.py +++ b/pydbml/definitions/table_group.py @@ -1,6 +1,6 @@ import pyparsing as pp -from pydbml.classes import TableGroup +from pydbml.parser.blueprints import TableGroupBlueprint from .common import _ from .common import _c @@ -33,7 +33,7 @@ def parse_table_group(s, l, t): if 'comment_before' in t: comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment - return TableGroup(**init_dict) + return TableGroupBlueprint(**init_dict) table_group.setParseAction(parse_table_group) diff --git a/pydbml/exceptions.py b/pydbml/exceptions.py index b1acf7e..413b309 100644 --- a/pydbml/exceptions.py +++ b/pydbml/exceptions.py @@ -6,9 +6,20 @@ class ColumnNotFoundError(Exception): pass +class IndexNotFoundError(Exception): + pass + + class AttributeMissingError(Exception): pass class DuplicateReferenceError(Exception): pass + + +class UnknownSchemaError(Exception): + pass + +class DBMLError(Exception): + pass diff --git a/pydbml/parser.py b/pydbml/parser.py deleted file mode 100644 index 412edc1..0000000 --- a/pydbml/parser.py +++ /dev/null @@ -1,281 +0,0 @@ -from __future__ import annotations - -import pyparsing as pp - -from io import TextIOWrapper -from pathlib import Path - -from typing import Dict -from typing import List -from typing import Optional -from typing import Union - -from .classes import Enum -from .classes import Project -from .classes import Reference -from .classes import ReferenceBlueprint -from .classes import Table -from .classes import TableGroup -from .classes import TableReference -from .definitions.common import _ -from .definitions.common import comment -from .definitions.enum import enum -from .definitions.project import project -from .definitions.reference import ref -from .definitions.table import table -from .definitions.table_group import table_group -from .exceptions import ColumnNotFoundError -from .exceptions import TableNotFoundError - -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') - - -class PyDBML: - ''' - PyDBML parser factory. If properly initiated, returns PyDBMLParseResults - which contains parse results in attributes. - - Usage option 1: - - >>> with open('schema.dbml') as f: - ... p = PyDBML(f) - ... # or - ... p = PyDBML(f.read()) - - Usage option 2: - >>> p = PyDBML.parse_file('schema.dbml') - >>> # or - >>> from pathlib import Path - >>> p = PyDBML(Path('schema.dbml')) - ''' - - def __new__(cls, - source_: Optional[Union[str, Path, TextIOWrapper]] = None): - if source_ is not None: - if isinstance(source_, str): - source = source_ - elif isinstance(source_, Path): - with open(source_, encoding='utf8') as f: - source = f.read() - else: # TextIOWrapper - source = source_.read() - if source[0] == '\ufeff': # removing BOM - source = source[1:] - return cls.parse(source) - else: - return super().__new__(cls) - - def __repr__(self): - return "" - - @staticmethod - def parse(text: str) -> PyDBMLParseResults: - if text[0] == '\ufeff': # removing BOM - text = text[1:] - return PyDBMLParseResults(text) - - @staticmethod - def parse_file(file: Union[str, Path, TextIOWrapper]): - if isinstance(file, TextIOWrapper): - source = file.read() - else: - with open(file, encoding='utf8') as f: - source = f.read() - if source[0] == '\ufeff': # removing BOM - source = source[1:] - return PyDBMLParseResults(source) - - -class PyDBMLParseResults: - def __init__(self, source: str): - self.tables: List[Table] = [] - self.table_dict: Dict[str, Table] = {} - self.refs: List[Reference] = [] - self.ref_blueprints: List[ReferenceBlueprint] = [] - self.enums: List[Enum] = [] - self.table_groups: List[TableGroup] = [] - self.project: Optional[Project] = None - self.source = source - - self._set_syntax() - self._syntax.parseString(self.source, parseAll=True) - self._validate() - self._process_refs() - self._process_table_groups() - self._set_enum_types() - - def __repr__(self): - return "" - - def _set_syntax(self): - table_expr = table.copy() - ref_expr = ref.copy() - enum_expr = enum.copy() - table_group_expr = table_group.copy() - project_expr = project.copy() - - table_expr.addParseAction(self._parse_table) - ref_expr.addParseAction(self._parse_ref_blueprint) - enum_expr.addParseAction(self._parse_enum) - table_group_expr.addParseAction(self._parse_table_group) - project_expr.addParseAction(self._parse_project) - - expr = ( - table_expr - | ref_expr - | enum_expr - | table_group_expr - | project_expr - ) - self._syntax = expr[...] + ('\n' | comment)[...] + pp.StringEnd() - - def __getitem__(self, k: Union[int, str]) -> Table: - if isinstance(k, int): - return self.tables[k] - else: - return self.table_dict[k] - - def __iter__(self): - return iter(self.tables) - - def _parse_table(self, s, l, t): - table = t[0] - self.tables.append(table) - for col in table.columns: - self.ref_blueprints.extend(col.ref_blueprints) - self.table_dict[table.name] = table - - def _parse_ref_blueprint(self, s, l, t): - self.ref_blueprints.append(t[0]) - - def _parse_enum(self, s, l, t): - self.enums.append(t[0]) - - def _parse_table_group(self, s, l, t): - self.table_groups.append(t[0]) - - def _parse_project(self, s, l, t): - if not self.project: - self.project = t[0] - else: - raise SyntaxError('Project redifinition not allowed') - - def _process_refs(self): - ''' - Fill up the `refs` attribute with Reference object, created from - reference blueprints; - Add TableReference objects to each table which has references. - Validate refs at the same time. - ''' - for ref_ in self.ref_blueprints: - for tb in self.tables: - if tb.name == ref_.table1 or tb.alias == ref_.table1: - table1 = tb - break - else: - raise TableNotFoundError('Error while parsing reference:' - f'table "{ref_.table1}"" is not defined.') - for tb in self.tables: - if tb.name == ref_.table2 or tb.alias == ref_.table2: - table2 = tb - break - else: - raise TableNotFoundError('Error while parsing reference:' - f'table "{ref_.table2}"" is not defined.') - col1_names = [c.strip('() ') for c in ref_.col1.split(',')] - col1 = [] - for col_name in col1_names: - try: - col1.append(table1[col_name]) - except KeyError: - raise ColumnNotFoundError('Error while parsing reference:' - f'column "{col_name} not defined in table "{table1.name}".') - col2_names = [c.strip('() ') for c in ref_.col2.split(',')] - col2 = [] - for col_name in col2_names: - try: - col2.append(table2[col_name]) - except KeyError: - raise ColumnNotFoundError('Error while parsing reference:' - f'column "{col_name} not defined in table "{table2.name}".') - self.refs.append( - Reference( - ref_.type, - table1, - col1, - table2, - col2, - name=ref_.name, - comment=ref_.comment, - on_update=ref_.on_update, - on_delete=ref_.on_delete - ) - ) - - if ref_.type in (Reference.MANY_TO_ONE, Reference.ONE_TO_ONE): - table = table1 - init_dict = { - 'col': col1, - 'ref_table': table2, - 'ref_col': col2, - 'name': ref_.name, - 'on_update': ref_.on_update, - 'on_delete': ref_.on_delete - } - else: - table = table2 - init_dict = { - 'col': col2, - 'ref_table': table1, - 'ref_col': col1, - 'name': ref_.name, - 'on_update': ref_.on_update, - 'on_delete': ref_.on_delete - } - table.add_ref( - TableReference(**init_dict) - ) - - def _set_enum_types(self): - enum_dict = {enum.name: enum for enum in self.enums} - for table_ in self.tables: - for col in table_: - col_type = str(col.type).strip('"') - if col_type in enum_dict: - col.type = enum_dict[col_type] - - def _validate(self): - self._validate_table_groups() - - def _validate_table_groups(self): - ''' - Check that all tables, mentioned in the table groups, exist - ''' - for tg in self.table_groups: - for table_name in tg: - if table_name not in self.table_dict: - raise TableNotFoundError(f'Cannot add Table Group "{tg.name}": table "{table_name}" not found.') - - def _process_table_groups(self): - ''' - Fill up each TableGroup's `item` attribute with references to actual tables. - ''' - for tg in self.table_groups: - tg.items = [self[i] for i in tg.items] - - @property - def sql(self): - '''Returs SQL of the parsed results''' - - components = (i.sql for i in (*self.enums, *self.tables)) - return '\n\n'.join(components) - - @property - def dbml(self): - '''Generates DBML code out of parsed results''' - items = [self.project] if self.project else [] - items.extend((*self.tables, *self.refs, *self.enums, *self.table_groups)) - components = ( - i.dbml for i in items - ) - return '\n\n'.join(components) diff --git a/pydbml/parser/__init__.py b/pydbml/parser/__init__.py new file mode 100644 index 0000000..db3c175 --- /dev/null +++ b/pydbml/parser/__init__.py @@ -0,0 +1 @@ +from .parser import PyDBML, parse diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py new file mode 100644 index 0000000..6d74d7d --- /dev/null +++ b/pydbml/parser/blueprints.py @@ -0,0 +1,247 @@ +from dataclasses import dataclass +from typing import Collection +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Union + +from pydbml.classes import Column +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.classes import Project +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml.classes import EnumItem +from pydbml.classes import Enum +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import TableNotFoundError + + +ONE_TO_MANY = '<' +MANY_TO_ONE = '>' +ONE_TO_ONE = '-' + + +class Blueprint: + parser = None + + +@dataclass +class NoteBlueprint(Blueprint): + text: str + + def build(self) -> 'Note': + return Note(self.text) + + +@dataclass +class ReferenceBlueprint(Blueprint): + type: Literal[MANY_TO_ONE, ONE_TO_MANY, ONE_TO_ONE] + inline: bool + name: Optional[str] = None + table1: Optional[str] = None + col1: Optional[Union[str, Collection[str]]] = None + table2: Optional[str] = None + col2: Optional[Union[str, Collection[str]]] = None + comment: Optional[str] = None + on_update: Optional[str] = None + on_delete: Optional[str] = None + + def build(self) -> 'Reference': + ''' + both tables and columns should be present before build + ''' + if self.table1 is None: + raise TableNotFoundError("Can't build Reference, table1 unknown") + if self.table2 is None: + raise TableNotFoundError("Can't build Reference, table2 unknown") + if self.col1 is None: + raise ColumnNotFoundError("Can't build Reference, col1 unknown") + if self.col2 is None: + raise ColumnNotFoundError("Can't build Reference, col2 unknown") + + table1 = self.parser.locate_table(self.table1) + col1_list = [self.col1] if isinstance(self.col1, str) else self.col1 + col1 = [table1[col] for col in col1_list] + + table2 = self.parser.locate_table(self.table2) + col2_list = [self.col2] if isinstance(self.col2, str) else self.col2 + col2 = [table2[col] for col in col2_list] + + return Reference( + type_=self.type, + inline=self.inline, + table1=table1, + col1=col1, + table2=table2, + col2=col2, + name=self.name, + comment=self.comment, + on_update=self.on_update, + on_delete=self.on_delete + ) + + +@dataclass +class ColumnBlueprint(Blueprint): + name: str + type: str + unique: bool = False + not_null: bool = False + pk: bool = False + autoinc: bool = False + default: Optional[Union[str, int, bool, float]] = None + note: Optional[NoteBlueprint] = None + ref_blueprints: Optional[List[ReferenceBlueprint]] = None + comment: Optional[str] = None + + def build(self) -> 'Column': + return Column( + name=self.name, + type_=self.type, + unique=self.unique, + not_null=self.not_null, + pk=self.pk, + autoinc=self.autoinc, + default=self.default, + note=self.note.build() if self.note else None, + comment=self.comment + ) + + +@dataclass +class IndexBlueprint(Blueprint): + subject_names: List[str] + name: Optional[str] = None + unique: bool = False + type: Optional[str] = None + pk: bool = False + note: Optional[NoteBlueprint] = None + comment: Optional[str] = None + + table = None + + def build(self) -> 'Index': + return Index( + # TableBlueprint will process subjects + subjects=list(self.subject_names), + name=self.name, + unique=self.unique, + type_=self.type, + pk=self.pk, + note=self.note.build() if self.note else None, + comment=self.comment + ) + + +@dataclass +class TableBlueprint(Blueprint): + name: str + columns: Optional[List[ColumnBlueprint]] = None # TODO: should it be optional? + indexes: Optional[List[IndexBlueprint]] = None + alias: Optional[str] = None + note: Optional[NoteBlueprint] = None + header_color: Optional[str] = None + comment: Optional[str] = None + + def build(self) -> 'Table': + result = Table( + name=self.name, + alias=self.alias, + note=self.note.build() if self.note else None, + header_color=self.header_color, + comment=self.comment + ) + columns = self.columns or [] + indexes = self.indexes or [] + for col_bp in columns: + result.add_column(col_bp.build()) + for index_bp in indexes: + index = index_bp.build() + new_subjects = [] + for subj in index.subjects: + if subj.startswith('(') and subj.endswith(')'): + new_subjects.append(subj) + else: + for col in result.columns: + if col.name == subj: + new_subjects.append(col) + break + else: + raise ColumnNotFoundError( + f'Cannot add index, column "{subj}" not defined in' + ' table "{self.name}".' + ) + index.subjects = new_subjects + result.add_index(index) + return result + + def get_reference_blueprints(self): + ''' the inline ones ''' + result = [] + for col in self.columns: + for ref_bp in col.ref_blueprints: + ref_bp.table1 = self.name + ref_bp.col1 = col.name + result.append(ref_bp) + return result + + +@dataclass +class EnumItemBlueprint(Blueprint): + name: str + note: Optional[NoteBlueprint] = None + comment: Optional[str] = None + + def build(self) -> 'EnumItem': + return EnumItem( + name=self.name, + note=self.note.build() if self.note else None, + comment=self.comment + ) + + +@dataclass +class EnumBlueprint(Blueprint): + name: str + items: List[EnumItemBlueprint] + comment: Optional[str] = None + + def build(self) -> 'Enum': + return Enum( + name=self.name, + items=[ei.build() for ei in self.items], + comment=self.comment + ) + + +@dataclass +class ProjectBlueprint(Blueprint): + name: str + items: Optional[Dict[str, str]] = None + note: Optional[NoteBlueprint] = None + comment: Optional[str] = None + + def build(self) -> 'Project': + return Project( + name=self.name, + items=dict(self.items) if self.items else {}, + note=self.note.build() if self.note else None, + comment=self.comment + ) + + +@dataclass +class TableGroupBlueprint(Blueprint): + name: str + items: List[str] + comment: Optional[str] = None + + def build(self) -> 'TableGroup': + return TableGroup( + name=self.name, + items=[self.parser.locate_table(table) for table in self.items], + comment=self.comment + ) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py new file mode 100644 index 0000000..c234799 --- /dev/null +++ b/pydbml/parser/parser.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import pyparsing as pp + +from io import TextIOWrapper +from pathlib import Path + +from typing import List +from typing import Optional +from typing import Union + +from .blueprints import EnumBlueprint +from .blueprints import ProjectBlueprint +from .blueprints import ReferenceBlueprint +from .blueprints import TableBlueprint +from .blueprints import TableGroupBlueprint +from pydbml.classes import Table +from pydbml.definitions.common import comment +from pydbml.definitions.enum import enum +from pydbml.definitions.project import project +from pydbml.definitions.reference import ref +from pydbml.definitions.table import table +from pydbml.definitions.table_group import table_group +from pydbml.exceptions import TableNotFoundError +from pydbml.schema import Schema + +pp.ParserElement.setDefaultWhitespaceChars(' \t\r') + + +class PyDBML: + ''' + PyDBML parser factory. If properly initiated, returns PyDBMLParseResults + which contains parse results in attributes. + + Usage option 1: + + >>> with open('schema.dbml') as f: + ... p = PyDBML(f) + ... # or + ... p = PyDBML(f.read()) + + Usage option 2: + >>> p = PyDBML.parse_file('schema.dbml') + >>> # or + >>> from pathlib import Path + >>> p = PyDBML(Path('schema.dbml')) + ''' + + def __new__(cls, + source_: Optional[Union[str, Path, TextIOWrapper]] = None): + if source_ is not None: + if isinstance(source_, str): + source = source_ + elif isinstance(source_, Path): + with open(source_, encoding='utf8') as f: + source = f.read() + else: # TextIOWrapper + source = source_.read() + if source[0] == '\ufeff': # removing BOM + source = source[1:] + return cls.parse(source) + else: + return super().__new__(cls) + + def __repr__(self): + return "" + + @staticmethod + def parse(text: str) -> PyDBMLParser: + if text[0] == '\ufeff': # removing BOM + text = text[1:] + + return PyDBMLParser(text) + + @staticmethod + def parse_file(file: Union[str, Path, TextIOWrapper]) -> PyDBMLParser: + if isinstance(file, TextIOWrapper): + source = file.read() + else: + with open(file, encoding='utf8') as f: + source = f.read() + if source[0] == '\ufeff': # removing BOM + source = source[1:] + parser = PyDBMLParser(source) + parser.parse() + return parser + + +def parse(source: str): + parser = PyDBMLParser(source) + return parser.parse() + + +class PyDBMLParser: + def __init__(self, source: str): + self.schema = None + + self.ref_blueprints: List[ReferenceBlueprint] = [] + self.table_groups = [] + self.source = source + self.tables = [] + self.refs = [] + self.enums = [] + self.table_groups = [] + self.project = None + + def parse(self): + self._set_syntax() + self._syntax.parseString(self.source, parseAll=True) + self.build_schema() + + def __repr__(self): + return "" + + def _set_syntax(self): + table_expr = table.copy() + ref_expr = ref.copy() + enum_expr = enum.copy() + table_group_expr = table_group.copy() + project_expr = project.copy() + + table_expr.addParseAction(self.parse_blueprint) + ref_expr.addParseAction(self.parse_blueprint) + enum_expr.addParseAction(self.parse_blueprint) + table_group_expr.addParseAction(self.parse_blueprint) + project_expr.addParseAction(self.parse_blueprint) + + expr = ( + table_expr + | ref_expr + | enum_expr + | table_group_expr + | project_expr + ) + self._syntax = expr[...] + ('\n' | comment)[...] + pp.StringEnd() + + def parse_blueprint(self, s, l, t): + blueprint = t[0] + if isinstance(blueprint, TableBlueprint): + self.tables.append(blueprint) + elif isinstance(blueprint, ReferenceBlueprint): + self.refs.append(blueprint) + elif isinstance(blueprint, EnumBlueprint): + self.enums.append(blueprint) + elif isinstance(blueprint, TableGroupBlueprint): + self.table_groups.append(blueprint) + elif isinstance(blueprint, ProjectBlueprint): + self.project = blueprint + else: + raise RuntimeError(f'type unknown: {blueprint}') + blueprint.parser = self + + def locate_table(self, name: str) -> 'Table': + if not self.schema: + raise RuntimeError('Schema is not ready') + try: + result = self.schema[name] + except KeyError: + raise TableNotFoundError(f'Table {name} not present in the schema') + return result + + def build_schema(self): + self.schema = Schema() + for table_bp in self.tables: + self.schema.add(table_bp.build()) + self.ref_blueprints.extend(table_bp.get_reference_blueprints()) + for enum_bp in self.enums: + self.schema.add(enum_bp.build()) + for table_group_bp in self.table_groups: + self.schema.add(table_group_bp.build()) + if self.project: + self.schema.add(project.build()) + for ref_bp in self.refs: + self.schema.add(ref_bp.build()) + + +# class Temp: +# def _parse_table(self, s, l, t): +# table = t[0] +# self.schema.add_table(table) +# for col in table.columns: +# self.ref_blueprints.extend(col.ref_blueprints) + +# def _parse_ref_blueprint(self, s, l, t): +# self.ref_blueprints.append(t[0]) + +# def _parse_enum(self, s, l, t): +# self.schema.add_enum(t[0]) + +# def _parse_table_group(self, s, l, t): +# self.table_groups.append(t[0]) + +# def _parse_project(self, s, l, t): +# self.schema.add_project(t[0]) + +# def _process_refs(self): +# ''' +# Fill up the `refs` attribute with Reference object, created from +# reference blueprints; +# Add TableReference objects to each table which has references. +# Validate refs at the same time. +# ''' +# self.schema._build_refs_from_blueprints(self.ref_blueprints) + +# def _set_enum_types(self): +# enum_dict = {enum.name: enum for enum in self.schema.enums} +# for table_ in self.schema.tables: +# for col in table_: +# col_type = str(col.type).strip('"') +# if col_type in enum_dict: +# col.type = enum_dict[col_type] + +# def _validate(self): +# self._validate_table_groups() + +# def _validate_table_groups(self): +# ''' +# Check that all tables, mentioned in the table groups, exist +# ''' +# for tg in self.table_groups: +# for table_name in tg: +# if table_name not in self.schema.tables_dict: +# raise TableNotFoundError(f'Cannot add Table Group "{tg.name}": table "{table_name}" not found.') + +# def _process_table_groups(self): +# ''' +# Fill up each TableGroup's `item` attribute with references to actual tables. +# ''' +# for tg in self.table_groups: +# tg.items = [self.schema.tables_dict[i] for i in tg.items] +# self.schema.add_table_group(tg) + +# @property +# def sql(self): +# '''Returs SQL of the parsed results''' + +# components = (i.sql for i in (*self.enums, *self.tables)) +# return '\n\n'.join(components) + +# @property +# def dbml(self): +# '''Generates DBML code out of parsed results''' +# items = [self.project] if self.project else [] +# items.extend((*self.tables, *self.refs, *self.enums, *self.table_groups)) +# components = ( +# i.dbml for i in items +# ) +# return '\n\n'.join(components) diff --git a/pydbml/schema.py b/pydbml/schema.py new file mode 100644 index 0000000..f2c6e3b --- /dev/null +++ b/pydbml/schema.py @@ -0,0 +1,211 @@ +from .classes import Enum +from .classes import Project +from .classes import Reference +from .classes import Table +from .classes import TableGroup +from pydbml.parser.blueprints import ReferenceBlueprint +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import TableNotFoundError +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + + +class SchemaValidationError(Exception): + pass + + +class Schema: + def __init__(self) -> None: + self.tables: List[Table] = [] + self.tables_dict: Dict[str, Table] = {} + self.refs: List[Reference] = [] + # self.ref_blueprints: List[ReferenceBlueprint] = [] + self.enums: List[Enum] = [] + self.table_groups: List[TableGroup] = [] + self.project: Optional[Project] = None + + def __repr__(self) -> str: + return f"" + + # def _build_refs_from_blueprints(self, blueprints: List[ReferenceBlueprint]): + # ''' + # Fill up the `refs` attribute with Reference object, created from + # reference blueprints; + # Add TableReference objects to each table which has references. + # Validate refs at the same time. + # ''' + # for ref_ in blueprints: + # for table_ in self.tables: + # if table_.name == ref_.table1 or table_.alias == ref_.table1: + # table1 = table_ + # break + # else: + # raise TableNotFoundError('Error while parsing reference:' + # f'table "{ref_.table1}"" is not defined.') + # for table_ in self.tables: + # if table_.name == ref_.table2 or table_.alias == ref_.table2: + # table2 = table_ + # break + # else: + # raise TableNotFoundError('Error while parsing reference:' + # f'table "{ref_.table2}"" is not defined.') + # col1_names = [c.strip('() ') for c in ref_.col1.split(',')] + # col1 = [] + # for col_name in col1_names: + # try: + # col1.append(table1[col_name]) + # except KeyError: + # raise ColumnNotFoundError('Error while parsing reference:' + # f'column "{col_name} not defined in table "{table1.name}".') + # col2_names = [c.strip('() ') for c in ref_.col2.split(',')] + # col2 = [] + # for col_name in col2_names: + # try: + # col2.append(table2[col_name]) + # except KeyError: + # raise ColumnNotFoundError('Error while parsing reference:' + # f'column "{col_name} not defined in table "{table2.name}".') + # self.add_reference( + # Reference( + # ref_.type, + # table1, + # col1, + # table2, + # col2, + # name=ref_.name, + # comment=ref_.comment, + # on_update=ref_.on_update, + # on_delete=ref_.on_delete + # ) + # ) + + def _set_schema(self, obj: Any) -> None: + obj.schema = self + + def add(self, obj: Any) -> Any: + if isinstance(obj, Table): + return self.add_table(obj) + elif isinstance(obj, Reference): + return self.add_reference(obj) + elif isinstance(obj, Enum): + return self.add_enum(obj) + elif isinstance(obj, TableGroup): + return self.add_table_group(obj) + elif isinstance(obj, Project): + return self.add_project(obj) + else: + raise SchemaValidationError(f'Unsupported type {type(obj)}.') + + def add_table(self, obj: Table) -> Table: + if obj.name in self.tables_dict: + raise SchemaValidationError(f'Table {obj.name} is already in the schema.') + if obj in self.tables: + raise SchemaValidationError(f'{obj} is already in the schema.') + + self._set_schema(obj) + + self.tables.append(obj) + self.tables_dict[obj.name] = obj + return obj + + def add_reference(self, obj: Reference): + for table in (obj.table1, obj.table2): + if table in self.tables: + break + else: + raise SchemaValidationError( + 'Cannot add reference. At least one of the referenced tables' + ' should belong to this schema' + ) + if obj in self.refs: + raise SchemaValidationError(f'{obj} is already in the schema.') + + self._set_schema(obj) + self.refs.append(obj) + return obj + + def add_enum(self, obj: Enum) -> Enum: + for enum in self.enums: + if enum.name == obj.name: + raise SchemaValidationError(f'Enum {obj.name} is already in the schema.') + if obj in self.enums: + raise SchemaValidationError(f'{obj} is already in the schema.') + + self._set_schema(obj) + self.enums.append(obj) + return obj + + def add_table_group(self, obj: TableGroup) -> TableGroup: + for table_group in self.table_groups: + if table_group.name == obj.name: + raise SchemaValidationError(f'TableGroup {obj.name} is already in the schema.') + if obj in self.table_groups: + raise SchemaValidationError(f'{obj} is already in the schema.') + + self._set_schema(obj) + self.table_groups.append(obj) + return obj + + def add_project(self, obj: Project) -> Project: + self._set_schema(obj) + self.project = obj + return obj + + def delete(self, obj: Any) -> Any: + if isinstance(obj, Table): + return self.delete_table(obj) + elif isinstance(obj, Reference): + return self.delete_reference(obj) + elif isinstance(obj, Enum): + return self.delete_enum(obj) + elif isinstance(obj, TableGroup): + return self.delete_table_group(obj) + elif isinstance(obj, Project): + return self.delete_project() + else: + raise SchemaValidationError(f'Unsupported type {type(obj)}.') + + def delete_table(self, obj: Table) -> Table: + try: + index = self.tables.index(obj) + except ValueError: + raise SchemaValidationError(f'{obj} is not in the schema.') + self.tables.pop(index).schema = None + return self.tables_dict.pop(obj.name) + + def delete_reference(self, obj: Reference) -> Reference: + try: + index = self.refs.index(obj) + except ValueError: + raise SchemaValidationError(f'{obj} is not in the schema.') + result = self.refs.pop(index) + result.schema = None + return result + + def delete_enum(self, obj: Enum) -> Enum: + try: + index = self.enums.index(obj) + except ValueError: + raise SchemaValidationError(f'{obj} is not in the schema.') + result = self.enums.pop(index) + result.schema = None + return result + + def delete_table_group(self, obj: TableGroup) -> TableGroup: + try: + index = self.tables_groups.index(obj) + except ValueError: + raise SchemaValidationError(f'{obj} is not in the schema.') + result = self.table_groups.pop(index) + result.schema = None + return result + + def delete_project(self) -> Project: + if self.Project is None: + raise SchemaValidationError(f'Project is not set.') + result = self.Project + self.Project = None + result.schema = None + return result diff --git a/pydbml/tools.py b/pydbml/tools.py index f5e3579..a1128a8 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -2,6 +2,7 @@ if TYPE_CHECKING: from .classes import Note + def comment_to_dbml(val: str) -> str: return '\n'.join(f'// {cl}' for cl in val.split('\n')) + '\n' @@ -20,4 +21,4 @@ def note_option_to_dbml(val: 'Note') -> str: def indent(val: str, spaces=4) -> str: if val == '': return val - return ' ' * spaces + val.replace('\n', '\n' +' ' * spaces) \ No newline at end of file + return ' ' * spaces + val.replace('\n', '\n' +' ' * spaces) diff --git a/setup.py b/setup.py index 2497ce6..2c6deec 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setup( name='pydbml', - python_requires='>=3.5', + python_requires='>=3.8', description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', diff --git a/test/test_classes.py b/test/_test_classes.py similarity index 92% rename from test/test_classes.py rename to test/_test_classes.py index 818dd1f..7c483ff 100644 --- a/test/test_classes.py +++ b/test/_test_classes.py @@ -548,91 +548,7 @@ def test_dbml_full(self): }""" self.assertEqual(t.dbml, expected) -class TestEnumItem(TestCase): - def test_dbml_simple(self): - ei = EnumItem('en-US') - expected = '"en-US"' - self.assertEqual(ei.dbml, expected) - - def test_sql(self): - ei = EnumItem('en-US') - expected = "'en-US'," - self.assertEqual(ei.sql, expected) - - def test_dbml_full(self): - ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') - expected = \ -'''// EnumItem comment -"en-US" [note: 'preferred']''' - self.assertEqual(ei.dbml, expected) - - -class TestEnum(TestCase): - def test_simple_enum(self) -> None: - items = [ - EnumItem('created'), - EnumItem('running'), - EnumItem('donef'), - EnumItem('failure'), - ] - e = Enum('job_status', items) - expected = \ -'''CREATE TYPE "job_status" AS ENUM ( - 'created', - 'running', - 'donef', - 'failure', -);''' - self.assertEqual(e.sql, expected) - - def test_comments(self) -> None: - items = [ - EnumItem('created', comment='EnumItem comment'), - EnumItem('running'), - EnumItem('donef', comment='EnumItem\nmultiline comment'), - EnumItem('failure'), - ] - e = Enum('job_status', items, comment='Enum comment') - expected = \ -'''-- Enum comment -CREATE TYPE "job_status" AS ENUM ( - -- EnumItem comment - 'created', - 'running', - -- EnumItem - -- multiline comment - 'donef', - 'failure', -);''' - self.assertEqual(e.sql, expected) - - def test_dbml_simple(self): - items = [EnumItem('en-US'), EnumItem('ru-RU'), EnumItem('en-GB')] - e = Enum('lang', items) - expected = \ -'''Enum "lang" { - "en-US" - "ru-RU" - "en-GB" -}''' - self.assertEqual(e.dbml, expected) - def test_dbml_full(self): - items = [ - EnumItem('en-US', note='preferred'), - EnumItem('ru-RU', comment='Multiline\ncomment'), - EnumItem('en-GB')] - e = Enum('lang', items, comment="Enum comment") - expected = \ -'''// Enum comment -Enum "lang" { - "en-US" [note: 'preferred'] - // Multiline - // comment - "ru-RU" - "en-GB" -}''' - self.assertEqual(e.dbml, expected) diff --git a/test/_test_parser.py b/test/_test_parser.py new file mode 100644 index 0000000..f6b76f7 --- /dev/null +++ b/test/_test_parser.py @@ -0,0 +1,96 @@ +import os + +from pathlib import Path +from unittest import TestCase + +from pydbml import PyDBML +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import TableNotFoundError + + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class TestParser(TestCase): + def setUp(self): + self.results = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') + + def test_table_refs(self) -> None: + p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') + r = p['order_items'].refs + self.assertEqual(r[0].col[0].name, 'order_id') + self.assertEqual(r[0].ref_table.name, 'orders') + self.assertEqual(r[0].ref_col[0].name, 'id') + r = p['products'].refs + self.assertEqual(r[0].col[0].name, 'merchant_id') + self.assertEqual(r[0].ref_table.name, 'merchants') + self.assertEqual(r[0].ref_col[0].name, 'id') + r = p['users'].refs + self.assertEqual(r[0].col[0].name, 'country_code') + self.assertEqual(r[0].ref_table.name, 'countries') + self.assertEqual(r[0].ref_col[0].name, 'code') + + def test_refs(self) -> None: + p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') + r = p.refs + self.assertEqual(r[0].table1.name, 'orders') + self.assertEqual(r[0].col1[0].name, 'id') + self.assertEqual(r[0].table2.name, 'order_items') + self.assertEqual(r[0].col2[0].name, 'order_id') + self.assertEqual(r[2].table1.name, 'users') + self.assertEqual(r[2].col1[0].name, 'country_code') + self.assertEqual(r[2].table2.name, 'countries') + self.assertEqual(r[2].col2[0].name, 'code') + self.assertEqual(r[4].table1.name, 'products') + self.assertEqual(r[4].col1[0].name, 'merchant_id') + self.assertEqual(r[4].table2.name, 'merchants') + self.assertEqual(r[4].col2[0].name, 'id') + + +class TestRefs(TestCase): + def test_reference_aliases(self): + results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_aliases.dbml') + posts, reviews, users = results['posts'], results['reviews'], results['users'] + posts2, reviews2, users2 = results['posts2'], results['reviews2'], results['users2'] + + rs = results.refs + self.assertEqual(rs[0].table1, users) + self.assertEqual(rs[0].table2, posts) + self.assertEqual(rs[1].table1, users) + self.assertEqual(rs[1].table2, reviews) + + self.assertEqual(rs[2].table1, posts2) + self.assertEqual(rs[2].table2, users2) + self.assertEqual(rs[3].table1, reviews2) + self.assertEqual(rs[3].table2, users2) + + def test_composite_references(self): + results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_composite.dbml') + self.assertEqual(len(results.tables), 4) + posts, reviews = results['posts'], results['reviews'] + posts2, reviews2 = results['posts2'], results['reviews2'] + + rs = results.refs + self.assertEqual(len(rs), 2) + + self.assertEqual(rs[0].table1, posts) + self.assertEqual(rs[0].col1, [posts['id'], posts['tag']]) + self.assertEqual(rs[0].table2, reviews) + self.assertEqual(rs[0].col2, [reviews['post_id'], reviews['tag']]) + + self.assertEqual(rs[1].table1, posts2) + self.assertEqual(rs[1].col1, [posts2['id'], posts2['tag']]) + self.assertEqual(rs[1].table2, reviews2) + self.assertEqual(rs[1].col2, [reviews2['post_id'], reviews2['tag']]) + + +class TestFaulty(TestCase): + def test_bad_reference(self) -> None: + with self.assertRaises(TableNotFoundError): + PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_table.dbml') + with self.assertRaises(ColumnNotFoundError): + PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_column.dbml') + + def test_bad_index(self) -> None: + with self.assertRaises(ColumnNotFoundError): + PyDBML(TEST_DATA_PATH / 'wrong_index.dbml') diff --git a/test/test_blueprints/test_column.py b/test/test_blueprints/test_column.py new file mode 100644 index 0000000..78a5b1e --- /dev/null +++ b/test/test_blueprints/test_column.py @@ -0,0 +1,43 @@ +from unittest import TestCase + +from pydbml.classes import Note +from pydbml.classes import Column +from pydbml.parser.blueprints import ColumnBlueprint +from pydbml.parser.blueprints import NoteBlueprint + + +class TestColumn(TestCase): + def test_build_minimal(self) -> None: + bp = ColumnBlueprint( + name='testcol', + type='varchar' + ) + result = bp.build() + self.assertIsInstance(result, Column) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.type, bp.type) + + def test_build_full(self) -> None: + bp = ColumnBlueprint( + name='id', + type='number', + unique=True, + not_null=True, + pk=True, + autoinc=True, + default=0, + note=NoteBlueprint(text='note text'), + comment='Col commment' + ) + result = bp.build() + self.assertIsInstance(result, Column) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.type, bp.type) + self.assertEqual(result.unique, bp.unique) + self.assertEqual(result.not_null, bp.not_null) + self.assertEqual(result.pk, bp.pk) + self.assertEqual(result.autoinc, bp.autoinc) + self.assertEqual(result.default, bp.default) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) diff --git a/test/test_blueprints/test_enum.py b/test/test_blueprints/test_enum.py new file mode 100644 index 0000000..63a9fcf --- /dev/null +++ b/test/test_blueprints/test_enum.py @@ -0,0 +1,50 @@ +from unittest import TestCase + +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Note +from pydbml.parser.blueprints import EnumBlueprint +from pydbml.parser.blueprints import EnumItemBlueprint +from pydbml.parser.blueprints import NoteBlueprint + + +class TestEnumItemBlueprint(TestCase): + def test_build_minimal(self) -> None: + bp = EnumItemBlueprint( + name='Red' + ) + result = bp.build() + self.assertIsInstance(result, EnumItem) + self.assertEqual(result.name, bp.name) + + def test_build_full(self) -> None: + bp = EnumItemBlueprint( + name='Red', + note=NoteBlueprint(text='Note text'), + comment='Comment text' + ) + result = bp.build() + self.assertIsInstance(result, EnumItem) + self.assertEqual(result.name, bp.name) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) + + +class TestEnumBlueprint(TestCase): + def test_build(self) -> None: + bp = EnumBlueprint( + name='Colors', + items=[ + EnumItemBlueprint(name='Red'), + EnumItemBlueprint(name='Green'), + EnumItemBlueprint(name='Blue') + ], + comment='Comment text' + ) + result = bp.build() + self.assertIsInstance(result, Enum) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.comment, bp.comment) + for ei in result.items: + self.assertIsInstance(ei, EnumItem) diff --git a/test/test_blueprints/test_index.py b/test/test_blueprints/test_index.py new file mode 100644 index 0000000..64f5df5 --- /dev/null +++ b/test/test_blueprints/test_index.py @@ -0,0 +1,37 @@ +from unittest import TestCase + +from pydbml.classes import Note +from pydbml.classes import Index +from pydbml.parser.blueprints import IndexBlueprint +from pydbml.parser.blueprints import NoteBlueprint + + +class TestIndex(TestCase): + def test_build_minimal(self) -> None: + bp = IndexBlueprint( + subject_names=['a', 'b', 'c'] + ) + result = bp.build() + self.assertIsInstance(result, Index) + self.assertEqual(result.subjects, bp.subject_names) + + def test_build_full(self) -> None: + bp = IndexBlueprint( + subject_names=['a', 'b', 'c'], + name='MyIndex', + unique=True, + type='hash', + pk=True, + note=NoteBlueprint(text='Note text'), + comment='Comment text' + ) + result = bp.build() + self.assertIsInstance(result, Index) + self.assertEqual(result.subject_names, bp.subject_names) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.unique, bp.unique) + self.assertEqual(result.type, bp.type) + self.assertEqual(result.pk, bp.pk) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) diff --git a/test/test_blueprints/test_note.py b/test/test_blueprints/test_note.py new file mode 100644 index 0000000..4ae20d0 --- /dev/null +++ b/test/test_blueprints/test_note.py @@ -0,0 +1,12 @@ +from unittest import TestCase + +from pydbml.classes import Note +from pydbml.parser.blueprints import NoteBlueprint + + +class TestNote(TestCase): + def test_build(self) -> None: + bp = NoteBlueprint(text='Note text') + result = bp.build() + self.assertIsInstance(result, Note) + self.assertEqual(result.text, bp.text) diff --git a/test/test_blueprints/test_project.py b/test/test_blueprints/test_project.py new file mode 100644 index 0000000..eebff86 --- /dev/null +++ b/test/test_blueprints/test_project.py @@ -0,0 +1,36 @@ +from unittest import TestCase + +from pydbml.classes import Note +from pydbml.classes import Project +from pydbml.parser.blueprints import NoteBlueprint +from pydbml.parser.blueprints import ProjectBlueprint + + +class TestProjectBlueprint(TestCase): + def test_build_minimal(self) -> None: + bp = ProjectBlueprint( + name='MyProject' + ) + result = bp.build() + self.assertIsInstance(result, Project) + self.assertEqual(result.name, bp.name) + + def test_build_full(self) -> None: + bp = ProjectBlueprint( + name='MyProject', + items={ + 'author': 'John Wick', + 'nickname': 'Baba Yaga', + 'reason': 'revenge' + }, + note=NoteBlueprint(text='note text'), + comment='comment text' + ) + result = bp.build() + self.assertIsInstance(result, Project) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.items, bp.items) + self.assertIsNot(result.items, bp.items) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) diff --git a/test/test_blueprints/test_reference.py b/test/test_blueprints/test_reference.py new file mode 100644 index 0000000..9678875 --- /dev/null +++ b/test/test_blueprints/test_reference.py @@ -0,0 +1,68 @@ +from unittest import TestCase +from unittest.mock import Mock + +from pydbml.classes import Column +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import TableNotFoundError +from pydbml.parser.blueprints import ReferenceBlueprint + + +class TestReferenceBlueprint(TestCase): + def test_build_minimal(self) -> None: + bp = ReferenceBlueprint( + type='>', + inline=True, + table1='table1', + col1='col1', + table2='table2', + col2='col2', + ) + + t1 = Table( + name='table1' + ) + c1 = Column(name='col1', type_='Number') + t1.add_column(c1) + t2 = Table( + name='table2' + ) + c2 = Column(name='col2', type_='Varchar') + t2.add_column(c2) + + parserMock = Mock() + parserMock.locate_table.side_effect = [t1, t2] + bp.parser = parserMock + result = bp.build() + self.assertIsInstance(result, Reference) + self.assertEqual(result.type, bp.type) + self.assertEqual(result.inline, bp.inline) + self.assertEqual(parserMock.locate_table.call_count, 2) + self.assertEqual(result.table1, t1) + self.assertEqual(result.col1, [c1]) + self.assertEqual(result.table2, t2) + self.assertEqual(result.col2, [c2]) + + def test_tables_and_cols_are_set(self) -> None: + bp = ReferenceBlueprint( + type='>', + inline=True, + table1=None, + col1='col1', + table2='table2', + col2='col2' + ) + with self.assertRaises(TableNotFoundError): + bp.build() + + bp = ReferenceBlueprint( + type='>', + inline=True, + table1='table1', + col1='col1', + table2='table2', + col2=None + ) + with self.assertRaises(ColumnNotFoundError): + bp.build() diff --git a/test/test_blueprints/test_table.py b/test/test_blueprints/test_table.py new file mode 100644 index 0000000..890ab33 --- /dev/null +++ b/test/test_blueprints/test_table.py @@ -0,0 +1,125 @@ +from unittest import TestCase + +from pydbml.exceptions import ColumnNotFoundError +from pydbml.classes import Note +from pydbml.classes import Table +from pydbml.classes import Index +from pydbml.classes import Column +from pydbml.parser.blueprints import IndexBlueprint +from pydbml.parser.blueprints import NoteBlueprint +from pydbml.parser.blueprints import ColumnBlueprint +from pydbml.parser.blueprints import TableBlueprint +from pydbml.parser.blueprints import ReferenceBlueprint + + +class TestTable(TestCase): + def test_build_minimal(self) -> None: + bp = TableBlueprint(name='TestTable') + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + + def test_build_full_simple(self) -> None: + bp = TableBlueprint( + name='TestTable', + alias='TestAlias', + note=NoteBlueprint(text='Note text'), + header_color='#ccc', + comment='comment text' + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.alias, bp.alias) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.header_color, bp.header_color) + self.assertEqual(result.comment, bp.comment) + + def test_with_columns(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint(name='id', type='Integer', not_null=True, autoinc=True), + ColumnBlueprint(name='name', type='Varchar') + ] + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + for col in result.columns: + self.assertIsInstance(col, Column) + + def test_with_indexes(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint(name='id', type='Integer', not_null=True, autoinc=True), + ColumnBlueprint(name='name', type='Varchar') + ], + indexes=[ + IndexBlueprint(subject_names=['name', 'id'], unique=True), + IndexBlueprint(subject_names=['id', '(name*2)'], name='ExprIndex') + ] + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + for col in result.columns: + self.assertIsInstance(col, Column) + for ind in result.indexes: + self.assertIsInstance(ind, Index) + + def test_bad_index(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint(name='id', type='Integer', not_null=True, autoinc=True), + ColumnBlueprint(name='name', type='Varchar') + ], + indexes=[ + IndexBlueprint(subject_names=['name', 'id'], unique=True), + IndexBlueprint(subject_names=['wrong', '(name*2)'], name='ExprIndex') + ] + ) + with self.assertRaises(ColumnNotFoundError): + bp.build() + + def test_get_reference_blueprints(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint( + name='id', + type='Integer', + not_null=True, + autoinc=True, + ref_blueprints=[ + ReferenceBlueprint( + type='<', + inline=True, + table2='AnotherTable', + col2=['AnotherCol']) + ] + ), + ColumnBlueprint( + name='name', + type='Varchar', + ref_blueprints=[ + ReferenceBlueprint( + type='>', + inline=True, + table2='YetAnotherTable', + col2=['YetAnotherCol']) + ] + ) + ] + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + ref_bps = bp.get_reference_blueprints() + self.assertEqual(ref_bps[0].table1, result.name) + self.assertEqual(ref_bps[0].col1, 'id') + self.assertEqual(ref_bps[1].table1, result.name) + self.assertEqual(ref_bps[1].col1, 'name') diff --git a/test/test_blueprints/test_table_group.py b/test/test_blueprints/test_table_group.py new file mode 100644 index 0000000..f76a823 --- /dev/null +++ b/test/test_blueprints/test_table_group.py @@ -0,0 +1,26 @@ +from unittest import TestCase +from unittest.mock import Mock + +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml.parser.blueprints import TableGroupBlueprint + + +class TestTableGroupBlueprint(TestCase): + def test_build(self) -> None: + bp = TableGroupBlueprint( + name='TestTableGroup', + items=['table1', 'table2'], + comment='Comment text' + ) + parserMock = Mock() + parserMock.locate_table.side_effect = [ + Table(name='table1'), + Table(name='table2') + ] + bp.parser = parserMock + result = bp.build() + self.assertIsInstance(result, TableGroup) + self.assertEqual(parserMock.locate_table.call_count, 2) + for i in result.items: + self.assertIsInstance(i, Table) diff --git a/test/test_classes/test_base.py b/test/test_classes/test_base.py new file mode 100644 index 0000000..7c0d228 --- /dev/null +++ b/test/test_classes/test_base.py @@ -0,0 +1,33 @@ +from unittest import TestCase + +from pydbml.classes.base import SQLOjbect +from pydbml.exceptions import AttributeMissingError + + +class TestDBMLObject(TestCase): + def test_check_attributes_for_sql(self) -> None: + o = SQLOjbect() + o.a1 = None + o.b1 = None + o.c1 = None + o.required_attributes = ('a1', 'b1') + with self.assertRaises(AttributeMissingError): + o.check_attributes_for_sql() + o.a1 = 1 + with self.assertRaises(AttributeMissingError): + o.check_attributes_for_sql() + o.b1 = 'a2' + o.check_attributes_for_sql() + + def test_comparison(self) -> None: + o1 = SQLOjbect() + o1.a1 = None + o1.b1 = 'c' + o1.c1 = 123 + o2 = SQLOjbect() + o2.a1 = None + o2.b1 = 'c' + o2.c1 = 123 + self.assertTrue(o1 == o2) + o1.a2 = True + self.assertFalse(o1 == o2) diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py new file mode 100644 index 0000000..e71f273 --- /dev/null +++ b/test/test_classes/test_column.py @@ -0,0 +1,120 @@ +from unittest import TestCase + +from pydbml.classes import Column + + +class TestColumn(TestCase): + def test_basic_sql(self) -> None: + r = Column(name='id', + type_='integer') + expected = '"id" integer' + self.assertEqual(r.sql, expected) + + def test_pk_autoinc(self) -> None: + r = Column(name='id', + type_='integer', + pk=True, + autoinc=True) + expected = '"id" integer PRIMARY KEY AUTOINCREMENT' + self.assertEqual(r.sql, expected) + + def test_unique_not_null(self) -> None: + r = Column(name='id', + type_='integer', + unique=True, + not_null=True) + expected = '"id" integer UNIQUE NOT NULL' + self.assertEqual(r.sql, expected) + + def test_default(self) -> None: + r = Column(name='order', + type_='integer', + default=0) + expected = '"order" integer DEFAULT 0' + self.assertEqual(r.sql, expected) + + def test_comment(self) -> None: + r = Column(name='id', + type_='integer', + unique=True, + not_null=True, + comment="Column comment") + expected = \ +'''-- Column comment +"id" integer UNIQUE NOT NULL''' + self.assertEqual(r.sql, expected) + + def test_dbml_simple(self): + c = Column( + name='order', + type_='integer' + ) + expected = '"order" integer' + + self.assertEqual(c.dbml, expected) + + def test_dbml_full(self): + c = Column( + name='order', + type_='integer', + unique=True, + not_null=True, + pk=True, + autoinc=True, + default='Def_value', + note='Note on the column', + comment='Comment on the column' + ) + expected = \ +'''// Comment on the column +"order" integer [pk, increment, default: 'Def_value', unique, not null, note: 'Note on the column']''' + + self.assertEqual(c.dbml, expected) + + def test_dbml_multiline_note(self): + c = Column( + name='order', + type_='integer', + not_null=True, + note='Note on the column\nmultiline', + comment='Comment on the column' + ) + expected = \ +"""// Comment on the column +"order" integer [not null, note: '''Note on the column +multiline''']""" + + self.assertEqual(c.dbml, expected) + + def test_dbml_default(self): + c = Column( + name='order', + type_='integer', + default='String value' + ) + expected = "\"order\" integer [default: 'String value']" + self.assertEqual(c.dbml, expected) + + c.default = 3 + expected = '"order" integer [default: 3]' + self.assertEqual(c.dbml, expected) + + c.default = 3.33 + expected = '"order" integer [default: 3.33]' + self.assertEqual(c.dbml, expected) + + c.default = "(now() - interval '5 days')" + expected = "\"order\" integer [default: `now() - interval '5 days'`]" + self.assertEqual(c.dbml, expected) + + c.default = 'NULL' + expected = '"order" integer [default: null]' + self.assertEqual(c.dbml, expected) + + c.default = 'TRue' + expected = '"order" integer [default: true]' + self.assertEqual(c.dbml, expected) + + c.default = 'false' + expected = '"order" integer [default: false]' + self.assertEqual(c.dbml, expected) diff --git a/test/test_classes/test_enum.py b/test/test_classes/test_enum.py new file mode 100644 index 0000000..a0f8563 --- /dev/null +++ b/test/test_classes/test_enum.py @@ -0,0 +1,90 @@ +from pydbml.classes import EnumItem +from pydbml.classes import Enum +from unittest import TestCase + + +class TestEnumItem(TestCase): + def test_dbml_simple(self): + ei = EnumItem('en-US') + expected = '"en-US"' + self.assertEqual(ei.dbml, expected) + + def test_sql(self): + ei = EnumItem('en-US') + expected = "'en-US'," + self.assertEqual(ei.sql, expected) + + def test_dbml_full(self): + ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') + expected = \ +'''// EnumItem comment +"en-US" [note: 'preferred']''' + self.assertEqual(ei.dbml, expected) + + +class TestEnum(TestCase): + def test_simple_enum(self) -> None: + items = [ + EnumItem('created'), + EnumItem('running'), + EnumItem('donef'), + EnumItem('failure'), + ] + e = Enum('job_status', items) + expected = \ +'''CREATE TYPE "job_status" AS ENUM ( + 'created', + 'running', + 'donef', + 'failure', +);''' + self.assertEqual(e.sql, expected) + + def test_comments(self) -> None: + items = [ + EnumItem('created', comment='EnumItem comment'), + EnumItem('running'), + EnumItem('donef', comment='EnumItem\nmultiline comment'), + EnumItem('failure'), + ] + e = Enum('job_status', items, comment='Enum comment') + expected = \ +'''-- Enum comment +CREATE TYPE "job_status" AS ENUM ( + -- EnumItem comment + 'created', + 'running', + -- EnumItem + -- multiline comment + 'donef', + 'failure', +);''' + self.assertEqual(e.sql, expected) + + def test_dbml_simple(self): + items = [EnumItem('en-US'), EnumItem('ru-RU'), EnumItem('en-GB')] + e = Enum('lang', items) + expected = \ +'''Enum "lang" { + "en-US" + "ru-RU" + "en-GB" +}''' + self.assertEqual(e.dbml, expected) + + def test_dbml_full(self): + items = [ + EnumItem('en-US', note='preferred'), + EnumItem('ru-RU', comment='Multiline\ncomment'), + EnumItem('en-GB')] + e = Enum('lang', items, comment="Enum comment") + expected = \ +'''// Enum comment +Enum "lang" { + "en-US" [note: 'preferred'] + // Multiline + // comment + "ru-RU" + "en-GB" +}''' + self.assertEqual(e.dbml, expected) diff --git a/test/test_classes/test_index.py b/test/test_classes/test_index.py new file mode 100644 index 0000000..f2bd5ee --- /dev/null +++ b/test/test_classes/test_index.py @@ -0,0 +1,105 @@ +from unittest import TestCase + +from pydbml.classes import Index +from pydbml.classes import Table +from pydbml.classes import Column + + +class TestIndex(TestCase): + def test_basic_sql(self) -> None: + t = Table('products') + t.add_column(Column('id', 'integer')) + r = Index(subjects=[t.columns[0]]) + t.add_index(r) + self.assertIs(r.table, t) + expected = 'CREATE INDEX ON "products" ("id");' + self.assertEqual(r.sql, expected) + + def test_comment(self) -> None: + t = Table('products') + t.add_column(Column('id', 'integer')) + r = Index(subjects=[t.columns[0]], + comment='Index comment') + t.add_index(r) + self.assertIs(r.table, t) + expected = \ +'''-- Index comment +CREATE INDEX ON "products" ("id");''' + + self.assertEqual(r.sql, expected) + + def test_unique_type_composite(self) -> None: + t = Table('products') + t.add_column(Column('id', 'integer')) + t.add_column(Column('name', 'varchar')) + r = Index( + subjects=[ + t.columns[0], + t.columns[1] + ], + type_='hash', + unique=True + ) + t.add_index(r) + expected = 'CREATE UNIQUE INDEX ON "products" USING HASH ("id", "name");' + self.assertEqual(r.sql, expected) + + def test_pk(self) -> None: + t = Table('products') + t.add_column(Column('id', 'integer')) + t.add_column(Column('name', 'varchar')) + r = Index( + subjects=[ + t.columns[0], + t.columns[1] + ], + pk=True + ) + t.add_index(r) + expected = 'PRIMARY KEY ("id", "name")' + self.assertEqual(r.sql, expected) + + def test_composite_with_expression(self) -> None: + t = Table('products') + t.add_column(Column('id', 'integer')) + r = Index(subjects=[t.columns[0], '(id*3)']) + t.add_index(r) + self.assertEqual(r.subjects, [t['id'], '(id*3)']) + expected = 'CREATE INDEX ON "products" ("id", (id*3));' + self.assertEqual(r.sql, expected) + + def test_dbml_simple(self): + t = Table('products') + t.add_column(Column('id', 'integer')) + i = Index(subjects=[t.columns[0]]) + t.add_index(i) + + expected = 'id' + self.assertEqual(i.dbml, expected) + + def test_dbml_composite(self): + t = Table('products') + t.add_column(Column('id', 'integer')) + i = Index(subjects=[t.columns[0], '(id*3)']) + t.add_index(i) + + expected = '(id, `id*3`)' + self.assertEqual(i.dbml, expected) + + def test_dbml_full(self): + t = Table('products') + t.add_column(Column('id', 'integer')) + i = Index( + subjects=[t.columns[0], '(getdate())'], + name='Dated id', + unique=True, + type_='hash', + pk=True, + note='Note on the column', + comment='Comment on the index' + ) + t.add_index(i) + expected = \ +'''// Comment on the index +(id, `getdate()`) [name: 'Dated id', pk, unique, type: hash, note: 'Note on the column']''' + self.assertEqual(i.dbml, expected) diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py new file mode 100644 index 0000000..d4f30ab --- /dev/null +++ b/test/test_classes/test_table.py @@ -0,0 +1,220 @@ +from unittest import TestCase + +from pydbml.classes import Index +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.classes import Column + + +class TestTable(TestCase): + def test_one_column(self) -> None: + t = Table('products') + c = Column('id', 'integer') + t.add_column(c) + expected = 'CREATE TABLE "products" (\n "id" integer\n);' + self.assertEqual(t.sql, expected) + +# def test_ref(self) -> None: +# t = Table('products') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# t.add_column(c1) +# t.add_column(c2) +# t2 = Table('names') +# c21 = Column('name_val', 'varchar2') +# t2.add_column(c21) +# r = TableReference(c2, t2, c21) +# t.add_ref(r) +# expected = \ +# '''CREATE TABLE "products" ( +# "id" integer, +# "name" varchar2, +# FOREIGN KEY ("name") REFERENCES "names" ("name_val") +# );''' +# self.assertEqual(t.sql, expected) + +# def test_duplicate_ref(self) -> None: +# t = Table('products') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# t.add_column(c1) +# t.add_column(c2) +# t2 = Table('names') +# c21 = Column('name_val', 'varchar2') +# t2.add_column(c21) +# r1 = TableReference(c2, t2, c21) +# t.add_ref(r1) +# r2 = TableReference(c2, t2, c21) +# self.assertEqual(r1, r2) +# with self.assertRaises(DuplicateReferenceError): +# t.add_ref(r2) + +# def test_notes(self) -> None: +# n = Note('Table note') +# nc1 = Note('First column note') +# nc2 = Note('Another column\nmultiline note') +# t = Table('products', note=n) +# c1 = Column('id', 'integer', note=nc1) +# c2 = Column('name', 'varchar') +# c3 = Column('country', 'varchar', note=nc2) +# t.add_column(c1) +# t.add_column(c2) +# t.add_column(c3) +# expected = \ +# '''CREATE TABLE "products" ( +# "id" integer, +# "name" varchar, +# "country" varchar +# ); + +# COMMENT ON TABLE "products" IS 'Table note'; + +# COMMENT ON COLUMN "products"."id" IS 'First column note'; + +# COMMENT ON COLUMN "products"."country" IS 'Another column +# multiline note';''' +# self.assertEqual(t.sql, expected) + +# def test_ref_index(self) -> None: +# t = Table('products') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# t.add_column(c1) +# t.add_column(c2) +# t2 = Table('names') +# c21 = Column('name_val', 'varchar2') +# t2.add_column(c21) +# r = TableReference(c2, t2, c21) +# t.add_ref(r) +# i = Index(['id', 'name']) +# t.add_index(i) +# expected = \ +# '''CREATE TABLE "products" ( +# "id" integer, +# "name" varchar2, +# FOREIGN KEY ("name") REFERENCES "names" ("name_val") +# ); + +# CREATE INDEX ON "products" ("id", "name");''' +# self.assertEqual(t.sql, expected) + +# def test_index_inline(self) -> None: +# t = Table('products') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# t.add_column(c1) +# t.add_column(c2) +# i = Index(['id', 'name'], pk=True) +# t.add_index(i) +# expected = \ +# '''CREATE TABLE "products" ( +# "id" integer, +# "name" varchar2, +# PRIMARY KEY ("id", "name") +# );''' +# self.assertEqual(t.sql, expected) + +# def test_index_inline_and_comments(self) -> None: +# t = Table('products', comment='Multiline\ntable comment') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# t.add_column(c1) +# t.add_column(c2) +# i = Index(['id', 'name'], pk=True, comment='Multiline\nindex comment') +# t.add_index(i) +# expected = \ +# '''-- Multiline +# -- table comment +# CREATE TABLE "products" ( +# "id" integer, +# "name" varchar2, +# -- Multiline +# -- index comment +# PRIMARY KEY ("id", "name") +# );''' +# self.assertEqual(t.sql, expected) + +# def test_add_column(self) -> None: +# t = Table('products') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# t.add_column(c1) +# t.add_column(c2) +# self.assertEqual(c1.table, t) +# self.assertEqual(c2.table, t) +# self.assertEqual(t.columns, [c1, c2]) + +# def test_add_index(self) -> None: +# t = Table('products') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# i1 = Index(['id']) +# i2 = Index(['name']) +# t.add_column(c1) +# t.add_column(c2) +# t.add_index(i1) +# t.add_index(i2) +# self.assertEqual(i1.table, t) +# self.assertEqual(i2.table, t) +# self.assertEqual(t.indexes, [i1, i2]) + +# def test_add_bad_index(self) -> None: +# t = Table('products') +# c = Column('id', 'integer') +# i = Index(['id', 'name']) +# t.add_column(c) +# with self.assertRaises(ColumnNotFoundError): +# t.add_index(i) + +# def test_dbml_simple(self): +# t = Table('products') +# c1 = Column('id', 'integer') +# c2 = Column('name', 'varchar2') +# t.add_column(c1) +# t.add_column(c2) +# expected = \ +# '''Table "products" { +# "id" integer +# "name" varchar2 +# }''' +# self.assertEqual(t.dbml, expected) + +# def test_dbml_full(self): +# t = Table( +# 'products', +# alias='pd', +# note='My multiline\nnote', +# comment='My multiline\ncomment' +# ) +# c0 = Column('zero', 'number') +# c1 = Column('id', 'integer', unique=True, note='Multiline\ncomment note') +# c2 = Column('name', 'varchar2') +# t.add_column(c0) +# t.add_column(c1) +# t.add_column(c2) +# i1 = Index(['zero', 'id'], unique=True) +# i2 = Index(['(capitalize(name))'], comment="index comment") +# t.add_index(i1) +# t.add_index(i2) +# expected = \ +# """// My multiline +# // comment +# Table "products" as "pd" { +# "zero" number +# "id" integer [unique, note: '''Multiline +# comment note'''] +# "name" varchar2 +# Note { +# ''' +# My multiline +# note +# ''' +# } + +# indexes { +# (zero, id) [unique] +# // index comment +# `capitalize(name)` +# } +# }""" +# self.assertEqual(t.dbml, expected) diff --git a/test/test_create_schema.py b/test/test_create_schema.py new file mode 100644 index 0000000..1c98ece --- /dev/null +++ b/test/test_create_schema.py @@ -0,0 +1,88 @@ +import os + +from pathlib import Path +from unittest import TestCase + +from pydbml.schema import Schema +from pydbml.classes import Table, Column, Index + + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class TestCreateTable(TestCase): + def test_one_column(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + self.assertEqual(c.table, t) + schema = Schema() + schema.add(t) + self.assertEqual(t.schema, schema) + self.assertEqual(schema.tables[0], t) + self.assertEqual(schema.tables[0].name, 'test_table') + self.assertEqual(schema.tables[0].columns[0].name, 'test') + + def test_delete_column(self) -> None: + c1 = Column('col1', 'varchar', True) + c2 = Column('col2', 'number', False) + t = Table('test_table') + t.add_column(c1) + t.add_column(c2) + result = t.delete_column(1) + self.assertEqual(result, c2) + self.assertIsNone(result.table) + + def test_delete_table(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + schema.add(t) + self.assertEqual(t.schema, schema) + self.assertEqual(schema.tables[0], t) + schema.delete(t) + self.assertIsNone(t.schema) + self.assertEqual(schema.tables, []) + + +class TestCreateIndex(TestCase): + def test_simple_index(self): + c1 = Column('col1', 'varchar', True) + c2 = Column('col2', 'number', False) + t = Table('test_table') + t.add_column(c1) + t.add_column(c2) + i = Index([c1], 'IndexName', True) + self.assertIsNone(i.table) + t.add_index(i) + self.assertEqual(i.table, t) + + def test_complex_index(self): + c1 = Column('col1', 'varchar', True) + c2 = Column('col2', 'number', False) + t = Table('test_table') + t.add_column(c1) + t.add_column(c2) + i1 = Index([c1, c2], 'Compound', True) + self.assertIsNone(i1.table) + t.add_index(i1) + self.assertEqual(i1.table, t) + i2 = Index([c1, '(c2 * 3)'], 'Compound expression', True) + self.assertIsNone(i2.table) + t.add_index(i2) + self.assertEqual(i2.table, t) + + def test_delete_index(self): + c1 = Column('col1', 'varchar', True) + c2 = Column('col2', 'number', False) + t = Table('test_table') + t.add_column(c1) + t.add_column(c2) + i = Index([c1], 'IndexName', True) + self.assertIsNone(i.table) + t.add_index(i) + self.assertEqual(i.table, t) + t.delete_index(0) + self.assertIsNone(i.table) + self.assertEqual(t.indexes, []) diff --git a/test/test_definitions/__init__.py b/test/test_definitions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_column.py b/test/test_definitions/test_column.py similarity index 100% rename from test/test_column.py rename to test/test_definitions/test_column.py diff --git a/test/test_common.py b/test/test_definitions/test_common.py similarity index 100% rename from test/test_common.py rename to test/test_definitions/test_common.py diff --git a/test/test_enum.py b/test/test_definitions/test_enum.py similarity index 100% rename from test/test_enum.py rename to test/test_definitions/test_enum.py diff --git a/test/test_index.py b/test/test_definitions/test_index.py similarity index 91% rename from test/test_index.py rename to test/test_definitions/test_index.py index 2067261..0dc432b 100644 --- a/test/test_index.py +++ b/test/test_definitions/test_index.py @@ -71,7 +71,7 @@ def test_unique(self) -> None: def test_name_type_multiline(self) -> None: val = '[\nname: "index name"\n,\ntype:\nbtree\n]' res = index_settings.parseString(val, parseAll=True) - self.assertEqual(res[0]['type_'], 'btree') + self.assertEqual(res[0]['type'], 'btree') self.assertEqual(res[0]['name'], 'index name') def test_pk(self) -> None: @@ -90,7 +90,7 @@ def test_wrong_pk(self) -> None: def test_all(self) -> None: val = '[type: hash, name: "index name", note: "index note", unique]' res = index_settings.parseString(val, parseAll=True) - self.assertEqual(res[0]['type_'], 'hash') + self.assertEqual(res[0]['type'], 'hash') self.assertEqual(res[0]['name'], 'index name') self.assertEqual(res[0]['note'].text, 'index note') self.assertTrue(res[0]['unique']) @@ -160,51 +160,51 @@ class TestIndex(TestCase): def test_single(self) -> None: val = 'my_column' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['my_column']) + self.assertEqual(res[0].subject_names, ['my_column']) def test_expression(self) -> None: val = '(`id*3`)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['(id*3)']) + self.assertEqual(res[0].subject_names, ['(id*3)']) def test_composite(self) -> None: val = '(my_column, my_another_column)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['my_column', 'my_another_column']) + self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) def test_composite_with_expression(self) -> None: val = '(`id*3`, fieldname)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['(id*3)', 'fieldname']) + self.assertEqual(res[0].subject_names, ['(id*3)', 'fieldname']) def test_with_settings(self) -> None: val = '(my_column, my_another_column) [unique]' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['my_column', 'my_another_column']) + self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) self.assertTrue(res[0].unique) def test_comment_above(self) -> None: val = '//comment above\nmy_column [unique]' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['my_column']) + self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment above') def test_comment_after(self) -> None: val = 'my_column [unique] //comment after' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['my_column']) + self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') val = 'my_column //comment after' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['my_column']) + self.assertEqual(res[0].subject_names, ['my_column']) self.assertEqual(res[0].comment, 'comment after') def test_both_comments(self) -> None: val = '//comment before\nmy_column [unique] //comment after' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0]._subject_names, ['my_column']) + self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') diff --git a/test/test_project.py b/test/test_definitions/test_project.py similarity index 100% rename from test/test_project.py rename to test/test_definitions/test_project.py diff --git a/test/test_reference.py b/test/test_definitions/test_reference.py similarity index 100% rename from test/test_reference.py rename to test/test_definitions/test_reference.py diff --git a/test/test_table.py b/test/test_definitions/test_table.py similarity index 100% rename from test/test_table.py rename to test/test_definitions/test_table.py diff --git a/test/test_table_group.py b/test/test_definitions/test_table_group.py similarity index 100% rename from test/test_table_group.py rename to test/test_definitions/test_table_group.py diff --git a/test/test_parser.py b/test/test_parser.py index f6b76f7..6228c65 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -3,7 +3,9 @@ from pathlib import Path from unittest import TestCase -from pydbml import PyDBML +from pydbml.parser.blueprints import EnumItemBlueprint +from pydbml.parser.blueprints import EnumBlueprint +from pydbml import PyDBMLParser from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError @@ -12,85 +14,5 @@ class TestParser(TestCase): - def setUp(self): - self.results = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - - def test_table_refs(self) -> None: - p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - r = p['order_items'].refs - self.assertEqual(r[0].col[0].name, 'order_id') - self.assertEqual(r[0].ref_table.name, 'orders') - self.assertEqual(r[0].ref_col[0].name, 'id') - r = p['products'].refs - self.assertEqual(r[0].col[0].name, 'merchant_id') - self.assertEqual(r[0].ref_table.name, 'merchants') - self.assertEqual(r[0].ref_col[0].name, 'id') - r = p['users'].refs - self.assertEqual(r[0].col[0].name, 'country_code') - self.assertEqual(r[0].ref_table.name, 'countries') - self.assertEqual(r[0].ref_col[0].name, 'code') - - def test_refs(self) -> None: - p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - r = p.refs - self.assertEqual(r[0].table1.name, 'orders') - self.assertEqual(r[0].col1[0].name, 'id') - self.assertEqual(r[0].table2.name, 'order_items') - self.assertEqual(r[0].col2[0].name, 'order_id') - self.assertEqual(r[2].table1.name, 'users') - self.assertEqual(r[2].col1[0].name, 'country_code') - self.assertEqual(r[2].table2.name, 'countries') - self.assertEqual(r[2].col2[0].name, 'code') - self.assertEqual(r[4].table1.name, 'products') - self.assertEqual(r[4].col1[0].name, 'merchant_id') - self.assertEqual(r[4].table2.name, 'merchants') - self.assertEqual(r[4].col2[0].name, 'id') - - -class TestRefs(TestCase): - def test_reference_aliases(self): - results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_aliases.dbml') - posts, reviews, users = results['posts'], results['reviews'], results['users'] - posts2, reviews2, users2 = results['posts2'], results['reviews2'], results['users2'] - - rs = results.refs - self.assertEqual(rs[0].table1, users) - self.assertEqual(rs[0].table2, posts) - self.assertEqual(rs[1].table1, users) - self.assertEqual(rs[1].table2, reviews) - - self.assertEqual(rs[2].table1, posts2) - self.assertEqual(rs[2].table2, users2) - self.assertEqual(rs[3].table1, reviews2) - self.assertEqual(rs[3].table2, users2) - - def test_composite_references(self): - results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_composite.dbml') - self.assertEqual(len(results.tables), 4) - posts, reviews = results['posts'], results['reviews'] - posts2, reviews2 = results['posts2'], results['reviews2'] - - rs = results.refs - self.assertEqual(len(rs), 2) - - self.assertEqual(rs[0].table1, posts) - self.assertEqual(rs[0].col1, [posts['id'], posts['tag']]) - self.assertEqual(rs[0].table2, reviews) - self.assertEqual(rs[0].col2, [reviews['post_id'], reviews['tag']]) - - self.assertEqual(rs[1].table1, posts2) - self.assertEqual(rs[1].col1, [posts2['id'], posts2['tag']]) - self.assertEqual(rs[1].table2, reviews2) - self.assertEqual(rs[1].col2, [reviews2['post_id'], reviews2['tag']]) - - -class TestFaulty(TestCase): - def test_bad_reference(self) -> None: - with self.assertRaises(TableNotFoundError): - PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_table.dbml') - with self.assertRaises(ColumnNotFoundError): - PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_column.dbml') - - def test_bad_index(self) -> None: - with self.assertRaises(ColumnNotFoundError): - PyDBML(TEST_DATA_PATH / 'wrong_index.dbml') + def test_build_enums(self) -> None: + i1 = EnumItemBlueprint() \ No newline at end of file From 06a544c31c3e8fabdc86acd187f8f50bf8b4dd3e Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 10 May 2022 21:27:51 +0200 Subject: [PATCH 020/125] Tests for reference class, create constants module, fix bugs caught by tests --- pydbml/__init__.py | 8 +- pydbml/classes/column.py | 3 + pydbml/classes/reference.py | 9 +- pydbml/classes/table.py | 6 +- pydbml/constants.py | 3 + pydbml/exceptions.py | 1 + pydbml/parser/blueprints.py | 12 +- test/_test_classes.py | 380 ---------------------------- test/test_classes/test_column.py | 2 + test/test_classes/test_reference.py | 269 ++++++++++++++++++++ 10 files changed, 295 insertions(+), 398 deletions(-) create mode 100644 pydbml/constants.py create mode 100644 test/test_classes/test_reference.py diff --git a/pydbml/__init__.py b/pydbml/__init__.py index 2be8095..e7d44ce 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -2,10 +2,10 @@ import unittest from . import classes -from pydbml.parser import PyDBML -from pydbml.parser.blueprints import MANY_TO_ONE -from pydbml.parser.blueprints import ONE_TO_MANY -from pydbml.parser.blueprints import ONE_TO_ONE + +from pydbml.constants import MANY_TO_ONE +from pydbml.constants import ONE_TO_MANY +from pydbml.constants import ONE_TO_ONE def load_tests(loader, tests, ignore): diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index b6bd1d9..eb61825 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -45,6 +45,9 @@ def __init__(self, self.table: Optional['Table'] = None def get_refs(self) -> List['Reference']: + ''' + get all references related to this column (where this col is col1 in ref) + ''' if not self.table: raise TableNotFoundError('Table for the column is not set') return [ref for ref in self.table.get_refs() if ref.col1 == self] diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 3f08764..957e81f 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -6,9 +6,9 @@ from .base import SQLOjbect from .column import Column -from pydbml import MANY_TO_ONE -from pydbml import ONE_TO_MANY -from pydbml import ONE_TO_ONE +from pydbml.constants import MANY_TO_ONE +from pydbml.constants import ONE_TO_MANY +from pydbml.constants import ONE_TO_ONE from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql from pydbml.exceptions import DBMLError @@ -149,9 +149,10 @@ def sql(self): @property def dbml(self): if self.inline: + # settings are ignored for inline ref if len(self.col2) > 1: raise DBMLError('Cannot render DBML: composite ref cannot be inline') - return f'ref: {self.type} {self.table2.name}.{self.col2[0].name}' + return f'ref: {self.type} "{self.table2.name}"."{self.col2[0].name}"' else: result = comment_to_dbml(self.comment) if self.comment else '' result += 'Ref' diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index c4c72ec..e847f2c 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -8,9 +8,9 @@ from .index import Index from .note import Note from .reference import Reference -from pydbml import MANY_TO_ONE -from pydbml import ONE_TO_MANY -from pydbml import ONE_TO_ONE +from pydbml.constants import MANY_TO_ONE +from pydbml.constants import ONE_TO_MANY +from pydbml.constants import ONE_TO_ONE from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import IndexNotFoundError from pydbml.exceptions import UnknownSchemaError diff --git a/pydbml/constants.py b/pydbml/constants.py new file mode 100644 index 0000000..e4fa877 --- /dev/null +++ b/pydbml/constants.py @@ -0,0 +1,3 @@ +ONE_TO_MANY = '<' +MANY_TO_ONE = '>' +ONE_TO_ONE = '-' diff --git a/pydbml/exceptions.py b/pydbml/exceptions.py index 413b309..c8e61b5 100644 --- a/pydbml/exceptions.py +++ b/pydbml/exceptions.py @@ -21,5 +21,6 @@ class DuplicateReferenceError(Exception): class UnknownSchemaError(Exception): pass + class DBMLError(Exception): pass diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 6d74d7d..a7837d6 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -6,24 +6,22 @@ from typing import Optional from typing import Union +from pydbml.constants import MANY_TO_ONE +from pydbml.constants import ONE_TO_MANY +from pydbml.constants import ONE_TO_ONE from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem from pydbml.classes import Index from pydbml.classes import Note from pydbml.classes import Project from pydbml.classes import Reference from pydbml.classes import Table from pydbml.classes import TableGroup -from pydbml.classes import EnumItem -from pydbml.classes import Enum from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError -ONE_TO_MANY = '<' -MANY_TO_ONE = '>' -ONE_TO_ONE = '-' - - class Blueprint: parser = None diff --git a/test/_test_classes.py b/test/_test_classes.py index 7c483ff..b28235b 100644 --- a/test/_test_classes.py +++ b/test/_test_classes.py @@ -18,34 +18,6 @@ from pydbml.exceptions import DuplicateReferenceError -class TestDBMLObject(TestCase): - def test_check_attributes_for_sql(self) -> None: - o = SQLOjbect() - o.a1 = None - o.b1 = None - o.c1 = None - o.required_attributes = ('a1', 'b1') - with self.assertRaises(AttributeMissingError): - o.check_attributes_for_sql() - o.a1 = 1 - with self.assertRaises(AttributeMissingError): - o.check_attributes_for_sql() - o.b1 = 'a2' - o.check_attributes_for_sql() - - def test_comparison(self) -> None: - o1 = SQLOjbect() - o1.a1 = None - o1.b1 = 'c' - o1.c1 = 123 - o2 = SQLOjbect() - o2.a1 = None - o2.b1 = 'c' - o2.c1 = 123 - self.assertTrue(o1 == o2) - o1.a2 = True - self.assertFalse(o1 == o2) - # class TestReferenceBlueprint(TestCase): # def test_basic_sql(self) -> None: @@ -103,238 +75,6 @@ def test_comparison(self) -> None: # self.assertEqual(r.sql, expected) -class TestColumn(TestCase): - def test_basic_sql(self) -> None: - r = Column(name='id', - type_='integer') - expected = '"id" integer' - self.assertEqual(r.sql, expected) - - def test_pk_autoinc(self) -> None: - r = Column(name='id', - type_='integer', - pk=True, - autoinc=True) - expected = '"id" integer PRIMARY KEY AUTOINCREMENT' - self.assertEqual(r.sql, expected) - - def test_unique_not_null(self) -> None: - r = Column(name='id', - type_='integer', - unique=True, - not_null=True) - expected = '"id" integer UNIQUE NOT NULL' - self.assertEqual(r.sql, expected) - - def test_default(self) -> None: - r = Column(name='order', - type_='integer', - default=0) - expected = '"order" integer DEFAULT 0' - self.assertEqual(r.sql, expected) - - def test_comment(self) -> None: - r = Column(name='id', - type_='integer', - unique=True, - not_null=True, - comment="Column comment") - expected = \ -'''-- Column comment -"id" integer UNIQUE NOT NULL''' - self.assertEqual(r.sql, expected) - - def test_table_setter(self) -> None: - r1 = ReferenceBlueprint( - ReferenceBlueprint.MANY_TO_ONE, - name='refname', - table1='bookings', - col1='order_id', - table2='orders', - col2='order', - ) - r2 = ReferenceBlueprint( - ReferenceBlueprint.MANY_TO_ONE, - name='refname', - table1='purchases', - col1='order_id', - table2='orders', - col2='order', - ) - c = Column( - name='order', - type_='integer', - default=0, - ref_blueprints=[r1, r2] - ) - t = Table('orders') - c.table = t - self.assertEqual(c.table, t) - self.assertEqual(c.ref_blueprints[0].table1, t.name) - self.assertEqual(c.ref_blueprints[1].table1, t.name) - - def test_dbml_simple(self): - c = Column( - name='order', - type_='integer' - ) - expected = '"order" integer' - - self.assertEqual(c.dbml, expected) - - def test_dbml_full(self): - c = Column( - name='order', - type_='integer', - unique=True, - not_null=True, - pk=True, - autoinc=True, - default='Def_value', - note='Note on the column', - comment='Comment on the column' - ) - expected = \ -'''// Comment on the column -"order" integer [pk, increment, default: 'Def_value', unique, not null, note: 'Note on the column']''' - - self.assertEqual(c.dbml, expected) - - def test_dbml_multiline_note(self): - c = Column( - name='order', - type_='integer', - not_null=True, - note='Note on the column\nmultiline', - comment='Comment on the column' - ) - expected = \ -"""// Comment on the column -"order" integer [not null, note: '''Note on the column -multiline''']""" - - self.assertEqual(c.dbml, expected) - - def test_dbml_default(self): - c = Column( - name='order', - type_='integer', - default='String value' - ) - expected = "\"order\" integer [default: 'String value']" - self.assertEqual(c.dbml, expected) - - c.default = 3 - expected = '"order" integer [default: 3]' - self.assertEqual(c.dbml, expected) - - c.default = 3.33 - expected = '"order" integer [default: 3.33]' - self.assertEqual(c.dbml, expected) - - c.default = "(now() - interval '5 days')" - expected = "\"order\" integer [default: `now() - interval '5 days'`]" - self.assertEqual(c.dbml, expected) - - c.default = 'NULL' - expected = '"order" integer [default: null]' - self.assertEqual(c.dbml, expected) - - c.default = 'TRue' - expected = '"order" integer [default: true]' - self.assertEqual(c.dbml, expected) - - c.default = 'false' - expected = '"order" integer [default: false]' - self.assertEqual(c.dbml, expected) - -class TestIndex(TestCase): - def test_basic_sql(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subject_names=['id'], - table=t) - t.add_index(r) - expected = 'CREATE INDEX ON "products" ("id");' - self.assertEqual(r.sql, expected) - - def test_comment(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subject_names=['id'], - table=t, - comment='Index comment') - t.add_index(r) - expected = \ -'''-- Index comment -CREATE INDEX ON "products" ("id");''' - - self.assertEqual(r.sql, expected) - - def test_unique_type_composite(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - t.add_column(Column('name', 'varchar')) - r = Index(subject_names=['id', 'name'], - table=t, - type_='hash', - unique=True) - t.add_index(r) - expected = 'CREATE UNIQUE INDEX ON "products" USING HASH ("id", "name");' - self.assertEqual(r.sql, expected) - - def test_pk(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - t.add_column(Column('name', 'varchar')) - r = Index(subject_names=['id', 'name'], - table=t, - pk=True) - t.add_index(r) - expected = 'PRIMARY KEY ("id", "name")' - self.assertEqual(r.sql, expected) - - def test_composite_with_expression(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subject_names=['id', '(id*3)'], - table=t) - t.add_index(r) - self.assertEqual(r.subjects, [t['id'], '(id*3)']) - expected = 'CREATE INDEX ON "products" ("id", (id*3));' - self.assertEqual(r.sql, expected) - - def test_dbml_simple(self): - i = Index( - ['id'] - ) - - expected = 'id' - self.assertEqual(i.dbml, expected) - - def test_dbml_composite(self): - i = Index( - ['id', '(id*3)'] - ) - - expected = '(id, `id*3`)' - self.assertEqual(i.dbml, expected) - - def test_dbml_full(self): - i = Index( - ['id', '(getdate())'], - name='Dated id', - unique=True, - type_='hash', - pk=True, - note='Note on the column', - comment='Comment on the index' - ) - expected = \ -'''// Comment on the index -(id, `getdate()`) [name: 'Dated id', pk, unique, type: hash, note: 'Note on the column']''' - - class TestTable(TestCase): def test_one_column(self) -> None: t = Table('products') @@ -609,126 +349,6 @@ def test_sql_full(self): expected = 'CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL' self.assertEqual(ref.sql, expected) -class TestReference(TestCase): - def test_sql_single(self): - t = Table('products') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('>', t, c1, t2, c2) - - expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' - self.assertEqual(ref.sql, expected) - - def test_sql_reverse(self): - t = Table('products') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('<', t, c1, t2, c2) - - expected = 'ALTER TABLE "names" ADD FOREIGN KEY ("name_val") REFERENCES "products" ("name");' - self.assertEqual(ref.sql, expected) - - def test_sql_multiple(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('>', t, [c11, c12], t2, (c21, c22)) - - expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val");' - self.assertEqual(ref.sql, expected) - - def test_sql_full(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference( - '>', - t, - [c11, c12], - t2, - (c21, c22), - name="country_name", - comment="Multiline\ncomment for the constraint", - on_update="CASCADE", - on_delete="SET NULL" - ) - - expected = \ -'''-- Multiline --- comment for the constraint -ALTER TABLE "products" ADD CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL;''' - - self.assertEqual(ref.sql, expected) - - def test_dbml_simple(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - ref = Reference('>', t, c2, t2, c21) - - expected = \ -'''Ref { - "products"."name" > "names"."name_val" -}''' - self.assertEqual(ref.dbml, expected) - - def test_dbml_full(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - c3 = Column('country', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t.add_column(c3) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference( - '<', - t, - [c2, c3], - t2, - (c21, c22), - name='nameref', - comment='Reference comment\nmultiline', - on_update='CASCADE', - on_delete='SET NULL' - ) - - expected = \ -'''// Reference comment -// multiline -Ref nameref { - "products".("name", "country") < "names".("name_val", "country") [update: CASCADE, delete: SET NULL] -}''' - self.assertEqual(ref.dbml, expected) class TestNote(TestCase): diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index e71f273..2168f74 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -118,3 +118,5 @@ def test_dbml_default(self): c.default = 'false' expected = '"order" integer [default: false]' self.assertEqual(c.dbml, expected) + +# TODO: test ref inline diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py new file mode 100644 index 0000000..6a85a6c --- /dev/null +++ b/test/test_classes/test_reference.py @@ -0,0 +1,269 @@ +from unittest import TestCase +from pydbml.classes import Column +from pydbml.classes import Table +from pydbml.classes import Reference +from pydbml.exceptions import DBMLError + + +class TestReference(TestCase): + def test_sql_single(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('>', t, c1, t2, c2) + + expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' + self.assertEqual(ref.sql, expected) + + def test_sql_reverse(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('<', t, c1, t2, c2) + + expected = 'ALTER TABLE "names" ADD FOREIGN KEY ("name_val") REFERENCES "products" ("name");' + self.assertEqual(ref.sql, expected) + + def test_sql_multiple(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('>', t, [c11, c12], t2, (c21, c22)) + + expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val");' + self.assertEqual(ref.sql, expected) + + def test_sql_full(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference( + '>', + t, + [c11, c12], + t2, + (c21, c22), + name="country_name", + comment="Multiline\ncomment for the constraint", + on_update="CASCADE", + on_delete="SET NULL" + ) + + expected = \ +'''-- Multiline +-- comment for the constraint +ALTER TABLE "products" ADD CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL;''' + + self.assertEqual(ref.sql, expected) + + def test_dbml_simple(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference('>', t, c2, t2, c21) + + expected = \ +'''Ref { + "products"."name" > "names"."name_val" +}''' + self.assertEqual(ref.dbml, expected) + + def test_dbml_full(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + c3 = Column('country', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference( + '<', + t, + [c2, c3], + t2, + (c21, c22), + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL' + ) + + expected = \ +'''// Reference comment +// multiline +Ref nameref { + "products".("name", "country") < "names".("name_val", "country") [update: CASCADE, delete: SET NULL] +}''' + self.assertEqual(ref.dbml, expected) + + +class TestReferenceInline(TestCase): + def test_sql_single(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('>', t, c1, t2, c2, inline=True) + + expected = 'FOREIGN KEY ("name") REFERENCES "names" ("name_val")' + self.assertEqual(ref.sql, expected) + + def test_sql_reverse(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('<', t, c1, t2, c2, inline=True) + + expected = 'FOREIGN KEY ("name_val") REFERENCES "products" ("name")' + self.assertEqual(ref.sql, expected) + + def test_sql_multiple(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('>', t, [c11, c12], t2, (c21, c22), inline=True) + + expected = 'FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val")' + self.assertEqual(ref.sql, expected) + + def test_sql_full(self): + t = Table('products') + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference( + '>', + t, + [c11, c12], + t2, + (c21, c22), + name="country_name", + comment="Multiline\ncomment for the constraint", + on_update="CASCADE", + on_delete="SET NULL", + inline=True + ) + + expected = \ +'''-- Multiline +-- comment for the constraint +CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL''' + + self.assertEqual(ref.sql, expected) + + def test_dbml_simple(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference('>', t, c2, t2, c21, inline=True) + + expected = 'ref: > "names"."name_val"' + self.assertEqual(ref.dbml, expected) + + def test_dbml_settings_ignored(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference( + '<', + t, + c2, + t2, + c21, + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL', + inline=True + ) + + expected = 'ref: < "names"."name_val"' + self.assertEqual(ref.dbml, expected) + + def test_dbml_composite_inline_ref_forbidden(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + c3 = Column('country', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + c22 = Column('country', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference( + '<', + t, + [c2, c3], + t2, + (c21, c22), + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL', + inline=True + ) + + with self.assertRaises(DBMLError): + ref.dbml + From decc5ec2ba1ed07b788dab9306e99568b9d05b6a Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Wed, 11 May 2022 19:16:36 +0200 Subject: [PATCH 021/125] Finish tests for classes, fix bugs --- pydbml/classes/column.py | 11 +- pydbml/classes/reference.py | 75 ++-- pydbml/classes/table.py | 11 +- pydbml/parser/blueprints.py | 2 - pydbml/schema.py | 29 +- test/_test_classes.py | 468 ------------------------- test/test_blueprints/test_reference.py | 2 - test/test_classes/test_column.py | 91 ++++- test/test_classes/test_note.py | 50 +++ test/test_classes/test_project.py | 41 +++ test/test_classes/test_reference.py | 69 +++- test/test_classes/test_table.py | 427 ++++++++++++---------- test/test_classes/test_table_group.py | 35 ++ test/test_create_schema.py | 4 +- 14 files changed, 581 insertions(+), 734 deletions(-) delete mode 100644 test/_test_classes.py create mode 100644 test/test_classes/test_note.py create mode 100644 test/test_classes/test_project.py create mode 100644 test/test_classes/test_table_group.py diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index eb61825..13e8dab 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -31,7 +31,6 @@ def __init__(self, note: Optional[Union['Note', str]] = None, # ref_blueprints: Optional[List[ReferenceBlueprint]] = None, comment: Optional[str] = None): - self.schema = None self.name = name self.type = type_ self.unique = unique @@ -46,11 +45,15 @@ def __init__(self, def get_refs(self) -> List['Reference']: ''' - get all references related to this column (where this col is col1 in ref) + get all references related to this column (where this col is col1 in) ''' if not self.table: raise TableNotFoundError('Table for the column is not set') - return [ref for ref in self.table.get_refs() if ref.col1 == self] + return [ref for ref in self.table.get_refs() if self in ref.col1] + + @property + def schema(self): + return self.table.schema if self.table else None @property def sql(self): @@ -93,7 +96,7 @@ def default_to_str(val: str) -> str: result = comment_to_dbml(self.comment) if self.comment else '' result += f'"{self.name}" {self.type}' - options = [ref.dbml() for ref in self.get_refs() if ref.inline] + options = [ref.dbml for ref in self.get_refs() if ref.inline] if self.pk: options.append('pk') if self.autoinc: diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 957e81f..896299e 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -12,10 +12,7 @@ from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql from pydbml.exceptions import DBMLError - - -if TYPE_CHECKING: - from .table import Table +from pydbml.exceptions import TableNotFoundError class Reference(SQLOjbect): @@ -24,13 +21,13 @@ class Reference(SQLOjbect): It is a separate object, which is not connected to Table or Column objects and its `sql` property contains the ALTER TABLE clause. ''' - required_attributes = ('type', 'table1', 'col1', 'table2', 'col2') + required_attributes = ('type', 'col1', 'col2') def __init__(self, type_: Literal[MANY_TO_ONE, ONE_TO_MANY, ONE_TO_ONE], - table1: 'Table', + # table1: 'Table', col1: Union[Column, Collection[Column]], - table2: 'Table', + # table2: 'Table', col2: Union[Column, Collection[Column]], name: Optional[str] = None, comment: Optional[str] = None, @@ -39,9 +36,9 @@ def __init__(self, inline: bool = False): self.schema = None self.type = type_ - self.table1 = table1 + # self.table1 = table1 self.col1 = [col1] if isinstance(col1, Column) else list(col1) - self.table2 = table2 + # self.table2 = table2 self.col2 = [col2] if isinstance(col2, Column) else list(col2) self.name = name if name else None self.comment = comment @@ -53,20 +50,17 @@ def __repr__(self): ''' >>> c1 = Column('c1', 'int') >>> c2 = Column('c2', 'int') - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> Reference('>', table1=t1, col1=c1, table2=t2, col2=c2) - ', 't1'.['c1'], 't2'.['c2']> + >>> Reference('>', col1=c1, col2=c2) + ', ['c1'], ['c2']> >>> c12 = Column('c12', 'int') >>> c22 = Column('c22', 'int') - >>> Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22)) - + >>> Reference('<', col1=[c1, c12], col2=(c2, c22)) + ''' - components = [f"' + col1 = ', '.join(f'{c.name!r}' for c in self.col1) + col2 = ', '.join(f'{c.name!r}' for c in self.col2) + return f"" def __str__(self): ''' @@ -82,13 +76,22 @@ def __str__(self): Reference(t1[c1, c12] < t2[c2, c22]) ''' - components = [f"Reference("] - components.append(self.table1.name) - components.append(f'[{", ".join(c.name for c in self.col1)}]') - components.append(f' {self.type} ') - components.append(self.table2.name) - components.append(f'[{", ".join(c.name for c in self.col2)}]') - return ''.join(components) + ')' + col1 = ', '.join(f'{c.name!r}' for c in self.col1) + col2 = ', '.join(f'{c.name!r}' for c in self.col2) + return f"Reference([{col1}] {self.type} [{col2}]" + + def _validate(self): + table1 = self.col1[0].table + if any(c.table != table1 for c in self.col1): + raise DBMLError('Columns in col1 are from different tables') + if table1 is None: + raise TableNotFoundError('Table on col1 is not set') + + table2 = self.col2[0].table + if any(c.table != table2 for c in self.col2): + raise DBMLError('Columns in col2 are from different tables') + if table2 is None: + raise TableNotFoundError('Table on col2 is not set') @property def sql(self): @@ -99,16 +102,17 @@ def sql(self): ''' self.check_attributes_for_sql() + self._validate() c = f'CONSTRAINT "{self.name}" ' if self.name else '' if self.inline: if self.type in (MANY_TO_ONE, ONE_TO_ONE): source_col = self.col1 - ref_table = self.table2 + ref_table = self.col2[0].table ref_col = self.col2 else: source_col = self.col2 - ref_table = self.table1 + ref_table = self.col1[0].table ref_col = self.col1 cols = '", "'.join(c.name for c in source_col) @@ -125,14 +129,14 @@ def sql(self): return result else: if self.type in (MANY_TO_ONE, ONE_TO_ONE): - t1 = self.table1 + t1 = self.col1[0].table c1 = ', '.join(f'"{c.name}"' for c in self.col1) - t2 = self.table2 + t2 = self.col2[0].table c2 = ', '.join(f'"{c.name}"' for c in self.col2) else: - t1 = self.table2 + t1 = self.col2[0].table c1 = ', '.join(f'"{c.name}"' for c in self.col2) - t2 = self.table1 + t2 = self.col1[0].table c2 = ', '.join(f'"{c.name}"' for c in self.col1) result = comment_to_sql(self.comment) if self.comment else '' @@ -148,11 +152,12 @@ def sql(self): @property def dbml(self): + self._validate() if self.inline: # settings are ignored for inline ref if len(self.col2) > 1: raise DBMLError('Cannot render DBML: composite ref cannot be inline') - return f'ref: {self.type} "{self.table2.name}"."{self.col2[0].name}"' + return f'ref: {self.type} "{self.col2[0].table.name}"."{self.col2[0].name}"' else: result = comment_to_dbml(self.comment) if self.comment else '' result += 'Ref' @@ -180,9 +185,9 @@ def dbml(self): options_str = f' [{", ".join(options)}]' if options else '' result += ( ' {\n ' - f'"{self.table1.name}".{col1} ' + f'"{self.col1[0].table.name}".{col1} ' f'{self.type} ' - f'"{self.table2.name}".{col2}' + f'"{self.col2[0].table.name}".{col2}' f'{options_str}' '\n}' ) diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index e847f2c..95b4feb 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -94,9 +94,6 @@ def add_index(self, i: Index) -> None: Adds index to self.indexes attribute and sets in this index the `table` attribute. ''' - # for subj in i.subjects: - # if isinstance(subj, Column) and (subj not in self.columns): - # raise ColumnNotFoundError(f'Cannot add index, column "{subj}" not defined in table "{self.name}".') i.table = self self.indexes.append(i) @@ -115,21 +112,21 @@ def delete_index(self, i: Union[Index, int]) -> Index: def get_refs(self) -> List[Reference]: if not self.schema: raise UnknownSchemaError('Schema for the table is not set') - return [ref for ref in self.schema.refs if ref.table1 == self] + return [ref for ref in self.schema.refs if ref.col1[0].table == self] def _get_references_for_sql(self) -> List[Reference]: ''' return inline references for this table sql definition ''' if not self.schema: - raise UnknownSchemaError('Schema for the table is not set') + raise UnknownSchemaError(f'Schema for the table {self} is not set') result = [] for ref in self.schema.refs: if ref.inline: if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ - (ref.table1 == self): + (ref.col1[0].table == self): result.append(ref) - elif (ref.type == ONE_TO_MANY) and (ref.table2 == self): + elif (ref.type == ONE_TO_MANY) and (ref.col2[0].table == self): result.append(ref) return result diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index a7837d6..690b322 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -71,9 +71,7 @@ def build(self) -> 'Reference': return Reference( type_=self.type, inline=self.inline, - table1=table1, col1=col1, - table2=table2, col2=col2, name=self.name, comment=self.comment, diff --git a/pydbml/schema.py b/pydbml/schema.py index f2c6e3b..dc0e797 100644 --- a/pydbml/schema.py +++ b/pydbml/schema.py @@ -1,15 +1,14 @@ +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from itertools import chain + from .classes import Enum from .classes import Project from .classes import Reference from .classes import Table from .classes import TableGroup -from pydbml.parser.blueprints import ReferenceBlueprint -from pydbml.exceptions import ColumnNotFoundError -from pydbml.exceptions import TableNotFoundError -from typing import Any -from typing import Dict -from typing import List -from typing import Optional class SchemaValidationError(Exception): @@ -18,13 +17,13 @@ class SchemaValidationError(Exception): class Schema: def __init__(self) -> None: - self.tables: List[Table] = [] - self.tables_dict: Dict[str, Table] = {} - self.refs: List[Reference] = [] + self.tables: List['Table'] = [] + self.tables_dict: Dict[str, 'Table'] = {} + self.refs: List['Reference'] = [] # self.ref_blueprints: List[ReferenceBlueprint] = [] - self.enums: List[Enum] = [] - self.table_groups: List[TableGroup] = [] - self.project: Optional[Project] = None + self.enums: List['Enum'] = [] + self.table_groups: List['TableGroup'] = [] + self.project: Optional['Project'] = None def __repr__(self) -> str: return f"" @@ -111,8 +110,8 @@ def add_table(self, obj: Table) -> Table: return obj def add_reference(self, obj: Reference): - for table in (obj.table1, obj.table2): - if table in self.tables: + for col in chain(obj.col1, obj.col2): + if col.schema == self: break else: raise SchemaValidationError( diff --git a/test/_test_classes.py b/test/_test_classes.py deleted file mode 100644 index b28235b..0000000 --- a/test/_test_classes.py +++ /dev/null @@ -1,468 +0,0 @@ -from unittest import TestCase - -from pydbml.classes import Column -from pydbml.classes import Enum -from pydbml.classes import EnumItem -from pydbml.classes import Index -from pydbml.classes import Note -from pydbml.classes import Project -from pydbml.classes import Reference -from pydbml.classes import TableReference -from pydbml.classes import ReferenceBlueprint -from pydbml.classes import SQLOjbect -from pydbml.classes import Table -from pydbml.classes import TableGroup -from pydbml.classes import TableReference -from pydbml.exceptions import AttributeMissingError -from pydbml.exceptions import ColumnNotFoundError -from pydbml.exceptions import DuplicateReferenceError - - - -# class TestReferenceBlueprint(TestCase): -# def test_basic_sql(self) -> None: -# r = ReferenceBlueprint( -# ReferenceBlueprint.MANY_TO_ONE, -# table1='bookings', -# col1='country', -# table2='ids', -# col2='id' -# ) -# expected = 'ALTER TABLE "bookings" ADD FOREIGN KEY ("country") REFERENCES "ids ("id");' -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_ONE -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_MANY -# expected2 = 'ALTER TABLE "ids" ADD FOREIGN KEY ("id") REFERENCES "bookings ("country");' -# self.assertEqual(r.sql, expected2) - -# def test_full(self) -> None: -# r = ReferenceBlueprint( -# ReferenceBlueprint.MANY_TO_ONE, -# name='refname', -# table1='bookings', -# col1='country', -# table2='ids', -# col2='id', -# on_update='cascade', -# on_delete='restrict' -# ) -# expected = 'ALTER TABLE "bookings" ADD CONSTRAINT "refname" FOREIGN KEY ("country") REFERENCES "ids ("id") ON UPDATE CASCADE ON DELETE RESTRICT;' -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_ONE -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_MANY -# expected2 = 'ALTER TABLE "ids" ADD CONSTRAINT "refname" FOREIGN KEY ("id") REFERENCES "bookings ("country") ON UPDATE CASCADE ON DELETE RESTRICT;' -# self.assertEqual(r.sql, expected2) - - -# class TestTableReference(TestCase): -# def test_basic_sql(self) -> None: -# r = TableReference(col='order_id', -# ref_table='orders', -# ref_col='id') -# expected = 'FOREIGN KEY ("order_id") REFERENCES "orders ("id")' -# self.assertEqual(r.sql, expected) - -# def test_full(self) -> None: -# r = TableReference(col='order_id', -# ref_table='orders', -# ref_col='id', -# name='refname', -# on_delete='set null', -# on_update='no action') -# expected = 'CONSTRAINT "refname" FOREIGN KEY ("order_id") REFERENCES "orders ("id") ON UPDATE NO ACTION ON DELETE SET NULL' -# self.assertEqual(r.sql, expected) - - -class TestTable(TestCase): - def test_one_column(self) -> None: - t = Table('products') - c = Column('id', 'integer') - t.add_column(c) - expected = 'CREATE TABLE "products" (\n "id" integer\n);' - self.assertEqual(t.sql, expected) - - def test_ref(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - r = TableReference(c2, t2, c21) - t.add_ref(r) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names" ("name_val") -);''' - self.assertEqual(t.sql, expected) - - def test_duplicate_ref(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - r1 = TableReference(c2, t2, c21) - t.add_ref(r1) - r2 = TableReference(c2, t2, c21) - self.assertEqual(r1, r2) - with self.assertRaises(DuplicateReferenceError): - t.add_ref(r2) - - def test_notes(self) -> None: - n = Note('Table note') - nc1 = Note('First column note') - nc2 = Note('Another column\nmultiline note') - t = Table('products', note=n) - c1 = Column('id', 'integer', note=nc1) - c2 = Column('name', 'varchar') - c3 = Column('country', 'varchar', note=nc2) - t.add_column(c1) - t.add_column(c2) - t.add_column(c3) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar, - "country" varchar -); - -COMMENT ON TABLE "products" IS 'Table note'; - -COMMENT ON COLUMN "products"."id" IS 'First column note'; - -COMMENT ON COLUMN "products"."country" IS 'Another column -multiline note';''' - self.assertEqual(t.sql, expected) - - def test_ref_index(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - r = TableReference(c2, t2, c21) - t.add_ref(r) - i = Index(['id', 'name']) - t.add_index(i) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names" ("name_val") -); - -CREATE INDEX ON "products" ("id", "name");''' - self.assertEqual(t.sql, expected) - - def test_index_inline(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - i = Index(['id', 'name'], pk=True) - t.add_index(i) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - PRIMARY KEY ("id", "name") -);''' - self.assertEqual(t.sql, expected) - - def test_index_inline_and_comments(self) -> None: - t = Table('products', comment='Multiline\ntable comment') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - i = Index(['id', 'name'], pk=True, comment='Multiline\nindex comment') - t.add_index(i) - expected = \ -'''-- Multiline --- table comment -CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - -- Multiline - -- index comment - PRIMARY KEY ("id", "name") -);''' - self.assertEqual(t.sql, expected) - - def test_add_column(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - self.assertEqual(c1.table, t) - self.assertEqual(c2.table, t) - self.assertEqual(t.columns, [c1, c2]) - - def test_add_index(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - i1 = Index(['id']) - i2 = Index(['name']) - t.add_column(c1) - t.add_column(c2) - t.add_index(i1) - t.add_index(i2) - self.assertEqual(i1.table, t) - self.assertEqual(i2.table, t) - self.assertEqual(t.indexes, [i1, i2]) - - def test_add_bad_index(self) -> None: - t = Table('products') - c = Column('id', 'integer') - i = Index(['id', 'name']) - t.add_column(c) - with self.assertRaises(ColumnNotFoundError): - t.add_index(i) - - def test_dbml_simple(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - expected = \ -'''Table "products" { - "id" integer - "name" varchar2 -}''' - self.assertEqual(t.dbml, expected) - - def test_dbml_full(self): - t = Table( - 'products', - alias='pd', - note='My multiline\nnote', - comment='My multiline\ncomment' - ) - c0 = Column('zero', 'number') - c1 = Column('id', 'integer', unique=True, note='Multiline\ncomment note') - c2 = Column('name', 'varchar2') - t.add_column(c0) - t.add_column(c1) - t.add_column(c2) - i1 = Index(['zero', 'id'], unique=True) - i2 = Index(['(capitalize(name))'], comment="index comment") - t.add_index(i1) - t.add_index(i2) - expected = \ -"""// My multiline -// comment -Table "products" as "pd" { - "zero" number - "id" integer [unique, note: '''Multiline - comment note'''] - "name" varchar2 - Note { - ''' - My multiline - note - ''' - } - - indexes { - (zero, id) [unique] - // index comment - `capitalize(name)` - } -}""" - self.assertEqual(t.dbml, expected) - - - - - -class TestTableReference(TestCase): - def test_sql_single(self): - t = Table('products') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = TableReference( - c1, t2, c2) - - expected = 'FOREIGN KEY ("name") REFERENCES "names" ("name_val")' - self.assertEqual(ref.sql, expected) - - def test_sql_multiple(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = TableReference( - [c11, c12], - t2, - (c21, c22) - ) - - expected = 'FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val")' - self.assertEqual(ref.sql, expected) - - def test_sql_full(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = TableReference( - [c11, c12], - t2, - (c21, c22), - name="country_name", - on_delete='SET NULL', - on_update='CASCADE' - ) - - expected = 'CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL' - self.assertEqual(ref.sql, expected) - - - -class TestNote(TestCase): - def test_init_types(self): - n1 = Note('My note text') - n2 = Note(3) - n3 = Note([1, 2, 3]) - n4 = Note(None) - n5 = Note(n1) - - self.assertEqual(n1.text, 'My note text') - self.assertEqual(n2.text, '3') - self.assertEqual(n3.text, '[1, 2, 3]') - self.assertEqual(n4.text, '') - self.assertEqual(n5.text, 'My note text') - - def test_oneline(self): - note = Note('One line of note text') - expected = \ -'''Note { - 'One line of note text' -}''' - self.assertEqual(note.dbml, expected) - - def test_multiline(self): - note = Note('The number of spaces you use to indent a block string will be the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') - expected = \ -"""Note { - ''' - The number of spaces you use to indent a block string will be the minimum number - of leading spaces among all lines. The parser will automatically remove the number - of indentation spaces in the final output. - ''' -}""" - self.assertEqual(note.dbml, expected) - - - def test_forced_multiline(self): - note = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') - expected = \ -"""Note { - ''' - The number of spaces you use to indent a block string - will - be the minimum number of leading spaces among all lines. The parser will automatically - remove the number of indentation spaces in the final output. - ''' -}""" - self.assertEqual(note.dbml, expected) - - -class TestTableGroup(TestCase): - def test_dbml(self): - tg = TableGroup('mytg', ['merchants', 'countries', 'customers']) - expected = \ -'''TableGroup mytg { - merchants - countries - customers -}''' - self.assertEqual(tg.dbml, expected) - - def test_dbml_with_comment_and_real_tables(self): - merchants = Table('merchants') - countries = Table('countries') - customers = Table('customers') - tg = TableGroup( - 'mytg', - [merchants, countries, customers], - comment='My table group\nmultiline comment' - ) - expected = \ -'''// My table group -// multiline comment -TableGroup mytg { - merchants - countries - customers -}''' - self.assertEqual(tg.dbml, expected) - -class TestProject(TestCase): - def test_dbml_note(self): - p = Project('myproject', note='Project note') - expected = \ -'''Project myproject { - Note { - 'Project note' - } -}''' - self.assertEqual(p.dbml, expected) - - def test_dbml_full(self): - p = Project( - 'myproject', - items={ - 'database_type': 'PostgreSQL', - 'story': "One day I was eating my cantaloupe and\nI thought, why shouldn't I?\nWhy shouldn't I create a database?" - }, - comment='Multiline\nProject comment', - note='Multiline\nProject note') - expected = \ -"""// Multiline -// Project comment -Project myproject { - database_type: 'PostgreSQL' - story: '''One day I was eating my cantaloupe and - I thought, why shouldn't I? - Why shouldn't I create a database?''' - Note { - ''' - Multiline - Project note - ''' - } -}""" - self.assertEqual(p.dbml, expected) diff --git a/test/test_blueprints/test_reference.py b/test/test_blueprints/test_reference.py index 9678875..809c092 100644 --- a/test/test_blueprints/test_reference.py +++ b/test/test_blueprints/test_reference.py @@ -39,9 +39,7 @@ def test_build_minimal(self) -> None: self.assertEqual(result.type, bp.type) self.assertEqual(result.inline, bp.inline) self.assertEqual(parserMock.locate_table.call_count, 2) - self.assertEqual(result.table1, t1) self.assertEqual(result.col1, [c1]) - self.assertEqual(result.table2, t2) self.assertEqual(result.col2, [c2]) def test_tables_and_cols_are_set(self) -> None: diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index 2168f74..1889226 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -1,6 +1,9 @@ from unittest import TestCase +from pydbml.schema import Schema from pydbml.classes import Column +from pydbml.classes import Table +from pydbml.classes import Reference class TestColumn(TestCase): @@ -49,6 +52,10 @@ def test_dbml_simple(self): name='order', type_='integer' ) + t = Table(name='Test') + t.add_column(c) + s = Schema() + s.add(t) expected = '"order" integer' self.assertEqual(c.dbml, expected) @@ -65,6 +72,10 @@ def test_dbml_full(self): note='Note on the column', comment='Comment on the column' ) + t = Table(name='Test') + t.add_column(c) + s = Schema() + s.add(t) expected = \ '''// Comment on the column "order" integer [pk, increment, default: 'Def_value', unique, not null, note: 'Note on the column']''' @@ -79,6 +90,10 @@ def test_dbml_multiline_note(self): note='Note on the column\nmultiline', comment='Comment on the column' ) + t = Table(name='Test') + t.add_column(c) + s = Schema() + s.add(t) expected = \ """// Comment on the column "order" integer [not null, note: '''Note on the column @@ -92,6 +107,11 @@ def test_dbml_default(self): type_='integer', default='String value' ) + t = Table(name='Test') + t.add_column(c) + s = Schema() + s.add(t) + expected = "\"order\" integer [default: 'String value']" self.assertEqual(c.dbml, expected) @@ -119,4 +139,73 @@ def test_dbml_default(self): expected = '"order" integer [default: false]' self.assertEqual(c.dbml, expected) -# TODO: test ref inline + def test_schema(self): + c1 = Column(name='client_id', type_='integer') + t1 = Table(name='products') + + self.assertIsNone(c1.schema) + t1.add_column(c1) + self.assertIsNone(c1.schema) + s = Schema() + s.add(t1) + self.assertIs(c1.schema, s) + + def test_get_refs(self) -> None: + c1 = Column(name='client_id', type_='integer') + t1 = Table(name='products') + t1.add_column(c1) + c2 = Column(name='id', type_='integer', autoinc=True, pk=True) + t2 = Table(name='clients') + t2.add_column(c2) + + ref = Reference(type_='>', col1=c1, col2=c2, inline=True) + s = Schema() + s.add(t1) + s.add(t2) + s.add(ref) + + self.assertEqual(c1.get_refs(), [ref]) + + def test_dbml_with_ref(self) -> None: + c1 = Column(name='client_id', type_='integer') + t1 = Table(name='products') + t1.add_column(c1) + c2 = Column(name='id', type_='integer', autoinc=True, pk=True) + t2 = Table(name='clients') + t2.add_column(c2) + + ref = Reference(type_='>', col1=c1, col2=c2) + s = Schema() + s.add(t1) + s.add(t2) + s.add(ref) + + expected = '"client_id" integer' + self.assertEqual(c1.dbml, expected) + ref.inline = True + expected = '"client_id" integer [ref: > "clients"."id"]' + self.assertEqual(c1.dbml, expected) + expected = '"id" integer [pk, increment]' + self.assertEqual(c2.dbml, expected) + + def test_dbml_with_ref_and_properties(self) -> None: + c1 = Column(name='client_id', type_='integer') + t1 = Table(name='products') + t1.add_column(c1) + c2 = Column(name='id', type_='integer', autoinc=True, pk=True) + t2 = Table(name='clients') + t2.add_column(c2) + + ref = Reference(type_='<', col1=c2, col2=c1) + s = Schema() + s.add(t1) + s.add(t2) + s.add(ref) + + expected = '"id" integer [pk, increment]' + self.assertEqual(c2.dbml, expected) + ref.inline = True + expected = '"id" integer [ref: < "products"."client_id", pk, increment]' + self.assertEqual(c2.dbml, expected) + expected = '"client_id" integer' + self.assertEqual(c1.dbml, expected) diff --git a/test/test_classes/test_note.py b/test/test_classes/test_note.py new file mode 100644 index 0000000..00fa5c4 --- /dev/null +++ b/test/test_classes/test_note.py @@ -0,0 +1,50 @@ +from pydbml.classes import Note +from unittest import TestCase + + +class TestNote(TestCase): + def test_init_types(self): + n1 = Note('My note text') + n2 = Note(3) + n3 = Note([1, 2, 3]) + n4 = Note(None) + n5 = Note(n1) + + self.assertEqual(n1.text, 'My note text') + self.assertEqual(n2.text, '3') + self.assertEqual(n3.text, '[1, 2, 3]') + self.assertEqual(n4.text, '') + self.assertEqual(n5.text, 'My note text') + + def test_oneline(self): + note = Note('One line of note text') + expected = \ +'''Note { + 'One line of note text' +}''' + self.assertEqual(note.dbml, expected) + + def test_multiline(self): + note = Note('The number of spaces you use to indent a block string will be the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') + expected = \ +"""Note { + ''' + The number of spaces you use to indent a block string will be the minimum number + of leading spaces among all lines. The parser will automatically remove the number + of indentation spaces in the final output. + ''' +}""" + self.assertEqual(note.dbml, expected) + + def test_forced_multiline(self): + note = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') + expected = \ +"""Note { + ''' + The number of spaces you use to indent a block string + will + be the minimum number of leading spaces among all lines. The parser will automatically + remove the number of indentation spaces in the final output. + ''' +}""" + self.assertEqual(note.dbml, expected) diff --git a/test/test_classes/test_project.py b/test/test_classes/test_project.py new file mode 100644 index 0000000..cb68829 --- /dev/null +++ b/test/test_classes/test_project.py @@ -0,0 +1,41 @@ +from pydbml.classes import Project + +from unittest import TestCase + + +class TestProject(TestCase): + def test_dbml_note(self): + p = Project('myproject', note='Project note') + expected = \ +'''Project myproject { + Note { + 'Project note' + } +}''' + self.assertEqual(p.dbml, expected) + + def test_dbml_full(self): + p = Project( + 'myproject', + items={ + 'database_type': 'PostgreSQL', + 'story': "One day I was eating my cantaloupe and\nI thought, why shouldn't I?\nWhy shouldn't I create a database?" + }, + comment='Multiline\nProject comment', + note='Multiline\nProject note') + expected = \ +"""// Multiline +// Project comment +Project myproject { + database_type: 'PostgreSQL' + story: '''One day I was eating my cantaloupe and + I thought, why shouldn't I? + Why shouldn't I create a database?''' + Note { + ''' + Multiline + Project note + ''' + } +}""" + self.assertEqual(p.dbml, expected) diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 6a85a6c..c2cd798 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -3,6 +3,7 @@ from pydbml.classes import Table from pydbml.classes import Reference from pydbml.exceptions import DBMLError +from pydbml.exceptions import TableNotFoundError class TestReference(TestCase): @@ -13,7 +14,7 @@ def test_sql_single(self): t2 = Table('names') c2 = Column('name_val', 'varchar2') t2.add_column(c2) - ref = Reference('>', t, c1, t2, c2) + ref = Reference('>', c1, c2) expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' self.assertEqual(ref.sql, expected) @@ -25,7 +26,7 @@ def test_sql_reverse(self): t2 = Table('names') c2 = Column('name_val', 'varchar2') t2.add_column(c2) - ref = Reference('<', t, c1, t2, c2) + ref = Reference('<', c1, c2) expected = 'ALTER TABLE "names" ADD FOREIGN KEY ("name_val") REFERENCES "products" ("name");' self.assertEqual(ref.sql, expected) @@ -41,7 +42,7 @@ def test_sql_multiple(self): c22 = Column('country_val', 'varchar2') t2.add_column(c21) t2.add_column(c22) - ref = Reference('>', t, [c11, c12], t2, (c21, c22)) + ref = Reference('>', [c11, c12], (c21, c22)) expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val");' self.assertEqual(ref.sql, expected) @@ -59,9 +60,7 @@ def test_sql_full(self): t2.add_column(c22) ref = Reference( '>', - t, [c11, c12], - t2, (c21, c22), name="country_name", comment="Multiline\ncomment for the constraint", @@ -85,7 +84,7 @@ def test_dbml_simple(self): t2 = Table('names') c21 = Column('name_val', 'varchar2') t2.add_column(c21) - ref = Reference('>', t, c2, t2, c21) + ref = Reference('>', c2, c21) expected = \ '''Ref { @@ -108,9 +107,7 @@ def test_dbml_full(self): t2.add_column(c22) ref = Reference( '<', - t, [c2, c3], - t2, (c21, c22), name='nameref', comment='Reference comment\nmultiline', @@ -135,7 +132,7 @@ def test_sql_single(self): t2 = Table('names') c2 = Column('name_val', 'varchar2') t2.add_column(c2) - ref = Reference('>', t, c1, t2, c2, inline=True) + ref = Reference('>', c1, c2, inline=True) expected = 'FOREIGN KEY ("name") REFERENCES "names" ("name_val")' self.assertEqual(ref.sql, expected) @@ -147,7 +144,7 @@ def test_sql_reverse(self): t2 = Table('names') c2 = Column('name_val', 'varchar2') t2.add_column(c2) - ref = Reference('<', t, c1, t2, c2, inline=True) + ref = Reference('<', c1, c2, inline=True) expected = 'FOREIGN KEY ("name_val") REFERENCES "products" ("name")' self.assertEqual(ref.sql, expected) @@ -163,7 +160,7 @@ def test_sql_multiple(self): c22 = Column('country_val', 'varchar2') t2.add_column(c21) t2.add_column(c22) - ref = Reference('>', t, [c11, c12], t2, (c21, c22), inline=True) + ref = Reference('>', [c11, c12], (c21, c22), inline=True) expected = 'FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val")' self.assertEqual(ref.sql, expected) @@ -181,9 +178,7 @@ def test_sql_full(self): t2.add_column(c22) ref = Reference( '>', - t, [c11, c12], - t2, (c21, c22), name="country_name", comment="Multiline\ncomment for the constraint", @@ -208,7 +203,7 @@ def test_dbml_simple(self): t2 = Table('names') c21 = Column('name_val', 'varchar2') t2.add_column(c21) - ref = Reference('>', t, c2, t2, c21, inline=True) + ref = Reference('>', c2, c21, inline=True) expected = 'ref: > "names"."name_val"' self.assertEqual(ref.dbml, expected) @@ -224,9 +219,7 @@ def test_dbml_settings_ignored(self): t2.add_column(c21) ref = Reference( '<', - t, c2, - t2, c21, name='nameref', comment='Reference comment\nmultiline', @@ -253,9 +246,7 @@ def test_dbml_composite_inline_ref_forbidden(self): t2.add_column(c22) ref = Reference( '<', - t, [c2, c3], - t2, (c21, c22), name='nameref', comment='Reference comment\nmultiline', @@ -267,3 +258,45 @@ def test_dbml_composite_inline_ref_forbidden(self): with self.assertRaises(DBMLError): ref.dbml + def test_validate_different_tables(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference( + '<', + [c2, c21], + [c1], + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL', + inline=True + ) + with self.assertRaises(DBMLError): + ref._validate() + + def test_validate_no_table(self): + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + c3 = Column('age', 'number') + c4 = Column('active', 'boolean') + ref1 = Reference( + '<', + c1, + c2 + ) + with self.assertRaises(TableNotFoundError): + ref1._validate() + + ref2 = Reference( + '<', + [c1, c2], + [c3, c4] + ) + with self.assertRaises(TableNotFoundError): + ref2._validate() diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index d4f30ab..1474751 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -1,9 +1,12 @@ from unittest import TestCase +from pydbml.classes import Column from pydbml.classes import Index +from pydbml.classes import Note from pydbml.classes import Reference from pydbml.classes import Table -from pydbml.classes import Column +from pydbml.exceptions import ColumnNotFoundError +from pydbml.schema import Schema class TestTable(TestCase): @@ -11,27 +14,39 @@ def test_one_column(self) -> None: t = Table('products') c = Column('id', 'integer') t.add_column(c) + s = Schema() + s.add(t) expected = 'CREATE TABLE "products" (\n "id" integer\n);' self.assertEqual(t.sql, expected) -# def test_ref(self) -> None: -# t = Table('products') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# t.add_column(c1) -# t.add_column(c2) -# t2 = Table('names') -# c21 = Column('name_val', 'varchar2') -# t2.add_column(c21) -# r = TableReference(c2, t2, c21) -# t.add_ref(r) -# expected = \ -# '''CREATE TABLE "products" ( -# "id" integer, -# "name" varchar2, -# FOREIGN KEY ("name") REFERENCES "names" ("name_val") -# );''' -# self.assertEqual(t.sql, expected) + def test_ref(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + s = Schema() + s.add(t) + s.add(t2) + r = Reference('>', c2, c21) + s.add(r) + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar2 +);''' + self.assertEqual(t.sql, expected) + r.inline = True + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar2, + FOREIGN KEY ("name") REFERENCES "names" ("name_val") +);''' + self.assertEqual(t.sql, expected) # def test_duplicate_ref(self) -> None: # t = Table('products') @@ -49,172 +64,222 @@ def test_one_column(self) -> None: # with self.assertRaises(DuplicateReferenceError): # t.add_ref(r2) -# def test_notes(self) -> None: -# n = Note('Table note') -# nc1 = Note('First column note') -# nc2 = Note('Another column\nmultiline note') -# t = Table('products', note=n) -# c1 = Column('id', 'integer', note=nc1) -# c2 = Column('name', 'varchar') -# c3 = Column('country', 'varchar', note=nc2) -# t.add_column(c1) -# t.add_column(c2) -# t.add_column(c3) -# expected = \ -# '''CREATE TABLE "products" ( -# "id" integer, -# "name" varchar, -# "country" varchar -# ); + def test_notes(self) -> None: + n = Note('Table note') + nc1 = Note('First column note') + nc2 = Note('Another column\nmultiline note') + t = Table('products', note=n) + c1 = Column('id', 'integer', note=nc1) + c2 = Column('name', 'varchar') + c3 = Column('country', 'varchar', note=nc2) + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + s = Schema() + s.add(t) + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar, + "country" varchar +); -# COMMENT ON TABLE "products" IS 'Table note'; +COMMENT ON TABLE "products" IS 'Table note'; -# COMMENT ON COLUMN "products"."id" IS 'First column note'; +COMMENT ON COLUMN "products"."id" IS 'First column note'; -# COMMENT ON COLUMN "products"."country" IS 'Another column -# multiline note';''' -# self.assertEqual(t.sql, expected) +COMMENT ON COLUMN "products"."country" IS 'Another column +multiline note';''' + self.assertEqual(t.sql, expected) -# def test_ref_index(self) -> None: -# t = Table('products') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# t.add_column(c1) -# t.add_column(c2) -# t2 = Table('names') -# c21 = Column('name_val', 'varchar2') -# t2.add_column(c21) -# r = TableReference(c2, t2, c21) -# t.add_ref(r) -# i = Index(['id', 'name']) -# t.add_index(i) -# expected = \ -# '''CREATE TABLE "products" ( -# "id" integer, -# "name" varchar2, -# FOREIGN KEY ("name") REFERENCES "names" ("name_val") -# ); - -# CREATE INDEX ON "products" ("id", "name");''' -# self.assertEqual(t.sql, expected) - -# def test_index_inline(self) -> None: -# t = Table('products') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# t.add_column(c1) -# t.add_column(c2) -# i = Index(['id', 'name'], pk=True) -# t.add_index(i) -# expected = \ -# '''CREATE TABLE "products" ( -# "id" integer, -# "name" varchar2, -# PRIMARY KEY ("id", "name") -# );''' -# self.assertEqual(t.sql, expected) - -# def test_index_inline_and_comments(self) -> None: -# t = Table('products', comment='Multiline\ntable comment') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# t.add_column(c1) -# t.add_column(c2) -# i = Index(['id', 'name'], pk=True, comment='Multiline\nindex comment') -# t.add_index(i) -# expected = \ -# '''-- Multiline -# -- table comment -# CREATE TABLE "products" ( -# "id" integer, -# "name" varchar2, -# -- Multiline -# -- index comment -# PRIMARY KEY ("id", "name") -# );''' -# self.assertEqual(t.sql, expected) - -# def test_add_column(self) -> None: -# t = Table('products') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# t.add_column(c1) -# t.add_column(c2) -# self.assertEqual(c1.table, t) -# self.assertEqual(c2.table, t) -# self.assertEqual(t.columns, [c1, c2]) + def test_ref_index(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + s = Schema() + s.add(t) -# def test_add_index(self) -> None: -# t = Table('products') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# i1 = Index(['id']) -# i2 = Index(['name']) -# t.add_column(c1) -# t.add_column(c2) -# t.add_index(i1) -# t.add_index(i2) -# self.assertEqual(i1.table, t) -# self.assertEqual(i2.table, t) -# self.assertEqual(t.indexes, [i1, i2]) + r = Reference('>', c2, c21, inline=True) + s.add(r) + i = Index(subjects=[c1, c2]) + t.add_index(i) + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar2, + FOREIGN KEY ("name") REFERENCES "names" ("name_val") +); -# def test_add_bad_index(self) -> None: -# t = Table('products') -# c = Column('id', 'integer') -# i = Index(['id', 'name']) -# t.add_column(c) -# with self.assertRaises(ColumnNotFoundError): -# t.add_index(i) +CREATE INDEX ON "products" ("id", "name");''' + self.assertEqual(t.sql, expected) -# def test_dbml_simple(self): -# t = Table('products') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# t.add_column(c1) -# t.add_column(c2) -# expected = \ -# '''Table "products" { -# "id" integer -# "name" varchar2 -# }''' -# self.assertEqual(t.dbml, expected) - -# def test_dbml_full(self): -# t = Table( -# 'products', -# alias='pd', -# note='My multiline\nnote', -# comment='My multiline\ncomment' -# ) -# c0 = Column('zero', 'number') -# c1 = Column('id', 'integer', unique=True, note='Multiline\ncomment note') -# c2 = Column('name', 'varchar2') -# t.add_column(c0) -# t.add_column(c1) -# t.add_column(c2) -# i1 = Index(['zero', 'id'], unique=True) -# i2 = Index(['(capitalize(name))'], comment="index comment") -# t.add_index(i1) -# t.add_index(i2) -# expected = \ -# """// My multiline -# // comment -# Table "products" as "pd" { -# "zero" number -# "id" integer [unique, note: '''Multiline -# comment note'''] -# "name" varchar2 -# Note { -# ''' -# My multiline -# note -# ''' -# } - -# indexes { -# (zero, id) [unique] -# // index comment -# `capitalize(name)` -# } -# }""" -# self.assertEqual(t.dbml, expected) + def test_index_inline(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + i = Index(subjects=[c1, c2], pk=True) + t.add_index(i) + s = Schema() + s.add(t) + + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar2, + PRIMARY KEY ("id", "name") +);''' + self.assertEqual(t.sql, expected) + + def test_index_inline_and_comments(self) -> None: + t = Table('products', comment='Multiline\ntable comment') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + i = Index(subjects=[c1, c2], pk=True, comment='Multiline\nindex comment') + t.add_index(i) + s = Schema() + s.add(t) + + expected = \ +'''-- Multiline +-- table comment +CREATE TABLE "products" ( + "id" integer, + "name" varchar2, + -- Multiline + -- index comment + PRIMARY KEY ("id", "name") +);''' + self.assertEqual(t.sql, expected) + + def test_add_column(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + self.assertEqual(c1.table, t) + self.assertEqual(c2.table, t) + self.assertEqual(t.columns, [c1, c2]) + + def test_add_index(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + i1 = Index([c1]) + i2 = Index([c2]) + t.add_column(c1) + t.add_column(c2) + t.add_index(i1) + t.add_index(i2) + self.assertEqual(i1.table, t) + self.assertEqual(i2.table, t) + self.assertEqual(t.indexes, [i1, i2]) + + # def test_add_bad_index(self) -> None: + # t = Table('products') + # c = Column('id', 'integer') + # i = Index(['id', 'name']) + # t.add_column(c) + # with self.assertRaises(ColumnNotFoundError): + # t.add_index(i) + + def test_dbml_simple(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + s = Schema() + s.add(t) + + expected = \ +'''Table "products" { + "id" integer + "name" varchar2 +}''' + self.assertEqual(t.dbml, expected) + + def test_dbml_reference(self): + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + s = Schema() + s.add(t) + s.add(t2) + r = Reference('>', c2, c21) + s.add(r) + expected = \ +'''Table "products" { + "id" integer + "name" varchar2 +}''' + self.assertEqual(t.dbml, expected) + r.inline = True + expected = \ +'''Table "products" { + "id" integer + "name" varchar2 [ref: > "names"."name_val"] +}''' + self.assertEqual(t.dbml, expected) + expected = \ +'''Table "names" { + "name_val" varchar2 +}''' + self.assertEqual(t2.dbml, expected) + + def test_dbml_full(self): + t = Table( + 'products', + alias='pd', + note='My multiline\nnote', + comment='My multiline\ncomment' + ) + c0 = Column('zero', 'number') + c1 = Column('id', 'integer', unique=True, note='Multiline\ncomment note') + c2 = Column('name', 'varchar2') + t.add_column(c0) + t.add_column(c1) + t.add_column(c2) + i1 = Index(['zero', 'id'], unique=True) + i2 = Index(['(capitalize(name))'], comment="index comment") + t.add_index(i1) + t.add_index(i2) + s = Schema() + s.add(t) + + expected = \ +"""// My multiline +// comment +Table "products" as "pd" { + "zero" number + "id" integer [unique, note: '''Multiline + comment note'''] + "name" varchar2 + Note { + ''' + My multiline + note + ''' + } + + indexes { + (zero, id) [unique] + // index comment + `capitalize(name)` + } +}""" + self.assertEqual(t.dbml, expected) diff --git a/test/test_classes/test_table_group.py b/test/test_classes/test_table_group.py new file mode 100644 index 0000000..033582b --- /dev/null +++ b/test/test_classes/test_table_group.py @@ -0,0 +1,35 @@ +from unittest import TestCase + +from pydbml.classes import Table +from pydbml.classes import TableGroup + + +class TestTableGroup(TestCase): + def test_dbml(self): + tg = TableGroup('mytg', ['merchants', 'countries', 'customers']) + expected = \ +'''TableGroup mytg { + merchants + countries + customers +}''' + self.assertEqual(tg.dbml, expected) + + def test_dbml_with_comment_and_real_tables(self): + merchants = Table('merchants') + countries = Table('countries') + customers = Table('customers') + tg = TableGroup( + 'mytg', + [merchants, countries, customers], + comment='My table group\nmultiline comment' + ) + expected = \ +'''// My table group +// multiline comment +TableGroup mytg { + merchants + countries + customers +}''' + self.assertEqual(tg.dbml, expected) diff --git a/test/test_create_schema.py b/test/test_create_schema.py index 1c98ece..1f98a5c 100644 --- a/test/test_create_schema.py +++ b/test/test_create_schema.py @@ -3,8 +3,10 @@ from pathlib import Path from unittest import TestCase +from pydbml.classes import Column +from pydbml.classes import Index +from pydbml.classes import Table from pydbml.schema import Schema -from pydbml.classes import Table, Column, Index TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' From a3292b6d35df083829e323f8d1f4f596d7dedb0a Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Thu, 12 May 2022 10:40:18 +0200 Subject: [PATCH 022/125] Add more tests for classes, fix issues, start testing and changing Schema --- TODO.md | 3 +- pydbml/classes/table.py | 2 +- pydbml/exceptions.py | 4 + pydbml/parser/blueprints.py | 5 + pydbml/parser/parser.py | 4 +- pydbml/schema.py | 148 ++++++---------- test/test_blueprints/test_column.py | 27 ++- test/test_classes/test_table.py | 90 +++++++++- test/test_create_schema.py | 90 ---------- test/test_schema.py | 263 ++++++++++++++++++++++++++++ 10 files changed, 439 insertions(+), 197 deletions(-) delete mode 100644 test/test_create_schema.py create mode 100644 test/test_schema.py diff --git a/TODO.md b/TODO.md index d95bd48..0d14f62 100644 --- a/TODO.md +++ b/TODO.md @@ -1,3 +1,4 @@ - Creating dbml schema in python - pyparsing new var names (+possibly new features) -- enum type +* - enum type +- `_type` -> `type` diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 95b4feb..c6eb17b 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -35,7 +35,7 @@ def __init__(self, header_color: Optional[str] = None, # refs: Optional[List[TableReference]] = None, comment: Optional[str] = None): - self.schema = None + self.schema: Schema = None self.name = name self.columns: List[Column] = [] self.indexes: List[Index] = [] diff --git a/pydbml/exceptions.py b/pydbml/exceptions.py index c8e61b5..b1c8cab 100644 --- a/pydbml/exceptions.py +++ b/pydbml/exceptions.py @@ -24,3 +24,7 @@ class UnknownSchemaError(Exception): class DBMLError(Exception): pass + + +class SchemaValidationError(Exception): + pass diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 690b322..ee8dd65 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -94,6 +94,11 @@ class ColumnBlueprint(Blueprint): comment: Optional[str] = None def build(self) -> 'Column': + if self.parser: + for enum in self.parser.schema.enums: + if enum.name == self.type: + self.type = enum + break return Column( name=self.name, type_=self.type, diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index c234799..3165301 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -161,11 +161,11 @@ def locate_table(self, name: str) -> 'Table': def build_schema(self): self.schema = Schema() + for enum_bp in self.enums: + self.schema.add(enum_bp.build()) for table_bp in self.tables: self.schema.add(table_bp.build()) self.ref_blueprints.extend(table_bp.get_reference_blueprints()) - for enum_bp in self.enums: - self.schema.add(enum_bp.build()) for table_group_bp in self.table_groups: self.schema.add(table_group_bp.build()) if self.project: diff --git a/pydbml/schema.py b/pydbml/schema.py index dc0e797..deb46e4 100644 --- a/pydbml/schema.py +++ b/pydbml/schema.py @@ -2,25 +2,22 @@ from typing import Dict from typing import List from typing import Optional -from itertools import chain from .classes import Enum from .classes import Project from .classes import Reference from .classes import Table from .classes import TableGroup - - -class SchemaValidationError(Exception): - pass +from .exceptions import SchemaValidationError class Schema: + _supported_types = (Table, Reference, Enum, TableGroup, Project) + def __init__(self) -> None: self.tables: List['Table'] = [] self.tables_dict: Dict[str, 'Table'] = {} self.refs: List['Reference'] = [] - # self.ref_blueprints: List[ReferenceBlueprint] = [] self.enums: List['Enum'] = [] self.table_groups: List['TableGroup'] = [] self.project: Optional['Project'] = None @@ -28,74 +25,29 @@ def __init__(self) -> None: def __repr__(self) -> str: return f"" - # def _build_refs_from_blueprints(self, blueprints: List[ReferenceBlueprint]): - # ''' - # Fill up the `refs` attribute with Reference object, created from - # reference blueprints; - # Add TableReference objects to each table which has references. - # Validate refs at the same time. - # ''' - # for ref_ in blueprints: - # for table_ in self.tables: - # if table_.name == ref_.table1 or table_.alias == ref_.table1: - # table1 = table_ - # break - # else: - # raise TableNotFoundError('Error while parsing reference:' - # f'table "{ref_.table1}"" is not defined.') - # for table_ in self.tables: - # if table_.name == ref_.table2 or table_.alias == ref_.table2: - # table2 = table_ - # break - # else: - # raise TableNotFoundError('Error while parsing reference:' - # f'table "{ref_.table2}"" is not defined.') - # col1_names = [c.strip('() ') for c in ref_.col1.split(',')] - # col1 = [] - # for col_name in col1_names: - # try: - # col1.append(table1[col_name]) - # except KeyError: - # raise ColumnNotFoundError('Error while parsing reference:' - # f'column "{col_name} not defined in table "{table1.name}".') - # col2_names = [c.strip('() ') for c in ref_.col2.split(',')] - # col2 = [] - # for col_name in col2_names: - # try: - # col2.append(table2[col_name]) - # except KeyError: - # raise ColumnNotFoundError('Error while parsing reference:' - # f'column "{col_name} not defined in table "{table2.name}".') - # self.add_reference( - # Reference( - # ref_.type, - # table1, - # col1, - # table2, - # col2, - # name=ref_.name, - # comment=ref_.comment, - # on_update=ref_.on_update, - # on_delete=ref_.on_delete - # ) - # ) - def _set_schema(self, obj: Any) -> None: obj.schema = self - def add(self, obj: Any) -> Any: - if isinstance(obj, Table): - return self.add_table(obj) - elif isinstance(obj, Reference): - return self.add_reference(obj) - elif isinstance(obj, Enum): - return self.add_enum(obj) - elif isinstance(obj, TableGroup): - return self.add_table_group(obj) - elif isinstance(obj, Project): - return self.add_project(obj) - else: - raise SchemaValidationError(f'Unsupported type {type(obj)}.') + def _unset_schema(self, obj: Any) -> None: + obj.schema = None + + def add(self, *objs: Any) -> List[Any]: + for obj in objs: + if not any(map(lambda t: isinstance(obj, t), self._supported_types)): + raise SchemaValidationError(f'Unsupported type {type(obj)}.') + result = [] + for obj in objs: + if isinstance(obj, Table): + result.append(self.add_table(obj)) + elif isinstance(obj, Reference): + result.append(self.add_reference(obj)) + elif isinstance(obj, Enum): + result.append(self.add_enum(obj)) + elif isinstance(obj, TableGroup): + result.append(self.add_table_group(obj)) + elif isinstance(obj, Project): + result.append(self.add_project(obj)) + return result def add_table(self, obj: Table) -> Table: if obj.name in self.tables_dict: @@ -110,8 +62,8 @@ def add_table(self, obj: Table) -> Table: return obj def add_reference(self, obj: Reference): - for col in chain(obj.col1, obj.col2): - if col.schema == self: + for col in (*obj.col1, *obj.col2): + if col.table.schema == self: break else: raise SchemaValidationError( @@ -148,30 +100,36 @@ def add_table_group(self, obj: TableGroup) -> TableGroup: return obj def add_project(self, obj: Project) -> Project: + if self.project: + self.delete_project() self._set_schema(obj) self.project = obj return obj - def delete(self, obj: Any) -> Any: - if isinstance(obj, Table): - return self.delete_table(obj) - elif isinstance(obj, Reference): - return self.delete_reference(obj) - elif isinstance(obj, Enum): - return self.delete_enum(obj) - elif isinstance(obj, TableGroup): - return self.delete_table_group(obj) - elif isinstance(obj, Project): - return self.delete_project() - else: - raise SchemaValidationError(f'Unsupported type {type(obj)}.') + def delete(self, *objs: Any) -> List[Any]: + for obj in objs: + if not any(map(lambda t: isinstance(obj, t), self._supported_types)): + raise SchemaValidationError(f'Unsupported type {type(obj)}.') + result = [] + for obj in objs: + if isinstance(obj, Table): + result.append(self.delete_table(obj)) + elif isinstance(obj, Reference): + result.append(self.delete_reference(obj)) + elif isinstance(obj, Enum): + result.append(self.delete_enum(obj)) + elif isinstance(obj, TableGroup): + result.append(self.delete_table_group(obj)) + elif isinstance(obj, Project): + result.append(self.delete_project()) + return result def delete_table(self, obj: Table) -> Table: try: index = self.tables.index(obj) except ValueError: raise SchemaValidationError(f'{obj} is not in the schema.') - self.tables.pop(index).schema = None + self._unset_schema(self.tables.pop(index)) return self.tables_dict.pop(obj.name) def delete_reference(self, obj: Reference) -> Reference: @@ -180,7 +138,7 @@ def delete_reference(self, obj: Reference) -> Reference: except ValueError: raise SchemaValidationError(f'{obj} is not in the schema.') result = self.refs.pop(index) - result.schema = None + self._unset_schema(result) return result def delete_enum(self, obj: Enum) -> Enum: @@ -189,22 +147,22 @@ def delete_enum(self, obj: Enum) -> Enum: except ValueError: raise SchemaValidationError(f'{obj} is not in the schema.') result = self.enums.pop(index) - result.schema = None + self._unset_schema(result) return result def delete_table_group(self, obj: TableGroup) -> TableGroup: try: - index = self.tables_groups.index(obj) + index = self.table_groups.index(obj) except ValueError: raise SchemaValidationError(f'{obj} is not in the schema.') result = self.table_groups.pop(index) - result.schema = None + self._unset_schema(result) return result def delete_project(self) -> Project: - if self.Project is None: + if self.project is None: raise SchemaValidationError(f'Project is not set.') - result = self.Project - self.Project = None - result.schema = None + result = self.project + self.project = None + self._unset_schema(result) return result diff --git a/test/test_blueprints/test_column.py b/test/test_blueprints/test_column.py index 78a5b1e..cd3e2e8 100644 --- a/test/test_blueprints/test_column.py +++ b/test/test_blueprints/test_column.py @@ -1,9 +1,13 @@ from unittest import TestCase +from unittest.mock import Mock -from pydbml.classes import Note from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Note from pydbml.parser.blueprints import ColumnBlueprint from pydbml.parser.blueprints import NoteBlueprint +from pydbml.schema import Schema class TestColumn(TestCase): @@ -41,3 +45,24 @@ def test_build_full(self) -> None: self.assertIsInstance(result.note, Note) self.assertEqual(result.note.text, bp.note.text) self.assertEqual(result.comment, bp.comment) + + def test_enum_type(self) -> None: + s = Schema() + e = Enum( + 'myenum', + items=[ + EnumItem('i1'), + EnumItem('i2') + ] + ) + s.add(e) + parser = Mock() + parser.schema = s + + bp = ColumnBlueprint( + name='testcol', + type='myenum' + ) + bp.parser = parser + result = bp.build() + self.assertIs(result.type, e) diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 1474751..2356230 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -6,6 +6,7 @@ from pydbml.classes import Reference from pydbml.classes import Table from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import IndexNotFoundError from pydbml.schema import Schema @@ -170,6 +171,21 @@ def test_add_column(self) -> None: self.assertEqual(c2.table, t) self.assertEqual(t.columns, [c1, c2]) + def test_delete_column(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t.delete_column(c1) + self.assertIsNone(c1.table) + self.assertNotIn(c1, t.columns) + t.delete_column(0) + self.assertIsNone(c2.table) + self.assertNotIn(c2, t.columns) + with self.assertRaises(ColumnNotFoundError): + t.delete_column(c2) + def test_add_index(self) -> None: t = Table('products') c1 = Column('id', 'integer') @@ -184,13 +200,73 @@ def test_add_index(self) -> None: self.assertEqual(i2.table, t) self.assertEqual(t.indexes, [i1, i2]) - # def test_add_bad_index(self) -> None: - # t = Table('products') - # c = Column('id', 'integer') - # i = Index(['id', 'name']) - # t.add_column(c) - # with self.assertRaises(ColumnNotFoundError): - # t.add_index(i) + def test_delete_index(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + i1 = Index([c1]) + i2 = Index([c2]) + t.add_column(c1) + t.add_column(c2) + t.add_index(i1) + t.add_index(i2) + t.delete_index(0) + self.assertIsNone(i1.table) + self.assertNotIn(i1, t.indexes) + t.delete_index(i2) + self.assertIsNone(i2.table) + self.assertNotIn(i2, t.indexes) + with self.assertRaises(IndexNotFoundError): + t.delete_index(i1) + + def test_get_references_for_sql(self): + t = Table('products') + c11 = Column('id', 'integer') + c12 = Column('name', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('id', 'integer') + c22 = Column('name_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + s = Schema() + s.add(t) + s.add(t2) + r1 = Reference('>', c12, c22) + r2 = Reference('-', c11, c21) + r3 = Reference('<', c11, c22) + s.add(r1) + s.add(r2) + s.add(r3) + self.assertEqual(t._get_references_for_sql(), []) + self.assertEqual(t2._get_references_for_sql(), []) + r1.inline = r2.inline = r3.inline = True + self.assertEqual(t._get_references_for_sql(), [r1, r2]) + self.assertEqual(t2._get_references_for_sql(), [r3]) + + def test_get_refs(self): + t = Table('products') + c11 = Column('id', 'integer') + c12 = Column('name', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('id', 'integer') + c22 = Column('name_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + s = Schema() + s.add(t) + s.add(t2) + r1 = Reference('>', c12, c22) + r2 = Reference('-', c11, c21) + r3 = Reference('<', c11, c22) + s.add(r1) + s.add(r2) + s.add(r3) + self.assertEqual(t.get_refs(), [r1, r2, r3]) + self.assertEqual(t2.get_refs(), []) def test_dbml_simple(self): t = Table('products') diff --git a/test/test_create_schema.py b/test/test_create_schema.py deleted file mode 100644 index 1f98a5c..0000000 --- a/test/test_create_schema.py +++ /dev/null @@ -1,90 +0,0 @@ -import os - -from pathlib import Path -from unittest import TestCase - -from pydbml.classes import Column -from pydbml.classes import Index -from pydbml.classes import Table -from pydbml.schema import Schema - - -TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' - - -class TestCreateTable(TestCase): - def test_one_column(self) -> None: - c = Column('test', 'varchar', True) - t = Table('test_table') - t.add_column(c) - self.assertEqual(c.table, t) - schema = Schema() - schema.add(t) - self.assertEqual(t.schema, schema) - self.assertEqual(schema.tables[0], t) - self.assertEqual(schema.tables[0].name, 'test_table') - self.assertEqual(schema.tables[0].columns[0].name, 'test') - - def test_delete_column(self) -> None: - c1 = Column('col1', 'varchar', True) - c2 = Column('col2', 'number', False) - t = Table('test_table') - t.add_column(c1) - t.add_column(c2) - result = t.delete_column(1) - self.assertEqual(result, c2) - self.assertIsNone(result.table) - - def test_delete_table(self) -> None: - c = Column('test', 'varchar', True) - t = Table('test_table') - t.add_column(c) - schema = Schema() - schema.add(t) - self.assertEqual(t.schema, schema) - self.assertEqual(schema.tables[0], t) - schema.delete(t) - self.assertIsNone(t.schema) - self.assertEqual(schema.tables, []) - - -class TestCreateIndex(TestCase): - def test_simple_index(self): - c1 = Column('col1', 'varchar', True) - c2 = Column('col2', 'number', False) - t = Table('test_table') - t.add_column(c1) - t.add_column(c2) - i = Index([c1], 'IndexName', True) - self.assertIsNone(i.table) - t.add_index(i) - self.assertEqual(i.table, t) - - def test_complex_index(self): - c1 = Column('col1', 'varchar', True) - c2 = Column('col2', 'number', False) - t = Table('test_table') - t.add_column(c1) - t.add_column(c2) - i1 = Index([c1, c2], 'Compound', True) - self.assertIsNone(i1.table) - t.add_index(i1) - self.assertEqual(i1.table, t) - i2 = Index([c1, '(c2 * 3)'], 'Compound expression', True) - self.assertIsNone(i2.table) - t.add_index(i2) - self.assertEqual(i2.table, t) - - def test_delete_index(self): - c1 = Column('col1', 'varchar', True) - c2 = Column('col2', 'number', False) - t = Table('test_table') - t.add_column(c1) - t.add_column(c2) - i = Index([c1], 'IndexName', True) - self.assertIsNone(i.table) - t.add_index(i) - self.assertEqual(i.table, t) - t.delete_index(0) - self.assertIsNone(i.table) - self.assertEqual(t.indexes, []) diff --git a/test/test_schema.py b/test/test_schema.py new file mode 100644 index 0000000..799a5f3 --- /dev/null +++ b/test/test_schema.py @@ -0,0 +1,263 @@ +import os + +from pathlib import Path +from unittest import TestCase + +from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Project +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml.exceptions import SchemaValidationError +from pydbml.schema import Schema + + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class TestSchema(TestCase): + def test_add_table(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + res = schema.add_table(t) + self.assertEqual(t.schema, schema) + self.assertIs(res, t) + self.assertIn(t, schema.tables) + + def test_add_table_bad(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + schema.add_table(t) + with self.assertRaises(SchemaValidationError): + schema.add_table(t) + t2 = Table('test_table') + with self.assertRaises(SchemaValidationError): + schema.add_table(t2) + + def test_delete_table(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + schema.add_table(t) + res = schema.delete_table(t) + self.assertIsNone(t.schema, schema) + self.assertIs(res, t) + self.assertNotIn(t, schema.tables) + + def test_delete_missing_table(self) -> None: + t = Table('test_table') + schema = Schema() + with self.assertRaises(SchemaValidationError): + schema.delete_table(t) + self.assertIsNone(t.schema, schema) + + def test_add_reference(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + schema.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + schema.add_table(t2) + ref = Reference('>', c, c2) + res = schema.add_reference(ref) + self.assertEqual(ref.schema, schema) + self.assertIs(res, ref) + self.assertIn(ref, schema.refs) + + def test_add_reference_bad(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + schema.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + schema.add_table(t2) + ref = Reference('>', c, c2) + schema.add_reference(ref) + with self.assertRaises(SchemaValidationError): + schema.add_reference(ref) + + c3 = Column('test', 'varchar', True) + t3 = Table('test_table') + t3.add_column(c3) + schema3 = Schema() + schema3.add_table(t3) + c32 = Column('test2', 'integer') + t32 = Table('test_table2') + t32.add_column(c32) + schema3.add_table(t32) + ref3 = Reference('>', c3, c32) + with self.assertRaises(SchemaValidationError): + schema.add_reference(ref3) + + def test_delete_reference(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + schema.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + schema.add_table(t2) + ref = Reference('>', c, c2) + res = schema.add_reference(ref) + res = schema.delete_reference(ref) + self.assertIsNone(ref.schema, schema) + self.assertIs(res, ref) + self.assertNotIn(ref, schema.refs) + + def test_delete_missing_reference(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + schema = Schema() + schema.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + schema.add_table(t2) + ref = Reference('>', c, c2) + with self.assertRaises(SchemaValidationError): + schema.delete_reference(ref) + self.assertIsNone(ref.schema) + + def test_add_enum(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + schema = Schema() + res = schema.add_enum(e) + self.assertEqual(e.schema, schema) + self.assertIs(res, e) + self.assertIn(e, schema.enums) + + def test_add_enum_bad(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + schema = Schema() + schema.add_enum(e) + with self.assertRaises(SchemaValidationError): + schema.add_enum(e) + e2 = Enum('myenum', [EnumItem('a2'), EnumItem('b2')]) + with self.assertRaises(SchemaValidationError): + schema.add_enum(e2) + + def test_delete_enum(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + schema = Schema() + schema.add_enum(e) + res = schema.delete_enum(e) + self.assertIsNone(e.schema) + self.assertIs(res, e) + self.assertNotIn(e, schema.enums) + + def test_delete_missing_enum(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + schema = Schema() + with self.assertRaises(SchemaValidationError): + schema.delete_enum(e) + self.assertIsNone(e.schema) + + def test_add_table_group(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + schema = Schema() + res = schema.add_table_group(tg) + self.assertEqual(tg.schema, schema) + self.assertIs(res, tg) + self.assertIn(tg, schema.table_groups) + + def test_add_table_group_bad(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + schema = Schema() + schema.add_table_group(tg) + with self.assertRaises(SchemaValidationError): + schema.add_table_group(tg) + tg2 = TableGroup('mytablegroup', [t2]) + with self.assertRaises(SchemaValidationError): + schema.add_table_group(tg2) + + def test_delete_table_group(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + schema = Schema() + schema.add_table_group(tg) + res = schema.delete_table_group(tg) + self.assertIsNone(tg.schema) + self.assertIs(res, tg) + self.assertNotIn(tg, schema.table_groups) + + def test_delete_missing_table_group(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + schema = Schema() + with self.assertRaises(SchemaValidationError): + schema.delete_table_group(tg) + self.assertIsNone(tg.schema) + + def test_add_project(self) -> None: + p = Project('myproject') + schema = Schema() + res = schema.add_project(p) + self.assertEqual(p.schema, schema) + self.assertIs(res, p) + self.assertIs(schema.project, p) + + def test_add_another_project(self) -> None: + p = Project('myproject') + schema = Schema() + schema.add_project(p) + p2 = Project('anotherproject') + res = schema.add_project(p2) + self.assertEqual(p2.schema, schema) + self.assertIs(res, p2) + self.assertIs(schema.project, p2) + self.assertIsNone(p.schema) + + def test_delete_project(self) -> None: + p = Project('myproject') + schema = Schema() + schema.add_project(p) + res = schema.delete_project() + self.assertIsNone(p.schema, schema) + self.assertIs(res, p) + self.assertIsNone(schema.project) + + def delete_missing_project(self) -> None: + schema = Schema() + with self.assertRaises(SchemaValidationError): + schema.delete_project() + + def test_add(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + schema = Schema() + res = schema.add(t, t2, e) + self.assertEqual(res, [t, t2, e]) + self.assertEqual(t.schema, schema) + self.assertIn(t, schema.tables) + self.assertEqual(t2.schema, schema) + self.assertIn(t2, schema.tables) + self.assertEqual(e.schema, schema) + self.assertIn(e, schema.enums) +# TODO: unset schema if error From d92324ada561b6c047128081add83c46e2f3e86b Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Fri, 13 May 2022 09:59:43 +0200 Subject: [PATCH 023/125] Finish tests for schema, fix test_docs and other bugs --- TODO.md | 1 + pydbml/__init__.py | 1 + pydbml/parser/blueprints.py | 7 ++- pydbml/parser/parser.py | 30 ++++------ pydbml/schema.py | 105 +++++++++++++++++++------------- test/test_docs.py | 29 ++++----- test/test_schema.py | 116 +++++++++++++++++++++++++++++++----- 7 files changed, 198 insertions(+), 91 deletions(-) diff --git a/TODO.md b/TODO.md index 0d14f62..80b71c4 100644 --- a/TODO.md +++ b/TODO.md @@ -2,3 +2,4 @@ - pyparsing new var names (+possibly new features) * - enum type - `_type` -> `type` +- schema.add and .delete to support multiple arguments (handle errors properly) \ No newline at end of file diff --git a/pydbml/__init__.py b/pydbml/__init__.py index e7d44ce..1ea4a15 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -2,6 +2,7 @@ import unittest from . import classes +from .parser import PyDBML from pydbml.constants import MANY_TO_ONE from pydbml.constants import ONE_TO_MANY diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index ee8dd65..b078a0e 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -61,11 +61,12 @@ def build(self) -> 'Reference': raise ColumnNotFoundError("Can't build Reference, col2 unknown") table1 = self.parser.locate_table(self.table1) - col1_list = [self.col1] if isinstance(self.col1, str) else self.col1 + + col1_list = [c.strip('() ') for c in self.col1.split(',')] col1 = [table1[col] for col in col1_list] table2 = self.parser.locate_table(self.table2) - col2_list = [self.col2] if isinstance(self.col2, str) else self.col2 + col2_list = [c.strip('() ') for c in self.col2.split(',')] col2 = [table2[col] for col in col2_list] return Reference( @@ -183,7 +184,7 @@ def get_reference_blueprints(self): ''' the inline ones ''' result = [] for col in self.columns: - for ref_bp in col.ref_blueprints: + for ref_bp in col.ref_blueprints or []: ref_bp.table1 = self.name ref_bp.col1 = col.name result.append(ref_bp) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 3165301..48409ac 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -70,7 +70,8 @@ def parse(text: str) -> PyDBMLParser: if text[0] == '\ufeff': # removing BOM text = text[1:] - return PyDBMLParser(text) + parser = PyDBMLParser(text) + return parser.parse() @staticmethod def parse_file(file: Union[str, Path, TextIOWrapper]) -> PyDBMLParser: @@ -82,8 +83,7 @@ def parse_file(file: Union[str, Path, TextIOWrapper]) -> PyDBMLParser: if source[0] == '\ufeff': # removing BOM source = source[1:] parser = PyDBMLParser(source) - parser.parse() - return parser + return parser.parse() def parse(source: str): @@ -108,6 +108,7 @@ def parse(self): self._set_syntax() self._syntax.parseString(self.source, parseAll=True) self.build_schema() + return self.schema def __repr__(self): return "" @@ -138,6 +139,10 @@ def parse_blueprint(self, s, l, t): blueprint = t[0] if isinstance(blueprint, TableBlueprint): self.tables.append(blueprint) + ref_bps = blueprint.get_reference_blueprints() + for ref_bp in ref_bps: + self.refs.append(ref_bp) + ref_bp.parser = self elif isinstance(blueprint, ReferenceBlueprint): self.refs.append(blueprint) elif isinstance(blueprint, EnumBlueprint): @@ -169,7 +174,7 @@ def build_schema(self): for table_group_bp in self.table_groups: self.schema.add(table_group_bp.build()) if self.project: - self.schema.add(project.build()) + self.schema.add(self.project.build()) for ref_bp in self.refs: self.schema.add(ref_bp.build()) @@ -230,19 +235,4 @@ def build_schema(self): # tg.items = [self.schema.tables_dict[i] for i in tg.items] # self.schema.add_table_group(tg) -# @property -# def sql(self): -# '''Returs SQL of the parsed results''' - -# components = (i.sql for i in (*self.enums, *self.tables)) -# return '\n\n'.join(components) - -# @property -# def dbml(self): -# '''Generates DBML code out of parsed results''' -# items = [self.project] if self.project else [] -# items.extend((*self.tables, *self.refs, *self.enums, *self.table_groups)) -# components = ( -# i.dbml for i in items -# ) -# return '\n\n'.join(components) + diff --git a/pydbml/schema.py b/pydbml/schema.py index deb46e4..a0e9538 100644 --- a/pydbml/schema.py +++ b/pydbml/schema.py @@ -2,6 +2,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Union from .classes import Enum from .classes import Project @@ -12,11 +13,9 @@ class Schema: - _supported_types = (Table, Reference, Enum, TableGroup, Project) - def __init__(self) -> None: self.tables: List['Table'] = [] - self.tables_dict: Dict[str, 'Table'] = {} + self.table_dict: Dict[str, 'Table'] = {} self.refs: List['Reference'] = [] self.enums: List['Enum'] = [] self.table_groups: List['TableGroup'] = [] @@ -25,40 +24,49 @@ def __init__(self) -> None: def __repr__(self) -> str: return f"" + def __getitem__(self, k: Union[int, str]) -> Table: + if isinstance(k, int): + return self.tables[k] + else: + return self.table_dict[k] + + def __iter__(self): + return iter(self.tables) + def _set_schema(self, obj: Any) -> None: obj.schema = self def _unset_schema(self, obj: Any) -> None: obj.schema = None - def add(self, *objs: Any) -> List[Any]: - for obj in objs: - if not any(map(lambda t: isinstance(obj, t), self._supported_types)): - raise SchemaValidationError(f'Unsupported type {type(obj)}.') - result = [] - for obj in objs: - if isinstance(obj, Table): - result.append(self.add_table(obj)) - elif isinstance(obj, Reference): - result.append(self.add_reference(obj)) - elif isinstance(obj, Enum): - result.append(self.add_enum(obj)) - elif isinstance(obj, TableGroup): - result.append(self.add_table_group(obj)) - elif isinstance(obj, Project): - result.append(self.add_project(obj)) - return result + def add(self, obj: Any) -> Any: + if isinstance(obj, Table): + return self.add_table(obj) + elif isinstance(obj, Reference): + return self.add_reference(obj) + elif isinstance(obj, Enum): + return self.add_enum(obj) + elif isinstance(obj, TableGroup): + return self.add_table_group(obj) + elif isinstance(obj, Project): + return self.add_project(obj) + else: + raise SchemaValidationError(f'Unsupported type {type(obj)}.') def add_table(self, obj: Table) -> Table: - if obj.name in self.tables_dict: + if obj.name in self.table_dict: raise SchemaValidationError(f'Table {obj.name} is already in the schema.') + if obj.alias and obj.alias in self.table_dict: + raise SchemaValidationError(f'Table {obj.alias} is already in the schema.') if obj in self.tables: raise SchemaValidationError(f'{obj} is already in the schema.') self._set_schema(obj) self.tables.append(obj) - self.tables_dict[obj.name] = obj + self.table_dict[obj.name] = obj + if obj.alias: + self.table_dict[obj.alias] = obj return obj def add_reference(self, obj: Reference): @@ -106,23 +114,19 @@ def add_project(self, obj: Project) -> Project: self.project = obj return obj - def delete(self, *objs: Any) -> List[Any]: - for obj in objs: - if not any(map(lambda t: isinstance(obj, t), self._supported_types)): - raise SchemaValidationError(f'Unsupported type {type(obj)}.') - result = [] - for obj in objs: - if isinstance(obj, Table): - result.append(self.delete_table(obj)) - elif isinstance(obj, Reference): - result.append(self.delete_reference(obj)) - elif isinstance(obj, Enum): - result.append(self.delete_enum(obj)) - elif isinstance(obj, TableGroup): - result.append(self.delete_table_group(obj)) - elif isinstance(obj, Project): - result.append(self.delete_project()) - return result + def delete(self, obj: Any) -> Any: + if isinstance(obj, Table): + return self.delete_table(obj) + elif isinstance(obj, Reference): + return self.delete_reference(obj) + elif isinstance(obj, Enum): + return self.delete_enum(obj) + elif isinstance(obj, TableGroup): + return self.delete_table_group(obj) + elif isinstance(obj, Project): + return self.delete_project() + else: + raise SchemaValidationError(f'Unsupported type {type(obj)}.') def delete_table(self, obj: Table) -> Table: try: @@ -130,7 +134,10 @@ def delete_table(self, obj: Table) -> Table: except ValueError: raise SchemaValidationError(f'{obj} is not in the schema.') self._unset_schema(self.tables.pop(index)) - return self.tables_dict.pop(obj.name) + result = self.table_dict.pop(obj.name) + if obj.alias: + self.table_dict.pop(obj.alias) + return result def delete_reference(self, obj: Reference) -> Reference: try: @@ -166,3 +173,21 @@ def delete_project(self) -> Project: self.project = None self._unset_schema(result) return result + + @property + def sql(self): + '''Returs SQL of the parsed results''' + refs = (ref for ref in self.refs if not ref.inline) + components = (i.sql for i in (*self.enums, *self.tables, *refs)) + return '\n\n'.join(components) + + @property + def dbml(self): + '''Generates DBML code out of parsed results''' + items = (self.project) if self.project else () + refs = (ref for ref in self.refs if not ref.inline) + items.extend((*self.enums, *self.tables, *refs, *self.table_groups)) + components = ( + i.dbml for i in items + ) + return '\n\n'.join(components) diff --git a/test/test_docs.py b/test/test_docs.py index f3fae25..1816a1c 100644 --- a/test/test_docs.py +++ b/test/test_docs.py @@ -14,6 +14,7 @@ TEST_DOCS_PATH = Path(os.path.abspath(__file__)).parent / 'test_data/docs' +TestCase.maxDiff = None class TestDocs(TestCase): @@ -25,7 +26,7 @@ def test_example(self) -> None: self.assertEqual(len(results.refs), 1) ref = results.refs[0] - self.assertEqual((posts, users), (ref.table1, ref.table2)) + self.assertEqual((posts, users), (ref.col1[0].table, ref.col2[0].table)) def test_project(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'project.dbml') @@ -52,7 +53,7 @@ def test_table_alias(self) -> None: self.assertEqual(len(results.refs), 1) ref = results.refs[0] - self.assertEqual((u, posts), (ref.table1, ref.table2)) + self.assertEqual((u, posts), (ref.col1[0].table, ref.col2[0].table)) def test_table_notes(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'table_notes.dbml') @@ -135,12 +136,12 @@ def test_relationships(self) -> None: rf = results.refs - self.assertEqual(rf[0].table1, posts) - self.assertEqual(rf[0].table2, users) + self.assertEqual(rf[0].col1[0].table, posts) + self.assertEqual(rf[0].col2[0].table, users) self.assertEqual(rf[0].type, '>') - self.assertEqual(rf[1].table1, reviews) - self.assertEqual(rf[1].table2, users) + self.assertEqual(rf[1].col1[0].table, reviews) + self.assertEqual(rf[1].col2[0].table, users) self.assertEqual(rf[1].type, '>') results = PyDBML.parse_file(TEST_DOCS_PATH / 'relationships_2.dbml') @@ -148,12 +149,12 @@ def test_relationships(self) -> None: rf = results.refs - self.assertEqual(rf[0].table1, users) - self.assertEqual(rf[0].table2, posts) + self.assertEqual(rf[0].col1[0].table, users) + self.assertEqual(rf[0].col2[0].table, posts) self.assertEqual(rf[0].type, '<') - self.assertEqual(rf[1].table1, users) - self.assertEqual(rf[1].table2, reviews) + self.assertEqual(rf[1].col1[0].table, users) + self.assertEqual(rf[1].col2[0].table, reviews) self.assertEqual(rf[1].type, '<') def test_relationships_composite(self) -> None: @@ -164,8 +165,8 @@ def test_relationships_composite(self) -> None: self.assertEqual(len(rf), 1) - self.assertEqual(rf[0].table1, merchant_periods) - self.assertEqual(rf[0].table2, merchants) + self.assertEqual(rf[0].col1[0].table, merchant_periods) + self.assertEqual(rf[0].col2[0].table, merchants) self.assertEqual(rf[0].type, '>') self.assertEqual( rf[0].col1, @@ -190,8 +191,8 @@ def test_relationship_settings(self) -> None: self.assertEqual(len(rf), 1) - self.assertEqual(rf[0].table1, merchant_periods) - self.assertEqual(rf[0].table2, merchants) + self.assertEqual(rf[0].col1[0].table, merchant_periods) + self.assertEqual(rf[0].col2[0].table, merchants) self.assertEqual(rf[0].type, '>') self.assertEqual(rf[0].on_delete, 'cascade') self.assertEqual(rf[0].on_update, 'no action') diff --git a/test/test_schema.py b/test/test_schema.py index 799a5f3..c45653d 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -28,6 +28,25 @@ def test_add_table(self) -> None: self.assertIs(res, t) self.assertIn(t, schema.tables) + def test_add_table_alias(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table', alias='myalias') + t.add_column(c) + schema = Schema() + schema.add_table(t) + self.assertIs(schema[t.alias], t) + + def test_add_table_alias_bad(self) -> None: + c = Column('test', 'varchar', True) + t = Table('myalias') + t.add_column(c) + schema = Schema() + schema.add_table(t) + t2 = Table('test_table', alias='myalias') + with self.assertRaises(SchemaValidationError): + schema.add_table(t2) + self.assertIsNone(t2.schema) + def test_add_table_bad(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table') @@ -42,7 +61,7 @@ def test_add_table_bad(self) -> None: def test_delete_table(self) -> None: c = Column('test', 'varchar', True) - t = Table('test_table') + t = Table('test_table', alias='myalias') t.add_column(c) schema = Schema() schema.add_table(t) @@ -50,6 +69,8 @@ def test_delete_table(self) -> None: self.assertIsNone(t.schema, schema) self.assertIs(res, t) self.assertNotIn(t, schema.tables) + self.assertNotIn('test_table', schema.table_dict) + self.assertNotIn('myalias', schema.table_dict) def test_delete_missing_table(self) -> None: t = Table('test_table') @@ -243,21 +264,88 @@ def delete_missing_project(self) -> None: with self.assertRaises(SchemaValidationError): schema.delete_project() + def test_geititem(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + schema = Schema() + schema.add_table(t1) + schema.add_table(t2) + self.assertIs(schema['table1'], t1) + self.assertIs(schema['table2'], t2) + self.assertIs(schema[0], t1) + self.assertIs(schema[1], t2) + with self.assertRaises(IndexError): + schema[2] + with self.assertRaises(KeyError): + schema['wrong'] + + def test_iter(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + schema = Schema() + schema.add_table(t1) + schema.add_table(t2) + self.assertEqual(list(iter(schema)), [t1, t2]) + def test_add(self) -> None: - c = Column('test', 'varchar', True) - t = Table('test_table') - t.add_column(c) - c2 = Column('test2', 'integer') - t2 = Table('test_table2') - t2.add_column(c2) + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) schema = Schema() - res = schema.add(t, t2, e) - self.assertEqual(res, [t, t2, e]) - self.assertEqual(t.schema, schema) - self.assertIn(t, schema.tables) - self.assertEqual(t2.schema, schema) + schema.add(t1) + schema.add(t2) + schema.add(e) + schema.add(tg) + self.assertIs(t1.schema, schema) + self.assertIs(t2.schema, schema) + self.assertIs(e.schema, schema) + self.assertIs(tg.schema, schema) + self.assertIn(t1, schema.tables) self.assertIn(t2, schema.tables) - self.assertEqual(e.schema, schema) + self.assertIn(tg, schema.table_groups) self.assertIn(e, schema.enums) -# TODO: unset schema if error + + def test_add_bad(self) -> None: + class Test: + pass + t = Test() + schema = Schema() + with self.assertRaises(SchemaValidationError): + schema.add(t) + with self.assertRaises(AttributeError): + t.schema + + def test_delete(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + schema = Schema() + schema.add(t1) + schema.add(t2) + schema.add(e) + schema.add(tg) + + schema.delete(t1) + schema.delete(t2) + schema.delete(e) + schema.delete(tg) + self.assertIsNone(t1.schema) + self.assertIsNone(t2.schema) + self.assertIsNone(e.schema) + self.assertIsNone(tg.schema) + self.assertNotIn(t1, schema.tables) + self.assertNotIn(t2, schema.tables) + self.assertNotIn(tg, schema.table_groups) + self.assertNotIn(e, schema.enums) + + def test_delete_bad(self) -> None: + class Test: + pass + t = Test() + schema = Schema() + with self.assertRaises(SchemaValidationError): + schema.delete(t) + with self.assertRaises(AttributeError): + t.schema From 1ebf701c9795cee6cdb8eab25d8c039b878004cd Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 14 May 2022 12:00:59 +0200 Subject: [PATCH 024/125] Finish old tests, fix found bugs --- pydbml/definitions/column.py | 2 +- pydbml/parser/parser.py | 6 ++ test/_test_parser.py | 96 ---------------------------- test/test_definitions/test_column.py | 8 +-- test/test_editing.py | 3 +- test/test_parser.py | 88 +++++++++++++++++++++++-- 6 files changed, 96 insertions(+), 107 deletions(-) delete mode 100644 test/_test_parser.py diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index 40ecfb4..144274f 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -21,7 +21,7 @@ pp.ParserElement.setDefaultWhitespaceChars(' \t\r') type_args = ("(" + pp.originalTextFor(expression)('args') + ")") -type_name = (pp.Word(pp.alphanums + '_') | pp.dblQuotedString())('name') +type_name = (pp.Word(pp.alphanums + '_') | pp.QuotedString('"'))('name') column_type = (type_name + type_args[0, 1]) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 48409ac..a4dc65a 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -140,9 +140,15 @@ def parse_blueprint(self, s, l, t): if isinstance(blueprint, TableBlueprint): self.tables.append(blueprint) ref_bps = blueprint.get_reference_blueprints() + col_bps = blueprint.columns or [] + index_bps = blueprint.indexes or [] for ref_bp in ref_bps: self.refs.append(ref_bp) ref_bp.parser = self + for col_bp in col_bps: + col_bp.parser = self + for index_bp in index_bps: + index_bp.parser = self elif isinstance(blueprint, ReferenceBlueprint): self.refs.append(blueprint) elif isinstance(blueprint, EnumBlueprint): diff --git a/test/_test_parser.py b/test/_test_parser.py deleted file mode 100644 index f6b76f7..0000000 --- a/test/_test_parser.py +++ /dev/null @@ -1,96 +0,0 @@ -import os - -from pathlib import Path -from unittest import TestCase - -from pydbml import PyDBML -from pydbml.exceptions import ColumnNotFoundError -from pydbml.exceptions import TableNotFoundError - - -TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' - - -class TestParser(TestCase): - def setUp(self): - self.results = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - - def test_table_refs(self) -> None: - p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - r = p['order_items'].refs - self.assertEqual(r[0].col[0].name, 'order_id') - self.assertEqual(r[0].ref_table.name, 'orders') - self.assertEqual(r[0].ref_col[0].name, 'id') - r = p['products'].refs - self.assertEqual(r[0].col[0].name, 'merchant_id') - self.assertEqual(r[0].ref_table.name, 'merchants') - self.assertEqual(r[0].ref_col[0].name, 'id') - r = p['users'].refs - self.assertEqual(r[0].col[0].name, 'country_code') - self.assertEqual(r[0].ref_table.name, 'countries') - self.assertEqual(r[0].ref_col[0].name, 'code') - - def test_refs(self) -> None: - p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - r = p.refs - self.assertEqual(r[0].table1.name, 'orders') - self.assertEqual(r[0].col1[0].name, 'id') - self.assertEqual(r[0].table2.name, 'order_items') - self.assertEqual(r[0].col2[0].name, 'order_id') - self.assertEqual(r[2].table1.name, 'users') - self.assertEqual(r[2].col1[0].name, 'country_code') - self.assertEqual(r[2].table2.name, 'countries') - self.assertEqual(r[2].col2[0].name, 'code') - self.assertEqual(r[4].table1.name, 'products') - self.assertEqual(r[4].col1[0].name, 'merchant_id') - self.assertEqual(r[4].table2.name, 'merchants') - self.assertEqual(r[4].col2[0].name, 'id') - - -class TestRefs(TestCase): - def test_reference_aliases(self): - results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_aliases.dbml') - posts, reviews, users = results['posts'], results['reviews'], results['users'] - posts2, reviews2, users2 = results['posts2'], results['reviews2'], results['users2'] - - rs = results.refs - self.assertEqual(rs[0].table1, users) - self.assertEqual(rs[0].table2, posts) - self.assertEqual(rs[1].table1, users) - self.assertEqual(rs[1].table2, reviews) - - self.assertEqual(rs[2].table1, posts2) - self.assertEqual(rs[2].table2, users2) - self.assertEqual(rs[3].table1, reviews2) - self.assertEqual(rs[3].table2, users2) - - def test_composite_references(self): - results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_composite.dbml') - self.assertEqual(len(results.tables), 4) - posts, reviews = results['posts'], results['reviews'] - posts2, reviews2 = results['posts2'], results['reviews2'] - - rs = results.refs - self.assertEqual(len(rs), 2) - - self.assertEqual(rs[0].table1, posts) - self.assertEqual(rs[0].col1, [posts['id'], posts['tag']]) - self.assertEqual(rs[0].table2, reviews) - self.assertEqual(rs[0].col2, [reviews['post_id'], reviews['tag']]) - - self.assertEqual(rs[1].table1, posts2) - self.assertEqual(rs[1].col1, [posts2['id'], posts2['tag']]) - self.assertEqual(rs[1].table2, reviews2) - self.assertEqual(rs[1].col2, [reviews2['post_id'], reviews2['tag']]) - - -class TestFaulty(TestCase): - def test_bad_reference(self) -> None: - with self.assertRaises(TableNotFoundError): - PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_table.dbml') - with self.assertRaises(ColumnNotFoundError): - PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_column.dbml') - - def test_bad_index(self) -> None: - with self.assertRaises(ColumnNotFoundError): - PyDBML(TEST_DATA_PATH / 'wrong_index.dbml') diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index 1b795af..98ef3e7 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -22,7 +22,7 @@ def test_simple(self) -> None: self.assertEqual(res[0], val) def test_quoted(self) -> None: - val = '"mytype"' + val = 'mytype' res = column_type.parseString(val, parseAll=True) self.assertEqual(res[0], val) @@ -181,7 +181,7 @@ def test_with_settings(self) -> None: val = "_test_ \"mytype\" [unique, not null, note: 'to include unit number']\n" res = table_column.parseString(val, parseAll=True) self.assertEqual(res[0].name, '_test_') - self.assertEqual(res[0].type, '\"mytype\"') + self.assertEqual(res[0].type, 'mytype') self.assertTrue(res[0].unique) self.assertTrue(res[0].not_null) self.assertTrue(res[0].note is not None) @@ -190,7 +190,7 @@ def test_settings_and_constraints(self) -> None: val = "_test_ \"mytype\" unique pk [not null]\n" res = table_column.parseString(val, parseAll=True) self.assertEqual(res[0].name, '_test_') - self.assertEqual(res[0].type, '\"mytype\"') + self.assertEqual(res[0].type, 'mytype') self.assertTrue(res[0].unique) self.assertTrue(res[0].not_null) self.assertTrue(res[0].pk) @@ -218,7 +218,7 @@ def test_comment_after(self) -> None: val3 = "_test_ \"mytype\" unique pk [not null] //comment after\n" res3 = table_column.parseString(val3, parseAll=True) self.assertEqual(res3[0].name, '_test_') - self.assertEqual(res3[0].type, '\"mytype\"') + self.assertEqual(res3[0].type, 'mytype') self.assertTrue(res3[0].unique) self.assertTrue(res3[0].not_null) self.assertTrue(res3[0].pk) diff --git a/test/test_editing.py b/test/test_editing.py index a53a509..8b20866 100644 --- a/test/test_editing.py +++ b/test/test_editing.py @@ -79,7 +79,8 @@ def test_name_ref(self) -> None: products = self.dbml['products'] col = products['merchant_id'] col.name = 'changed_merchant_id' - table_ref = products.refs[0] + merchants = self.dbml['merchants'] + table_ref = merchants.get_refs()[0] self.assertIn('FOREIGN KEY ("changed_merchant_id")', table_ref.sql) diff --git a/test/test_parser.py b/test/test_parser.py index 6228c65..2c35025 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -3,9 +3,7 @@ from pathlib import Path from unittest import TestCase -from pydbml.parser.blueprints import EnumItemBlueprint -from pydbml.parser.blueprints import EnumBlueprint -from pydbml import PyDBMLParser +from pydbml import PyDBML from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError @@ -14,5 +12,85 @@ class TestParser(TestCase): - def test_build_enums(self) -> None: - i1 = EnumItemBlueprint() \ No newline at end of file + def setUp(self): + self.results = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') + + def test_table_refs(self) -> None: + p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') + r = p['orders'].get_refs() + self.assertEqual(r[0].col2[0].name, 'order_id') + self.assertEqual(r[0].col1[0].table.name, 'orders') + self.assertEqual(r[0].col1[0].name, 'id') + r = p['products'].get_refs() + self.assertEqual(r[1].col1[0].name, 'merchant_id') + self.assertEqual(r[1].col2[0].table.name, 'merchants') + self.assertEqual(r[1].col2[0].name, 'id') + r = p['users'].get_refs() + self.assertEqual(r[0].col1[0].name, 'country_code') + self.assertEqual(r[0].col2[0].table.name, 'countries') + self.assertEqual(r[0].col2[0].name, 'code') + + def test_refs(self) -> None: + p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') + r = p.refs + self.assertEqual(r[0].col1[0].table.name, 'orders') + self.assertEqual(r[0].col1[0].name, 'id') + self.assertEqual(r[0].col2[0].table.name, 'order_items') + self.assertEqual(r[0].col2[0].name, 'order_id') + self.assertEqual(r[2].col1[0].table.name, 'users') + self.assertEqual(r[2].col1[0].name, 'country_code') + self.assertEqual(r[2].col2[0].table.name, 'countries') + self.assertEqual(r[2].col2[0].name, 'code') + self.assertEqual(r[4].col1[0].table.name, 'products') + self.assertEqual(r[4].col1[0].name, 'merchant_id') + self.assertEqual(r[4].col2[0].table.name, 'merchants') + self.assertEqual(r[4].col2[0].name, 'id') + + +class TestRefs(TestCase): + def test_reference_aliases(self): + results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_aliases.dbml') + posts, reviews, users = results['posts'], results['reviews'], results['users'] + posts2, reviews2, users2 = results['posts2'], results['reviews2'], results['users2'] + + rs = results.refs + self.assertEqual(rs[0].col1[0].table, users) + self.assertEqual(rs[0].col2[0].table, posts) + self.assertEqual(rs[1].col1[0].table, users) + self.assertEqual(rs[1].col2[0].table, reviews) + + self.assertEqual(rs[2].col1[0].table, posts2) + self.assertEqual(rs[2].col2[0].table, users2) + self.assertEqual(rs[3].col1[0].table, reviews2) + self.assertEqual(rs[3].col2[0].table, users2) + + def test_composite_references(self): + results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_composite.dbml') + self.assertEqual(len(results.tables), 4) + posts, reviews = results['posts'], results['reviews'] + posts2, reviews2 = results['posts2'], results['reviews2'] + + rs = results.refs + self.assertEqual(len(rs), 2) + + self.assertEqual(rs[0].col1[0].table, posts) + self.assertEqual(rs[0].col1, [posts['id'], posts['tag']]) + self.assertEqual(rs[0].col2[0].table, reviews) + self.assertEqual(rs[0].col2, [reviews['post_id'], reviews['tag']]) + + self.assertEqual(rs[1].col1[0].table, posts2) + self.assertEqual(rs[1].col1, [posts2['id'], posts2['tag']]) + self.assertEqual(rs[1].col2[0].table, reviews2) + self.assertEqual(rs[1].col2, [reviews2['post_id'], reviews2['tag']]) + + +class TestFaulty(TestCase): + def test_bad_reference(self) -> None: + with self.assertRaises(TableNotFoundError): + PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_table.dbml') + with self.assertRaises(ColumnNotFoundError): + PyDBML(TEST_DATA_PATH / 'wrong_inline_ref_column.dbml') + + def test_bad_index(self) -> None: + with self.assertRaises(ColumnNotFoundError): + PyDBML(TEST_DATA_PATH / 'wrong_index.dbml') From d0db6a91d0d5de335ab67dd0cd4f80743d7b3197 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 14 May 2022 15:25:38 +0200 Subject: [PATCH 025/125] more tests, fix bugs, cleanup --- .gitignore | 1 + TODO.md | 1 + changelog.md | 6 + pydbml/classes/_classes.py | 161 ------------------------- pydbml/classes/enum.py | 2 +- pydbml/classes/note.py | 4 +- pydbml/classes/reference.py | 5 - pydbml/classes/table.py | 42 +------ pydbml/classes/table_group.py | 2 +- pydbml/definitions/enum.py | 7 +- pydbml/parser/parser.py | 59 --------- pydbml/schema.py | 12 +- pydbml/tools.py | 10 +- test/test_blueprints/__init__.py | 0 test/test_blueprints/test_reference.py | 18 ++- test/test_classes/__init__.py | 0 test/test_classes/test_base.py | 1 + test/test_classes/test_column.py | 45 +++++++ test/test_classes/test_enum.py | 31 +++++ test/test_classes/test_note.py | 12 ++ test/test_classes/test_reference.py | 38 ++++-- test/test_classes/test_table.py | 64 +++++++--- test/test_classes/test_table_group.py | 25 ++++ test/test_generate_dbml.py | 75 ++++++++++++ test/test_schema.py | 16 ++- test/test_tools.py | 64 ++++++++++ 26 files changed, 398 insertions(+), 303 deletions(-) delete mode 100644 pydbml/classes/_classes.py create mode 100644 test/test_blueprints/__init__.py create mode 100644 test/test_classes/__init__.py create mode 100644 test/test_generate_dbml.py create mode 100644 test/test_tools.py diff --git a/.gitignore b/.gitignore index fae75a3..00e1adf 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build dist pydbml.egg-info .mypy_cache +.coverage diff --git a/TODO.md b/TODO.md index 80b71c4..1349156 100644 --- a/TODO.md +++ b/TODO.md @@ -2,4 +2,5 @@ - pyparsing new var names (+possibly new features) * - enum type - `_type` -> `type` +- expression class - schema.add and .delete to support multiple arguments (handle errors properly) \ No newline at end of file diff --git a/changelog.md b/changelog.md index 3f16e76..51abe17 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,9 @@ +# 1.0.0 + +- refs don't have tables, only columns +- tables don't have refs +- col1 col2 in ref are as they were in dbml + # 0.4.2 - Fix: after editing column name index dbml was not updated. diff --git a/pydbml/classes/_classes.py b/pydbml/classes/_classes.py deleted file mode 100644 index dee3485..0000000 --- a/pydbml/classes/_classes.py +++ /dev/null @@ -1,161 +0,0 @@ -from __future__ import annotations - -from typing import Any -from typing import Collection -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -from pydbml.parser.blueprints import IndexBlueprint -from pydbml.parser.blueprints import ReferenceBlueprint -from pydbml.exceptions import AttributeMissingError -from pydbml.exceptions import ColumnNotFoundError -from pydbml.exceptions import IndexNotFoundError -from pydbml.exceptions import UnknownSchemaError -from pydbml.tools import comment_to_dbml -from pydbml.tools import comment_to_sql -from pydbml.tools import indent -from pydbml.tools import note_option_to_dbml - - -class SQLOjbect: - ''' - Base class for all SQL objects. - ''' - required_attributes: Tuple[str, ...] = () - - def check_attributes_for_sql(self): - ''' - Check if all attributes, required for rendering SQL are set in the - instance. If some attribute is missing, raise AttributeMissingError - ''' - for attr in self.required_attributes: - if getattr(self, attr) is None: - raise AttributeMissingError( - f'Cannot render SQL. Missing required attribute "{attr}".' - ) - - def __setattr__(self, name: str, value: Any): - """ - Required for type testing with MyPy. - """ - super().__setattr__(name, value) - - def __eq__(self, other: object) -> bool: - """ - Two instances of the same SQLObject subclass are equal if all their - attributes are equal. - """ - - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - return False - - -class TableReference(SQLOjbect): - ''' - Class, representing a foreign key constraint. - This object should be assigned to the `refs` attribute of a Table object. - Its `sql` property contains the inline definition of the FOREIGN KEY clause. - ''' - required_attributes = ('col', 'ref_table', 'ref_col') - - def __init__(self, - col: Union[Column, List[Column]], - ref_table: Table, - ref_col: Union[Column, List[Column]], - name: Optional[str] = None, - on_delete: Optional[str] = None, - on_update: Optional[str] = None): - self.col = [col] if isinstance(col, Column) else list(col) - self.ref_table = ref_table - self.ref_col = [ref_col] if isinstance(ref_col, Column) else list(ref_col) - self.name = name - self.on_update = on_update - self.on_delete = on_delete - - def __repr__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t2 = Table('t2') - >>> TableReference(col=c1, ref_table=t2, ref_col=c2) - - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22)) - - ''' - - col_names = [c.name for c in self.col] - ref_col_names = [c.name for c in self.ref_col] - return f"" - - def __str__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t2 = Table('t2') - >>> print(TableReference(col=c1, ref_table=t2, ref_col=c2)) - TableReference([c1] > t2[c2]) - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> print(TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22))) - TableReference([c1, c12] > t2[c2, c22]) - ''' - - components = [f"TableReference("] - components.append(f'[{", ".join(c.name for c in self.col)}]') - components.append(' > ') - components.append(self.ref_table.name) - components.append(f'[{", ".join(c.name for c in self.ref_col)}]') - return ''.join(components) + ')' - - @property - def sql(self): - ''' - Returns inline SQL of the reference, which should be a part of table definition: - - FOREIGN KEY ("order_id") REFERENCES "orders ("id") - - ''' - self.check_attributes_for_sql() - c = f'CONSTRAINT "{self.name}" ' if self.name else '' - cols = '", "'.join(c.name for c in self.col) - ref_cols = '", "'.join(c.name for c in self.ref_col) - result = ( - f'{c}FOREIGN KEY ("{cols}") ' - f'REFERENCES "{self.ref_table.name}" ("{ref_cols}")' - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result - - - - - - - - -class EnumType(Enum): - ''' - Enum object, intended to be put in the `type` attribute of a column. - - >>> en = EnumItem('en-US') - >>> ru = EnumItem('ru-RU') - >>> EnumType('languages', [en, ru]) - - >>> print(_) - languages - ''' - - pass - - - - diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index 7114819..74284d5 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -64,7 +64,7 @@ def __init__(self, self.items = items self.comment = comment - def __getitem__(self, key) -> EnumItem: + def __getitem__(self, key: int) -> EnumItem: return self.items[key] def __iter__(self): diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index 32bbae3..0964806 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -1,7 +1,9 @@ from typing import Any +from .base import SQLOjbect -class Note: + +class Note(SQLOjbect): def __init__(self, text: Any): self.text = str(text) if text else '' diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 896299e..c8279f7 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -1,7 +1,6 @@ from typing import Collection from typing import Literal from typing import Optional -from typing import TYPE_CHECKING from typing import Union from .base import SQLOjbect @@ -25,9 +24,7 @@ class Reference(SQLOjbect): def __init__(self, type_: Literal[MANY_TO_ONE, ONE_TO_MANY, ONE_TO_ONE], - # table1: 'Table', col1: Union[Column, Collection[Column]], - # table2: 'Table', col2: Union[Column, Collection[Column]], name: Optional[str] = None, comment: Optional[str] = None, @@ -36,9 +33,7 @@ def __init__(self, inline: bool = False): self.schema = None self.type = type_ - # self.table1 = table1 self.col1 = [col1] if isinstance(col1, Column) else list(col1) - # self.table2 = table2 self.col2 = [col2] if isinstance(col2, Column) else list(col2) self.name = name if name else None self.comment = comment diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index c6eb17b..3bb151c 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -45,31 +45,6 @@ def __init__(self, # self.refs = refs or [] self.comment = comment - # def _build_index_from_blueprint(self, blueprint: IndexBlueprint) -> None: - # subjects = [] - # for subj in blueprint.subject_names: - # if subj.startswith('(') and subj.endswith(')'): - # # subject is an expression, add it as string - # subjects.append(subj) - # else: - # try: - # col = self[subj] - # subjects.append(col) - # except KeyError: - # raise ColumnNotFoundError(f'Cannot add index, column "{subj}" not defined in table "{self.name}".') - # index = Index( - # subjects, - # name=blueprint.name, - # unique=blueprint.unique, - # type_=blueprint.type, - # pk=blueprint.pk, - # note=blueprint.note, - # comment=blueprint.comment - # ) - - # index.table = self - # self.indexes.append(index) - def add_column(self, c: Column) -> None: ''' Adds column to self.columns attribute and sets in this column the @@ -130,28 +105,21 @@ def _get_references_for_sql(self) -> List[Reference]: result.append(ref) return result - - # def add_ref(self, r: TableReference) -> None: - # ''' - # Adds a reference to the table. If reference already present in the table, - # raises DuplicateReferenceError. - # ''' - # if r in self.refs: - # raise DuplicateReferenceError(f'Reference with same endpoints {r} already present in the table.') - # self.refs.append(r) def __getitem__(self, k: Union[int, str]) -> Column: if isinstance(k, int): return self.columns[k] - else: + elif isinstance(k, str): for c in self.columns: if c.name == k: return c raise ColumnNotFoundError(f'Column {k} not present in table {self.name}') + else: + raise TypeError('indeces must be str or int') - def get(self, k, default=None): + def get(self, k, default: Optional[Column] = None) -> Optional[Column]: try: return self.__getitem__(k) - except KeyError: + except (IndexError, ColumnNotFoundError): return default def __iter__(self): diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py index ba17b74..a34fdd4 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/classes/table_group.py @@ -37,7 +37,7 @@ def __repr__(self): items = [i if isinstance(i, str) else i.name for i in self.items] return f'' - def __getitem__(self, key) -> str: + def __getitem__(self, key: int) -> str: return self.items[key] def __iter__(self): diff --git a/pydbml/definitions/enum.py b/pydbml/definitions/enum.py index 3ca9751..b6fd993 100644 --- a/pydbml/definitions/enum.py +++ b/pydbml/definitions/enum.py @@ -40,10 +40,9 @@ def parse_enum_item(s, l, t): init_dict = {'name': t['name']} if 'settings' in t: init_dict.update(t['settings']) - - # comments after settings have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] + # comments after settings have priority + if 'comment' in t['settings']: + init_dict['comment'] = t['settings']['comment'] if 'comment' not in init_dict and 'comment_before' in t: comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index a4dc65a..06cc38e 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -183,62 +183,3 @@ def build_schema(self): self.schema.add(self.project.build()) for ref_bp in self.refs: self.schema.add(ref_bp.build()) - - -# class Temp: -# def _parse_table(self, s, l, t): -# table = t[0] -# self.schema.add_table(table) -# for col in table.columns: -# self.ref_blueprints.extend(col.ref_blueprints) - -# def _parse_ref_blueprint(self, s, l, t): -# self.ref_blueprints.append(t[0]) - -# def _parse_enum(self, s, l, t): -# self.schema.add_enum(t[0]) - -# def _parse_table_group(self, s, l, t): -# self.table_groups.append(t[0]) - -# def _parse_project(self, s, l, t): -# self.schema.add_project(t[0]) - -# def _process_refs(self): -# ''' -# Fill up the `refs` attribute with Reference object, created from -# reference blueprints; -# Add TableReference objects to each table which has references. -# Validate refs at the same time. -# ''' -# self.schema._build_refs_from_blueprints(self.ref_blueprints) - -# def _set_enum_types(self): -# enum_dict = {enum.name: enum for enum in self.schema.enums} -# for table_ in self.schema.tables: -# for col in table_: -# col_type = str(col.type).strip('"') -# if col_type in enum_dict: -# col.type = enum_dict[col_type] - -# def _validate(self): -# self._validate_table_groups() - -# def _validate_table_groups(self): -# ''' -# Check that all tables, mentioned in the table groups, exist -# ''' -# for tg in self.table_groups: -# for table_name in tg: -# if table_name not in self.schema.tables_dict: -# raise TableNotFoundError(f'Cannot add Table Group "{tg.name}": table "{table_name}" not found.') - -# def _process_table_groups(self): -# ''' -# Fill up each TableGroup's `item` attribute with references to actual tables. -# ''' -# for tg in self.table_groups: -# tg.items = [self.schema.tables_dict[i] for i in tg.items] -# self.schema.add_table_group(tg) - - diff --git a/pydbml/schema.py b/pydbml/schema.py index a0e9538..0b01bf1 100644 --- a/pydbml/schema.py +++ b/pydbml/schema.py @@ -54,12 +54,12 @@ def add(self, obj: Any) -> Any: raise SchemaValidationError(f'Unsupported type {type(obj)}.') def add_table(self, obj: Table) -> Table: + if obj in self.tables: + raise SchemaValidationError(f'{obj} is already in the schema.') if obj.name in self.table_dict: raise SchemaValidationError(f'Table {obj.name} is already in the schema.') if obj.alias and obj.alias in self.table_dict: raise SchemaValidationError(f'Table {obj.alias} is already in the schema.') - if obj in self.tables: - raise SchemaValidationError(f'{obj} is already in the schema.') self._set_schema(obj) @@ -86,22 +86,22 @@ def add_reference(self, obj: Reference): return obj def add_enum(self, obj: Enum) -> Enum: + if obj in self.enums: + raise SchemaValidationError(f'{obj} is already in the schema.') for enum in self.enums: if enum.name == obj.name: raise SchemaValidationError(f'Enum {obj.name} is already in the schema.') - if obj in self.enums: - raise SchemaValidationError(f'{obj} is already in the schema.') self._set_schema(obj) self.enums.append(obj) return obj def add_table_group(self, obj: TableGroup) -> TableGroup: + if obj in self.table_groups: + raise SchemaValidationError(f'{obj} is already in the schema.') for table_group in self.table_groups: if table_group.name == obj.name: raise SchemaValidationError(f'TableGroup {obj.name} is already in the schema.') - if obj in self.table_groups: - raise SchemaValidationError(f'{obj} is already in the schema.') self._set_schema(obj) self.table_groups.append(obj) diff --git a/pydbml/tools.py b/pydbml/tools.py index a1128a8..2e07f64 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -3,12 +3,16 @@ from .classes import Note +def comment(val: str, comb: str) -> str: + return '\n'.join(f'{comb} {cl}' for cl in val.split('\n')) + '\n' + + def comment_to_dbml(val: str) -> str: - return '\n'.join(f'// {cl}' for cl in val.split('\n')) + '\n' + return comment(val, '//') def comment_to_sql(val: str) -> str: - return '\n'.join(f'-- {cl}' for cl in val.split('\n')) + '\n' + return comment(val, '--') def note_option_to_dbml(val: 'Note') -> str: @@ -21,4 +25,4 @@ def note_option_to_dbml(val: 'Note') -> str: def indent(val: str, spaces=4) -> str: if val == '': return val - return ' ' * spaces + val.replace('\n', '\n' +' ' * spaces) + return ' ' * spaces + val.replace('\n', '\n' + ' ' * spaces) diff --git a/test/test_blueprints/__init__.py b/test/test_blueprints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_blueprints/test_reference.py b/test/test_blueprints/test_reference.py index 809c092..493a17b 100644 --- a/test/test_blueprints/test_reference.py +++ b/test/test_blueprints/test_reference.py @@ -42,7 +42,7 @@ def test_build_minimal(self) -> None: self.assertEqual(result.col1, [c1]) self.assertEqual(result.col2, [c2]) - def test_tables_and_cols_are_set(self) -> None: + def test_tables_and_cols_are_not_set(self) -> None: bp = ReferenceBlueprint( type='>', inline=True, @@ -54,6 +54,22 @@ def test_tables_and_cols_are_set(self) -> None: with self.assertRaises(TableNotFoundError): bp.build() + bp.table1 = 'table1' + bp.table2 = None + with self.assertRaises(TableNotFoundError): + bp.build() + + bp.table2 = 'table2' + bp.col1 = None + with self.assertRaises(ColumnNotFoundError): + bp.build() + + bp.col1 = 'col1' + bp.col2 = None + with self.assertRaises(ColumnNotFoundError): + bp.build() + + def test_tables_and_cols_are_set(self) -> None: bp = ReferenceBlueprint( type='>', inline=True, diff --git a/test/test_classes/__init__.py b/test/test_classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_classes/test_base.py b/test/test_classes/test_base.py index 7c0d228..fb5b6a1 100644 --- a/test/test_classes/test_base.py +++ b/test/test_classes/test_base.py @@ -31,3 +31,4 @@ def test_comparison(self) -> None: self.assertTrue(o1 == o2) o1.a2 = True self.assertFalse(o1 == o2) + self.assertFalse(o1 == 123) diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index 1889226..42a5420 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -4,9 +4,52 @@ from pydbml.classes import Column from pydbml.classes import Table from pydbml.classes import Reference +from pydbml.classes import Note +from pydbml.exceptions import TableNotFoundError class TestColumn(TestCase): + def test_attributes(self) -> None: + name = 'name' + type_ = 'type' + unique = True + not_null = True + pk = True + autoinc = True + default = '1' + note = Note('note') + comment = 'comment' + col = Column( + name=name, + type_=type_, + unique=unique, + not_null=not_null, + pk=pk, + autoinc=autoinc, + default=default, + note=note, + comment=comment, + ) + self.assertEqual(col.name, name) + self.assertEqual(col.type, type_) + self.assertEqual(col.unique, unique) + self.assertEqual(col.not_null, not_null) + self.assertEqual(col.pk, pk) + self.assertEqual(col.autoinc, autoinc) + self.assertEqual(col.default, default) + self.assertEqual(col.note, note) + self.assertEqual(col.comment, comment) + + def test_schema_set(self) -> None: + col = Column('name', 'int') + table = Table('name') + self.assertIsNone(col.schema) + table.add_column(col) + self.assertIsNone(col.schema) + schema = Schema() + schema.add(table) + self.assertIs(col.schema, schema) + def test_basic_sql(self) -> None: r = Column(name='id', type_='integer') @@ -152,6 +195,8 @@ def test_schema(self): def test_get_refs(self) -> None: c1 = Column(name='client_id', type_='integer') + with self.assertRaises(TableNotFoundError): + c1.get_refs() t1 = Table(name='products') t1.add_column(c1) c2 = Column(name='id', type_='integer', autoinc=True, pk=True) diff --git a/test/test_classes/test_enum.py b/test/test_classes/test_enum.py index a0f8563..c49f51e 100644 --- a/test/test_classes/test_enum.py +++ b/test/test_classes/test_enum.py @@ -88,3 +88,34 @@ def test_dbml_full(self): "en-GB" }''' self.assertEqual(e.dbml, expected) + + def test_getitem(self) -> None: + ei = EnumItem('created') + items = [ + EnumItem('running'), + ei, + EnumItem('donef'), + EnumItem('failure'), + ] + e = Enum('job_status', items) + self.assertIs(e[1], ei) + with self.assertRaises(IndexError): + e[22] + with self.assertRaises(TypeError): + e['abc'] + + def test_iter(self) -> None: + ei1 = EnumItem('created') + ei2 = EnumItem('running') + ei3 = EnumItem('donef') + ei4 = EnumItem('failure') + items = [ + ei1, + ei2, + ei3, + ei4, + ] + e = Enum('job_status', items) + + for i1, i2 in zip(e, [ei1, ei2, ei3, ei4]): + self.assertIs(i1, i2) diff --git a/test/test_classes/test_note.py b/test/test_classes/test_note.py index 00fa5c4..8e3c6b1 100644 --- a/test/test_classes/test_note.py +++ b/test/test_classes/test_note.py @@ -48,3 +48,15 @@ def test_forced_multiline(self): ''' }""" self.assertEqual(note.dbml, expected) + + def test_sql(self) -> None: + note1 = Note(None) + self.assertEqual(note1.sql, '') + note2 = Note('One line of note text') + self.assertEqual(note2.sql, '-- One line of note text') + note3 = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') + expected = \ +"""-- The number of spaces you use to indent a block string +-- will +-- be the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.""" + self.assertEqual(note3.sql, expected) diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index c2cd798..511369a 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -259,18 +259,32 @@ def test_dbml_composite_inline_ref_forbidden(self): ref.dbml def test_validate_different_tables(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) + t1 = Table('products') + c11 = Column('id', 'integer') + c12 = Column('name', 'varchar2') + t1.add_column(c11) + t1.add_column(c12) t2 = Table('names') c21 = Column('name_val', 'varchar2') + c22 = Column('product', 'varchar2') t2.add_column(c21) ref = Reference( '<', - [c2, c21], - [c1], + [c12, c21], + [c21], + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL', + inline=True + ) + with self.assertRaises(DBMLError): + ref._validate() + + ref = Reference( + '<', + [c11, c12], + [c21, c12], name='nameref', comment='Reference comment\nmultiline', on_update='CASCADE', @@ -292,6 +306,11 @@ def test_validate_no_table(self): ) with self.assertRaises(TableNotFoundError): ref1._validate() + table = Table('name') + table.add_column(c1) + with self.assertRaises(TableNotFoundError): + ref1._validate() + table.delete_column(c1) ref2 = Reference( '<', @@ -300,3 +319,8 @@ def test_validate_no_table(self): ) with self.assertRaises(TableNotFoundError): ref2._validate() + table = Table('name') + table.add_column(c1) + table.add_column(c2) + with self.assertRaises(TableNotFoundError): + ref2._validate() diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 2356230..37b4805 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -7,6 +7,7 @@ from pydbml.classes import Table from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import IndexNotFoundError +from pydbml.exceptions import UnknownSchemaError from pydbml.schema import Schema @@ -20,6 +21,49 @@ def test_one_column(self) -> None: expected = 'CREATE TABLE "products" (\n "id" integer\n);' self.assertEqual(t.sql, expected) + def test_getitem(self) -> None: + t = Table('products') + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + self.assertIs(t['col1'], c1) + self.assertIs(t[1], c2) + with self.assertRaises(IndexError): + t[22] + with self.assertRaises(TypeError): + t[None] + with self.assertRaises(ColumnNotFoundError): + t['wrong'] + + def test_get(self) -> None: + t = Table('products') + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t.add_column(c1) + t.add_column(c2) + self.assertIs(t.get(0), c1) + self.assertIs(t.get('col2'), c2) + self.assertIsNone(t.get('wrong')) + self.assertIsNone(t.get(22)) + self.assertIs(t.get('wrong', c2), c2) + self.assertIs(t.get(22, c2), c2) + self.assertIs(t.get('wrong', c3), c3) + + def test_iter(self) -> None: + t = Table('products') + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + for i1, i2 in zip(t, [c1, c2, c3]): + self.assertIs(i1, i2) + def test_ref(self) -> None: t = Table('products') c1 = Column('id', 'integer') @@ -49,22 +93,6 @@ def test_ref(self) -> None: );''' self.assertEqual(t.sql, expected) -# def test_duplicate_ref(self) -> None: -# t = Table('products') -# c1 = Column('id', 'integer') -# c2 = Column('name', 'varchar2') -# t.add_column(c1) -# t.add_column(c2) -# t2 = Table('names') -# c21 = Column('name_val', 'varchar2') -# t2.add_column(c21) -# r1 = TableReference(c2, t2, c21) -# t.add_ref(r1) -# r2 = TableReference(c2, t2, c21) -# self.assertEqual(r1, r2) -# with self.assertRaises(DuplicateReferenceError): -# t.add_ref(r2) - def test_notes(self) -> None: n = Note('Table note') nc1 = Note('First column note') @@ -221,6 +249,8 @@ def test_delete_index(self) -> None: def test_get_references_for_sql(self): t = Table('products') + with self.assertRaises(UnknownSchemaError): + t._get_references_for_sql() c11 = Column('id', 'integer') c12 = Column('name', 'varchar2') t.add_column(c11) @@ -247,6 +277,8 @@ def test_get_references_for_sql(self): def test_get_refs(self): t = Table('products') + with self.assertRaises(UnknownSchemaError): + t.get_refs() c11 = Column('id', 'integer') c12 = Column('name', 'varchar2') t.add_column(c11) diff --git a/test/test_classes/test_table_group.py b/test/test_classes/test_table_group.py index 033582b..75d05a1 100644 --- a/test/test_classes/test_table_group.py +++ b/test/test_classes/test_table_group.py @@ -33,3 +33,28 @@ def test_dbml_with_comment_and_real_tables(self): customers }''' self.assertEqual(tg.dbml, expected) + + def test_getitem(self) -> None: + merchants = Table('merchants') + countries = Table('countries') + customers = Table('customers') + tg = TableGroup( + 'mytg', + [merchants, countries, customers], + comment='My table group\nmultiline comment' + ) + self.assertIs(tg[1], countries) + with self.assertRaises(IndexError): + tg[22] + + def test_iter(self) -> None: + merchants = Table('merchants') + countries = Table('countries') + customers = Table('customers') + tg = TableGroup( + 'mytg', + [merchants, countries, customers], + comment='My table group\nmultiline comment' + ) + for i1, i2 in zip(tg, [merchants, countries, customers]): + self.assertIs(i1, i2) diff --git a/test/test_generate_dbml.py b/test/test_generate_dbml.py new file mode 100644 index 0000000..84d9e52 --- /dev/null +++ b/test/test_generate_dbml.py @@ -0,0 +1,75 @@ +import os + +from pathlib import Path +from unittest import TestCase + +from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Project +from pydbml.classes import Reference +from pydbml.classes import Index +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml.classes import Note +from pydbml.exceptions import SchemaValidationError +from pydbml.schema import Schema + + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class TestGenerateDBML(TestCase): + def test_generate_dbml(self) -> None: + schema = Schema() + emp_level = Enum( + 'level', + [ + EnumItem('junior'), + EnumItem('middle'), + EnumItem('senior'), + ] + ) + schema.add(emp_level) + + t1 = Table('Employees', alias='emp') + c11 = Column('id', 'integer', pk=True, autoinc=True) + c12 = Column('name', 'varchar', note=Note('Full employee name')) + c13 = Column('age', 'number', default=0) + c14 = Column('level', 'level') + c15 = Column('favorite_book_id', 'integer') + t1.add_column(c11) + t1.add_column(c12) + t1.add_column(c13) + t1.add_column(c14) + t1.add_column(c15) + schema.add(t1) + + t2 = Table('books') + c21 = Column('id', 'integer', pk=True, autoinc=True) + c22 = Column('title', 'varchar') + c23 = Column('author', 'varchar') + c24 = Column('country_id', 'integer') + t2.add_column(c21) + t2.add_column(c22) + t2.add_column(c23) + t2.add_column(c24) + schema.add(t2) + + t3 = Table('countries') + c31 = Column('id', 'integer', pk=True, autoinc=True) + c32 = Column('name', 'varchar2', unique=True) + t3.add_column(c31) + t3.add_column(c32) + i31 = Index([c32]) + t3.add_index(i31) + # TODO: expression class + # i32 = Index(['']) + schema.add(t3) + + ref1 = Reference('>', c15, c21) + schema.add(ref1) + + ref2 = Reference('<', c31, c24, name='Country Reference') + schema.add(ref2) + diff --git a/test/test_schema.py b/test/test_schema.py index c45653d..213e3ee 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -259,7 +259,7 @@ def test_delete_project(self) -> None: self.assertIs(res, p) self.assertIsNone(schema.project) - def delete_missing_project(self) -> None: + def test_delete_missing_project(self) -> None: schema = Schema() with self.assertRaises(SchemaValidationError): schema.delete_project() @@ -318,27 +318,41 @@ class Test: def test_delete(self) -> None: t1 = Table('table1') + c1 = Column('col1', 'int') + t1.add_column(c1) t2 = Table('table2') + c2 = Column('col2', 'int') + t2.add_column(c2) + ref = Reference('>', [c1], [c2]) tg = TableGroup('mytablegroup', [t1, t2]) e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + p = Project('myproject') schema = Schema() schema.add(t1) schema.add(t2) schema.add(e) schema.add(tg) + schema.add(ref) + schema.add(p) schema.delete(t1) schema.delete(t2) schema.delete(e) schema.delete(tg) + schema.delete(ref) + schema.delete(p) self.assertIsNone(t1.schema) self.assertIsNone(t2.schema) self.assertIsNone(e.schema) self.assertIsNone(tg.schema) + self.assertIsNone(ref.schema) + self.assertIsNone(p.schema) + self.assertIsNone(schema.project) self.assertNotIn(t1, schema.tables) self.assertNotIn(t2, schema.tables) self.assertNotIn(tg, schema.table_groups) self.assertNotIn(e, schema.enums) + self.assertNotIn(ref, schema.refs) def test_delete_bad(self) -> None: class Test: diff --git a/test/test_tools.py b/test/test_tools.py new file mode 100644 index 0000000..6ba4abc --- /dev/null +++ b/test/test_tools.py @@ -0,0 +1,64 @@ +from unittest import TestCase + +from pydbml.classes import Note +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql +from pydbml.tools import note_option_to_dbml +from pydbml.tools import indent + + +class TestCommentToDBML(TestCase): + def test_comment(self) -> None: + oneline = 'comment' + self.assertEqual(f'// {oneline}\n', comment_to_dbml(oneline)) + + expected = \ +'''// +// line1 +// line2 +// line3 +// +''' + source = '\nline1\nline2\nline3\n' + self.assertEqual(comment_to_dbml(source), expected) + + +class TestCommentToSQL(TestCase): + def test_comment(self) -> None: + oneline = 'comment' + self.assertEqual(f'-- {oneline}\n', comment_to_sql(oneline)) + + expected = \ +'''-- +-- line1 +-- line2 +-- line3 +-- +''' + source = '\nline1\nline2\nline3\n' + self.assertEqual(comment_to_sql(source), expected) + + +class TestNoteOptionToDBML(TestCase): + def test_oneline(self) -> None: + note = Note('one line note') + self.assertEqual(f"note: 'one line note'", note_option_to_dbml(note)) + + def test_multiline(self) -> None: + note = Note('line1\nline2\nline3') + expected = "note: '''line1\nline2\nline3'''" + self.assertEqual(expected, note_option_to_dbml(note)) + + +class TestIndent(TestCase): + def test_empty(self) -> None: + self.assertEqual(indent(''), '') + + def test_nonempty(self) -> None: + oneline = 'one line text' + self.assertEqual(indent(oneline), f' {oneline}') + source = 'line1\nline2\nline3' + expected = ' line1\n line2\n line3' + self.assertEqual(indent(source), expected) + expected2 = ' line1\n line2\n line3' + self.assertEqual(indent(source, 2), expected2) From e9a5413c896ef32428f443f24f639d5c810032c0 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 14 May 2022 18:25:24 +0200 Subject: [PATCH 026/125] add doctests to test suite --- pydbml/__init__.py | 10 +++++----- pydbml/classes/index.py | 26 ++++++++++++++++---------- pydbml/classes/reference.py | 14 ++++++-------- pydbml/schema.py | 18 ++++++++++++++++-- test/test_doctest.py | 25 +++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 25 deletions(-) create mode 100644 test/test_doctest.py diff --git a/pydbml/__init__.py b/pydbml/__init__.py index 1ea4a15..558c211 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -1,5 +1,5 @@ -import doctest -import unittest +# import doctest +# import unittest from . import classes from .parser import PyDBML @@ -9,6 +9,6 @@ from pydbml.constants import ONE_TO_ONE -def load_tests(loader, tests, ignore): - tests.addTests(doctest.DocTestSuite(classes)) - return tests +# def load_tests(loader, tests, ignore): +# tests.addTests(doctest.DocTestSuite(classes)) +# return tests diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index a09d63e..01b56b0 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -42,11 +42,14 @@ def subject_names(self): def __repr__(self): ''' - >>> Index(['name', 'type']) - - >>> t = Table('t') - >>> Index(['name', 'type'], table=t) - + >>> c = Column('col', 'int') + >>> i = Index([c, '(c*2)']) + >>> i + + >>> from .table import Table + >>> Table('test').add_index(i) + >>> i + ''' table_name = self.table.name if self.table else None @@ -54,11 +57,14 @@ def __repr__(self): def __str__(self): ''' - >>> print(Index(['name', 'type'])) - Index([name, type]) - >>> t = Table('t') - >>> print(Index(['name', 'type'], table=t)) - Index(t[name, type]) + >>> c = Column('col', 'int') + >>> i = Index([c, '(c*2)']) + >>> print(i) + Index([col, (c*2)]) + >>> from .table import Table + >>> Table('test').add_index(i) + >>> print(i) + Index(test[col, (c*2)]) ''' table_name = self.table.name if self.table else '' diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index c8279f7..356d3e7 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -61,18 +61,16 @@ def __str__(self): ''' >>> c1 = Column('c1', 'int') >>> c2 = Column('c2', 'int') - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> print(Reference('>', table1=t1, col1=c1, table2=t2, col2=c2)) - Reference(t1[c1] > t2[c2]) + >>> print(Reference('>', col1=c1, col2=c2)) + Reference([c1] > [c2] >>> c12 = Column('c12', 'int') >>> c22 = Column('c22', 'int') - >>> print(Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22))) - Reference(t1[c1, c12] < t2[c2, c22]) + >>> print(Reference('<', col1=[c1, c12], col2=(c2, c22))) + Reference([c1, c12] < [c2, c22] ''' - col1 = ', '.join(f'{c.name!r}' for c in self.col1) - col2 = ', '.join(f'{c.name!r}' for c in self.col2) + col1 = ', '.join(f'{c.name}' for c in self.col1) + col2 = ', '.join(f'{c.name}' for c in self.col2) return f"Reference([{col1}] {self.type} [{col2}]" def _validate(self): diff --git a/pydbml/schema.py b/pydbml/schema.py index 0b01bf1..824ca99 100644 --- a/pydbml/schema.py +++ b/pydbml/schema.py @@ -13,7 +13,8 @@ class Schema: - def __init__(self) -> None: + def __init__(self, name: str = 'public') -> None: + self.name = name self.tables: List['Table'] = [] self.table_dict: Dict[str, 'Table'] = {} self.refs: List['Reference'] = [] @@ -22,7 +23,20 @@ def __init__(self) -> None: self.project: Optional['Project'] = None def __repr__(self) -> str: - return f"" + """ + >>> Schema("private") + + """ + + return f"" + + def __str__(self) -> str: + """ + >>> print(Schema("private")) + + """ + + return f"" def __getitem__(self, k: Union[int, str]) -> Table: if isinstance(k, int): diff --git a/test/test_doctest.py b/test/test_doctest.py new file mode 100644 index 0000000..a7a5341 --- /dev/null +++ b/test/test_doctest.py @@ -0,0 +1,25 @@ +import doctest +import unittest + +from pydbml import schema +from pydbml.classes import column +from pydbml.classes import enum +from pydbml.classes import index +from pydbml.classes import note +from pydbml.classes import project +from pydbml.classes import reference +from pydbml.classes import table +from pydbml.classes import table_group + + +def load_tests(loader, tests, ignore): + tests.addTests(doctest.DocTestSuite(column)) + tests.addTests(doctest.DocTestSuite(enum)) + tests.addTests(doctest.DocTestSuite(index)) + tests.addTests(doctest.DocTestSuite(project)) + tests.addTests(doctest.DocTestSuite(note)) + tests.addTests(doctest.DocTestSuite(reference)) + tests.addTests(doctest.DocTestSuite(schema)) + tests.addTests(doctest.DocTestSuite(table)) + tests.addTests(doctest.DocTestSuite(table_group)) + return tests From d68dc97ef18fc644f9b7b0144aca2172817d5a3f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 14 May 2022 19:14:51 +0200 Subject: [PATCH 027/125] Expression is now class --- changelog.md | 1 + pydbml/classes/__init__.py | 3 ++- pydbml/classes/expression.py | 30 +++++++++++++++++++++++++ pydbml/definitions/generic.py | 4 +++- pydbml/parser/blueprints.py | 25 +++++++++++++++------ test/test_blueprints/test_expression.py | 12 ++++++++++ test/test_blueprints/test_table.py | 5 ++++- test/test_classes/test_expression.py | 12 ++++++++++ test/test_definitions/test_column.py | 10 ++++++--- test/test_definitions/test_generic.py | 17 ++++++++++++++ test/test_definitions/test_index.py | 11 ++++++--- test/test_docs.py | 18 ++++++++++----- test/test_doctest.py | 2 ++ 13 files changed, 128 insertions(+), 22 deletions(-) create mode 100644 pydbml/classes/expression.py create mode 100644 test/test_blueprints/test_expression.py create mode 100644 test/test_classes/test_expression.py create mode 100644 test/test_definitions/test_generic.py diff --git a/changelog.md b/changelog.md index 51abe17..9c9b880 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,7 @@ - refs don't have tables, only columns - tables don't have refs - col1 col2 in ref are as they were in dbml +- Expression class # 0.4.2 diff --git a/pydbml/classes/__init__.py b/pydbml/classes/__init__.py index 4ff08f2..bfdc6e2 100644 --- a/pydbml/classes/__init__.py +++ b/pydbml/classes/__init__.py @@ -1,9 +1,10 @@ from .column import Column -from .table import Table from .enum import Enum from .enum import EnumItem +from .expression import Expression from .index import Index from .note import Note from .project import Project from .reference import Reference +from .table import Table from .table_group import TableGroup diff --git a/pydbml/classes/expression.py b/pydbml/classes/expression.py new file mode 100644 index 0000000..5dcd90a --- /dev/null +++ b/pydbml/classes/expression.py @@ -0,0 +1,30 @@ +from .base import SQLOjbect + + +class Expression(SQLOjbect): + def __init__(self, text: str): + self.text = text + + def __str__(self) -> str: + ''' + >>> print(Expression('sum(amount)')) + sum(amount) + ''' + + return self.text + + def __repr__(self) -> str: + ''' + >>> Expression('sum(amount)') + Expression('sum(amount)') + ''' + + return f'Expression({repr(self.text)})' + + @property + def sql(self) -> str: + return f'({self.text})' + + @property + def dbml(self) -> str: + return f'`{self.text}`' diff --git a/pydbml/definitions/generic.py b/pydbml/definitions/generic.py index 1bb5650..ca05311 100644 --- a/pydbml/definitions/generic.py +++ b/pydbml/definitions/generic.py @@ -1,5 +1,7 @@ import pyparsing as pp +from pydbml.parser.blueprints import ExpressionBlueprint + pp.ParserElement.setDefaultWhitespaceChars(' \t\r') name = pp.Word(pp.alphanums + '_') | pp.QuotedString('"') @@ -15,7 +17,7 @@ pp.Suppress('`') + pp.CharsNotIn('`')[...] + pp.Suppress('`') -).setParseAction(lambda s, l, t: f'({t[0]})') +).setParseAction(lambda s, l, t: ExpressionBlueprint(t[0])) boolean_literal = ( pp.CaselessLiteral('true') diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index b078a0e..fa33964 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -6,18 +6,19 @@ from typing import Optional from typing import Union -from pydbml.constants import MANY_TO_ONE -from pydbml.constants import ONE_TO_MANY -from pydbml.constants import ONE_TO_ONE from pydbml.classes import Column from pydbml.classes import Enum from pydbml.classes import EnumItem +from pydbml.classes import Expression from pydbml.classes import Index from pydbml.classes import Note from pydbml.classes import Project from pydbml.classes import Reference from pydbml.classes import Table from pydbml.classes import TableGroup +from pydbml.constants import MANY_TO_ONE +from pydbml.constants import ONE_TO_MANY +from pydbml.constants import ONE_TO_ONE from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError @@ -34,6 +35,14 @@ def build(self) -> 'Note': return Note(self.text) +@dataclass +class ExpressionBlueprint(Blueprint): + text: str + + def build(self) -> Expression: + return Expression(self.text) + + @dataclass class ReferenceBlueprint(Blueprint): type: Literal[MANY_TO_ONE, ONE_TO_MANY, ONE_TO_ONE] @@ -89,12 +98,14 @@ class ColumnBlueprint(Blueprint): not_null: bool = False pk: bool = False autoinc: bool = False - default: Optional[Union[str, int, bool, float]] = None + default: Optional[Union[str, int, bool, float, ExpressionBlueprint]] = None note: Optional[NoteBlueprint] = None ref_blueprints: Optional[List[ReferenceBlueprint]] = None comment: Optional[str] = None def build(self) -> 'Column': + if isinstance(self.default, ExpressionBlueprint): + self.default = self.default.build() if self.parser: for enum in self.parser.schema.enums: if enum.name == self.type: @@ -115,7 +126,7 @@ def build(self) -> 'Column': @dataclass class IndexBlueprint(Blueprint): - subject_names: List[str] + subject_names: List[Union[str, ExpressionBlueprint]] name: Optional[str] = None unique: bool = False type: Optional[str] = None @@ -164,8 +175,8 @@ def build(self) -> 'Table': index = index_bp.build() new_subjects = [] for subj in index.subjects: - if subj.startswith('(') and subj.endswith(')'): - new_subjects.append(subj) + if isinstance(subj, ExpressionBlueprint): + new_subjects.append(subj.build()) else: for col in result.columns: if col.name == subj: diff --git a/test/test_blueprints/test_expression.py b/test/test_blueprints/test_expression.py new file mode 100644 index 0000000..18ac8dc --- /dev/null +++ b/test/test_blueprints/test_expression.py @@ -0,0 +1,12 @@ +from unittest import TestCase + +from pydbml.classes import Expression +from pydbml.parser.blueprints import ExpressionBlueprint + + +class TestNote(TestCase): + def test_build(self) -> None: + bp = ExpressionBlueprint(text='amount*2') + result = bp.build() + self.assertIsInstance(result, Expression) + self.assertEqual(result.text, bp.text) diff --git a/test/test_blueprints/test_table.py b/test/test_blueprints/test_table.py index 890ab33..4a566a9 100644 --- a/test/test_blueprints/test_table.py +++ b/test/test_blueprints/test_table.py @@ -5,11 +5,13 @@ from pydbml.classes import Table from pydbml.classes import Index from pydbml.classes import Column +from pydbml.classes import Expression from pydbml.parser.blueprints import IndexBlueprint from pydbml.parser.blueprints import NoteBlueprint from pydbml.parser.blueprints import ColumnBlueprint from pydbml.parser.blueprints import TableBlueprint from pydbml.parser.blueprints import ReferenceBlueprint +from pydbml.parser.blueprints import ExpressionBlueprint class TestTable(TestCase): @@ -59,7 +61,7 @@ def test_with_indexes(self) -> None: ], indexes=[ IndexBlueprint(subject_names=['name', 'id'], unique=True), - IndexBlueprint(subject_names=['id', '(name*2)'], name='ExprIndex') + IndexBlueprint(subject_names=['id', ExpressionBlueprint('name*2')], name='ExprIndex') ] ) result = bp.build() @@ -69,6 +71,7 @@ def test_with_indexes(self) -> None: self.assertIsInstance(col, Column) for ind in result.indexes: self.assertIsInstance(ind, Index) + self.assertIsInstance(result.indexes[1].subjects[1], Expression) def test_bad_index(self) -> None: bp = TableBlueprint( diff --git a/test/test_classes/test_expression.py b/test/test_classes/test_expression.py new file mode 100644 index 0000000..5990f3e --- /dev/null +++ b/test/test_classes/test_expression.py @@ -0,0 +1,12 @@ +from pydbml.classes import Expression +from unittest import TestCase + + +class TestNote(TestCase): + def test_sql(self): + e = Expression('SUM(amount)') + self.assertEqual(e.sql, '(SUM(amount))') + + def test_dbml(self): + e = Expression('SUM(amount)') + self.assertEqual(e.dbml, '`SUM(amount)`') diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index 98ef3e7..963be43 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -4,6 +4,7 @@ from pyparsing import ParseSyntaxException from pyparsing import ParserElement +from pydbml.parser.blueprints import ExpressionBlueprint from pydbml.definitions.column import column_setting from pydbml.definitions.column import column_settings from pydbml.definitions.column import column_type @@ -59,11 +60,14 @@ def test_expression(self) -> None: val2 = f"default: `{expr2}`" val3 = f"default: ``" res = default.parseString(val, parseAll=True) - self.assertEqual(res[0], f'({expr1})') + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, expr1) res = default.parseString(val2, parseAll=True) - self.assertEqual(res[0], f'({expr2})') + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, expr2) res = default.parseString(val3, parseAll=True) - self.assertEqual(res[0], '()') + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, '') def test_bool(self) -> None: vals = ['true', 'false', 'null'] diff --git a/test/test_definitions/test_generic.py b/test/test_definitions/test_generic.py new file mode 100644 index 0000000..8570515 --- /dev/null +++ b/test/test_definitions/test_generic.py @@ -0,0 +1,17 @@ +from unittest import TestCase + +from pyparsing import ParserElement + +from pydbml.definitions.generic import expression_literal +from pydbml.parser.blueprints import ExpressionBlueprint + + +ParserElement.setDefaultWhitespaceChars(' \t\r') + + +class TestExpressionLiteral(TestCase): + def test_expression_literal(self) -> None: + val = '`SUM(amount)`' + res = expression_literal.parseString(val) + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, 'SUM(amount)') diff --git a/test/test_definitions/test_index.py b/test/test_definitions/test_index.py index 0dc432b..318d192 100644 --- a/test/test_definitions/test_index.py +++ b/test/test_definitions/test_index.py @@ -12,6 +12,7 @@ from pydbml.definitions.index import indexes from pydbml.definitions.index import single_index_syntax from pydbml.definitions.index import subject +from pydbml.parser.blueprints import ExpressionBlueprint ParserElement.setDefaultWhitespaceChars(' \t\r') @@ -105,7 +106,8 @@ def test_name(self) -> None: def test_expression(self) -> None: val = '`id*3`' res = subject.parseString(val, parseAll=True) - self.assertEqual(res[0], '(id*3)') + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, 'id*3') def test_wrong(self) -> None: val = '12d*(' @@ -165,7 +167,8 @@ def test_single(self) -> None: def test_expression(self) -> None: val = '(`id*3`)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['(id*3)']) + self.assertIsInstance(res[0].subject_names[0], ExpressionBlueprint) + self.assertEqual(res[0].subject_names[0].text, 'id*3') def test_composite(self) -> None: val = '(my_column, my_another_column)' @@ -175,7 +178,9 @@ def test_composite(self) -> None: def test_composite_with_expression(self) -> None: val = '(`id*3`, fieldname)' res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['(id*3)', 'fieldname']) + self.assertIsInstance(res[0].subject_names[0], ExpressionBlueprint) + self.assertEqual(res[0].subject_names[0].text, 'id*3') + self.assertEqual(res[0].subject_names[1], 'fieldname') def test_with_settings(self) -> None: val = '(my_column, my_another_column) [unique]' diff --git a/test/test_docs.py b/test/test_docs.py index 1816a1c..66191cb 100644 --- a/test/test_docs.py +++ b/test/test_docs.py @@ -9,8 +9,7 @@ from unittest import TestCase from pydbml import PyDBML -from pydbml.exceptions import ColumnNotFoundError -from pydbml.exceptions import TableNotFoundError +from pydbml.classes import Expression TEST_DOCS_PATH = Path(os.path.abspath(__file__)).parent / 'test_data/docs' @@ -95,7 +94,8 @@ def test_default_value(self) -> None: self.assertEqual([c.name for c in table.columns], ['id', 'username', 'full_name', 'gender', 'created_at', 'rating']) *_, gender, created_at, rating = table.columns self.assertEqual(gender.default, 'm') - self.assertEqual(created_at.default, '(now())') + self.assertIsInstance(created_at.default, Expression) + self.assertEqual(created_at.default.text, 'now()') self.assertEqual(rating.default, 10) def test_index_definition(self) -> None: @@ -124,11 +124,17 @@ def test_index_definition(self) -> None: self.assertEqual(ix[4].subjects, [table['booking_date']]) self.assertEqual(ix[4].type, 'hash') - self.assertEqual(ix[5].subjects, ['(id*2)']) + self.assertEqual(len(ix[5].subjects), 1) + self.assertIsInstance(ix[5].subjects[0], Expression) + self.assertEqual(ix[5].subjects[0].text, 'id*2') - self.assertEqual(ix[6].subjects, ['(id*3)', '(getdate())']) + self.assertEqual(len(ix[6].subjects), 2) + self.assertIsInstance(ix[6].subjects[0], Expression) + self.assertIsInstance(ix[6].subjects[1], Expression) + self.assertEqual(ix[6].subjects[0].text, 'id*3') + self.assertEqual(ix[6].subjects[1].text, 'getdate()') - self.assertEqual(ix[7].subjects, ['(id*3)', table['id']]) + self.assertEqual(ix[7].subjects, [Expression('id*3'), table['id']]) def test_relationships(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'relationships_1.dbml') diff --git a/test/test_doctest.py b/test/test_doctest.py index a7a5341..b4122a1 100644 --- a/test/test_doctest.py +++ b/test/test_doctest.py @@ -4,6 +4,7 @@ from pydbml import schema from pydbml.classes import column from pydbml.classes import enum +from pydbml.classes import expression from pydbml.classes import index from pydbml.classes import note from pydbml.classes import project @@ -15,6 +16,7 @@ def load_tests(loader, tests, ignore): tests.addTests(doctest.DocTestSuite(column)) tests.addTests(doctest.DocTestSuite(enum)) + tests.addTests(doctest.DocTestSuite(expression)) tests.addTests(doctest.DocTestSuite(index)) tests.addTests(doctest.DocTestSuite(project)) tests.addTests(doctest.DocTestSuite(note)) From b3ecfdf6cd76160e10d6f6c6ffec8d948fc568ad Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 15 May 2022 10:47:39 +0200 Subject: [PATCH 028/125] 100% test coverage --- pydbml/classes/column.py | 15 +++-- pydbml/classes/index.py | 44 +++++++++----- pydbml/classes/project.py | 2 +- pydbml/classes/reference.py | 3 +- pydbml/classes/table.py | 4 +- pydbml/classes/table_group.py | 4 +- pydbml/definitions/index.py | 3 +- pydbml/parser/__init__.py | 2 +- pydbml/parser/blueprints.py | 21 ++++--- pydbml/parser/parser.py | 56 +++++++++-------- pydbml/schema.py | 8 ++- pydbml/tools.py | 8 ++- test.sh | 2 +- test/test_blueprints/test_index.py | 4 +- test/test_blueprints/test_reference.py | 3 + test/test_blueprints/test_table_group.py | 3 + test/test_classes/test_column.py | 3 +- test/test_classes/test_index.py | 18 ++++-- test/test_classes/test_project.py | 12 +++- test/test_classes/test_table.py | 3 +- test/test_data/integration1.dbml | 44 ++++++++++++++ test/test_data/integration1.sql | 34 +++++++++++ test/test_doctest.py | 2 + ...t_generate_dbml.py => test_integration.py} | 60 +++++++++++++++++-- test/test_parser.py | 10 ++++ test/test_schema.py | 3 + 26 files changed, 286 insertions(+), 85 deletions(-) create mode 100644 test/test_data/integration1.dbml create mode 100644 test/test_data/integration1.sql rename test/{test_generate_dbml.py => test_integration.py} (52%) diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 13e8dab..9f39c74 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -5,12 +5,13 @@ from .base import SQLOjbect from .note import Note +from .expression import Expression from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql from pydbml.tools import note_option_to_dbml from pydbml.exceptions import TableNotFoundError -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from .table import Table from .reference import Reference @@ -27,7 +28,7 @@ def __init__(self, not_null: bool = False, pk: bool = False, autoinc: bool = False, - default: Optional[Union[str, int, bool, float]] = None, + default: Optional[Union[str, int, bool, float, Expression]] = None, note: Optional[Union['Note', str]] = None, # ref_blueprints: Optional[List[ReferenceBlueprint]] = None, comment: Optional[str] = None): @@ -74,7 +75,9 @@ def sql(self): if self.not_null: components.append('NOT NULL') if self.default is not None: - components.append('DEFAULT ' + str(self.default)) + default = self.default.sql \ + if isinstance(self.default, Expression) else self.default + components.append(f'DEFAULT {default}') result = comment_to_sql(self.comment) if self.comment else '' result += ' '.join(components) @@ -82,14 +85,14 @@ def sql(self): @property def dbml(self): - def default_to_str(val: str) -> str: + def default_to_str(val: Union[Expression, str]) -> str: if isinstance(val, str): if val.lower() in ('null', 'true', 'false'): return val.lower() - elif val.startswith('(') and val.endswith(')'): - return f'`{val[1:-1]}`' else: return f"'{val}'" + elif isinstance(val, Expression): + return val.dbml else: # int or float or bool return val diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index 01b56b0..ef5022f 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -1,21 +1,26 @@ from typing import Optional from typing import Union from typing import List +from typing import TYPE_CHECKING from .base import SQLOjbect from .note import Note from .column import Column +from .expression import Expression from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql from pydbml.tools import note_option_to_dbml +if TYPE_CHECKING: # pragma: no cover + from .table import Table + class Index(SQLOjbect): '''Class representing index.''' required_attributes = ('subjects', 'table') def __init__(self, - subjects: List[Union[str, 'Column']], + subjects: List[Union[str, 'Column', 'Expression']], name: Optional[str] = None, unique: bool = False, type_: Optional[str] = None, @@ -24,7 +29,7 @@ def __init__(self, comment: Optional[str] = None): self.schema = None self.subjects = subjects - self.table = None + self.table: Optional['Table'] = None self.name = name if name else None self.unique = unique @@ -36,9 +41,9 @@ def __init__(self, @property def subject_names(self): ''' - For backward compatibility. Returns updated list of subject names. + Returns updated list of subject names. ''' - return [s.name if isinstance(s, Column) else s for s in self.subjects] + return [s.name if isinstance(s, Column) else str(s) for s in self.subjects] def __repr__(self): ''' @@ -86,7 +91,16 @@ def sql(self): ''' self.check_attributes_for_sql() - keys = ', '.join(f'"{key.name}"' if isinstance(key, Column) else key for key in self.subjects) + subjects = [] + + for subj in self.subjects: + if isinstance(subj, Column): + subjects.append(f'"{subj.name}"') + elif isinstance(subj, Expression): + subjects.append(subj.sql) + else: + subjects.append(subj) + keys = ', '.join(subj for subj in subjects) if self.pk: result = comment_to_sql(self.comment) if self.comment else '' result += f'PRIMARY KEY ({keys})' @@ -108,20 +122,22 @@ def sql(self): @property def dbml(self): - def subject_to_str(val: str) -> str: - if val.startswith('(') and val.endswith(')'): - return f'`{val[1:-1]}`' + subjects = [] + + for subj in self.subjects: + if isinstance(subj, Column): + subjects.append(subj.name) + elif isinstance(subj, Expression): + subjects.append(subj.dbml) else: - return val + subjects.append(subj) result = comment_to_dbml(self.comment) if self.comment else '' - subject_names = self.subject_names - - if len(subject_names) > 1: - result += f'({", ".join(subject_to_str(sn) for sn in subject_names)})' + if len(subjects) > 1: + result += f'({", ".join(subj for subj in subjects)})' else: - result += subject_to_str(subject_names[0]) + result += subjects[0] options = [] if self.name: diff --git a/pydbml/classes/project.py b/pydbml/classes/project.py index 5f70020..0661813 100644 --- a/pydbml/classes/project.py +++ b/pydbml/classes/project.py @@ -30,7 +30,7 @@ def __repr__(self): @property def dbml(self): result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Project {self.name} {{\n' + result += f'Project "{self.name}" {{\n' if self.items: items_str = '' for k, v in self.items.items(): diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 356d3e7..fdc7a22 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -6,7 +6,6 @@ from .base import SQLOjbect from .column import Column from pydbml.constants import MANY_TO_ONE -from pydbml.constants import ONE_TO_MANY from pydbml.constants import ONE_TO_ONE from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql @@ -23,7 +22,7 @@ class Reference(SQLOjbect): required_attributes = ('type', 'col1', 'col2') def __init__(self, - type_: Literal[MANY_TO_ONE, ONE_TO_MANY, ONE_TO_ONE], + type_: Literal['>', '<', '-'], col1: Union[Column, Collection[Column]], col2: Union[Column, Collection[Column]], name: Optional[str] = None, diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 3bb151c..53676e3 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -19,7 +19,7 @@ from pydbml.tools import indent -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from pydbml.schema import Schema @@ -35,7 +35,7 @@ def __init__(self, header_color: Optional[str] = None, # refs: Optional[List[TableReference]] = None, comment: Optional[str] = None): - self.schema: Schema = None + self.schema: Optional[Schema] = None self.name = name self.columns: List[Column] = [] self.indexes: List[Index] = [] diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py index a34fdd4..450caf7 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/classes/table_group.py @@ -15,7 +15,7 @@ class TableGroup: def __init__(self, name: str, - items: Union[List[str], List[Table]], + items: List[Table], comment: Optional[str] = None): self.schema = None self.name = name @@ -37,7 +37,7 @@ def __repr__(self): items = [i if isinstance(i, str) else i.name for i in self.items] return f'' - def __getitem__(self, key: int) -> str: + def __getitem__(self, key: int) -> Table: return self.items[key] def __iter__(self): diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index 8ab13ce..176d3ef 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -1,6 +1,7 @@ import pyparsing as pp from pydbml.parser.blueprints import IndexBlueprint +from pydbml.parser.blueprints import ExpressionBlueprint from .common import _ from .common import _c @@ -84,7 +85,7 @@ def parse_index(s, l, t): ] ''' init_dict = {} - if isinstance(t['subject'], str): + if isinstance(t['subject'], (str, ExpressionBlueprint)): subjects = [t['subject']] else: subjects = list(t['subject']) diff --git a/pydbml/parser/__init__.py b/pydbml/parser/__init__.py index db3c175..aa03f88 100644 --- a/pydbml/parser/__init__.py +++ b/pydbml/parser/__init__.py @@ -1 +1 @@ -from .parser import PyDBML, parse +from .parser import PyDBML diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index fa33964..7192c34 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -4,6 +4,7 @@ from typing import List from typing import Literal from typing import Optional +from typing import Any from typing import Union from pydbml.classes import Column @@ -16,9 +17,6 @@ from pydbml.classes import Reference from pydbml.classes import Table from pydbml.classes import TableGroup -from pydbml.constants import MANY_TO_ONE -from pydbml.constants import ONE_TO_MANY -from pydbml.constants import ONE_TO_ONE from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError @@ -45,7 +43,7 @@ def build(self) -> Expression: @dataclass class ReferenceBlueprint(Blueprint): - type: Literal[MANY_TO_ONE, ONE_TO_MANY, ONE_TO_ONE] + type: Literal['>', '<', '-'] inline: bool name: Optional[str] = None table1: Optional[str] = None @@ -69,7 +67,10 @@ def build(self) -> 'Reference': if self.col2 is None: raise ColumnNotFoundError("Can't build Reference, col2 unknown") - table1 = self.parser.locate_table(self.table1) + if self.parser: + table1 = self.parser.locate_table(self.table1) + else: + raise RuntimeError('Parser is not set') col1_list = [c.strip('() ') for c in self.col1.split(',')] col1 = [table1[col] for col in col1_list] @@ -98,7 +99,7 @@ class ColumnBlueprint(Blueprint): not_null: bool = False pk: bool = False autoinc: bool = False - default: Optional[Union[str, int, bool, float, ExpressionBlueprint]] = None + default: Optional[Any] = None note: Optional[NoteBlueprint] = None ref_blueprints: Optional[List[ReferenceBlueprint]] = None comment: Optional[str] = None @@ -139,7 +140,7 @@ class IndexBlueprint(Blueprint): def build(self) -> 'Index': return Index( # TableBlueprint will process subjects - subjects=list(self.subject_names), + subjects=[], name=self.name, unique=self.unique, type_=self.type, @@ -173,8 +174,8 @@ def build(self) -> 'Table': result.add_column(col_bp.build()) for index_bp in indexes: index = index_bp.build() - new_subjects = [] - for subj in index.subjects: + new_subjects: List[Union[str, Column, Expression]] = [] + for subj in index_bp.subject_names: if isinstance(subj, ExpressionBlueprint): new_subjects.append(subj.build()) else: @@ -253,6 +254,8 @@ class TableGroupBlueprint(Blueprint): comment: Optional[str] = None def build(self) -> 'TableGroup': + if not self.parser: + raise RuntimeError('Parser is not set') return TableGroup( name=self.name, items=[self.parser.locate_table(table) for table in self.items], diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 06cc38e..a6797d9 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -23,27 +23,28 @@ from pydbml.definitions.table_group import table_group from pydbml.exceptions import TableNotFoundError from pydbml.schema import Schema +from pydbml.tools import remove_bom + pp.ParserElement.setDefaultWhitespaceChars(' \t\r') class PyDBML: ''' - PyDBML parser factory. If properly initiated, returns PyDBMLParseResults - which contains parse results in attributes. + PyDBML parser factory. If properly initiated, returns parsed Schema. Usage option 1: - >>> with open('schema.dbml') as f: + >>> with open('test_schema.dbml') as f: ... p = PyDBML(f) ... # or ... p = PyDBML(f.read()) Usage option 2: - >>> p = PyDBML.parse_file('schema.dbml') + >>> p = PyDBML.parse_file('test_schema.dbml') >>> # or >>> from pathlib import Path - >>> p = PyDBML(Path('schema.dbml')) + >>> p = PyDBML(Path('test_schema.dbml')) ''' def __new__(cls, @@ -54,55 +55,53 @@ def __new__(cls, elif isinstance(source_, Path): with open(source_, encoding='utf8') as f: source = f.read() - else: # TextIOWrapper + elif isinstance(source_, TextIOWrapper): source = source_.read() - if source[0] == '\ufeff': # removing BOM - source = source[1:] + else: + raise TypeError('Source must be str, path or file stream') + + source = remove_bom(source) return cls.parse(source) else: return super().__new__(cls) def __repr__(self): + """ + >>> PyDBML() + + """ + return "" @staticmethod - def parse(text: str) -> PyDBMLParser: - if text[0] == '\ufeff': # removing BOM - text = text[1:] - + def parse(text: str) -> Schema: + text = remove_bom(text) parser = PyDBMLParser(text) return parser.parse() @staticmethod - def parse_file(file: Union[str, Path, TextIOWrapper]) -> PyDBMLParser: + def parse_file(file: Union[str, Path, TextIOWrapper]) -> Schema: if isinstance(file, TextIOWrapper): source = file.read() else: with open(file, encoding='utf8') as f: source = f.read() - if source[0] == '\ufeff': # removing BOM - source = source[1:] + source = remove_bom(source) parser = PyDBMLParser(source) return parser.parse() -def parse(source: str): - parser = PyDBMLParser(source) - return parser.parse() - - class PyDBMLParser: def __init__(self, source: str): self.schema = None self.ref_blueprints: List[ReferenceBlueprint] = [] - self.table_groups = [] + self.table_groups: List[TableGroupBlueprint] = [] self.source = source - self.tables = [] - self.refs = [] - self.enums = [] - self.table_groups = [] - self.project = None + self.tables: List[TableGroupBlueprint] = [] + self.refs: List[ReferenceBlueprint] = [] + self.enums: List[EnumBlueprint] = [] + self.project: Optional[ProjectBlueprint] = None def parse(self): self._set_syntax() @@ -111,6 +110,11 @@ def parse(self): return self.schema def __repr__(self): + """ + >>> PyDBMLParser('') + + """ + return "" def _set_syntax(self): diff --git a/pydbml/schema.py b/pydbml/schema.py index 824ca99..ea25d1a 100644 --- a/pydbml/schema.py +++ b/pydbml/schema.py @@ -41,8 +41,10 @@ def __str__(self) -> str: def __getitem__(self, k: Union[int, str]) -> Table: if isinstance(k, int): return self.tables[k] - else: + elif isinstance(k, str): return self.table_dict[k] + else: + raise TypeError('indeces must be str or int') def __iter__(self): return iter(self.tables) @@ -85,7 +87,7 @@ def add_table(self, obj: Table) -> Table: def add_reference(self, obj: Reference): for col in (*obj.col1, *obj.col2): - if col.table.schema == self: + if col.table and col.table.schema == self: break else: raise SchemaValidationError( @@ -198,7 +200,7 @@ def sql(self): @property def dbml(self): '''Generates DBML code out of parsed results''' - items = (self.project) if self.project else () + items = [self.project] if self.project else [] refs = (ref for ref in self.refs if not ref.inline) items.extend((*self.enums, *self.tables, *refs, *self.table_groups)) components = ( diff --git a/pydbml/tools.py b/pydbml/tools.py index 2e07f64..48df6da 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -1,5 +1,5 @@ from typing import TYPE_CHECKING -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from .classes import Note @@ -26,3 +26,9 @@ def indent(val: str, spaces=4) -> str: if val == '': return val return ' ' * spaces + val.replace('\n', '\n' + ' ' * spaces) + + +def remove_bom(source: str) -> str: + if source and source[0] == '\ufeff': + source = source[1:] + return source diff --git a/test.sh b/test.sh index 8fd1cdf..3b88011 100755 --- a/test.sh +++ b/test.sh @@ -1,3 +1,3 @@ python3 -m doctest README.md &&\ python3 -m unittest discover &&\ - mypy . --ignore-missing-imports + mypy pydbml --ignore-missing-imports diff --git a/test/test_blueprints/test_index.py b/test/test_blueprints/test_index.py index 64f5df5..6c09e26 100644 --- a/test/test_blueprints/test_index.py +++ b/test/test_blueprints/test_index.py @@ -13,7 +13,7 @@ def test_build_minimal(self) -> None: ) result = bp.build() self.assertIsInstance(result, Index) - self.assertEqual(result.subjects, bp.subject_names) + self.assertEqual(result.subject_names, []) def test_build_full(self) -> None: bp = IndexBlueprint( @@ -27,7 +27,7 @@ def test_build_full(self) -> None: ) result = bp.build() self.assertIsInstance(result, Index) - self.assertEqual(result.subject_names, bp.subject_names) + self.assertEqual(result.subject_names, []) self.assertEqual(result.name, bp.name) self.assertEqual(result.unique, bp.unique) self.assertEqual(result.type, bp.type) diff --git a/test/test_blueprints/test_reference.py b/test/test_blueprints/test_reference.py index 493a17b..7bf3b10 100644 --- a/test/test_blueprints/test_reference.py +++ b/test/test_blueprints/test_reference.py @@ -31,6 +31,9 @@ def test_build_minimal(self) -> None: c2 = Column(name='col2', type_='Varchar') t2.add_column(c2) + with self.assertRaises(RuntimeError): + bp.build() + parserMock = Mock() parserMock.locate_table.side_effect = [t1, t2] bp.parser = parserMock diff --git a/test/test_blueprints/test_table_group.py b/test/test_blueprints/test_table_group.py index f76a823..69331dc 100644 --- a/test/test_blueprints/test_table_group.py +++ b/test/test_blueprints/test_table_group.py @@ -13,6 +13,9 @@ def test_build(self) -> None: items=['table1', 'table2'], comment='Comment text' ) + with self.assertRaises(RuntimeError): + bp.build() + parserMock = Mock() parserMock.locate_table.side_effect = [ Table(name='table1'), diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index 42a5420..49a71ff 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -2,6 +2,7 @@ from pydbml.schema import Schema from pydbml.classes import Column +from pydbml.classes import Expression from pydbml.classes import Table from pydbml.classes import Reference from pydbml.classes import Note @@ -166,7 +167,7 @@ def test_dbml_default(self): expected = '"order" integer [default: 3.33]' self.assertEqual(c.dbml, expected) - c.default = "(now() - interval '5 days')" + c.default = Expression("now() - interval '5 days'") expected = "\"order\" integer [default: `now() - interval '5 days'`]" self.assertEqual(c.dbml, expected) diff --git a/test/test_classes/test_index.py b/test/test_classes/test_index.py index f2bd5ee..324d993 100644 --- a/test/test_classes/test_index.py +++ b/test/test_classes/test_index.py @@ -3,6 +3,7 @@ from pydbml.classes import Index from pydbml.classes import Table from pydbml.classes import Column +from pydbml.classes import Expression class TestIndex(TestCase): @@ -15,6 +16,15 @@ def test_basic_sql(self) -> None: expected = 'CREATE INDEX ON "products" ("id");' self.assertEqual(r.sql, expected) + def test_basic_sql_str(self) -> None: + t = Table('products') + t.add_column(Column('id', 'integer')) + r = Index(subjects=['id']) + t.add_index(r) + self.assertIs(r.table, t) + expected = 'CREATE INDEX ON "products" (id);' + self.assertEqual(r.sql, expected) + def test_comment(self) -> None: t = Table('products') t.add_column(Column('id', 'integer')) @@ -62,9 +72,9 @@ def test_pk(self) -> None: def test_composite_with_expression(self) -> None: t = Table('products') t.add_column(Column('id', 'integer')) - r = Index(subjects=[t.columns[0], '(id*3)']) + r = Index(subjects=[t.columns[0], Expression('id*3')]) t.add_index(r) - self.assertEqual(r.subjects, [t['id'], '(id*3)']) + self.assertEqual(r.subjects, [t['id'], Expression('id*3')]) expected = 'CREATE INDEX ON "products" ("id", (id*3));' self.assertEqual(r.sql, expected) @@ -80,7 +90,7 @@ def test_dbml_simple(self): def test_dbml_composite(self): t = Table('products') t.add_column(Column('id', 'integer')) - i = Index(subjects=[t.columns[0], '(id*3)']) + i = Index(subjects=[t.columns[0], Expression('id*3')]) t.add_index(i) expected = '(id, `id*3`)' @@ -90,7 +100,7 @@ def test_dbml_full(self): t = Table('products') t.add_column(Column('id', 'integer')) i = Index( - subjects=[t.columns[0], '(getdate())'], + subjects=[t.columns[0], Expression('getdate()')], name='Dated id', unique=True, type_='hash', diff --git a/test/test_classes/test_project.py b/test/test_classes/test_project.py index cb68829..df86fa5 100644 --- a/test/test_classes/test_project.py +++ b/test/test_classes/test_project.py @@ -7,7 +7,7 @@ class TestProject(TestCase): def test_dbml_note(self): p = Project('myproject', note='Project note') expected = \ -'''Project myproject { +'''Project "myproject" { Note { 'Project note' } @@ -26,7 +26,7 @@ def test_dbml_full(self): expected = \ """// Multiline // Project comment -Project myproject { +Project "myproject" { database_type: 'PostgreSQL' story: '''One day I was eating my cantaloupe and I thought, why shouldn't I? @@ -39,3 +39,11 @@ def test_dbml_full(self): } }""" self.assertEqual(p.dbml, expected) + + def test_dbml_space(self) -> None: + p = Project('My project', {'a': 'b'}) + expected = \ +'''Project "My project" { + a: 'b' +}''' + self.assertEqual(p.dbml, expected) diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 37b4805..461bb59 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -1,6 +1,7 @@ from unittest import TestCase from pydbml.classes import Column +from pydbml.classes import Expression from pydbml.classes import Index from pydbml.classes import Note from pydbml.classes import Reference @@ -363,7 +364,7 @@ def test_dbml_full(self): t.add_column(c1) t.add_column(c2) i1 = Index(['zero', 'id'], unique=True) - i2 = Index(['(capitalize(name))'], comment="index comment") + i2 = Index([Expression('capitalize(name)')], comment="index comment") t.add_index(i1) t.add_index(i2) s = Schema() diff --git a/test/test_data/integration1.dbml b/test/test_data/integration1.dbml new file mode 100644 index 0000000..e0a99f9 --- /dev/null +++ b/test/test_data/integration1.dbml @@ -0,0 +1,44 @@ +Project "my project" { + author: 'me' + reason: 'testing' +} + +Enum "level" { + "junior" + "middle" + "senior" +} + +Table "Employees" as "emp" { + "id" integer [pk, increment] + "name" varchar [note: 'Full employee name'] + "age" number + "level" level + "favorite_book_id" integer +} + +Table "books" { + "id" integer [pk, increment] + "title" varchar + "author" varchar + "country_id" integer +} + +Table "countries" { + "id" integer [ref: < "books"."country_id", pk, increment] + "name" varchar2 [unique] + + indexes { + name [unique] + `UPPER(name)` + } +} + +Ref { + "Employees"."favorite_book_id" > "books"."id" +} + +TableGroup Unanimate { + books + countries +} \ No newline at end of file diff --git a/test/test_data/integration1.sql b/test/test_data/integration1.sql new file mode 100644 index 0000000..726f25a --- /dev/null +++ b/test/test_data/integration1.sql @@ -0,0 +1,34 @@ +CREATE TYPE "level" AS ENUM ( + 'junior', + 'middle', + 'senior', +); + +CREATE TABLE "Employees" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "name" varchar, + "age" number DEFAULT 0, + "level" level, + "favorite_book_id" integer +); + +COMMENT ON COLUMN "Employees"."name" IS 'Full employee name'; + +CREATE TABLE "books" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "title" varchar, + "author" varchar, + "country_id" integer, + CONSTRAINT "Country Reference" FOREIGN KEY ("country_id") REFERENCES "countries" ("id") +); + +CREATE TABLE "countries" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "name" varchar2 UNIQUE +); + +CREATE UNIQUE INDEX ON "countries" ("name"); + +CREATE INDEX ON "countries" ((UPPER(name))); + +ALTER TABLE "Employees" ADD FOREIGN KEY ("favorite_book_id") REFERENCES "books" ("id"); \ No newline at end of file diff --git a/test/test_doctest.py b/test/test_doctest.py index b4122a1..2ba1272 100644 --- a/test/test_doctest.py +++ b/test/test_doctest.py @@ -11,6 +11,7 @@ from pydbml.classes import reference from pydbml.classes import table from pydbml.classes import table_group +from pydbml.parser import parser def load_tests(loader, tests, ignore): @@ -24,4 +25,5 @@ def load_tests(loader, tests, ignore): tests.addTests(doctest.DocTestSuite(schema)) tests.addTests(doctest.DocTestSuite(table)) tests.addTests(doctest.DocTestSuite(table_group)) + tests.addTests(doctest.DocTestSuite(parser)) return tests diff --git a/test/test_generate_dbml.py b/test/test_integration.py similarity index 52% rename from test/test_generate_dbml.py rename to test/test_integration.py index 84d9e52..e0e3f1e 100644 --- a/test/test_generate_dbml.py +++ b/test/test_integration.py @@ -2,6 +2,7 @@ from pathlib import Path from unittest import TestCase +from unittest.mock import patch, Mock from pydbml.classes import Column from pydbml.classes import Enum @@ -9,18 +10,19 @@ from pydbml.classes import Project from pydbml.classes import Reference from pydbml.classes import Index +from pydbml.classes import Expression from pydbml.classes import Table from pydbml.classes import TableGroup from pydbml.classes import Note -from pydbml.exceptions import SchemaValidationError from pydbml.schema import Schema +from pydbml import PyDBML TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' class TestGenerateDBML(TestCase): - def test_generate_dbml(self) -> None: + def create_schema(self) -> Schema: schema = Schema() emp_level = Enum( 'level', @@ -61,15 +63,61 @@ def test_generate_dbml(self) -> None: c32 = Column('name', 'varchar2', unique=True) t3.add_column(c31) t3.add_column(c32) - i31 = Index([c32]) + i31 = Index([c32], unique=True) t3.add_index(i31) - # TODO: expression class - # i32 = Index(['']) + i32 = Index([Expression('UPPER(name)')]) + t3.add_index(i32) schema.add(t3) ref1 = Reference('>', c15, c21) schema.add(ref1) - ref2 = Reference('<', c31, c24, name='Country Reference') + ref2 = Reference('<', c31, c24, name='Country Reference', inline=True) schema.add(ref2) + tg = TableGroup('Unanimate', [t2, t3]) + schema.add(tg) + + p = Project('my project', {'author': 'me', 'reason': 'testing'}) + schema.add(p) + return schema + + def test_generate_dbml(self) -> None: + schema = self.create_schema() + with open(TEST_DATA_PATH / 'integration1.dbml') as f: + expected = f.read() + self.assertEqual(schema.dbml, expected) + + def test_generate_sql(self) -> None: + schema = self.create_schema() + with open(TEST_DATA_PATH / 'integration1.sql') as f: + expected = f.read() + self.assertEqual(schema.sql, expected) + + def test_parser(self): + source_path = TEST_DATA_PATH / 'integration1.dbml' + with self.assertRaises(TypeError): + PyDBML(2) + res1 = PyDBML(source_path) + self.assertIsInstance(res1, Schema) + with open(source_path) as f: + res2 = PyDBML(f) + self.assertIsInstance(res2, Schema) + with open(source_path) as f: + source = f.read() + res3 = PyDBML(source) + self.assertIsInstance(res3, Schema) + res4 = PyDBML('\ufeff' + source) + self.assertIsInstance(res4, Schema) + + pydbml = PyDBML() + self.assertIsInstance(pydbml, PyDBML) + res5 = pydbml.parse(source) + self.assertIsInstance(res5, Schema) + res6 = PyDBML.parse('\ufeff' + source) + self.assertIsInstance(res6, Schema) + res7 = PyDBML.parse_file(str(source_path)) + self.assertIsInstance(res7, Schema) + with open(source_path) as f: + res8 = PyDBML.parse_file(f) + self.assertIsInstance(res8, Schema) diff --git a/test/test_parser.py b/test/test_parser.py index 2c35025..cf3f15a 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -4,6 +4,7 @@ from unittest import TestCase from pydbml import PyDBML +from pydbml.parser.parser import PyDBMLParser from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError @@ -94,3 +95,12 @@ def test_bad_reference(self) -> None: def test_bad_index(self) -> None: with self.assertRaises(ColumnNotFoundError): PyDBML(TEST_DATA_PATH / 'wrong_index.dbml') + + +class TestPyDBMLParser(TestCase): + def test_edge(self) -> None: + p = PyDBMLParser('') + with self.assertRaises(RuntimeError): + p.locate_table('test') + with self.assertRaises(RuntimeError): + p.parse_blueprint(1, 1, [1]) diff --git a/test/test_schema.py b/test/test_schema.py index 213e3ee..a94ab2d 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -34,6 +34,7 @@ def test_add_table_alias(self) -> None: t.add_column(c) schema = Schema() schema.add_table(t) + self.assertIsInstance(t.alias, str) self.assertIs(schema[t.alias], t) def test_add_table_alias_bad(self) -> None: @@ -274,6 +275,8 @@ def test_geititem(self) -> None: self.assertIs(schema['table2'], t2) self.assertIs(schema[0], t1) self.assertIs(schema[1], t2) + with self.assertRaises(TypeError): + schema[None] with self.assertRaises(IndexError): schema[2] with self.assertRaises(KeyError): From 0b1e71f70d2b92871d270e6d7b8f151458e9b63f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 17 May 2022 08:14:26 +0200 Subject: [PATCH 029/125] Rename Schema -> Database --- TODO.md | 9 +- changelog.md | 1 + pydbml/classes/column.py | 4 +- pydbml/classes/enum.py | 2 +- pydbml/classes/index.py | 2 +- pydbml/classes/project.py | 2 +- pydbml/classes/reference.py | 2 +- pydbml/classes/table.py | 18 +- pydbml/classes/table_group.py | 2 +- pydbml/{schema.py => database.py} | 85 +++---- pydbml/definitions/common.py | 5 +- pydbml/definitions/table.py | 6 +- pydbml/exceptions.py | 4 +- pydbml/parser/blueprints.py | 4 +- pydbml/parser/parser.py | 36 +-- test/test_blueprints/test_column.py | 6 +- test/test_classes/test_column.py | 38 +-- test/test_classes/test_table.py | 30 +-- test/test_definitions/test_common.py | 9 + test/test_doctest.py | 4 +- test/test_integration.py | 48 ++-- test/test_schema.py | 358 +++++++++++++-------------- 22 files changed, 343 insertions(+), 332 deletions(-) rename pydbml/{schema.py => database.py} (68%) diff --git a/TODO.md b/TODO.md index 1349156..57ada15 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,9 @@ -- Creating dbml schema in python +* - Creating dbml schema in python - pyparsing new var names (+possibly new features) * - enum type - `_type` -> `type` -- expression class -- schema.add and .delete to support multiple arguments (handle errors properly) \ No newline at end of file +* - expression class +- schema.add and .delete to support multiple arguments (handle errors properly) +- 2.3.1 Multiline comment /* ... */ +- 2.4 Multiple Schemas +- validation on "add_index", "add_table" etc \ No newline at end of file diff --git a/changelog.md b/changelog.md index 9c9b880..7721a20 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ - tables don't have refs - col1 col2 in ref are as they were in dbml - Expression class +- add multiline comment # 0.4.2 diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 9f39c74..55024b2 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -53,8 +53,8 @@ def get_refs(self) -> List['Reference']: return [ref for ref in self.table.get_refs() if self in ref.col1] @property - def schema(self): - return self.table.schema if self.table else None + def database(self): + return self.table.database if self.table else None @property def sql(self): diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index 74284d5..ac84c2d 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -59,7 +59,7 @@ def __init__(self, name: str, items: List['EnumItem'], comment: Optional[str] = None): - self.schema = None + self.database = None self.name = name self.items = items self.comment = comment diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index ef5022f..354c3a1 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -27,7 +27,7 @@ def __init__(self, pk: bool = False, note: Optional[Union['Note', str]] = None, comment: Optional[str] = None): - self.schema = None + self.database = None self.subjects = subjects self.table: Optional['Table'] = None diff --git a/pydbml/classes/project.py b/pydbml/classes/project.py index 0661813..582f8c4 100644 --- a/pydbml/classes/project.py +++ b/pydbml/classes/project.py @@ -13,7 +13,7 @@ def __init__(self, items: Optional[Dict[str, str]] = None, note: Optional[Union['Note', str]] = None, comment: Optional[str] = None): - self.schema = None + self.database = None self.name = name self.items = items self.note = Note(note) diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index fdc7a22..4ed997e 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -30,7 +30,7 @@ def __init__(self, on_update: Optional[str] = None, on_delete: Optional[str] = None, inline: bool = False): - self.schema = None + self.database = None self.type = type_ self.col1 = [col1] if isinstance(col1, Column) else list(col1) self.col2 = [col2] if isinstance(col2, Column) else list(col2) diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 53676e3..c5847f6 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -13,14 +13,14 @@ from pydbml.constants import ONE_TO_ONE from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import IndexNotFoundError -from pydbml.exceptions import UnknownSchemaError +from pydbml.exceptions import UnknownDatabaseError from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql from pydbml.tools import indent if TYPE_CHECKING: # pragma: no cover - from pydbml.schema import Schema + from pydbml.database import Database class Table(SQLOjbect): @@ -35,7 +35,7 @@ def __init__(self, header_color: Optional[str] = None, # refs: Optional[List[TableReference]] = None, comment: Optional[str] = None): - self.schema: Optional[Schema] = None + self.database: Optional[Database] = None self.name = name self.columns: List[Column] = [] self.indexes: List[Index] = [] @@ -85,18 +85,18 @@ def delete_index(self, i: Union[Index, int]) -> Index: return self.indexes.pop(i) def get_refs(self) -> List[Reference]: - if not self.schema: - raise UnknownSchemaError('Schema for the table is not set') - return [ref for ref in self.schema.refs if ref.col1[0].table == self] + if not self.database: + raise UnknownDatabaseError('Database for the table is not set') + return [ref for ref in self.database.refs if ref.col1[0].table == self] def _get_references_for_sql(self) -> List[Reference]: ''' return inline references for this table sql definition ''' - if not self.schema: - raise UnknownSchemaError(f'Schema for the table {self} is not set') + if not self.database: + raise UnknownDatabaseError(f'Database for the table {self} is not set') result = [] - for ref in self.schema.refs: + for ref in self.database.refs: if ref.inline: if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ (ref.col1[0].table == self): diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py index 450caf7..dbba53a 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/classes/table_group.py @@ -17,7 +17,7 @@ def __init__(self, name: str, items: List[Table], comment: Optional[str] = None): - self.schema = None + self.database = None self.name = name self.items = items self.comment = comment diff --git a/pydbml/schema.py b/pydbml/database.py similarity index 68% rename from pydbml/schema.py rename to pydbml/database.py index ea25d1a..ab7b4e1 100644 --- a/pydbml/schema.py +++ b/pydbml/database.py @@ -9,12 +9,11 @@ from .classes import Reference from .classes import Table from .classes import TableGroup -from .exceptions import SchemaValidationError +from .exceptions import DatabaseValidationError -class Schema: - def __init__(self, name: str = 'public') -> None: - self.name = name +class Database: + def __init__(self) -> None: self.tables: List['Table'] = [] self.table_dict: Dict[str, 'Table'] = {} self.refs: List['Reference'] = [] @@ -24,19 +23,11 @@ def __init__(self, name: str = 'public') -> None: def __repr__(self) -> str: """ - >>> Schema("private") - + >>> Database() + """ - return f"" - - def __str__(self) -> str: - """ - >>> print(Schema("private")) - - """ - - return f"" + return f"" def __getitem__(self, k: Union[int, str]) -> Table: if isinstance(k, int): @@ -49,11 +40,11 @@ def __getitem__(self, k: Union[int, str]) -> Table: def __iter__(self): return iter(self.tables) - def _set_schema(self, obj: Any) -> None: - obj.schema = self + def _set_database(self, obj: Any) -> None: + obj.database = self - def _unset_schema(self, obj: Any) -> None: - obj.schema = None + def _unset_database(self, obj: Any) -> None: + obj.database = None def add(self, obj: Any) -> Any: if isinstance(obj, Table): @@ -67,17 +58,17 @@ def add(self, obj: Any) -> Any: elif isinstance(obj, Project): return self.add_project(obj) else: - raise SchemaValidationError(f'Unsupported type {type(obj)}.') + raise DatabaseValidationError(f'Unsupported type {type(obj)}.') def add_table(self, obj: Table) -> Table: if obj in self.tables: - raise SchemaValidationError(f'{obj} is already in the schema.') + raise DatabaseValidationError(f'{obj} is already in the database.') if obj.name in self.table_dict: - raise SchemaValidationError(f'Table {obj.name} is already in the schema.') + raise DatabaseValidationError(f'Table {obj.name} is already in the database.') if obj.alias and obj.alias in self.table_dict: - raise SchemaValidationError(f'Table {obj.alias} is already in the schema.') + raise DatabaseValidationError(f'Table {obj.alias} is already in the database.') - self._set_schema(obj) + self._set_database(obj) self.tables.append(obj) self.table_dict[obj.name] = obj @@ -87,46 +78,46 @@ def add_table(self, obj: Table) -> Table: def add_reference(self, obj: Reference): for col in (*obj.col1, *obj.col2): - if col.table and col.table.schema == self: + if col.table and col.table.database == self: break else: - raise SchemaValidationError( + raise DatabaseValidationError( 'Cannot add reference. At least one of the referenced tables' - ' should belong to this schema' + ' should belong to this database' ) if obj in self.refs: - raise SchemaValidationError(f'{obj} is already in the schema.') + raise DatabaseValidationError(f'{obj} is already in the database.') - self._set_schema(obj) + self._set_database(obj) self.refs.append(obj) return obj def add_enum(self, obj: Enum) -> Enum: if obj in self.enums: - raise SchemaValidationError(f'{obj} is already in the schema.') + raise DatabaseValidationError(f'{obj} is already in the database.') for enum in self.enums: if enum.name == obj.name: - raise SchemaValidationError(f'Enum {obj.name} is already in the schema.') + raise DatabaseValidationError(f'Enum {obj.name} is already in the database.') - self._set_schema(obj) + self._set_database(obj) self.enums.append(obj) return obj def add_table_group(self, obj: TableGroup) -> TableGroup: if obj in self.table_groups: - raise SchemaValidationError(f'{obj} is already in the schema.') + raise DatabaseValidationError(f'{obj} is already in the database.') for table_group in self.table_groups: if table_group.name == obj.name: - raise SchemaValidationError(f'TableGroup {obj.name} is already in the schema.') + raise DatabaseValidationError(f'TableGroup {obj.name} is already in the database.') - self._set_schema(obj) + self._set_database(obj) self.table_groups.append(obj) return obj def add_project(self, obj: Project) -> Project: if self.project: self.delete_project() - self._set_schema(obj) + self._set_database(obj) self.project = obj return obj @@ -142,14 +133,14 @@ def delete(self, obj: Any) -> Any: elif isinstance(obj, Project): return self.delete_project() else: - raise SchemaValidationError(f'Unsupported type {type(obj)}.') + raise DatabaseValidationError(f'Unsupported type {type(obj)}.') def delete_table(self, obj: Table) -> Table: try: index = self.tables.index(obj) except ValueError: - raise SchemaValidationError(f'{obj} is not in the schema.') - self._unset_schema(self.tables.pop(index)) + raise DatabaseValidationError(f'{obj} is not in the database.') + self._unset_database(self.tables.pop(index)) result = self.table_dict.pop(obj.name) if obj.alias: self.table_dict.pop(obj.alias) @@ -159,35 +150,35 @@ def delete_reference(self, obj: Reference) -> Reference: try: index = self.refs.index(obj) except ValueError: - raise SchemaValidationError(f'{obj} is not in the schema.') + raise DatabaseValidationError(f'{obj} is not in the database.') result = self.refs.pop(index) - self._unset_schema(result) + self._unset_database(result) return result def delete_enum(self, obj: Enum) -> Enum: try: index = self.enums.index(obj) except ValueError: - raise SchemaValidationError(f'{obj} is not in the schema.') + raise DatabaseValidationError(f'{obj} is not in the database.') result = self.enums.pop(index) - self._unset_schema(result) + self._unset_database(result) return result def delete_table_group(self, obj: TableGroup) -> TableGroup: try: index = self.table_groups.index(obj) except ValueError: - raise SchemaValidationError(f'{obj} is not in the schema.') + raise DatabaseValidationError(f'{obj} is not in the database.') result = self.table_groups.pop(index) - self._unset_schema(result) + self._unset_database(result) return result def delete_project(self) -> Project: if self.project is None: - raise SchemaValidationError(f'Project is not set.') + raise DatabaseValidationError(f'Project is not set.') result = self.project self.project = None - self._unset_schema(result) + self._unset_database(result) return result @property diff --git a/pydbml/definitions/common.py b/pydbml/definitions/common.py index 620f6db..bdde161 100644 --- a/pydbml/definitions/common.py +++ b/pydbml/definitions/common.py @@ -6,7 +6,10 @@ pp.ParserElement.setDefaultWhitespaceChars(' \t\r') -comment = pp.Suppress("//") + pp.SkipTo(pp.LineEnd()) +comment = ( + pp.Suppress("//") + pp.SkipTo(pp.LineEnd()) + | pp.Suppress('/*') + ... + pp.Suppress('*/') +) # optional comment or newline _ = ('\n' | comment)[...].suppress() diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index 59622f8..c5377a0 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -47,9 +47,11 @@ def parse_table_settings(s, l, t): table_body = table_column[1, ...]('columns') + _ + table_element[...] +table_name = (name('schema') + '.' + name('name')) | (name('name')) + table = _c + ( pp.CaselessLiteral("table").suppress() - + name('name') + + table_name + alias('alias')[0, 1] + table_settings('settings')[0, 1] + _ + '{' - table_body + _ + '}' @@ -72,6 +74,8 @@ def parse_table(s, l, t): init_dict = { 'name': t['name'], } + # if 'schema' in t: + # init_dict['schema'] = t['schema'] if 'settings' in t: init_dict.update(t['settings']) if 'alias' in t: diff --git a/pydbml/exceptions.py b/pydbml/exceptions.py index b1c8cab..757c434 100644 --- a/pydbml/exceptions.py +++ b/pydbml/exceptions.py @@ -18,7 +18,7 @@ class DuplicateReferenceError(Exception): pass -class UnknownSchemaError(Exception): +class UnknownDatabaseError(Exception): pass @@ -26,5 +26,5 @@ class DBMLError(Exception): pass -class SchemaValidationError(Exception): +class DatabaseValidationError(Exception): pass diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 7192c34..bc34ed7 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -108,7 +108,7 @@ def build(self) -> 'Column': if isinstance(self.default, ExpressionBlueprint): self.default = self.default.build() if self.parser: - for enum in self.parser.schema.enums: + for enum in self.parser.database.enums: if enum.name == self.type: self.type = enum break @@ -153,7 +153,7 @@ def build(self) -> 'Index': @dataclass class TableBlueprint(Blueprint): name: str - columns: Optional[List[ColumnBlueprint]] = None # TODO: should it be optional? + columns: List[ColumnBlueprint] = None indexes: Optional[List[IndexBlueprint]] = None alias: Optional[str] = None note: Optional[NoteBlueprint] = None diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index a6797d9..c4f18be 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -22,7 +22,7 @@ from pydbml.definitions.table import table from pydbml.definitions.table_group import table_group from pydbml.exceptions import TableNotFoundError -from pydbml.schema import Schema +from pydbml.database import Database from pydbml.tools import remove_bom @@ -31,7 +31,7 @@ class PyDBML: ''' - PyDBML parser factory. If properly initiated, returns parsed Schema. + PyDBML parser factory. If properly initiated, returns parsed Database. Usage option 1: @@ -74,13 +74,13 @@ def __repr__(self): return "" @staticmethod - def parse(text: str) -> Schema: + def parse(text: str) -> Database: text = remove_bom(text) parser = PyDBMLParser(text) return parser.parse() @staticmethod - def parse_file(file: Union[str, Path, TextIOWrapper]) -> Schema: + def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: if isinstance(file, TextIOWrapper): source = file.read() else: @@ -93,7 +93,7 @@ def parse_file(file: Union[str, Path, TextIOWrapper]) -> Schema: class PyDBMLParser: def __init__(self, source: str): - self.schema = None + self.database = None self.ref_blueprints: List[ReferenceBlueprint] = [] self.table_groups: List[TableGroupBlueprint] = [] @@ -106,8 +106,8 @@ def __init__(self, source: str): def parse(self): self._set_syntax() self._syntax.parseString(self.source, parseAll=True) - self.build_schema() - return self.schema + self.build_database() + return self.database def __repr__(self): """ @@ -166,24 +166,24 @@ def parse_blueprint(self, s, l, t): blueprint.parser = self def locate_table(self, name: str) -> 'Table': - if not self.schema: - raise RuntimeError('Schema is not ready') + if not self.database: + raise RuntimeError('Database is not ready') try: - result = self.schema[name] + result = self.database[name] except KeyError: - raise TableNotFoundError(f'Table {name} not present in the schema') + raise TableNotFoundError(f'Table {name} not present in the database') return result - def build_schema(self): - self.schema = Schema() + def build_database(self): + self.database = Database() for enum_bp in self.enums: - self.schema.add(enum_bp.build()) + self.database.add(enum_bp.build()) for table_bp in self.tables: - self.schema.add(table_bp.build()) + self.database.add(table_bp.build()) self.ref_blueprints.extend(table_bp.get_reference_blueprints()) for table_group_bp in self.table_groups: - self.schema.add(table_group_bp.build()) + self.database.add(table_group_bp.build()) if self.project: - self.schema.add(self.project.build()) + self.database.add(self.project.build()) for ref_bp in self.refs: - self.schema.add(ref_bp.build()) + self.database.add(ref_bp.build()) diff --git a/test/test_blueprints/test_column.py b/test/test_blueprints/test_column.py index cd3e2e8..0f51a85 100644 --- a/test/test_blueprints/test_column.py +++ b/test/test_blueprints/test_column.py @@ -7,7 +7,7 @@ from pydbml.classes import Note from pydbml.parser.blueprints import ColumnBlueprint from pydbml.parser.blueprints import NoteBlueprint -from pydbml.schema import Schema +from pydbml.database import Database class TestColumn(TestCase): @@ -47,7 +47,7 @@ def test_build_full(self) -> None: self.assertEqual(result.comment, bp.comment) def test_enum_type(self) -> None: - s = Schema() + s = Database() e = Enum( 'myenum', items=[ @@ -57,7 +57,7 @@ def test_enum_type(self) -> None: ) s.add(e) parser = Mock() - parser.schema = s + parser.database = s bp = ColumnBlueprint( name='testcol', diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index 49a71ff..c95da1c 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -1,6 +1,6 @@ from unittest import TestCase -from pydbml.schema import Schema +from pydbml.database import Database from pydbml.classes import Column from pydbml.classes import Expression from pydbml.classes import Table @@ -41,15 +41,15 @@ def test_attributes(self) -> None: self.assertEqual(col.note, note) self.assertEqual(col.comment, comment) - def test_schema_set(self) -> None: + def test_database_set(self) -> None: col = Column('name', 'int') table = Table('name') - self.assertIsNone(col.schema) + self.assertIsNone(col.database) table.add_column(col) - self.assertIsNone(col.schema) - schema = Schema() - schema.add(table) - self.assertIs(col.schema, schema) + self.assertIsNone(col.database) + database = Database() + database.add(table) + self.assertIs(col.database, database) def test_basic_sql(self) -> None: r = Column(name='id', @@ -98,7 +98,7 @@ def test_dbml_simple(self): ) t = Table(name='Test') t.add_column(c) - s = Schema() + s = Database() s.add(t) expected = '"order" integer' @@ -118,7 +118,7 @@ def test_dbml_full(self): ) t = Table(name='Test') t.add_column(c) - s = Schema() + s = Database() s.add(t) expected = \ '''// Comment on the column @@ -136,7 +136,7 @@ def test_dbml_multiline_note(self): ) t = Table(name='Test') t.add_column(c) - s = Schema() + s = Database() s.add(t) expected = \ """// Comment on the column @@ -153,7 +153,7 @@ def test_dbml_default(self): ) t = Table(name='Test') t.add_column(c) - s = Schema() + s = Database() s.add(t) expected = "\"order\" integer [default: 'String value']" @@ -183,16 +183,16 @@ def test_dbml_default(self): expected = '"order" integer [default: false]' self.assertEqual(c.dbml, expected) - def test_schema(self): + def test_database(self): c1 = Column(name='client_id', type_='integer') t1 = Table(name='products') - self.assertIsNone(c1.schema) + self.assertIsNone(c1.database) t1.add_column(c1) - self.assertIsNone(c1.schema) - s = Schema() + self.assertIsNone(c1.database) + s = Database() s.add(t1) - self.assertIs(c1.schema, s) + self.assertIs(c1.database, s) def test_get_refs(self) -> None: c1 = Column(name='client_id', type_='integer') @@ -205,7 +205,7 @@ def test_get_refs(self) -> None: t2.add_column(c2) ref = Reference(type_='>', col1=c1, col2=c2, inline=True) - s = Schema() + s = Database() s.add(t1) s.add(t2) s.add(ref) @@ -221,7 +221,7 @@ def test_dbml_with_ref(self) -> None: t2.add_column(c2) ref = Reference(type_='>', col1=c1, col2=c2) - s = Schema() + s = Database() s.add(t1) s.add(t2) s.add(ref) @@ -243,7 +243,7 @@ def test_dbml_with_ref_and_properties(self) -> None: t2.add_column(c2) ref = Reference(type_='<', col1=c2, col2=c1) - s = Schema() + s = Database() s.add(t1) s.add(t2) s.add(ref) diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 461bb59..ea5c0ca 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -8,8 +8,8 @@ from pydbml.classes import Table from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import IndexNotFoundError -from pydbml.exceptions import UnknownSchemaError -from pydbml.schema import Schema +from pydbml.exceptions import UnknownDatabaseError +from pydbml.database import Database class TestTable(TestCase): @@ -17,7 +17,7 @@ def test_one_column(self) -> None: t = Table('products') c = Column('id', 'integer') t.add_column(c) - s = Schema() + s = Database() s.add(t) expected = 'CREATE TABLE "products" (\n "id" integer\n);' self.assertEqual(t.sql, expected) @@ -74,7 +74,7 @@ def test_ref(self) -> None: t2 = Table('names') c21 = Column('name_val', 'varchar2') t2.add_column(c21) - s = Schema() + s = Database() s.add(t) s.add(t2) r = Reference('>', c2, c21) @@ -105,7 +105,7 @@ def test_notes(self) -> None: t.add_column(c1) t.add_column(c2) t.add_column(c3) - s = Schema() + s = Database() s.add(t) expected = \ '''CREATE TABLE "products" ( @@ -131,7 +131,7 @@ def test_ref_index(self) -> None: t2 = Table('names') c21 = Column('name_val', 'varchar2') t2.add_column(c21) - s = Schema() + s = Database() s.add(t) r = Reference('>', c2, c21, inline=True) @@ -156,7 +156,7 @@ def test_index_inline(self) -> None: t.add_column(c2) i = Index(subjects=[c1, c2], pk=True) t.add_index(i) - s = Schema() + s = Database() s.add(t) expected = \ @@ -175,7 +175,7 @@ def test_index_inline_and_comments(self) -> None: t.add_column(c2) i = Index(subjects=[c1, c2], pk=True, comment='Multiline\nindex comment') t.add_index(i) - s = Schema() + s = Database() s.add(t) expected = \ @@ -250,7 +250,7 @@ def test_delete_index(self) -> None: def test_get_references_for_sql(self): t = Table('products') - with self.assertRaises(UnknownSchemaError): + with self.assertRaises(UnknownDatabaseError): t._get_references_for_sql() c11 = Column('id', 'integer') c12 = Column('name', 'varchar2') @@ -261,7 +261,7 @@ def test_get_references_for_sql(self): c22 = Column('name_val', 'varchar2') t2.add_column(c21) t2.add_column(c22) - s = Schema() + s = Database() s.add(t) s.add(t2) r1 = Reference('>', c12, c22) @@ -278,7 +278,7 @@ def test_get_references_for_sql(self): def test_get_refs(self): t = Table('products') - with self.assertRaises(UnknownSchemaError): + with self.assertRaises(UnknownDatabaseError): t.get_refs() c11 = Column('id', 'integer') c12 = Column('name', 'varchar2') @@ -289,7 +289,7 @@ def test_get_refs(self): c22 = Column('name_val', 'varchar2') t2.add_column(c21) t2.add_column(c22) - s = Schema() + s = Database() s.add(t) s.add(t2) r1 = Reference('>', c12, c22) @@ -307,7 +307,7 @@ def test_dbml_simple(self): c2 = Column('name', 'varchar2') t.add_column(c1) t.add_column(c2) - s = Schema() + s = Database() s.add(t) expected = \ @@ -326,7 +326,7 @@ def test_dbml_reference(self): t2 = Table('names') c21 = Column('name_val', 'varchar2') t2.add_column(c21) - s = Schema() + s = Database() s.add(t) s.add(t2) r = Reference('>', c2, c21) @@ -367,7 +367,7 @@ def test_dbml_full(self): i2 = Index([Expression('capitalize(name)')], comment="index comment") t.add_index(i1) t.add_index(i2) - s = Schema() + s = Database() s.add(t) expected = \ diff --git a/test/test_definitions/test_common.py b/test/test_definitions/test_common.py index 3376d8f..fa3b27c 100644 --- a/test/test_definitions/test_common.py +++ b/test/test_definitions/test_common.py @@ -23,6 +23,15 @@ def test_comment_endline(self) -> None: res = comment.parseString(val) self.assertEqual(res[0], 'test comment') + def test_multiline_comment(self) -> None: + val = '/*test comment*/' + res = comment.parseString(val) + self.assertEqual(res[0], 'test comment') + + val2 = '/*\nline1\nline2\nline3\n*/' + res2 = comment.parseString(val2) + self.assertEqual(res2[0], '\nline1\nline2\nline3\n') + class Test_c(TestCase): def test_comment(self) -> None: diff --git a/test/test_doctest.py b/test/test_doctest.py index 2ba1272..48688f9 100644 --- a/test/test_doctest.py +++ b/test/test_doctest.py @@ -1,7 +1,7 @@ import doctest import unittest -from pydbml import schema +from pydbml import database from pydbml.classes import column from pydbml.classes import enum from pydbml.classes import expression @@ -22,7 +22,7 @@ def load_tests(loader, tests, ignore): tests.addTests(doctest.DocTestSuite(project)) tests.addTests(doctest.DocTestSuite(note)) tests.addTests(doctest.DocTestSuite(reference)) - tests.addTests(doctest.DocTestSuite(schema)) + tests.addTests(doctest.DocTestSuite(database)) tests.addTests(doctest.DocTestSuite(table)) tests.addTests(doctest.DocTestSuite(table_group)) tests.addTests(doctest.DocTestSuite(parser)) diff --git a/test/test_integration.py b/test/test_integration.py index e0e3f1e..7e45f66 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -14,7 +14,7 @@ from pydbml.classes import Table from pydbml.classes import TableGroup from pydbml.classes import Note -from pydbml.schema import Schema +from pydbml.database import Database from pydbml import PyDBML @@ -22,8 +22,8 @@ class TestGenerateDBML(TestCase): - def create_schema(self) -> Schema: - schema = Schema() + def create_database(self) -> Database: + database = Database() emp_level = Enum( 'level', [ @@ -32,7 +32,7 @@ def create_schema(self) -> Schema: EnumItem('senior'), ] ) - schema.add(emp_level) + database.add(emp_level) t1 = Table('Employees', alias='emp') c11 = Column('id', 'integer', pk=True, autoinc=True) @@ -45,7 +45,7 @@ def create_schema(self) -> Schema: t1.add_column(c13) t1.add_column(c14) t1.add_column(c15) - schema.add(t1) + database.add(t1) t2 = Table('books') c21 = Column('id', 'integer', pk=True, autoinc=True) @@ -56,7 +56,7 @@ def create_schema(self) -> Schema: t2.add_column(c22) t2.add_column(c23) t2.add_column(c24) - schema.add(t2) + database.add(t2) t3 = Table('countries') c31 = Column('id', 'integer', pk=True, autoinc=True) @@ -67,57 +67,57 @@ def create_schema(self) -> Schema: t3.add_index(i31) i32 = Index([Expression('UPPER(name)')]) t3.add_index(i32) - schema.add(t3) + database.add(t3) ref1 = Reference('>', c15, c21) - schema.add(ref1) + database.add(ref1) ref2 = Reference('<', c31, c24, name='Country Reference', inline=True) - schema.add(ref2) + database.add(ref2) tg = TableGroup('Unanimate', [t2, t3]) - schema.add(tg) + database.add(tg) p = Project('my project', {'author': 'me', 'reason': 'testing'}) - schema.add(p) - return schema + database.add(p) + return database def test_generate_dbml(self) -> None: - schema = self.create_schema() + database = self.create_database() with open(TEST_DATA_PATH / 'integration1.dbml') as f: expected = f.read() - self.assertEqual(schema.dbml, expected) + self.assertEqual(database.dbml, expected) def test_generate_sql(self) -> None: - schema = self.create_schema() + database = self.create_database() with open(TEST_DATA_PATH / 'integration1.sql') as f: expected = f.read() - self.assertEqual(schema.sql, expected) + self.assertEqual(database.sql, expected) def test_parser(self): source_path = TEST_DATA_PATH / 'integration1.dbml' with self.assertRaises(TypeError): PyDBML(2) res1 = PyDBML(source_path) - self.assertIsInstance(res1, Schema) + self.assertIsInstance(res1, Database) with open(source_path) as f: res2 = PyDBML(f) - self.assertIsInstance(res2, Schema) + self.assertIsInstance(res2, Database) with open(source_path) as f: source = f.read() res3 = PyDBML(source) - self.assertIsInstance(res3, Schema) + self.assertIsInstance(res3, Database) res4 = PyDBML('\ufeff' + source) - self.assertIsInstance(res4, Schema) + self.assertIsInstance(res4, Database) pydbml = PyDBML() self.assertIsInstance(pydbml, PyDBML) res5 = pydbml.parse(source) - self.assertIsInstance(res5, Schema) + self.assertIsInstance(res5, Database) res6 = PyDBML.parse('\ufeff' + source) - self.assertIsInstance(res6, Schema) + self.assertIsInstance(res6, Database) res7 = PyDBML.parse_file(str(source_path)) - self.assertIsInstance(res7, Schema) + self.assertIsInstance(res7, Database) with open(source_path) as f: res8 = PyDBML.parse_file(f) - self.assertIsInstance(res8, Schema) + self.assertIsInstance(res8, Database) diff --git a/test/test_schema.py b/test/test_schema.py index a94ab2d..d8c3210 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -10,314 +10,314 @@ from pydbml.classes import Reference from pydbml.classes import Table from pydbml.classes import TableGroup -from pydbml.exceptions import SchemaValidationError -from pydbml.schema import Schema +from pydbml.exceptions import DatabaseValidationError +from pydbml.database import Database TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' -class TestSchema(TestCase): +class TestDatabase(TestCase): def test_add_table(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table') t.add_column(c) - schema = Schema() - res = schema.add_table(t) - self.assertEqual(t.schema, schema) + database = Database() + res = database.add_table(t) + self.assertEqual(t.database, database) self.assertIs(res, t) - self.assertIn(t, schema.tables) + self.assertIn(t, database.tables) def test_add_table_alias(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table', alias='myalias') t.add_column(c) - schema = Schema() - schema.add_table(t) + database = Database() + database.add_table(t) self.assertIsInstance(t.alias, str) - self.assertIs(schema[t.alias], t) + self.assertIs(database[t.alias], t) def test_add_table_alias_bad(self) -> None: c = Column('test', 'varchar', True) t = Table('myalias') t.add_column(c) - schema = Schema() - schema.add_table(t) + database = Database() + database.add_table(t) t2 = Table('test_table', alias='myalias') - with self.assertRaises(SchemaValidationError): - schema.add_table(t2) - self.assertIsNone(t2.schema) + with self.assertRaises(DatabaseValidationError): + database.add_table(t2) + self.assertIsNone(t2.database) def test_add_table_bad(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table') t.add_column(c) - schema = Schema() - schema.add_table(t) - with self.assertRaises(SchemaValidationError): - schema.add_table(t) + database = Database() + database.add_table(t) + with self.assertRaises(DatabaseValidationError): + database.add_table(t) t2 = Table('test_table') - with self.assertRaises(SchemaValidationError): - schema.add_table(t2) + with self.assertRaises(DatabaseValidationError): + database.add_table(t2) def test_delete_table(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table', alias='myalias') t.add_column(c) - schema = Schema() - schema.add_table(t) - res = schema.delete_table(t) - self.assertIsNone(t.schema, schema) + database = Database() + database.add_table(t) + res = database.delete_table(t) + self.assertIsNone(t.database, database) self.assertIs(res, t) - self.assertNotIn(t, schema.tables) - self.assertNotIn('test_table', schema.table_dict) - self.assertNotIn('myalias', schema.table_dict) + self.assertNotIn(t, database.tables) + self.assertNotIn('test_table', database.table_dict) + self.assertNotIn('myalias', database.table_dict) def test_delete_missing_table(self) -> None: t = Table('test_table') - schema = Schema() - with self.assertRaises(SchemaValidationError): - schema.delete_table(t) - self.assertIsNone(t.schema, schema) + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_table(t) + self.assertIsNone(t.database, database) def test_add_reference(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table') t.add_column(c) - schema = Schema() - schema.add_table(t) + database = Database() + database.add_table(t) c2 = Column('test2', 'integer') t2 = Table('test_table2') t2.add_column(c2) - schema.add_table(t2) + database.add_table(t2) ref = Reference('>', c, c2) - res = schema.add_reference(ref) - self.assertEqual(ref.schema, schema) + res = database.add_reference(ref) + self.assertEqual(ref.database, database) self.assertIs(res, ref) - self.assertIn(ref, schema.refs) + self.assertIn(ref, database.refs) def test_add_reference_bad(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table') t.add_column(c) - schema = Schema() - schema.add_table(t) + database = Database() + database.add_table(t) c2 = Column('test2', 'integer') t2 = Table('test_table2') t2.add_column(c2) - schema.add_table(t2) + database.add_table(t2) ref = Reference('>', c, c2) - schema.add_reference(ref) - with self.assertRaises(SchemaValidationError): - schema.add_reference(ref) + database.add_reference(ref) + with self.assertRaises(DatabaseValidationError): + database.add_reference(ref) c3 = Column('test', 'varchar', True) t3 = Table('test_table') t3.add_column(c3) - schema3 = Schema() - schema3.add_table(t3) + database3 = Database() + database3.add_table(t3) c32 = Column('test2', 'integer') t32 = Table('test_table2') t32.add_column(c32) - schema3.add_table(t32) + database3.add_table(t32) ref3 = Reference('>', c3, c32) - with self.assertRaises(SchemaValidationError): - schema.add_reference(ref3) + with self.assertRaises(DatabaseValidationError): + database.add_reference(ref3) def test_delete_reference(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table') t.add_column(c) - schema = Schema() - schema.add_table(t) + database = Database() + database.add_table(t) c2 = Column('test2', 'integer') t2 = Table('test_table2') t2.add_column(c2) - schema.add_table(t2) + database.add_table(t2) ref = Reference('>', c, c2) - res = schema.add_reference(ref) - res = schema.delete_reference(ref) - self.assertIsNone(ref.schema, schema) + res = database.add_reference(ref) + res = database.delete_reference(ref) + self.assertIsNone(ref.database, database) self.assertIs(res, ref) - self.assertNotIn(ref, schema.refs) + self.assertNotIn(ref, database.refs) def test_delete_missing_reference(self) -> None: c = Column('test', 'varchar', True) t = Table('test_table') t.add_column(c) - schema = Schema() - schema.add_table(t) + database = Database() + database.add_table(t) c2 = Column('test2', 'integer') t2 = Table('test_table2') t2.add_column(c2) - schema.add_table(t2) + database.add_table(t2) ref = Reference('>', c, c2) - with self.assertRaises(SchemaValidationError): - schema.delete_reference(ref) - self.assertIsNone(ref.schema) + with self.assertRaises(DatabaseValidationError): + database.delete_reference(ref) + self.assertIsNone(ref.database) def test_add_enum(self) -> None: e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) - schema = Schema() - res = schema.add_enum(e) - self.assertEqual(e.schema, schema) + database = Database() + res = database.add_enum(e) + self.assertEqual(e.database, database) self.assertIs(res, e) - self.assertIn(e, schema.enums) + self.assertIn(e, database.enums) def test_add_enum_bad(self) -> None: e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) - schema = Schema() - schema.add_enum(e) - with self.assertRaises(SchemaValidationError): - schema.add_enum(e) + database = Database() + database.add_enum(e) + with self.assertRaises(DatabaseValidationError): + database.add_enum(e) e2 = Enum('myenum', [EnumItem('a2'), EnumItem('b2')]) - with self.assertRaises(SchemaValidationError): - schema.add_enum(e2) + with self.assertRaises(DatabaseValidationError): + database.add_enum(e2) def test_delete_enum(self) -> None: e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) - schema = Schema() - schema.add_enum(e) - res = schema.delete_enum(e) - self.assertIsNone(e.schema) + database = Database() + database.add_enum(e) + res = database.delete_enum(e) + self.assertIsNone(e.database) self.assertIs(res, e) - self.assertNotIn(e, schema.enums) + self.assertNotIn(e, database.enums) def test_delete_missing_enum(self) -> None: e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) - schema = Schema() - with self.assertRaises(SchemaValidationError): - schema.delete_enum(e) - self.assertIsNone(e.schema) + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_enum(e) + self.assertIsNone(e.database) def test_add_table_group(self) -> None: t1 = Table('table1') t2 = Table('table2') tg = TableGroup('mytablegroup', [t1, t2]) - schema = Schema() - res = schema.add_table_group(tg) - self.assertEqual(tg.schema, schema) + database = Database() + res = database.add_table_group(tg) + self.assertEqual(tg.database, database) self.assertIs(res, tg) - self.assertIn(tg, schema.table_groups) + self.assertIn(tg, database.table_groups) def test_add_table_group_bad(self) -> None: t1 = Table('table1') t2 = Table('table2') tg = TableGroup('mytablegroup', [t1, t2]) - schema = Schema() - schema.add_table_group(tg) - with self.assertRaises(SchemaValidationError): - schema.add_table_group(tg) + database = Database() + database.add_table_group(tg) + with self.assertRaises(DatabaseValidationError): + database.add_table_group(tg) tg2 = TableGroup('mytablegroup', [t2]) - with self.assertRaises(SchemaValidationError): - schema.add_table_group(tg2) + with self.assertRaises(DatabaseValidationError): + database.add_table_group(tg2) def test_delete_table_group(self) -> None: t1 = Table('table1') t2 = Table('table2') tg = TableGroup('mytablegroup', [t1, t2]) - schema = Schema() - schema.add_table_group(tg) - res = schema.delete_table_group(tg) - self.assertIsNone(tg.schema) + database = Database() + database.add_table_group(tg) + res = database.delete_table_group(tg) + self.assertIsNone(tg.database) self.assertIs(res, tg) - self.assertNotIn(tg, schema.table_groups) + self.assertNotIn(tg, database.table_groups) def test_delete_missing_table_group(self) -> None: t1 = Table('table1') t2 = Table('table2') tg = TableGroup('mytablegroup', [t1, t2]) - schema = Schema() - with self.assertRaises(SchemaValidationError): - schema.delete_table_group(tg) - self.assertIsNone(tg.schema) + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_table_group(tg) + self.assertIsNone(tg.database) def test_add_project(self) -> None: p = Project('myproject') - schema = Schema() - res = schema.add_project(p) - self.assertEqual(p.schema, schema) + database = Database() + res = database.add_project(p) + self.assertEqual(p.database, database) self.assertIs(res, p) - self.assertIs(schema.project, p) + self.assertIs(database.project, p) def test_add_another_project(self) -> None: p = Project('myproject') - schema = Schema() - schema.add_project(p) + database = Database() + database.add_project(p) p2 = Project('anotherproject') - res = schema.add_project(p2) - self.assertEqual(p2.schema, schema) + res = database.add_project(p2) + self.assertEqual(p2.database, database) self.assertIs(res, p2) - self.assertIs(schema.project, p2) - self.assertIsNone(p.schema) + self.assertIs(database.project, p2) + self.assertIsNone(p.database) def test_delete_project(self) -> None: p = Project('myproject') - schema = Schema() - schema.add_project(p) - res = schema.delete_project() - self.assertIsNone(p.schema, schema) + database = Database() + database.add_project(p) + res = database.delete_project() + self.assertIsNone(p.database, database) self.assertIs(res, p) - self.assertIsNone(schema.project) + self.assertIsNone(database.project) def test_delete_missing_project(self) -> None: - schema = Schema() - with self.assertRaises(SchemaValidationError): - schema.delete_project() + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_project() def test_geititem(self) -> None: t1 = Table('table1') t2 = Table('table2') - schema = Schema() - schema.add_table(t1) - schema.add_table(t2) - self.assertIs(schema['table1'], t1) - self.assertIs(schema['table2'], t2) - self.assertIs(schema[0], t1) - self.assertIs(schema[1], t2) + database = Database() + database.add_table(t1) + database.add_table(t2) + self.assertIs(database['table1'], t1) + self.assertIs(database['table2'], t2) + self.assertIs(database[0], t1) + self.assertIs(database[1], t2) with self.assertRaises(TypeError): - schema[None] + database[None] with self.assertRaises(IndexError): - schema[2] + database[2] with self.assertRaises(KeyError): - schema['wrong'] + database['wrong'] def test_iter(self) -> None: t1 = Table('table1') t2 = Table('table2') - schema = Schema() - schema.add_table(t1) - schema.add_table(t2) - self.assertEqual(list(iter(schema)), [t1, t2]) + database = Database() + database.add_table(t1) + database.add_table(t2) + self.assertEqual(list(iter(database)), [t1, t2]) def test_add(self) -> None: t1 = Table('table1') t2 = Table('table2') tg = TableGroup('mytablegroup', [t1, t2]) e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) - schema = Schema() - schema.add(t1) - schema.add(t2) - schema.add(e) - schema.add(tg) - self.assertIs(t1.schema, schema) - self.assertIs(t2.schema, schema) - self.assertIs(e.schema, schema) - self.assertIs(tg.schema, schema) - self.assertIn(t1, schema.tables) - self.assertIn(t2, schema.tables) - self.assertIn(tg, schema.table_groups) - self.assertIn(e, schema.enums) + database = Database() + database.add(t1) + database.add(t2) + database.add(e) + database.add(tg) + self.assertIs(t1.database, database) + self.assertIs(t2.database, database) + self.assertIs(e.database, database) + self.assertIs(tg.database, database) + self.assertIn(t1, database.tables) + self.assertIn(t2, database.tables) + self.assertIn(tg, database.table_groups) + self.assertIn(e, database.enums) def test_add_bad(self) -> None: class Test: pass t = Test() - schema = Schema() - with self.assertRaises(SchemaValidationError): - schema.add(t) + database = Database() + with self.assertRaises(DatabaseValidationError): + database.add(t) with self.assertRaises(AttributeError): - t.schema + t.database def test_delete(self) -> None: t1 = Table('table1') @@ -330,39 +330,39 @@ def test_delete(self) -> None: tg = TableGroup('mytablegroup', [t1, t2]) e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) p = Project('myproject') - schema = Schema() - schema.add(t1) - schema.add(t2) - schema.add(e) - schema.add(tg) - schema.add(ref) - schema.add(p) - - schema.delete(t1) - schema.delete(t2) - schema.delete(e) - schema.delete(tg) - schema.delete(ref) - schema.delete(p) - self.assertIsNone(t1.schema) - self.assertIsNone(t2.schema) - self.assertIsNone(e.schema) - self.assertIsNone(tg.schema) - self.assertIsNone(ref.schema) - self.assertIsNone(p.schema) - self.assertIsNone(schema.project) - self.assertNotIn(t1, schema.tables) - self.assertNotIn(t2, schema.tables) - self.assertNotIn(tg, schema.table_groups) - self.assertNotIn(e, schema.enums) - self.assertNotIn(ref, schema.refs) + database = Database() + database.add(t1) + database.add(t2) + database.add(e) + database.add(tg) + database.add(ref) + database.add(p) + + database.delete(t1) + database.delete(t2) + database.delete(e) + database.delete(tg) + database.delete(ref) + database.delete(p) + self.assertIsNone(t1.database) + self.assertIsNone(t2.database) + self.assertIsNone(e.database) + self.assertIsNone(tg.database) + self.assertIsNone(ref.database) + self.assertIsNone(p.database) + self.assertIsNone(database.project) + self.assertNotIn(t1, database.tables) + self.assertNotIn(t2, database.tables) + self.assertNotIn(tg, database.table_groups) + self.assertNotIn(e, database.enums) + self.assertNotIn(ref, database.refs) def test_delete_bad(self) -> None: class Test: pass t = Test() - schema = Schema() - with self.assertRaises(SchemaValidationError): - schema.delete(t) + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete(t) with self.assertRaises(AttributeError): - t.schema + t.database From 76c45399276408f1f41f4c492814ee8e057a6f6b Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Thu, 19 May 2022 08:45:35 +0200 Subject: [PATCH 030/125] multiple schemas for tables and enums --- TODO.md | 7 +- pydbml/classes/enum.py | 12 ++- pydbml/classes/reference.py | 13 ++-- pydbml/classes/table.py | 31 ++++++-- pydbml/classes/table_group.py | 7 +- pydbml/database.py | 8 +- pydbml/definitions/column.py | 30 +++++--- pydbml/definitions/enum.py | 7 +- pydbml/definitions/reference.py | 88 ++++++++++++++++++---- pydbml/definitions/table.py | 4 +- pydbml/definitions/table_group.py | 4 +- pydbml/parser/blueprints.py | 23 +++++- pydbml/parser/parser.py | 13 ++-- test/test_blueprints/test_column.py | 22 ++++++ test/test_classes/test_enum.py | 28 +++++++ test/test_classes/test_reference.py | 92 ++++++++++++++++++++++- test/test_classes/test_table.py | 44 +++++++++++ test/test_classes/test_table_group.py | 41 +++++++--- test/test_data/integration1.dbml | 4 +- test/{test_schema.py => test_database.py} | 8 +- test/test_definitions/test_column.py | 12 ++- test/test_definitions/test_enum.py | 10 ++- test/test_definitions/test_reference.py | 28 +++++++ test/test_definitions/test_table.py | 12 +++ test/test_docs.py | 6 +- test/test_editing.py | 21 ++---- test/test_parser.py | 16 ++-- 27 files changed, 474 insertions(+), 117 deletions(-) rename test/{test_schema.py => test_database.py} (98%) diff --git a/TODO.md b/TODO.md index 57ada15..f2805f9 100644 --- a/TODO.md +++ b/TODO.md @@ -4,6 +4,7 @@ - `_type` -> `type` * - expression class - schema.add and .delete to support multiple arguments (handle errors properly) -- 2.3.1 Multiline comment /* ... */ -- 2.4 Multiple Schemas -- validation on "add_index", "add_table" etc \ No newline at end of file +* - 2.3.1 Multiline comment /* ... */ +* - 2.4 Multiple Schemas +- validation on "add_index", "add_table" etc +* - enum type in table definition with schema \ No newline at end of file diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index ac84c2d..6c3ab5f 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -58,9 +58,11 @@ class Enum(SQLOjbect): def __init__(self, name: str, items: List['EnumItem'], + schema: str = 'public', comment: Optional[str] = None): self.database = None self.name = name + self.schema = schema self.items = items self.comment = comment @@ -92,6 +94,12 @@ def __str__(self): return self.name + def _get_full_name_for_sql(self) -> str: + if self.schema == 'public': + return f'"{self.name}"' + else: + return f'"{self.schema}"."{self.name}"' + @property def sql(self): ''' @@ -107,7 +115,7 @@ def sql(self): ''' self.check_attributes_for_sql() result = comment_to_sql(self.comment) if self.comment else '' - result += f'CREATE TYPE "{self.name}" AS ENUM (\n' + result += f'CREATE TYPE {self._get_full_name_for_sql()} AS ENUM (\n' result += '\n'.join(f'{indent(i.sql, 2)}' for i in self.items) result += '\n);' return result @@ -115,7 +123,7 @@ def sql(self): @property def dbml(self): result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Enum "{self.name}" {{\n' + result += f'Enum {self._get_full_name_for_sql()} {{\n' items_str = '\n'.join(i.dbml for i in self.items) result += indent(items_str) result += '\n}' diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 4ed997e..ac8dd98 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -112,7 +112,7 @@ def sql(self): result = comment_to_sql(self.comment) if self.comment else '' result += ( f'{c}FOREIGN KEY ("{cols}") ' - f'REFERENCES "{ref_table.name}" ("{ref_cols}")' + f'REFERENCES {ref_table._get_full_name_for_sql()} ("{ref_cols}")' ) if self.on_update: result += f' ON UPDATE {self.on_update.upper()}' @@ -133,8 +133,8 @@ def sql(self): result = comment_to_sql(self.comment) if self.comment else '' result += ( - f'ALTER TABLE "{t1.name}" ADD {c}FOREIGN KEY ({c1}) ' - f'REFERENCES "{t2.name}" ({c2})' + f'ALTER TABLE {t1._get_full_name_for_sql()} ADD {c}FOREIGN KEY ({c1}) ' + f'REFERENCES {t2._get_full_name_for_sql()} ({c2})' ) if self.on_update: result += f' ON UPDATE {self.on_update.upper()}' @@ -149,7 +149,8 @@ def dbml(self): # settings are ignored for inline ref if len(self.col2) > 1: raise DBMLError('Cannot render DBML: composite ref cannot be inline') - return f'ref: {self.type} "{self.col2[0].table.name}"."{self.col2[0].name}"' + table_name = self.col2[0].table._get_full_name_for_sql() + return f'ref: {self.type} {table_name}."{self.col2[0].name}"' else: result = comment_to_dbml(self.comment) if self.comment else '' result += 'Ref' @@ -177,9 +178,9 @@ def dbml(self): options_str = f' [{", ".join(options)}]' if options else '' result += ( ' {\n ' - f'"{self.col1[0].table.name}".{col1} ' + f'{self.col1[0].table._get_full_name_for_sql()}.{col1} ' f'{self.type} ' - f'"{self.col2[0].table.name}".{col2}' + f'{self.col2[0].table._get_full_name_for_sql()}.{col2}' f'{options_str}' '\n}' ) diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index c5847f6..f79b72a 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -30,21 +30,25 @@ class Table(SQLOjbect): def __init__(self, name: str, + schema: str = 'public', alias: Optional[str] = None, note: Optional[Union['Note', str]] = None, header_color: Optional[str] = None, - # refs: Optional[List[TableReference]] = None, comment: Optional[str] = None): self.database: Optional[Database] = None self.name = name + self.schema = schema self.columns: List[Column] = [] self.indexes: List[Index] = [] self.alias = alias if alias else None self.note = Note(note) self.header_color = header_color - # self.refs = refs or [] self.comment = comment + @property + def full_name(self) -> str: + return f'{self.schema}.{self.name}' + def add_column(self, c: Column) -> None: ''' Adds column to self.columns attribute and sets in this column the @@ -105,6 +109,12 @@ def _get_references_for_sql(self) -> List[Reference]: result.append(ref) return result + def _get_full_name_for_sql(self) -> str: + if self.schema == 'public': + return f'"{self.name}"' + else: + return f'"{self.schema}"."{self.name}"' + def __getitem__(self, k: Union[int, str]) -> Column: if isinstance(k, int): return self.columns[k] @@ -129,10 +139,10 @@ def __repr__(self): ''' >>> table = Table('customers') >>> table -
+
''' - return f'
' + return f'
' def __str__(self): ''' @@ -140,10 +150,10 @@ def __str__(self): >>> table.add_column(Column('id', 'INTEGER')) >>> table.add_column(Column('name', 'VARCHAR2')) >>> print(table) - customers(id, name) + public.customers(id, name) ''' - return f'{self.name}({", ".join(c.name for c in self.columns)})' + return f'{self.schema}.{self.name}({", ".join(c.name for c in self.columns)})' @property def sql(self): @@ -161,7 +171,9 @@ def sql(self): CREATE INDEX ON "products" ("id", "name"); ''' self.check_attributes_for_sql() - components = [f'CREATE TABLE "{self.name}" ('] + name = self._get_full_name_for_sql() + components = [f'CREATE TABLE {name} ('] + body = [] body.extend(indent(c.sql, 2) for c in self.columns) body.extend(indent(i.sql, 2) for i in self.indexes if i.pk) @@ -188,7 +200,10 @@ def sql(self): @property def dbml(self): result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Table "{self.name}" ' + + name = self._get_full_name_for_sql() + + result += f'Table {name} ' if self.alias: result += f'as "{self.alias}" ' result += '{\n' diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py index dbba53a..06fe718 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/classes/table_group.py @@ -45,15 +45,10 @@ def __iter__(self): @property def dbml(self): - def item_to_str(val: Union[str, Table]) -> str: - if isinstance(val, Table): - return val.name - else: - return val result = comment_to_dbml(self.comment) if self.comment else '' result += f'TableGroup {self.name} {{\n' for i in self.items: - result += f' {item_to_str(i)}\n' + result += f' {i._get_full_name_for_sql()}\n' result += '}' return result diff --git a/pydbml/database.py b/pydbml/database.py index ab7b4e1..3f3b922 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -63,15 +63,15 @@ def add(self, obj: Any) -> Any: def add_table(self, obj: Table) -> Table: if obj in self.tables: raise DatabaseValidationError(f'{obj} is already in the database.') - if obj.name in self.table_dict: - raise DatabaseValidationError(f'Table {obj.name} is already in the database.') + if obj.full_name in self.table_dict: + raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.') if obj.alias and obj.alias in self.table_dict: raise DatabaseValidationError(f'Table {obj.alias} is already in the database.') self._set_database(obj) self.tables.append(obj) - self.table_dict[obj.name] = obj + self.table_dict[obj.full_name] = obj if obj.alias: self.table_dict[obj.alias] = obj return obj @@ -141,7 +141,7 @@ def delete_table(self, obj: Table) -> Table: except ValueError: raise DatabaseValidationError(f'{obj} is not in the database.') self._unset_database(self.tables.pop(index)) - result = self.table_dict.pop(obj.name) + result = self.table_dict.pop(obj.full_name) if obj.alias: self.table_dict.pop(obj.alias) return result diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index 144274f..d80bf98 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -20,22 +20,28 @@ pp.ParserElement.setDefaultWhitespaceChars(' \t\r') -type_args = ("(" + pp.originalTextFor(expression)('args') + ")") -type_name = (pp.Word(pp.alphanums + '_') | pp.QuotedString('"'))('name') -column_type = (type_name + type_args[0, 1]) +type_args = ("(" + pp.originalTextFor(expression) + ")") +# column type is parsed as a single string, it will be split by blueprint +column_type = pp.Combine((name + '.' + name) | ((name) + type_args[0, 1])) -def parse_column_type(s, l, t) -> str: - ''' - int or "mytype" or varchar(255) - ''' - result = t['name'] - args = t.get('args') - result += '(' + args + ')' if args else '' - return result + +# def parse_column_type(s, l, t) -> str: +# ''' +# int or "mytype" or varchar(255) or +# ''' +# result = {} +# if '.' in t['name']: +# result['schema'], result['name'] = t['name'].split('.') +# else: +# result['name'] = t['name'] + +# if 'args' in t: +# result['args'] = f'({t["args"]})' +# return result -column_type.setParseAction(parse_column_type) +# column_type.setParseAction(parse_column_type) default = pp.CaselessLiteral('default:').suppress() + _ - ( diff --git a/pydbml/definitions/enum.py b/pydbml/definitions/enum.py index b6fd993..83a85db 100644 --- a/pydbml/definitions/enum.py +++ b/pydbml/definitions/enum.py @@ -54,9 +54,11 @@ def parse_enum_item(s, l, t): enum_body = enum_item[1, ...] +enum_name = pp.Combine(name("schema") + '.' + name("name")) | name("name") + enum = _c + ( pp.CaselessLiteral('enum') - - name('name') + _ + - enum_name + _ - '{' + enum_body('items') + n - '}' @@ -77,6 +79,9 @@ def parse_enum(s, l, t): 'items': list(t['items']) } + if 'schema' in t: + init_dict['schema'] = t['schema'] + if 'comment_before' in t: comment = '\n'.join(c[0] for c in t['comment_before']) init_dict['comment'] = comment diff --git a/pydbml/definitions/reference.py b/pydbml/definitions/reference.py index 9a511c1..a428b67 100644 --- a/pydbml/definitions/reference.py +++ b/pydbml/definitions/reference.py @@ -11,17 +11,33 @@ pp.ParserElement.setDefaultWhitespaceChars(' \t\r') relation = pp.oneOf("> - <") -ref_inline = pp.Literal("ref:") - relation('type') - name('table') - '.' - name('field') + +col_name = ( + ( + name('schema') + '.' + name('table') + '.' - name('field') + ) | ( + name('table') + '.' + name('field') + ) +) + +ref_inline = pp.Literal("ref:") - relation('type') - col_name def parse_inline_relation(s, l, t): ''' ref: < table.column + or + ref: < schema1.table.column ''' - return ReferenceBlueprint(type=t['type'], - inline=True, - table2=t['table'], - col2=t['field']) + result = { + 'type': t['type'], + 'inline': True, + 'table2': t['table'], + 'col2': t['field'] + } + if 'schema' in t: + result['schema2'] = t['schema'] + return ReferenceBlueprint(**result) ref_inline.setParseAction(parse_inline_relation) @@ -77,16 +93,53 @@ def parse_ref_settings(s, l, t): ) name_or_composite = name | pp.Combine(composite_name) +ref_cols = ( + ( + name('schema') + + pp.Suppress('.') + name('table') + + pp.Suppress('.') + name_or_composite('field') + ) | ( + name('table') + + pp.Suppress('.') + name_or_composite('field') + ) +) + + +def parse_ref_cols(s, l, t): + ''' + table1.col1 + or + schema1.table1.col1 + or + schema1.table1.(col1, col2) + ''' + result = { + 'table': t['table'], + 'field': t['field'], + } + if 'schema' in t: + result['schema'] = t['schema'] + return result + + +ref_cols.setParseAction(parse_ref_cols) + ref_body = ( - name('table1') - - '.' - - name_or_composite('field1') + ref_cols('col1') - relation('type') - - name('table2') - - '.' - - name_or_composite('field2') + c + - ref_cols('col2') + c + ref_settings('settings')[0, 1] ) +# ref_body = ( +# table_name('table1') +# - '.' +# - name_or_composite('field1') +# - relation('type') +# - table_name('table2') +# - '.' +# - name_or_composite('field2') + c +# + ref_settings('settings')[0, 1] +# ) ref_short = _c + pp.CaselessLiteral('ref') + name('name')[0, 1] + ':' - ref_body @@ -110,11 +163,16 @@ def parse_ref(s, l, t): init_dict = { 'type': t['type'], 'inline': False, - 'table1': t['table1'], - 'col1': t['field1'], - 'table2': t['table2'], - 'col2': t['field2'] + 'table1': t['col1']['table'], + 'col1': t['col1']['field'], + 'table2': t['col2']['table'], + 'col2': t['col2']['field'], } + + if 'schema' in t['col1']: + init_dict['schema1'] = t['col1']['schema'] + if 'schema' in t['col2']: + init_dict['schema2'] = t['col2']['schema'] if 'name' in t: init_dict['name'] = t['name'] if 'settings' in t: diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index c5377a0..5c21540 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -74,8 +74,8 @@ def parse_table(s, l, t): init_dict = { 'name': t['name'], } - # if 'schema' in t: - # init_dict['schema'] = t['schema'] + if 'schema' in t: + init_dict['schema'] = t['schema'] if 'settings' in t: init_dict.update(t['settings']) if 'alias' in t: diff --git a/pydbml/definitions/table_group.py b/pydbml/definitions/table_group.py index 3929ab5..422067b 100644 --- a/pydbml/definitions/table_group.py +++ b/pydbml/definitions/table_group.py @@ -9,11 +9,13 @@ pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +table_name = pp.Combine(name + '.' + name) | name + table_group = _c + ( pp.CaselessLiteral('TableGroup') - name('name') + _ - '{' + _ - - (name + _)[...]('items') + _ + - (table_name + _)[...]('items') + _ - '}' ) + end diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index bc34ed7..da87685 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -46,8 +46,10 @@ class ReferenceBlueprint(Blueprint): type: Literal['>', '<', '-'] inline: bool name: Optional[str] = None + schema1: str = 'public' table1: Optional[str] = None col1: Optional[Union[str, Collection[str]]] = None + schema2: str = 'public' table2: Optional[str] = None col2: Optional[Union[str, Collection[str]]] = None comment: Optional[str] = None @@ -68,14 +70,14 @@ def build(self) -> 'Reference': raise ColumnNotFoundError("Can't build Reference, col2 unknown") if self.parser: - table1 = self.parser.locate_table(self.table1) + table1 = self.parser.locate_table(self.schema1, self.table1) else: raise RuntimeError('Parser is not set') col1_list = [c.strip('() ') for c in self.col1.split(',')] col1 = [table1[col] for col in col1_list] - table2 = self.parser.locate_table(self.table2) + table2 = self.parser.locate_table(self.schema2, self.table2) col2_list = [c.strip('() ') for c in self.col2.split(',')] col2 = [table2[col] for col in col2_list] @@ -108,8 +110,12 @@ def build(self) -> 'Column': if isinstance(self.default, ExpressionBlueprint): self.default = self.default.build() if self.parser: + if '.' in self.type: + schema, name = self.type.split('.') + else: + schema, name = 'public', self.type for enum in self.parser.database.enums: - if enum.name == self.type: + if (enum.schema, enum.name) == (schema, name): self.type = enum break return Column( @@ -153,6 +159,7 @@ def build(self) -> 'Index': @dataclass class TableBlueprint(Blueprint): name: str + schema: str = 'public' columns: List[ColumnBlueprint] = None indexes: Optional[List[IndexBlueprint]] = None alias: Optional[str] = None @@ -163,6 +170,7 @@ class TableBlueprint(Blueprint): def build(self) -> 'Table': result = Table( name=self.name, + schema=self.schema, alias=self.alias, note=self.note.build() if self.note else None, header_color=self.header_color, @@ -221,12 +229,14 @@ def build(self) -> 'EnumItem': class EnumBlueprint(Blueprint): name: str items: List[EnumItemBlueprint] + schema: str = 'public' comment: Optional[str] = None def build(self) -> 'Enum': return Enum( name=self.name, items=[ei.build() for ei in self.items], + schema=self.schema, comment=self.comment ) @@ -256,8 +266,13 @@ class TableGroupBlueprint(Blueprint): def build(self) -> 'TableGroup': if not self.parser: raise RuntimeError('Parser is not set') + items = [] + for table_name in self.items: + components = table_name.split('.') + schema, table = components if len(components) == 2 else 'public', components[0] + items.append(self.parser.locate_table(schema, table)) return TableGroup( name=self.name, - items=[self.parser.locate_table(table) for table in self.items], + items=items, comment=self.comment ) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index c4f18be..26ab2d0 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -165,13 +165,16 @@ def parse_blueprint(self, s, l, t): raise RuntimeError(f'type unknown: {blueprint}') blueprint.parser = self - def locate_table(self, name: str) -> 'Table': + def locate_table(self, schema: str, name: str) -> 'Table': if not self.database: raise RuntimeError('Database is not ready') - try: - result = self.database[name] - except KeyError: - raise TableNotFoundError(f'Table {name} not present in the database') + # first by alias + result = self.database.table_dict.get(name) + if result is None: + full_name = f'{schema}.{name}' + result = self.database.table_dict.get(full_name) + if result is None: + raise TableNotFoundError(f'Table {full_name} not present in the database') return result def build_database(self): diff --git a/test/test_blueprints/test_column.py b/test/test_blueprints/test_column.py index 0f51a85..622bbd5 100644 --- a/test/test_blueprints/test_column.py +++ b/test/test_blueprints/test_column.py @@ -66,3 +66,25 @@ def test_enum_type(self) -> None: bp.parser = parser result = bp.build() self.assertIs(result.type, e) + + def test_enum_type_schema(self) -> None: + s = Database() + e = Enum( + 'myenum', + schema='myschema', + items=[ + EnumItem('i1'), + EnumItem('i2') + ] + ) + s.add(e) + parser = Mock() + parser.database = s + + bp = ColumnBlueprint( + name='testcol', + type='myschema.myenum' + ) + bp.parser = parser + result = bp.build() + self.assertIs(result.type, e) diff --git a/test/test_classes/test_enum.py b/test/test_classes/test_enum.py index c49f51e..66c8361 100644 --- a/test/test_classes/test_enum.py +++ b/test/test_classes/test_enum.py @@ -40,6 +40,23 @@ def test_simple_enum(self) -> None: );''' self.assertEqual(e.sql, expected) + def test_schema(self) -> None: + items = [ + EnumItem('created'), + EnumItem('running'), + EnumItem('donef'), + EnumItem('failure'), + ] + e = Enum('job_status', items, schema="myschema") + expected = \ +'''CREATE TYPE "myschema"."job_status" AS ENUM ( + 'created', + 'running', + 'donef', + 'failure', +);''' + self.assertEqual(e.sql, expected) + def test_comments(self) -> None: items = [ EnumItem('created', comment='EnumItem comment'), @@ -72,6 +89,17 @@ def test_dbml_simple(self): }''' self.assertEqual(e.dbml, expected) + def test_dbml_schema(self): + items = [EnumItem('en-US'), EnumItem('ru-RU'), EnumItem('en-GB')] + e = Enum('lang', items, schema="myschema") + expected = \ +'''Enum "myschema"."lang" { + "en-US" + "ru-RU" + "en-GB" +}''' + self.assertEqual(e.dbml, expected) + def test_dbml_full(self): items = [ EnumItem('en-US', note='preferred'), diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 511369a..4193899 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -19,6 +19,18 @@ def test_sql_single(self): expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' self.assertEqual(ref.sql, expected) + def test_sql_schema_single(self): + t = Table('products', schema='myschema1') + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names', schema='myschema2') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('>', c1, c2) + + expected = 'ALTER TABLE "myschema1"."products" ADD FOREIGN KEY ("name") REFERENCES "myschema2"."names" ("name_val");' + self.assertEqual(ref.sql, expected) + def test_sql_reverse(self): t = Table('products') c1 = Column('name', 'varchar2') @@ -47,6 +59,22 @@ def test_sql_multiple(self): expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val");' self.assertEqual(ref.sql, expected) + def test_sql_schema_multiple(self): + t = Table('products', schema="myschema1") + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names', schema="myschema2") + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('<', [c11, c12], (c21, c22)) + + expected = 'ALTER TABLE "myschema2"."names" ADD FOREIGN KEY ("name_val", "country_val") REFERENCES "myschema1"."products" ("name", "country");' + self.assertEqual(ref.sql, expected) + def test_sql_full(self): t = Table('products') c11 = Column('name', 'varchar2') @@ -92,6 +120,23 @@ def test_dbml_simple(self): }''' self.assertEqual(ref.dbml, expected) + def test_dbml_schema(self): + t = Table('products', schema="myschema1") + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names', schema="myschema2") + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference('>', c2, c21) + + expected = \ +'''Ref { + "myschema1"."products"."name" > "myschema2"."names"."name_val" +}''' + self.assertEqual(ref.dbml, expected) + def test_dbml_full(self): t = Table('products') c1 = Column('id', 'integer') @@ -100,7 +145,7 @@ def test_dbml_full(self): t.add_column(c1) t.add_column(c2) t.add_column(c3) - t2 = Table('names') + t2 = Table('names', schema="myschema") c21 = Column('name_val', 'varchar2') c22 = Column('country', 'varchar2') t2.add_column(c21) @@ -119,7 +164,7 @@ def test_dbml_full(self): '''// Reference comment // multiline Ref nameref { - "products".("name", "country") < "names".("name_val", "country") [update: CASCADE, delete: SET NULL] + "products".("name", "country") < "myschema"."names".("name_val", "country") [update: CASCADE, delete: SET NULL] }''' self.assertEqual(ref.dbml, expected) @@ -137,6 +182,18 @@ def test_sql_single(self): expected = 'FOREIGN KEY ("name") REFERENCES "names" ("name_val")' self.assertEqual(ref.sql, expected) + def test_sql_schema_single(self): + t = Table('products', schema="myschema1") + c1 = Column('name', 'varchar2') + t.add_column(c1) + t2 = Table('names', schema="myschema2") + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('>', c1, c2, inline=True) + + expected = 'FOREIGN KEY ("name") REFERENCES "myschema2"."names" ("name_val")' + self.assertEqual(ref.sql, expected) + def test_sql_reverse(self): t = Table('products') c1 = Column('name', 'varchar2') @@ -165,6 +222,22 @@ def test_sql_multiple(self): expected = 'FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val")' self.assertEqual(ref.sql, expected) + def test_sql_schema_multiple(self): + t = Table('products', schema="myschema1") + c11 = Column('name', 'varchar2') + c12 = Column('country', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names', schema="myschema2") + c21 = Column('name_val', 'varchar2') + c22 = Column('country_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('<', [c11, c12], (c21, c22), inline=True) + + expected = 'FOREIGN KEY ("name_val", "country_val") REFERENCES "myschema1"."products" ("name", "country")' + self.assertEqual(ref.sql, expected) + def test_sql_full(self): t = Table('products') c11 = Column('name', 'varchar2') @@ -208,6 +281,20 @@ def test_dbml_simple(self): expected = 'ref: > "names"."name_val"' self.assertEqual(ref.dbml, expected) + def test_dbml_schema(self): + t = Table('products', schema="myschema1") + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t2 = Table('names', schema="myschema2") + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference('>', c2, c21, inline=True) + + expected = 'ref: > "myschema2"."names"."name_val"' + self.assertEqual(ref.dbml, expected) + def test_dbml_settings_ignored(self): t = Table('products') c1 = Column('id', 'integer') @@ -266,7 +353,6 @@ def test_validate_different_tables(self): t1.add_column(c12) t2 = Table('names') c21 = Column('name_val', 'varchar2') - c22 = Column('product', 'varchar2') t2.add_column(c21) ref = Reference( '<', diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index ea5c0ca..1f77a97 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -13,6 +13,12 @@ class TestTable(TestCase): + def test_schema(self) -> None: + t = Table('test') + self.assertEqual(t.schema, 'public') + t2 = Table('test', 'schema1') + self.assertEqual(t2.schema, 'schema1') + def test_one_column(self) -> None: t = Table('products') c = Column('id', 'integer') @@ -167,6 +173,28 @@ def test_index_inline(self) -> None: );''' self.assertEqual(t.sql, expected) + def test_schema_sql(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + s = Database() + s.add(t) + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar2 +);''' + self.assertEqual(t.sql, expected) + t.schema = 'myschema' + expected = \ +'''CREATE TABLE "myschema"."products" ( + "id" integer, + "name" varchar2 +);''' + self.assertEqual(t.sql, expected) + def test_index_inline_and_comments(self) -> None: t = Table('products', comment='Multiline\ntable comment') c1 = Column('id', 'integer') @@ -317,6 +345,22 @@ def test_dbml_simple(self): }''' self.assertEqual(t.dbml, expected) + def test_schema_dbml(self): + t = Table('products', schema="myschema") + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + s = Database() + s.add(t) + + expected = \ +'''Table "myschema"."products" { + "id" integer + "name" varchar2 +}''' + self.assertEqual(t.dbml, expected) + def test_dbml_reference(self): t = Table('products') c1 = Column('id', 'integer') diff --git a/test/test_classes/test_table_group.py b/test/test_classes/test_table_group.py index 75d05a1..d7c8a9e 100644 --- a/test/test_classes/test_table_group.py +++ b/test/test_classes/test_table_group.py @@ -5,15 +5,16 @@ class TestTableGroup(TestCase): - def test_dbml(self): - tg = TableGroup('mytg', ['merchants', 'countries', 'customers']) - expected = \ -'''TableGroup mytg { - merchants - countries - customers -}''' - self.assertEqual(tg.dbml, expected) +# string items no longer supported +# def test_dbml(self): +# tg = TableGroup('mytg', ['merchants', 'countries', 'customers']) +# expected = \ +# '''TableGroup mytg { +# merchants +# countries +# customers +# }''' +# self.assertEqual(tg.dbml, expected) def test_dbml_with_comment_and_real_tables(self): merchants = Table('merchants') @@ -28,9 +29,25 @@ def test_dbml_with_comment_and_real_tables(self): '''// My table group // multiline comment TableGroup mytg { - merchants - countries - customers + "merchants" + "countries" + "customers" +}''' + self.assertEqual(tg.dbml, expected) + + def test_dbml_schema(self): + merchants = Table('merchants', schema="myschema1") + countries = Table('countries', schema="myschema2") + customers = Table('customers', schema="myschema3") + tg = TableGroup( + 'mytg', + [merchants, countries, customers], + ) + expected = \ +'''TableGroup mytg { + "myschema1"."merchants" + "myschema2"."countries" + "myschema3"."customers" }''' self.assertEqual(tg.dbml, expected) diff --git a/test/test_data/integration1.dbml b/test/test_data/integration1.dbml index e0a99f9..9c59723 100644 --- a/test/test_data/integration1.dbml +++ b/test/test_data/integration1.dbml @@ -39,6 +39,6 @@ Ref { } TableGroup Unanimate { - books - countries + "books" + "countries" } \ No newline at end of file diff --git a/test/test_schema.py b/test/test_database.py similarity index 98% rename from test/test_schema.py rename to test/test_database.py index d8c3210..86c6dc9 100644 --- a/test/test_schema.py +++ b/test/test_database.py @@ -39,7 +39,7 @@ def test_add_table_alias(self) -> None: def test_add_table_alias_bad(self) -> None: c = Column('test', 'varchar', True) - t = Table('myalias') + t = Table('test', alias='myalias') t.add_column(c) database = Database() database.add_table(t) @@ -267,12 +267,12 @@ def test_delete_missing_project(self) -> None: def test_geititem(self) -> None: t1 = Table('table1') - t2 = Table('table2') + t2 = Table('table2', schema='myschema') database = Database() database.add_table(t1) database.add_table(t2) - self.assertIs(database['table1'], t1) - self.assertIs(database['table2'], t2) + self.assertIs(database['public.table1'], t1) + self.assertIs(database['myschema.table2'], t2) self.assertIs(database[0], t1) self.assertIs(database[1], t2) with self.assertRaises(TypeError): diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index 963be43..3ab6350 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -23,7 +23,12 @@ def test_simple(self) -> None: self.assertEqual(res[0], val) def test_quoted(self) -> None: - val = 'mytype' + val = '"mytype"' + res = column_type.parseString(val, parseAll=True) + self.assertEqual(res[0], 'mytype') + + def test_with_schema(self) -> None: + val = 'myschema.mytype' res = column_type.parseString(val, parseAll=True) self.assertEqual(res[0], val) @@ -190,6 +195,11 @@ def test_with_settings(self) -> None: self.assertTrue(res[0].not_null) self.assertTrue(res[0].note is not None) + def test_enum_type_bad(self) -> None: + val = "_test_ myschema.mytype(12) [unique]\n" + with self.assertRaises(ParseException): + table_column.parseString(val, parseAll=True) + def test_settings_and_constraints(self) -> None: val = "_test_ \"mytype\" unique pk [not null]\n" res = table_column.parseString(val, parseAll=True) diff --git a/test/test_definitions/test_enum.py b/test/test_definitions/test_enum.py index ab2a38b..f345929 100644 --- a/test/test_definitions/test_enum.py +++ b/test/test_definitions/test_enum.py @@ -11,7 +11,7 @@ ParserElement.setDefaultWhitespaceChars(' \t\r') -class TestTableSettings(TestCase): +class TestEnumSettings(TestCase): def test_note(self) -> None: val = '[note: "note content"]' enum_settings.parseString(val, parseAll=True) @@ -69,6 +69,14 @@ def test_several_items(self) -> None: self.assertEqual(len(res[0].items), 4) self.assertEqual(res[0].name, 'members') + def test_schema(self) -> None: + val1 = 'enum members {janitor teacher\nstudent\nheadmaster\n}' + res1 = enum.parseString(val1, parseAll=True) + self.assertEqual(res1[0].schema, 'public') + val2 = 'enum myschema.members {janitor teacher\nstudent\nheadmaster\n}' + res2 = enum.parseString(val2, parseAll=True) + self.assertEqual(res2[0].schema, 'myschema') + def test_comment(self) -> None: val = '//comment before\nenum members {janitor teacher\nstudent\nheadmaster\n}' res = enum.parseString(val, parseAll=True) diff --git a/test/test_definitions/test_reference.py b/test/test_definitions/test_reference.py index a1a5dc9..2d1d82d 100644 --- a/test/test_definitions/test_reference.py +++ b/test/test_definitions/test_reference.py @@ -37,6 +37,14 @@ def test_ok(self) -> None: self.assertIsNone(res[0].table1) self.assertIsNone(res[0].col1) + def test_schema(self) -> None: + val1 = 'ref: < table.column' + res1 = ref_inline.parseString(val1, parseAll=True) + self.assertEqual(res1[0].schema2, 'public') + val2 = 'ref: < myschema.table.column' + res2 = ref_inline.parseString(val2, parseAll=True) + self.assertEqual(res2[0].schema2, 'myschema') + def test_nok(self) -> None: vals = [ 'ref:\n< table.column', @@ -90,6 +98,16 @@ def test_no_name(self) -> None: self.assertEqual(res[0].table2, 'table2') self.assertEqual(res[0].col2, 'col2') + def test_schema(self) -> None: + val1 = 'ref: table1.col1 > table2.col2' + res1 = ref_short.parseString(val1, parseAll=True) + self.assertEqual(res1[0].schema1, 'public') + self.assertEqual(res1[0].schema2, 'public') + val2 = 'ref: myschema1.table1.col1 > myschema2.table2.col2' + res2 = ref_short.parseString(val2, parseAll=True) + self.assertEqual(res2[0].schema1, 'myschema1') + self.assertEqual(res2[0].schema2, 'myschema2') + def test_name(self) -> None: val = 'ref name: table1.col1 > table2.col2' res = ref_short.parseString(val, parseAll=True) @@ -196,6 +214,16 @@ def test_no_name(self) -> None: self.assertEqual(res[0].table2, 'table2') self.assertEqual(res[0].col2, 'col2') + def test_schema(self) -> None: + val1 = 'ref {table1.col1 > table2.col2}' + res1 = ref_long.parseString(val1, parseAll=True) + self.assertEqual(res1[0].schema1, 'public') + self.assertEqual(res1[0].schema2, 'public') + val2 = 'ref {myschema1.table1.col1 > myschema2.table2.col2}' + res2 = ref_long.parseString(val2, parseAll=True) + self.assertEqual(res2[0].schema1, 'myschema1') + self.assertEqual(res2[0].schema2, 'myschema2') + def test_name(self) -> None: val = 'ref\nname\n{\ntable1.col1 > table2.col2\n}' res = ref_long.parseString(val, parseAll=True) diff --git a/test/test_definitions/test_table.py b/test/test_definitions/test_table.py index 02983a6..8028292 100644 --- a/test/test_definitions/test_table.py +++ b/test/test_definitions/test_table.py @@ -135,6 +135,18 @@ def test_with_alias(self) -> None: self.assertEqual(res[0].alias, 'ii') self.assertEqual(len(res[0].columns), 1) + def test_schema(self) -> None: + val = 'table ids as ii {\nid integer\n}' + res = table.parseString(val, parseAll=True) + self.assertEqual(res[0].name, 'ids') + self.assertEqual(res[0].schema, 'public') # default + self.assertEqual(len(res[0].columns), 1) + + val = 'table myschema.ids as ii {\nid integer\n}' + res = table.parseString(val, parseAll=True) + self.assertEqual(res[0].name, 'ids') + self.assertEqual(res[0].schema, 'myschema') + def test_with_settings(self) -> None: val = 'table ids as ii [headercolor: #ccc, note: "headernote"] {\nid integer\n}' res = table.parseString(val, parseAll=True) diff --git a/test/test_docs.py b/test/test_docs.py index 66191cb..f186545 100644 --- a/test/test_docs.py +++ b/test/test_docs.py @@ -206,7 +206,7 @@ def test_relationship_settings(self) -> None: def test_note_definition(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'note_definition.dbml') self.assertEqual(len(results.tables), 1) - users = results['users'] + users = results['public.users'] self.assertEqual(users.note.text, 'This is a note of this table') def test_project_notes(self) -> None: @@ -218,7 +218,7 @@ def test_project_notes(self) -> None: def test_column_notes(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'column_notes.dbml') - users = results['users'] + users = results['public.users'] self.assertEqual(users.note.text, 'Stores user data') self.assertEqual(users['column_name'].note.text, 'replace text here') @@ -226,7 +226,7 @@ def test_column_notes(self) -> None: def test_enum_definition(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'enum_definition.dbml') - jobs = results['jobs'] + jobs = results['public.jobs'] jobs['status'].type == 'job_status' self.assertEqual(len(results.enums), 1) diff --git a/test/test_editing.py b/test/test_editing.py index 8b20866..68120b7 100644 --- a/test/test_editing.py +++ b/test/test_editing.py @@ -3,16 +3,9 @@ from pathlib import Path from unittest import TestCase -from pyparsing import ParseException -from pyparsing import ParseSyntaxException from pyparsing import ParserElement from pydbml import PyDBML -from pydbml.definitions.table import alias -from pydbml.definitions.table import header_color -from pydbml.definitions.table import table -from pydbml.definitions.table import table_body -from pydbml.definitions.table import table_settings ParserElement.setDefaultWhitespaceChars(' \t\r') @@ -28,7 +21,7 @@ def setUp(self): class TestEditTable(EditingTestCase): def test_name(self) -> None: - products = self.dbml['products'] + products = self.dbml['public.products'] products.name = 'changed_products' self.assertIn('CREATE TABLE "changed_products"', products.sql) self.assertIn('CREATE INDEX "product_status" ON "changed_products"', products.sql) @@ -42,7 +35,7 @@ def test_name(self) -> None: self.assertIn('ON "changed_products"', index.sql) def test_alias(self) -> None: - products = self.dbml['products'] + products = self.dbml['public.products'] products.alias = 'new_alias' self.assertIn('as "new_alias"', products.dbml) @@ -50,7 +43,7 @@ def test_alias(self) -> None: class TestColumn(EditingTestCase): def test_name(self) -> None: - products = self.dbml['products'] + products = self.dbml['public.products'] col = products['name'] col.name = 'new_name' self.assertEqual(col.sql, '"new_name" varchar') @@ -61,7 +54,7 @@ def test_name(self) -> None: self.assertEqual(col, products[col.name]) def test_name_index(self) -> None: - products = self.dbml['products'] + products = self.dbml['public.products'] col = products['status'] col.name = 'changed_status' self.assertIn('"changed_status"', products.indexes[0].sql) @@ -76,17 +69,17 @@ def test_name_index(self) -> None: ) def test_name_ref(self) -> None: - products = self.dbml['products'] + products = self.dbml['public.products'] col = products['merchant_id'] col.name = 'changed_merchant_id' - merchants = self.dbml['merchants'] + merchants = self.dbml['public.merchants'] table_ref = merchants.get_refs()[0] self.assertIn('FOREIGN KEY ("changed_merchant_id")', table_ref.sql) class TestEnum(EditingTestCase): def test_enum_name(self): - products = self.dbml['products'] + products = self.dbml['public.products'] enum = self.dbml.enums[0] enum.name = 'changed product status' self.assertIn('CREATE TYPE "changed product status"', enum.sql) diff --git a/test/test_parser.py b/test/test_parser.py index cf3f15a..f045b3f 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -18,15 +18,15 @@ def setUp(self): def test_table_refs(self) -> None: p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - r = p['orders'].get_refs() + r = p['public.orders'].get_refs() self.assertEqual(r[0].col2[0].name, 'order_id') self.assertEqual(r[0].col1[0].table.name, 'orders') self.assertEqual(r[0].col1[0].name, 'id') - r = p['products'].get_refs() + r = p['public.products'].get_refs() self.assertEqual(r[1].col1[0].name, 'merchant_id') self.assertEqual(r[1].col2[0].table.name, 'merchants') self.assertEqual(r[1].col2[0].name, 'id') - r = p['users'].get_refs() + r = p['public.users'].get_refs() self.assertEqual(r[0].col1[0].name, 'country_code') self.assertEqual(r[0].col2[0].table.name, 'countries') self.assertEqual(r[0].col2[0].name, 'code') @@ -51,8 +51,8 @@ def test_refs(self) -> None: class TestRefs(TestCase): def test_reference_aliases(self): results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_aliases.dbml') - posts, reviews, users = results['posts'], results['reviews'], results['users'] - posts2, reviews2, users2 = results['posts2'], results['reviews2'], results['users2'] + posts, reviews, users = results['public.posts'], results['public.reviews'], results['public.users'] + posts2, reviews2, users2 = results['public.posts2'], results['public.reviews2'], results['public.users2'] rs = results.refs self.assertEqual(rs[0].col1[0].table, users) @@ -68,8 +68,8 @@ def test_reference_aliases(self): def test_composite_references(self): results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_composite.dbml') self.assertEqual(len(results.tables), 4) - posts, reviews = results['posts'], results['reviews'] - posts2, reviews2 = results['posts2'], results['reviews2'] + posts, reviews = results['public.posts'], results['public.reviews'] + posts2, reviews2 = results['public.posts2'], results['public.reviews2'] rs = results.refs self.assertEqual(len(rs), 2) @@ -101,6 +101,6 @@ class TestPyDBMLParser(TestCase): def test_edge(self) -> None: p = PyDBMLParser('') with self.assertRaises(RuntimeError): - p.locate_table('test') + p.locate_table('myschema', 'test') with self.assertRaises(RuntimeError): p.parse_blueprint(1, 1, [1]) From cebf5f16293d37ef9211ab59e46ae10078793d76 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Thu, 19 May 2022 08:58:36 +0200 Subject: [PATCH 031/125] pythonic variable and attribute names --- pydbml/definitions/column.py | 91 +++++++++-------------- pydbml/definitions/common.py | 6 +- pydbml/definitions/enum.py | 48 ++++++------ pydbml/definitions/generic.py | 4 +- pydbml/definitions/index.py | 46 ++++++------ pydbml/definitions/project.py | 14 ++-- pydbml/definitions/reference.py | 84 ++++++++++----------- pydbml/definitions/table.py | 48 ++++++------ pydbml/definitions/table_group.py | 14 ++-- pydbml/parser/parser.py | 8 +- test/test_definitions/test_column.py | 84 ++++++++++----------- test/test_definitions/test_common.py | 32 ++++---- test/test_definitions/test_enum.py | 28 +++---- test/test_definitions/test_generic.py | 4 +- test/test_definitions/test_index.py | 74 +++++++++--------- test/test_definitions/test_project.py | 14 ++-- test/test_definitions/test_reference.py | 68 ++++++++--------- test/test_definitions/test_table.py | 42 +++++------ test/test_definitions/test_table_group.py | 8 +- test/test_editing.py | 2 +- 20 files changed, 350 insertions(+), 369 deletions(-) diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index d80bf98..7aaf4d5 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -18,54 +18,35 @@ from .reference import ref_inline -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') -type_args = ("(" + pp.originalTextFor(expression) + ")") +type_args = ("(" + pp.original_text_for(expression) + ")") # column type is parsed as a single string, it will be split by blueprint column_type = pp.Combine((name + '.' + name) | ((name) + type_args[0, 1])) - -# def parse_column_type(s, l, t) -> str: -# ''' -# int or "mytype" or varchar(255) or -# ''' -# result = {} -# if '.' in t['name']: -# result['schema'], result['name'] = t['name'].split('.') -# else: -# result['name'] = t['name'] - -# if 'args' in t: -# result['args'] = f'({t["args"]})' -# return result - - -# column_type.setParseAction(parse_column_type) - - default = pp.CaselessLiteral('default:').suppress() + _ - ( string_literal | expression_literal - | boolean_literal.setParseAction( - lambda s, l, t: { + | boolean_literal.set_parse_action( + lambda s, loc, tok: { 'true': True, 'false': False, 'NULL': None - }[t[0]] + }[tok[0]] ) - | number_literal.setParseAction( - lambda s, l, t: float(''.join(t[0])) if '.' in t[0] else int(t[0]) + | number_literal.set_parse_action( + lambda s, loc, tok: float(''.join(tok[0])) if '.' in tok[0] else int(tok[0]) ) ) column_setting = _ + ( - pp.CaselessLiteral("not null").setParseAction( - lambda s, l, t: True + pp.CaselessLiteral("not null").set_parse_action( + lambda s, loc, tok: True )('notnull') - | pp.CaselessLiteral("null").setParseAction( - lambda s, l, t: False + | pp.CaselessLiteral("null").set_parse_action( + lambda s, loc, tok: False )('notnull') | pp.CaselessLiteral("primary key")('pk') | pk('pk') @@ -78,31 +59,31 @@ column_settings = '[' - column_setting + ("," + column_setting)[...] + ']' + c -def parse_column_settings(s, l, t): +def parse_column_settings(s, loc, tok): ''' [ NOT NULL, increment, default: `now()`] ''' result = {} - if t.get('notnull'): + if tok.get('notnull'): result['not_null'] = True - if 'pk' in t: + if 'pk' in tok: result['pk'] = True - if 'unique' in t: + if 'unique' in tok: result['unique'] = True - if 'increment' in t: + if 'increment' in tok: result['autoinc'] = True - if 'note' in t: - result['note'] = t['note'] - if 'default' in t: - result['default'] = t['default'][0] - if 'ref' in t: - result['ref_blueprints'] = list(t['ref']) - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'note' in tok: + result['note'] = tok['note'] + if 'default' in tok: + result['default'] = tok['default'][0] + if 'ref' in tok: + result['ref_blueprints'] = list(tok['ref']) + if 'comment' in tok: + result['comment'] = tok['comment'][0] return result -column_settings.setParseAction(parse_column_settings) +column_settings.set_parse_action(parse_column_settings) constraint = pp.CaselessLiteral("unique") | pp.CaselessLiteral("pk") @@ -115,32 +96,32 @@ def parse_column_settings(s, l, t): ) + n -def parse_column(s, l, t): +def parse_column(s, loc, tok): ''' address varchar(255) [unique, not null, note: 'to include unit number'] ''' init_dict = { - 'name': t['name'], - 'type': t['type'], + 'name': tok['name'], + 'type': tok['type'], } # deprecated - for constraint in t.get('constraints', []): + for constraint in tok.get('constraints', []): if constraint == 'pk': init_dict['pk'] = True elif constraint == 'unique': init_dict['unique'] = True - if 'settings' in t: - init_dict.update(t['settings']) + if 'settings' in tok: + init_dict.update(tok['settings']) # comments after column definition have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment' in tok: + init_dict['comment'] = tok['comment'][0] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment return ColumnBlueprint(**init_dict) -table_column.setParseAction(parse_column) +table_column.set_parse_action(parse_column) diff --git a/pydbml/definitions/common.py b/pydbml/definitions/common.py index bdde161..58191e0 100644 --- a/pydbml/definitions/common.py +++ b/pydbml/definitions/common.py @@ -4,7 +4,7 @@ from .generic import string_literal -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') comment = ( pp.Suppress("//") + pp.SkipTo(pp.LineEnd()) @@ -26,10 +26,10 @@ # n = pp.Suppress('\n')[1, ...] note = pp.CaselessLiteral("note:") + _ - string_literal('text') -note.setParseAction(lambda s, l, t: NoteBlueprint(t['text'])) +note.set_parse_action(lambda s, loc, tok: NoteBlueprint(tok['text'])) note_object = pp.CaselessLiteral('note') + _ - '{' + _ - string_literal('text') + _ - '}' -note_object.setParseAction(lambda s, l, t: NoteBlueprint(t['text'])) +note_object.set_parse_action(lambda s, loc, tok: NoteBlueprint(tok['text'])) pk = pp.CaselessLiteral("pk") unique = pp.CaselessLiteral("unique") diff --git a/pydbml/definitions/enum.py b/pydbml/definitions/enum.py index 83a85db..f722f4e 100644 --- a/pydbml/definitions/enum.py +++ b/pydbml/definitions/enum.py @@ -11,46 +11,46 @@ from .common import note from .generic import name -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') enum_settings = '[' + _ - note('note') + _ - ']' + c -def parse_enum_settings(s, l, t): +def parse_enum_settings(s, loc, tok): ''' [note: "note content"] // comment ''' result = {} - if 'note' in t: - result['note'] = t['note'] - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'note' in tok: + result['note'] = tok['note'] + if 'comment' in tok: + result['comment'] = tok['comment'][0] return result -enum_settings.setParseAction(parse_enum_settings) +enum_settings.set_parse_action(parse_enum_settings) enum_item = _c + (name('name') + c + enum_settings('settings')[0, 1]) -def parse_enum_item(s, l, t): +def parse_enum_item(s, loc, tok): ''' student [note: "is stupid"] ''' - init_dict = {'name': t['name']} - if 'settings' in t: - init_dict.update(t['settings']) + init_dict = {'name': tok['name']} + if 'settings' in tok: + init_dict.update(tok['settings']) # comments after settings have priority - if 'comment' in t['settings']: - init_dict['comment'] = t['settings']['comment'] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment' in tok['settings']: + init_dict['comment'] = tok['settings']['comment'] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment return EnumItemBlueprint(**init_dict) -enum_item.setParseAction(parse_enum_item) +enum_item.set_parse_action(parse_enum_item) enum_body = enum_item[1, ...] @@ -65,7 +65,7 @@ def parse_enum_item(s, l, t): ) + end -def parse_enum(s, l, t): +def parse_enum(s, loc, tok): ''' enum members { janitor @@ -75,18 +75,18 @@ def parse_enum(s, l, t): } ''' init_dict = { - 'name': t['name'], - 'items': list(t['items']) + 'name': tok['name'], + 'items': list(tok['items']) } - if 'schema' in t: - init_dict['schema'] = t['schema'] + if 'schema' in tok: + init_dict['schema'] = tok['schema'] - if 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment return EnumBlueprint(**init_dict) -enum.setParseAction(parse_enum) +enum.set_parse_action(parse_enum) diff --git a/pydbml/definitions/generic.py b/pydbml/definitions/generic.py index ca05311..b29962f 100644 --- a/pydbml/definitions/generic.py +++ b/pydbml/definitions/generic.py @@ -2,7 +2,7 @@ from pydbml.parser.blueprints import ExpressionBlueprint -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') name = pp.Word(pp.alphanums + '_') | pp.QuotedString('"') @@ -17,7 +17,7 @@ pp.Suppress('`') + pp.CharsNotIn('`')[...] + pp.Suppress('`') -).setParseAction(lambda s, l, t: ExpressionBlueprint(t[0])) +).set_parse_action(lambda s, lok, tok: ExpressionBlueprint(tok[0])) boolean_literal = ( pp.CaselessLiteral('true') diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index 176d3ef..b014834 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -13,7 +13,7 @@ from .generic import name from .generic import string_literal -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') index_type = pp.CaselessLiteral("type:").suppress() + _ - ( pp.CaselessLiteral("btree")('type') | pp.CaselessLiteral("hash")('type') @@ -30,27 +30,27 @@ ) -def parse_index_settings(s, l, t): +def parse_index_settings(s, lok, tok): ''' [type: btree, name: 'name', unique, note: 'note'] ''' result = {} - if 'unique' in t: + if 'unique' in tok: result['unique'] = True - if 'name' in t: - result['name'] = t['name'] - if 'pk' in t: + if 'name' in tok: + result['name'] = tok['name'] + if 'pk' in tok: result['pk'] = True - if 'type' in t: - result['type'] = t['type'] - if 'note' in t: - result['note'] = t['note'] - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'type' in tok: + result['type'] = tok['type'] + if 'note' in tok: + result['note'] = tok['note'] + if 'comment' in tok: + result['comment'] = tok['comment'][0] return result -index_settings.setParseAction(parse_index_settings) +index_settings.set_parse_action(parse_index_settings) subject = name | expression_literal composite_index_syntax = ( @@ -73,7 +73,7 @@ def parse_index_settings(s, l, t): ) -def parse_index(s, l, t): +def parse_index(s, lok, tok): ''' (id, country) [pk] // composite primary key or @@ -85,22 +85,22 @@ def parse_index(s, l, t): ] ''' init_dict = {} - if isinstance(t['subject'], (str, ExpressionBlueprint)): - subjects = [t['subject']] + if isinstance(tok['subject'], (str, ExpressionBlueprint)): + subjects = [tok['subject']] else: - subjects = list(t['subject']) + subjects = list(tok['subject']) init_dict['subject_names'] = subjects - settings = t.get('settings', {}) + settings = tok.get('settings', {}) init_dict.update(settings) # comments after settings have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment' in tok: + init_dict['comment'] = tok['comment'][0] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment return IndexBlueprint(**init_dict) -index.setParseAction(parse_index) +index.set_parse_action(parse_index) diff --git a/pydbml/definitions/project.py b/pydbml/definitions/project.py index 11e8991..352ce36 100644 --- a/pydbml/definitions/project.py +++ b/pydbml/definitions/project.py @@ -10,7 +10,7 @@ from .generic import name from .generic import string_literal -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') project_field = pp.Group(name + _ + pp.Suppress(':') + _ - string_literal) @@ -27,16 +27,16 @@ ) + (n | pp.StringEnd()) -def parse_project(s, l, t): +def parse_project(s, loc, tok): ''' Project project_name { database_type: 'PostgreSQL' Note: 'Description of the project' } ''' - init_dict = {'name': t['name']} + init_dict = {'name': tok['name']} items = {} - for item in t.get('items', []): + for item in tok.get('items', []): if isinstance(item, NoteBlueprint): init_dict['note'] = item else: @@ -44,10 +44,10 @@ def parse_project(s, l, t): items[k] = v if items: init_dict['items'] = items - if 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment return ProjectBlueprint(**init_dict) -project.setParseAction(parse_project) +project.set_parse_action(parse_project) diff --git a/pydbml/definitions/reference.py b/pydbml/definitions/reference.py index a428b67..4c1ac8d 100644 --- a/pydbml/definitions/reference.py +++ b/pydbml/definitions/reference.py @@ -8,7 +8,7 @@ from .common import n from .generic import name -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') relation = pp.oneOf("> - <") @@ -23,24 +23,24 @@ ref_inline = pp.Literal("ref:") - relation('type') - col_name -def parse_inline_relation(s, l, t): +def parse_inline_relation(s, loc, tok): ''' ref: < table.column or ref: < schema1.table.column ''' result = { - 'type': t['type'], + 'type': tok['type'], 'inline': True, - 'table2': t['table'], - 'col2': t['field'] + 'table2': tok['table'], + 'col2': tok['field'] } - if 'schema' in t: - result['schema2'] = t['schema'] + if 'schema' in tok: + result['schema2'] = tok['schema'] return ReferenceBlueprint(**result) -ref_inline.setParseAction(parse_inline_relation) +ref_inline.set_parse_action(parse_inline_relation) on_option = ( pp.CaselessLiteral('no action') @@ -65,21 +65,21 @@ def parse_inline_relation(s, l, t): ) -def parse_ref_settings(s, l, t): +def parse_ref_settings(s, loc, tok): ''' [delete: cascade] ''' result = {} - if 'update' in t: - result['on_update'] = t['update'][0] - if 'delete' in t: - result['on_delete'] = t['delete'][0] - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'update' in tok: + result['on_update'] = tok['update'][0] + if 'delete' in tok: + result['on_delete'] = tok['delete'][0] + if 'comment' in tok: + result['comment'] = tok['comment'][0] return result -ref_settings.setParseAction(parse_ref_settings) +ref_settings.set_parse_action(parse_ref_settings) composite_name = ( '(' + pp.White()[...] @@ -105,7 +105,7 @@ def parse_ref_settings(s, l, t): ) -def parse_ref_cols(s, l, t): +def parse_ref_cols(s, loc, tok): ''' table1.col1 or @@ -114,15 +114,15 @@ def parse_ref_cols(s, l, t): schema1.table1.(col1, col2) ''' result = { - 'table': t['table'], - 'field': t['field'], + 'table': tok['table'], + 'field': tok['field'], } - if 'schema' in t: - result['schema'] = t['schema'] + if 'schema' in tok: + result['schema'] = tok['schema'] return result -ref_cols.setParseAction(parse_ref_cols) +ref_cols.set_parse_action(parse_ref_cols) ref_body = ( ref_cols('col1') @@ -152,7 +152,7 @@ def parse_ref_cols(s, l, t): ) -def parse_ref(s, l, t): +def parse_ref(s, loc, tok): ''' ref name: table1.col1 > table2.col2 or @@ -161,35 +161,35 @@ def parse_ref(s, l, t): } ''' init_dict = { - 'type': t['type'], + 'type': tok['type'], 'inline': False, - 'table1': t['col1']['table'], - 'col1': t['col1']['field'], - 'table2': t['col2']['table'], - 'col2': t['col2']['field'], + 'table1': tok['col1']['table'], + 'col1': tok['col1']['field'], + 'table2': tok['col2']['table'], + 'col2': tok['col2']['field'], } - if 'schema' in t['col1']: - init_dict['schema1'] = t['col1']['schema'] - if 'schema' in t['col2']: - init_dict['schema2'] = t['col2']['schema'] - if 'name' in t: - init_dict['name'] = t['name'] - if 'settings' in t: - init_dict.update(t['settings']) + if 'schema' in tok['col1']: + init_dict['schema1'] = tok['col1']['schema'] + if 'schema' in tok['col2']: + init_dict['schema2'] = tok['col2']['schema'] + if 'name' in tok: + init_dict['name'] = tok['name'] + if 'settings' in tok: + init_dict.update(tok['settings']) # comments after settings have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment' in tok: + init_dict['comment'] = tok['comment'][0] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment ref = ReferenceBlueprint(**init_dict) return ref -ref_short.setParseAction(parse_ref) -ref_long.setParseAction(parse_ref) +ref_short.set_parse_action(parse_ref) +ref_long.set_parse_action(parse_ref) ref = ref_short | ref_long + (n | pp.StringEnd()) diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index 5c21540..aad22f9 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -11,7 +11,7 @@ from .generic import name from .index import indexes -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') alias = pp.WordStart() + pp.Literal('as').suppress() - pp.WordEnd() - name @@ -26,19 +26,19 @@ table_settings = '[' + table_setting + (',' + table_setting)[...] + ']' -def parse_table_settings(s, l, t): +def parse_table_settings(s, loc, tok): ''' [headercolor: #cccccc, note: 'note'] ''' result = {} - if 'note' in t: - result['note'] = t['note'] - if 'header_color' in t: - result['header_color'] = t['header_color'] + if 'note' in tok: + result['note'] = tok['note'] + if 'header_color' in tok: + result['header_color'] = tok['header_color'] return result -table_settings.setParseAction(parse_table_settings) +table_settings.set_parse_action(parse_table_settings) note_element = note | note_object @@ -58,7 +58,7 @@ def parse_table_settings(s, l, t): ) + end -def parse_table(s, l, t): +def parse_table(s, loc, tok): ''' Table bookings as bb [headercolor: #cccccc] { id integer @@ -72,27 +72,27 @@ def parse_table(s, l, t): } ''' init_dict = { - 'name': t['name'], + 'name': tok['name'], } - if 'schema' in t: - init_dict['schema'] = t['schema'] - if 'settings' in t: - init_dict.update(t['settings']) - if 'alias' in t: - init_dict['alias'] = t['alias'][0] - if 'note' in t: + if 'schema' in tok: + init_dict['schema'] = tok['schema'] + if 'settings' in tok: + init_dict.update(tok['settings']) + if 'alias' in tok: + init_dict['alias'] = tok['alias'][0] + if 'note' in tok: # will override one from settings - init_dict['note'] = t['note'][0] - if 'indexes' in t: - init_dict['indexes'] = t['indexes'] - if 'columns' in t: - init_dict['columns'] = t['columns'] - if'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + init_dict['note'] = tok['note'][0] + if 'indexes' in tok: + init_dict['indexes'] = tok['indexes'] + if 'columns' in tok: + init_dict['columns'] = tok['columns'] + if'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment result = TableBlueprint(**init_dict) return result -table.setParseAction(parse_table) +table.set_parse_action(parse_table) diff --git a/pydbml/definitions/table_group.py b/pydbml/definitions/table_group.py index 422067b..8b92338 100644 --- a/pydbml/definitions/table_group.py +++ b/pydbml/definitions/table_group.py @@ -7,7 +7,7 @@ from .common import end from .generic import name -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') table_name = pp.Combine(name + '.' + name) | name @@ -20,7 +20,7 @@ ) + end -def parse_table_group(s, l, t): +def parse_table_group(s, loc, tok): ''' TableGroup tablegroup_name { table1 @@ -29,13 +29,13 @@ def parse_table_group(s, l, t): } ''' init_dict = { - 'name': t['name'], - 'items': list(t.get('items', [])) + 'name': tok['name'], + 'items': list(tok.get('items', [])) } - if 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment return TableGroupBlueprint(**init_dict) -table_group.setParseAction(parse_table_group) +table_group.set_parse_action(parse_table_group) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 26ab2d0..17d941c 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -26,7 +26,7 @@ from pydbml.tools import remove_bom -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') class PyDBML: @@ -105,7 +105,7 @@ def __init__(self, source: str): def parse(self): self._set_syntax() - self._syntax.parseString(self.source, parseAll=True) + self._syntax.parse_string(self.source, parseAll=True) self.build_database() return self.database @@ -139,8 +139,8 @@ def _set_syntax(self): ) self._syntax = expr[...] + ('\n' | comment)[...] + pp.StringEnd() - def parse_blueprint(self, s, l, t): - blueprint = t[0] + def parse_blueprint(self, s, loc, tok): + blueprint = tok[0] if isinstance(blueprint, TableBlueprint): self.tables.append(blueprint) ref_bps = blueprint.get_reference_blueprints() diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index 3ab6350..a81f16d 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -13,39 +13,39 @@ from pydbml.definitions.column import table_column -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestColumnType(TestCase): def test_simple(self) -> None: val = 'int' - res = column_type.parseString(val, parseAll=True) + res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_quoted(self) -> None: val = '"mytype"' - res = column_type.parseString(val, parseAll=True) + res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], 'mytype') def test_with_schema(self) -> None: val = 'myschema.mytype' - res = column_type.parseString(val, parseAll=True) + res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_expression(self) -> None: val = 'varchar(255)' - res = column_type.parseString(val, parseAll=True) + res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_symbols(self) -> None: val = '(*#^)' with self.assertRaises(ParseException): - column_type.parseString(val, parseAll=True) + column_type.parse_string(val, parseAll=True) def test_string(self) -> None: val = "'mytype'" with self.assertRaises(ParseException): - column_type.parseString(val, parseAll=True) + column_type.parse_string(val, parseAll=True) class TestDefault(TestCase): @@ -53,9 +53,9 @@ def test_string(self) -> None: val = "default: 'string'" val2 = "default: \n\n'string'" expected = 'string' - res = default.parseString(val, parseAll=True) + res = default.parse_string(val, parseAll=True) self.assertEqual(res[0], expected) - res = default.parseString(val2, parseAll=True) + res = default.parse_string(val2, parseAll=True) self.assertEqual(res[0], expected) def test_expression(self) -> None: @@ -64,13 +64,13 @@ def test_expression(self) -> None: val = f"default: `{expr1}`" val2 = f"default: `{expr2}`" val3 = f"default: ``" - res = default.parseString(val, parseAll=True) + res = default.parse_string(val, parseAll=True) self.assertIsInstance(res[0], ExpressionBlueprint) self.assertEqual(res[0].text, expr1) - res = default.parseString(val2, parseAll=True) + res = default.parse_string(val2, parseAll=True) self.assertIsInstance(res[0], ExpressionBlueprint) self.assertEqual(res[0].text, expr2) - res = default.parseString(val3, parseAll=True) + res = default.parse_string(val3, parseAll=True) self.assertIsInstance(res[0], ExpressionBlueprint) self.assertEqual(res[0].text, '') @@ -78,20 +78,20 @@ def test_bool(self) -> None: vals = ['true', 'false', 'null'] exps = [True, False, 'NULL'] while len(vals) > 0: - res = default.parseString(f'default: {vals.pop()}', parseAll=True) + res = default.parse_string(f'default: {vals.pop()}', parseAll=True) self.assertEqual(exps.pop(), res[0]) def test_numbers(self) -> None: vals = [0, 17, 13.3, 2.0] while len(vals) > 0: cur = vals.pop() - res = default.parseString(f'default: {cur}', parseAll=True) + res = default.parse_string(f'default: {cur}', parseAll=True) self.assertEqual((cur), res[0]) def test_wrong(self) -> None: val = "default: now" with self.assertRaises(ParseSyntaxException): - default.parseString(val, parseAll=True) + default.parse_string(val, parseAll=True) class TestColumnSetting(TestCase): @@ -104,7 +104,7 @@ def test_pass(self) -> None: 'default: 123', 'ref: > table.column'] for val in vals: - column_setting.parseString(val, parseAll=True) + column_setting.parse_string(val, parseAll=True) def test_fail(self) -> None: vals = ['wrong', @@ -112,75 +112,75 @@ def test_fail(self) -> None: '"pk"'] for val in vals: with self.assertRaises(ParseException): - column_setting.parseString(val, parseAll=True) + column_setting.parse_string(val, parseAll=True) class TestColumnSettings(TestCase): def test_nulls(self) -> None: - res = column_settings.parseString('[NULL]', parseAll=True) + res = column_settings.parse_string('[NULL]', parseAll=True) self.assertNotIn('not_null', res[0]) - res = column_settings.parseString('[NOT NULL]', parseAll=True) + res = column_settings.parse_string('[NOT NULL]', parseAll=True) self.assertTrue(res[0]['not_null']) - res = column_settings.parseString('[NULL, NOT NULL]', parseAll=True) + res = column_settings.parse_string('[NULL, NOT NULL]', parseAll=True) self.assertTrue(res[0]['not_null']) - res = column_settings.parseString('[NOT NULL, NULL]', parseAll=True) + res = column_settings.parse_string('[NOT NULL, NULL]', parseAll=True) self.assertNotIn('not_null', res[0]) def test_pk(self) -> None: - res = column_settings.parseString('[pk]', parseAll=True) + res = column_settings.parse_string('[pk]', parseAll=True) self.assertTrue(res[0]['pk']) - res = column_settings.parseString('[primary key]', parseAll=True) + res = column_settings.parse_string('[primary key]', parseAll=True) self.assertTrue(res[0]['pk']) - res = column_settings.parseString('[primary key, pk]', parseAll=True) + res = column_settings.parse_string('[primary key, pk]', parseAll=True) self.assertTrue(res[0]['pk']) def test_unique_increment(self) -> None: - res = column_settings.parseString('[unique, increment]', parseAll=True) + res = column_settings.parse_string('[unique, increment]', parseAll=True) self.assertTrue(res[0]['unique']) self.assertTrue(res[0]['autoinc']) def test_refs(self) -> None: - res = column_settings.parseString('[ref: > table.column]', parseAll=True) + res = column_settings.parse_string('[ref: > table.column]', parseAll=True) self.assertEqual(len(res[0]['ref_blueprints']), 1) - res = column_settings.parseString('[ref: - table.column, ref: < table2.column2]', parseAll=True) + res = column_settings.parse_string('[ref: - table.column, ref: < table2.column2]', parseAll=True) self.assertEqual(len(res[0]['ref_blueprints']), 2) def test_note_default(self) -> None: - res = column_settings.parseString('[default: 123, note: "mynote"]', parseAll=True) + res = column_settings.parse_string('[default: 123, note: "mynote"]', parseAll=True) self.assertIn('note', res[0]) self.assertEqual(res[0]['default'], 123) def test_wrong(self) -> None: val = "[wrong]" with self.assertRaises(ParseSyntaxException): - column_settings.parseString(val, parseAll=True) + column_settings.parse_string(val, parseAll=True) class TestConstraint(TestCase): def test_should_parse(self) -> None: - constraint.parseString('unique', parseAll=True) - constraint.parseString('pk', parseAll=True) + constraint.parse_string('unique', parseAll=True) + constraint.parse_string('pk', parseAll=True) def test_should_fail(self) -> None: with self.assertRaises(ParseException): - constraint.parseString('wrong', parseAll=True) + constraint.parse_string('wrong', parseAll=True) class TestColumn(TestCase): def test_no_settings(self) -> None: val = 'address varchar(255)\n' - res = table_column.parseString(val, parseAll=True) + res = table_column.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'address') self.assertEqual(res[0].type, 'varchar(255)') def test_with_constraint(self) -> None: val = 'user_id integer unique\n' - res = table_column.parseString(val, parseAll=True) + res = table_column.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'user_id') self.assertEqual(res[0].type, 'integer') self.assertTrue(res[0].unique) val2 = 'user_id integer pk unique\n' - res2 = table_column.parseString(val2, parseAll=True) + res2 = table_column.parse_string(val2, parseAll=True) self.assertEqual(res2[0].name, 'user_id') self.assertEqual(res2[0].type, 'integer') self.assertTrue(res2[0].unique) @@ -188,7 +188,7 @@ def test_with_constraint(self) -> None: def test_with_settings(self) -> None: val = "_test_ \"mytype\" [unique, not null, note: 'to include unit number']\n" - res = table_column.parseString(val, parseAll=True) + res = table_column.parse_string(val, parseAll=True) self.assertEqual(res[0].name, '_test_') self.assertEqual(res[0].type, 'mytype') self.assertTrue(res[0].unique) @@ -198,11 +198,11 @@ def test_with_settings(self) -> None: def test_enum_type_bad(self) -> None: val = "_test_ myschema.mytype(12) [unique]\n" with self.assertRaises(ParseException): - table_column.parseString(val, parseAll=True) + table_column.parse_string(val, parseAll=True) def test_settings_and_constraints(self) -> None: val = "_test_ \"mytype\" unique pk [not null]\n" - res = table_column.parseString(val, parseAll=True) + res = table_column.parse_string(val, parseAll=True) self.assertEqual(res[0].name, '_test_') self.assertEqual(res[0].type, 'mytype') self.assertTrue(res[0].unique) @@ -211,26 +211,26 @@ def test_settings_and_constraints(self) -> None: def test_comment_above(self) -> None: val = '//comment above\naddress varchar\n' - res = table_column.parseString(val, parseAll=True) + res = table_column.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'address') self.assertEqual(res[0].type, 'varchar') self.assertEqual(res[0].comment, 'comment above') def test_comment_after(self) -> None: val = 'address varchar //comment after\n' - res = table_column.parseString(val, parseAll=True) + res = table_column.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'address') self.assertEqual(res[0].type, 'varchar') self.assertEqual(res[0].comment, 'comment after') val2 = 'user_id integer pk unique //comment after\n' - res2 = table_column.parseString(val2, parseAll=True) + res2 = table_column.parse_string(val2, parseAll=True) self.assertEqual(res2[0].name, 'user_id') self.assertEqual(res2[0].type, 'integer') self.assertTrue(res2[0].unique) self.assertTrue(res2[0].pk) self.assertEqual(res2[0].comment, 'comment after') val3 = "_test_ \"mytype\" unique pk [not null] //comment after\n" - res3 = table_column.parseString(val3, parseAll=True) + res3 = table_column.parse_string(val3, parseAll=True) self.assertEqual(res3[0].name, '_test_') self.assertEqual(res3[0].type, 'mytype') self.assertTrue(res3[0].unique) diff --git a/test/test_definitions/test_common.py b/test/test_definitions/test_common.py index fa3b27c..1107fb0 100644 --- a/test/test_definitions/test_common.py +++ b/test/test_definitions/test_common.py @@ -9,86 +9,86 @@ from pydbml.definitions.common import note_object -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestComment(TestCase): def test_comment_endstring(self) -> None: val = '//test comment' - res = comment.parseString(val, parseAll=True) + res = comment.parse_string(val, parseAll=True) self.assertEqual(res[0], 'test comment') def test_comment_endline(self) -> None: val = '//test comment\n\n\n\n\n' - res = comment.parseString(val) + res = comment.parse_string(val) self.assertEqual(res[0], 'test comment') def test_multiline_comment(self) -> None: val = '/*test comment*/' - res = comment.parseString(val) + res = comment.parse_string(val) self.assertEqual(res[0], 'test comment') val2 = '/*\nline1\nline2\nline3\n*/' - res2 = comment.parseString(val2) + res2 = comment.parse_string(val2) self.assertEqual(res2[0], '\nline1\nline2\nline3\n') class Test_c(TestCase): def test_comment(self) -> None: val = '\n\n\n\n//comment line 1\n\n//comment line 2' - res = _c.parseString(val, parseAll=True) + res = _c.parse_string(val, parseAll=True) self.assertEqual(list(res), ['comment line 1', 'comment line 2']) class TestNote(TestCase): def test_single_quote(self) -> None: val = "note: 'test note'" - res = note.parseString(val, parseAll=True) + res = note.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_double_quote(self) -> None: val = 'note: \n "test note"' - res = note.parseString(val, parseAll=True) + res = note.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_multiline(self) -> None: val = "note: '''line1\nline2\nline3'''" - res = note.parseString(val, parseAll=True) + res = note.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'line1\nline2\nline3') def test_unclosed_quote(self) -> None: val = 'note: "test note' with self.assertRaises(ParseSyntaxException): - note.parseString(val, parseAll=True) + note.parse_string(val, parseAll=True) def test_not_allowed_multiline(self) -> None: val = "note: 'line1\nline2\nline3'" with self.assertRaises(ParseSyntaxException): - note.parseString(val, parseAll=True) + note.parse_string(val, parseAll=True) class TestNoteObject(TestCase): def test_single_quote(self) -> None: val = "note {'test note'}" - res = note_object.parseString(val, parseAll=True) + res = note_object.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_double_quote(self) -> None: val = 'note \n\n {\n\n"test note"\n\n}' - res = note_object.parseString(val, parseAll=True) + res = note_object.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_multiline(self) -> None: val = "note\n{ '''line1\nline2\nline3'''}" - res = note_object.parseString(val, parseAll=True) + res = note_object.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'line1\nline2\nline3') def test_unclosed_quote(self) -> None: val = 'note{ "test note}' with self.assertRaises(ParseSyntaxException): - note_object.parseString(val, parseAll=True) + note_object.parse_string(val, parseAll=True) def test_not_allowed_multiline(self) -> None: val = "note { 'line1\nline2\nline3' }" with self.assertRaises(ParseSyntaxException): - note_object.parseString(val, parseAll=True) + note_object.parse_string(val, parseAll=True) diff --git a/test/test_definitions/test_enum.py b/test/test_definitions/test_enum.py index f345929..0d33d07 100644 --- a/test/test_definitions/test_enum.py +++ b/test/test_definitions/test_enum.py @@ -8,49 +8,49 @@ from pydbml.definitions.enum import enum_settings -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestEnumSettings(TestCase): def test_note(self) -> None: val = '[note: "note content"]' - enum_settings.parseString(val, parseAll=True) + enum_settings.parse_string(val, parseAll=True) def test_wrong(self) -> None: val = '[wrong]' with self.assertRaises(ParseSyntaxException): - enum_settings.parseString(val, parseAll=True) + enum_settings.parse_string(val, parseAll=True) class TestEnumItem(TestCase): def test_no_settings(self) -> None: val = 'student' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') def test_settings(self) -> None: val = 'student [note: "our future, help us God"]' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') def test_comment_before(self) -> None: val = '//comment before\nstudent [note: "our future, help us God"]' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') self.assertEqual(res[0].comment, 'comment before') def test_comment_after(self) -> None: val = 'student [note: "our future, help us God"] //comment after' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') self.assertEqual(res[0].comment, 'comment after') def test_comment_both(self) -> None: val = '//comment before\nstudent [note: "our future, help us God"] //comment after' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') self.assertEqual(res[0].comment, 'comment after') @@ -59,27 +59,27 @@ def test_comment_both(self) -> None: class TestEnum(TestCase): def test_singe_item(self) -> None: val = 'enum members {\nstudent\n}' - res = enum.parseString(val, parseAll=True) + res = enum.parse_string(val, parseAll=True) self.assertEqual(len(res[0].items), 1) self.assertEqual(res[0].name, 'members') def test_several_items(self) -> None: val = 'enum members {janitor teacher\nstudent\nheadmaster\n}' - res = enum.parseString(val, parseAll=True) + res = enum.parse_string(val, parseAll=True) self.assertEqual(len(res[0].items), 4) self.assertEqual(res[0].name, 'members') def test_schema(self) -> None: val1 = 'enum members {janitor teacher\nstudent\nheadmaster\n}' - res1 = enum.parseString(val1, parseAll=True) + res1 = enum.parse_string(val1, parseAll=True) self.assertEqual(res1[0].schema, 'public') val2 = 'enum myschema.members {janitor teacher\nstudent\nheadmaster\n}' - res2 = enum.parseString(val2, parseAll=True) + res2 = enum.parse_string(val2, parseAll=True) self.assertEqual(res2[0].schema, 'myschema') def test_comment(self) -> None: val = '//comment before\nenum members {janitor teacher\nstudent\nheadmaster\n}' - res = enum.parseString(val, parseAll=True) + res = enum.parse_string(val, parseAll=True) self.assertEqual(len(res[0].items), 4) self.assertEqual(res[0].name, 'members') self.assertEqual(res[0].comment, 'comment before') @@ -87,4 +87,4 @@ def test_comment(self) -> None: def test_oneline(self) -> None: val = 'enum members {student}' with self.assertRaises(ParseSyntaxException): - enum.parseString(val, parseAll=True) + enum.parse_string(val, parseAll=True) diff --git a/test/test_definitions/test_generic.py b/test/test_definitions/test_generic.py index 8570515..25afe6e 100644 --- a/test/test_definitions/test_generic.py +++ b/test/test_definitions/test_generic.py @@ -6,12 +6,12 @@ from pydbml.parser.blueprints import ExpressionBlueprint -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestExpressionLiteral(TestCase): def test_expression_literal(self) -> None: val = '`SUM(amount)`' - res = expression_literal.parseString(val) + res = expression_literal.parse_string(val) self.assertIsInstance(res[0], ExpressionBlueprint) self.assertEqual(res[0].text, 'SUM(amount)') diff --git a/test/test_definitions/test_index.py b/test/test_definitions/test_index.py index 318d192..00a7ee8 100644 --- a/test/test_definitions/test_index.py +++ b/test/test_definitions/test_index.py @@ -15,82 +15,82 @@ from pydbml.parser.blueprints import ExpressionBlueprint -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestIndexType(TestCase): def test_correct(self) -> None: val = 'Type: BTREE' - res = index_type.parseString(val, parseAll=True) + res = index_type.parse_string(val, parseAll=True) self.assertEqual(res['type'], 'btree') val2 = 'type:\nhash' - res2 = index_type.parseString(val2, parseAll=True) + res2 = index_type.parse_string(val2, parseAll=True) self.assertEqual(res2['type'], 'hash') def test_incorrect(self) -> None: val = 'type: wrong' with self.assertRaises(ParseSyntaxException): - index_type.parseString(val, parseAll=True) + index_type.parse_string(val, parseAll=True) class TestIndexSetting(TestCase): def test_unique(self) -> None: val = 'unique' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['unique'], 'unique') def test_type(self) -> None: val = 'type: btree' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['type'], 'btree') def test_name(self) -> None: val = 'name: "index name"' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['name'], 'index name') def test_wrong_name(self) -> None: val = 'name: index name' with self.assertRaises(ParseSyntaxException): - index_setting.parseString(val, parseAll=True) + index_setting.parse_string(val, parseAll=True) val2 = 'name:,' with self.assertRaises(ParseSyntaxException): - index_setting.parseString(val2, parseAll=True) + index_setting.parse_string(val2, parseAll=True) def test_note(self) -> None: val = 'note: "note text"' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['note'].text, 'note text') class TestIndexSettings(TestCase): def test_unique(self) -> None: val = '[unique]' - res = index_settings.parseString(val, parseAll=True) + res = index_settings.parse_string(val, parseAll=True) self.assertTrue(res[0]['unique']) def test_name_type_multiline(self) -> None: val = '[\nname: "index name"\n,\ntype:\nbtree\n]' - res = index_settings.parseString(val, parseAll=True) + res = index_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['type'], 'btree') self.assertEqual(res[0]['name'], 'index name') def test_pk(self) -> None: val = '[\npk\n]' - res = index_settings.parseString(val, parseAll=True) + res = index_settings.parse_string(val, parseAll=True) self.assertTrue(res[0]['pk']) def test_wrong_pk(self) -> None: val = '[pk, name: "not allowed"]' with self.assertRaises(ParseSyntaxException): - index_settings.parseString(val, parseAll=True) + index_settings.parse_string(val, parseAll=True) val2 = '[note: "pk not allowed", pk]' with self.assertRaises(ParseSyntaxException): - index_settings.parseString(val2, parseAll=True) + index_settings.parse_string(val2, parseAll=True) def test_all(self) -> None: val = '[type: hash, name: "index name", note: "index note", unique]' - res = index_settings.parseString(val, parseAll=True) + res = index_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['type'], 'hash') self.assertEqual(res[0]['name'], 'index name') self.assertEqual(res[0]['note'].text, 'index note') @@ -100,50 +100,50 @@ def test_all(self) -> None: class TestSubject(TestCase): def test_name(self) -> None: val = 'my_column' - res = subject.parseString(val, parseAll=True) + res = subject.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_expression(self) -> None: val = '`id*3`' - res = subject.parseString(val, parseAll=True) + res = subject.parse_string(val, parseAll=True) self.assertIsInstance(res[0], ExpressionBlueprint) self.assertEqual(res[0].text, 'id*3') def test_wrong(self) -> None: val = '12d*(' with self.assertRaises(ParseException): - subject.parseString(val, parseAll=True) + subject.parse_string(val, parseAll=True) class TestSingleIndex(TestCase): def test_no_settings(self) -> None: val = 'my_column' - res = single_index_syntax.parseString(val, parseAll=True) + res = single_index_syntax.parse_string(val, parseAll=True) self.assertEqual(res['subject'], val) def test_settings(self) -> None: val = 'my_column [unique]' - res = single_index_syntax.parseString(val, parseAll=True) + res = single_index_syntax.parse_string(val, parseAll=True) self.assertEqual(res['subject'], 'my_column') self.assertTrue(res['settings']['unique']) def test_settings_on_new_line(self) -> None: val = 'my_column\n[unique]' with self.assertRaises(ParseException): - single_index_syntax.parseString(val, parseAll=True) + single_index_syntax.parse_string(val, parseAll=True) class TestCompositeIndex(TestCase): def test_no_settings(self) -> None: val = '(my_column, my_another_column)' - res = composite_index_syntax.parseString(val, parseAll=True) + res = composite_index_syntax.parse_string(val, parseAll=True) self.assertIn('my_column', list(res['subject'])) self.assertIn('my_another_column', list(res['subject'])) self.assertEqual(len(res['subject']), 2) def test_settings(self) -> None: val = '(my_column, my_another_column) [unique]' - res = composite_index_syntax.parseString(val, parseAll=True) + res = composite_index_syntax.parse_string(val, parseAll=True) self.assertIn('my_column', list(res['subject'])) self.assertIn('my_another_column', list(res['subject'])) self.assertEqual(len(res['subject']), 2) @@ -152,63 +152,63 @@ def test_settings(self) -> None: def test_new_line(self) -> None: val = '(my_column,\nmy_another_column) [unique]' with self.assertRaises(ParseException): - composite_index_syntax.parseString(val, parseAll=True) + composite_index_syntax.parse_string(val, parseAll=True) val2 = '(my_column, my_another_column)\n[unique]' with self.assertRaises(ParseException): - composite_index_syntax.parseString(val2, parseAll=True) + composite_index_syntax.parse_string(val2, parseAll=True) class TestIndex(TestCase): def test_single(self) -> None: val = 'my_column' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) def test_expression(self) -> None: val = '(`id*3`)' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertIsInstance(res[0].subject_names[0], ExpressionBlueprint) self.assertEqual(res[0].subject_names[0].text, 'id*3') def test_composite(self) -> None: val = '(my_column, my_another_column)' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) def test_composite_with_expression(self) -> None: val = '(`id*3`, fieldname)' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertIsInstance(res[0].subject_names[0], ExpressionBlueprint) self.assertEqual(res[0].subject_names[0].text, 'id*3') self.assertEqual(res[0].subject_names[1], 'fieldname') def test_with_settings(self) -> None: val = '(my_column, my_another_column) [unique]' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) self.assertTrue(res[0].unique) def test_comment_above(self) -> None: val = '//comment above\nmy_column [unique]' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment above') def test_comment_after(self) -> None: val = 'my_column [unique] //comment after' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') val = 'my_column //comment after' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertEqual(res[0].comment, 'comment after') def test_both_comments(self) -> None: val = '//comment before\nmy_column [unique] //comment after' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') @@ -229,10 +229,10 @@ def test_valid(self) -> None: (`id*3`,`getdate()`) (`id*3`,id) }''' - res = indexes.parseString(val) + res = indexes.parse_string(val) self.assertEqual(len(res), 8) def test_invalid(self) -> None: val = 'indexes {my_column' with self.assertRaises(ParseSyntaxException): - indexes.parseString(val) + indexes.parse_string(val) diff --git a/test/test_definitions/test_project.py b/test/test_definitions/test_project.py index 650965a..c049a0a 100644 --- a/test/test_definitions/test_project.py +++ b/test/test_definitions/test_project.py @@ -7,36 +7,36 @@ from pydbml.definitions.project import project_field -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestProjectField(TestCase): def test_ok(self) -> None: val = "field: 'value'" - project_field.parseString(val, parseAll=True) + project_field.parse_string(val, parseAll=True) def test_nok(self) -> None: val = "field: value" with self.assertRaises(ParseSyntaxException): - project_field.parseString(val, parseAll=True) + project_field.parse_string(val, parseAll=True) class TestProject(TestCase): def test_empty(self) -> None: val = 'project name {}' - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') def test_fields(self) -> None: val = "project name {field1: 'value1' field2: 'value2'}" - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items['field1'], 'value1') self.assertEqual(res[0].items['field2'], 'value2') def test_fields_and_note(self) -> None: val = "project name {\nfield1: 'value1'\nfield2: 'value2'\nnote: 'note value'}" - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items['field1'], 'value1') self.assertEqual(res[0].items['field2'], 'value2') @@ -44,7 +44,7 @@ def test_fields_and_note(self) -> None: def test_comment(self) -> None: val = "//comment before\nproject name {\nfield1: 'value1'\nfield2: 'value2'\nnote: 'note value'}" - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items['field1'], 'value1') self.assertEqual(res[0].items['field2'], 'value2') diff --git a/test/test_definitions/test_reference.py b/test/test_definitions/test_reference.py index 2d1d82d..cbcc96c 100644 --- a/test/test_definitions/test_reference.py +++ b/test/test_definitions/test_reference.py @@ -12,25 +12,25 @@ from pydbml.definitions.reference import relation -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestRelation(TestCase): def test_ok(self) -> None: vals = ['>', '-', '<'] for v in vals: - relation.parseString(v, parseAll=True) + relation.parse_string(v, parseAll=True) def test_nok(self) -> None: val = 'wrong' with self.assertRaises(ParseException): - relation.parseString(val, parseAll=True) + relation.parse_string(val, parseAll=True) class TestInlineRelation(TestCase): def test_ok(self) -> None: val = 'ref: < table.column' - res = ref_inline.parseString(val, parseAll=True) + res = ref_inline.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '<') self.assertEqual(res[0].table2, 'table') self.assertEqual(res[0].col2, 'column') @@ -39,10 +39,10 @@ def test_ok(self) -> None: def test_schema(self) -> None: val1 = 'ref: < table.column' - res1 = ref_inline.parseString(val1, parseAll=True) + res1 = ref_inline.parse_string(val1, parseAll=True) self.assertEqual(res1[0].schema2, 'public') val2 = 'ref: < myschema.table.column' - res2 = ref_inline.parseString(val2, parseAll=True) + res2 = ref_inline.parse_string(val2, parseAll=True) self.assertEqual(res2[0].schema2, 'myschema') def test_nok(self) -> None: @@ -54,7 +54,7 @@ def test_nok(self) -> None: ] for v in vals: with self.assertRaises(ParseSyntaxException): - ref_inline.parseString(v) + ref_inline.parse_string(v) class TestOnOption(TestCase): @@ -67,23 +67,23 @@ def test_ok(self) -> None: 'set default' ] for v in vals: - on_option.parseString(v, parseAll=True) + on_option.parse_string(v, parseAll=True) def test_nok(self) -> None: val = 'wrong' with self.assertRaises(ParseException): - on_option.parseString(val, parseAll=True) + on_option.parse_string(val, parseAll=True) class TestRefSettings(TestCase): def test_one_setting(self) -> None: val = '[delete: cascade]' - res = ref_settings.parseString(val, parseAll=True) + res = ref_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['on_delete'], 'cascade') def test_two_settings_multiline(self) -> None: val = '[\ndelete:\ncascade\n,\nupdate:\nrestrict\n]' - res = ref_settings.parseString(val, parseAll=True) + res = ref_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['on_delete'], 'cascade') self.assertEqual(res[0]['on_update'], 'restrict') @@ -91,7 +91,7 @@ def test_two_settings_multiline(self) -> None: class TestRefShort(TestCase): def test_no_name(self) -> None: val = 'ref: table1.col1 > table2.col2' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -100,17 +100,17 @@ def test_no_name(self) -> None: def test_schema(self) -> None: val1 = 'ref: table1.col1 > table2.col2' - res1 = ref_short.parseString(val1, parseAll=True) + res1 = ref_short.parse_string(val1, parseAll=True) self.assertEqual(res1[0].schema1, 'public') self.assertEqual(res1[0].schema2, 'public') val2 = 'ref: myschema1.table1.col1 > myschema2.table2.col2' - res2 = ref_short.parseString(val2, parseAll=True) + res2 = ref_short.parse_string(val2, parseAll=True) self.assertEqual(res2[0].schema1, 'myschema1') self.assertEqual(res2[0].schema2, 'myschema2') def test_name(self) -> None: val = 'ref name: table1.col1 > table2.col2' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -120,7 +120,7 @@ def test_name(self) -> None: def test_composite_with_name(self) -> None: val = 'ref name: table1.(col1 , col2,col3) > table2.(col11 , col21,col31)' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, '(col1 , col2,col3)') @@ -130,7 +130,7 @@ def test_composite_with_name(self) -> None: def test_with_settings(self) -> None: val = 'ref name: table1.col1 > table2.col2 [update: cascade, delete: restrict]' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -143,14 +143,14 @@ def test_with_settings(self) -> None: def test_newline(self) -> None: val = 'ref\nname: table1.col1 > table2.col2' with self.assertRaises(ParseException): - ref_short.parseString(val, parseAll=True) + ref_short.parse_string(val, parseAll=True) val2 = 'ref name: table1.col1\n> table2.col2' with self.assertRaises(ParseSyntaxException): - ref_short.parseString(val2, parseAll=True) + ref_short.parse_string(val2, parseAll=True) def test_comment_above(self) -> None: val = '//comment above\nref name: table1.col1 > table2.col2' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -161,7 +161,7 @@ def test_comment_above(self) -> None: def test_comment_after(self) -> None: val = 'ref name: table1.col1 > table2.col2 //comment after' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -170,7 +170,7 @@ def test_comment_after(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = 'ref name: table1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after' - res2 = ref_short.parseString(val2, parseAll=True) + res2 = ref_short.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') @@ -183,7 +183,7 @@ def test_comment_after(self) -> None: def test_comment_both(self) -> None: val = '//comment above\nref name: table1.col1 > table2.col2 //comment after' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -192,7 +192,7 @@ def test_comment_both(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = '//comment above\nref name: table1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after' - res2 = ref_short.parseString(val2, parseAll=True) + res2 = ref_short.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') @@ -207,7 +207,7 @@ def test_comment_both(self) -> None: class TestRefLong(TestCase): def test_no_name(self) -> None: val = 'ref {table1.col1 > table2.col2}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -216,17 +216,17 @@ def test_no_name(self) -> None: def test_schema(self) -> None: val1 = 'ref {table1.col1 > table2.col2}' - res1 = ref_long.parseString(val1, parseAll=True) + res1 = ref_long.parse_string(val1, parseAll=True) self.assertEqual(res1[0].schema1, 'public') self.assertEqual(res1[0].schema2, 'public') val2 = 'ref {myschema1.table1.col1 > myschema2.table2.col2}' - res2 = ref_long.parseString(val2, parseAll=True) + res2 = ref_long.parse_string(val2, parseAll=True) self.assertEqual(res2[0].schema1, 'myschema1') self.assertEqual(res2[0].schema2, 'myschema2') def test_name(self) -> None: val = 'ref\nname\n{\ntable1.col1 > table2.col2\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -236,7 +236,7 @@ def test_name(self) -> None: def test_with_settings(self) -> None: val = 'ref name {\ntable1.col1 > table2.col2 [update: cascade, delete: restrict]\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -248,7 +248,7 @@ def test_with_settings(self) -> None: def test_comment_above(self) -> None: val = '//comment above\nref name {\ntable1.col1 > table2.col2\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -259,7 +259,7 @@ def test_comment_above(self) -> None: def test_comment_after(self) -> None: val = 'ref name {\ntable1.col1 > table2.col2 //comment after\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -268,7 +268,7 @@ def test_comment_after(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = 'ref name {\ntable1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after\n}' - res2 = ref_long.parseString(val2, parseAll=True) + res2 = ref_long.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') @@ -281,7 +281,7 @@ def test_comment_after(self) -> None: def test_comment_both(self) -> None: val = '//comment above\nref name {\ntable1.col1 > table2.col2 //comment after\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -290,7 +290,7 @@ def test_comment_both(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = '//comment above\nref name {\ntable1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after\n}' - res2 = ref_long.parseString(val2, parseAll=True) + res2 = ref_long.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') diff --git a/test/test_definitions/test_table.py b/test/test_definitions/test_table.py index 8028292..9bcfcfa 100644 --- a/test/test_definitions/test_table.py +++ b/test/test_definitions/test_table.py @@ -11,41 +11,41 @@ from pydbml.definitions.table import table_settings -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestAlias(TestCase): def test_ok(self) -> None: val = 'as Alias' - alias.parseString(val, parseAll=True) + alias.parse_string(val, parseAll=True) def test_nok(self) -> None: val = 'asalias' with self.assertRaises(ParseSyntaxException): - alias.parseString(val, parseAll=True) + alias.parse_string(val, parseAll=True) class TestHeaderColor(TestCase): def test_oneline(self) -> None: val = 'headercolor: #CCCCCC' - res = header_color.parseString(val, parseAll=True) + res = header_color.parse_string(val, parseAll=True) self.assertEqual(res['header_color'], '#CCCCCC') def test_multiline(self) -> None: val = 'headercolor:\n\n#E02' - res = header_color.parseString(val, parseAll=True) + res = header_color.parse_string(val, parseAll=True) self.assertEqual(res['header_color'], '#E02') class TestTableSettings(TestCase): def test_one(self) -> None: val = '[headercolor: #E024DF]' - res = table_settings.parseString(val, parseAll=True) + res = table_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['header_color'], '#E024DF') def test_both(self) -> None: val = '[note: "note content", headercolor: #E024DF]' - res = table_settings.parseString(val, parseAll=True) + res = table_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['header_color'], '#E024DF') self.assertIn('note', res[0]) @@ -53,12 +53,12 @@ def test_both(self) -> None: class TestTableBody(TestCase): def test_one_column(self) -> None: val = 'id integer [pk, increment]\n' - res = table_body.parseString(val, parseAll=True) + res = table_body.parse_string(val, parseAll=True) self.assertEqual(len(res['columns']), 1) def test_two_columns(self) -> None: val = 'id integer [pk, increment]\nname string\n' - res = table_body.parseString(val, parseAll=True) + res = table_body.parse_string(val, parseAll=True) self.assertEqual(len(res['columns']), 2) def test_columns_indexes(self) -> None: @@ -69,7 +69,7 @@ def test_columns_indexes(self) -> None: indexes { (id, country) [pk] // composite primary key }''' - res = table_body.parseString(val, parseAll=True) + res = table_body.parse_string(val, parseAll=True) self.assertEqual(len(res['columns']), 3) self.assertEqual(len(res['indexes']), 1) @@ -82,7 +82,7 @@ def test_columns_indexes_note(self) -> None: indexes { (id, country) [pk] // composite primary key }''' - res = table_body.parseString(val, parseAll=True) + res = table_body.parse_string(val, parseAll=True) self.assertEqual(len(res['columns']), 3) self.assertEqual(len(res['indexes']), 1) self.assertIsNotNone(res['note']) @@ -96,7 +96,7 @@ def test_columns_indexes_note(self) -> None: indexes { (id, country) [pk] // composite primary key }''' - res2 = table_body.parseString(val2, parseAll=True) + res2 = table_body.parse_string(val2, parseAll=True) self.assertEqual(len(res2['columns']), 3) self.assertEqual(len(res2['indexes']), 1) self.assertIsNotNone(res2['note']) @@ -108,7 +108,7 @@ def test_no_columns(self) -> None: (id, country) [pk] // composite primary key }''' with self.assertRaises(ParseException): - table_body.parseString(val, parseAll=True) + table_body.parse_string(val, parseAll=True) def test_columns_after_indexes(self) -> None: val = ''' @@ -118,38 +118,38 @@ def test_columns_after_indexes(self) -> None: } id integer''' with self.assertRaises(ParseException): - table_body.parseString(val, parseAll=True) + table_body.parse_string(val, parseAll=True) class TestTable(TestCase): def test_simple(self) -> None: val = 'table ids {\nid integer\n}' - res = table.parseString(val, parseAll=True) + res = table.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'ids') self.assertEqual(len(res[0].columns), 1) def test_with_alias(self) -> None: val = 'table ids as ii {\nid integer\n}' - res = table.parseString(val, parseAll=True) + res = table.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'ids') self.assertEqual(res[0].alias, 'ii') self.assertEqual(len(res[0].columns), 1) def test_schema(self) -> None: val = 'table ids as ii {\nid integer\n}' - res = table.parseString(val, parseAll=True) + res = table.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'ids') self.assertEqual(res[0].schema, 'public') # default self.assertEqual(len(res[0].columns), 1) val = 'table myschema.ids as ii {\nid integer\n}' - res = table.parseString(val, parseAll=True) + res = table.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'ids') self.assertEqual(res[0].schema, 'myschema') def test_with_settings(self) -> None: val = 'table ids as ii [headercolor: #ccc, note: "headernote"] {\nid integer\n}' - res = table.parseString(val, parseAll=True) + res = table.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'ids') self.assertEqual(res[0].alias, 'ii') self.assertEqual(res[0].header_color, '#ccc') @@ -165,7 +165,7 @@ def test_with_body_note(self) -> None: id integer note: "bodynote" }''' - res = table.parseString(val, parseAll=True) + res = table.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'ids') self.assertEqual(res[0].alias, 'ii') self.assertEqual(res[0].header_color, '#ccc') @@ -185,7 +185,7 @@ def test_with_indexes(self) -> None: (id, country) [pk] // composite primary key } }''' - res = table.parseString(val, parseAll=True) + res = table.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'ids') self.assertEqual(res[0].alias, 'ii') self.assertEqual(res[0].header_color, '#ccc') diff --git a/test/test_definitions/test_table_group.py b/test/test_definitions/test_table_group.py index 7e81251..d8e2a7a 100644 --- a/test/test_definitions/test_table_group.py +++ b/test/test_definitions/test_table_group.py @@ -5,24 +5,24 @@ from pydbml.definitions.table_group import table_group -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestProject(TestCase): def test_empty(self) -> None: val = 'TableGroup name {}' - res = table_group.parseString(val, parseAll=True) + res = table_group.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') def test_fields(self) -> None: val = "TableGroup name {table1 table2}" - res = table_group.parseString(val, parseAll=True) + res = table_group.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items, ['table1', 'table2']) def test_comment(self) -> None: val = "//comment before\nTableGroup name\n{\ntable1\ntable2\n}" - res = table_group.parseString(val, parseAll=True) + res = table_group.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items, ['table1', 'table2']) self.assertEqual(res[0].comment, 'comment before') diff --git a/test/test_editing.py b/test/test_editing.py index 68120b7..da6d518 100644 --- a/test/test_editing.py +++ b/test/test_editing.py @@ -8,7 +8,7 @@ from pydbml import PyDBML -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' From aa8b9de52bc92c1861611f6f28bbda32f115b4e5 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Thu, 19 May 2022 08:59:26 +0200 Subject: [PATCH 032/125] type_ -> type --- TODO.md | 4 +-- pydbml/classes/column.py | 4 +-- pydbml/classes/index.py | 4 +-- pydbml/classes/reference.py | 4 +-- pydbml/parser/blueprints.py | 6 ++-- test/test_blueprints/test_reference.py | 4 +-- test/test_classes/test_column.py | 44 +++++++++++++------------- test/test_classes/test_index.py | 4 +-- 8 files changed, 37 insertions(+), 37 deletions(-) diff --git a/TODO.md b/TODO.md index f2805f9..ba88ef7 100644 --- a/TODO.md +++ b/TODO.md @@ -1,7 +1,7 @@ * - Creating dbml schema in python -- pyparsing new var names (+possibly new features) +* - pyparsing new var names (+possibly new features) * - enum type -- `_type` -> `type` +* - `_type` -> `type` * - expression class - schema.add and .delete to support multiple arguments (handle errors properly) * - 2.3.1 Multiline comment /* ... */ diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 55024b2..cab29cf 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -23,7 +23,7 @@ class Column(SQLOjbect): def __init__(self, name: str, - type_: str, + type: str, unique: bool = False, not_null: bool = False, pk: bool = False, @@ -33,7 +33,7 @@ def __init__(self, # ref_blueprints: Optional[List[ReferenceBlueprint]] = None, comment: Optional[str] = None): self.name = name - self.type = type_ + self.type = type self.unique = unique self.not_null = not_null self.pk = pk diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index 354c3a1..2c311db 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -23,7 +23,7 @@ def __init__(self, subjects: List[Union[str, 'Column', 'Expression']], name: Optional[str] = None, unique: bool = False, - type_: Optional[str] = None, + type: Optional[str] = None, pk: bool = False, note: Optional[Union['Note', str]] = None, comment: Optional[str] = None): @@ -33,7 +33,7 @@ def __init__(self, self.name = name if name else None self.unique = unique - self.type = type_ + self.type = type self.pk = pk self.note = Note(note) self.comment = comment diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index ac8dd98..c2e9989 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -22,7 +22,7 @@ class Reference(SQLOjbect): required_attributes = ('type', 'col1', 'col2') def __init__(self, - type_: Literal['>', '<', '-'], + type: Literal['>', '<', '-'], col1: Union[Column, Collection[Column]], col2: Union[Column, Collection[Column]], name: Optional[str] = None, @@ -31,7 +31,7 @@ def __init__(self, on_delete: Optional[str] = None, inline: bool = False): self.database = None - self.type = type_ + self.type = type self.col1 = [col1] if isinstance(col1, Column) else list(col1) self.col2 = [col2] if isinstance(col2, Column) else list(col2) self.name = name if name else None diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index da87685..df9c7d5 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -82,7 +82,7 @@ def build(self) -> 'Reference': col2 = [table2[col] for col in col2_list] return Reference( - type_=self.type, + type=self.type, inline=self.inline, col1=col1, col2=col2, @@ -120,7 +120,7 @@ def build(self) -> 'Column': break return Column( name=self.name, - type_=self.type, + type=self.type, unique=self.unique, not_null=self.not_null, pk=self.pk, @@ -149,7 +149,7 @@ def build(self) -> 'Index': subjects=[], name=self.name, unique=self.unique, - type_=self.type, + type=self.type, pk=self.pk, note=self.note.build() if self.note else None, comment=self.comment diff --git a/test/test_blueprints/test_reference.py b/test/test_blueprints/test_reference.py index 7bf3b10..dd9bfd7 100644 --- a/test/test_blueprints/test_reference.py +++ b/test/test_blueprints/test_reference.py @@ -23,12 +23,12 @@ def test_build_minimal(self) -> None: t1 = Table( name='table1' ) - c1 = Column(name='col1', type_='Number') + c1 = Column(name='col1', type='Number') t1.add_column(c1) t2 = Table( name='table2' ) - c2 = Column(name='col2', type_='Varchar') + c2 = Column(name='col2', type='Varchar') t2.add_column(c2) with self.assertRaises(RuntimeError): diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index c95da1c..874593e 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -12,7 +12,7 @@ class TestColumn(TestCase): def test_attributes(self) -> None: name = 'name' - type_ = 'type' + type = 'type' unique = True not_null = True pk = True @@ -22,7 +22,7 @@ def test_attributes(self) -> None: comment = 'comment' col = Column( name=name, - type_=type_, + type=type, unique=unique, not_null=not_null, pk=pk, @@ -32,7 +32,7 @@ def test_attributes(self) -> None: comment=comment, ) self.assertEqual(col.name, name) - self.assertEqual(col.type, type_) + self.assertEqual(col.type, type) self.assertEqual(col.unique, unique) self.assertEqual(col.not_null, not_null) self.assertEqual(col.pk, pk) @@ -53,13 +53,13 @@ def test_database_set(self) -> None: def test_basic_sql(self) -> None: r = Column(name='id', - type_='integer') + type='integer') expected = '"id" integer' self.assertEqual(r.sql, expected) def test_pk_autoinc(self) -> None: r = Column(name='id', - type_='integer', + type='integer', pk=True, autoinc=True) expected = '"id" integer PRIMARY KEY AUTOINCREMENT' @@ -67,7 +67,7 @@ def test_pk_autoinc(self) -> None: def test_unique_not_null(self) -> None: r = Column(name='id', - type_='integer', + type='integer', unique=True, not_null=True) expected = '"id" integer UNIQUE NOT NULL' @@ -75,14 +75,14 @@ def test_unique_not_null(self) -> None: def test_default(self) -> None: r = Column(name='order', - type_='integer', + type='integer', default=0) expected = '"order" integer DEFAULT 0' self.assertEqual(r.sql, expected) def test_comment(self) -> None: r = Column(name='id', - type_='integer', + type='integer', unique=True, not_null=True, comment="Column comment") @@ -94,7 +94,7 @@ def test_comment(self) -> None: def test_dbml_simple(self): c = Column( name='order', - type_='integer' + type='integer' ) t = Table(name='Test') t.add_column(c) @@ -107,7 +107,7 @@ def test_dbml_simple(self): def test_dbml_full(self): c = Column( name='order', - type_='integer', + type='integer', unique=True, not_null=True, pk=True, @@ -129,7 +129,7 @@ def test_dbml_full(self): def test_dbml_multiline_note(self): c = Column( name='order', - type_='integer', + type='integer', not_null=True, note='Note on the column\nmultiline', comment='Comment on the column' @@ -148,7 +148,7 @@ def test_dbml_multiline_note(self): def test_dbml_default(self): c = Column( name='order', - type_='integer', + type='integer', default='String value' ) t = Table(name='Test') @@ -184,7 +184,7 @@ def test_dbml_default(self): self.assertEqual(c.dbml, expected) def test_database(self): - c1 = Column(name='client_id', type_='integer') + c1 = Column(name='client_id', type='integer') t1 = Table(name='products') self.assertIsNone(c1.database) @@ -195,16 +195,16 @@ def test_database(self): self.assertIs(c1.database, s) def test_get_refs(self) -> None: - c1 = Column(name='client_id', type_='integer') + c1 = Column(name='client_id', type='integer') with self.assertRaises(TableNotFoundError): c1.get_refs() t1 = Table(name='products') t1.add_column(c1) - c2 = Column(name='id', type_='integer', autoinc=True, pk=True) + c2 = Column(name='id', type='integer', autoinc=True, pk=True) t2 = Table(name='clients') t2.add_column(c2) - ref = Reference(type_='>', col1=c1, col2=c2, inline=True) + ref = Reference(type='>', col1=c1, col2=c2, inline=True) s = Database() s.add(t1) s.add(t2) @@ -213,14 +213,14 @@ def test_get_refs(self) -> None: self.assertEqual(c1.get_refs(), [ref]) def test_dbml_with_ref(self) -> None: - c1 = Column(name='client_id', type_='integer') + c1 = Column(name='client_id', type='integer') t1 = Table(name='products') t1.add_column(c1) - c2 = Column(name='id', type_='integer', autoinc=True, pk=True) + c2 = Column(name='id', type='integer', autoinc=True, pk=True) t2 = Table(name='clients') t2.add_column(c2) - ref = Reference(type_='>', col1=c1, col2=c2) + ref = Reference(type='>', col1=c1, col2=c2) s = Database() s.add(t1) s.add(t2) @@ -235,14 +235,14 @@ def test_dbml_with_ref(self) -> None: self.assertEqual(c2.dbml, expected) def test_dbml_with_ref_and_properties(self) -> None: - c1 = Column(name='client_id', type_='integer') + c1 = Column(name='client_id', type='integer') t1 = Table(name='products') t1.add_column(c1) - c2 = Column(name='id', type_='integer', autoinc=True, pk=True) + c2 = Column(name='id', type='integer', autoinc=True, pk=True) t2 = Table(name='clients') t2.add_column(c2) - ref = Reference(type_='<', col1=c2, col2=c1) + ref = Reference(type='<', col1=c2, col2=c1) s = Database() s.add(t1) s.add(t2) diff --git a/test/test_classes/test_index.py b/test/test_classes/test_index.py index 324d993..b10ca02 100644 --- a/test/test_classes/test_index.py +++ b/test/test_classes/test_index.py @@ -47,7 +47,7 @@ def test_unique_type_composite(self) -> None: t.columns[0], t.columns[1] ], - type_='hash', + type='hash', unique=True ) t.add_index(r) @@ -103,7 +103,7 @@ def test_dbml_full(self): subjects=[t.columns[0], Expression('getdate()')], name='Dated id', unique=True, - type_='hash', + type='hash', pk=True, note='Note on the column', comment='Comment on the index' From bde47c362f285eb19a0f9b93f6bdbc1fada8404d Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Thu, 19 May 2022 09:30:45 +0200 Subject: [PATCH 033/125] validate add_index, cleanup imports --- TODO.md | 5 ++++- changelog.md | 2 ++ pydbml/__init__.py | 9 --------- pydbml/classes/base.py | 3 ++- pydbml/classes/column.py | 8 ++++---- pydbml/classes/index.py | 8 ++++++-- pydbml/classes/reference.py | 4 ++-- pydbml/classes/table.py | 4 +++- pydbml/classes/table_group.py | 1 - pydbml/definitions/column.py | 3 +-- pydbml/definitions/common.py | 3 +-- pydbml/definitions/enum.py | 5 ++--- pydbml/definitions/index.py | 5 ++--- pydbml/definitions/project.py | 5 ++--- pydbml/definitions/reference.py | 3 +-- pydbml/definitions/table.py | 3 +-- pydbml/definitions/table_group.py | 3 +-- pydbml/parser/blueprints.py | 4 ++-- pydbml/parser/parser.py | 8 +++----- pydbml/tools.py | 1 + test/test_blueprints/test_column.py | 2 +- test/test_blueprints/test_index.py | 2 +- test/test_blueprints/test_table.py | 14 +++++++------- test/test_classes/test_column.py | 6 +++--- test/test_classes/test_enum.py | 2 +- test/test_classes/test_expression.py | 3 ++- test/test_classes/test_index.py | 19 +++++++++++++++++-- test/test_classes/test_reference.py | 3 ++- test/test_classes/test_table.py | 2 +- test/test_database.py | 2 +- test/test_definitions/test_column.py | 2 +- test/test_doctest.py | 1 - test/test_integration.py | 9 ++++----- test/test_parser.py | 2 +- test/test_tools.py | 2 +- 35 files changed, 83 insertions(+), 75 deletions(-) diff --git a/TODO.md b/TODO.md index ba88ef7..4066deb 100644 --- a/TODO.md +++ b/TODO.md @@ -7,4 +7,7 @@ * - 2.3.1 Multiline comment /* ... */ * - 2.4 Multiple Schemas - validation on "add_index", "add_table" etc -* - enum type in table definition with schema \ No newline at end of file +* - enum type in table definition with schema +- add coverage badge +- add docstrings +- new docs diff --git a/changelog.md b/changelog.md index 7721a20..0e1e2ca 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,8 @@ - col1 col2 in ref are as they were in dbml - Expression class - add multiline comment +- support multiple schemas +- type_ -> type # 0.4.2 diff --git a/pydbml/__init__.py b/pydbml/__init__.py index 558c211..5da1d6e 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -1,14 +1,5 @@ -# import doctest -# import unittest - from . import classes from .parser import PyDBML - from pydbml.constants import MANY_TO_ONE from pydbml.constants import ONE_TO_MANY from pydbml.constants import ONE_TO_ONE - - -# def load_tests(loader, tests, ignore): -# tests.addTests(doctest.DocTestSuite(classes)) -# return tests diff --git a/pydbml/classes/base.py b/pydbml/classes/base.py index 419adca..61c55bc 100644 --- a/pydbml/classes/base.py +++ b/pydbml/classes/base.py @@ -1,5 +1,6 @@ -from typing import Tuple from typing import Any +from typing import Tuple + from pydbml.exceptions import AttributeMissingError diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index cab29cf..d220ae2 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -1,15 +1,15 @@ -from typing import Optional -from typing import Union from typing import List +from typing import Optional from typing import TYPE_CHECKING +from typing import Union from .base import SQLOjbect -from .note import Note from .expression import Expression +from .note import Note +from pydbml.exceptions import TableNotFoundError from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql from pydbml.tools import note_option_to_dbml -from pydbml.exceptions import TableNotFoundError if TYPE_CHECKING: # pragma: no cover from .table import Table diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index 2c311db..595285b 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -52,7 +52,9 @@ def __repr__(self): >>> i >>> from .table import Table - >>> Table('test').add_index(i) + >>> t = Table('test') + >>> t.add_column(c) + >>> t.add_index(i) >>> i ''' @@ -67,7 +69,9 @@ def __str__(self): >>> print(i) Index([col, (c*2)]) >>> from .table import Table - >>> Table('test').add_index(i) + >>> t = Table('test') + >>> t.add_column(c) + >>> t.add_index(i) >>> print(i) Index(test[col, (c*2)]) ''' diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index c2e9989..2cdf3d6 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -7,10 +7,10 @@ from .column import Column from pydbml.constants import MANY_TO_ONE from pydbml.constants import ONE_TO_ONE -from pydbml.tools import comment_to_dbml -from pydbml.tools import comment_to_sql from pydbml.exceptions import DBMLError from pydbml.exceptions import TableNotFoundError +from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_sql class Reference(SQLOjbect): diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index f79b72a..820200a 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -73,7 +73,9 @@ def add_index(self, i: Index) -> None: Adds index to self.indexes attribute and sets in this index the `table` attribute. ''' - + for subject in i.subjects: + if isinstance(subject, Column) and subject.table != self: + raise ColumnNotFoundError(f'Column {subject} not in the table') i.table = self self.indexes.append(i) diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py index 06fe718..3386fa0 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/classes/table_group.py @@ -1,6 +1,5 @@ from typing import List from typing import Optional -from typing import Union from .table import Table from pydbml.tools import comment_to_dbml diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index 7aaf4d5..a9fe0c1 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -1,7 +1,5 @@ import pyparsing as pp -from pydbml.parser.blueprints import ColumnBlueprint - from .common import _ from .common import _c from .common import c @@ -16,6 +14,7 @@ from .generic import number_literal from .generic import string_literal from .reference import ref_inline +from pydbml.parser.blueprints import ColumnBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/definitions/common.py b/pydbml/definitions/common.py index 58191e0..99d3601 100644 --- a/pydbml/definitions/common.py +++ b/pydbml/definitions/common.py @@ -1,8 +1,7 @@ import pyparsing as pp -from pydbml.parser.blueprints import NoteBlueprint - from .generic import string_literal +from pydbml.parser.blueprints import NoteBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/definitions/enum.py b/pydbml/definitions/enum.py index f722f4e..0c6d6fa 100644 --- a/pydbml/definitions/enum.py +++ b/pydbml/definitions/enum.py @@ -1,8 +1,5 @@ import pyparsing as pp -from pydbml.parser.blueprints import EnumBlueprint -from pydbml.parser.blueprints import EnumItemBlueprint - from .common import _ from .common import _c from .common import c @@ -10,6 +7,8 @@ from .common import n from .common import note from .generic import name +from pydbml.parser.blueprints import EnumBlueprint +from pydbml.parser.blueprints import EnumItemBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index b014834..f08090b 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -1,8 +1,5 @@ import pyparsing as pp -from pydbml.parser.blueprints import IndexBlueprint -from pydbml.parser.blueprints import ExpressionBlueprint - from .common import _ from .common import _c from .common import c @@ -12,6 +9,8 @@ from .generic import expression_literal from .generic import name from .generic import string_literal +from pydbml.parser.blueprints import ExpressionBlueprint +from pydbml.parser.blueprints import IndexBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/definitions/project.py b/pydbml/definitions/project.py index 352ce36..16bd2e0 100644 --- a/pydbml/definitions/project.py +++ b/pydbml/definitions/project.py @@ -1,14 +1,13 @@ import pyparsing as pp -from pydbml.parser.blueprints import NoteBlueprint -from pydbml.parser.blueprints import ProjectBlueprint - from .common import _ from .common import _c from .common import n from .common import note from .generic import name from .generic import string_literal +from pydbml.parser.blueprints import NoteBlueprint +from pydbml.parser.blueprints import ProjectBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/definitions/reference.py b/pydbml/definitions/reference.py index 4c1ac8d..c8eebf8 100644 --- a/pydbml/definitions/reference.py +++ b/pydbml/definitions/reference.py @@ -1,12 +1,11 @@ import pyparsing as pp -from pydbml.parser.blueprints import ReferenceBlueprint - from .common import _ from .common import _c from .common import c from .common import n from .generic import name +from pydbml.parser.blueprints import ReferenceBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index aad22f9..1b4a58e 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -1,7 +1,5 @@ import pyparsing as pp -from pydbml.parser.blueprints import TableBlueprint - from .column import table_column from .common import _ from .common import _c @@ -10,6 +8,7 @@ from .common import note_object from .generic import name from .index import indexes +from pydbml.parser.blueprints import TableBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/definitions/table_group.py b/pydbml/definitions/table_group.py index 8b92338..cf56938 100644 --- a/pydbml/definitions/table_group.py +++ b/pydbml/definitions/table_group.py @@ -1,11 +1,10 @@ import pyparsing as pp -from pydbml.parser.blueprints import TableGroupBlueprint - from .common import _ from .common import _c from .common import end from .generic import name +from pydbml.parser.blueprints import TableGroupBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index df9c7d5..773d3e3 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -1,10 +1,10 @@ from dataclasses import dataclass +from typing import Any from typing import Collection from typing import Dict from typing import List from typing import Literal from typing import Optional -from typing import Any from typing import Union from pydbml.classes import Column @@ -160,7 +160,7 @@ def build(self) -> 'Index': class TableBlueprint(Blueprint): name: str schema: str = 'public' - columns: List[ColumnBlueprint] = None + columns: Optional[List[ColumnBlueprint]] = None indexes: Optional[List[IndexBlueprint]] = None alias: Optional[str] = None note: Optional[NoteBlueprint] = None diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 17d941c..705b9dc 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -1,20 +1,19 @@ from __future__ import annotations - -import pyparsing as pp - from io import TextIOWrapper from pathlib import Path - from typing import List from typing import Optional from typing import Union +import pyparsing as pp + from .blueprints import EnumBlueprint from .blueprints import ProjectBlueprint from .blueprints import ReferenceBlueprint from .blueprints import TableBlueprint from .blueprints import TableGroupBlueprint from pydbml.classes import Table +from pydbml.database import Database from pydbml.definitions.common import comment from pydbml.definitions.enum import enum from pydbml.definitions.project import project @@ -22,7 +21,6 @@ from pydbml.definitions.table import table from pydbml.definitions.table_group import table_group from pydbml.exceptions import TableNotFoundError -from pydbml.database import Database from pydbml.tools import remove_bom diff --git a/pydbml/tools.py b/pydbml/tools.py index 48df6da..63f2b81 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING + if TYPE_CHECKING: # pragma: no cover from .classes import Note diff --git a/test/test_blueprints/test_column.py b/test/test_blueprints/test_column.py index 622bbd5..ab9a7cf 100644 --- a/test/test_blueprints/test_column.py +++ b/test/test_blueprints/test_column.py @@ -5,9 +5,9 @@ from pydbml.classes import Enum from pydbml.classes import EnumItem from pydbml.classes import Note +from pydbml.database import Database from pydbml.parser.blueprints import ColumnBlueprint from pydbml.parser.blueprints import NoteBlueprint -from pydbml.database import Database class TestColumn(TestCase): diff --git a/test/test_blueprints/test_index.py b/test/test_blueprints/test_index.py index 6c09e26..dc17440 100644 --- a/test/test_blueprints/test_index.py +++ b/test/test_blueprints/test_index.py @@ -1,7 +1,7 @@ from unittest import TestCase -from pydbml.classes import Note from pydbml.classes import Index +from pydbml.classes import Note from pydbml.parser.blueprints import IndexBlueprint from pydbml.parser.blueprints import NoteBlueprint diff --git a/test/test_blueprints/test_table.py b/test/test_blueprints/test_table.py index 4a566a9..b12ce42 100644 --- a/test/test_blueprints/test_table.py +++ b/test/test_blueprints/test_table.py @@ -1,17 +1,17 @@ from unittest import TestCase -from pydbml.exceptions import ColumnNotFoundError -from pydbml.classes import Note -from pydbml.classes import Table -from pydbml.classes import Index from pydbml.classes import Column from pydbml.classes import Expression +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.classes import Table +from pydbml.exceptions import ColumnNotFoundError +from pydbml.parser.blueprints import ColumnBlueprint +from pydbml.parser.blueprints import ExpressionBlueprint from pydbml.parser.blueprints import IndexBlueprint from pydbml.parser.blueprints import NoteBlueprint -from pydbml.parser.blueprints import ColumnBlueprint -from pydbml.parser.blueprints import TableBlueprint from pydbml.parser.blueprints import ReferenceBlueprint -from pydbml.parser.blueprints import ExpressionBlueprint +from pydbml.parser.blueprints import TableBlueprint class TestTable(TestCase): diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index 874593e..8802917 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -1,11 +1,11 @@ from unittest import TestCase -from pydbml.database import Database from pydbml.classes import Column from pydbml.classes import Expression -from pydbml.classes import Table -from pydbml.classes import Reference from pydbml.classes import Note +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.database import Database from pydbml.exceptions import TableNotFoundError diff --git a/test/test_classes/test_enum.py b/test/test_classes/test_enum.py index 66c8361..b28e84d 100644 --- a/test/test_classes/test_enum.py +++ b/test/test_classes/test_enum.py @@ -1,5 +1,5 @@ -from pydbml.classes import EnumItem from pydbml.classes import Enum +from pydbml.classes import EnumItem from unittest import TestCase diff --git a/test/test_classes/test_expression.py b/test/test_classes/test_expression.py index 5990f3e..e61fdf3 100644 --- a/test/test_classes/test_expression.py +++ b/test/test_classes/test_expression.py @@ -1,6 +1,7 @@ -from pydbml.classes import Expression from unittest import TestCase +from pydbml.classes import Expression + class TestNote(TestCase): def test_sql(self): diff --git a/test/test_classes/test_index.py b/test/test_classes/test_index.py index b10ca02..9c77571 100644 --- a/test/test_classes/test_index.py +++ b/test/test_classes/test_index.py @@ -1,9 +1,10 @@ from unittest import TestCase -from pydbml.classes import Index -from pydbml.classes import Table from pydbml.classes import Column from pydbml.classes import Expression +from pydbml.classes import Index +from pydbml.classes import Table +from pydbml.exceptions import ColumnNotFoundError class TestIndex(TestCase): @@ -25,6 +26,20 @@ def test_basic_sql_str(self) -> None: expected = 'CREATE INDEX ON "products" (id);' self.assertEqual(r.sql, expected) + def test_column_not_in_table(self) -> None: + t = Table('products') + c = Column('id', 'integer') + i = Index(subjects=[c]) + with self.assertRaises(ColumnNotFoundError): + t.add_index(i) + self.assertIsNone(i.table) + t2 = Table('customers') + t2.add_column(c) + i2 = Index(subjects=[c]) + with self.assertRaises(ColumnNotFoundError): + t.add_index(i2) + self.assertIsNone(i2.table) + def test_comment(self) -> None: t = Table('products') t.add_column(Column('id', 'integer')) diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 4193899..7cd25d3 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -1,7 +1,8 @@ from unittest import TestCase + from pydbml.classes import Column -from pydbml.classes import Table from pydbml.classes import Reference +from pydbml.classes import Table from pydbml.exceptions import DBMLError from pydbml.exceptions import TableNotFoundError diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 1f77a97..69e4152 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -6,10 +6,10 @@ from pydbml.classes import Note from pydbml.classes import Reference from pydbml.classes import Table +from pydbml.database import Database from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import IndexNotFoundError from pydbml.exceptions import UnknownDatabaseError -from pydbml.database import Database class TestTable(TestCase): diff --git a/test/test_database.py b/test/test_database.py index 86c6dc9..552512f 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -10,8 +10,8 @@ from pydbml.classes import Reference from pydbml.classes import Table from pydbml.classes import TableGroup -from pydbml.exceptions import DatabaseValidationError from pydbml.database import Database +from pydbml.exceptions import DatabaseValidationError TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index a81f16d..262699b 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -4,13 +4,13 @@ from pyparsing import ParseSyntaxException from pyparsing import ParserElement -from pydbml.parser.blueprints import ExpressionBlueprint from pydbml.definitions.column import column_setting from pydbml.definitions.column import column_settings from pydbml.definitions.column import column_type from pydbml.definitions.column import constraint from pydbml.definitions.column import default from pydbml.definitions.column import table_column +from pydbml.parser.blueprints import ExpressionBlueprint ParserElement.set_default_whitespace_chars(' \t\r') diff --git a/test/test_doctest.py b/test/test_doctest.py index 48688f9..877cca4 100644 --- a/test/test_doctest.py +++ b/test/test_doctest.py @@ -1,5 +1,4 @@ import doctest -import unittest from pydbml import database from pydbml.classes import column diff --git a/test/test_integration.py b/test/test_integration.py index 7e45f66..7b134bb 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -2,20 +2,19 @@ from pathlib import Path from unittest import TestCase -from unittest.mock import patch, Mock +from pydbml import PyDBML from pydbml.classes import Column from pydbml.classes import Enum from pydbml.classes import EnumItem +from pydbml.classes import Expression +from pydbml.classes import Index +from pydbml.classes import Note from pydbml.classes import Project from pydbml.classes import Reference -from pydbml.classes import Index -from pydbml.classes import Expression from pydbml.classes import Table from pydbml.classes import TableGroup -from pydbml.classes import Note from pydbml.database import Database -from pydbml import PyDBML TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' diff --git a/test/test_parser.py b/test/test_parser.py index f045b3f..9b53057 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -4,9 +4,9 @@ from unittest import TestCase from pydbml import PyDBML -from pydbml.parser.parser import PyDBMLParser from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError +from pydbml.parser.parser import PyDBMLParser TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' diff --git a/test/test_tools.py b/test/test_tools.py index 6ba4abc..db29063 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -3,8 +3,8 @@ from pydbml.classes import Note from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql -from pydbml.tools import note_option_to_dbml from pydbml.tools import indent +from pydbml.tools import note_option_to_dbml class TestCommentToDBML(TestCase): From 1d3c9f73dff4604e733aba78a21e36ce06691ec9 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Fri, 20 May 2022 09:03:32 +0200 Subject: [PATCH 034/125] fix enum type in column, adding enum items by string --- pydbml/classes/column.py | 16 ++++++++++++--- pydbml/classes/enum.py | 14 ++++++++++--- test/test_classes/test_column.py | 34 ++++++++++++++++++++++++++++++++ test/test_editing.py | 4 ++-- 4 files changed, 60 insertions(+), 8 deletions(-) diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index d220ae2..68ac695 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -5,6 +5,7 @@ from .base import SQLOjbect from .expression import Expression +from .enum import Enum from .note import Note from pydbml.exceptions import TableNotFoundError from pydbml.tools import comment_to_dbml @@ -23,7 +24,7 @@ class Column(SQLOjbect): def __init__(self, name: str, - type: str, + type: Union[str, Enum], unique: bool = False, not_null: bool = False, pk: bool = False, @@ -65,7 +66,12 @@ def sql(self): ''' self.check_attributes_for_sql() - components = [f'"{self.name}"', str(self.type)] + components = [f'"{self.name}"'] + if isinstance(self.type, Enum): + components.append(self.type._get_full_name_for_sql()) + else: + components.append(str(self.type)) + if self.pk: components.append('PRIMARY KEY') if self.autoinc: @@ -97,7 +103,11 @@ def default_to_str(val: Union[Expression, str]) -> str: return val result = comment_to_dbml(self.comment) if self.comment else '' - result += f'"{self.name}" {self.type}' + result += f'"{self.name}" ' + if isinstance(self.type, Enum): + result += self.type._get_full_name_for_sql() + else: + result += self.type options = [ref.dbml for ref in self.get_refs() if ref.inline] if self.pk: diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index 6c3ab5f..fdbb4d9 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Iterable from typing import Optional from typing import Union @@ -57,14 +57,22 @@ class Enum(SQLOjbect): def __init__(self, name: str, - items: List['EnumItem'], + items: Iterable[Union['EnumItem', str]], schema: str = 'public', comment: Optional[str] = None): self.database = None self.name = name self.schema = schema - self.items = items self.comment = comment + self.items = [] + for item in items: + self.add_item(item) + + def add_item(self, item: Union['EnumItem', str]) -> None: + if isinstance(item, EnumItem): + self.items.append(item) + elif isinstance(item, str): + self.items.append(EnumItem(item)) def __getitem__(self, key: int) -> EnumItem: return self.items[key] diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index 8802917..cfd6b61 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -5,6 +5,7 @@ from pydbml.classes import Note from pydbml.classes import Reference from pydbml.classes import Table +from pydbml.classes import Enum from pydbml.database import Database from pydbml.exceptions import TableNotFoundError @@ -57,6 +58,21 @@ def test_basic_sql(self) -> None: expected = '"id" integer' self.assertEqual(r.sql, expected) + def test_sql_enum_type(self) -> None: + et = Enum('product status', ('production', 'development')) + db = Database() + db.add_enum(et) + r = Column(name='id', + type=et, + pk=True, + autoinc=True) + expected = '"id" "product status" PRIMARY KEY AUTOINCREMENT' + self.assertEqual(r.sql, expected) + + et.schema = 'myschema' + expected = '"id" "myschema"."product status" PRIMARY KEY AUTOINCREMENT' + self.assertEqual(r.sql, expected) + def test_pk_autoinc(self) -> None: r = Column(name='id', type='integer', @@ -104,6 +120,24 @@ def test_dbml_simple(self): self.assertEqual(c.dbml, expected) + def test_dbml_enum_type(self) -> None: + et = Enum('product status', ('production', 'development')) + db = Database() + db.add_enum(et) + r = Column(name='id', + type=et, + pk=True, + autoinc=True) + t = Table('products') + t.add_column(r) + db.add_table(t) + expected = '"id" "product status" [pk, increment]' + self.assertEqual(r.dbml, expected) + + et.schema = 'myschema' + expected = '"id" "myschema"."product status" [pk, increment]' + self.assertEqual(r.dbml, expected) + def test_dbml_full(self): c = Column( name='order', diff --git a/test/test_editing.py b/test/test_editing.py index da6d518..05b9bc9 100644 --- a/test/test_editing.py +++ b/test/test_editing.py @@ -86,5 +86,5 @@ def test_enum_name(self): self.assertIn('Enum "changed product status"', enum.dbml) col = products['status'] - self.assertEqual(col.sql, '"status" changed product status') - self.assertEqual(col.dbml, '"status" changed product status') + self.assertEqual(col.sql, '"status" "changed product status"') + self.assertEqual(col.dbml, '"status" "changed product status"') From 6132f5aab3d6d060d7beb623a34f17652db99ba5 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Fri, 20 May 2022 10:01:45 +0200 Subject: [PATCH 035/125] fix minor bugs, start docs --- README.md | 251 ++++++-------------------------------- TODO.md | 3 +- docs/classes.md | 259 ++++++++++++++++++++++++++++++++++++++++ pydbml/classes/enum.py | 3 +- pydbml/classes/index.py | 10 +- pydbml/classes/table.py | 2 +- 6 files changed, 306 insertions(+), 222 deletions(-) create mode 100644 docs/classes.md diff --git a/README.md b/README.md index 2c745e5..d3c5dd4 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,13 @@ # DBML parser for Python -*Compliant with DBML **v2.3.0** syntax* +*Compliant with DBML **v2.4.1** syntax* -PyDBML is a Python parser for [DBML](https://www.dbml.org) syntax. +PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. + +**Docs:** + +* [Class Reference](docs/classes.md) ## Installation @@ -16,7 +20,7 @@ pip3 install pydbml ## Quick start -Import the `PyDBML` class and initialize it with path to DBML-file: +To parse a DBML file, import the `PyDBML` class and initialize it with Path object ```python >>> from pydbml import PyDBML @@ -25,21 +29,27 @@ Import the `PyDBML` class and initialize it with path to DBML-file: ``` -or with file stream: +or with file stream + ```python >>> with open('test_schema.dbml') as f: ... parsed = PyDBML(f) ``` -or with entire source string: +or with entire source string + ```python >>> with open('test_schema.dbml') as f: ... source = f.read() >>> parsed = PyDBML(source) +>>> parsed + ``` +Parser returns a Database object that is a container for the parsed DBML entities. + You can access tables inside the `tables` attribute: ```python @@ -55,24 +65,24 @@ countries ``` -Or just by getting items by index or table name: +Or just by getting items by index or full table name: ```python ->>> parsed['countries'] -
>>> parsed[1] -
+
+>>> parsed['public.countries'] +
``` -Other meaningful attributes are: +Other attributes are: * **refs** — list of all references, * **enums** — list of all enums, * **table_groups** — list of all table groups, * **project** — the Project object, if was defined. -You can get the SQL for your DBML schema by accessing `sql` property: +Generate SQL for your DBML Database by accessing the `sql` property: ```python >>> print(parsed.sql) # doctest:+ELLIPSIS @@ -91,29 +101,41 @@ CREATE TYPE "product status" AS ENUM ( CREATE TABLE "orders" ( "id" int PRIMARY KEY AUTOINCREMENT, "user_id" int UNIQUE NOT NULL, - "status" orders_status, + "status" "orders_status", "created_at" varchar ); ... ``` -Finally, you can generate the DBML source from your schema with updated values from the classes (added in **0.4.1**): +Generate DBML for your Database by accessing the `dbml` property: ```python >>> parsed.project.items['author'] = 'John Doe' >>> print(parsed.dbml) # doctest:+ELLIPSIS -Project test_schema { +Project "test_schema" { author: 'John Doe' Note { 'This schema is used for PyDBML doctest' } } +Enum "orders_status" { + "created" + "running" + "done" + "failure" +} + +Enum "product status" { + "Out of Stock" + "In Stock" +} + Table "orders" { "id" int [pk, increment] "user_id" int [unique, not null] - "status" orders_status + "status" "orders_status" "created_at" varchar } @@ -125,202 +147,3 @@ Table "order_items" { ... ``` - -# Docs - -## Table class - -After running parser all tables from the schema are stored in `tables` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> table = parsed.tables[0] ->>> table -
- -``` - -Important attributes of the `Table` object are: - -* **name** (str) — table name, -* **refs** (list of `TableReference`) — all foreign keys, defined for the table, -* **columns** (list of `Column`) — table columns, -* **indexes** (list of `Index`) — indexes, defined for the table. -* **alias** (str) — table alias, if defined. -* **note** (str) — note for table, if defined. -* **header_color** (str) — the header_color param, if defined. -* **comment** (str) — comment, if it was added just before table definition. - -`Table` object may act as a list or a dictionary of columns: - -```python ->>> table[0] - ->>> table['status'] - - -``` - -## Column class - -Table columns are stored in the `columns` attribute of a `Table` object. - -Important attributes of the `Column` object are: - -* **name** (str) — column name, -* **table** (Table)— link to `Table` object, which holds this column. -* **type** (str or `Enum`) — column type. If type is a enum, defined in the same schema, this attribute will hold a link to corresponding `Enum` object. -* **unique** (bool) — is column unique. -* **not_null** (bool) — is column not null. -* **pk** (bool) — is column a primary key. -* **autoinc** (bool) — is an autoincrement column. -* **default** (str or int or float) — column's default value. -* **note** (Note) — column's note if was defined. -* **comment** (str) — comment, if it was added just before column definition or right after it on the same line. - -## Index class - -Indexes are stored in the `indexes` attribute of a `Table` object. - -Important attributes of the `Index` object are: - -* **subjects** (list of `Column` or `str`) — list subjects which are indexed. Columns are represented by `Column` objects, expressions (`getdate()`) are stored as strings `(getdate())`. Expressions are supported since **0.3.5**. -* **table** (`Table`) — table, for which this index is defined. -* **name** (str) — index name, if defined. -* **unique** (bool) — is index unique. -* **type** (str) — index type, if defined. Can be either `hash` or `btree`. -* **pk** (bool) — is this a primary key index. -* **note** (note) — index note, if defined. -* **comment** (str) — comment, if it was added just before index definition. - -## Reference class - -After running parser all references from the schema are stored in `refs` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> parsed.refs[0] - - -``` - -Important attributes of the `Reference` object are: - -* **type** (str) — reference type, in DBML syntax: - * `<` — one to many; - * `>` — many to one; - * `-` — one to one. -* **table1** (`Table`) — link to the first table of the reference. -* **col1** (list os `Column`) — list of Column objects of the first table of the reference. Changed in **0.4.0**, previously was plain `Column`. -* **table2** (`Table`) — link to the second table of the reference. -* **col2** (list of `Column`) — list of Column objects of the second table of the reference. Changed in **0.4.0**, previously was plain `Column`. -* **name** (str) — reference name, if defined. -* **on_update** (str) — reference's on update setting, if defined. -* **on_delete** (str) — reference's on delete setting, if defined. -* **comment** (str) — comment, if it was added before reference definition. - -## TableReference class - -Apart from `Reference` objects, parser also creates `TableReference` objects, which are stored in each table, where the foreign key should be defined. These objects don't have types. List of references is stored in `refs` attribute of a Table object: - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> order_items_refs = parsed.tables[1].refs ->>> order_items_refs[0] - - -``` - -Important attributes of the `TableReference` object are: - -* **col** (list[`Column`]) — list of Column objects, which are referenced in this table. Changed in **0.4.0**, previously was plain `Column`. -* **ref_table** (`Table`) — link to the second table of the reference. -* **ref_col** (list[`Column`]) — list of Column objects, which are referenced by this table. Changed in **0.4.0**, previously was plain `Column`. -* **name** (str) — reference name, if defined. -* **on_update** (str) — reference's on update setting, if defined. -* **on_delete** (str) — reference's on delete setting, if defined. - -## Enum class - -After running parser all enums from the schema are stored in `enums` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> enum = parsed.enums[0] ->>> enum - - -``` - -`Enum` object contains three attributes: - -* **name** (str) — enum name, -* **items** (list of `EnumItem`) — list of items. -* **comment** (str) — comment, which was defined before enum definition. - -Enum objects also act as a list of items: - -```python ->>> enum[0] - - -``` - -### EnumItem class - -Enum items are stored in the `items` property of a `Enum` class. - -`EnumItem` object contains following attributes: - -* **name** (str) — enum item name, -* **note** (`Note`) — enum item note, if was defined. -* **comment** (str) — comment, which was defined before enum item definition or right after it on the same line. - -## Note class - -Note is a basic class, which may appear in some other classes' `note` attribute. It has just one meaningful attribute: - -**text** (str) — note text. - -## Project class - -After running parser the project info is stored in the `project` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> parsed.project - - -``` - -Attributes of the `Project` object: - -* **name** (str) — project name, -* **items** (str) — dictionary with project items, -* **note** (`Note`) — note, if was defined, -* **comment** (str) — comment, if was added before project definition. - -## TableGroup class - -After running parser the project info is stored in the `project` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> parsed.table_groups -[, ] - -``` - -Attributes of the `TableGroup` object: - -* **name** (str) — table group name, -* **items** (str) — dictionary with tables in the group, -* **comment** (str) — comment, if was added before table group definition. - -> TableGroup `items` parameter initially holds just the names of the tables, but after parsing the whole document, `PyDBMLParseResults` class replaces them with references to actual tables. diff --git a/TODO.md b/TODO.md index 4066deb..462a12d 100644 --- a/TODO.md +++ b/TODO.md @@ -6,8 +6,9 @@ - schema.add and .delete to support multiple arguments (handle errors properly) * - 2.3.1 Multiline comment /* ... */ * - 2.4 Multiple Schemas -- validation on "add_index", "add_table" etc +* - validation on "add_index", "add_table" etc * - enum type in table definition with schema - add coverage badge - add docstrings - new docs +- comment class diff --git a/docs/classes.md b/docs/classes.md new file mode 100644 index 0000000..4fa1531 --- /dev/null +++ b/docs/classes.md @@ -0,0 +1,259 @@ +# Class Reference + +Import pydbml classes from the `pydbml.classes` package. + +```python +>>> from pydbml.classes import Table, Column, Reference + +``` + +Each class represents a database entity. + +## Table + +`Table` class represents a database table. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> table = parsed.tables[0] +>>> table +
+ +``` + +`Table` object may act as a list or a dictionary of columns: + +```python +>>> table[0] + +>>> table['status'] + + +``` + +### Attributes + +* **database** (`Database`) — link to the table's database object, if it was set. +* **name** (str) — table name. +* **schema** (str) — table schema name. +* **full_name** (str) — table name with schema prefix. +* **columns** (list of `Column`) — table columns. +* **indexes** (list of `Index`) — indexes, defined for the table. +* **alias** (str) — table alias, if defined. +* **note** (str) — note for table, if defined. +* **header_color** (str) — the header_color param, if defined. +* **comment** (str) — comment, if it was added just before table definition. +* **sql** (str) — SQL definition for this table. +* **dbml** (str) — DBML definition for this table. + +### Methods + +* **add_column** (c: `Column`) — add a column to the table, +* **delete_column** (c: `Column` or int) — delete a column from the table by Column object or column index. +* **add_index** (i: `Index`) — add an index to the table, +* **delete_index** (i: Index or int) — delete an index from the table by Index object or index number. +* **get_refs** — get list of references, defined for this table. + +## Column + +`Column` class represents a column of a database table. + +Table columns are stored in the `columns` attribute of a `Table` object. + +### Attributes + +* **database** (`Database`) — link to the database object of this column's table, if it was set. +* **name** (str) — column name, +* **table** (`Table`) — link to `Table` object, which holds this column. +* **type** (str or `Enum`) — column type. If type is a enum, this attribute will hold a link to corresponding `Enum` object. +* **unique** (bool) — indicates whether the column is unique. +* **not_null** (bool) — indicates whether the column is not null. +* **pk** (bool) — indicates whether the column is a primary key. +* **autoinc** (bool) — indicates whether this is an autoincrement column. +* **default** (str or bool or int or float or Expression) — column's default value. +* **note** (Note) — column's note if was defined. +* **comment** (str) — comment, if it was added just before column definition or right after it on the same line. +* **sql** (str) — SQL definition for this column. +* **dbml** (str) — DBML definition for this column. + +### Methods + +* **get_refs** — get list of references, defined for this column. + +## Index + +`Index` class represents an index of a database table. + +Indexes are stored in the `indexes` attribute of a `Table` object. + +### Attributes + +* **subjects** (list of `Column` or `Expression`) — list subjects which are indexed. Columns are represented by `Column` objects or `Expression` objects. +* **subject_names** (list of str) — list of index subject names. +* **table** (`Table`) — link to table, for which this index is defined. +* **name** (str) — index name, if defined. +* **unique** (bool) — indicates whether the index is unique. +* **type** (str) — index type, if defined. Can be either `hash` or `btree`. +* **pk** (bool) — indicates whether this a primary key index. +* **note** (note) — index note, if defined. +* **comment** (str) — comment, if it was added just before index definition. +* **sql** (str) — SQL definition for this index. +* **dbml** (str) — DBML definition for this index. + +## Reference + +`Index` class represents a database relation. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> parsed.refs[0] + + +``` + +### Attributes + +database +type +col1 +col2 +name +comment +on_update +on_delete +inline + +* **database** (`Database`) — link to the reference's database object, if it was set. +* **type** (str) — reference type, in DBML syntax: + * `<` — one to many; + * `>` — many to one; + * `-` — one to one. +* **col1** (list os `Column`) — list of Column objects of the left side of the reference. Changed in **0.4.0**, previously was plain `Column`. +* **col2** (list of `Column`) — list of Column objects of the right side of the reference. Changed in **0.4.0**, previously was plain `Column`. +* **name** (str) — reference name, if defined. +* **on_update** (str) — reference's on update setting, if defined. +* **on_delete** (str) — reference's on delete setting, if defined. +* **comment** (str) — comment, if it was added before reference definition. +* **inline** (bool) — indicates whether this reference should be rendered inside SQL or DBML definition of the table. +* **sql** (str) — SQL definition for this reference. +* **dbml** (str) — DBML definition for this reference. + +## Enum + +`Enum` class represents a enum type in the database. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> enum = parsed.enums[0] +>>> enum + + +``` + +Enum objects also act as a list of items: + +```python +>>> enum[0] + + +``` + +### Attributes + +database +name +schema +comment +items + +* **database** (`Database`) — link to the enum's database object, if it was set. +* **schema** (str) — enum schema name. +* **name** (str) — enum name, +* **items** (list of `EnumItem`) — list of items. +* **comment** (str) — comment, which was defined before enum definition. +* **sql** (str) — SQL definition for this enum. +* **dbml** (str) — DBML definition for this enum. + +### Methods + +* **add_item** (item: `EnumItem` or str) — add an item to this enum. + +### EnumItem + +`EnumItem` class represents an item of a enum type in the database. + +Enum items are stored in the `items` property of a `Enum` class. + +### Attributes + +* **name** (str) — enum item name, +* **note** (`Note`) — enum item note, if was defined. +* **comment** (str) — comment, which was defined before enum item definition or right after it on the same line. +* **sql** (str) — SQL definition for this enum item. +* **dbml** (str) — DBML definition for this enum item. + +## Note + +Note is a basic class, which may appear in some other classes' `note` attribute. Mainly used for documentation of a DBML database. + +### Attributes + +**text** (str) — note text. +* **sql** (str) — SQL definition for this note. +* **dbml** (str) — DBML definition for this note. + +## Expression + +**new in PyDBML 1.0.0** + +`Expression` class represents an SQL expression. Expressions may appear in `Index` subjects or `Column` default values. + +### Attributes + +**text** (str) — expression text. +* **sql** (str) — SQL definition for this expression. +* **dbml** (str) — DBML definition for this expression. + +## Project + +`Project` class holds DBML project metadata. Project is not present in SQL. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> parsed.project + + +``` + +### Attributes + +* **database** (`Database`) — link to the project's database object, if it was set. +* **name** (str) — project name, +* **items** (str) — dictionary with project metadata, +* **note** (`Note`) — note, if was defined, +* **comment** (str) — comment, if was added before project definition. +* **dbml** (str) — DBML definition for this project. + +## TableGroup + +`TableGroup` class represents a table group in the DBML database. TableGroups are not present in SQL. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> parsed.table_groups +[, ] + +``` + +### Attributes + +* **database** (`Database`) — link to the tableg group's database object, if it was set. +* **name** (str) — table group name, +* **items** (str) — dictionary with tables in the group, +* **comment** (str) — comment, if was added before table group definition. +* **dbml** (str) — DBML definition for this table group. diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index fdbb4d9..3597234 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -1,4 +1,5 @@ from typing import Iterable +from typing import List from typing import Optional from typing import Union @@ -64,7 +65,7 @@ def __init__(self, self.name = name self.schema = schema self.comment = comment - self.items = [] + self.items: List[EnumItem] = [] for item in items: self.add_item(item) diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index 595285b..de58ec5 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -1,12 +1,13 @@ -from typing import Optional -from typing import Union from typing import List +from typing import Literal +from typing import Optional from typing import TYPE_CHECKING +from typing import Union from .base import SQLOjbect -from .note import Note from .column import Column from .expression import Expression +from .note import Note from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql from pydbml.tools import note_option_to_dbml @@ -23,11 +24,10 @@ def __init__(self, subjects: List[Union[str, 'Column', 'Expression']], name: Optional[str] = None, unique: bool = False, - type: Optional[str] = None, + type: Literal['hash', 'btree'] = None, pk: bool = False, note: Optional[Union['Note', str]] = None, comment: Optional[str] = None): - self.database = None self.subjects = subjects self.table: Optional['Table'] = None diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 820200a..3c13b94 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -26,7 +26,7 @@ class Table(SQLOjbect): '''Class representing table.''' - required_attributes = ('name',) + required_attributes = ('name', 'schema') def __init__(self, name: str, From abd89ec7721ce9ec278bb4d67ddbf0bfa92e4736 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 21 May 2022 21:55:55 +0200 Subject: [PATCH 036/125] continue docs --- TODO.md | 1 + docs/creating_schema.md | 78 +++++++++++++++++++++++++++++++++++++++++ pydbml/__init__.py | 1 + 3 files changed, 80 insertions(+) create mode 100644 docs/creating_schema.md diff --git a/TODO.md b/TODO.md index 462a12d..5ac6493 100644 --- a/TODO.md +++ b/TODO.md @@ -8,6 +8,7 @@ * - 2.4 Multiple Schemas * - validation on "add_index", "add_table" etc * - enum type in table definition with schema +- table columns parameter - add coverage badge - add docstrings - new docs diff --git a/docs/creating_schema.md b/docs/creating_schema.md new file mode 100644 index 0000000..af0ff07 --- /dev/null +++ b/docs/creating_schema.md @@ -0,0 +1,78 @@ +# Creating DBML schema + +You can use PyDBML not only for parsing DBML files, but also for creating schema from scratch in Python. + +## Database object + +You always start by creating a Database object. It will connect all other entities of the database for us. + +```python +>>> from pydbml import Database +>>> db = Database() + +``` + +Now let's create a table and add it to the database. + +```python +>>> from pydbml.classes import Table +>>> table1 = Table(name='products') +>>> db.add(table1) + +``` + +To add columns to the table you have to use the `add_column` method of the Table object. + +```python +>>> from pydbml.classes import Column +>>> col1 = Column(name='id', type='Integer', pk=True, autoinc=True) +>>> table1.add_column(col1) +>>> col2 = Column(name='product_name', type='Varchar', unique=True) +>>> table1.add_column(col2) +>>> col3 = Column(name='manufacturer_id', type='Integer') +>>> table1.add_column(col3) + +``` + +Index is also a part of a table, so you have to add it similarly, using `add_index` method: + +```python +>>> from pydbml.classes import Index +>>> index1 = Index([col2], unique=True) +>>> table1.add_index(index1) + +``` + +The table's third column, `manufacturer_id`. looks like it should be a foreign key. Let's create another table, called `manufacturers`, so that we could create a relation. + +```python +>>> table2 = Table( +... 'products', +... columns=[ +... Column('id', type='Integer', pk=True, autoinc=True), +... Column('manufacturer_name', type='Varchar'), +... Column('manufacturer_country', type='Varchar') +... ] +... ) +>>> db.add(table2) +
+ +``` + +Now to the relation: + +```python +from pydbml.classes import Reference +>>> ref = Reference('>', table1['manufacturer_id'], table2['id']) +>>> db.add(ref) + +``` + +You noticed that we are calling the `add` method on the Database after creating each object. While objects can somewhat function without being added to a database, DBML/SQL generation and some other useful methods won't work properly. + +Now let's generate DBML code for our schema. This is done by just calling the `dbml` property of the Database object: + +```python +>>> print(db.dbml) +>>> #breakpoint() +``` diff --git a/pydbml/__init__.py b/pydbml/__init__.py index 5da1d6e..df94b5c 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -1,5 +1,6 @@ from . import classes from .parser import PyDBML +from .database import Database from pydbml.constants import MANY_TO_ONE from pydbml.constants import ONE_TO_MANY from pydbml.constants import ONE_TO_ONE From 4ac528ad3ebabf1af40004e7834114643ada2023 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 10:17:02 +0200 Subject: [PATCH 037/125] init table with columns and indexes, finish main docs --- docs/classes.md | 54 +++++++++++++++++++++++++++++++-- docs/creating_schema.md | 51 ++++++++++++++++++++++++++++--- pydbml/classes/table.py | 11 +++++++ test/test_classes/test_table.py | 28 +++++++++++++++++ 4 files changed, 138 insertions(+), 6 deletions(-) diff --git a/docs/classes.md b/docs/classes.md index 4fa1531..ae2edb9 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -1,13 +1,63 @@ # Class Reference -Import pydbml classes from the `pydbml.classes` package. +PyDBML classes represent database entities. They live in the `pydbml.classes` package. ```python >>> from pydbml.classes import Table, Column, Reference ``` -Each class represents a database entity. +The `Database` class represents a PyDBML database. You can import it from the `pydbml` package. + +```python +>>> from pydbml import Database + +``` + +## Database + +`Database` is the main class, representing a PyDBML database. When PyDBML parses a .dbml file, it returns a `Database` object. This object holds all objects of the database and makes sure they are properly connected. You can access the `Database` object by calling the `database` property of each class (except child classes like `Column` or `Index`). + +When you are creating PyDBML schema from scratch, you have to add each created object to the database by calling `Database.add`. + +`Database` object may act as a list or a dictionary of tables: + +```python +>>> from pydbml import PyDBML +>>> db = PyDBML.parse_file('test_schema.dbml') +>>> table = db.tables[0] +>>> db['public.orders'] +
+>>> db[0] +
+ +``` + +### Attributes + +* **tables** (list of `Table`) — list of all `Table` objects, defined in this database. +* **table_dict** (dict of `Table`) — dictionary holding database `Table` objects. The key is full table name (with schema: `public.mytable`) or a table alias (`myalias`). +* **refs** (list of `Reference`) — list of all `Reference` objects, defined in this database. +* **enums** (list of `Enum`) — list of all `Enum` objects, defined in this database. +* **table_groups** (list of `TableGroup`) — list of all `TableGroup` objects, defined in this database. +* **project** (`Project`) — database `Project`. +* **sql** () — SQL definition for this database. +* **dbml** () — DBML definition for this table. + +### Methods + +* **add** (PyDBML object) — add a PyDBML object to the database. +* **add_table** (`Table`) — add a `Table` object to the database. +* **add_reference** (`Reference`) — add a `Reference` object to the database. +* **add_enum** (`Enum`) — add a `Enum` object to the database. +* **add_table_group** (`TableGroup`) — add a `TableGroup` object to the database. +* **add_project** (`Project`) — add a `Project` object to the database. +* **delete** (PyDBML object) — delete a PyDBML object from the database. +* **delete_table** (`Table`) — delete a `Table` object from the database. +* **delete_reference** (`Reference`) — delete a `Reference` object from the database. +* **delete_enum** (`Enum`) — delete a `Enum` object from the database. +* **delete_table_group** (`TableGroup`) — delete a `TableGroup` object from the database. +* **delete_project** (`Project`) — delete a `Project` object from the database. ## Table diff --git a/docs/creating_schema.md b/docs/creating_schema.md index af0ff07..f4f37d8 100644 --- a/docs/creating_schema.md +++ b/docs/creating_schema.md @@ -18,6 +18,7 @@ Now let's create a table and add it to the database. >>> from pydbml.classes import Table >>> table1 = Table(name='products') >>> db.add(table1) +
``` @@ -47,7 +48,7 @@ The table's third column, `manufacturer_id`. looks like it should be a foreign k ```python >>> table2 = Table( -... 'products', +... 'manufacturers', ... columns=[ ... Column('id', type='Integer', pk=True, autoinc=True), ... Column('manufacturer_name', type='Varchar'), @@ -55,16 +56,17 @@ The table's third column, `manufacturer_id`. looks like it should be a foreign k ... ] ... ) >>> db.add(table2) -
+
``` Now to the relation: ```python -from pydbml.classes import Reference +>>> from pydbml.classes import Reference >>> ref = Reference('>', table1['manufacturer_id'], table2['id']) >>> db.add(ref) +', ['manufacturer_id'], ['id']> ``` @@ -74,5 +76,46 @@ Now let's generate DBML code for our schema. This is done by just calling the `d ```python >>> print(db.dbml) ->>> #breakpoint() +Table "products" { + "id" Integer [pk, increment] + "product_name" Varchar [unique] + "manufacturer_id" Integer + + indexes { + product_name [unique] + } +} + +Table "manufacturers" { + "id" Integer [pk, increment] + "manufacturer_name" Varchar + "manufacturer_country" Varchar +} + +Ref { + "products"."manufacturer_id" > "manufacturers"."id" +} + +``` + +We can generate SQL for the schema in a similar way, by calling the `sql` property: + +```python +>>> print(db.sql) +CREATE TABLE "products" ( + "id" Integer PRIMARY KEY AUTOINCREMENT, + "product_name" Varchar UNIQUE, + "manufacturer_id" Integer +); + +CREATE UNIQUE INDEX ON "products" ("product_name"); + +CREATE TABLE "manufacturers" ( + "id" Integer PRIMARY KEY AUTOINCREMENT, + "manufacturer_name" Varchar, + "manufacturer_country" Varchar +); + +ALTER TABLE "products" ADD FOREIGN KEY ("manufacturer_id") REFERENCES "manufacturers" ("id"); + ``` diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 3c13b94..bf037bf 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -2,6 +2,7 @@ from typing import Optional from typing import TYPE_CHECKING from typing import Union +from typing import Iterable from .base import SQLOjbect from .column import Column @@ -32,6 +33,8 @@ def __init__(self, name: str, schema: str = 'public', alias: Optional[str] = None, + columns: Optional[Iterable[Column]] = None, + indexes: Optional[Iterable[Index]] = None, note: Optional[Union['Note', str]] = None, header_color: Optional[str] = None, comment: Optional[str] = None): @@ -39,7 +42,11 @@ def __init__(self, self.name = name self.schema = schema self.columns: List[Column] = [] + for column in columns or []: + self.add_column(column) self.indexes: List[Index] = [] + for index in indexes or []: + self.add_index(index) self.alias = alias if alias else None self.note = Note(note) self.header_color = header_color @@ -54,6 +61,8 @@ def add_column(self, c: Column) -> None: Adds column to self.columns attribute and sets in this column the `table` attribute. ''' + if not isinstance(c, Column): + raise TypeError('Columns must be of type Column') c.table = self self.columns.append(c) @@ -73,6 +82,8 @@ def add_index(self, i: Index) -> None: Adds index to self.indexes attribute and sets in this index the `table` attribute. ''' + if not isinstance(i, Index): + raise TypeError('Indexes must be of type Index') for subject in i.subjects: if isinstance(subject, Column) and subject.table != self: raise ColumnNotFoundError(f'Column {subject} not in the table') diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 69e4152..422a1fc 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -45,6 +45,30 @@ def test_getitem(self) -> None: with self.assertRaises(ColumnNotFoundError): t['wrong'] + def test_init_with_columns(self) -> None: + t = Table( + 'products', + columns=( + Column('col1', 'integer'), + Column('col2', 'integer'), + Column('col3', 'integer'), + ) + ) + self.assertIs(t['col1'].table, t) + self.assertIs(t['col2'].table, t) + self.assertIs(t['col3'].table, t) + + def test_init_with_indexes(self) -> None: + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t = Table( + 'products', + columns=[c1, c2, c3], + indexes=[Index(subjects=[c1])] + ) + self.assertIs(t.indexes[0].table, t) + def test_get(self) -> None: t = Table('products') c1 = Column('col1', 'integer') @@ -227,6 +251,8 @@ def test_add_column(self) -> None: self.assertEqual(c1.table, t) self.assertEqual(c2.table, t) self.assertEqual(t.columns, [c1, c2]) + with self.assertRaises(TypeError): + t.add_column('wrong type') def test_delete_column(self) -> None: t = Table('products') @@ -256,6 +282,8 @@ def test_add_index(self) -> None: self.assertEqual(i1.table, t) self.assertEqual(i2.table, t) self.assertEqual(t.indexes, [i1, i2]) + with self.assertRaises(TypeError): + t.add_index('wrong_type') def test_delete_index(self) -> None: t = Table('products') From 7a16e891b99a97083fa341bc4b25544c53f4cb31 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 11:28:34 +0200 Subject: [PATCH 038/125] table1 and table2 properties on reference --- README.md | 1 + docs/classes.md | 12 +-- docs/upgrading.md | 109 ++++++++++++++++++++++++++++ pydbml/classes/enum.py | 2 +- pydbml/classes/reference.py | 42 +++++++---- pydbml/classes/table.py | 6 +- test/test_classes/test_reference.py | 19 ++++- 7 files changed, 160 insertions(+), 31 deletions(-) create mode 100644 docs/upgrading.md diff --git a/README.md b/README.md index d3c5dd4..5447bd4 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. **Docs:** * [Class Reference](docs/classes.md) +* [Creating DBML schema](docs/creating_schema.md) ## Installation diff --git a/docs/classes.md b/docs/classes.md index ae2edb9..172f27d 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -165,23 +165,15 @@ Indexes are stored in the `indexes` attribute of a `Table` object. ### Attributes -database -type -col1 -col2 -name -comment -on_update -on_delete -inline - * **database** (`Database`) — link to the reference's database object, if it was set. * **type** (str) — reference type, in DBML syntax: * `<` — one to many; * `>` — many to one; * `-` — one to one. * **col1** (list os `Column`) — list of Column objects of the left side of the reference. Changed in **0.4.0**, previously was plain `Column`. +* **table1** (`Table` or `None`) — link to the left `Table` object of the reference or `None` of it was not set. * **col2** (list of `Column`) — list of Column objects of the right side of the reference. Changed in **0.4.0**, previously was plain `Column`. +* **table2** (`Table` or `None`) — link to the right `Table` object of the reference or `None` of it was not set. * **name** (str) — reference name, if defined. * **on_update** (str) — reference's on update setting, if defined. * **on_delete** (str) — reference's on delete setting, if defined. diff --git a/docs/upgrading.md b/docs/upgrading.md new file mode 100644 index 0000000..9bc4352 --- /dev/null +++ b/docs/upgrading.md @@ -0,0 +1,109 @@ +# Upgrading to PyDBML 1.0.0 + +When I created PyDBML back in April 2020, I just needed a DBML parser, and it was written as a parser. When people started using it, they wanted to also be able to edit DBML schema in Python and create it from scratch. While it worked to some extent, the project architecture was not completely ready for such usage. + +In May 2022 I've rewritten PyDBML from scratch and released version 1.0.0. Now you can not only parse DBML files, but also create them in Python and edit parsed schema. Sadly, it made the new version completely incompatible with the old one. This article will help you upgrade to PyDBML 1.0.0 and adapt your code to work with the new version. + +## Getting Tables From Parse Results by Name + +Previously the parser returned the `PyDBMLParseResults` object, now it returnes a `Database` object. While they mostly can be operated similarly, now you can't get a table just by name. + +Since v2.4 DBML supports multiple schemas for tables and enums. PyDBML 1.0.0 also supports multiple schemas, but this means that there may be tables with the same name in different schemas. So now you can't get a table from the parse results just by name, you have to specify the schema too: + +```python +>>> from pydbml import PyDBML +>>> db = PyDBML.parse_file('test_schema.dbml') +>>> db['orders'] +Traceback (most recent call last): +... +KeyError: 'orders' +>>> db['public.orders'] +
+ +``` + +## New Table Object + +Previously the `Table` object had a `refs` attribute which holded a list of `TableReference` objects. `TableReference` represented a table relation and duplicated the `Reference` object of `PyDBMLParseResults` container. + +**In 1.0.0 the `TableReference` class is removed, and there's not `Table.refs` attribute.** + +Now all relations are represented by a single `Reference` object. You can still access `Table` references by calling the `get_refs` method. + +`Table.get_refs` will return a list of References for this table, but only if this table is on the left side of DBML relation. + +Here's an example DBML reference definition: + +```python +>>> source = ''' +... Table posts { +... id integer [primary key] +... user_id integer +... } +... +... Table users { +... id integer +... } +... +... Ref name_optional: posts.user_id > users.id +... ''' +>>> db = PyDBML(source) + +``` + +Here the many-to-one (`>`) relation is defined with the **posts** table on the left side, so calling `get_refs` on the **posts** table will return you this reference: + +```python +>>> db['public.posts'].get_refs() +[', ['user_id'], ['id']>] + +``` + +But calling `get_refs` on the **users** table won't give you the reference, because **users** is on the right side of the relation: + +```python +>>> db['public.users'].get_refs() +[] + +``` + +This depends on the side the table was referenced on, not on the type of the reference. So, if we modify the previous example to use one-to-many relation instead of many-to-one: + +```python +>>> source = ''' +... Table posts { +... id integer [primary key] +... user_id integer +... } +... +... Table users { +... id integer +... } +... +... Ref name_optional: users.id < posts.user_id +... ''' +>>> db = PyDBML(source) + +``` + +Now the **users** table is on the left, and we can only get the reference from the **users** table: + +```python +>>> db['public.users'].get_refs() +[] +>>> db['public.posts'].get_refs() +[] + +``` + +You can still get all the references for the database by accessing `Database.refs` property: + +```python +>>> db.refs +[] + +``` + +## New Reference Object + +Previously the `Reference` object had links to referenced tables in `table1` and `table2` attributes. Now maybe too?? \ No newline at end of file diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index 3597234..e2e835f 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -54,7 +54,7 @@ def dbml(self): class Enum(SQLOjbect): - required_attributes = ('name', 'items') + required_attributes = ('name', 'schema', 'items') def __init__(self, name: str, diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 2cdf3d6..6b845ec 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -2,6 +2,7 @@ from typing import Literal from typing import Optional from typing import Union +from typing import TYPE_CHECKING from .base import SQLOjbect from .column import Column @@ -12,6 +13,9 @@ from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql +if TYPE_CHECKING: # pragma: no cover + from .table import Table + class Reference(SQLOjbect): ''' @@ -40,6 +44,16 @@ def __init__(self, self.on_delete = on_delete self.inline = inline + @property + def table1(self) -> Optional['Table']: + self._validate() + return self.col1[0].table if self.col1 else None + + @property + def table2(self) -> Optional['Table']: + self._validate() + return self.col2[0].table if self.col2 else None + def __repr__(self): ''' >>> c1 = Column('c1', 'int') @@ -76,13 +90,15 @@ def _validate(self): table1 = self.col1[0].table if any(c.table != table1 for c in self.col1): raise DBMLError('Columns in col1 are from different tables') - if table1 is None: - raise TableNotFoundError('Table on col1 is not set') table2 = self.col2[0].table if any(c.table != table2 for c in self.col2): raise DBMLError('Columns in col2 are from different tables') - if table2 is None: + + def _validate_for_sql(self): + if self.table1 is None: + raise TableNotFoundError('Table on col1 is not set') + if self.table2 is None: raise TableNotFoundError('Table on col2 is not set') @property @@ -94,17 +110,17 @@ def sql(self): ''' self.check_attributes_for_sql() - self._validate() + self._validate_for_sql() c = f'CONSTRAINT "{self.name}" ' if self.name else '' if self.inline: if self.type in (MANY_TO_ONE, ONE_TO_ONE): source_col = self.col1 - ref_table = self.col2[0].table + ref_table = self.table2 ref_col = self.col2 else: source_col = self.col2 - ref_table = self.col1[0].table + ref_table = self.table1 ref_col = self.col1 cols = '", "'.join(c.name for c in source_col) @@ -121,14 +137,14 @@ def sql(self): return result else: if self.type in (MANY_TO_ONE, ONE_TO_ONE): - t1 = self.col1[0].table + t1 = self.table1 c1 = ', '.join(f'"{c.name}"' for c in self.col1) - t2 = self.col2[0].table + t2 = self.table2 c2 = ', '.join(f'"{c.name}"' for c in self.col2) else: - t1 = self.col2[0].table + t1 = self.table2 c1 = ', '.join(f'"{c.name}"' for c in self.col2) - t2 = self.col1[0].table + t2 = self.table1 c2 = ', '.join(f'"{c.name}"' for c in self.col1) result = comment_to_sql(self.comment) if self.comment else '' @@ -144,7 +160,7 @@ def sql(self): @property def dbml(self): - self._validate() + self._validate_for_sql() if self.inline: # settings are ignored for inline ref if len(self.col2) > 1: @@ -178,9 +194,9 @@ def dbml(self): options_str = f' [{", ".join(options)}]' if options else '' result += ( ' {\n ' - f'{self.col1[0].table._get_full_name_for_sql()}.{col1} ' + f'{self.table1._get_full_name_for_sql()}.{col1} ' f'{self.type} ' - f'{self.col2[0].table._get_full_name_for_sql()}.{col2}' + f'{self.table2._get_full_name_for_sql()}.{col2}' f'{options_str}' '\n}' ) diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index bf037bf..f2d4865 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -104,7 +104,7 @@ def delete_index(self, i: Union[Index, int]) -> Index: def get_refs(self) -> List[Reference]: if not self.database: raise UnknownDatabaseError('Database for the table is not set') - return [ref for ref in self.database.refs if ref.col1[0].table == self] + return [ref for ref in self.database.refs if ref.table1 == self] def _get_references_for_sql(self) -> List[Reference]: ''' @@ -116,9 +116,9 @@ def _get_references_for_sql(self) -> List[Reference]: for ref in self.database.refs: if ref.inline: if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ - (ref.col1[0].table == self): + (ref.table1 == self): result.append(ref) - elif (ref.type == ONE_TO_MANY) and (ref.col2[0].table == self): + elif (ref.type == ONE_TO_MANY) and (ref.table2 == self): result.append(ref) return result diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 7cd25d3..9a32533 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -20,6 +20,17 @@ def test_sql_single(self): expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' self.assertEqual(ref.sql, expected) + def test_table1(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('>', c1, c2) + self.assertIsNone(ref.table1) + t.add_column(c1) + self.assertIs(ref.table1, t) + def test_sql_schema_single(self): t = Table('products', schema='myschema1') c1 = Column('name', 'varchar2') @@ -392,11 +403,11 @@ def test_validate_no_table(self): c2 ) with self.assertRaises(TableNotFoundError): - ref1._validate() + ref1._validate_for_sql() table = Table('name') table.add_column(c1) with self.assertRaises(TableNotFoundError): - ref1._validate() + ref1._validate_for_sql() table.delete_column(c1) ref2 = Reference( @@ -405,9 +416,9 @@ def test_validate_no_table(self): [c3, c4] ) with self.assertRaises(TableNotFoundError): - ref2._validate() + ref2._validate_for_sql() table = Table('name') table.add_column(c1) table.add_column(c2) with self.assertRaises(TableNotFoundError): - ref2._validate() + ref2._validate_for_sql() From 9b9798dd461281295f7523e09c7714ff447d88dd Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 11:55:35 +0200 Subject: [PATCH 039/125] finish docs, update changelog --- changelog.md | 12 +++---- docs/upgrading.md | 87 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 8 deletions(-) diff --git a/changelog.md b/changelog.md index 0e1e2ca..8cbc773 100644 --- a/changelog.md +++ b/changelog.md @@ -1,12 +1,10 @@ # 1.0.0 -- refs don't have tables, only columns -- tables don't have refs -- col1 col2 in ref are as they were in dbml -- Expression class -- add multiline comment -- support multiple schemas -- type_ -> type +- New project architecture, full support for creating and editing DBML. See details in [Upgrading to PyDBML 1.0.0](docs/upgrading.md) +- New Expression class +- Support DBML 2.4.1 syntax: + - Multiline comments + - Multiple schemas # 0.4.2 diff --git a/docs/upgrading.md b/docs/upgrading.md index 9bc4352..7e606f3 100644 --- a/docs/upgrading.md +++ b/docs/upgrading.md @@ -106,4 +106,89 @@ You can still get all the references for the database by accessing `Database.ref ## New Reference Object -Previously the `Reference` object had links to referenced tables in `table1` and `table2` attributes. Now maybe too?? \ No newline at end of file +Reference now can be explicitly inline. This is defined by the `Reference.inline` attribute. The `inline` attribute only affects how the reference will be rendered in table's SQL or DBML. + +Let's define an inline reference. + +```python +>>> from pydbml import Database +>>> from pydbml.classes import Table, Column, Reference +>>> db = Database() +>>> table1 = Table('products') +>>> db.add(table1) +
+>>> c1 = Column('name', 'varchar2') +>>> table1.add_column(c1) +>>> table2 = Table('names') +>>> db.add(table2) +
+>>> c2 = Column('name_val', 'varchar2') +>>> table2.add_column(c2) +>>> ref = Reference('>', c1, c2, inline=True) +>>> db.add(ref) +', ['name'], ['name_val']> +>>> print(table1.sql) +CREATE TABLE "products" ( + "name" varchar2, + FOREIGN KEY ("name") REFERENCES "names" ("name_val") +); + +``` + +If the reference is not inline, it won't appear in the Table SQL definition, otherwise it will le rendered separately as an `ALTER TABLE` clause: + +```python +>>> ref.inline = False +>>> print(table1.sql) +CREATE TABLE "products" ( + "name" varchar2 +); +>>> print(ref.sql) +ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val"); + +``` + +## `type_` -> `type` + +Previously you would initialize a `Column`, `Index` and `Reference` type with `type_` parameter. Now this parameter is renamed to simply `type`. + +```python +>>> from pydbml.classes import Index, Column +>>> c = Column(name='name', type='varchar') +>>> c + +>>> t = Table('names') +>>> t.add_column(c) +>>> i = Index(subjects=[c], type='btree') +>>> t.add_index(i) +>>> i + +>>> t2 = Table('names_caps', columns=[Column('name_caps', 'varchar')]) +>>> ref = Reference(type='-', col1=t['name'], col2=t2['name_caps']) +>>> ref + + +``` + +## New Expression Class + +SQL expressions are allowed in column's `default` value definition and in index's subject definition. Previously you defined expressions as parentesised strings: `"(upper(name))"`. Now you have to use the `Expression` class. This will make sure the expression will be rendered properly in SQL and DBML. + +```python +>>> from pydbml.classes import Expression +>>> c = Column( +... name='upper_name', +... type='varchar', +... default=Expression('upper(name)') +... ) +>>> t = Table('names') +>>> t.add_column(c) +>>> db = Database() +>>> db.add(t) +
+>>> print(c.sql) +"upper_name" varchar DEFAULT (upper(name)) +>>> print(c.dbml) +"upper_name" varchar [default: `upper(name)`] + +``` \ No newline at end of file From bda2d7f10dfee1b916fe2ff627aac0cefe86241e Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 11:58:45 +0200 Subject: [PATCH 040/125] SQLOjbect -> SQLObject --- pydbml/classes/base.py | 2 +- pydbml/classes/column.py | 4 ++-- pydbml/classes/enum.py | 4 ++-- pydbml/classes/expression.py | 4 ++-- pydbml/classes/index.py | 4 ++-- pydbml/classes/note.py | 4 ++-- pydbml/classes/reference.py | 4 ++-- pydbml/classes/table.py | 4 ++-- test/test_classes/test_base.py | 8 ++++---- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pydbml/classes/base.py b/pydbml/classes/base.py index 61c55bc..4a1b844 100644 --- a/pydbml/classes/base.py +++ b/pydbml/classes/base.py @@ -4,7 +4,7 @@ from pydbml.exceptions import AttributeMissingError -class SQLOjbect: +class SQLObject: ''' Base class for all SQL objects. ''' diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 68ac695..97a22df 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from typing import Union -from .base import SQLOjbect +from .base import SQLObject from .expression import Expression from .enum import Enum from .note import Note @@ -17,7 +17,7 @@ from .reference import Reference -class Column(SQLOjbect): +class Column(SQLObject): '''Class representing table column.''' required_attributes = ('name', 'type') diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index e2e835f..63b0bfe 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -3,7 +3,7 @@ from typing import Optional from typing import Union -from .base import SQLOjbect +from .base import SQLObject from .note import Note from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql @@ -53,7 +53,7 @@ def dbml(self): return result -class Enum(SQLOjbect): +class Enum(SQLObject): required_attributes = ('name', 'schema', 'items') def __init__(self, diff --git a/pydbml/classes/expression.py b/pydbml/classes/expression.py index 5dcd90a..30d0d6b 100644 --- a/pydbml/classes/expression.py +++ b/pydbml/classes/expression.py @@ -1,7 +1,7 @@ -from .base import SQLOjbect +from .base import SQLObject -class Expression(SQLOjbect): +class Expression(SQLObject): def __init__(self, text: str): self.text = text diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index de58ec5..9d19188 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from typing import Union -from .base import SQLOjbect +from .base import SQLObject from .column import Column from .expression import Expression from .note import Note @@ -16,7 +16,7 @@ from .table import Table -class Index(SQLOjbect): +class Index(SQLObject): '''Class representing index.''' required_attributes = ('subjects', 'table') diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index 0964806..354638e 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -1,9 +1,9 @@ from typing import Any -from .base import SQLOjbect +from .base import SQLObject -class Note(SQLOjbect): +class Note(SQLObject): def __init__(self, text: Any): self.text = str(text) if text else '' diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 6b845ec..48ac163 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -4,7 +4,7 @@ from typing import Union from typing import TYPE_CHECKING -from .base import SQLOjbect +from .base import SQLObject from .column import Column from pydbml.constants import MANY_TO_ONE from pydbml.constants import ONE_TO_ONE @@ -17,7 +17,7 @@ from .table import Table -class Reference(SQLOjbect): +class Reference(SQLObject): ''' Class, representing a foreign key constraint. It is a separate object, which is not connected to Table or Column objects diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index f2d4865..82515f6 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -4,7 +4,7 @@ from typing import Union from typing import Iterable -from .base import SQLOjbect +from .base import SQLObject from .column import Column from .index import Index from .note import Note @@ -24,7 +24,7 @@ from pydbml.database import Database -class Table(SQLOjbect): +class Table(SQLObject): '''Class representing table.''' required_attributes = ('name', 'schema') diff --git a/test/test_classes/test_base.py b/test/test_classes/test_base.py index fb5b6a1..4305665 100644 --- a/test/test_classes/test_base.py +++ b/test/test_classes/test_base.py @@ -1,12 +1,12 @@ from unittest import TestCase -from pydbml.classes.base import SQLOjbect +from pydbml.classes.base import SQLObject from pydbml.exceptions import AttributeMissingError class TestDBMLObject(TestCase): def test_check_attributes_for_sql(self) -> None: - o = SQLOjbect() + o = SQLObject() o.a1 = None o.b1 = None o.c1 = None @@ -20,11 +20,11 @@ def test_check_attributes_for_sql(self) -> None: o.check_attributes_for_sql() def test_comparison(self) -> None: - o1 = SQLOjbect() + o1 = SQLObject() o1.a1 = None o1.b1 = 'c' o1.c1 = 123 - o2 = SQLOjbect() + o2 = SQLObject() o2.a1 = None o2.b1 = 'c' o2.c1 = 123 From 6781ccfb1fc636d8076af09d2098c44609ce7dd0 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 12:03:32 +0200 Subject: [PATCH 041/125] review docs --- README.md | 5 ++++- docs/creating_schema.md | 6 +++--- docs/upgrading.md | 8 ++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5447bd4..11f609b 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,13 @@ PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. +> The project was rewritten in May 2022, the new version 1.0.0 is not compatible with the previous ones. See details in [Upgrading to PyDBML 1.0.0](docs/upgrading.md). + **Docs:** * [Class Reference](docs/classes.md) * [Creating DBML schema](docs/creating_schema.md) +* [Upgrading to PyDBML 1.0.0](docs/upgrading.md) ## Installation @@ -49,7 +52,7 @@ or with entire source string ``` -Parser returns a Database object that is a container for the parsed DBML entities. +The parser returns a Database object that is a container for the parsed DBML entities. You can access tables inside the `tables` attribute: diff --git a/docs/creating_schema.md b/docs/creating_schema.md index f4f37d8..4b90dc9 100644 --- a/docs/creating_schema.md +++ b/docs/creating_schema.md @@ -22,7 +22,7 @@ Now let's create a table and add it to the database. ``` -To add columns to the table you have to use the `add_column` method of the Table object. +To add columns to the table, you have to use the `add_column` method of the Table object. ```python >>> from pydbml.classes import Column @@ -44,7 +44,7 @@ Index is also a part of a table, so you have to add it similarly, using `add_ind ``` -The table's third column, `manufacturer_id`. looks like it should be a foreign key. Let's create another table, called `manufacturers`, so that we could create a relation. +The table's third column, `manufacturer_id` looks like it should be a foreign key. Let's create another table, called `manufacturers`, so that we could create a relation. ```python >>> table2 = Table( @@ -98,7 +98,7 @@ Ref { ``` -We can generate SQL for the schema in a similar way, by calling the `sql` property: +We can generate SQL for the schema similarly, by calling the `sql` property: ```python >>> print(db.sql) diff --git a/docs/upgrading.md b/docs/upgrading.md index 7e606f3..97bb58a 100644 --- a/docs/upgrading.md +++ b/docs/upgrading.md @@ -6,7 +6,7 @@ In May 2022 I've rewritten PyDBML from scratch and released version 1.0.0. Now y ## Getting Tables From Parse Results by Name -Previously the parser returned the `PyDBMLParseResults` object, now it returnes a `Database` object. While they mostly can be operated similarly, now you can't get a table just by name. +Previously the parser returned the `PyDBMLParseResults` object, now it returns a `Database` object. While they mostly can be operated similarly, now you can't get a table just by name. Since v2.4 DBML supports multiple schemas for tables and enums. PyDBML 1.0.0 also supports multiple schemas, but this means that there may be tables with the same name in different schemas. So now you can't get a table from the parse results just by name, you have to specify the schema too: @@ -135,7 +135,7 @@ CREATE TABLE "products" ( ``` -If the reference is not inline, it won't appear in the Table SQL definition, otherwise it will le rendered separately as an `ALTER TABLE` clause: +If the reference is not inline, it won't appear in the Table SQL definition, otherwise it will be rendered separately as an `ALTER TABLE` clause: ```python >>> ref.inline = False @@ -150,7 +150,7 @@ ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val"); ## `type_` -> `type` -Previously you would initialize a `Column`, `Index` and `Reference` type with `type_` parameter. Now this parameter is renamed to simply `type`. +Previously you would initialize a `Column`, `Index` and `Reference` type with `type_` parameter. Now, this parameter is renamed to simply `type`. ```python >>> from pydbml.classes import Index, Column @@ -172,7 +172,7 @@ Previously you would initialize a `Column`, `Index` and `Reference` type with `t ## New Expression Class -SQL expressions are allowed in column's `default` value definition and in index's subject definition. Previously you defined expressions as parentesised strings: `"(upper(name))"`. Now you have to use the `Expression` class. This will make sure the expression will be rendered properly in SQL and DBML. +SQL expressions are allowed in column's `default` value definition and in index's subject definition. Previously, you defined expressions as parenthesized strings: `"(upper(name))"`. Now you have to use the `Expression` class. This will make sure the expression will be rendered properly in SQL and DBML. ```python >>> from pydbml.classes import Expression From 25c5673af3e3f9715bd9ee2ea45a0a61bf42caf2 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 12:07:57 +0200 Subject: [PATCH 042/125] bump version --- .gitignore | 1 + setup.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 00e1adf..7b02e9b 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ dist pydbml.egg-info .mypy_cache .coverage +.eggs diff --git a/setup.py b/setup.py index 2c6deec..c80d992 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup -SHORT_DESCRIPTION = 'DBML syntax parser for Python' +SHORT_DESCRIPTION = 'Python parser and builder for DBML' try: with open('README.md', encoding='utf8') as readme: @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='0.4.2', + version='1.0.0', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From c6b44fd3223794ccf08818f238e9b98333b16fb6 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 12:14:31 +0200 Subject: [PATCH 043/125] add coverage badge --- README.md | 2 +- coverage.svg | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 coverage.svg diff --git a/README.md b/README.md index 11f609b..b2f3148 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![](https://img.shields.io/pypi/v/pydbml.svg)](https://pypi.org/project/pydbml/) [![](https://img.shields.io/pypi/dm/pydbml.svg)](https://pypi.org/project/pydbml/) [![](https://img.shields.io/github/v/tag/Vanderhoof/PyDBML.svg?label=GitHub)](https://github.com/Vanderhoof/PyDBML) +[![](https://img.shields.io/pypi/v/pydbml.svg)](https://pypi.org/project/pydbml/) [![](https://img.shields.io/pypi/dm/pydbml.svg)](https://pypi.org/project/pydbml/) [![](https://img.shields.io/github/v/tag/Vanderhoof/PyDBML.svg?label=GitHub)](https://github.com/Vanderhoof/PyDBML) ![](coverage.svg) # DBML parser for Python diff --git a/coverage.svg b/coverage.svg new file mode 100644 index 0000000..e5db27c --- /dev/null +++ b/coverage.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + coverage + coverage + 100% + 100% + + From e115ae93ffc1ff7e74ca305161412b7c2f168eff Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 12:17:26 +0200 Subject: [PATCH 044/125] fix Index type arg, update todo and tests.sh --- TODO.md | 13 +------------ pydbml/classes/index.py | 2 +- pydbml/parser/blueprints.py | 2 +- test.sh | 3 +++ 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/TODO.md b/TODO.md index 5ac6493..825ba3a 100644 --- a/TODO.md +++ b/TODO.md @@ -1,15 +1,4 @@ -* - Creating dbml schema in python -* - pyparsing new var names (+possibly new features) -* - enum type -* - `_type` -> `type` -* - expression class - schema.add and .delete to support multiple arguments (handle errors properly) -* - 2.3.1 Multiline comment /* ... */ -* - 2.4 Multiple Schemas -* - validation on "add_index", "add_table" etc -* - enum type in table definition with schema -- table columns parameter -- add coverage badge - add docstrings -- new docs - comment class +- add CI diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index 9d19188..c2c0b62 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -24,7 +24,7 @@ def __init__(self, subjects: List[Union[str, 'Column', 'Expression']], name: Optional[str] = None, unique: bool = False, - type: Literal['hash', 'btree'] = None, + type: Optional[Literal['hash', 'btree']] = None, pk: bool = False, note: Optional[Union['Note', str]] = None, comment: Optional[str] = None): diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 773d3e3..dbfb76f 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -136,7 +136,7 @@ class IndexBlueprint(Blueprint): subject_names: List[Union[str, ExpressionBlueprint]] name: Optional[str] = None unique: bool = False - type: Optional[str] = None + type: Optional[Literal['hash', 'btree']] = None pk: bool = False note: Optional[NoteBlueprint] = None comment: Optional[str] = None diff --git a/test.sh b/test.sh index 3b88011..54ae6b8 100755 --- a/test.sh +++ b/test.sh @@ -1,3 +1,6 @@ python3 -m doctest README.md &&\ + python3 -m doctest docs/classes.md &&\ + python3 -m doctest docs/upgrading.md &&\ + python3 -m doctest docs/creating_schema.md &&\ python3 -m unittest discover &&\ mypy pydbml --ignore-missing-imports From 78898104423753869a90ef9ce2217193a37ba600 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 22 May 2022 12:18:42 +0200 Subject: [PATCH 045/125] rename changelog, update license --- changelog.md => CHANGELOG.md | 0 LICENSE | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename changelog.md => CHANGELOG.md (100%) diff --git a/changelog.md b/CHANGELOG.md similarity index 100% rename from changelog.md rename to CHANGELOG.md diff --git a/LICENSE b/LICENSE index 1637db2..4d20996 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 +Copyright (c) 2022 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From 551ae3b6be7a424b0dd848c1fc20402694baee64 Mon Sep 17 00:00:00 2001 From: Daniel Minukhin Date: Tue, 24 May 2022 06:22:53 +0000 Subject: [PATCH 046/125] Update upgrading.md --- docs/upgrading.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/upgrading.md b/docs/upgrading.md index 97bb58a..31ce7bc 100644 --- a/docs/upgrading.md +++ b/docs/upgrading.md @@ -26,9 +26,9 @@ KeyError: 'orders' Previously the `Table` object had a `refs` attribute which holded a list of `TableReference` objects. `TableReference` represented a table relation and duplicated the `Reference` object of `PyDBMLParseResults` container. -**In 1.0.0 the `TableReference` class is removed, and there's not `Table.refs` attribute.** +**In 1.0.0 the `TableReference` class is removed, and there's no `Table.refs` attribute.** -Now all relations are represented by a single `Reference` object. You can still access `Table` references by calling the `get_refs` method. +Now each relation is represented by a single `Reference` object. You can still access `Table` references by calling the `get_refs` method. `Table.get_refs` will return a list of References for this table, but only if this table is on the left side of DBML relation. @@ -191,4 +191,4 @@ SQL expressions are allowed in column's `default` value definition and in index' >>> print(c.dbml) "upper_name" varchar [default: `upper(name)`] -``` \ No newline at end of file +``` From 02f779daeb72c18a13703373332617d17253621e Mon Sep 17 00:00:00 2001 From: Matthew V Date: Fri, 27 May 2022 16:03:51 -0500 Subject: [PATCH 047/125] Added classes and parser directories to setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c80d992..fc633d1 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', - packages=['pydbml', 'pydbml.definitions'], + packages=['pydbml', 'pydbml.classes', 'pydbml.definitions', 'pydbml.parser'], license='MIT', platforms='any', install_requires=[ From 204d648dd099ca728d7a5c1ca8077bebecbc354e Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 28 May 2022 10:42:22 +0200 Subject: [PATCH 048/125] Fix setup.py --- CHANGELOG.md | 4 ++++ setup.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cbc773..f4fc624 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.1 + +- Fixed setup.py, thanks to @vosskj03. + # 1.0.0 - New project architecture, full support for creating and editing DBML. See details in [Upgrading to PyDBML 1.0.0](docs/upgrading.md) diff --git a/setup.py b/setup.py index fc633d1..079f363 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.0', + version='1.0.1', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 7fa8f8ec1299a39ef5ec3a96bec7a9ce72d9c8d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jens=20K=C3=B6ster?= Date: Wed, 6 Jul 2022 16:04:02 +0200 Subject: [PATCH 049/125] set schema1 on inline refs --- pydbml/parser/blueprints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index dbfb76f..bf15f76 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -205,6 +205,7 @@ def get_reference_blueprints(self): result = [] for col in self.columns: for ref_bp in col.ref_blueprints or []: + ref_bp.schema1 = self.schema ref_bp.table1 = self.name ref_bp.col1 = col.name result.append(ref_bp) From ff0c582bae7bb3d3da0a5614cb18a61e4316eeea Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 10 Jul 2022 12:38:38 +0200 Subject: [PATCH 050/125] add test for inline ref schema --- test/test_parser.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_parser.py b/test/test_parser.py index 9b53057..c54ad96 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -47,6 +47,26 @@ def test_refs(self) -> None: self.assertEqual(r[4].col2[0].table.name, 'merchants') self.assertEqual(r[4].col2[0].name, 'id') + def test_inline_refs_schema(self) -> None: + # Thanks @jens-koster for this example + source = ''' +Table core.pk_tbl { + pk_col varchar [pk] +} +Table core.fk_tbl { + fk_col varchar [ref: > core.pk_tbl.pk_col] +} +''' + p = PyDBMLParser(source) + p.parse() + r = p.refs + pk_tbl = p.tables[0] + fk_tble = p.tables[1] + ref = p.refs[0] + self.assertEqual(ref.table1, fk_tble.name) + self.assertEqual(ref.table2, pk_tbl.name) + self.assertEqual(ref.schema1, fk_tble.schema) + self.assertEqual(ref.schema2, pk_tbl.schema) class TestRefs(TestCase): def test_reference_aliases(self): From 2ecf2b9b48bc88eefcd83c3278ffec666b061861 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 10 Jul 2022 12:38:47 +0200 Subject: [PATCH 051/125] update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7b02e9b..1547f37 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ pydbml.egg-info .mypy_cache .coverage .eggs +.idea From 08a20b5cb8fc0be8acc5f0e9ab3baff12515b076 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 10 Jul 2022 12:44:46 +0200 Subject: [PATCH 052/125] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4fc624..30ac75a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.2 + +- Fix: inline ref schema bug, thanks to @jens-koster + # 1.0.1 - Fixed setup.py, thanks to @vosskj03. From 4eb695123a51597b2a6f19068c7abe67d160d361 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 10 Jul 2022 16:28:59 +0200 Subject: [PATCH 053/125] attempt to fix #16 --- pydbml/classes/note.py | 45 +++++------- pydbml/parser/blueprints.py | 6 +- pydbml/parser/parser.py | 27 ++++++-- pydbml/tools.py | 42 ++++++++++++ test/test_classes/test_note.py | 39 +++++++++++ test/test_data/notes.dbml | 60 ++++++++++++++++ test/test_docs.py | 2 +- test/test_parser.py | 122 +++++++++++++++++++++++++++++++++ test/utils.py | 12 ++++ 9 files changed, 321 insertions(+), 34 deletions(-) create mode 100644 test/test_data/notes.dbml create mode 100644 test/utils.py diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index 354638e..635e61d 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -1,11 +1,20 @@ from typing import Any from .base import SQLObject +from pydbml.tools import reformat_note_text +from pydbml.tools import remove_indentation +from pydbml.tools import indent class Note(SQLObject): - def __init__(self, text: Any): - self.text = str(text) if text else '' + def __init__(self, text: Any, reformat: bool = True): + if isinstance(text, Note): + raw_text = text.text + self.reformat = text.reformat + else: + raw_text = str(text) if text else '' + self.reformat = reformat + self.text = remove_indentation(raw_text.strip('\n')) def __str__(self): ''' @@ -35,30 +44,14 @@ def sql(self): @property def dbml(self): - lines = [] - line = '' - for word in self.text.split(' '): - if len(line) > 80: - lines.append(line) - line = '' - if '\n' in word: - sublines = word.split('\n') - for sl in sublines[:-1]: - line += sl - lines.append(line) - line = '' - line = sublines[-1] + ' ' - else: - line += f'{word} ' - if line: - lines.append(line) - result = 'Note {\n ' - - if len(lines) > 1: - lines_str = '\n '.join(lines)[:-1] + '\n' - result += f"'''\n {lines_str} '''" + if self.reformat: + note_text = reformat_note_text(self.text) else: - result += f"'{lines[0][:-1]}'" + if '\n' in self.text: + note_text = f"'''\n{self.text}'''" + else: + note_text = f"'{self.text}'" - result += '\n}' + note_text = indent(note_text) + result = f'Note {{\n{note_text}\n}}' return result diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index bf15f76..b69fbc5 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -30,7 +30,11 @@ class NoteBlueprint(Blueprint): text: str def build(self) -> 'Note': - return Note(self.text) + if self.parser: + reformat = self.parser.options['reformat_notes'] + return Note(self.text, reformat=reformat) + else: + return Note(self.text) @dataclass diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 705b9dc..4989bb2 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -46,7 +46,8 @@ class PyDBML: ''' def __new__(cls, - source_: Optional[Union[str, Path, TextIOWrapper]] = None): + source_: Optional[Union[str, Path, TextIOWrapper]] = None, + reformat_notes: bool = True): if source_ is not None: if isinstance(source_, str): source = source_ @@ -59,7 +60,7 @@ def __new__(cls, raise TypeError('Source must be str, path or file stream') source = remove_bom(source) - return cls.parse(source) + return cls.parse(source, reformat_notes) else: return super().__new__(cls) @@ -72,9 +73,9 @@ def __repr__(self): return "" @staticmethod - def parse(text: str) -> Database: + def parse(text: str, reformat_notes: bool = True) -> Database: text = remove_bom(text) - parser = PyDBMLParser(text) + parser = PyDBMLParser(text, reformat_notes) return parser.parse() @staticmethod @@ -90,13 +91,16 @@ def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: class PyDBMLParser: - def __init__(self, source: str): + def __init__(self, source: str, reformat_notes: bool = True): self.database = None + self.options = { + 'reformat_notes': reformat_notes + } self.ref_blueprints: List[ReferenceBlueprint] = [] self.table_groups: List[TableGroupBlueprint] = [] self.source = source - self.tables: List[TableGroupBlueprint] = [] + self.tables: List[TableBlueprint] = [] self.refs: List[ReferenceBlueprint] = [] self.enums: List[EnumBlueprint] = [] self.project: Optional[ProjectBlueprint] = None @@ -149,16 +153,27 @@ def parse_blueprint(self, s, loc, tok): ref_bp.parser = self for col_bp in col_bps: col_bp.parser = self + if col_bp.note: + col_bp.note.parser = self for index_bp in index_bps: index_bp.parser = self + if index_bp.note: + index_bp.note.parser = self + if blueprint.note: + blueprint.note.parser = self elif isinstance(blueprint, ReferenceBlueprint): self.refs.append(blueprint) elif isinstance(blueprint, EnumBlueprint): self.enums.append(blueprint) + for enum_item in blueprint.items: + if enum_item.note: + enum_item.note.parser = self elif isinstance(blueprint, TableGroupBlueprint): self.table_groups.append(blueprint) elif isinstance(blueprint, ProjectBlueprint): self.project = blueprint + if blueprint.note: + blueprint.note.parser = self else: raise RuntimeError(f'type unknown: {blueprint}') blueprint.parser = self diff --git a/pydbml/tools.py b/pydbml/tools.py index 63f2b81..5cc629b 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -1,3 +1,4 @@ +import re from typing import TYPE_CHECKING if TYPE_CHECKING: # pragma: no cover @@ -33,3 +34,44 @@ def remove_bom(source: str) -> str: if source and source[0] == '\ufeff': source = source[1:] return source + +def remove_indentation(source: str) -> str: + pattern = re.compile(r'(?<=\n)\s*') + spaces = pattern.findall(f'\n{source}') + if not spaces: + return source + indent = min(map(len, spaces)) + lines = source.split('\n') + lines = [l[indent:] for l in lines] + return '\n'.join(lines) + +def reformat_note_text(source: str, spaces=4) -> str: + """ + Add line breaks at approx 80-90 characters, indent text. + If source is less than 90 characters and has no line breaks, leave it unchanged. + """ + if '\n' not in source and len(source) <= 90: + return f"'{source}'" + + # text = source.strip('\n') + lines = [] + line = '' + for word in source.split(' '): + if len(line) > 80: + lines.append(line) + line = '' + if '\n' in word: + sublines = word.split('\n') + for sl in sublines[:-1]: + line += sl + lines.append(line) + line = '' + line = sublines[-1] + ' ' + else: + line += f'{word} ' + if line: + lines.append(line) + result = '\n'.join(lines).rstrip() + result = f"'''\n{result}\n'''" + # result = indent((result)) + return f'{result}' diff --git a/test/test_classes/test_note.py b/test/test_classes/test_note.py index 8e3c6b1..c08d731 100644 --- a/test/test_classes/test_note.py +++ b/test/test_classes/test_note.py @@ -60,3 +60,42 @@ def test_sql(self) -> None: -- will -- be the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.""" self.assertEqual(note3.sql, expected) + +class TestNoteReformat(TestCase): + def test_auto_reformat_long_line(self): + n = Note('Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.') + expected = \ +"""Note { + ''' + Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem + facere eos, quod error consectetur. + ''' +}""" + self.assertEqual(n.dbml, expected) + + def test_long_line_reformat_off(self): + n = Note( + 'Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.', + reformat=False + ) + expected = "Note {\n 'Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.'\n}" + self.assertEqual(n.dbml, expected) + + def test_short_line(self): + n = Note('Short line of note text.', reformat=False) + expected = "Note {\n 'Short line of note text.'\n}" + self.assertEqual(n.dbml, expected) + n.reformat = True + self.assertEqual(n.dbml, expected) + + def test_reindent(self): + n = Note(' Line1\n Line2\n Line3\n Line4') + expected = """Note { + ''' + Line1 + Line2 + Line3 + Line4 + ''' +}""" + self.assertEqual(n.dbml, expected) diff --git a/test/test_data/notes.dbml b/test/test_data/notes.dbml new file mode 100644 index 0000000..d944c26 --- /dev/null +++ b/test/test_data/notes.dbml @@ -0,0 +1,60 @@ +Project "my project" { + author: 'me' + reason: 'testing' + Note: ''' + # DBML - Database Markup Language + DBML (database markup language) is a simple, readable DSL language designed to define database structures. + + ## Benefits + + * It is simple, flexible and highly human-readable + * It is database agnostic, focusing on the essential database structure definition without worrying about the detailed syntaxes of each database + * Comes with a free, simple database visualiser at [dbdiagram.io](http://dbdiagram.io) + ''' +} + +Enum "level" { + "junior" [note: 'enum item note'] + "middle" + "senior" +} + +Table "orders" [headercolor: #fff] { + "id" int [pk, increment] + "user_id" int [unique, not null] + "status" orders_status [note: "test note"] + "created_at" varchar + Note: 'Simple one line note' +} + +Table "order_items" { + "order_id" int + "product_id" int + "quantity" int [default: 1] + Note: 'Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.' + indexes { + order_id [unique, Note: 'Index note'] + `ROUND(quantity)` + } +} + +Table "products" { + "id" int [pk] + "name" varchar + "merchant_id" int [not null] + "price" int + "status" "product status" + "created_at" datetime [default: `now()`] + Note { + '''Indented note which is actually a Markdown formated string: + + - List item 1 + - Another list item + + ```[python + def test(): + print('Hello world!') + return 1 + ```''' + } +} diff --git a/test/test_docs.py b/test/test_docs.py index f186545..8a8b84f 100644 --- a/test/test_docs.py +++ b/test/test_docs.py @@ -214,7 +214,7 @@ def test_project_notes(self) -> None: project = results.project self.assertEqual(project.name, 'DBML') - self.assertTrue(project.note.text.startswith('\n # DBML - Database Markup Language\n DBML')) + self.assertTrue(project.note.text.startswith('# DBML - Database Markup Language\nDBML')) def test_column_notes(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'column_notes.dbml') diff --git a/test/test_parser.py b/test/test_parser.py index c54ad96..8b1ab99 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -68,6 +68,7 @@ def test_inline_refs_schema(self) -> None: self.assertEqual(ref.schema1, fk_tble.schema) self.assertEqual(ref.schema2, pk_tbl.schema) + class TestRefs(TestCase): def test_reference_aliases(self): results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_aliases.dbml') @@ -124,3 +125,124 @@ def test_edge(self) -> None: p.locate_table('myschema', 'test') with self.assertRaises(RuntimeError): p.parse_blueprint(1, 1, [1]) + + +class TestReformatNotes(TestCase): + def test_reformat_by_default(self): + p = PyDBML(TEST_DATA_PATH / 'notes.dbml') + + for table in p.tables: + if table.note: + self.assertTrue(table.note.reformat) + for col in table.columns: + if col.note: + self.assertTrue(col.note.reformat) + for index in table.indexes: + if index.note: + self.assertTrue(index.note.reformat) + for enum in p.enums: + for item in enum.items: + if item.note: + self.assertTrue(item.note.reformat) + self.assertTrue(p.project.note.reformat) + + def test_reformat_off(self): + p = PyDBML(TEST_DATA_PATH / 'notes.dbml', reformat_notes=False) + + for table in p.tables: + if table.note: + self.assertFalse(table.note.reformat) + for col in table.columns: + if col.note: + self.assertFalse(col.note.reformat) + for index in table.indexes: + if index.note: + self.assertFalse(index.note.reformat) + for enum in p.enums: + for item in enum.items: + if item.note: + self.assertFalse(item.note.reformat) + self.assertFalse(p.project.note.reformat) + + def test_note_is_idempotent(self): + dbml_source = """ +Table test { + id integer + Note { + ''' + Indented note which is actually a Markdown formated string: + + - List item 1 + - Another list item + + ```python + def test(): + print('Hello world!') + return 1 + ``` + ''' + } +} +""" + source_text = \ +"""Indented note which is actually a Markdown formated string: + +- List item 1 +- Another list item + +```python +def test(): + print('Hello world!') + return 1 +```""" + p = PyDBML(dbml_source) + note = p.tables[0].note + self.assertEqual(source_text, note.text) + + p_mod = p + for _ in range(10): + p_mod = PyDBML(p_mod.dbml) + note2 = p_mod.tables[0].note + self.assertEqual(source_text, note2.text) + + + def test_unformatted_note_is_idempotent(self): + dbml_source = """ +Table test { + id integer + Note { + ''' + Indented note which is actually a Markdown formated string: + + - List item 1 + - Another list item + + ```python + def test(): + print('Hello world!') + return 1 + ``` + ''' + } +} +""" + source_text = \ +"""Indented note which is actually a Markdown formated string: + +- List item 1 +- Another list item + +```python +def test(): + print('Hello world!') + return 1 +```""" + p = PyDBML(dbml_source, reformat_notes=False) + note = p.tables[0].note + self.assertEqual(source_text, note.text) + + p_mod = p + for _ in range(10): + p_mod = PyDBML(p_mod.dbml, reformat_notes=False) + note2 = p_mod.tables[0].note + self.assertEqual(source_text, note2.text) diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 0000000..5d1163d --- /dev/null +++ b/test/utils.py @@ -0,0 +1,12 @@ +from unittest.mock import Mock +from typing import Optional + + +DEFAULT_OPTIONS = { + 'reformat_notes': True, +} + +def mock_parser(options: Optional[dict] = None): + if options is None: + options = dict(DEFAULT_OPTIONS) + return Mock(options=options) From 35984b0fda04fd6c1fe66a8490af0a23de0fdc12 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 10 Jul 2022 16:39:29 +0200 Subject: [PATCH 054/125] fix #15, support note objects in project --- pydbml/definitions/project.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydbml/definitions/project.py b/pydbml/definitions/project.py index 16bd2e0..12fa2fd 100644 --- a/pydbml/definitions/project.py +++ b/pydbml/definitions/project.py @@ -4,6 +4,7 @@ from .common import _c from .common import n from .common import note +from .common import note_object from .generic import name from .generic import string_literal from pydbml.parser.blueprints import NoteBlueprint @@ -13,7 +14,7 @@ project_field = pp.Group(name + _ + pp.Suppress(':') + _ - string_literal) -project_element = _ + (note | project_field) + _ +project_element = _ + (note | note_object | project_field) + _ project_body = project_element[...] From a09ed68e739e974973c7f113cf4f0f62dcfe8c1f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 10 Jul 2022 17:48:38 +0200 Subject: [PATCH 055/125] fix note formatting --- pydbml/classes/note.py | 5 ++--- pydbml/parser/blueprints.py | 8 ++++++-- pydbml/tools.py | 32 ++++++++++++++++++++++++-------- test/test_parser.py | 8 ++++---- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index 635e61d..b45e220 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -9,12 +9,11 @@ class Note(SQLObject): def __init__(self, text: Any, reformat: bool = True): if isinstance(text, Note): - raw_text = text.text + self.text = text.text self.reformat = text.reformat else: - raw_text = str(text) if text else '' + self.text = str(text) if text else '' self.reformat = reformat - self.text = remove_indentation(raw_text.strip('\n')) def __str__(self): ''' diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index b69fbc5..6943ad6 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -19,6 +19,8 @@ from pydbml.classes import TableGroup from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError +from pydbml.tools import remove_indentation +from pydbml.tools import strip_empty_lines class Blueprint: @@ -30,11 +32,13 @@ class NoteBlueprint(Blueprint): text: str def build(self) -> 'Note': + text = strip_empty_lines(self.text) + text = remove_indentation(text) if self.parser: reformat = self.parser.options['reformat_notes'] - return Note(self.text, reformat=reformat) + return Note(text, reformat=reformat) else: - return Note(self.text) + return Note(text) @dataclass diff --git a/pydbml/tools.py b/pydbml/tools.py index 5cc629b..6c7f6a8 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -35,13 +35,29 @@ def remove_bom(source: str) -> str: source = source[1:] return source + +def strip_empty_lines(source: str) -> str: + """Remove empty lines or lines with just spaces from beginning and end.""" + first_line = 0 + lines = source.split('\n') + last_line = len(lines) - 1 + while not lines[first_line] or lines[first_line].isspace(): + first_line += 1 + while not lines[last_line] or lines[last_line].isspace(): + last_line -= 1 + return '\n'.join(lines[first_line: last_line + 1]) + + def remove_indentation(source: str) -> str: - pattern = re.compile(r'(?<=\n)\s*') - spaces = pattern.findall(f'\n{source}') - if not spaces: - return source - indent = min(map(len, spaces)) + pattern = re.compile(r'^\s*') + lines = source.split('\n') + spaces = [] + for line in lines: + if line and not line.isspace(): + spaces.append(len(pattern.search(line).group())) + + indent = min(spaces) lines = [l[indent:] for l in lines] return '\n'.join(lines) @@ -53,10 +69,10 @@ def reformat_note_text(source: str, spaces=4) -> str: if '\n' not in source and len(source) <= 90: return f"'{source}'" - # text = source.strip('\n') lines = [] line = '' - for word in source.split(' '): + text = remove_indentation(source.strip('\n')) + for word in text.split(' '): if len(line) > 80: lines.append(line) line = '' @@ -73,5 +89,5 @@ def reformat_note_text(source: str, spaces=4) -> str: lines.append(line) result = '\n'.join(lines).rstrip() result = f"'''\n{result}\n'''" - # result = indent((result)) + return f'{result}' diff --git a/test/test_parser.py b/test/test_parser.py index 8b1ab99..b355a6b 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -170,7 +170,7 @@ def test_note_is_idempotent(self): id integer Note { ''' - Indented note which is actually a Markdown formated string: + Indented note which is actually a Markdown formatted string: - List item 1 - Another list item @@ -185,7 +185,7 @@ def test(): } """ source_text = \ -"""Indented note which is actually a Markdown formated string: +"""Indented note which is actually a Markdown formatted string: - List item 1 - Another list item @@ -212,7 +212,7 @@ def test_unformatted_note_is_idempotent(self): id integer Note { ''' - Indented note which is actually a Markdown formated string: + Indented note which is actually a Markdown formatted string: - List item 1 - Another list item @@ -227,7 +227,7 @@ def test(): } """ source_text = \ -"""Indented note which is actually a Markdown formated string: +"""Indented note which is actually a Markdown formatted string: - List item 1 - Another list item From 92ad6945d20f5ba8002c91d225541e355db72f1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jens=20K=C3=B6ster?= Date: Tue, 12 Jul 2022 10:07:16 +0200 Subject: [PATCH 056/125] added the reformat_notes parameter to parse_file --- pydbml/parser/parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 4989bb2..05c704d 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -79,14 +79,14 @@ def parse(text: str, reformat_notes: bool = True) -> Database: return parser.parse() @staticmethod - def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: + def parse_file(file: Union[str, Path, TextIOWrapper], reformat_notes: bool = True) -> Database: if isinstance(file, TextIOWrapper): source = file.read() else: with open(file, encoding='utf8') as f: source = f.read() source = remove_bom(source) - parser = PyDBMLParser(source) + parser = PyDBMLParser(source, reformat_notes=reformat_notes) return parser.parse() From 39bb79a6482d6d038515bd4fc73683960ccf58a1 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 24 Jul 2022 10:30:44 +0200 Subject: [PATCH 057/125] update todo and docstring --- TODO.md | 5 ++--- pydbml/tools.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/TODO.md b/TODO.md index 825ba3a..ccfe00c 100644 --- a/TODO.md +++ b/TODO.md @@ -1,4 +1,3 @@ - schema.add and .delete to support multiple arguments (handle errors properly) -- add docstrings -- comment class -- add CI +- support escape sequences in multiline strings +- support 2.4.2 (many to many relationships) diff --git a/pydbml/tools.py b/pydbml/tools.py index 6c7f6a8..3e33ced 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -63,7 +63,7 @@ def remove_indentation(source: str) -> str: def reformat_note_text(source: str, spaces=4) -> str: """ - Add line breaks at approx 80-90 characters, indent text. + Add line breaks at approx 80-90 characters. If source is less than 90 characters and has no line breaks, leave it unchanged. """ if '\n' not in source and len(source) <= 90: From dc171b70a673873c4f4c6ebbb06d0b7753d9e273 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 24 Jul 2022 11:14:02 +0200 Subject: [PATCH 058/125] add tests for note bp preformat --- pydbml/parser/blueprints.py | 9 ++++++-- pydbml/tools.py | 1 + test/test_blueprints/test_note.py | 34 +++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 6943ad6..8ddd79f 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -31,9 +31,14 @@ class Blueprint: class NoteBlueprint(Blueprint): text: str + def _preformat_text(self) -> str: + '''Preformat the note text for idempotence''' + result = strip_empty_lines(self.text) + result = remove_indentation(result) + return result + def build(self) -> 'Note': - text = strip_empty_lines(self.text) - text = remove_indentation(text) + text = self._preformat_text() if self.parser: reformat = self.parser.options['reformat_notes'] return Note(text, reformat=reformat) diff --git a/pydbml/tools.py b/pydbml/tools.py index 3e33ced..b8bcd60 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -61,6 +61,7 @@ def remove_indentation(source: str) -> str: lines = [l[indent:] for l in lines] return '\n'.join(lines) + def reformat_note_text(source: str, spaces=4) -> str: """ Add line breaks at approx 80-90 characters. diff --git a/test/test_blueprints/test_note.py b/test/test_blueprints/test_note.py index 4ae20d0..85028fa 100644 --- a/test/test_blueprints/test_note.py +++ b/test/test_blueprints/test_note.py @@ -10,3 +10,37 @@ def test_build(self) -> None: result = bp.build() self.assertIsInstance(result, Note) self.assertEqual(result.text, bp.text) + + def test_preformat_not_needed(self): + oneline = 'One line of note text' + multiline = 'Multiline\nnote\n\ntext' + long_line = 'Lorem ipsum dolor sit amet consectetur adipisicing elit. Aspernatur quidem adipisci, impedit, ut illum dolorum consequatur odio voluptate numquam ea itaque excepturi, a libero placeat corrupti. Amet beatae suscipit necessitatibus. Ea expedita explicabo iste quae rem aliquam minus cumque eveniet enim delectus, alias aut impedit quaerat quia ex, aliquid sint amet iusto rerum! Sunt deserunt ea saepe corrupti officiis. Assumenda.' + + bp = NoteBlueprint(text=oneline) + self.assertEqual(bp._preformat_text(), oneline) + bp = NoteBlueprint(text=multiline) + self.assertEqual(bp._preformat_text(), multiline) + bp = NoteBlueprint(text=long_line) + self.assertEqual(bp._preformat_text(), long_line) + + def test_preformat_needed(self): + uniform_indentation = ' line1\n line2\n line3' + varied_indentation = ' line1\n line2\n\n line3' + empty_lines = '\n\n\n\n\n\n\nline1\nline2\nline3\n\n\n\n\n\n\n' + empty_indented_lines = '\n \n\n \n\n line1\n line2\n line3\n\n\n\n \n\n\n' + + exptected = 'line1\nline2\nline3' + bp = NoteBlueprint(text=uniform_indentation) + self.assertEqual(bp._preformat_text(), exptected) + + exptected = 'line1\n line2\n\n line3' + bp = NoteBlueprint(text=varied_indentation) + self.assertEqual(bp._preformat_text(), exptected) + + exptected = 'line1\nline2\nline3' + bp = NoteBlueprint(text=empty_lines) + self.assertEqual(bp._preformat_text(), exptected) + + exptected = 'line1\nline2\nline3' + bp = NoteBlueprint(text=empty_indented_lines) + self.assertEqual(bp._preformat_text(), exptected) From 65fe1a80998da2344247bf70b1f4b4215b5f4ed0 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 24 Jul 2022 11:39:57 +0200 Subject: [PATCH 059/125] remove note reformatting --- pydbml/classes/note.py | 18 +++----- pydbml/parser/blueprints.py | 6 +-- pydbml/parser/parser.py | 19 +++----- pydbml/tools.py | 6 ++- test/test_classes/test_note.py | 60 ++----------------------- test/test_parser.py | 80 +--------------------------------- 6 files changed, 25 insertions(+), 164 deletions(-) diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index b45e220..0460ca9 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -1,19 +1,18 @@ from typing import Any from .base import SQLObject -from pydbml.tools import reformat_note_text -from pydbml.tools import remove_indentation +# from pydbml.tools import reformat_note_text +# from pydbml.tools import remove_indentation from pydbml.tools import indent class Note(SQLObject): - def __init__(self, text: Any, reformat: bool = True): + def __init__(self, text: Any): + self.text: str if isinstance(text, Note): self.text = text.text - self.reformat = text.reformat else: self.text = str(text) if text else '' - self.reformat = reformat def __str__(self): ''' @@ -43,13 +42,10 @@ def sql(self): @property def dbml(self): - if self.reformat: - note_text = reformat_note_text(self.text) + if '\n' in self.text: + note_text = f"'''\n{self.text}\n'''" else: - if '\n' in self.text: - note_text = f"'''\n{self.text}'''" - else: - note_text = f"'{self.text}'" + note_text = f"'{self.text}'" note_text = indent(note_text) result = f'Note {{\n{note_text}\n}}' diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 8ddd79f..aff27ae 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -39,11 +39,7 @@ def _preformat_text(self) -> str: def build(self) -> 'Note': text = self._preformat_text() - if self.parser: - reformat = self.parser.options['reformat_notes'] - return Note(text, reformat=reformat) - else: - return Note(text) + return Note(text) @dataclass diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 05c704d..9fd0531 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -45,9 +45,7 @@ class PyDBML: >>> p = PyDBML(Path('test_schema.dbml')) ''' - def __new__(cls, - source_: Optional[Union[str, Path, TextIOWrapper]] = None, - reformat_notes: bool = True): + def __new__(cls, source_: Optional[Union[str, Path, TextIOWrapper]] = None): if source_ is not None: if isinstance(source_, str): source = source_ @@ -60,7 +58,7 @@ def __new__(cls, raise TypeError('Source must be str, path or file stream') source = remove_bom(source) - return cls.parse(source, reformat_notes) + return cls.parse(source) else: return super().__new__(cls) @@ -73,30 +71,27 @@ def __repr__(self): return "" @staticmethod - def parse(text: str, reformat_notes: bool = True) -> Database: + def parse(text: str) -> Database: text = remove_bom(text) - parser = PyDBMLParser(text, reformat_notes) + parser = PyDBMLParser(text) return parser.parse() @staticmethod - def parse_file(file: Union[str, Path, TextIOWrapper], reformat_notes: bool = True) -> Database: + def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: if isinstance(file, TextIOWrapper): source = file.read() else: with open(file, encoding='utf8') as f: source = f.read() source = remove_bom(source) - parser = PyDBMLParser(source, reformat_notes=reformat_notes) + parser = PyDBMLParser(source) return parser.parse() class PyDBMLParser: - def __init__(self, source: str, reformat_notes: bool = True): + def __init__(self, source: str): self.database = None - self.options = { - 'reformat_notes': reformat_notes - } self.ref_blueprints: List[ReferenceBlueprint] = [] self.table_groups: List[TableGroupBlueprint] = [] self.source = source diff --git a/pydbml/tools.py b/pydbml/tools.py index b8bcd60..bb9f852 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -55,7 +55,9 @@ def remove_indentation(source: str) -> str: spaces = [] for line in lines: if line and not line.isspace(): - spaces.append(len(pattern.search(line).group())) + indent_match = pattern.search(line) + if indent_match is not None: # this is just for you mypy + spaces.append(len(indent_match[0])) indent = min(spaces) lines = [l[indent:] for l in lines] @@ -64,6 +66,8 @@ def remove_indentation(source: str) -> str: def reformat_note_text(source: str, spaces=4) -> str: """ + Currently not used. + Add line breaks at approx 80-90 characters. If source is less than 90 characters and has no line breaks, leave it unchanged. """ diff --git a/test/test_classes/test_note.py b/test/test_classes/test_note.py index c08d731..4a6b8e1 100644 --- a/test/test_classes/test_note.py +++ b/test/test_classes/test_note.py @@ -24,27 +24,14 @@ def test_oneline(self): }''' self.assertEqual(note.dbml, expected) - def test_multiline(self): - note = Note('The number of spaces you use to indent a block string will be the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') - expected = \ -"""Note { - ''' - The number of spaces you use to indent a block string will be the minimum number - of leading spaces among all lines. The parser will automatically remove the number - of indentation spaces in the final output. - ''' -}""" - self.assertEqual(note.dbml, expected) - def test_forced_multiline(self): - note = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') + note = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.') expected = \ """Note { ''' The number of spaces you use to indent a block string will - be the minimum number of leading spaces among all lines. The parser will automatically - remove the number of indentation spaces in the final output. + be the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output. ''' }""" self.assertEqual(note.dbml, expected) @@ -54,48 +41,9 @@ def test_sql(self) -> None: self.assertEqual(note1.sql, '') note2 = Note('One line of note text') self.assertEqual(note2.sql, '-- One line of note text') - note3 = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.') + note3 = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.') expected = \ """-- The number of spaces you use to indent a block string -- will --- be the minimum number of leading spaces among all lines. The parser will automatically remove the number of indentation spaces in the final output.""" +-- be the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.""" self.assertEqual(note3.sql, expected) - -class TestNoteReformat(TestCase): - def test_auto_reformat_long_line(self): - n = Note('Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.') - expected = \ -"""Note { - ''' - Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem - facere eos, quod error consectetur. - ''' -}""" - self.assertEqual(n.dbml, expected) - - def test_long_line_reformat_off(self): - n = Note( - 'Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.', - reformat=False - ) - expected = "Note {\n 'Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.'\n}" - self.assertEqual(n.dbml, expected) - - def test_short_line(self): - n = Note('Short line of note text.', reformat=False) - expected = "Note {\n 'Short line of note text.'\n}" - self.assertEqual(n.dbml, expected) - n.reformat = True - self.assertEqual(n.dbml, expected) - - def test_reindent(self): - n = Note(' Line1\n Line2\n Line3\n Line4') - expected = """Note { - ''' - Line1 - Line2 - Line3 - Line4 - ''' -}""" - self.assertEqual(n.dbml, expected) diff --git a/test/test_parser.py b/test/test_parser.py index b355a6b..a20bf85 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -127,43 +127,7 @@ def test_edge(self) -> None: p.parse_blueprint(1, 1, [1]) -class TestReformatNotes(TestCase): - def test_reformat_by_default(self): - p = PyDBML(TEST_DATA_PATH / 'notes.dbml') - - for table in p.tables: - if table.note: - self.assertTrue(table.note.reformat) - for col in table.columns: - if col.note: - self.assertTrue(col.note.reformat) - for index in table.indexes: - if index.note: - self.assertTrue(index.note.reformat) - for enum in p.enums: - for item in enum.items: - if item.note: - self.assertTrue(item.note.reformat) - self.assertTrue(p.project.note.reformat) - - def test_reformat_off(self): - p = PyDBML(TEST_DATA_PATH / 'notes.dbml', reformat_notes=False) - - for table in p.tables: - if table.note: - self.assertFalse(table.note.reformat) - for col in table.columns: - if col.note: - self.assertFalse(col.note.reformat) - for index in table.indexes: - if index.note: - self.assertFalse(index.note.reformat) - for enum in p.enums: - for item in enum.items: - if item.note: - self.assertFalse(item.note.reformat) - self.assertFalse(p.project.note.reformat) - +class TestNotesIdempotent(TestCase): def test_note_is_idempotent(self): dbml_source = """ Table test { @@ -204,45 +168,3 @@ def test(): p_mod = PyDBML(p_mod.dbml) note2 = p_mod.tables[0].note self.assertEqual(source_text, note2.text) - - - def test_unformatted_note_is_idempotent(self): - dbml_source = """ -Table test { - id integer - Note { - ''' - Indented note which is actually a Markdown formatted string: - - - List item 1 - - Another list item - - ```python - def test(): - print('Hello world!') - return 1 - ``` - ''' - } -} -""" - source_text = \ -"""Indented note which is actually a Markdown formatted string: - -- List item 1 -- Another list item - -```python -def test(): - print('Hello world!') - return 1 -```""" - p = PyDBML(dbml_source, reformat_notes=False) - note = p.tables[0].note - self.assertEqual(source_text, note.text) - - p_mod = p - for _ in range(10): - p_mod = PyDBML(p_mod.dbml, reformat_notes=False) - note2 = p_mod.tables[0].note - self.assertEqual(source_text, note2.text) From 3907068092259d186af6c8664dad6747236b780d Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 24 Jul 2022 11:41:59 +0200 Subject: [PATCH 060/125] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30ac75a..e759db6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # 1.0.2 - Fix: inline ref schema bug, thanks to @jens-koster +- Fix: notes are not idempotent, thanks @jens-koster for reporting +- Fix: note objects are now supported in project definition, thanks @jens-koster for reporting # 1.0.1 From e032e0b14160671de0951efc33535e205807b70b Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 24 Jul 2022 12:09:11 +0200 Subject: [PATCH 061/125] fix schema parsing in TableGroupBlueprint --- CHANGELOG.md | 5 +++-- pydbml/parser/blueprints.py | 2 +- test/test_blueprints/test_table_group.py | 24 ++++++++++++++++++++++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e759db6..7233fce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,9 @@ # 1.0.2 - Fix: inline ref schema bug, thanks to @jens-koster -- Fix: notes are not idempotent, thanks @jens-koster for reporting -- Fix: note objects are now supported in project definition, thanks @jens-koster for reporting +- Fix: (#16) notes were not idempotent, thanks @jens-koster for reporting +- Fix: (#15) note objects were not supported in project definition, thanks @jens-koster for reporting +- Fix: (#20) schema didn't work in table group definition, thanks @mjfii for reporting # 1.0.1 diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index aff27ae..deee85c 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -279,7 +279,7 @@ def build(self) -> 'TableGroup': items = [] for table_name in self.items: components = table_name.split('.') - schema, table = components if len(components) == 2 else 'public', components[0] + schema, table = components if len(components) == 2 else ('public', components[0]) items.append(self.parser.locate_table(schema, table)) return TableGroup( name=self.name, diff --git a/test/test_blueprints/test_table_group.py b/test/test_blueprints/test_table_group.py index 69331dc..48a262f 100644 --- a/test/test_blueprints/test_table_group.py +++ b/test/test_blueprints/test_table_group.py @@ -27,3 +27,27 @@ def test_build(self) -> None: self.assertEqual(parserMock.locate_table.call_count, 2) for i in result.items: self.assertIsInstance(i, Table) + + def test_build_with_schema(self) -> None: + bp = TableGroupBlueprint( + name='TestTableGroup', + items=['myschema.table1', 'myschema.table2'], + comment='Comment text' + ) + with self.assertRaises(RuntimeError): + bp.build() + + parserMock = Mock() + parserMock.locate_table.side_effect = [ + Table(name='table1', schema='myschema'), + Table(name='table2', schema='myschema') + ] + bp.parser = parserMock + result = bp.build() + self.assertIsInstance(result, TableGroup) + locate_table_calls = parserMock.locate_table.call_args_list + self.assertEqual(len(locate_table_calls), 2) + self.assertEqual(locate_table_calls[0].args, ('myschema', 'table1')) + self.assertEqual(locate_table_calls[1].args, ('myschema', 'table2')) + for i in result.items: + self.assertIsInstance(i, Table) From a929f84651a74a973292fe97e1249d110bc33a3b Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 24 Jul 2022 16:14:04 +0200 Subject: [PATCH 062/125] note parent backref, support escape newline in notes, escape quotes in sql and dbml, note.sql depending on parent --- CHANGELOG.md | 3 ++ TODO.md | 1 - pydbml/classes/column.py | 9 +++++ pydbml/classes/enum.py | 9 +++++ pydbml/classes/index.py | 9 +++++ pydbml/classes/note.py | 49 ++++++++++++++++++++++---- pydbml/classes/project.py | 9 +++++ pydbml/classes/table.py | 13 +++++-- test/test_classes/test_column.py | 8 ++++- test/test_classes/test_enum.py | 7 ++++ test/test_classes/test_index.py | 9 +++++ test/test_classes/test_note.py | 57 ++++++++++++++++++++++++++++++- test/test_classes/test_project.py | 7 ++++ test/test_classes/test_table.py | 6 ++++ 14 files changed, 184 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7233fce..9cdc53b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,12 @@ # 1.0.2 +- New: "backslash newline" is supported in note text (line continuation) +- New: notes have reference to their parent. Note.sql now depends on type of parent (for tables and columns it's COMMENT ON clause) - Fix: inline ref schema bug, thanks to @jens-koster - Fix: (#16) notes were not idempotent, thanks @jens-koster for reporting - Fix: (#15) note objects were not supported in project definition, thanks @jens-koster for reporting - Fix: (#20) schema didn't work in table group definition, thanks @mjfii for reporting +- Fix: quotes in note text broke sql and dbml # 1.0.1 diff --git a/TODO.md b/TODO.md index ccfe00c..5926ed3 100644 --- a/TODO.md +++ b/TODO.md @@ -1,3 +1,2 @@ - schema.add and .delete to support multiple arguments (handle errors properly) -- support escape sequences in multiline strings - support 2.4.2 (many to many relationships) diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 97a22df..954a2b9 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -45,6 +45,15 @@ def __init__(self, self.default = default self.table: Optional['Table'] = None + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + def get_refs(self) -> List['Reference']: ''' get all references related to this column (where this col is col1 in) diff --git a/pydbml/classes/enum.py b/pydbml/classes/enum.py index 63b0bfe..ec4ad14 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/classes/enum.py @@ -22,6 +22,15 @@ def __init__(self, self.note = Note(note) self.comment = comment + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + def __repr__(self): ''' >>> EnumItem('en-US') diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index c2c0b62..b51a4b4 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -38,6 +38,15 @@ def __init__(self, self.note = Note(note) self.comment = comment + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + @property def subject_names(self): ''' diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index 0460ca9..2482b5a 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -1,18 +1,20 @@ +import re from typing import Any from .base import SQLObject -# from pydbml.tools import reformat_note_text -# from pydbml.tools import remove_indentation from pydbml.tools import indent +from pydbml import classes class Note(SQLObject): + def __init__(self, text: Any): self.text: str if isinstance(text, Note): self.text = text.text else: self.text = str(text) if text else '' + self.parent: Any = None def __str__(self): ''' @@ -33,19 +35,54 @@ def __repr__(self): return f'Note({repr(self.text)})' + def _prepare_text_for_sql(self) -> str: + ''' + - Process special escape sequence: slash before line break, which means no line break + https://www.dbml.org/docs/#multi-line-string + + - replace all single quotes with double quotes + ''' + + pattern = re.compile(r'\\\n') + result = pattern.sub('', self.text) + + result = result.replace("'", '"') + return result + + def _prepare_text_for_dbml(self): + '''Escape single quotes''' + pattern = re.compile(r"('''|')") + return pattern.sub(r'\\\1', self.text) + + def generate_comment_on(self, entity: str, name: str) -> str: + """Generate a COMMENT ON clause out from this note.""" + quoted_text = f"'{self._prepare_text_for_sql()}'" + note_sql = f'COMMENT ON {entity.upper()} "{name}" IS {quoted_text};' + return note_sql + @property def sql(self): + """ + For Tables and Columns Note is converted into COMMENT ON clause. All other entities don't + have notes generated in their SQL code, but as a fallback their notes are rendered as SQL + comments when sql property is called directly. + """ if self.text: - return '\n'.join(f'-- {line}' for line in self.text.split('\n')) + if isinstance(self.parent, (classes.Table, classes.Column)): + return self.generate_comment_on(self.parent.__class__.__name__, self.parent.name) + else: + text = self._prepare_text_for_sql() + return '\n'.join(f'-- {line}' for line in text.split('\n')) else: return '' @property def dbml(self): - if '\n' in self.text: - note_text = f"'''\n{self.text}\n'''" + text = self._prepare_text_for_dbml() + if '\n' in text: + note_text = f"'''\n{text}\n'''" else: - note_text = f"'{self.text}'" + note_text = f"'{text}'" note_text = indent(note_text) result = f'Note {{\n{note_text}\n}}' diff --git a/pydbml/classes/project.py b/pydbml/classes/project.py index 582f8c4..3133a85 100644 --- a/pydbml/classes/project.py +++ b/pydbml/classes/project.py @@ -27,6 +27,15 @@ def __repr__(self): return f'' + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + @property def dbml(self): result = comment_to_dbml(self.comment) if self.comment else '' diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 82515f6..b01e6c7 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -52,6 +52,15 @@ def __init__(self, self.header_color = header_color self.comment = comment + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + @property def full_name(self) -> str: return f'{self.schema}.{self.name}' @@ -199,9 +208,7 @@ def sql(self): result += '\n'.join(components) if self.note: - quoted_note = f"'{self.note.text}'" - note_sql = f'COMMENT ON TABLE "{self.name}" IS {quoted_note};' - result += f'\n\n{note_sql}' + result += f'\n\n{self.note.sql}' for col in self.columns: if col.note: diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index cfd6b61..cd1975d 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -39,7 +39,7 @@ def test_attributes(self) -> None: self.assertEqual(col.pk, pk) self.assertEqual(col.autoinc, autoinc) self.assertEqual(col.default, default) - self.assertEqual(col.note, note) + self.assertEqual(col.note.text, note.text) self.assertEqual(col.comment, comment) def test_database_set(self) -> None: @@ -289,3 +289,9 @@ def test_dbml_with_ref_and_properties(self) -> None: self.assertEqual(c2.dbml, expected) expected = '"client_id" integer' self.assertEqual(c1.dbml, expected) + + def test_note_property(self): + note1 = Note('column note') + c1 = Column(name='client_id', type='integer') + c1.note = note1 + self.assertIs(c1.note.parent, c1) diff --git a/test/test_classes/test_enum.py b/test/test_classes/test_enum.py index b28e84d..e3c8bb0 100644 --- a/test/test_classes/test_enum.py +++ b/test/test_classes/test_enum.py @@ -1,5 +1,6 @@ from pydbml.classes import Enum from pydbml.classes import EnumItem +from pydbml.classes import Note from unittest import TestCase @@ -21,6 +22,12 @@ def test_dbml_full(self): "en-US" [note: 'preferred']''' self.assertEqual(ei.dbml, expected) + def test_note_property(self): + note1 = Note('enum item note') + ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') + ei.note = note1 + self.assertIs(ei.note.parent, ei) + class TestEnum(TestCase): def test_simple_enum(self) -> None: diff --git a/test/test_classes/test_index.py b/test/test_classes/test_index.py index 9c77571..9710d2a 100644 --- a/test/test_classes/test_index.py +++ b/test/test_classes/test_index.py @@ -3,6 +3,7 @@ from pydbml.classes import Column from pydbml.classes import Expression from pydbml.classes import Index +from pydbml.classes import Note from pydbml.classes import Table from pydbml.exceptions import ColumnNotFoundError @@ -128,3 +129,11 @@ def test_dbml_full(self): '''// Comment on the index (id, `getdate()`) [name: 'Dated id', pk, unique, type: hash, note: 'Note on the column']''' self.assertEqual(i.dbml, expected) + + def test_note_property(self): + note1 = Note('column note') + t = Table('products') + c = Column('id', 'integer') + i = Index(subjects=[c]) + i.note = note1 + self.assertIs(i.note.parent, i) diff --git a/test/test_classes/test_note.py b/test/test_classes/test_note.py index 4a6b8e1..8ff2a9d 100644 --- a/test/test_classes/test_note.py +++ b/test/test_classes/test_note.py @@ -1,4 +1,7 @@ from pydbml.classes import Note +from pydbml.classes import Table +from pydbml.classes import Index +from pydbml.classes import Column from unittest import TestCase @@ -36,7 +39,7 @@ def test_forced_multiline(self): }""" self.assertEqual(note.dbml, expected) - def test_sql(self) -> None: + def test_sql_general(self) -> None: note1 = Note(None) self.assertEqual(note1.sql, '') note2 = Note('One line of note text') @@ -47,3 +50,55 @@ def test_sql(self) -> None: -- will -- be the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.""" self.assertEqual(note3.sql, expected) + + def test_sql_table(self) -> None: + table = Table(name="test") + note1 = Note(None) + table.note = note1 + self.assertEqual(note1.sql, '') + note2 = Note('One line of note text') + table.note = note2 + self.assertEqual(note2.sql, 'COMMENT ON TABLE "test" IS \'One line of note text\';') + + def test_sql_column(self) -> None: + column = Column(name="test", type="int") + note1 = Note(None) + column.note = note1 + self.assertEqual(note1.sql, '') + note2 = Note('One line of note text') + column.note = note2 + self.assertEqual(note2.sql, 'COMMENT ON COLUMN "test" IS \'One line of note text\';') + + def test_sql_index(self) -> None: + t = Table('products') + t.add_column(Column('id', 'integer')) + index = Index(subjects=[t.columns[0]]) + + note1 = Note(None) + index.note = note1 + self.assertEqual(note1.sql, '') + note2 = Note('One line of note text') + index.note = note2 + self.assertEqual(note2.sql, '-- One line of note text') + + def test_prepare_text_for_sql(self): + line_escape = 'This text \\\nis not split \\\ninto lines' + quotes = "'asd' There's ''' asda '''' asd ''''' asdsa ''" + + note = Note(line_escape) + expected = 'This text is not split into lines' + self.assertEqual(note._prepare_text_for_sql(), expected) + + note = Note(quotes) + expected = '"asd" There"s """ asda """" asd """"" asdsa ""' + self.assertEqual(note._prepare_text_for_sql(), expected) + + def test_prepare_text_for_dbml(self): + quotes = "'asd' There's ''' asda '''' asd ''''' asdsa ''" + expected = "\\'asd\\' There\\'s \\''' asda \\'''\\' asd \\'''\\'\\' asdsa \\'\\'" + note = Note(quotes) + self.assertEqual(note._prepare_text_for_dbml(), expected) + + def test_escaped_newline_sql(self) -> None: + note = Note('One line of note text \\\nstill one line') + self.assertEqual(note.sql, '-- One line of note text still one line') diff --git a/test/test_classes/test_project.py b/test/test_classes/test_project.py index df86fa5..d77537a 100644 --- a/test/test_classes/test_project.py +++ b/test/test_classes/test_project.py @@ -1,4 +1,5 @@ from pydbml.classes import Project +from pydbml.classes import Note from unittest import TestCase @@ -47,3 +48,9 @@ def test_dbml_space(self) -> None: a: 'b' }''' self.assertEqual(p.dbml, expected) + + def test_note_property(self): + note1 = Note('column note') + p = Project('myproject') + p.note = note1 + self.assertIs(p.note.parent, p) diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 422a1fc..aeb4793 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -464,3 +464,9 @@ def test_dbml_full(self): } }""" self.assertEqual(t.dbml, expected) + + def test_note_property(self): + note1 = Note('table note') + t = Table(name='test') + t.note = note1 + self.assertIs(t.note.parent, t) From 389dd96c25dc51e433812d6fa4aa3bf8cf34dc1a Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 6 Aug 2022 12:22:16 +0200 Subject: [PATCH 063/125] support many to many --- CHANGELOG.md | 5 +- README.md | 2 +- TODO.md | 1 - pydbml/classes/column.py | 3 +- pydbml/classes/reference.py | 126 +++++++++++++++++----------- pydbml/classes/table.py | 21 ++++- pydbml/constants.py | 1 + pydbml/definitions/reference.py | 2 +- pydbml/parser/blueprints.py | 2 +- test/test_classes/test_reference.py | 79 +++++++++++++++++ test/test_classes/test_table.py | 21 +++++ 11 files changed, 205 insertions(+), 58 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cdc53b..de2fb3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,15 @@ # 1.0.2 - New: "backslash newline" is supported in note text (line continuation) -- New: notes have reference to their parent. Note.sql now depends on type of parent (for tables and columns it's COMMENT ON clause) +- New: notes have reference to their parent. Note.sql now depends on type of parent (for tables and columns it's COMMENT ON clause) +- New: pydbml no longer splits long notes into multiple lines - Fix: inline ref schema bug, thanks to @jens-koster - Fix: (#16) notes were not idempotent, thanks @jens-koster for reporting - Fix: (#15) note objects were not supported in project definition, thanks @jens-koster for reporting - Fix: (#20) schema didn't work in table group definition, thanks @mjfii for reporting - Fix: quotes in note text broke sql and dbml +- New: proper support of composite primary keys without creating an index +- New: support of many-to-many relationships # 1.0.1 diff --git a/README.md b/README.md index b2f3148..09d27c0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # DBML parser for Python -*Compliant with DBML **v2.4.1** syntax* +*Compliant with DBML **v2.4.2** syntax* PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. diff --git a/TODO.md b/TODO.md index 5926ed3..3fde65e 100644 --- a/TODO.md +++ b/TODO.md @@ -1,2 +1 @@ - schema.add and .delete to support multiple arguments (handle errors properly) -- support 2.4.2 (many to many relationships) diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 954a2b9..4743581 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -81,7 +81,8 @@ def sql(self): else: components.append(str(self.type)) - if self.pk: + table_has_composite_pk = False if self.table is None else self.table._has_composite_pk() + if self.pk and not table_has_composite_pk: # comp-PKs are rendered in table sql components.append('PRIMARY KEY') if self.autoinc: components.append('AUTOINCREMENT') diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 48ac163..7e74863 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -1,20 +1,24 @@ from typing import Collection from typing import Literal +from typing import List from typing import Optional -from typing import Union from typing import TYPE_CHECKING +from typing import Union +from itertools import chain -from .base import SQLObject -from .column import Column from pydbml.constants import MANY_TO_ONE +from pydbml.constants import MANY_TO_MANY +from pydbml.constants import ONE_TO_MANY from pydbml.constants import ONE_TO_ONE from pydbml.exceptions import DBMLError from pydbml.exceptions import TableNotFoundError from pydbml.tools import comment_to_dbml from pydbml.tools import comment_to_sql +from .base import SQLObject +from .column import Column -if TYPE_CHECKING: # pragma: no cover - from .table import Table +# if TYPE_CHECKING: # pragma: no cover +from .table import Table class Reference(SQLObject): @@ -26,7 +30,7 @@ class Reference(SQLObject): required_attributes = ('type', 'col1', 'col2') def __init__(self, - type: Literal['>', '<', '-'], + type: Literal['>', '<', '-', '<>'], col1: Union[Column, Collection[Column]], col2: Union[Column, Collection[Column]], name: Optional[str] = None, @@ -44,6 +48,20 @@ def __init__(self, self.on_delete = on_delete self.inline = inline + @property + def join_table(self) -> Optional['Table']: + if self.type != MANY_TO_MANY: + return + + return Table( + name=f'{self.table1.name}_{self.table2.name}', + columns=( + Column(name=f'{c.table.name}_{c.name}', type=c.type, not_null=True, pk=True) + for c in chain(self.col1, self.col2) + ), + abstract=True + ) + @property def table1(self) -> Optional['Table']: self._validate() @@ -101,8 +119,47 @@ def _validate_for_sql(self): if self.table2 is None: raise TableNotFoundError('Table on col2 is not set') + def _generate_inline_sql(self, source_col: List['Column'], ref_col: List['Column']) -> str: + result = comment_to_sql(self.comment) if self.comment else '' + result += ( + f'{{c}}FOREIGN KEY ({self._col_names(source_col)}) ' + f'REFERENCES {ref_col[0].table._get_full_name_for_sql()} ({self._col_names(ref_col)})' + ) + if self.on_update: + result += f' ON UPDATE {self.on_update.upper()}' + if self.on_delete: + result += f' ON DELETE {self.on_delete.upper()}' + return result + + def _generate_not_inline_sql(self, c1: List['Column'], c2: List['Column']): + result = comment_to_sql(self.comment) if self.comment else '' + result += ( + f'ALTER TABLE {c1[0].table._get_full_name_for_sql()} ADD {{c}}FOREIGN KEY ({self._col_names(c1)}) ' + f'REFERENCES {c2[0].table._get_full_name_for_sql()} ({self._col_names(c2)})' + ) + if self.on_update: + result += f' ON UPDATE {self.on_update.upper()}' + if self.on_delete: + result += f' ON DELETE {self.on_delete.upper()}' + return result + ';' + + def _generate_many_to_many_sql(self) -> str: + join_table = self.join_table + table_sql = join_table.sql + + n = len(self.col1) + ref1_sql = self._generate_not_inline_sql(join_table.columns[:n], self.col1) + ref2_sql = self._generate_not_inline_sql(join_table.columns[n:], self.col2) + + result = '\n\n'.join((table_sql, ref1_sql, ref2_sql)) + return result.format(c='') + + @staticmethod + def _col_names(cols: List[Column]) -> str: + return ', '.join(f'"{c.name}"' for c in cols) + @property - def sql(self): + def sql(self) -> str: ''' Returns SQL of the reference: @@ -111,55 +168,28 @@ def sql(self): ''' self.check_attributes_for_sql() self._validate_for_sql() - c = f'CONSTRAINT "{self.name}" ' if self.name else '' + if self.type == MANY_TO_MANY: + return self._generate_many_to_many_sql() + + result = '' if self.inline: if self.type in (MANY_TO_ONE, ONE_TO_ONE): - source_col = self.col1 - ref_table = self.table2 - ref_col = self.col2 - else: - source_col = self.col2 - ref_table = self.table1 - ref_col = self.col1 - - cols = '", "'.join(c.name for c in source_col) - ref_cols = '", "'.join(c.name for c in ref_col) - result = comment_to_sql(self.comment) if self.comment else '' - result += ( - f'{c}FOREIGN KEY ("{cols}") ' - f'REFERENCES {ref_table._get_full_name_for_sql()} ("{ref_cols}")' - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result + result = self._generate_inline_sql(self.col1, self.col2) + elif self.type == ONE_TO_MANY: + result = self._generate_inline_sql(self.col2, self.col1) else: if self.type in (MANY_TO_ONE, ONE_TO_ONE): - t1 = self.table1 - c1 = ', '.join(f'"{c.name}"' for c in self.col1) - t2 = self.table2 - c2 = ', '.join(f'"{c.name}"' for c in self.col2) - else: - t1 = self.table2 - c1 = ', '.join(f'"{c.name}"' for c in self.col2) - t2 = self.table1 - c2 = ', '.join(f'"{c.name}"' for c in self.col1) + result = self._generate_not_inline_sql(c1=self.col1, c2=self.col2) + elif self.type == ONE_TO_MANY: + result = self._generate_not_inline_sql(c1=self.col2, c2=self.col1) - result = comment_to_sql(self.comment) if self.comment else '' - result += ( - f'ALTER TABLE {t1._get_full_name_for_sql()} ADD {c}FOREIGN KEY ({c1}) ' - f'REFERENCES {t2._get_full_name_for_sql()} ({c2})' - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result + ';' + c = f'CONSTRAINT "{self.name}" ' if self.name else '' + + return result.format(c=c) @property - def dbml(self): + def dbml(self) -> str: self._validate_for_sql() if self.inline: # settings are ignored for inline ref diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index b01e6c7..bf13250 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -8,7 +8,6 @@ from .column import Column from .index import Index from .note import Note -from .reference import Reference from pydbml.constants import MANY_TO_ONE from pydbml.constants import ONE_TO_MANY from pydbml.constants import ONE_TO_ONE @@ -22,6 +21,7 @@ if TYPE_CHECKING: # pragma: no cover from pydbml.database import Database + from .reference import Reference class Table(SQLObject): @@ -37,7 +37,8 @@ def __init__(self, indexes: Optional[Iterable[Index]] = None, note: Optional[Union['Note', str]] = None, header_color: Optional[str] = None, - comment: Optional[str] = None): + comment: Optional[str] = None, + abstract: bool = False): self.database: Optional[Database] = None self.name = name self.schema = schema @@ -51,6 +52,7 @@ def __init__(self, self.note = Note(note) self.header_color = header_color self.comment = comment + self.abstract = abstract @property def note(self): @@ -65,6 +67,9 @@ def note(self, val: Note) -> None: def full_name(self) -> str: return f'{self.schema}.{self.name}' + def _has_composite_pk(self) -> bool: + return sum(c.pk for c in self.columns) > 1 + def add_column(self, c: Column) -> None: ''' Adds column to self.columns attribute and sets in this column the @@ -110,15 +115,17 @@ def delete_index(self, i: Union[Index, int]) -> Index: self.indexes[i].table = None return self.indexes.pop(i) - def get_refs(self) -> List[Reference]: + def get_refs(self) -> List['Reference']: if not self.database: raise UnknownDatabaseError('Database for the table is not set') return [ref for ref in self.database.refs if ref.table1 == self] - def _get_references_for_sql(self) -> List[Reference]: + def _get_references_for_sql(self) -> List['Reference']: ''' return inline references for this table sql definition ''' + if self.abstract: + return [] if not self.database: raise UnknownDatabaseError(f'Database for the table {self} is not set') result = [] @@ -200,6 +207,12 @@ def sql(self): body.extend(indent(c.sql, 2) for c in self.columns) body.extend(indent(i.sql, 2) for i in self.indexes if i.pk) body.extend(indent(r.sql, 2) for r in self._get_references_for_sql()) + + if self._has_composite_pk(): + body.append( + " PRIMARY KEY (" + + ', '.join(f'"{c.name}"' for c in self.columns if c.pk) + + ')') components.append(',\n'.join(body)) components.append(');') components.extend('\n' + i.sql for i in self.indexes if not i.pk) diff --git a/pydbml/constants.py b/pydbml/constants.py index e4fa877..712ac61 100644 --- a/pydbml/constants.py +++ b/pydbml/constants.py @@ -1,3 +1,4 @@ ONE_TO_MANY = '<' MANY_TO_ONE = '>' ONE_TO_ONE = '-' +MANY_TO_MANY = '<>' diff --git a/pydbml/definitions/reference.py b/pydbml/definitions/reference.py index c8eebf8..9b6450e 100644 --- a/pydbml/definitions/reference.py +++ b/pydbml/definitions/reference.py @@ -9,7 +9,7 @@ pp.ParserElement.set_default_whitespace_chars(' \t\r') -relation = pp.oneOf("> - <") +relation = pp.oneOf("> - < <>") col_name = ( ( diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index deee85c..f052977 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -52,7 +52,7 @@ def build(self) -> Expression: @dataclass class ReferenceBlueprint(Blueprint): - type: Literal['>', '<', '-'] + type: Literal['>', '<', '-', '<>'] inline: bool name: Optional[str] = None schema1: str = 'public' diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 9a32533..3a7be58 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -115,6 +115,85 @@ def test_sql_full(self): self.assertEqual(ref.sql, expected) + def test_many_to_many_sql_simple(self) -> None: + t1 = Table('books') + c11 = Column('id', 'integer', pk=True) + c12 = Column('author', 'varchar') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('authors') + c21 = Column('id', 'integer', pk=True) + c22 = Column('name', 'varchar') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('<>', c11, c21) + + expected = \ +'''CREATE TABLE "books_authors" ( + "books_id" integer NOT NULL, + "authors_id" integer NOT NULL, + PRIMARY KEY ("books_id", "authors_id") +); + +ALTER TABLE "books_authors" ADD FOREIGN KEY ("books_id") REFERENCES "books" ("id"); + +ALTER TABLE "books_authors" ADD FOREIGN KEY ("authors_id") REFERENCES "authors" ("id");''' + self.assertEqual(expected, ref.sql) + + def test_many_to_many_sql_composite(self) -> None: + t1 = Table('books') + c11 = Column('id', 'integer', pk=True) + c12 = Column('author', 'varchar') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('authors') + c21 = Column('id', 'integer', pk=True) + c22 = Column('name', 'varchar') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('<>', [c11, c12], [c21, c22]) + + expected = \ +'''CREATE TABLE "books_authors" ( + "books_id" integer NOT NULL, + "books_author" varchar NOT NULL, + "authors_id" integer NOT NULL, + "authors_name" varchar NOT NULL, + PRIMARY KEY ("books_id", "books_author", "authors_id", "authors_name") +); + +ALTER TABLE "books_authors" ADD FOREIGN KEY ("books_id", "books_author") REFERENCES "books" ("id", "author"); + +ALTER TABLE "books_authors" ADD FOREIGN KEY ("authors_id", "authors_name") REFERENCES "authors" ("id", "name");''' + self.assertEqual(expected, ref.sql) + + def test_many_to_many_sql_composite_different_schemas(self) -> None: + t1 = Table('books', schema="schema1") + c11 = Column('id', 'integer', pk=True) + c12 = Column('author', 'varchar') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('authors', schema="schema2") + c21 = Column('id', 'integer', pk=True) + c22 = Column('name', 'varchar') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('<>', [c11, c12], [c21, c22]) + + expected = \ +'''CREATE TABLE "books_authors" ( + "books_id" integer NOT NULL, + "books_author" varchar NOT NULL, + "authors_id" integer NOT NULL, + "authors_name" varchar NOT NULL, + PRIMARY KEY ("books_id", "books_author", "authors_id", "authors_name") +); + +ALTER TABLE "books_authors" ADD FOREIGN KEY ("books_id", "books_author") REFERENCES "schema1"."books" ("id", "author"); + +ALTER TABLE "books_authors" ADD FOREIGN KEY ("authors_id", "authors_name") REFERENCES "schema2"."authors" ("id", "name");''' + self.assertEqual(expected, ref.sql) + def test_dbml_simple(self): t = Table('products') c1 = Column('id', 'integer') diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index aeb4793..ac9f0f6 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -242,6 +242,27 @@ def test_index_inline_and_comments(self) -> None: );''' self.assertEqual(t.sql, expected) + def test_composite_pk_sql(self): + table = Table( + 'products', + columns=( + Column('id', 'integer', pk=True), + Column('name', 'varchar2', pk=True), + Column('prop', 'object', pk=True), + ) + ) + s = Database() + s.add(table) + + expected = \ +'''CREATE TABLE "products" ( + "id" integer, + "name" varchar2, + "prop" object, + PRIMARY KEY ("id", "name", "prop") +);''' + self.assertEqual(table.sql, expected) + def test_add_column(self) -> None: t = Table('products') c1 = Column('id', 'integer') From b600fdfc81d10725fb53bd486dd6e66ac3c2799a Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 6 Aug 2022 12:44:24 +0200 Subject: [PATCH 064/125] Update readme, add test for join table --- README.md | 2 ++ pydbml/tools.py | 34 ----------------------------- test/test_classes/test_reference.py | 18 +++++++++++++++ 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 09d27c0..8c94615 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. * [Creating DBML schema](docs/creating_schema.md) * [Upgrading to PyDBML 1.0.0](docs/upgrading.md) +> PyDBML requires Python v3.8 or higher + ## Installation You can install PyDBML using pip: diff --git a/pydbml/tools.py b/pydbml/tools.py index bb9f852..947a814 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -62,37 +62,3 @@ def remove_indentation(source: str) -> str: indent = min(spaces) lines = [l[indent:] for l in lines] return '\n'.join(lines) - - -def reformat_note_text(source: str, spaces=4) -> str: - """ - Currently not used. - - Add line breaks at approx 80-90 characters. - If source is less than 90 characters and has no line breaks, leave it unchanged. - """ - if '\n' not in source and len(source) <= 90: - return f"'{source}'" - - lines = [] - line = '' - text = remove_indentation(source.strip('\n')) - for word in text.split(' '): - if len(line) > 80: - lines.append(line) - line = '' - if '\n' in word: - sublines = word.split('\n') - for sl in sublines[:-1]: - line += sl - lines.append(line) - line = '' - line = sublines[-1] + ' ' - else: - line += f'{word} ' - if line: - lines.append(line) - result = '\n'.join(lines).rstrip() - result = f"'''\n{result}\n'''" - - return f'{result}' diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 3a7be58..3d8d84a 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -194,6 +194,24 @@ def test_many_to_many_sql_composite_different_schemas(self) -> None: ALTER TABLE "books_authors" ADD FOREIGN KEY ("authors_id", "authors_name") REFERENCES "schema2"."authors" ("id", "name");''' self.assertEqual(expected, ref.sql) + def test_join_table(self) -> None: + t1 = Table('books') + c11 = Column('id', 'integer', pk=True) + c12 = Column('author', 'varchar') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('authors') + c21 = Column('id', 'integer', pk=True) + c22 = Column('name', 'varchar') + t2.add_column(c21) + t2.add_column(c22) + ref0 = Reference('>', [c11], [c21]) + ref = Reference('<>', [c11, c12], [c21, c22]) + + self.assertIsNone(ref0.join_table) + self.assertEqual(ref.join_table.name, 'books_authors') + self.assertEqual(len(ref.join_table.columns), 4) + def test_dbml_simple(self): t = Table('products') c1 = Column('id', 'integer') From b5331341b5d3248f54b7181a047f3122f3a3af63 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 6 Aug 2022 13:25:43 +0200 Subject: [PATCH 065/125] Fix typing --- pydbml/classes/reference.py | 42 +++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index 7e74863..badad4d 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -1,13 +1,12 @@ +from itertools import chain from typing import Collection -from typing import Literal from typing import List +from typing import Literal from typing import Optional -from typing import TYPE_CHECKING from typing import Union -from itertools import chain -from pydbml.constants import MANY_TO_ONE from pydbml.constants import MANY_TO_MANY +from pydbml.constants import MANY_TO_ONE from pydbml.constants import ONE_TO_MANY from pydbml.constants import ONE_TO_ONE from pydbml.exceptions import DBMLError @@ -16,8 +15,6 @@ from pydbml.tools import comment_to_sql from .base import SQLObject from .column import Column - -# if TYPE_CHECKING: # pragma: no cover from .table import Table @@ -51,12 +48,17 @@ def __init__(self, @property def join_table(self) -> Optional['Table']: if self.type != MANY_TO_MANY: - return + return None + + if self.table1 is None: + raise TableNotFoundError(f"Cannot generate join table for {self}: table 1 is unknown") + if self.table2 is None: + raise TableNotFoundError(f"Cannot generate join table for {self}: table 2 is unknown") return Table( name=f'{self.table1.name}_{self.table2.name}', columns=( - Column(name=f'{c.table.name}_{c.name}', type=c.type, not_null=True, pk=True) + Column(name=f'{c.table.name}_{c.name}', type=c.type, not_null=True, pk=True) # type: ignore for c in chain(self.col1, self.col2) ), abstract=True @@ -114,15 +116,14 @@ def _validate(self): raise DBMLError('Columns in col2 are from different tables') def _validate_for_sql(self): - if self.table1 is None: - raise TableNotFoundError('Table on col1 is not set') - if self.table2 is None: - raise TableNotFoundError('Table on col2 is not set') + for col in chain(self.col1, self.col2): + if col.table is None: + raise TableNotFoundError(f'Table on {col} is not set') def _generate_inline_sql(self, source_col: List['Column'], ref_col: List['Column']) -> str: result = comment_to_sql(self.comment) if self.comment else '' result += ( - f'{{c}}FOREIGN KEY ({self._col_names(source_col)}) ' + f'{{c}}FOREIGN KEY ({self._col_names(source_col)}) ' # type: ignore f'REFERENCES {ref_col[0].table._get_full_name_for_sql()} ({self._col_names(ref_col)})' ) if self.on_update: @@ -134,8 +135,9 @@ def _generate_inline_sql(self, source_col: List['Column'], ref_col: List['Column def _generate_not_inline_sql(self, c1: List['Column'], c2: List['Column']): result = comment_to_sql(self.comment) if self.comment else '' result += ( - f'ALTER TABLE {c1[0].table._get_full_name_for_sql()} ADD {{c}}FOREIGN KEY ({self._col_names(c1)}) ' - f'REFERENCES {c2[0].table._get_full_name_for_sql()} ({self._col_names(c2)})' + f'ALTER TABLE {c1[0].table._get_full_name_for_sql()}' # type: ignore + f' ADD {{c}}FOREIGN KEY ({self._col_names(c1)})' + f' REFERENCES {c2[0].table._get_full_name_for_sql()} ({self._col_names(c2)})' ) if self.on_update: result += f' ON UPDATE {self.on_update.upper()}' @@ -145,11 +147,11 @@ def _generate_not_inline_sql(self, c1: List['Column'], c2: List['Column']): def _generate_many_to_many_sql(self) -> str: join_table = self.join_table - table_sql = join_table.sql + table_sql = join_table.sql # type: ignore n = len(self.col1) - ref1_sql = self._generate_not_inline_sql(join_table.columns[:n], self.col1) - ref2_sql = self._generate_not_inline_sql(join_table.columns[n:], self.col2) + ref1_sql = self._generate_not_inline_sql(join_table.columns[:n], self.col1) # type: ignore + ref2_sql = self._generate_not_inline_sql(join_table.columns[n:], self.col2) # type: ignore result = '\n\n'.join((table_sql, ref1_sql, ref2_sql)) return result.format(c='') @@ -195,7 +197,7 @@ def dbml(self) -> str: # settings are ignored for inline ref if len(self.col2) > 1: raise DBMLError('Cannot render DBML: composite ref cannot be inline') - table_name = self.col2[0].table._get_full_name_for_sql() + table_name = self.col2[0].table._get_full_name_for_sql() # type: ignore return f'ref: {self.type} {table_name}."{self.col2[0].name}"' else: result = comment_to_dbml(self.comment) if self.comment else '' @@ -223,7 +225,7 @@ def dbml(self) -> str: options_str = f' [{", ".join(options)}]' if options else '' result += ( - ' {\n ' + ' {\n ' # type: ignore f'{self.table1._get_full_name_for_sql()}.{col1} ' f'{self.type} ' f'{self.table2._get_full_name_for_sql()}.{col2}' From f14833c580171db709113647bf7bcc7cbe9e1796 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 6 Aug 2022 13:26:06 +0200 Subject: [PATCH 066/125] bump version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 079f363..23def9e 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.1', + version='1.0.2', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From d6038059045d16f4d3e4c3931f86a7ecb0455340 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 25 Sep 2022 12:02:19 +0200 Subject: [PATCH 067/125] Fix inline many to many references didn't work --- pydbml/classes/reference.py | 10 +++++++++- test/test_classes/test_reference.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index badad4d..e94e523 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -43,7 +43,15 @@ def __init__(self, self.comment = comment self.on_update = on_update self.on_delete = on_delete - self.inline = inline + self._inline = inline + + @property + def inline(self) -> bool: + return self._inline and not self.type == MANY_TO_MANY + + @inline.setter + def inline(self, val) -> None: + self._inline = val @property def join_table(self) -> Optional['Table']: diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 3d8d84a..7414a1e 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -212,6 +212,29 @@ def test_join_table(self) -> None: self.assertEqual(ref.join_table.name, 'books_authors') self.assertEqual(len(ref.join_table.columns), 4) + def test_join_table_none(self) -> None: + t1 = Table('books') + c11 = Column('id', 'integer', pk=True) + c12 = Column('author', 'varchar') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('authors') + c21 = Column('id', 'integer', pk=True) + c22 = Column('name', 'varchar') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('<>', [c11], [c21]) + + _table1 = ref.table1 + ref.col1[0].table = None + with self.assertRaises(TableNotFoundError): + ref.join_table + + ref.col1[0].table = _table1 + ref.col2[0].table = None + with self.assertRaises(TableNotFoundError): + ref.join_table + def test_dbml_simple(self): t = Table('products') c1 = Column('id', 'integer') From 81a90200231346de213f64289a3d2287c17725d8 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 25 Sep 2022 12:03:58 +0200 Subject: [PATCH 068/125] update changelog and bump version --- CHANGELOG.md | 4 ++++ setup.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de2fb3e..c198d5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.3 + +- Fix: inline many-to-many references were not rendered in sql + # 1.0.2 - New: "backslash newline" is supported in note text (line continuation) diff --git a/setup.py b/setup.py index 23def9e..e9472e0 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.2', + version='1.0.3', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 4f36dd9cb1fcf16ffdea8469f7a2bd78f9a4fbf5 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 23 Oct 2022 10:42:20 +0200 Subject: [PATCH 069/125] First defined referenced tables in SQL #23 --- CHANGELOG.md | 4 ++++ pydbml/database.py | 25 +++++++++++++++++++++++- test/test_data/integration1.sql | 16 ++++++++-------- test/test_database.py | 34 +++++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c198d5b..2b305c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.4 + +- New: referenced tables in SQL are now defined first in SQL (#23 reported by @minhl) + # 1.0.3 - Fix: inline many-to-many references were not rendered in sql diff --git a/pydbml/database.py b/pydbml/database.py index 3f3b922..27a875b 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -11,6 +11,28 @@ from .classes import TableGroup from .exceptions import DatabaseValidationError +from .constants import MANY_TO_ONE, ONE_TO_MANY + + +def reorder_tables_for_sql(tables: List['Table'], refs: list['Reference']) -> List['Table']: + """ + Attempt to reorder the tables, so that they are defined in SQL before they are referenced by + inline foreign keys. + + Won't aid the rare cases of cross-references and many-to-many relations. + """ + references: Dict[str: int] = {} + for ref in refs: + if ref.inline: + if ref.type == MANY_TO_ONE: + table_name = ref.table1.name + elif ref.type == ONE_TO_MANY: + table_name = ref.table2.name + else: + continue + references[table_name] = references.get(table_name, 0) + 1 + return sorted(tables, key=lambda t: references.get(t.name, 0), reverse=True) + class Database: def __init__(self) -> None: @@ -185,7 +207,8 @@ def delete_project(self) -> Project: def sql(self): '''Returs SQL of the parsed results''' refs = (ref for ref in self.refs if not ref.inline) - components = (i.sql for i in (*self.enums, *self.tables, *refs)) + tables = reorder_tables_for_sql(self.tables, self.refs) + components = (i.sql for i in (*self.enums, *tables, *refs)) return '\n\n'.join(components) @property diff --git a/test/test_data/integration1.sql b/test/test_data/integration1.sql index 726f25a..7330553 100644 --- a/test/test_data/integration1.sql +++ b/test/test_data/integration1.sql @@ -4,6 +4,14 @@ CREATE TYPE "level" AS ENUM ( 'senior', ); +CREATE TABLE "books" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "title" varchar, + "author" varchar, + "country_id" integer, + CONSTRAINT "Country Reference" FOREIGN KEY ("country_id") REFERENCES "countries" ("id") +); + CREATE TABLE "Employees" ( "id" integer PRIMARY KEY AUTOINCREMENT, "name" varchar, @@ -14,14 +22,6 @@ CREATE TABLE "Employees" ( COMMENT ON COLUMN "Employees"."name" IS 'Full employee name'; -CREATE TABLE "books" ( - "id" integer PRIMARY KEY AUTOINCREMENT, - "title" varchar, - "author" varchar, - "country_id" integer, - CONSTRAINT "Country Reference" FOREIGN KEY ("country_id") REFERENCES "countries" ("id") -); - CREATE TABLE "countries" ( "id" integer PRIMARY KEY AUTOINCREMENT, "name" varchar2 UNIQUE diff --git a/test/test_database.py b/test/test_database.py index 552512f..a33ea82 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -2,6 +2,7 @@ from pathlib import Path from unittest import TestCase +from unittest.mock import Mock from pydbml.classes import Column from pydbml.classes import Enum @@ -12,6 +13,8 @@ from pydbml.classes import TableGroup from pydbml.database import Database from pydbml.exceptions import DatabaseValidationError +from pydbml.constants import ONE_TO_MANY, MANY_TO_ONE, MANY_TO_MANY +from pydbml.database import reorder_tables_for_sql TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' @@ -366,3 +369,34 @@ class Test: database.delete(t) with self.assertRaises(AttributeError): t.database + + +class TestReorderTablesForSQL(TestCase): + + def test_reorder_tables(self) -> None: + t1 = Mock(name="table1") # 1 ref + t2 = Mock(name="table2") # 2 refs + t3 = Mock(name="table3") + t4 = Mock(name="table4") # 1 ref + t5 = Mock(name="table5") + t6 = Mock(name="table6") # 3 refs + t7 = Mock(name="table7") + t8 = Mock(name="table8") + t9 = Mock(name="table9") + t10 = Mock(name="table10") + + refs = [ + Mock(type=ONE_TO_MANY, table1=t1, table2=t2, inline=True), + Mock(type=MANY_TO_ONE, table1=t4, table2=t3, inline=True), + Mock(type=ONE_TO_MANY, table1=t6, table2=t2, inline=True), + Mock(type=ONE_TO_MANY, table1=t7, table2=t6, inline=True), + Mock(type=MANY_TO_ONE, table1=t6, table2=t8, inline=True), + Mock(type=ONE_TO_MANY, table1=t9, table2=t6, inline=True), + Mock(type=ONE_TO_MANY, table1=t1, table2=t2, inline=False), # ignored not inline + Mock(type=ONE_TO_MANY, table1=t10, table2=t1, inline=True), + Mock(type=MANY_TO_MANY, table1=t1, table2=t2, inline=True), # ignored m2m + ] + original = [t1, t2, t3, t4, t5, t6, t7, t8, t9, t10] + expected = [t6, t2, t1, t4, t3, t5, t7, t8, t9, t10] + result = reorder_tables_for_sql(original, refs) + self.assertEqual(expected, result) From 90f376cbf487ecd2ccc480d8bc07dd1968d18db1 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 23 Oct 2022 11:19:21 +0200 Subject: [PATCH 070/125] Fix quotes in note text were not properly rendered in DBML and SQL --- pydbml/classes/table.py | 2 +- pydbml/database.py | 2 +- pydbml/tools.py | 8 ++++---- test/test_tools.py | 9 +++++++++ 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index bf13250..9ffe87e 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -225,7 +225,7 @@ def sql(self): for col in self.columns: if col.note: - quoted_note = f"'{col.note.text}'" + quoted_note = f"'{col.note._prepare_text_for_sql()}'" note_sql = f'COMMENT ON COLUMN "{self.name}"."{col.name}" IS {quoted_note};' result += f'\n\n{note_sql}' return result diff --git a/pydbml/database.py b/pydbml/database.py index 27a875b..3dfb4ea 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -21,7 +21,7 @@ def reorder_tables_for_sql(tables: List['Table'], refs: list['Reference']) -> Li Won't aid the rare cases of cross-references and many-to-many relations. """ - references: Dict[str: int] = {} + references: Dict[str, int] = {} for ref in refs: if ref.inline: if ref.type == MANY_TO_ONE: diff --git a/pydbml/tools.py b/pydbml/tools.py index 947a814..2d58cfc 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -17,11 +17,11 @@ def comment_to_sql(val: str) -> str: return comment(val, '--') -def note_option_to_dbml(val: 'Note') -> str: - if '\n' in val.text: - return f"note: '''{val.text}'''" +def note_option_to_dbml(note: 'Note') -> str: + if '\n' in note.text: + return f"note: '''{note._prepare_text_for_dbml()}'''" else: - return f"note: '{val.text}'" + return f"note: '{note._prepare_text_for_dbml()}'" def indent(val: str, spaces=4) -> str: diff --git a/test/test_tools.py b/test/test_tools.py index db29063..81a1cae 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -44,11 +44,20 @@ def test_oneline(self) -> None: note = Note('one line note') self.assertEqual(f"note: 'one line note'", note_option_to_dbml(note)) + def test_oneline_with_quote(self) -> None: + note = Note('one line\'d note') + self.assertEqual(f"note: 'one line\\'d note'", note_option_to_dbml(note)) + def test_multiline(self) -> None: note = Note('line1\nline2\nline3') expected = "note: '''line1\nline2\nline3'''" self.assertEqual(expected, note_option_to_dbml(note)) + def test_multiline_with_quotes(self) -> None: + note = Note('line1\n\'\'\'line2\nline3') + expected = "note: '''line1\n\\'''line2\nline3'''" + self.assertEqual(expected, note_option_to_dbml(note)) + class TestIndent(TestCase): def test_empty(self) -> None: From 4a125f6ee0741c9bb126635dbadfcfeb8aa7474f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 23 Oct 2022 11:22:18 +0200 Subject: [PATCH 071/125] Silence wrong mypy errors, update reorder tables --- pydbml/classes/reference.py | 8 ++++---- pydbml/database.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index e94e523..af82568 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -132,7 +132,7 @@ def _generate_inline_sql(self, source_col: List['Column'], ref_col: List['Column result = comment_to_sql(self.comment) if self.comment else '' result += ( f'{{c}}FOREIGN KEY ({self._col_names(source_col)}) ' # type: ignore - f'REFERENCES {ref_col[0].table._get_full_name_for_sql()} ({self._col_names(ref_col)})' + f'REFERENCES {ref_col[0].table._get_full_name_for_sql()} ({self._col_names(ref_col)})' # type: ignore ) if self.on_update: result += f' ON UPDATE {self.on_update.upper()}' @@ -145,7 +145,7 @@ def _generate_not_inline_sql(self, c1: List['Column'], c2: List['Column']): result += ( f'ALTER TABLE {c1[0].table._get_full_name_for_sql()}' # type: ignore f' ADD {{c}}FOREIGN KEY ({self._col_names(c1)})' - f' REFERENCES {c2[0].table._get_full_name_for_sql()} ({self._col_names(c2)})' + f' REFERENCES {c2[0].table._get_full_name_for_sql()} ({self._col_names(c2)})' # type: ignore ) if self.on_update: result += f' ON UPDATE {self.on_update.upper()}' @@ -234,9 +234,9 @@ def dbml(self) -> str: options_str = f' [{", ".join(options)}]' if options else '' result += ( ' {\n ' # type: ignore - f'{self.table1._get_full_name_for_sql()}.{col1} ' + f'{self.table1._get_full_name_for_sql()}.{col1} ' # type: ignore f'{self.type} ' - f'{self.table2._get_full_name_for_sql()}.{col2}' + f'{self.table2._get_full_name_for_sql()}.{col2}' # type: ignore f'{options_str}' '\n}' ) diff --git a/pydbml/database.py b/pydbml/database.py index 3dfb4ea..091a7d7 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -24,9 +24,9 @@ def reorder_tables_for_sql(tables: List['Table'], refs: list['Reference']) -> Li references: Dict[str, int] = {} for ref in refs: if ref.inline: - if ref.type == MANY_TO_ONE: + if ref.type == MANY_TO_ONE and ref.table1 is not None: table_name = ref.table1.name - elif ref.type == ONE_TO_MANY: + elif ref.type == ONE_TO_MANY and ref.table2 is not None: table_name = ref.table2.name else: continue From 48a8e5678ee53fd822b33eeafd7ad61250d14a32 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 23 Oct 2022 11:23:18 +0200 Subject: [PATCH 072/125] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b305c0..0863887 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # 1.0.4 - New: referenced tables in SQL are now defined first in SQL (#23 reported by @minhl) +- Fix: single quotes were not escaped in column notes (#24 reported by @fivegrant) # 1.0.3 From 1d79762665536e1178a0b0ac7026d4c42c267f46 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 23 Oct 2022 11:23:52 +0200 Subject: [PATCH 073/125] bump version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e9472e0..35d1561 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.3', + version='1.0.4', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 198fcdf2106a59dc74f4ab48b22e38861e1852ed Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 19 Nov 2022 18:00:04 +0100 Subject: [PATCH 074/125] update docs, fix type annotation, update pyparsing dep version --- docs/classes.md | 12 ++++++++++++ pydbml/database.py | 2 +- setup.py | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/classes.md b/docs/classes.md index 172f27d..9e999d9 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -1,3 +1,15 @@ + +* [Database](#database) +* [Table](#table) +* [Column](#column) +* [Index](#index) +* [Reference](#reference) +* [Enum](#enum) +* [Note](#note) +* [Expression](#expression) +* [Project](#project) +* [TableGroup](#tablegroup) + # Class Reference PyDBML classes represent database entities. They live in the `pydbml.classes` package. diff --git a/pydbml/database.py b/pydbml/database.py index 091a7d7..c4bacb6 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -14,7 +14,7 @@ from .constants import MANY_TO_ONE, ONE_TO_MANY -def reorder_tables_for_sql(tables: List['Table'], refs: list['Reference']) -> List['Table']: +def reorder_tables_for_sql(tables: List['Table'], refs: List['Reference']) -> List['Table']: """ Attempt to reorder the tables, so that they are defined in SQL before they are referenced by inline foreign keys. diff --git a/setup.py b/setup.py index 35d1561..06a1bbc 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ license='MIT', platforms='any', install_requires=[ - 'pyparsing>=2.4.7', + 'pyparsing>=3.0.0', ], classifiers=[ "Development Status :: 4 - Beta", From 8685a10f51e129e73bdb8d354c7098274943d9cf Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 20 Nov 2022 09:29:37 +0100 Subject: [PATCH 075/125] junction table now has the schema of the first referenced table --- pydbml/classes/reference.py | 1 + test/test_classes/test_reference.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index af82568..f093016 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -65,6 +65,7 @@ def join_table(self) -> Optional['Table']: return Table( name=f'{self.table1.name}_{self.table2.name}', + schema=self.table1.schema, columns=( Column(name=f'{c.table.name}_{c.name}', type=c.type, not_null=True, pk=True) # type: ignore for c in chain(self.col1, self.col2) diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index 7414a1e..d00e698 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -181,7 +181,7 @@ def test_many_to_many_sql_composite_different_schemas(self) -> None: ref = Reference('<>', [c11, c12], [c21, c22]) expected = \ -'''CREATE TABLE "books_authors" ( +'''CREATE TABLE "schema1"."books_authors" ( "books_id" integer NOT NULL, "books_author" varchar NOT NULL, "authors_id" integer NOT NULL, @@ -189,9 +189,9 @@ def test_many_to_many_sql_composite_different_schemas(self) -> None: PRIMARY KEY ("books_id", "books_author", "authors_id", "authors_name") ); -ALTER TABLE "books_authors" ADD FOREIGN KEY ("books_id", "books_author") REFERENCES "schema1"."books" ("id", "author"); +ALTER TABLE "schema1"."books_authors" ADD FOREIGN KEY ("books_id", "books_author") REFERENCES "schema1"."books" ("id", "author"); -ALTER TABLE "books_authors" ADD FOREIGN KEY ("authors_id", "authors_name") REFERENCES "schema2"."authors" ("id", "name");''' +ALTER TABLE "schema1"."books_authors" ADD FOREIGN KEY ("authors_id", "authors_name") REFERENCES "schema2"."authors" ("id", "name");''' self.assertEqual(expected, ref.sql) def test_join_table(self) -> None: From 873ed5af3878cbc9d44b47e13777d749e16c762f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 20 Nov 2022 09:33:58 +0100 Subject: [PATCH 076/125] update changelog and bump version --- CHANGELOG.md | 5 +++++ setup.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0863887..3a490bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.0.5 + +- Fix: junction table now has the schema of the first referenced table (as introduced in DBML 2.4.3) +- Fix: typing issue which failed for Python 3.8 and Python 3.9 + # 1.0.4 - New: referenced tables in SQL are now defined first in SQL (#23 reported by @minhl) diff --git a/setup.py b/setup.py index 06a1bbc..8df1598 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.4', + version='1.0.5', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 21cdfefed3e9915107eefc51f1994b94bee4146c Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 20 Nov 2022 09:37:14 +0100 Subject: [PATCH 077/125] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8c94615..3fcd617 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # DBML parser for Python -*Compliant with DBML **v2.4.2** syntax* +*Compliant with DBML **v2.4.4** syntax* PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. From c3c754d1d855635350a5db749fe5a54ec58f4dc7 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Dec 2022 10:58:23 +0100 Subject: [PATCH 078/125] fixed empty line stripping #26 --- pydbml/tools.py | 10 ++-------- test/test_tools.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pydbml/tools.py b/pydbml/tools.py index 2d58cfc..75a0dfd 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -38,14 +38,8 @@ def remove_bom(source: str) -> str: def strip_empty_lines(source: str) -> str: """Remove empty lines or lines with just spaces from beginning and end.""" - first_line = 0 - lines = source.split('\n') - last_line = len(lines) - 1 - while not lines[first_line] or lines[first_line].isspace(): - first_line += 1 - while not lines[last_line] or lines[last_line].isspace(): - last_line -= 1 - return '\n'.join(lines[first_line: last_line + 1]) + pattern = re.compile(r'^([ \t]*\n)*(?P[\s\S]+?)(\n[ \t]*)*$') + return pattern.sub('\g', source) def remove_indentation(source: str) -> str: diff --git a/test/test_tools.py b/test/test_tools.py index 81a1cae..5aa714f 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -5,6 +5,7 @@ from pydbml.tools import comment_to_sql from pydbml.tools import indent from pydbml.tools import note_option_to_dbml +from pydbml.tools import strip_empty_lines class TestCommentToDBML(TestCase): @@ -71,3 +72,30 @@ def test_nonempty(self) -> None: self.assertEqual(indent(source), expected) expected2 = ' line1\n line2\n line3' self.assertEqual(indent(source, 2), expected2) + + +class TestStripEmptyLines(TestCase): + def test_empty(self) -> None: + source = '' + self.assertEqual(strip_empty_lines(source), source) + + def test_no_empty_lines(self) -> None: + source = 'line1\n\n\nline2' + self.assertEqual(strip_empty_lines(source), source) + + def test_empty_lines(self) -> None: + stripped = ' line1\n\n line2' + source = f'\n \n \n\t \t \n \n{stripped}\n\n\n \n \t \n\t \n \n' + self.assertEqual(strip_empty_lines(source), stripped) + + def test_one_empty_line(self) -> None: + stripped = ' line1\n\n line2' + source = f'\n{stripped}' + self.assertEqual(strip_empty_lines(source), stripped) + source = f'{stripped}\n' + self.assertEqual(strip_empty_lines(source), stripped) + + def test_end(self) -> None: + stripped = ' line1\n\n line2' + source = f'\n{stripped}\n ' + self.assertEqual(strip_empty_lines(source), stripped) \ No newline at end of file From 472ddfba2ee728712ac732f707a0bce157f14b9f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Dec 2022 11:00:53 +0100 Subject: [PATCH 079/125] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a490bb..b8628bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.6 + +- Fix: (#26) bug in note empty line stripping, thanks @Jaschenn for reporting + # 1.0.5 - Fix: junction table now has the schema of the first referenced table (as introduced in DBML 2.4.3) From 4cd31013ed2e2dac778c6094b1613f70cc4d707f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Dec 2022 11:12:51 +0100 Subject: [PATCH 080/125] add public get_references_for_sql table method --- pydbml/classes/table.py | 30 ++++++++++++++++++------------ test/test_classes/test_table.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 9ffe87e..99ab6c7 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -120,24 +120,30 @@ def get_refs(self) -> List['Reference']: raise UnknownDatabaseError('Database for the table is not set') return [ref for ref in self.database.refs if ref.table1 == self] - def _get_references_for_sql(self) -> List['Reference']: - ''' - return inline references for this table sql definition - ''' - if self.abstract: - return [] + def get_references_for_sql(self) -> List['Reference']: + """ + Return all references in the database where this table is on the left side of SQL + reference definition. + """ if not self.database: raise UnknownDatabaseError(f'Database for the table {self} is not set') result = [] for ref in self.database.refs: - if ref.inline: - if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ - (ref.table1 == self): - result.append(ref) - elif (ref.type == ONE_TO_MANY) and (ref.table2 == self): - result.append(ref) + if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ + (ref.table1 == self): + result.append(ref) + elif (ref.type == ONE_TO_MANY) and (ref.table2 == self): + result.append(ref) return result + def _get_references_for_sql(self) -> List['Reference']: + ''' + Return inline references for this table sql definition + ''' + if self.abstract: + return [] + return [r for r in self.get_references_for_sql() if r.inline] + def _get_full_name_for_sql(self) -> str: if self.schema == 'public': return f'"{self.name}"' diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index ac9f0f6..75fda72 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -353,6 +353,34 @@ def test_get_references_for_sql(self): self.assertEqual(t._get_references_for_sql(), [r1, r2]) self.assertEqual(t2._get_references_for_sql(), [r3]) + def test_get_references_for_sql_public(self): + t = Table('products') + with self.assertRaises(UnknownDatabaseError): + t._get_references_for_sql() + c11 = Column('id', 'integer') + c12 = Column('name', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('id', 'integer') + c22 = Column('name_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + s = Database() + s.add(t) + s.add(t2) + r1 = Reference('>', c12, c22, inline=True) + r2 = Reference('-', c11, c21, inline=True) + r3 = Reference('<', c11, c22, inline=True) + s.add(r1) + s.add(r2) + s.add(r3) + self.assertEqual(t.get_references_for_sql(), [r1, r2]) + self.assertEqual(t2.get_references_for_sql(), [r3]) + r1.inline = r2.inline = r3.inline = False + self.assertEqual(t.get_references_for_sql(), [r1, r2]) + self.assertEqual(t2.get_references_for_sql(), [r3]) + def test_get_refs(self): t = Table('products') with self.assertRaises(UnknownDatabaseError): From 07ce650a3b9b5f47d5dcbcc58e01e676fc94f291 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Dec 2022 12:11:02 +0100 Subject: [PATCH 081/125] update changelog and docs --- CHANGELOG.md | 1 + docs/classes.md | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8628bb..8c64c7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # 1.0.6 - Fix: (#26) bug in note empty line stripping, thanks @Jaschenn for reporting +- New: get_references_for_sql table method # 1.0.5 diff --git a/docs/classes.md b/docs/classes.md index 9e999d9..511a6d7 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -116,6 +116,7 @@ When you are creating PyDBML schema from scratch, you have to add each created o * **add_index** (i: `Index`) — add an index to the table, * **delete_index** (i: Index or int) — delete an index from the table by Index object or index number. * **get_refs** — get list of references, defined for this table. +* **get_references_for_sql** — get list of references where this table is on the left side of FOREIGN KEY definition in SQL. ## Column From 3f83a6f21e71f87a2670ffa1705048f1fa734031 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Dec 2022 12:11:47 +0100 Subject: [PATCH 082/125] bump version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8df1598..b3a7796 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.5', + version='1.0.6', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 93ee4a33d4d7ad84955214c914c4d41dbbe3fe5a Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Dec 2022 12:27:31 +0100 Subject: [PATCH 083/125] fix remove_indentation bug --- CHANGELOG.md | 4 ++++ pydbml/tools.py | 3 +++ test/test_tools.py | 15 +++++++++++++-- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c64c7d..d90ac3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # 1.0.6 +- Fix: removing indentation bug + +# 1.0.6 + - Fix: (#26) bug in note empty line stripping, thanks @Jaschenn for reporting - New: get_references_for_sql table method diff --git a/pydbml/tools.py b/pydbml/tools.py index 75a0dfd..fe08cf1 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -43,6 +43,9 @@ def strip_empty_lines(source: str) -> str: def remove_indentation(source: str) -> str: + if not source: + return source + pattern = re.compile(r'^\s*') lines = source.split('\n') diff --git a/test/test_tools.py b/test/test_tools.py index 5aa714f..57943f8 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -1,7 +1,7 @@ from unittest import TestCase from pydbml.classes import Note -from pydbml.tools import comment_to_dbml +from pydbml.tools import comment_to_dbml, remove_indentation from pydbml.tools import comment_to_sql from pydbml.tools import indent from pydbml.tools import note_option_to_dbml @@ -98,4 +98,15 @@ def test_one_empty_line(self) -> None: def test_end(self) -> None: stripped = ' line1\n\n line2' source = f'\n{stripped}\n ' - self.assertEqual(strip_empty_lines(source), stripped) \ No newline at end of file + self.assertEqual(strip_empty_lines(source), stripped) + + +class TestRemoveIndentation(TestCase): + def test_empty(self) -> None: + source = '' + self.assertEqual(remove_indentation(source), source) + + def test_not_empty(self) -> None: + source = ' line1\n line2' + expected = 'line1\n line2' + self.assertEqual(remove_indentation(source), expected) From 1dc5670667ed87bf4702cfad72cf831a0c1bb5f7 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Dec 2022 12:28:18 +0100 Subject: [PATCH 084/125] bump version --- CHANGELOG.md | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d90ac3f..c68481f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -# 1.0.6 +# 1.0.7 - Fix: removing indentation bug diff --git a/setup.py b/setup.py index b3a7796..be1d83e 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.6', + version='1.0.7', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 0f05fb9c0eaa2798e74170af18c356b1e8dad3c5 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 14 May 2023 13:46:58 +0200 Subject: [PATCH 085/125] allow comments after structures like table --- CHANGELOG.md | 4 ++++ pydbml/definitions/common.py | 4 +++- test/test_definitions/test_table.py | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c68481f..be0359d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.8 + +- Fix: (#27) allowing comments after Tables, Enums, etc. Thanks @marktaff for reporting + # 1.0.7 - Fix: removing indentation bug diff --git a/pydbml/definitions/common.py b/pydbml/definitions/common.py index 99d3601..fc5ed7a 100644 --- a/pydbml/definitions/common.py +++ b/pydbml/definitions/common.py @@ -20,7 +20,9 @@ c = comment('comment')[0, 1] n = pp.LineEnd() -end = n | pp.StringEnd() + +end = comment[...].suppress() + n | pp.StringEnd() + # obligatory newline # n = pp.Suppress('\n')[1, ...] diff --git a/test/test_definitions/test_table.py b/test/test_definitions/test_table.py index 9bcfcfa..561f430 100644 --- a/test/test_definitions/test_table.py +++ b/test/test_definitions/test_table.py @@ -172,6 +172,24 @@ def test_with_body_note(self) -> None: self.assertEqual(res[0].note.text, 'bodynote') self.assertEqual(len(res[0].columns), 1) + def test_comment_after(self) -> None: + val = ''' +// some comment before table +table ids as ii [ + headercolor: #ccc, + note: "headernote"] +{ + id integer + note: "bodynote" +} // some somment after table''' + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].comment, 'some comment before table') + self.assertEqual(res[0].name, 'ids') + self.assertEqual(res[0].alias, 'ii') + self.assertEqual(res[0].header_color, '#ccc') + self.assertEqual(res[0].note.text, 'bodynote') + self.assertEqual(len(res[0].columns), 1) + def test_with_indexes(self) -> None: val = ''' table ids as ii [ From f1d978c925bf4f7373e7e6ab150d2ee2cc59582e Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 14 May 2023 13:55:27 +0200 Subject: [PATCH 086/125] bump version, update readme --- README.md | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3fcd617..b73f6aa 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ # DBML parser for Python -*Compliant with DBML **v2.4.4** syntax* +*Compliant with DBML **v2.5.3** syntax* PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. -> The project was rewritten in May 2022, the new version 1.0.0 is not compatible with the previous ones. See details in [Upgrading to PyDBML 1.0.0](docs/upgrading.md). +> The project was rewritten in May 2022, the new version 1.0.0 is not compatible with versions 0.x.x. See details in [Upgrading to PyDBML 1.0.0](docs/upgrading.md). **Docs:** diff --git a/setup.py b/setup.py index be1d83e..864b99d 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.7', + version='1.0.8', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From fb485d19bb937f86ae402633347870350db3f069 Mon Sep 17 00:00:00 2001 From: Ee Durbin Date: Mon, 22 May 2023 18:18:10 -0400 Subject: [PATCH 087/125] Fix enum collision Adds a reference DBML file from the core implementation, I was trying to use this dbml to validate my own project and ran accross the enum collision bug --- pydbml/database.py | 4 +- test/test_data/dbml_schema_def.dbml | 64 +++++++++++++++++++++++++++++ test/test_parser.py | 9 ++++ 3 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 test/test_data/dbml_schema_def.dbml diff --git a/pydbml/database.py b/pydbml/database.py index c4bacb6..910b6c5 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -118,8 +118,8 @@ def add_enum(self, obj: Enum) -> Enum: if obj in self.enums: raise DatabaseValidationError(f'{obj} is already in the database.') for enum in self.enums: - if enum.name == obj.name: - raise DatabaseValidationError(f'Enum {obj.name} is already in the database.') + if enum.name == obj.name and enum.schema == obj.schema: + raise DatabaseValidationError(f'Enum {obj.schema}.{obj.name} is already in the database.') self._set_database(obj) self.enums.append(obj) diff --git a/test/test_data/dbml_schema_def.dbml b/test/test_data/dbml_schema_def.dbml new file mode 100644 index 0000000..a5d509a --- /dev/null +++ b/test/test_data/dbml_schema_def.dbml @@ -0,0 +1,64 @@ +Table "ecommerce"."users" as EU { + id int [pk] + name varchar + ejs job_status + ejs2 public.job_status + eg schemaB.gender + eg2 gender +} + +Table public.users { + id int [pk] + name varchar + pjs job_status + pjs2 public.job_status + pg schemaB.gender + pg2 gender +} + +Table products { + id int [pk] + name varchar +} + +Table schemaA.products as A { + id int [pk] + name varchar [ref: > EU.id] +} + +Table schemaA.locations { + id int [pk] + name varchar [ref: > users.id ] +} + +Ref: "public".users.id < EU.id + +Ref name_optional { + users.name < ecommerce.users.id +} + +TableGroup tablegroup_name { // tablegroup is case-insensitive. + public.products + users + ecommerce.users + A +} + +enum job_status { + created2 [note: 'abcdef'] + running2 + done2 + failure2 +} + +enum schemaB.gender { + man + woman + nonbinary +} + +enum gender { + man2 + woman2 + nonbinary2 +} diff --git a/test/test_parser.py b/test/test_parser.py index a20bf85..d0ff6e1 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -106,6 +106,15 @@ def test_composite_references(self): self.assertEqual(rs[1].col2, [reviews2['post_id'], reviews2['tag']]) +class TestDBMLReferenceDef(TestCase): + def test_dbml_reference_def(self): + results = PyDBML.parse_file(TEST_DATA_PATH / 'dbml_schema_def.dbml') + self.assertEqual(len(results.aliases), 1) + self.assertEqual(len(results.tables), 5) + self.assertEqual(len(results.table_groups), 1) + self.assertEqual(len(results.enums), 3) + + class TestFaulty(TestCase): def test_bad_reference(self) -> None: with self.assertRaises(TableNotFoundError): From 935140bde527137dd4b754ce091aadedd1acc925 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 23 May 2023 19:30:36 +0200 Subject: [PATCH 088/125] fix test --- test/test_parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_parser.py b/test/test_parser.py index d0ff6e1..ac26f53 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -109,7 +109,6 @@ def test_composite_references(self): class TestDBMLReferenceDef(TestCase): def test_dbml_reference_def(self): results = PyDBML.parse_file(TEST_DATA_PATH / 'dbml_schema_def.dbml') - self.assertEqual(len(results.aliases), 1) self.assertEqual(len(results.tables), 5) self.assertEqual(len(results.table_groups), 1) self.assertEqual(len(results.enums), 3) From 1aeae9af21fcfad57698fa15337b0fc288b00877 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 23 May 2023 19:31:09 +0200 Subject: [PATCH 089/125] update changelog, bump version --- CHANGELOG.md | 4 ++++ setup.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be0359d..0290b03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.9 + +- Fix: enum collision from different schemas. Thanks @ewdurbin for the contribution + # 1.0.8 - Fix: (#27) allowing comments after Tables, Enums, etc. Thanks @marktaff for reporting diff --git a/setup.py b/setup.py index 864b99d..0aa37c5 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.8', + version='1.0.9', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From fb8b4d739ac2e2cd6155ccd0e48e3734497ee023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20Beraud?= Date: Fri, 26 May 2023 11:33:40 +0200 Subject: [PATCH 090/125] fix class name in docx And also fix typos. --- docs/classes.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/classes.md b/docs/classes.md index 511a6d7..1dd14d4 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -166,7 +166,7 @@ Indexes are stored in the `indexes` attribute of a `Table` object. ## Reference -`Index` class represents a database relation. +`Reference` class represents a database relation. ```python >>> from pydbml import PyDBML @@ -183,7 +183,7 @@ Indexes are stored in the `indexes` attribute of a `Table` object. * `<` — one to many; * `>` — many to one; * `-` — one to one. -* **col1** (list os `Column`) — list of Column objects of the left side of the reference. Changed in **0.4.0**, previously was plain `Column`. +* **col1** (list of `Column`) — list of Column objects of the left side of the reference. Changed in **0.4.0**, previously was plain `Column`. * **table1** (`Table` or `None`) — link to the left `Table` object of the reference or `None` of it was not set. * **col2** (list of `Column`) — list of Column objects of the right side of the reference. Changed in **0.4.0**, previously was plain `Column`. * **table2** (`Table` or `None`) — link to the right `Table` object of the reference or `None` of it was not set. From 860f71f1ac2c1601600a65643c4e8964150f4795 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 31 Oct 2023 08:41:05 +0100 Subject: [PATCH 091/125] test: update enum test, readme version compliance --- README.md | 2 +- test/test_data/docs/enum_definition.dbml | 10 ++++++++++ test/test_docs.py | 14 +++++++++----- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index b73f6aa..1fe4415 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # DBML parser for Python -*Compliant with DBML **v2.5.3** syntax* +*Compliant with DBML **v2.6.1** syntax* PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. diff --git a/test/test_data/docs/enum_definition.dbml b/test/test_data/docs/enum_definition.dbml index aeb07f7..b82579d 100644 --- a/test/test_data/docs/enum_definition.dbml +++ b/test/test_data/docs/enum_definition.dbml @@ -5,8 +5,18 @@ enum job_status { failure } +enum grade { + "A+" + "A" + "A-" + "Not Yet Set" +} + + + Table jobs { id integer status job_status + grade grade } diff --git a/test/test_docs.py b/test/test_docs.py index 8a8b84f..3880303 100644 --- a/test/test_docs.py +++ b/test/test_docs.py @@ -9,8 +9,7 @@ from unittest import TestCase from pydbml import PyDBML -from pydbml.classes import Expression - +from pydbml.classes import Expression, Enum TEST_DOCS_PATH = Path(os.path.abspath(__file__)).parent / 'test_data/docs' TestCase.maxDiff = None @@ -227,14 +226,19 @@ def test_column_notes(self) -> None: def test_enum_definition(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'enum_definition.dbml') jobs = results['public.jobs'] - jobs['status'].type == 'job_status' + self.assertIsInstance(jobs['status'].type, Enum) + self.assertIsInstance(jobs['grade'].type, Enum) - self.assertEqual(len(results.enums), 1) - js = results.enums[0] + self.assertEqual(len(results.enums), 2) + js, g = results.enums self.assertEqual(js.name, 'job_status') self.assertEqual([ei.name for ei in js.items], ['created', 'running', 'done', 'failure']) + self.assertEqual(g.name, 'grade') + self.assertEqual([ei.name for ei in g.items], ['A+', 'A', 'A-', 'Not Yet Set']) + + def test_table_group(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'table_group.dbml') From 71d63f5fc9a575b1e28534a2e245f67819c59c05 Mon Sep 17 00:00:00 2001 From: Tristan Grebot Date: Fri, 15 Mar 2024 12:06:07 +0100 Subject: [PATCH 092/125] fix missing headercolor setting in Table dbml generation --- pydbml/classes/table.py | 2 ++ test/test_classes/test_table.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 99ab6c7..5022725 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -245,6 +245,8 @@ def dbml(self): result += f'Table {name} ' if self.alias: result += f'as "{self.alias}" ' + if self.header_color: + result += f'[headercolor: {self.header_color}] ' result += '{\n' columns_str = '\n'.join(c.dbml for c in self.columns) result += indent(columns_str) + '\n' diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 75fda72..586aa15 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -422,6 +422,24 @@ def test_dbml_simple(self): }''' self.assertEqual(t.dbml, expected) + def test_header_color_dbml(self): + t = Table('products') + t.header_color = '#C84432' + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + s = Database() + s.add(t) + + expected = \ +'''Table "products" [headercolor: #C84432] { + "id" integer + "name" varchar2 +}''' + self.assertEqual(t.dbml, expected) + + def test_schema_dbml(self): t = Table('products', schema="myschema") c1 = Column('id', 'integer') From d276e464a4161ac84c7ba160a544f230b2d6a0e2 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 16 Mar 2024 18:53:55 +0100 Subject: [PATCH 093/125] feat: allow arrays in column type (v3.1.0) --- README.md | 2 +- pydbml/definitions/column.py | 2 +- test/test_definitions/test_column.py | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1fe4415..bec1991 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ Enum "product status" { "In Stock" } -Table "orders" { +Table "orders" [headercolor: #fff] { "id" int [pk, increment] "user_id" int [unique, not null] "status" "orders_status" diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index a9fe0c1..5c5e4b3 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -22,7 +22,7 @@ type_args = ("(" + pp.original_text_for(expression) + ")") # column type is parsed as a single string, it will be split by blueprint -column_type = pp.Combine((name + '.' + name) | ((name) + type_args[0, 1])) +column_type = pp.Combine((name + pp.Literal('[]')) | (name + '.' + name) | ((name) + type_args[0, 1])) default = pp.CaselessLiteral('default:').suppress() + _ - ( string_literal diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index 262699b..b8b6ea1 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -37,6 +37,11 @@ def test_expression(self) -> None: res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) + def test_array(self) -> None: + val = 'int[]' + res = column_type.parse_string(val, parseAll=True) + self.assertEqual(res[0], val) + def test_symbols(self) -> None: val = '(*#^)' with self.assertRaises(ParseException): From a21d32b23b42f71fed5a5ab724f38d6fbb4fcfaa Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sat, 16 Mar 2024 19:28:02 +0100 Subject: [PATCH 094/125] feat: allow double quotes in expression (v3.1.2) --- pydbml/definitions/generic.py | 4 ++-- test/test_definitions/test_generic.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pydbml/definitions/generic.py b/pydbml/definitions/generic.py index b29962f..69c1270 100644 --- a/pydbml/definitions/generic.py +++ b/pydbml/definitions/generic.py @@ -33,8 +33,8 @@ # Expression -expr_chars = pp.Word(pp.alphanums + "'`,._+- \n\t") -expr_chars_no_comma_space = pp.Word(pp.alphanums + "'`._+-") +expr_chars = pp.Word(pp.alphanums + "\"'`,._+- \n\t") +expr_chars_no_comma_space = pp.Word(pp.alphanums + "\"'`._+-") expression = pp.Forward() factor = ( pp.Word(pp.alphanums + '_')[0, 1] + '(' + expression + ')' diff --git a/test/test_definitions/test_generic.py b/test/test_definitions/test_generic.py index 25afe6e..9a78430 100644 --- a/test/test_definitions/test_generic.py +++ b/test/test_definitions/test_generic.py @@ -2,7 +2,7 @@ from pyparsing import ParserElement -from pydbml.definitions.generic import expression_literal +from pydbml.definitions.generic import expression_literal, expression from pydbml.parser.blueprints import ExpressionBlueprint @@ -15,3 +15,10 @@ def test_expression_literal(self) -> None: res = expression_literal.parse_string(val) self.assertIsInstance(res[0], ExpressionBlueprint) self.assertEqual(res[0].text, 'SUM(amount)') + +class TestExpression(TestCase): + def test_comma_separated_expression(self) -> None: + val = 'MAX, 3, "MAX", \'MAX\'' + expected = ['MAX', ',', '3', ',', '"MAX"', ',', "'MAX'"] + res = expression.parse_string(val, parseAll=True) + self.assertEqual(res.asList(), expected) From a6a4e8490effa2a7a78f156b17c7fde33bc4d332 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 17 Mar 2024 07:43:01 +0100 Subject: [PATCH 095/125] feat: fix equality check, don't allow duplicate tables in tablegroup (v3.1.6) --- pydbml/classes/base.py | 15 +++++++++++++-- pydbml/classes/note.py | 4 ++-- pydbml/database.py | 8 ++++---- pydbml/exceptions.py | 4 ++++ pydbml/parser/blueprints.py | 6 +++++- test/test_blueprints/test_table_group.py | 18 ++++++++++++++++++ 6 files changed, 46 insertions(+), 9 deletions(-) diff --git a/pydbml/classes/base.py b/pydbml/classes/base.py index 4a1b844..07c4330 100644 --- a/pydbml/classes/base.py +++ b/pydbml/classes/base.py @@ -33,6 +33,17 @@ def __eq__(self, other: object) -> bool: attributes are equal. """ - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ + if not isinstance(other, self.__class__): + return False + # not comparing those because they are circular references + not_compared_fields = ('parent', 'table', 'database') + + self_dict = dict(self.__dict__) + other_dict = dict(other.__dict__) + + for field in not_compared_fields: + self_dict.pop(field, None) + other_dict.pop(field, None) + + return self_dict == other_dict return False diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index 2482b5a..a3965eb 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -1,5 +1,5 @@ import re -from typing import Any +from typing import Any, Union from .base import SQLObject from pydbml.tools import indent @@ -8,7 +8,7 @@ class Note(SQLObject): - def __init__(self, text: Any): + def __init__(self, text: Union[str, 'Note']) -> None: self.text: str if isinstance(text, Note): self.text = text.text diff --git a/pydbml/database.py b/pydbml/database.py index 910b6c5..f2e8a35 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -83,7 +83,7 @@ def add(self, obj: Any) -> Any: raise DatabaseValidationError(f'Unsupported type {type(obj)}.') def add_table(self, obj: Table) -> Table: - if obj in self.tables: + if obj.database == self and obj in self.tables: raise DatabaseValidationError(f'{obj} is already in the database.') if obj.full_name in self.table_dict: raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.') @@ -107,7 +107,7 @@ def add_reference(self, obj: Reference): 'Cannot add reference. At least one of the referenced tables' ' should belong to this database' ) - if obj in self.refs: + if obj.database == self and obj in self.refs: raise DatabaseValidationError(f'{obj} is already in the database.') self._set_database(obj) @@ -115,7 +115,7 @@ def add_reference(self, obj: Reference): return obj def add_enum(self, obj: Enum) -> Enum: - if obj in self.enums: + if obj.database == self and obj in self.enums: raise DatabaseValidationError(f'{obj} is already in the database.') for enum in self.enums: if enum.name == obj.name and enum.schema == obj.schema: @@ -126,7 +126,7 @@ def add_enum(self, obj: Enum) -> Enum: return obj def add_table_group(self, obj: TableGroup) -> TableGroup: - if obj in self.table_groups: + if obj.database == self and obj in self.table_groups: raise DatabaseValidationError(f'{obj} is already in the database.') for table_group in self.table_groups: if table_group.name == obj.name: diff --git a/pydbml/exceptions.py b/pydbml/exceptions.py index 757c434..b5914f7 100644 --- a/pydbml/exceptions.py +++ b/pydbml/exceptions.py @@ -28,3 +28,7 @@ class DBMLError(Exception): class DatabaseValidationError(Exception): pass + + +class ValidationError(Exception): + pass diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index f052977..a4e0ecc 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -19,6 +19,7 @@ from pydbml.classes import TableGroup from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError +from pydbml.exceptions import ValidationError from pydbml.tools import remove_indentation from pydbml.tools import strip_empty_lines @@ -280,7 +281,10 @@ def build(self) -> 'TableGroup': for table_name in self.items: components = table_name.split('.') schema, table = components if len(components) == 2 else ('public', components[0]) - items.append(self.parser.locate_table(schema, table)) + table_obj = self.parser.locate_table(schema, table) + if table_obj in items: + raise ValidationError(f'Table "{table}" is already in group "{self.name}"') + items.append(table_obj) return TableGroup( name=self.name, items=items, diff --git a/test/test_blueprints/test_table_group.py b/test/test_blueprints/test_table_group.py index 48a262f..763ebf8 100644 --- a/test/test_blueprints/test_table_group.py +++ b/test/test_blueprints/test_table_group.py @@ -3,6 +3,7 @@ from pydbml.classes import Table from pydbml.classes import TableGroup +from pydbml.exceptions import ValidationError from pydbml.parser.blueprints import TableGroupBlueprint @@ -51,3 +52,20 @@ def test_build_with_schema(self) -> None: self.assertEqual(locate_table_calls[1].args, ('myschema', 'table2')) for i in result.items: self.assertIsInstance(i, Table) + + def test_duplicate_table(self) -> None: + bp = TableGroupBlueprint( + name='TestTableGroup', + items=['table1', 'table2', 'table1'], + comment='Comment text' + ) + + parserMock = Mock() + parserMock.locate_table.side_effect = [ + Table(name='table1'), + Table(name='table2'), + Table(name='table1') + ] + bp.parser = parserMock + with self.assertRaises(ValidationError): + bp.build() From 048b3480e198da933c696aec41bab49ce97e0568 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 17 Mar 2024 08:02:24 +0100 Subject: [PATCH 096/125] feat: fix equality check again, don't allow duplicate refs (v3.1.6) --- pydbml/classes/base.py | 5 ++--- pydbml/classes/column.py | 8 ++++++++ pydbml/classes/index.py | 1 + pydbml/classes/note.py | 1 + pydbml/classes/project.py | 2 ++ pydbml/classes/reference.py | 1 + pydbml/classes/table.py | 1 + pydbml/classes/table_group.py | 1 + pydbml/database.py | 8 ++++---- 9 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pydbml/classes/base.py b/pydbml/classes/base.py index 07c4330..ec8bfd9 100644 --- a/pydbml/classes/base.py +++ b/pydbml/classes/base.py @@ -9,6 +9,7 @@ class SQLObject: Base class for all SQL objects. ''' required_attributes: Tuple[str, ...] = () + dont_compare_fields = () def check_attributes_for_sql(self): ''' @@ -36,14 +37,12 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False # not comparing those because they are circular references - not_compared_fields = ('parent', 'table', 'database') self_dict = dict(self.__dict__) other_dict = dict(other.__dict__) - for field in not_compared_fields: + for field in self.dont_compare_fields: self_dict.pop(field, None) other_dict.pop(field, None) return self_dict == other_dict - return False diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 4743581..d599c2a 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -21,6 +21,7 @@ class Column(SQLObject): '''Class representing table column.''' required_attributes = ('name', 'type') + dont_compare_fields = ('table',) def __init__(self, name: str, @@ -45,6 +46,13 @@ def __init__(self, self.default = default self.table: Optional['Table'] = None + def __eq__(self, other: 'Column') -> bool: + self_table = self.table.full_name if self.table else None + other_table = other.table.full_name if other.table else None + if self_table != other_table: + return False + return super().__eq__(other) + @property def note(self): return self._note diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index b51a4b4..a3d771a 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -19,6 +19,7 @@ class Index(SQLObject): '''Class representing index.''' required_attributes = ('subjects', 'table') + dont_compare_fields = ('table',) def __init__(self, subjects: List[Union[str, 'Column', 'Expression']], diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index a3965eb..eee65cf 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -7,6 +7,7 @@ class Note(SQLObject): + dont_compare_fields = ('parent',) def __init__(self, text: Union[str, 'Note']) -> None: self.text: str diff --git a/pydbml/classes/project.py b/pydbml/classes/project.py index 3133a85..6069fc8 100644 --- a/pydbml/classes/project.py +++ b/pydbml/classes/project.py @@ -8,6 +8,8 @@ class Project: + dont_compare_fields = ('database',) + def __init__(self, name: str, items: Optional[Dict[str, str]] = None, diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index f093016..cb37e3b 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -25,6 +25,7 @@ class Reference(SQLObject): and its `sql` property contains the ALTER TABLE clause. ''' required_attributes = ('type', 'col1', 'col2') + dont_compare_fields = ('database', '_inline') def __init__(self, type: Literal['>', '<', '-', '<>'], diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 5022725..f493e9c 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -28,6 +28,7 @@ class Table(SQLObject): '''Class representing table.''' required_attributes = ('name', 'schema') + dont_compare_fields = ('database',) def __init__(self, name: str, diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py index 3386fa0..1f38978 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/classes/table_group.py @@ -11,6 +11,7 @@ class TableGroup: but after parsing the whole document, PyDBMLParseResults class replaces them with references to actual tables. ''' + dont_compare_fields = ('database',) def __init__(self, name: str, diff --git a/pydbml/database.py b/pydbml/database.py index f2e8a35..910b6c5 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -83,7 +83,7 @@ def add(self, obj: Any) -> Any: raise DatabaseValidationError(f'Unsupported type {type(obj)}.') def add_table(self, obj: Table) -> Table: - if obj.database == self and obj in self.tables: + if obj in self.tables: raise DatabaseValidationError(f'{obj} is already in the database.') if obj.full_name in self.table_dict: raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.') @@ -107,7 +107,7 @@ def add_reference(self, obj: Reference): 'Cannot add reference. At least one of the referenced tables' ' should belong to this database' ) - if obj.database == self and obj in self.refs: + if obj in self.refs: raise DatabaseValidationError(f'{obj} is already in the database.') self._set_database(obj) @@ -115,7 +115,7 @@ def add_reference(self, obj: Reference): return obj def add_enum(self, obj: Enum) -> Enum: - if obj.database == self and obj in self.enums: + if obj in self.enums: raise DatabaseValidationError(f'{obj} is already in the database.') for enum in self.enums: if enum.name == obj.name and enum.schema == obj.schema: @@ -126,7 +126,7 @@ def add_enum(self, obj: Enum) -> Enum: return obj def add_table_group(self, obj: TableGroup) -> TableGroup: - if obj.database == self and obj in self.table_groups: + if obj in self.table_groups: raise DatabaseValidationError(f'{obj} is already in the database.') for table_group in self.table_groups: if table_group.name == obj.name: From f945dc8e9ebe4f7991f7e4a25b83821a3c66197a Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 17 Mar 2024 09:21:43 +0100 Subject: [PATCH 097/125] feat: add sticky notes (v3.2.0) --- LICENSE | 2 +- TODO.md | 1 - docs/classes.md | 13 ++++++ pydbml/classes/base.py | 2 +- pydbml/classes/column.py | 4 +- pydbml/classes/note.py | 7 +-- pydbml/classes/sticky_note.py | 50 +++++++++++++++++++++ pydbml/database.py | 13 +++++- pydbml/definitions/sticky_note.py | 21 +++++++++ pydbml/parser/blueprints.py | 18 ++++++++ pydbml/parser/parser.py | 21 ++++++--- test/test_blueprints/test_sticky_note.py | 53 +++++++++++++++++++++++ test/test_classes/test_sticky_note.py | 46 ++++++++++++++++++++ test/test_data/docs/sticky_notes.dbml | 10 +++++ test/test_definitions/test_sticky_note.py | 35 +++++++++++++++ test/test_docs.py | 13 +++++- test_schema.dbml | 13 +++++- 17 files changed, 303 insertions(+), 19 deletions(-) delete mode 100644 TODO.md create mode 100644 pydbml/classes/sticky_note.py create mode 100644 pydbml/definitions/sticky_note.py create mode 100644 test/test_blueprints/test_sticky_note.py create mode 100644 test/test_classes/test_sticky_note.py create mode 100644 test/test_data/docs/sticky_notes.dbml create mode 100644 test/test_definitions/test_sticky_note.py diff --git a/LICENSE b/LICENSE index 4d20996..c13f991 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 +Copyright (c) 2024 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 3fde65e..0000000 --- a/TODO.md +++ /dev/null @@ -1 +0,0 @@ -- schema.add and .delete to support multiple arguments (handle errors properly) diff --git a/docs/classes.md b/docs/classes.md index 1dd14d4..276c834 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -6,6 +6,7 @@ * [Reference](#reference) * [Enum](#enum) * [Note](#note) +* [StickyNote](#sticky_note) * [Expression](#expression) * [Project](#project) * [TableGroup](#tablegroup) @@ -260,6 +261,18 @@ Note is a basic class, which may appear in some other classes' `note` attribute. * **sql** (str) — SQL definition for this note. * **dbml** (str) — DBML definition for this note. +## Note + +**new in PyDBML 1.0.10** + +Sticky notes are similar to regular notes, except that they are defined at the root of your DBML file and have a name. + +### Attributes + +**name** (str) — note name. +**text** (str) — note text. +* **dbml** (str) — DBML definition for this note. + ## Expression **new in PyDBML 1.0.0** diff --git a/pydbml/classes/base.py b/pydbml/classes/base.py index ec8bfd9..c19ddec 100644 --- a/pydbml/classes/base.py +++ b/pydbml/classes/base.py @@ -9,7 +9,7 @@ class SQLObject: Base class for all SQL objects. ''' required_attributes: Tuple[str, ...] = () - dont_compare_fields = () + dont_compare_fields: Tuple[str, ...] = () def check_attributes_for_sql(self): ''' diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index d599c2a..c95f3e2 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -46,7 +46,9 @@ def __init__(self, self.default = default self.table: Optional['Table'] = None - def __eq__(self, other: 'Column') -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False self_table = self.table.full_name if self.table else None other_table = other.table.full_name if other.table else None if self_table != other_table: diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index eee65cf..5dd4dcf 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -9,12 +9,9 @@ class Note(SQLObject): dont_compare_fields = ('parent',) - def __init__(self, text: Union[str, 'Note']) -> None: + def __init__(self, text: Any) -> None: self.text: str - if isinstance(text, Note): - self.text = text.text - else: - self.text = str(text) if text else '' + self.text = str(text) if text is not None else '' self.parent: Any = None def __str__(self): diff --git a/pydbml/classes/sticky_note.py b/pydbml/classes/sticky_note.py new file mode 100644 index 0000000..b843cca --- /dev/null +++ b/pydbml/classes/sticky_note.py @@ -0,0 +1,50 @@ +import re +from typing import Any + +from pydbml.tools import indent + + +class StickyNote: + dont_compare_fields = ('database',) + + def __init__(self, name: str, text: Any) -> None: + self.name = name + self.text = str(text) if text is not None else '' + + self.database = None + + def __str__(self): + ''' + >>> print(StickyNote('mynote', 'Note text')) + StickyNote('mynote', 'Note text') + ''' + + return self.__class__.__name__ + f'({repr(self.name)}, {repr(self.text)})' + + def __bool__(self): + return bool(self.text) + + def __repr__(self): + ''' + >>> StickyNote('mynote', 'Note text') + + ''' + + return f'<{self.__class__.__name__} {self.name!r}, {self.text!r}>' + + def _prepare_text_for_dbml(self): + '''Escape single quotes''' + pattern = re.compile(r"('''|')") + return pattern.sub(r'\\\1', self.text) + + @property + def dbml(self): + text = self._prepare_text_for_dbml() + if '\n' in text: + note_text = f"'''\n{text}\n'''" + else: + note_text = f"'{text}'" + + note_text = indent(note_text) + result = f'Note {self.name} {{\n{note_text}\n}}' + return result diff --git a/pydbml/database.py b/pydbml/database.py index 910b6c5..1347909 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -4,11 +4,12 @@ from typing import Optional from typing import Union -from .classes import Enum +from .classes import Enum, Note from .classes import Project from .classes import Reference from .classes import Table from .classes import TableGroup +from .classes.sticky_note import StickyNote from .exceptions import DatabaseValidationError from .constants import MANY_TO_ONE, ONE_TO_MANY @@ -41,6 +42,7 @@ def __init__(self) -> None: self.refs: List['Reference'] = [] self.enums: List['Enum'] = [] self.table_groups: List['TableGroup'] = [] + self.sticky_notes: List['StickyNote'] = [] self.project: Optional['Project'] = None def __repr__(self) -> str: @@ -79,6 +81,8 @@ def add(self, obj: Any) -> Any: return self.add_table_group(obj) elif isinstance(obj, Project): return self.add_project(obj) + elif isinstance(obj, StickyNote): + return self.add_sticky_note(obj) else: raise DatabaseValidationError(f'Unsupported type {type(obj)}.') @@ -125,6 +129,11 @@ def add_enum(self, obj: Enum) -> Enum: self.enums.append(obj) return obj + def add_sticky_note(self, obj: StickyNote) -> StickyNote: + self._set_database(obj) + self.sticky_notes.append(obj) + return obj + def add_table_group(self, obj: TableGroup) -> TableGroup: if obj in self.table_groups: raise DatabaseValidationError(f'{obj} is already in the database.') @@ -216,7 +225,7 @@ def dbml(self): '''Generates DBML code out of parsed results''' items = [self.project] if self.project else [] refs = (ref for ref in self.refs if not ref.inline) - items.extend((*self.enums, *self.tables, *refs, *self.table_groups)) + items.extend((*self.enums, *self.tables, *refs, *self.table_groups, *self.sticky_notes)) components = ( i.dbml for i in items ) diff --git a/pydbml/definitions/sticky_note.py b/pydbml/definitions/sticky_note.py new file mode 100644 index 0000000..ebd1848 --- /dev/null +++ b/pydbml/definitions/sticky_note.py @@ -0,0 +1,21 @@ +import pyparsing as pp + +from .common import _, end, _c +from .generic import string_literal, name +from ..parser.blueprints import StickyNoteBlueprint + +sticky_note = _c + pp.CaselessLiteral('note') + _ + (name('name') + _ - '{' + _ - string_literal('text') + _ - '}') + end + + +def parse_sticky_note(s, loc, tok): + ''' + Note single_line_note { + 'This is a single line note' + } + ''' + init_dict = {'name': tok['name'], 'text': tok['text']} + + return StickyNoteBlueprint(**init_dict) + + +sticky_note.set_parse_action(parse_sticky_note) diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index a4e0ecc..8ceea10 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -17,6 +17,7 @@ from pydbml.classes import Reference from pydbml.classes import Table from pydbml.classes import TableGroup +from pydbml.classes.sticky_note import StickyNote from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError from pydbml.exceptions import ValidationError @@ -43,6 +44,23 @@ def build(self) -> 'Note': return Note(text) +@dataclass +class StickyNoteBlueprint(Blueprint): + name: str + text: str + + def _preformat_text(self) -> str: + '''Preformat the note text for idempotence''' + result = strip_empty_lines(self.text) + result = remove_indentation(result) + return result + + def build(self) -> StickyNote: + text = self._preformat_text() + name = self.name + return StickyNote(name=name, text=text) + + @dataclass class ExpressionBlueprint(Blueprint): text: str diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 9fd0531..0f3dabb 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -1,4 +1,5 @@ from __future__ import annotations + from io import TextIOWrapper from pathlib import Path from typing import List @@ -7,22 +8,22 @@ import pyparsing as pp -from .blueprints import EnumBlueprint -from .blueprints import ProjectBlueprint -from .blueprints import ReferenceBlueprint -from .blueprints import TableBlueprint -from .blueprints import TableGroupBlueprint from pydbml.classes import Table from pydbml.database import Database from pydbml.definitions.common import comment from pydbml.definitions.enum import enum from pydbml.definitions.project import project from pydbml.definitions.reference import ref +from pydbml.definitions.sticky_note import sticky_note from pydbml.definitions.table import table from pydbml.definitions.table_group import table_group from pydbml.exceptions import TableNotFoundError from pydbml.tools import remove_bom - +from .blueprints import EnumBlueprint, StickyNoteBlueprint +from .blueprints import ProjectBlueprint +from .blueprints import ReferenceBlueprint +from .blueprints import TableBlueprint +from .blueprints import TableGroupBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') @@ -99,6 +100,7 @@ def __init__(self, source: str): self.refs: List[ReferenceBlueprint] = [] self.enums: List[EnumBlueprint] = [] self.project: Optional[ProjectBlueprint] = None + self.sticky_notes: List[StickyNoteBlueprint] = [] def parse(self): self._set_syntax() @@ -120,12 +122,14 @@ def _set_syntax(self): enum_expr = enum.copy() table_group_expr = table_group.copy() project_expr = project.copy() + note_expr = sticky_note.copy() table_expr.addParseAction(self.parse_blueprint) ref_expr.addParseAction(self.parse_blueprint) enum_expr.addParseAction(self.parse_blueprint) table_group_expr.addParseAction(self.parse_blueprint) project_expr.addParseAction(self.parse_blueprint) + note_expr.addParseAction(self.parse_blueprint) expr = ( table_expr @@ -133,6 +137,7 @@ def _set_syntax(self): | enum_expr | table_group_expr | project_expr + | note_expr ) self._syntax = expr[...] + ('\n' | comment)[...] + pp.StringEnd() @@ -169,6 +174,8 @@ def parse_blueprint(self, s, loc, tok): self.project = blueprint if blueprint.note: blueprint.note.parser = self + elif isinstance(blueprint, StickyNoteBlueprint): + self.sticky_notes.append(blueprint) else: raise RuntimeError(f'type unknown: {blueprint}') blueprint.parser = self @@ -194,6 +201,8 @@ def build_database(self): self.ref_blueprints.extend(table_bp.get_reference_blueprints()) for table_group_bp in self.table_groups: self.database.add(table_group_bp.build()) + for note_bp in self.sticky_notes: + self.database.add(note_bp.build()) if self.project: self.database.add(self.project.build()) for ref_bp in self.refs: diff --git a/test/test_blueprints/test_sticky_note.py b/test/test_blueprints/test_sticky_note.py new file mode 100644 index 0000000..5cdcd9c --- /dev/null +++ b/test/test_blueprints/test_sticky_note.py @@ -0,0 +1,53 @@ +from unittest import TestCase + +from pydbml.classes.sticky_note import StickyNote +from pydbml.parser.blueprints import StickyNoteBlueprint + +class TestNote(TestCase): + def test_build(self) -> None: + bp = StickyNoteBlueprint(name='mynote', text='Note text') + result = bp.build() + self.assertIsInstance(result, StickyNote) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.text, bp.text) + + def test_preformat_not_needed(self): + oneline = 'One line of note text' + multiline = 'Multiline\nnote\n\ntext' + long_line = 'Lorem ipsum dolor sit amet consectetur adipisicing elit. Aspernatur quidem adipisci, impedit, ut illum dolorum consequatur odio voluptate numquam ea itaque excepturi, a libero placeat corrupti. Amet beatae suscipit necessitatibus. Ea expedita explicabo iste quae rem aliquam minus cumque eveniet enim delectus, alias aut impedit quaerat quia ex, aliquid sint amet iusto rerum! Sunt deserunt ea saepe corrupti officiis. Assumenda.' + + bp = StickyNoteBlueprint(name='mynote', text=oneline) + self.assertEqual(bp.name, bp.name) + self.assertEqual(bp._preformat_text(), oneline) + bp = StickyNoteBlueprint(name='mynote', text=multiline) + self.assertEqual(bp.name, bp.name) + self.assertEqual(bp._preformat_text(), multiline) + bp = StickyNoteBlueprint(name='mynote', text=long_line) + self.assertEqual(bp.name, bp.name) + self.assertEqual(bp._preformat_text(), long_line) + + def test_preformat_needed(self): + uniform_indentation = ' line1\n line2\n line3' + varied_indentation = ' line1\n line2\n\n line3' + empty_lines = '\n\n\n\n\n\n\nline1\nline2\nline3\n\n\n\n\n\n\n' + empty_indented_lines = '\n \n\n \n\n line1\n line2\n line3\n\n\n\n \n\n\n' + + exptected = 'line1\nline2\nline3' + bp = StickyNoteBlueprint(name='mynote', text=uniform_indentation) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) + + exptected = 'line1\n line2\n\n line3' + bp = StickyNoteBlueprint(name='mynote', text=varied_indentation) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) + + exptected = 'line1\nline2\nline3' + bp = StickyNoteBlueprint(name='mynote', text=empty_lines) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) + + exptected = 'line1\nline2\nline3' + bp = StickyNoteBlueprint(name='mynote', text=empty_indented_lines) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) diff --git a/test/test_classes/test_sticky_note.py b/test/test_classes/test_sticky_note.py new file mode 100644 index 0000000..1727e83 --- /dev/null +++ b/test/test_classes/test_sticky_note.py @@ -0,0 +1,46 @@ +from pydbml.classes import Table +from pydbml.classes import Index +from pydbml.classes import Column +from unittest import TestCase + +from pydbml.classes.sticky_note import StickyNote + + +class TestNote(TestCase): + def test_init_types(self): + n1 = StickyNote('mynote', 'My note text') + n2 = StickyNote('mynote', 3) + n3 = StickyNote('mynote', [1, 2, 3]) + n4 = StickyNote('mynote', None) + + self.assertEqual(n1.text, 'My note text') + self.assertEqual(n2.text, '3') + self.assertEqual(n3.text, '[1, 2, 3]') + self.assertEqual(n4.text, '') + self.assertTrue(n1.name == n2.name == n3.name == n4.name == 'mynote') + + def test_oneline(self): + note = StickyNote('mynote', 'One line of note text') + expected = \ +'''Note mynote { + 'One line of note text' +}''' + self.assertEqual(note.dbml, expected) + + def test_forced_multiline(self): + note = StickyNote('mynote', 'The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.') + expected = \ +"""Note mynote { + ''' + The number of spaces you use to indent a block string + will + be the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output. + ''' +}""" + self.assertEqual(note.dbml, expected) + + def test_prepare_text_for_dbml(self): + quotes = "'asd' There's ''' asda '''' asd ''''' asdsa ''" + expected = "\\'asd\\' There\\'s \\''' asda \\'''\\' asd \\'''\\'\\' asdsa \\'\\'" + note = StickyNote('mynote', quotes) + self.assertEqual(note._prepare_text_for_dbml(), expected) diff --git a/test/test_data/docs/sticky_notes.dbml b/test/test_data/docs/sticky_notes.dbml new file mode 100644 index 0000000..7f03cf7 --- /dev/null +++ b/test/test_data/docs/sticky_notes.dbml @@ -0,0 +1,10 @@ +Note single_line_note { + 'This is a single line note' +} + +Note multiple_lines_note { +''' + This is a multiple lines note + This string can spans over multiple lines. +''' +} diff --git a/test/test_definitions/test_sticky_note.py b/test/test_definitions/test_sticky_note.py new file mode 100644 index 0000000..2870b9d --- /dev/null +++ b/test/test_definitions/test_sticky_note.py @@ -0,0 +1,35 @@ +from unittest import TestCase + +from pyparsing import ParseSyntaxException + +from pydbml.definitions.sticky_note import sticky_note + + +class TestSticky(TestCase): + def test_single_quote(self) -> None: + val = "note mynote {'test note'}" + res = sticky_note.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, 'mynote') + self.assertEqual(res[0].text, 'test note') + + def test_double_quote(self) -> None: + val = 'note \n\nmynote\n\n {\n\n"test note"\n\n}' + res = sticky_note.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, 'mynote') + self.assertEqual(res[0].text, 'test note') + + def test_multiline(self) -> None: + val = "note\nmynote\n{ '''line1\nline2\nline3'''}" + res = sticky_note.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, 'mynote') + self.assertEqual(res[0].text, 'line1\nline2\nline3') + + def test_unclosed_quote(self) -> None: + val = 'note mynote{ "test note}' + with self.assertRaises(ParseSyntaxException): + sticky_note.parse_string(val, parseAll=True) + + def test_not_allowed_multiline(self) -> None: + val = "note mynote { 'line1\nline2\nline3' }" + with self.assertRaises(ParseSyntaxException): + sticky_note.parse_string(val, parseAll=True) diff --git a/test/test_docs.py b/test/test_docs.py index 3880303..c5025bd 100644 --- a/test/test_docs.py +++ b/test/test_docs.py @@ -238,7 +238,6 @@ def test_enum_definition(self) -> None: self.assertEqual(g.name, 'grade') self.assertEqual([ei.name for ei in g.items], ['A+', 'A', 'A-', 'Not Yet Set']) - def test_table_group(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'table_group.dbml') @@ -253,3 +252,15 @@ def test_table_group(self) -> None: self.assertEqual(tg1.items, [tb1, tb2, tb3]) self.assertEqual(tg2.name, 'e_commerce1') self.assertEqual(tg2.items, [merchants, countries]) + + def test_sticky_notes(self) -> None: + results = PyDBML.parse_file(TEST_DOCS_PATH / 'sticky_notes.dbml') + + self.assertEqual(len(results.sticky_notes), 2) + + sn1, sn2 = results.sticky_notes + + self.assertEqual(sn1.name, 'single_line_note') + self.assertEqual(sn1.text, 'This is a single line note') + self.assertEqual(sn2.name, 'multiple_lines_note') + self.assertEqual(sn2.text, '''This is a multiple lines note\nThis string can spans over multiple lines.''') diff --git a/test_schema.dbml b/test_schema.dbml index 2739c7d..432c1ca 100644 --- a/test_schema.dbml +++ b/test_schema.dbml @@ -74,7 +74,7 @@ Table "merchants" { } -Ref:"products"."id" < "order_items"."product_id" +Ref:"products"."id" < "order_items"."product_id" [update: set default, delete: set null] Ref:"countries"."code" < "users"."country_code" @@ -89,3 +89,14 @@ Table "countries" { "name" varchar "continent_name" varchar } + +Note sticky_note1 { + 'One line note' +} + +Note sticky_note2 { + ''' + # Title + body + ''' +} From 4f21e39b72b3ca2c0a835221f68f8004d09c7986 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 17 Mar 2024 09:26:28 +0100 Subject: [PATCH 098/125] chore: bump version, update readme and changelog --- CHANGELOG.md | 8 ++++++++ README.md | 2 +- setup.py | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0290b03..c47a577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# 1.0.10 +- New: Sticky notes syntax (DBML v3.2.0) +- Fix: Table header color was not rendered in `dbml()` (thanks @tristangrebot for the contribution) +- New: allow array column types (DBML v3.1.0) +- New: allow double quotes in expressions (DBML v3.1.2) +- Fix: recursion in object equality check +- New: don't allow duplicate refs even if they have different inline method (DBML v3.1.6) + # 1.0.9 - Fix: enum collision from different schemas. Thanks @ewdurbin for the contribution diff --git a/README.md b/README.md index bec1991..2f1e426 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # DBML parser for Python -*Compliant with DBML **v2.6.1** syntax* +*Compliant with DBML **v3.2.0** syntax* PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. diff --git a/setup.py b/setup.py index 0aa37c5..bc65cb8 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.9', + version='1.0.10', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 77e24d463aae92dbbaf976d6a54b1b2b96082c61 Mon Sep 17 00:00:00 2001 From: Pierre Souchay Date: Tue, 23 Apr 2024 18:09:03 +0200 Subject: [PATCH 099/125] fix(indexes): let pk being named (#35) * fix(indexes): let pk being named Ref: https://github.com/holistics/dbml/pull/549 At least postgresql let pk being named, handle that properly * tests(indexes): allow pk to have names and notes, updated the test case --- pydbml/definitions/index.py | 4 ++-- test/test_data/relationships_aliases.dbml | 8 ++++++++ test/test_definitions/test_index.py | 8 -------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index f08090b..5c7e096 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -22,10 +22,10 @@ | index_type | pp.CaselessLiteral("name:") + _ - string_literal('name') | note('note') + | pk('pk') ) index_settings = ( - '[' + _ + pk('pk') + _ - ']' + c - | '[' + _ + index_setting + (_ + ',' + _ - index_setting)[...] + _ - ']' + c + '[' + _ + index_setting + (_ + ',' + _ - index_setting)[...] + _ - ']' + c ) diff --git a/test/test_data/relationships_aliases.dbml b/test/test_data/relationships_aliases.dbml index 4e7265b..7e8589d 100644 --- a/test/test_data/relationships_aliases.dbml +++ b/test/test_data/relationships_aliases.dbml @@ -29,3 +29,11 @@ Table reviews2 as re2 { Table users2 as us2 { id integer } + +Table "alembic_version" { + "version_num" "character varying(32)" [not null] + +Indexes { + version_num [pk, name: "alembic_version_pk"] +} +} \ No newline at end of file diff --git a/test/test_definitions/test_index.py b/test/test_definitions/test_index.py index 00a7ee8..4396998 100644 --- a/test/test_definitions/test_index.py +++ b/test/test_definitions/test_index.py @@ -80,14 +80,6 @@ def test_pk(self) -> None: res = index_settings.parse_string(val, parseAll=True) self.assertTrue(res[0]['pk']) - def test_wrong_pk(self) -> None: - val = '[pk, name: "not allowed"]' - with self.assertRaises(ParseSyntaxException): - index_settings.parse_string(val, parseAll=True) - val2 = '[note: "pk not allowed", pk]' - with self.assertRaises(ParseSyntaxException): - index_settings.parse_string(val2, parseAll=True) - def test_all(self) -> None: val = '[type: hash, name: "index name", note: "index note", unique]' res = index_settings.parse_string(val, parseAll=True) From 6e0af331cfd039642567b6ee6c8db36a50995c56 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 23 Apr 2024 18:11:22 +0200 Subject: [PATCH 100/125] bump version --- CHANGELOG.md | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c47a577..61c8825 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.0.11 +- Fix: allow pk in named indexes (thanks @pierresouchay for the contribution) + # 1.0.10 - New: Sticky notes syntax (DBML v3.2.0) - Fix: Table header color was not rendered in `dbml()` (thanks @tristangrebot for the contribution) diff --git a/setup.py b/setup.py index bc65cb8..634e0b7 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.10', + version='1.0.11', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 90c63c83190e5e104dfbb11caa932b935fd84068 Mon Sep 17 00:00:00 2001 From: Daniel Minukhin Date: Thu, 18 Jul 2024 20:50:56 +0200 Subject: [PATCH 101/125] Rewrite SQL and DBML rendering (#39) --- pydbml/__init__.py | 5 +- pydbml/_classes/__init__.py | 0 pydbml/{classes => _classes}/base.py | 22 + pydbml/_classes/column.py | 92 ++++ pydbml/{classes => _classes}/enum.py | 75 +--- pydbml/{classes => _classes}/expression.py | 12 +- pydbml/_classes/index.py | 70 +++ pydbml/_classes/note.py | 23 + pydbml/_classes/project.py | 34 ++ pydbml/_classes/reference.py | 120 +++++ pydbml/_classes/sticky_note.py | 24 + pydbml/{classes => _classes}/table.py | 124 +---- pydbml/{classes => _classes}/table_group.py | 16 +- pydbml/classes/__init__.py | 35 +- pydbml/classes/column.py | 164 ------- pydbml/classes/index.py | 170 ------- pydbml/classes/note.py | 87 ---- pydbml/classes/project.py | 56 --- pydbml/classes/reference.py | 245 ---------- pydbml/classes/sticky_note.py | 50 --- pydbml/database.py | 57 +-- pydbml/parser/blueprints.py | 2 +- pydbml/parser/parser.py | 10 - pydbml/renderer/__init__.py | 0 pydbml/renderer/base.py | 38 ++ pydbml/renderer/dbml/__init__.py | 0 pydbml/renderer/dbml/default/__init__.py | 10 + pydbml/renderer/dbml/default/column.py | 51 +++ pydbml/renderer/dbml/default/enum.py | 25 ++ pydbml/renderer/dbml/default/expression.py | 7 + pydbml/renderer/dbml/default/index.py | 49 ++ pydbml/renderer/dbml/default/note.py | 24 + pydbml/renderer/dbml/default/project.py | 27 ++ pydbml/renderer/dbml/default/reference.py | 68 +++ pydbml/renderer/dbml/default/renderer.py | 19 + pydbml/renderer/dbml/default/sticky_note.py | 18 + pydbml/renderer/dbml/default/table.py | 51 +++ pydbml/renderer/dbml/default/table_group.py | 14 + pydbml/renderer/dbml/default/utils.py | 13 + pydbml/renderer/sql/__init__.py | 0 pydbml/renderer/sql/default/__init__.py | 8 + pydbml/renderer/sql/default/column.py | 39 ++ pydbml/renderer/sql/default/enum.py | 33 ++ pydbml/renderer/sql/default/expression.py | 7 + pydbml/renderer/sql/default/index.py | 65 +++ pydbml/renderer/sql/default/note.py | 43 ++ pydbml/renderer/sql/default/reference.py | 82 ++++ pydbml/renderer/sql/default/renderer.py | 24 + pydbml/renderer/sql/default/table.py | 97 ++++ pydbml/renderer/sql/default/utils.py | 37 ++ pydbml/tools.py | 17 +- setup.py | 11 +- test/conftest.py | 137 ++++++ test/test_blueprints/test_sticky_note.py | 2 +- test/test_classes/test_base.py | 2 +- test/test_classes/test_column.py | 223 ++------- test/test_classes/test_enum.py | 120 +---- test/test_classes/test_expression.py | 14 +- test/test_classes/test_index.py | 140 +----- test/test_classes/test_note.py | 111 +---- test/test_classes/test_project.py | 59 +-- test/test_classes/test_reference.py | 424 +----------------- test/test_classes/test_sticky_note.py | 74 ++- test/test_classes/test_table.py | 380 +--------------- test/test_classes/test_table_group.py | 46 -- test/test_database.py | 34 +- test/test_doctest.py | 18 +- test/test_parser.py | 9 + test/test_renderer/__init__.py | 0 test/test_renderer/test_base.py | 40 ++ test/test_renderer/test_dbml/__init__.py | 0 test/test_renderer/test_dbml/test_column.py | 120 +++++ test/test_renderer/test_dbml/test_enum.py | 36 ++ .../test_dbml/test_expression.py | 6 + test/test_renderer/test_dbml/test_index.py | 92 ++++ test/test_renderer/test_dbml/test_note.py | 22 + test/test_renderer/test_dbml/test_project.py | 46 ++ .../test_renderer/test_dbml/test_reference.py | 123 +++++ test/test_renderer/test_dbml/test_renderer.py | 20 + .../test_dbml/test_sticky_note.py | 17 + test/test_renderer/test_dbml/test_table.py | 81 ++++ .../test_dbml/test_table_group.py | 18 + test/test_renderer/test_dbml/test_utils.py | 20 + test/test_renderer/test_sql/__init__.py | 0 .../test_sql/test_default/__init__.py | 0 .../test_sql/test_default/test_column.py | 19 + .../test_sql/test_default/test_enum.py | 39 ++ .../test_sql/test_default/test_expression.py | 6 + .../test_sql/test_default/test_index.py | 82 ++++ .../test_sql/test_default/test_note.py | 45 ++ .../test_sql/test_default/test_reference.py | 153 +++++++ .../test_sql/test_default/test_renderer.py | 29 ++ .../test_sql/test_default/test_table.py | 215 +++++++++ .../test_sql/test_default/test_utils.py | 50 +++ test/test_tools.py | 6 +- 95 files changed, 2895 insertions(+), 2553 deletions(-) create mode 100644 pydbml/_classes/__init__.py rename pydbml/{classes => _classes}/base.py (66%) create mode 100644 pydbml/_classes/column.py rename pydbml/{classes => _classes}/enum.py (51%) rename pydbml/{classes => _classes}/expression.py (65%) create mode 100644 pydbml/_classes/index.py create mode 100644 pydbml/_classes/note.py create mode 100644 pydbml/_classes/project.py create mode 100644 pydbml/_classes/reference.py create mode 100644 pydbml/_classes/sticky_note.py rename pydbml/{classes => _classes}/table.py (56%) rename pydbml/{classes => _classes}/table_group.py (75%) delete mode 100644 pydbml/classes/column.py delete mode 100644 pydbml/classes/index.py delete mode 100644 pydbml/classes/note.py delete mode 100644 pydbml/classes/project.py delete mode 100644 pydbml/classes/reference.py delete mode 100644 pydbml/classes/sticky_note.py create mode 100644 pydbml/renderer/__init__.py create mode 100644 pydbml/renderer/base.py create mode 100644 pydbml/renderer/dbml/__init__.py create mode 100644 pydbml/renderer/dbml/default/__init__.py create mode 100644 pydbml/renderer/dbml/default/column.py create mode 100644 pydbml/renderer/dbml/default/enum.py create mode 100644 pydbml/renderer/dbml/default/expression.py create mode 100644 pydbml/renderer/dbml/default/index.py create mode 100644 pydbml/renderer/dbml/default/note.py create mode 100644 pydbml/renderer/dbml/default/project.py create mode 100644 pydbml/renderer/dbml/default/reference.py create mode 100644 pydbml/renderer/dbml/default/renderer.py create mode 100644 pydbml/renderer/dbml/default/sticky_note.py create mode 100644 pydbml/renderer/dbml/default/table.py create mode 100644 pydbml/renderer/dbml/default/table_group.py create mode 100644 pydbml/renderer/dbml/default/utils.py create mode 100644 pydbml/renderer/sql/__init__.py create mode 100644 pydbml/renderer/sql/default/__init__.py create mode 100644 pydbml/renderer/sql/default/column.py create mode 100644 pydbml/renderer/sql/default/enum.py create mode 100644 pydbml/renderer/sql/default/expression.py create mode 100644 pydbml/renderer/sql/default/index.py create mode 100644 pydbml/renderer/sql/default/note.py create mode 100644 pydbml/renderer/sql/default/reference.py create mode 100644 pydbml/renderer/sql/default/renderer.py create mode 100644 pydbml/renderer/sql/default/table.py create mode 100644 pydbml/renderer/sql/default/utils.py create mode 100644 test/conftest.py create mode 100644 test/test_renderer/__init__.py create mode 100644 test/test_renderer/test_base.py create mode 100644 test/test_renderer/test_dbml/__init__.py create mode 100644 test/test_renderer/test_dbml/test_column.py create mode 100644 test/test_renderer/test_dbml/test_enum.py create mode 100644 test/test_renderer/test_dbml/test_expression.py create mode 100644 test/test_renderer/test_dbml/test_index.py create mode 100644 test/test_renderer/test_dbml/test_note.py create mode 100644 test/test_renderer/test_dbml/test_project.py create mode 100644 test/test_renderer/test_dbml/test_reference.py create mode 100644 test/test_renderer/test_dbml/test_renderer.py create mode 100644 test/test_renderer/test_dbml/test_sticky_note.py create mode 100644 test/test_renderer/test_dbml/test_table.py create mode 100644 test/test_renderer/test_dbml/test_table_group.py create mode 100644 test/test_renderer/test_dbml/test_utils.py create mode 100644 test/test_renderer/test_sql/__init__.py create mode 100644 test/test_renderer/test_sql/test_default/__init__.py create mode 100644 test/test_renderer/test_sql/test_default/test_column.py create mode 100644 test/test_renderer/test_sql/test_default/test_enum.py create mode 100644 test/test_renderer/test_sql/test_default/test_expression.py create mode 100644 test/test_renderer/test_sql/test_default/test_index.py create mode 100644 test/test_renderer/test_sql/test_default/test_note.py create mode 100644 test/test_renderer/test_sql/test_default/test_reference.py create mode 100644 test/test_renderer/test_sql/test_default/test_renderer.py create mode 100644 test/test_renderer/test_sql/test_default/test_table.py create mode 100644 test/test_renderer/test_sql/test_default/test_utils.py diff --git a/pydbml/__init__.py b/pydbml/__init__.py index df94b5c..a585105 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -1,6 +1,3 @@ -from . import classes +from . import _classes from .parser import PyDBML from .database import Database -from pydbml.constants import MANY_TO_ONE -from pydbml.constants import ONE_TO_MANY -from pydbml.constants import ONE_TO_ONE diff --git a/pydbml/_classes/__init__.py b/pydbml/_classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/classes/base.py b/pydbml/_classes/base.py similarity index 66% rename from pydbml/classes/base.py rename to pydbml/_classes/base.py index c19ddec..d1f781f 100644 --- a/pydbml/classes/base.py +++ b/pydbml/_classes/base.py @@ -21,6 +21,15 @@ def check_attributes_for_sql(self): raise AttributeMissingError( f'Cannot render SQL. Missing required attribute "{attr}".' ) + @property + def sql(self) -> str: + if hasattr(self, 'database') and self.database is not None: + renderer = self.database.sql_renderer + else: + from pydbml.renderer.sql.default import DefaultSQLRenderer + renderer = DefaultSQLRenderer + + return renderer.render(self) def __setattr__(self, name: str, value: Any): """ @@ -46,3 +55,16 @@ def __eq__(self, other: object) -> bool: other_dict.pop(field, None) return self_dict == other_dict + + +class DBMLObject: + '''Base class for all DBML objects.''' + @property + def dbml(self) -> str: + if hasattr(self, 'database') and self.database is not None: + renderer = self.database.dbml_renderer + else: + from pydbml.renderer.dbml.default import DefaultDBMLRenderer + renderer = DefaultDBMLRenderer + + return renderer.render(self) diff --git a/pydbml/_classes/column.py b/pydbml/_classes/column.py new file mode 100644 index 0000000..c2ff98b --- /dev/null +++ b/pydbml/_classes/column.py @@ -0,0 +1,92 @@ +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from pydbml.exceptions import TableNotFoundError +from .base import SQLObject, DBMLObject +from .enum import Enum +from .expression import Expression +from .note import Note + +if TYPE_CHECKING: # pragma: no cover + from .table import Table + from .reference import Reference + + +class Column(SQLObject, DBMLObject): + '''Class representing table column.''' + + required_attributes = ('name', 'type') + dont_compare_fields = ('table',) + + def __init__(self, + name: str, + type: Union[str, Enum], + unique: bool = False, + not_null: bool = False, + pk: bool = False, + autoinc: bool = False, + default: Optional[Union[str, int, bool, float, Expression]] = None, + note: Optional[Union[Note, str]] = None, + # ref_blueprints: Optional[List[ReferenceBlueprint]] = None, + comment: Optional[str] = None): + self.name = name + self.type = type + self.unique = unique + self.not_null = not_null + self.pk = pk + self.autoinc = autoinc + self.comment = comment + self.note = Note(note) + + self.default = default + self.table: Optional['Table'] = None + + def __eq__(self, other: object) -> bool: + if other is self: + return True + if not isinstance(other, self.__class__): + return False + self_table = self.table.full_name if self.table else None + other_table = other.table.full_name if other.table else None + if self_table != other_table: + return False + return super().__eq__(other) + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + + def get_refs(self) -> List['Reference']: + ''' + get all references related to this column (where this col is col1 in) + ''' + if not self.table: + raise TableNotFoundError('Table for the column is not set') + return [ref for ref in self.table.get_refs() if self in ref.col1] + + @property + def database(self): + return self.table.database if self.table else None + + def __repr__(self): + ''' + >>> Column('name', 'VARCHAR2') + + ''' + type_name = self.type if isinstance(self.type, str) else self.type.name + return f'' + + def __str__(self): + ''' + >>> print(Column('name', 'VARCHAR2')) + name[VARCHAR2] + ''' + + return f'{self.name}[{self.type}]' diff --git a/pydbml/classes/enum.py b/pydbml/_classes/enum.py similarity index 51% rename from pydbml/classes/enum.py rename to pydbml/_classes/enum.py index ec4ad14..0b2aee6 100644 --- a/pydbml/classes/enum.py +++ b/pydbml/_classes/enum.py @@ -3,20 +3,18 @@ from typing import Optional from typing import Union -from .base import SQLObject +from .base import SQLObject, DBMLObject from .note import Note -from pydbml.tools import comment_to_dbml -from pydbml.tools import comment_to_sql -from pydbml.tools import indent -from pydbml.tools import note_option_to_dbml -class EnumItem: +class EnumItem(SQLObject, DBMLObject): '''Single enum item''' + required_attributes = ('name',) + def __init__(self, name: str, - note: Optional[Union['Note', str]] = None, + note: Optional[Union[Note, str]] = None, comment: Optional[str] = None): self.name = name self.note = Note(note) @@ -32,37 +30,15 @@ def note(self, val: Note) -> None: val.parent = self def __repr__(self): - ''' - >>> EnumItem('en-US') - - ''' - + '''''' return f'' def __str__(self): - ''' - >>> print(EnumItem('en-US')) - en-US - ''' - + '''en-US''' return self.name - @property - def sql(self): - result = comment_to_sql(self.comment) if self.comment else '' - result += f"'{self.name}'," - return result - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'"{self.name}"' - if self.note: - result += f' [{note_option_to_dbml(self.note)}]' - return result - -class Enum(SQLObject): +class Enum(SQLObject, DBMLObject): required_attributes = ('name', 'schema', 'items') def __init__(self, @@ -111,38 +87,3 @@ def __str__(self): ''' return self.name - - def _get_full_name_for_sql(self) -> str: - if self.schema == 'public': - return f'"{self.name}"' - else: - return f'"{self.schema}"."{self.name}"' - - @property - def sql(self): - ''' - Returns SQL for enum type: - - CREATE TYPE "job_status" AS ENUM ( - 'created', - 'running', - 'donef', - 'failure', - ); - - ''' - self.check_attributes_for_sql() - result = comment_to_sql(self.comment) if self.comment else '' - result += f'CREATE TYPE {self._get_full_name_for_sql()} AS ENUM (\n' - result += '\n'.join(f'{indent(i.sql, 2)}' for i in self.items) - result += '\n);' - return result - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Enum {self._get_full_name_for_sql()} {{\n' - items_str = '\n'.join(i.dbml for i in self.items) - result += indent(items_str) - result += '\n}' - return result diff --git a/pydbml/classes/expression.py b/pydbml/_classes/expression.py similarity index 65% rename from pydbml/classes/expression.py rename to pydbml/_classes/expression.py index 30d0d6b..8b34445 100644 --- a/pydbml/classes/expression.py +++ b/pydbml/_classes/expression.py @@ -1,7 +1,7 @@ -from .base import SQLObject +from .base import SQLObject, DBMLObject -class Expression(SQLObject): +class Expression(SQLObject, DBMLObject): def __init__(self, text: str): self.text = text @@ -20,11 +20,3 @@ def __repr__(self) -> str: ''' return f'Expression({repr(self.text)})' - - @property - def sql(self) -> str: - return f'({self.text})' - - @property - def dbml(self) -> str: - return f'`{self.text}`' diff --git a/pydbml/_classes/index.py b/pydbml/_classes/index.py new file mode 100644 index 0000000..35581bd --- /dev/null +++ b/pydbml/_classes/index.py @@ -0,0 +1,70 @@ +from typing import List +from typing import Literal +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from .base import SQLObject, DBMLObject +from .column import Column +from .expression import Expression +from .note import Note + +if TYPE_CHECKING: # pragma: no cover + from .table import Table + + +class Index(SQLObject, DBMLObject): + '''Class representing index.''' + required_attributes = ('subjects', 'table') + dont_compare_fields = ('table',) + + def __init__(self, + subjects: List[Union[str, Column, Expression]], + name: Optional[str] = None, + unique: bool = False, + type: Optional[Literal['hash', 'btree']] = None, + pk: bool = False, + note: Optional[Union[Note, str]] = None, + comment: Optional[str] = None): + self.subjects = subjects + self.table: Optional[Table] = None + + self.name = name if name else None + self.unique = unique + self.type = type + self.pk = pk + self.note = Note(note) + self.comment = comment + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + + @property + def subject_names(self): + ''' + Returns updated list of subject names. + ''' + return [s.name if isinstance(s, Column) else str(s) for s in self.subjects] + + def __repr__(self): + ''' + + ''' + + table_name = self.table.name if self.table else None + return f"" + + def __str__(self): + ''' + Index(test[col, (c*2)]) + ''' + + table_name = self.table.name if self.table else '' + subjects = ', '.join(self.subject_names) + return f"Index({table_name}[{subjects}])" diff --git a/pydbml/_classes/note.py b/pydbml/_classes/note.py new file mode 100644 index 0000000..627eb07 --- /dev/null +++ b/pydbml/_classes/note.py @@ -0,0 +1,23 @@ +from typing import Any + +from .base import SQLObject, DBMLObject + + +class Note(SQLObject, DBMLObject): + dont_compare_fields = ('parent',) + + def __init__(self, text: Any) -> None: + self.text: str + self.text = str(text) if text is not None else '' + self.parent: Any = None + + def __str__(self): + '''Note text''' + return self.text + + def __bool__(self): + return bool(self.text) + + def __repr__(self): + '''Note('Note text')''' + return f'Note({repr(self.text)})' diff --git a/pydbml/_classes/project.py b/pydbml/_classes/project.py new file mode 100644 index 0000000..1efa235 --- /dev/null +++ b/pydbml/_classes/project.py @@ -0,0 +1,34 @@ +from typing import Dict +from typing import Optional +from typing import Union + +from pydbml._classes.base import DBMLObject +from pydbml._classes.note import Note + + +class Project(DBMLObject): + dont_compare_fields = ('database',) + + def __init__(self, + name: str, + items: Optional[Dict[str, str]] = None, + note: Optional[Union[Note, str]] = None, + comment: Optional[str] = None): + self.database = None + self.name = name + self.items = items + self.note = Note(note) + self.comment = comment + + def __repr__(self): + """""" + return f'' + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self diff --git a/pydbml/_classes/reference.py b/pydbml/_classes/reference.py new file mode 100644 index 0000000..9da1e03 --- /dev/null +++ b/pydbml/_classes/reference.py @@ -0,0 +1,120 @@ +from itertools import chain +from typing import Collection +from typing import Literal +from typing import Optional +from typing import Union + +from pydbml.constants import MANY_TO_MANY +from pydbml.exceptions import DBMLError +from pydbml.exceptions import TableNotFoundError +from .base import SQLObject, DBMLObject +from .column import Column +from .table import Table + + +class Reference(SQLObject, DBMLObject): + ''' + Class, representing a foreign key constraint. + It is a separate object, which is not connected to Table or Column objects + and its `sql` property contains the ALTER TABLE clause. + ''' + required_attributes = ('type', 'col1', 'col2') + dont_compare_fields = ('database', '_inline') + + def __init__(self, + type: Literal['>', '<', '-', '<>'], + col1: Union[Column, Collection[Column]], + col2: Union[Column, Collection[Column]], + name: Optional[str] = None, + comment: Optional[str] = None, + on_update: Optional[str] = None, + on_delete: Optional[str] = None, + inline: bool = False): + self.database = None + self.type = type + self.col1 = [col1] if isinstance(col1, Column) else list(col1) + self.col2 = [col2] if isinstance(col2, Column) else list(col2) + self.name = name if name else None + self.comment = comment + self.on_update = on_update + self.on_delete = on_delete + self._inline = inline + + @property + def inline(self) -> bool: + return self._inline and not self.type == MANY_TO_MANY + + @inline.setter + def inline(self, val) -> None: + self._inline = val + + @property + def join_table(self) -> Optional[Table]: + if self.type != MANY_TO_MANY: + return None + + if self.table1 is None: + raise TableNotFoundError(f"Cannot generate join table for {self}: table 1 is unknown") + if self.table2 is None: + raise TableNotFoundError(f"Cannot generate join table for {self}: table 2 is unknown") + + return Table( + name=f'{self.table1.name}_{self.table2.name}', + schema=self.table1.schema, + columns=( + Column(name=f'{c.table.name}_{c.name}', type=c.type, not_null=True, pk=True) # type: ignore + for c in chain(self.col1, self.col2) + ), + abstract=True + ) + + @property + def table1(self) -> Optional[Table]: + self._validate() + return self.col1[0].table if self.col1 else None + + @property + def table2(self) -> Optional[Table]: + self._validate() + return self.col2[0].table if self.col2 else None + + def __repr__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> Reference('>', col1=c1, col2=c2) + ', ['c1'], ['c2']> + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> Reference('<', col1=[c1, c12], col2=(c2, c22)) + + ''' + + col1 = ', '.join(f'{c.name!r}' for c in self.col1) + col2 = ', '.join(f'{c.name!r}' for c in self.col2) + return f"" + + def __str__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> print(Reference('>', col1=c1, col2=c2)) + Reference([c1] > [c2] + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> print(Reference('<', col1=[c1, c12], col2=(c2, c22))) + Reference([c1, c12] < [c2, c22] + ''' + + col1 = ', '.join(f'{c.name}' for c in self.col1) + col2 = ', '.join(f'{c.name}' for c in self.col2) + return f"Reference([{col1}] {self.type} [{col2}]" + + def _validate(self): + table1 = self.col1[0].table + if any(c.table != table1 for c in self.col1): + raise DBMLError('Columns in col1 are from different tables') + + table2 = self.col2[0].table + if any(c.table != table2 for c in self.col2): + raise DBMLError('Columns in col2 are from different tables') diff --git a/pydbml/_classes/sticky_note.py b/pydbml/_classes/sticky_note.py new file mode 100644 index 0000000..e01e1b0 --- /dev/null +++ b/pydbml/_classes/sticky_note.py @@ -0,0 +1,24 @@ +from typing import Any + +from pydbml._classes.base import DBMLObject + + +class StickyNote(DBMLObject): + dont_compare_fields = ('database',) + + def __init__(self, name: str, text: Any) -> None: + self.name = name + self.text = str(text) if text is not None else '' + + self.database = None + + def __str__(self): + '''StickyNote('mynote', 'Note text')''' + return self.__class__.__name__ + f'({repr(self.name)}, {repr(self.text)})' + + def __bool__(self): + return bool(self.text) + + def __repr__(self): + '''''' + return f'<{self.__class__.__name__} {self.name!r}, {self.text!r}>' diff --git a/pydbml/classes/table.py b/pydbml/_classes/table.py similarity index 56% rename from pydbml/classes/table.py rename to pydbml/_classes/table.py index f493e9c..b5132bf 100644 --- a/pydbml/classes/table.py +++ b/pydbml/_classes/table.py @@ -1,30 +1,23 @@ +from typing import Iterable from typing import List from typing import Optional from typing import TYPE_CHECKING from typing import Union -from typing import Iterable -from .base import SQLObject -from .column import Column -from .index import Index -from .note import Note -from pydbml.constants import MANY_TO_ONE -from pydbml.constants import ONE_TO_MANY -from pydbml.constants import ONE_TO_ONE from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import IndexNotFoundError from pydbml.exceptions import UnknownDatabaseError -from pydbml.tools import comment_to_dbml -from pydbml.tools import comment_to_sql -from pydbml.tools import indent - +from .base import SQLObject, DBMLObject +from .column import Column +from .index import Index +from .note import Note if TYPE_CHECKING: # pragma: no cover from pydbml.database import Database from .reference import Reference -class Table(SQLObject): +class Table(SQLObject, DBMLObject): '''Class representing table.''' required_attributes = ('name', 'schema') @@ -36,7 +29,7 @@ def __init__(self, alias: Optional[str] = None, columns: Optional[Iterable[Column]] = None, indexes: Optional[Iterable[Index]] = None, - note: Optional[Union['Note', str]] = None, + note: Optional[Union[Note, str]] = None, header_color: Optional[str] = None, comment: Optional[str] = None, abstract: bool = False): @@ -100,7 +93,7 @@ def add_index(self, i: Index) -> None: if not isinstance(i, Index): raise TypeError('Indexes must be of type Index') for subject in i.subjects: - if isinstance(subject, Column) and subject.table != self: + if isinstance(subject, Column) and subject.table is not self: raise ColumnNotFoundError(f'Column {subject} not in the table') i.table = self self.indexes.append(i) @@ -121,36 +114,6 @@ def get_refs(self) -> List['Reference']: raise UnknownDatabaseError('Database for the table is not set') return [ref for ref in self.database.refs if ref.table1 == self] - def get_references_for_sql(self) -> List['Reference']: - """ - Return all references in the database where this table is on the left side of SQL - reference definition. - """ - if not self.database: - raise UnknownDatabaseError(f'Database for the table {self} is not set') - result = [] - for ref in self.database.refs: - if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ - (ref.table1 == self): - result.append(ref) - elif (ref.type == ONE_TO_MANY) and (ref.table2 == self): - result.append(ref) - return result - - def _get_references_for_sql(self) -> List['Reference']: - ''' - Return inline references for this table sql definition - ''' - if self.abstract: - return [] - return [r for r in self.get_references_for_sql() if r.inline] - - def _get_full_name_for_sql(self) -> str: - if self.schema == 'public': - return f'"{self.name}"' - else: - return f'"{self.schema}"."{self.name}"' - def __getitem__(self, k: Union[int, str]) -> Column: if isinstance(k, int): return self.columns[k] @@ -190,74 +153,3 @@ def __str__(self): ''' return f'{self.schema}.{self.name}({", ".join(c.name for c in self.columns)})' - - @property - def sql(self): - ''' - Returns full SQL for table definition: - - CREATE TABLE "countries" ( - "code" int PRIMARY KEY, - "name" varchar, - "continent_name" varchar - ); - - Also returns indexes if they were defined: - - CREATE INDEX ON "products" ("id", "name"); - ''' - self.check_attributes_for_sql() - name = self._get_full_name_for_sql() - components = [f'CREATE TABLE {name} ('] - - body = [] - body.extend(indent(c.sql, 2) for c in self.columns) - body.extend(indent(i.sql, 2) for i in self.indexes if i.pk) - body.extend(indent(r.sql, 2) for r in self._get_references_for_sql()) - - if self._has_composite_pk(): - body.append( - " PRIMARY KEY (" - + ', '.join(f'"{c.name}"' for c in self.columns if c.pk) - + ')') - components.append(',\n'.join(body)) - components.append(');') - components.extend('\n' + i.sql for i in self.indexes if not i.pk) - - result = comment_to_sql(self.comment) if self.comment else '' - result += '\n'.join(components) - - if self.note: - result += f'\n\n{self.note.sql}' - - for col in self.columns: - if col.note: - quoted_note = f"'{col.note._prepare_text_for_sql()}'" - note_sql = f'COMMENT ON COLUMN "{self.name}"."{col.name}" IS {quoted_note};' - result += f'\n\n{note_sql}' - return result - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - - name = self._get_full_name_for_sql() - - result += f'Table {name} ' - if self.alias: - result += f'as "{self.alias}" ' - if self.header_color: - result += f'[headercolor: {self.header_color}] ' - result += '{\n' - columns_str = '\n'.join(c.dbml for c in self.columns) - result += indent(columns_str) + '\n' - if self.note: - result += indent(self.note.dbml) + '\n' - if self.indexes: - result += '\n indexes {\n' - indexes_str = '\n'.join(i.dbml for i in self.indexes) - result += indent(indexes_str, 8) + '\n' - result += ' }\n' - - result += '}' - return result diff --git a/pydbml/classes/table_group.py b/pydbml/_classes/table_group.py similarity index 75% rename from pydbml/classes/table_group.py rename to pydbml/_classes/table_group.py index 1f38978..0b6d4dd 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/_classes/table_group.py @@ -1,11 +1,11 @@ from typing import List from typing import Optional -from .table import Table -from pydbml.tools import comment_to_dbml +from pydbml._classes.base import DBMLObject +from pydbml._classes.table import Table -class TableGroup: +class TableGroup(DBMLObject): ''' TableGroup `items` parameter initially holds just the names of the tables, but after parsing the whole document, PyDBMLParseResults class replaces @@ -42,13 +42,3 @@ def __getitem__(self, key: int) -> Table: def __iter__(self): return iter(self.items) - - @property - def dbml(self): - - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'TableGroup {self.name} {{\n' - for i in self.items: - result += f' {i._get_full_name_for_sql()}\n' - result += '}' - return result diff --git a/pydbml/classes/__init__.py b/pydbml/classes/__init__.py index bfdc6e2..a083781 100644 --- a/pydbml/classes/__init__.py +++ b/pydbml/classes/__init__.py @@ -1,10 +1,25 @@ -from .column import Column -from .enum import Enum -from .enum import EnumItem -from .expression import Expression -from .index import Index -from .note import Note -from .project import Project -from .reference import Reference -from .table import Table -from .table_group import TableGroup +from .._classes.column import Column +from .._classes.enum import Enum +from .._classes.enum import EnumItem +from .._classes.expression import Expression +from .._classes.index import Index +from .._classes.note import Note +from .._classes.project import Project +from .._classes.reference import Reference +from .._classes.sticky_note import StickyNote +from .._classes.table import Table +from .._classes.table_group import TableGroup + +__all__ = [ + "Column", + "Enum", + "EnumItem", + "Expression", + "Index", + "Note", + "Project", + "Reference", + "StickyNote", + "Table", + "TableGroup", +] diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py deleted file mode 100644 index c95f3e2..0000000 --- a/pydbml/classes/column.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import List -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union - -from .base import SQLObject -from .expression import Expression -from .enum import Enum -from .note import Note -from pydbml.exceptions import TableNotFoundError -from pydbml.tools import comment_to_dbml -from pydbml.tools import comment_to_sql -from pydbml.tools import note_option_to_dbml - -if TYPE_CHECKING: # pragma: no cover - from .table import Table - from .reference import Reference - - -class Column(SQLObject): - '''Class representing table column.''' - - required_attributes = ('name', 'type') - dont_compare_fields = ('table',) - - def __init__(self, - name: str, - type: Union[str, Enum], - unique: bool = False, - not_null: bool = False, - pk: bool = False, - autoinc: bool = False, - default: Optional[Union[str, int, bool, float, Expression]] = None, - note: Optional[Union['Note', str]] = None, - # ref_blueprints: Optional[List[ReferenceBlueprint]] = None, - comment: Optional[str] = None): - self.name = name - self.type = type - self.unique = unique - self.not_null = not_null - self.pk = pk - self.autoinc = autoinc - self.comment = comment - self.note = Note(note) - - self.default = default - self.table: Optional['Table'] = None - - def __eq__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - return False - self_table = self.table.full_name if self.table else None - other_table = other.table.full_name if other.table else None - if self_table != other_table: - return False - return super().__eq__(other) - - @property - def note(self): - return self._note - - @note.setter - def note(self, val: Note) -> None: - self._note = val - val.parent = self - - def get_refs(self) -> List['Reference']: - ''' - get all references related to this column (where this col is col1 in) - ''' - if not self.table: - raise TableNotFoundError('Table for the column is not set') - return [ref for ref in self.table.get_refs() if self in ref.col1] - - @property - def database(self): - return self.table.database if self.table else None - - @property - def sql(self): - ''' - Returns inline SQL of the column, which should be a part of table definition: - - "id" integer PRIMARY KEY AUTOINCREMENT - ''' - - self.check_attributes_for_sql() - components = [f'"{self.name}"'] - if isinstance(self.type, Enum): - components.append(self.type._get_full_name_for_sql()) - else: - components.append(str(self.type)) - - table_has_composite_pk = False if self.table is None else self.table._has_composite_pk() - if self.pk and not table_has_composite_pk: # comp-PKs are rendered in table sql - components.append('PRIMARY KEY') - if self.autoinc: - components.append('AUTOINCREMENT') - if self.unique: - components.append('UNIQUE') - if self.not_null: - components.append('NOT NULL') - if self.default is not None: - default = self.default.sql \ - if isinstance(self.default, Expression) else self.default - components.append(f'DEFAULT {default}') - - result = comment_to_sql(self.comment) if self.comment else '' - result += ' '.join(components) - return result - - @property - def dbml(self): - def default_to_str(val: Union[Expression, str]) -> str: - if isinstance(val, str): - if val.lower() in ('null', 'true', 'false'): - return val.lower() - else: - return f"'{val}'" - elif isinstance(val, Expression): - return val.dbml - else: # int or float or bool - return val - - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'"{self.name}" ' - if isinstance(self.type, Enum): - result += self.type._get_full_name_for_sql() - else: - result += self.type - - options = [ref.dbml for ref in self.get_refs() if ref.inline] - if self.pk: - options.append('pk') - if self.autoinc: - options.append('increment') - if self.default: - options.append(f'default: {default_to_str(self.default)}') - if self.unique: - options.append('unique') - if self.not_null: - options.append('not null') - if self.note: - options.append(note_option_to_dbml(self.note)) - - if options: - result += f' [{", ".join(options)}]' - return result - - def __repr__(self): - ''' - >>> Column('name', 'VARCHAR2') - - ''' - type_name = self.type if isinstance(self.type, str) else self.type.name - return f'' - - def __str__(self): - ''' - >>> print(Column('name', 'VARCHAR2')) - name[VARCHAR2] - ''' - - return f'{self.name}[{self.type}]' diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py deleted file mode 100644 index a3d771a..0000000 --- a/pydbml/classes/index.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import List -from typing import Literal -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union - -from .base import SQLObject -from .column import Column -from .expression import Expression -from .note import Note -from pydbml.tools import comment_to_dbml -from pydbml.tools import comment_to_sql -from pydbml.tools import note_option_to_dbml - -if TYPE_CHECKING: # pragma: no cover - from .table import Table - - -class Index(SQLObject): - '''Class representing index.''' - required_attributes = ('subjects', 'table') - dont_compare_fields = ('table',) - - def __init__(self, - subjects: List[Union[str, 'Column', 'Expression']], - name: Optional[str] = None, - unique: bool = False, - type: Optional[Literal['hash', 'btree']] = None, - pk: bool = False, - note: Optional[Union['Note', str]] = None, - comment: Optional[str] = None): - self.subjects = subjects - self.table: Optional['Table'] = None - - self.name = name if name else None - self.unique = unique - self.type = type - self.pk = pk - self.note = Note(note) - self.comment = comment - - @property - def note(self): - return self._note - - @note.setter - def note(self, val: Note) -> None: - self._note = val - val.parent = self - - @property - def subject_names(self): - ''' - Returns updated list of subject names. - ''' - return [s.name if isinstance(s, Column) else str(s) for s in self.subjects] - - def __repr__(self): - ''' - >>> c = Column('col', 'int') - >>> i = Index([c, '(c*2)']) - >>> i - - >>> from .table import Table - >>> t = Table('test') - >>> t.add_column(c) - >>> t.add_index(i) - >>> i - - ''' - - table_name = self.table.name if self.table else None - return f"" - - def __str__(self): - ''' - >>> c = Column('col', 'int') - >>> i = Index([c, '(c*2)']) - >>> print(i) - Index([col, (c*2)]) - >>> from .table import Table - >>> t = Table('test') - >>> t.add_column(c) - >>> t.add_index(i) - >>> print(i) - Index(test[col, (c*2)]) - ''' - - table_name = self.table.name if self.table else '' - subjects = ', '.join(self.subject_names) - return f"Index({table_name}[{subjects}])" - - @property - def sql(self): - ''' - Returns inline SQL of the index to be created separately from table - definition: - - CREATE UNIQUE INDEX ON "products" USING HASH ("id"); - - But if it's a (composite) primary key index, returns an inline SQL for - composite primary key to be used inside table definition: - - PRIMARY KEY ("id", "name") - - ''' - self.check_attributes_for_sql() - subjects = [] - - for subj in self.subjects: - if isinstance(subj, Column): - subjects.append(f'"{subj.name}"') - elif isinstance(subj, Expression): - subjects.append(subj.sql) - else: - subjects.append(subj) - keys = ', '.join(subj for subj in subjects) - if self.pk: - result = comment_to_sql(self.comment) if self.comment else '' - result += f'PRIMARY KEY ({keys})' - return result - - components = ['CREATE'] - if self.unique: - components.append('UNIQUE') - components.append('INDEX') - if self.name: - components.append(f'"{self.name}"') - components.append(f'ON "{self.table.name}"') - if self.type: - components.append(f'USING {self.type.upper()}') - components.append(f'({keys})') - result = comment_to_sql(self.comment) if self.comment else '' - result += ' '.join(components) + ';' - return result - - @property - def dbml(self): - subjects = [] - - for subj in self.subjects: - if isinstance(subj, Column): - subjects.append(subj.name) - elif isinstance(subj, Expression): - subjects.append(subj.dbml) - else: - subjects.append(subj) - - result = comment_to_dbml(self.comment) if self.comment else '' - - if len(subjects) > 1: - result += f'({", ".join(subj for subj in subjects)})' - else: - result += subjects[0] - - options = [] - if self.name: - options.append(f"name: '{self.name}'") - if self.pk: - options.append('pk') - if self.unique: - options.append('unique') - if self.type: - options.append(f'type: {self.type}') - if self.note: - options.append(note_option_to_dbml(self.note)) - - if options: - result += f' [{", ".join(options)}]' - return result diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py deleted file mode 100644 index 5dd4dcf..0000000 --- a/pydbml/classes/note.py +++ /dev/null @@ -1,87 +0,0 @@ -import re -from typing import Any, Union - -from .base import SQLObject -from pydbml.tools import indent -from pydbml import classes - - -class Note(SQLObject): - dont_compare_fields = ('parent',) - - def __init__(self, text: Any) -> None: - self.text: str - self.text = str(text) if text is not None else '' - self.parent: Any = None - - def __str__(self): - ''' - >>> print(Note('Note text')) - Note text - ''' - - return self.text - - def __bool__(self): - return bool(self.text) - - def __repr__(self): - ''' - >>> Note('Note text') - Note('Note text') - ''' - - return f'Note({repr(self.text)})' - - def _prepare_text_for_sql(self) -> str: - ''' - - Process special escape sequence: slash before line break, which means no line break - https://www.dbml.org/docs/#multi-line-string - - - replace all single quotes with double quotes - ''' - - pattern = re.compile(r'\\\n') - result = pattern.sub('', self.text) - - result = result.replace("'", '"') - return result - - def _prepare_text_for_dbml(self): - '''Escape single quotes''' - pattern = re.compile(r"('''|')") - return pattern.sub(r'\\\1', self.text) - - def generate_comment_on(self, entity: str, name: str) -> str: - """Generate a COMMENT ON clause out from this note.""" - quoted_text = f"'{self._prepare_text_for_sql()}'" - note_sql = f'COMMENT ON {entity.upper()} "{name}" IS {quoted_text};' - return note_sql - - @property - def sql(self): - """ - For Tables and Columns Note is converted into COMMENT ON clause. All other entities don't - have notes generated in their SQL code, but as a fallback their notes are rendered as SQL - comments when sql property is called directly. - """ - if self.text: - if isinstance(self.parent, (classes.Table, classes.Column)): - return self.generate_comment_on(self.parent.__class__.__name__, self.parent.name) - else: - text = self._prepare_text_for_sql() - return '\n'.join(f'-- {line}' for line in text.split('\n')) - else: - return '' - - @property - def dbml(self): - text = self._prepare_text_for_dbml() - if '\n' in text: - note_text = f"'''\n{text}\n'''" - else: - note_text = f"'{text}'" - - note_text = indent(note_text) - result = f'Note {{\n{note_text}\n}}' - return result diff --git a/pydbml/classes/project.py b/pydbml/classes/project.py deleted file mode 100644 index 6069fc8..0000000 --- a/pydbml/classes/project.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Dict -from typing import Optional -from typing import Union - -from .note import Note -from pydbml.tools import comment_to_dbml -from pydbml.tools import indent - - -class Project: - dont_compare_fields = ('database',) - - def __init__(self, - name: str, - items: Optional[Dict[str, str]] = None, - note: Optional[Union['Note', str]] = None, - comment: Optional[str] = None): - self.database = None - self.name = name - self.items = items - self.note = Note(note) - self.comment = comment - - def __repr__(self): - """ - >>> Project('myproject') - - """ - - return f'' - - @property - def note(self): - return self._note - - @note.setter - def note(self, val: Note) -> None: - self._note = val - val.parent = self - - @property - def dbml(self): - result = comment_to_dbml(self.comment) if self.comment else '' - result += f'Project "{self.name}" {{\n' - if self.items: - items_str = '' - for k, v in self.items.items(): - if '\n' in v: - items_str += f"{k}: '''{v}'''\n" - else: - items_str += f"{k}: '{v}'\n" - result += indent(items_str[:-1]) + '\n' - if self.note: - result += indent(self.note.dbml) + '\n' - result += '}' - return result diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py deleted file mode 100644 index cb37e3b..0000000 --- a/pydbml/classes/reference.py +++ /dev/null @@ -1,245 +0,0 @@ -from itertools import chain -from typing import Collection -from typing import List -from typing import Literal -from typing import Optional -from typing import Union - -from pydbml.constants import MANY_TO_MANY -from pydbml.constants import MANY_TO_ONE -from pydbml.constants import ONE_TO_MANY -from pydbml.constants import ONE_TO_ONE -from pydbml.exceptions import DBMLError -from pydbml.exceptions import TableNotFoundError -from pydbml.tools import comment_to_dbml -from pydbml.tools import comment_to_sql -from .base import SQLObject -from .column import Column -from .table import Table - - -class Reference(SQLObject): - ''' - Class, representing a foreign key constraint. - It is a separate object, which is not connected to Table or Column objects - and its `sql` property contains the ALTER TABLE clause. - ''' - required_attributes = ('type', 'col1', 'col2') - dont_compare_fields = ('database', '_inline') - - def __init__(self, - type: Literal['>', '<', '-', '<>'], - col1: Union[Column, Collection[Column]], - col2: Union[Column, Collection[Column]], - name: Optional[str] = None, - comment: Optional[str] = None, - on_update: Optional[str] = None, - on_delete: Optional[str] = None, - inline: bool = False): - self.database = None - self.type = type - self.col1 = [col1] if isinstance(col1, Column) else list(col1) - self.col2 = [col2] if isinstance(col2, Column) else list(col2) - self.name = name if name else None - self.comment = comment - self.on_update = on_update - self.on_delete = on_delete - self._inline = inline - - @property - def inline(self) -> bool: - return self._inline and not self.type == MANY_TO_MANY - - @inline.setter - def inline(self, val) -> None: - self._inline = val - - @property - def join_table(self) -> Optional['Table']: - if self.type != MANY_TO_MANY: - return None - - if self.table1 is None: - raise TableNotFoundError(f"Cannot generate join table for {self}: table 1 is unknown") - if self.table2 is None: - raise TableNotFoundError(f"Cannot generate join table for {self}: table 2 is unknown") - - return Table( - name=f'{self.table1.name}_{self.table2.name}', - schema=self.table1.schema, - columns=( - Column(name=f'{c.table.name}_{c.name}', type=c.type, not_null=True, pk=True) # type: ignore - for c in chain(self.col1, self.col2) - ), - abstract=True - ) - - @property - def table1(self) -> Optional['Table']: - self._validate() - return self.col1[0].table if self.col1 else None - - @property - def table2(self) -> Optional['Table']: - self._validate() - return self.col2[0].table if self.col2 else None - - def __repr__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> Reference('>', col1=c1, col2=c2) - ', ['c1'], ['c2']> - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> Reference('<', col1=[c1, c12], col2=(c2, c22)) - - ''' - - col1 = ', '.join(f'{c.name!r}' for c in self.col1) - col2 = ', '.join(f'{c.name!r}' for c in self.col2) - return f"" - - def __str__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> print(Reference('>', col1=c1, col2=c2)) - Reference([c1] > [c2] - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> print(Reference('<', col1=[c1, c12], col2=(c2, c22))) - Reference([c1, c12] < [c2, c22] - ''' - - col1 = ', '.join(f'{c.name}' for c in self.col1) - col2 = ', '.join(f'{c.name}' for c in self.col2) - return f"Reference([{col1}] {self.type} [{col2}]" - - def _validate(self): - table1 = self.col1[0].table - if any(c.table != table1 for c in self.col1): - raise DBMLError('Columns in col1 are from different tables') - - table2 = self.col2[0].table - if any(c.table != table2 for c in self.col2): - raise DBMLError('Columns in col2 are from different tables') - - def _validate_for_sql(self): - for col in chain(self.col1, self.col2): - if col.table is None: - raise TableNotFoundError(f'Table on {col} is not set') - - def _generate_inline_sql(self, source_col: List['Column'], ref_col: List['Column']) -> str: - result = comment_to_sql(self.comment) if self.comment else '' - result += ( - f'{{c}}FOREIGN KEY ({self._col_names(source_col)}) ' # type: ignore - f'REFERENCES {ref_col[0].table._get_full_name_for_sql()} ({self._col_names(ref_col)})' # type: ignore - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result - - def _generate_not_inline_sql(self, c1: List['Column'], c2: List['Column']): - result = comment_to_sql(self.comment) if self.comment else '' - result += ( - f'ALTER TABLE {c1[0].table._get_full_name_for_sql()}' # type: ignore - f' ADD {{c}}FOREIGN KEY ({self._col_names(c1)})' - f' REFERENCES {c2[0].table._get_full_name_for_sql()} ({self._col_names(c2)})' # type: ignore - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result + ';' - - def _generate_many_to_many_sql(self) -> str: - join_table = self.join_table - table_sql = join_table.sql # type: ignore - - n = len(self.col1) - ref1_sql = self._generate_not_inline_sql(join_table.columns[:n], self.col1) # type: ignore - ref2_sql = self._generate_not_inline_sql(join_table.columns[n:], self.col2) # type: ignore - - result = '\n\n'.join((table_sql, ref1_sql, ref2_sql)) - return result.format(c='') - - @staticmethod - def _col_names(cols: List[Column]) -> str: - return ', '.join(f'"{c.name}"' for c in cols) - - @property - def sql(self) -> str: - ''' - Returns SQL of the reference: - - ALTER TABLE "orders" ADD FOREIGN KEY ("customer_id") REFERENCES "customers ("id"); - - ''' - self.check_attributes_for_sql() - self._validate_for_sql() - - if self.type == MANY_TO_MANY: - return self._generate_many_to_many_sql() - - result = '' - if self.inline: - if self.type in (MANY_TO_ONE, ONE_TO_ONE): - result = self._generate_inline_sql(self.col1, self.col2) - elif self.type == ONE_TO_MANY: - result = self._generate_inline_sql(self.col2, self.col1) - else: - if self.type in (MANY_TO_ONE, ONE_TO_ONE): - result = self._generate_not_inline_sql(c1=self.col1, c2=self.col2) - elif self.type == ONE_TO_MANY: - result = self._generate_not_inline_sql(c1=self.col2, c2=self.col1) - - c = f'CONSTRAINT "{self.name}" ' if self.name else '' - - return result.format(c=c) - - @property - def dbml(self) -> str: - self._validate_for_sql() - if self.inline: - # settings are ignored for inline ref - if len(self.col2) > 1: - raise DBMLError('Cannot render DBML: composite ref cannot be inline') - table_name = self.col2[0].table._get_full_name_for_sql() # type: ignore - return f'ref: {self.type} {table_name}."{self.col2[0].name}"' - else: - result = comment_to_dbml(self.comment) if self.comment else '' - result += 'Ref' - if self.name: - result += f' {self.name}' - - if len(self.col1) == 1: - col1 = f'"{self.col1[0].name}"' - else: - names = (f'"{c.name}"' for c in self.col1) - col1 = f'({", ".join(names)})' - - if len(self.col2) == 1: - col2 = f'"{self.col2[0].name}"' - else: - names = (f'"{c.name}"' for c in self.col2) - col2 = f'({", ".join(names)})' - - options = [] - if self.on_update: - options.append(f'update: {self.on_update}') - if self.on_delete: - options.append(f'delete: {self.on_delete}') - - options_str = f' [{", ".join(options)}]' if options else '' - result += ( - ' {\n ' # type: ignore - f'{self.table1._get_full_name_for_sql()}.{col1} ' # type: ignore - f'{self.type} ' - f'{self.table2._get_full_name_for_sql()}.{col2}' # type: ignore - f'{options_str}' - '\n}' - ) - return result diff --git a/pydbml/classes/sticky_note.py b/pydbml/classes/sticky_note.py deleted file mode 100644 index b843cca..0000000 --- a/pydbml/classes/sticky_note.py +++ /dev/null @@ -1,50 +0,0 @@ -import re -from typing import Any - -from pydbml.tools import indent - - -class StickyNote: - dont_compare_fields = ('database',) - - def __init__(self, name: str, text: Any) -> None: - self.name = name - self.text = str(text) if text is not None else '' - - self.database = None - - def __str__(self): - ''' - >>> print(StickyNote('mynote', 'Note text')) - StickyNote('mynote', 'Note text') - ''' - - return self.__class__.__name__ + f'({repr(self.name)}, {repr(self.text)})' - - def __bool__(self): - return bool(self.text) - - def __repr__(self): - ''' - >>> StickyNote('mynote', 'Note text') - - ''' - - return f'<{self.__class__.__name__} {self.name!r}, {self.text!r}>' - - def _prepare_text_for_dbml(self): - '''Escape single quotes''' - pattern = re.compile(r"('''|')") - return pattern.sub(r'\\\1', self.text) - - @property - def dbml(self): - text = self._prepare_text_for_dbml() - if '\n' in text: - note_text = f"'''\n{text}\n'''" - else: - note_text = f"'{text}'" - - note_text = indent(note_text) - result = f'Note {self.name} {{\n{note_text}\n}}' - return result diff --git a/pydbml/database.py b/pydbml/database.py index 1347909..5585751 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -1,42 +1,31 @@ -from typing import Any +from typing import Any, Type from typing import Dict from typing import List from typing import Optional from typing import Union -from .classes import Enum, Note +from .classes import Enum from .classes import Project from .classes import Reference from .classes import Table from .classes import TableGroup -from .classes.sticky_note import StickyNote +from ._classes.sticky_note import StickyNote from .exceptions import DatabaseValidationError -from .constants import MANY_TO_ONE, ONE_TO_MANY - - -def reorder_tables_for_sql(tables: List['Table'], refs: List['Reference']) -> List['Table']: - """ - Attempt to reorder the tables, so that they are defined in SQL before they are referenced by - inline foreign keys. - - Won't aid the rare cases of cross-references and many-to-many relations. - """ - references: Dict[str, int] = {} - for ref in refs: - if ref.inline: - if ref.type == MANY_TO_ONE and ref.table1 is not None: - table_name = ref.table1.name - elif ref.type == ONE_TO_MANY and ref.table2 is not None: - table_name = ref.table2.name - else: - continue - references[table_name] = references.get(table_name, 0) + 1 - return sorted(tables, key=lambda t: references.get(t.name, 0), reverse=True) +from .renderer.base import BaseRenderer +from .renderer.dbml.default.renderer import DefaultDBMLRenderer +from .renderer.sql.default import DefaultSQLRenderer +from .renderer.sql.default.utils import reorder_tables_for_sql class Database: - def __init__(self) -> None: + def __init__( + self, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer + ) -> None: + self.sql_renderer = sql_renderer + self.dbml_renderer = dbml_renderer self.tables: List['Table'] = [] self.table_dict: Dict[str, 'Table'] = {} self.refs: List['Reference'] = [] @@ -46,11 +35,6 @@ def __init__(self) -> None: self.project: Optional['Project'] = None def __repr__(self) -> str: - """ - >>> Database() - - """ - return f"" def __getitem__(self, k: Union[int, str]) -> Table: @@ -215,18 +199,9 @@ def delete_project(self) -> Project: @property def sql(self): '''Returs SQL of the parsed results''' - refs = (ref for ref in self.refs if not ref.inline) - tables = reorder_tables_for_sql(self.tables, self.refs) - components = (i.sql for i in (*self.enums, *tables, *refs)) - return '\n\n'.join(components) + return self.sql_renderer.render_db(self) @property def dbml(self): '''Generates DBML code out of parsed results''' - items = [self.project] if self.project else [] - refs = (ref for ref in self.refs if not ref.inline) - items.extend((*self.enums, *self.tables, *refs, *self.table_groups, *self.sticky_notes)) - components = ( - i.dbml for i in items - ) - return '\n\n'.join(components) + return self.dbml_renderer.render_db(self) diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 8ceea10..1077c29 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -17,7 +17,7 @@ from pydbml.classes import Reference from pydbml.classes import Table from pydbml.classes import TableGroup -from pydbml.classes.sticky_note import StickyNote +from pydbml._classes.sticky_note import StickyNote from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError from pydbml.exceptions import ValidationError diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 0f3dabb..7b0d854 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -64,11 +64,6 @@ def __new__(cls, source_: Optional[Union[str, Path, TextIOWrapper]] = None): return super().__new__(cls) def __repr__(self): - """ - >>> PyDBML() - - """ - return "" @staticmethod @@ -109,11 +104,6 @@ def parse(self): return self.database def __repr__(self): - """ - >>> PyDBMLParser('') - - """ - return "" def _set_syntax(self): diff --git a/pydbml/renderer/__init__.py b/pydbml/renderer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/renderer/base.py b/pydbml/renderer/base.py new file mode 100644 index 0000000..a8a0c77 --- /dev/null +++ b/pydbml/renderer/base.py @@ -0,0 +1,38 @@ +from typing import Type, Callable, Dict, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from pydbml.database import Database + + +def unsupported_renderer(model) -> str: + return '' + + +class BaseRenderer: + _unsupported_renderer = unsupported_renderer + + @property + def model_renderers(cls) -> Dict[Type, Callable]: + """A class attribute dictionary to store the model renderers.""" + raise NotImplementedError # pragma: no cover + + @classmethod + def render(cls, model) -> str: + """ + Render the model to a string. If the model is not supported, fall back to + `self._unsupported_renderer` that by default returns an empty string. + """ + + return cls.model_renderers.get(type(model), cls._unsupported_renderer)(model) + + @classmethod + def renderer_for(cls, model_cls: Type) -> Callable: + """A decorator to register a renderer for a model class.""" + def decorator(func) -> Callable: + cls.model_renderers[model_cls] = func + return func + return decorator + + @classmethod + def render_db(cls, db: 'Database') -> str: + raise NotImplementedError # pragma: no cover diff --git a/pydbml/renderer/dbml/__init__.py b/pydbml/renderer/dbml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/renderer/dbml/default/__init__.py b/pydbml/renderer/dbml/default/__init__.py new file mode 100644 index 0000000..82f7386 --- /dev/null +++ b/pydbml/renderer/dbml/default/__init__.py @@ -0,0 +1,10 @@ +from .renderer import DefaultDBMLRenderer +from .column import render_column +from .enum import render_enum, render_enum_item +from .expression import render_expression +from .index import render_index +from .project import render_project +from .reference import render_reference +from .sticky_note import render_sticky_note +from .table import render_table +from .table_group import render_table_group diff --git a/pydbml/renderer/dbml/default/column.py b/pydbml/renderer/dbml/default/column.py new file mode 100644 index 0000000..6b4dd62 --- /dev/null +++ b/pydbml/renderer/dbml/default/column.py @@ -0,0 +1,51 @@ +from typing import Union + +from pydbml.classes import Column, Enum, Expression +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml +from pydbml.renderer.sql.default.utils import get_full_name_for_sql + + +def default_to_str(val: Union[Expression, str, int, float]) -> str: + if isinstance(val, str): + if val.lower() in ('null', 'true', 'false'): + return val.lower() + else: + return f"'{val}'" + elif isinstance(val, Expression): + return val.dbml + else: # int or float or bool + return str(val) + + +def render_options(model: Column) -> str: + options = [ref.dbml for ref in model.get_refs() if ref.inline] + if model.pk: + options.append('pk') + if model.autoinc: + options.append('increment') + if model.default: + options.append(f'default: {default_to_str(model.default)}') + if model.unique: + options.append('unique') + if model.not_null: + options.append('not null') + if model.note: + options.append(note_option_to_dbml(model.note)) + + if options: + return f' [{", ".join(options)}]' + return '' + + +@DefaultDBMLRenderer.renderer_for(Column) +def render_column(model: Column) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'"{model.name}" ' + if isinstance(model.type, Enum): + result += get_full_name_for_sql(model.type) + else: + result += model.type + + result += render_options(model) + return result diff --git a/pydbml/renderer/dbml/default/enum.py b/pydbml/renderer/dbml/default/enum.py new file mode 100644 index 0000000..f89c310 --- /dev/null +++ b/pydbml/renderer/dbml/default/enum.py @@ -0,0 +1,25 @@ +from textwrap import indent + +from pydbml.classes import Enum, EnumItem +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml +from pydbml.renderer.sql.default.utils import get_full_name_for_sql + + +@DefaultDBMLRenderer.renderer_for(Enum) +def render_enum(model: Enum) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'Enum {get_full_name_for_sql(model)} {{\n' + items_str = '\n'.join(DefaultDBMLRenderer.render(i) for i in model.items) + result += indent(items_str, ' ') + result += '\n}' + return result + + +@DefaultDBMLRenderer.renderer_for(EnumItem) +def render_enum_item(model: EnumItem) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'"{model.name}"' + if model.note: + result += f' [{note_option_to_dbml(model.note)}]' + return result diff --git a/pydbml/renderer/dbml/default/expression.py b/pydbml/renderer/dbml/default/expression.py new file mode 100644 index 0000000..627f286 --- /dev/null +++ b/pydbml/renderer/dbml/default/expression.py @@ -0,0 +1,7 @@ +from pydbml.classes import Expression +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer + + +@DefaultDBMLRenderer.renderer_for(Expression) +def render_expression(model: Expression) -> str: + return f'`{model.text}`' diff --git a/pydbml/renderer/dbml/default/index.py b/pydbml/renderer/dbml/default/index.py new file mode 100644 index 0000000..74d6081 --- /dev/null +++ b/pydbml/renderer/dbml/default/index.py @@ -0,0 +1,49 @@ +from typing import List, Any + +from pydbml.classes import Index, Expression, Column +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml + + +def render_subjects(source_subjects: List[Any]) -> str: + subjects = [] + + for subj in source_subjects: + if isinstance(subj, Column): + subjects.append(subj.name) + elif isinstance(subj, Expression): + subjects.append(DefaultDBMLRenderer.render(subj)) + else: + subjects.append(subj) + + if len(subjects) > 1: + return f'({", ".join(subj for subj in subjects)})' + else: + return subjects[0] + + +def render_options(model: Index) -> str: + options = [] + if model.name: + options.append(f"name: '{model.name}'") + if model.pk: + options.append('pk') + if model.unique: + options.append('unique') + if model.type: + options.append(f'type: {model.type}') + if model.note: + options.append(note_option_to_dbml(model.note)) + + if options: + return f' [{", ".join(options)}]' + return '' + + +@DefaultDBMLRenderer.renderer_for(Index) +def render_index(model: Index) -> str: + return ( + (comment_to_dbml(model.comment) if model.comment else '') + + render_subjects(model.subjects) + + render_options(model) + ) diff --git a/pydbml/renderer/dbml/default/note.py b/pydbml/renderer/dbml/default/note.py new file mode 100644 index 0000000..b07e023 --- /dev/null +++ b/pydbml/renderer/dbml/default/note.py @@ -0,0 +1,24 @@ +import re +from textwrap import indent + +from pydbml.classes import Note +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer + + +def prepare_text_for_dbml(model): + '''Escape single quotes''' + pattern = re.compile(r"('''|')") + return pattern.sub(r'\\\1', model.text) + + +@DefaultDBMLRenderer.renderer_for(Note) +def render_note(model: Note) -> str: + text = prepare_text_for_dbml(model) + if '\n' in text: + note_text = f"'''\n{text}\n'''" + else: + note_text = f"'{text}'" + + note_text = indent(note_text, ' ') + result = f'Note {{\n{note_text}\n}}' + return result diff --git a/pydbml/renderer/dbml/default/project.py b/pydbml/renderer/dbml/default/project.py new file mode 100644 index 0000000..eed74a8 --- /dev/null +++ b/pydbml/renderer/dbml/default/project.py @@ -0,0 +1,27 @@ +from textwrap import indent +from typing import Dict + +from pydbml.classes import Project +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml + + +def render_items(items: Dict[str, str]) -> str: + items_str = '' + for k, v in items.items(): + if '\n' in v: + items_str += f"{k}: '''{v}'''\n" + else: + items_str += f"{k}: '{v}'\n" + return indent(items_str.rstrip('\n'), ' ') + '\n' + + +@DefaultDBMLRenderer.renderer_for(Project) +def render_project(model: Project) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'Project "{model.name}" {{\n' + result += render_items(model.items) + if model.note: + result += indent(DefaultDBMLRenderer.render(model.note), ' ') + '\n' + result += '}' + return result diff --git a/pydbml/renderer/dbml/default/reference.py b/pydbml/renderer/dbml/default/reference.py new file mode 100644 index 0000000..abf8874 --- /dev/null +++ b/pydbml/renderer/dbml/default/reference.py @@ -0,0 +1,68 @@ +from itertools import chain +from textwrap import indent +from typing import List + +from pydbml.classes import Reference, Column +from pydbml.exceptions import TableNotFoundError, DBMLError +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml +from .table import get_full_name_for_dbml + + +def validate_for_dbml(model: Reference): + for col in chain(model.col1, model.col2): + if col.table is None: + raise TableNotFoundError(f'Table on {col} is not set') + + +def render_inline_reference(model: Reference) -> str: + # settings are ignored for inline ref + if len(model.col2) > 1: + raise DBMLError('Cannot render DBML: composite ref cannot be inline') + table_name = get_full_name_for_dbml(model.col2[0].table) + return f'ref: {model.type} {table_name}."{model.col2[0].name}"' + + +def render_col(col: List[Column]) -> str: + if len(col) == 1: + return f'"{col[0].name}"' + else: + names = (f'"{c.name}"' for c in col) + return f'({", ".join(names)})' + + +def render_options(model: Reference) -> str: + options = [] + if model.on_update: + options.append(f'update: {model.on_update}') + if model.on_delete: + options.append(f'delete: {model.on_delete}') + if options: + return f' [{", ".join(options)}]' + return '' + + +def render_not_inline_reference(model: Reference) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += 'Ref' + if model.name: + result += f' {model.name}' + + result += ( + ' {\n ' # type: ignore + f'{get_full_name_for_dbml(model.table1)}.{render_col(model.col1)} ' + f'{model.type} ' + f'{get_full_name_for_dbml(model.table2)}.{render_col(model.col2)}' + f'{render_options(model)}' + '\n}' + ) + return result + + +@DefaultDBMLRenderer.renderer_for(Reference) +def render_reference(model: Reference) -> str: + validate_for_dbml(model) + if model.inline: + return render_inline_reference(model) + else: + return render_not_inline_reference(model) diff --git a/pydbml/renderer/dbml/default/renderer.py b/pydbml/renderer/dbml/default/renderer.py new file mode 100644 index 0000000..0445c75 --- /dev/null +++ b/pydbml/renderer/dbml/default/renderer.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING, List + +from pydbml.renderer.base import BaseRenderer +from pydbml._classes.base import DBMLObject + +if TYPE_CHECKING: # pragma: no cover + from pydbml.database import Database + + +class DefaultDBMLRenderer(BaseRenderer): + model_renderers = {} + + @classmethod + def render_db(cls, db: 'Database') -> str: + items: List[DBMLObject] = [db.project] if db.project else [] + refs = (ref for ref in db.refs if not ref.inline) + items.extend((*db.enums, *db.tables, *refs, *db.table_groups, *db.sticky_notes)) + + return '\n\n'.join(cls.render(i) for i in items) diff --git a/pydbml/renderer/dbml/default/sticky_note.py b/pydbml/renderer/dbml/default/sticky_note.py new file mode 100644 index 0000000..e2af3c8 --- /dev/null +++ b/pydbml/renderer/dbml/default/sticky_note.py @@ -0,0 +1,18 @@ +from textwrap import indent + +from pydbml.classes import StickyNote +from pydbml.renderer.dbml.default.note import prepare_text_for_dbml +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer + + +@DefaultDBMLRenderer.renderer_for(StickyNote) +def render_sticky_note(model: StickyNote) -> str: + text = prepare_text_for_dbml(model) + if '\n' in text: + note_text = f"'''\n{text}\n'''" + else: + note_text = f"'{text}'" + + note_text = indent(note_text, ' ') + result = f'Note {model.name} {{\n{note_text}\n}}' + return result diff --git a/pydbml/renderer/dbml/default/table.py b/pydbml/renderer/dbml/default/table.py new file mode 100644 index 0000000..bfd02ed --- /dev/null +++ b/pydbml/renderer/dbml/default/table.py @@ -0,0 +1,51 @@ +import re +from textwrap import indent + +from pydbml.classes import Table +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml + + +def get_full_name_for_dbml(model) -> str: + if model.schema == 'public': + return f'"{model.name}"' + else: + return f'"{model.schema}"."{model.name}"' + + +def render_header(model: Table) -> str: + name = get_full_name_for_dbml(model) + + result = f'Table {name} ' + if model.alias: + result += f'as "{model.alias}" ' + if model.header_color: + result += f'[headercolor: {model.header_color}] ' + return result + + +def render_indexes(model: Table) -> str: + if model.indexes: + result = '\n indexes {\n' + indexes_str = '\n'.join(DefaultDBMLRenderer.render(i) for i in model.indexes) + result += indent(indexes_str, ' ') + '\n' + result += ' }\n' + return result + return '' + + +@DefaultDBMLRenderer.renderer_for(Table) +def render_table(model: Table) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += render_header(model) + + result += '{\n' + columns_str = '\n'.join(DefaultDBMLRenderer.render(c) for c in model.columns) + result += indent(columns_str, ' ') + '\n' + if model.note: + result += indent(model.note.dbml, ' ') + '\n' + + result += render_indexes(model) + + result += '}' + return result diff --git a/pydbml/renderer/dbml/default/table_group.py b/pydbml/renderer/dbml/default/table_group.py new file mode 100644 index 0000000..29d4e85 --- /dev/null +++ b/pydbml/renderer/dbml/default/table_group.py @@ -0,0 +1,14 @@ +from pydbml.classes import TableGroup +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.table import get_full_name_for_dbml +from pydbml.renderer.dbml.default.utils import comment_to_dbml + + +@DefaultDBMLRenderer.renderer_for(TableGroup) +def render_table_group(model: TableGroup) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'TableGroup {model.name} {{\n' + for i in model.items: + result += f' {get_full_name_for_dbml(i)}\n' + result += '}' + return result diff --git a/pydbml/renderer/dbml/default/utils.py b/pydbml/renderer/dbml/default/utils.py new file mode 100644 index 0000000..60f8544 --- /dev/null +++ b/pydbml/renderer/dbml/default/utils.py @@ -0,0 +1,13 @@ +from pydbml.renderer.dbml.default.note import prepare_text_for_dbml +from pydbml.tools import comment + + +def note_option_to_dbml(note: 'Note') -> str: + if '\n' in note.text: + return f"note: '''{prepare_text_for_dbml(note)}'''" + else: + return f"note: '{prepare_text_for_dbml(note)}'" + + +def comment_to_dbml(val: str) -> str: + return comment(val, '//') diff --git a/pydbml/renderer/sql/__init__.py b/pydbml/renderer/sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/renderer/sql/default/__init__.py b/pydbml/renderer/sql/default/__init__.py new file mode 100644 index 0000000..c3b1e48 --- /dev/null +++ b/pydbml/renderer/sql/default/__init__.py @@ -0,0 +1,8 @@ +from .renderer import DefaultSQLRenderer +from .column import render_column +from .enum import render_enum, render_enum_item +from .expression import render_expression +from .index import render_index +from .note import render_note +from .reference import render_reference +from .table import render_table diff --git a/pydbml/renderer/sql/default/column.py b/pydbml/renderer/sql/default/column.py new file mode 100644 index 0000000..d890ad3 --- /dev/null +++ b/pydbml/renderer/sql/default/column.py @@ -0,0 +1,39 @@ +from pydbml.classes import Column, Enum, Expression +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from .utils import comment_to_sql +from .enum import get_full_name_for_sql as get_full_name_for_sql_enum + + +@DefaultSQLRenderer.renderer_for(Column) +def render_column(model: Column) -> str: + ''' + Returns inline SQL of the column, which should be a part of table definition: + + "id" integer PRIMARY KEY AUTOINCREMENT + ''' + + components = [f'"{model.name}"'] + if isinstance(model.type, Enum): + components.append(get_full_name_for_sql_enum(model.type)) + else: + components.append(str(model.type)) + + table_has_composite_pk = model.table._has_composite_pk() if model.table else False + if model.pk and not table_has_composite_pk: # composite PKs are rendered in table sql + components.append('PRIMARY KEY') + if model.autoinc: + components.append('AUTOINCREMENT') + if model.unique: + components.append('UNIQUE') + if model.not_null: + components.append('NOT NULL') + if model.default is not None: + if isinstance(model.default, Expression): + default = DefaultSQLRenderer.render(model.default) + else: + default = model.default + components.append(f'DEFAULT {default}') + + result = comment_to_sql(model.comment) if model.comment else '' + result += ' '.join(components) + return result diff --git a/pydbml/renderer/sql/default/enum.py b/pydbml/renderer/sql/default/enum.py new file mode 100644 index 0000000..6bc8626 --- /dev/null +++ b/pydbml/renderer/sql/default/enum.py @@ -0,0 +1,33 @@ +from textwrap import indent + +from pydbml._classes.enum import EnumItem +from pydbml.classes import Enum +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql, get_full_name_for_sql + + +@DefaultSQLRenderer.renderer_for(Enum) +def render_enum(model: Enum) -> str: + ''' + Returns SQL for enum type: + + CREATE TYPE "job_status" AS ENUM ( + 'created', + 'running', + 'done', + 'failure', + ); + ''' + + result = comment_to_sql(model.comment) if model.comment else '' + result += f'CREATE TYPE {get_full_name_for_sql(model)} AS ENUM (\n' + result += '\n'.join(f'{indent(DefaultSQLRenderer.render(i), " ")}' for i in model.items) + result += '\n);' + return result + + +@DefaultSQLRenderer.renderer_for(EnumItem) +def render_enum_item(model: EnumItem) -> str: + result = comment_to_sql(model.comment) if model.comment else '' + result += f"'{model.name}'," + return result diff --git a/pydbml/renderer/sql/default/expression.py b/pydbml/renderer/sql/default/expression.py new file mode 100644 index 0000000..b080ee7 --- /dev/null +++ b/pydbml/renderer/sql/default/expression.py @@ -0,0 +1,7 @@ +from pydbml.classes import Expression +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer + + +@DefaultSQLRenderer.renderer_for(Expression) +def render_expression(model: Expression) -> str: + return f'({model.text})' diff --git a/pydbml/renderer/sql/default/index.py b/pydbml/renderer/sql/default/index.py new file mode 100644 index 0000000..3202496 --- /dev/null +++ b/pydbml/renderer/sql/default/index.py @@ -0,0 +1,65 @@ +from typing import Any + +from pydbml.classes import Expression, Index, Column +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql + + +def render_subject(subject: Any) -> str: + if isinstance(subject, Column): + return f'"{subject.name}"' + elif isinstance(subject, Expression): + return DefaultSQLRenderer.render(subject) + else: + return subject + + +def render_pk(model: Index, keys: str) -> str: + result = comment_to_sql(model.comment) if model.comment else '' + result += f'PRIMARY KEY ({keys})' + return result + + +def create_components(model: Index, keys: str) -> str: + components = [] + if model.comment: + components.append(comment_to_sql(model.comment)) + + components.append('CREATE ') + + if model.unique: + components.append('UNIQUE ') + + components.append('INDEX ') + + if model.name: + components.append(f'"{model.name}" ') + if model.table: + components.append(f'ON "{model.table.name}" ') + + if model.type: + components.append(f'USING {model.type.upper()} ') + components.append(f'({keys})') + return ''.join(components) + ';' + + +@DefaultSQLRenderer.renderer_for(Index) +def render_index(model: Index) -> str: + ''' + Returns inline SQL of the index to be created separately from table + definition: + + CREATE UNIQUE INDEX ON "products" USING HASH ("id"); + + But if it's a (composite) primary key index, returns an inline SQL for + composite primary key to be used inside table definition: + + PRIMARY KEY ("id", "name") + ''' + + keys = ', '.join(render_subject(s) for s in model.subjects) + + if model.pk: + return render_pk(model, keys) + + return create_components(model, keys) diff --git a/pydbml/renderer/sql/default/note.py b/pydbml/renderer/sql/default/note.py new file mode 100644 index 0000000..751bfd5 --- /dev/null +++ b/pydbml/renderer/sql/default/note.py @@ -0,0 +1,43 @@ +import re + +from pydbml.classes import Note, Table, Column +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer + + +def prepare_text_for_sql(model: Note) -> str: + ''' + - Process special escape sequence: slash before line break, which means no line break + https://www.dbml.org/docs/#multi-line-string + - replace all single quotes with double quotes + ''' + + pattern = re.compile(r'\\\n') + result = pattern.sub('', model.text) + + result = result.replace("'", '"') + return result + + +def generate_comment_on(model: Note, entity: str, name: str) -> str: + """Generate a COMMENT ON clause out from this note.""" + quoted_text = f"'{prepare_text_for_sql(model)}'" + note_sql = f'COMMENT ON {entity.upper()} "{name}" IS {quoted_text};' + return note_sql + + +@DefaultSQLRenderer.renderer_for(Note) +def render_note(model: Note) -> str: + """ + For Tables and Columns Note is converted into COMMENT ON clause. All other entities don't + have notes generated in their SQL code, but as a fallback their notes are rendered as SQL + comments when sql property is called directly. + """ + + if model.text: + if isinstance(model.parent, (Table, Column)): + return generate_comment_on(model, model.parent.__class__.__name__, model.parent.name) + else: + text = prepare_text_for_sql(model) + return '\n'.join(f'-- {line}' for line in text.split('\n')) + else: + return '' diff --git a/pydbml/renderer/sql/default/reference.py b/pydbml/renderer/sql/default/reference.py new file mode 100644 index 0000000..ee83f11 --- /dev/null +++ b/pydbml/renderer/sql/default/reference.py @@ -0,0 +1,82 @@ +from itertools import chain +from typing import List + +from pydbml.classes import Reference, Column +from pydbml.constants import MANY_TO_MANY, MANY_TO_ONE, ONE_TO_ONE, ONE_TO_MANY +from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql, get_full_name_for_sql + + +def col_names(cols: List[Column]) -> str: + return ', '.join(f'"{c.name}"' for c in cols) + + +def validate_for_sql(model: Reference): + for col in chain(model.col1, model.col2): + if col.table is None: + raise TableNotFoundError(f'Table on {col} is not set') + + +def generate_inline_sql(model: Reference, source_col: List[Column], ref_col: List[Column]) -> str: + result = comment_to_sql(model.comment) if model.comment else '' + result += ( + f'{{c}}FOREIGN KEY ({col_names(source_col)}) ' # type: ignore + f'REFERENCES {get_full_name_for_sql(ref_col[0].table)} ({col_names(ref_col)})' # type: ignore + ) + if model.on_update: + result += f' ON UPDATE {model.on_update.upper()}' + if model.on_delete: + result += f' ON DELETE {model.on_delete.upper()}' + return result + + +def generate_not_inline_sql(model: Reference, source_col: List['Column'], ref_col: List['Column']): + result = comment_to_sql(model.comment) if model.comment else '' + result += ( + f'ALTER TABLE {get_full_name_for_sql(source_col[0].table)}' # type: ignore + f' ADD {{c}}FOREIGN KEY ({col_names(source_col)})' + f' REFERENCES {get_full_name_for_sql(ref_col[0].table)} ({col_names(ref_col)})' # type: ignore + ) + if model.on_update: + result += f' ON UPDATE {model.on_update.upper()}' + if model.on_delete: + result += f' ON DELETE {model.on_delete.upper()}' + return result + ';' + + +def generate_many_to_many_sql(model: Reference) -> str: + join_table = model.join_table + table_sql = join_table.sql # type: ignore + + n = len(model.col1) + ref1_sql = generate_not_inline_sql(model, join_table.columns[:n], model.col1) # type: ignore + ref2_sql = generate_not_inline_sql(model, join_table.columns[n:], model.col2) # type: ignore + + result = '\n\n'.join((table_sql, ref1_sql, ref2_sql)) + return result.format(c='') + + +@DefaultSQLRenderer.renderer_for(Reference) +def render_reference(model: Reference) -> str: + ''' + Returns SQL of the reference: + + ALTER TABLE "orders" ADD FOREIGN KEY ("customer_id") REFERENCES "customers ("id"); + + ''' + validate_for_sql(model) + + if model.type == MANY_TO_MANY: + return generate_many_to_many_sql(model) + + result = '' + func = generate_inline_sql if model.inline else generate_not_inline_sql + if model.type in (MANY_TO_ONE, ONE_TO_ONE): + result = func(model=model, source_col=model.col1, ref_col=model.col2) + elif model.type == ONE_TO_MANY: + result = func(model=model, source_col=model.col2, ref_col=model.col1) + + c = f'CONSTRAINT "{model.name}" ' if model.name else '' + + return result.format(c=c) diff --git a/pydbml/renderer/sql/default/renderer.py b/pydbml/renderer/sql/default/renderer.py new file mode 100644 index 0000000..188d07d --- /dev/null +++ b/pydbml/renderer/sql/default/renderer.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from pydbml.renderer.sql.default.utils import reorder_tables_for_sql +from pydbml.renderer.base import BaseRenderer + + +if TYPE_CHECKING: # pragma: no cover + from pydbml.database import Database + + +class DefaultSQLRenderer(BaseRenderer): + model_renderers = {} + + @classmethod + def render(cls, model) -> str: + model.check_attributes_for_sql() + return super().render(model) + + @classmethod + def render_db(cls, db: 'Database') -> str: + refs = (ref for ref in db.refs if not ref.inline) + tables = reorder_tables_for_sql(db.tables, db.refs) + components = (cls.render(i) for i in (*db.enums, *tables, *refs)) + return '\n\n'.join(components) diff --git a/pydbml/renderer/sql/default/table.py b/pydbml/renderer/sql/default/table.py new file mode 100644 index 0000000..d208143 --- /dev/null +++ b/pydbml/renderer/sql/default/table.py @@ -0,0 +1,97 @@ +from textwrap import indent +from typing import List + +from pydbml.constants import MANY_TO_ONE, ONE_TO_ONE, ONE_TO_MANY +from pydbml.classes import Table, Reference, Column +from pydbml.exceptions import UnknownDatabaseError +from pydbml.renderer.sql.default.note import prepare_text_for_sql +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql, get_full_name_for_sql + + +def get_references_for_sql(model: Table) -> List[Reference]: + """ + Return all references in the database where this table is on the left side of SQL + reference definition. + """ + if not model.database: + raise UnknownDatabaseError(f'Database for the table {model} is not set') + result = [] + for ref in model.database.refs: + if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ + (ref.table1 == model): + result.append(ref) + elif (ref.type == ONE_TO_MANY) and (ref.table2 == model): + result.append(ref) + return result + + +def get_inline_references_for_sql(model: Table) -> List[Reference]: + ''' + Return inline references for this table sql definition + ''' + if model.abstract: + return [] + return [r for r in get_references_for_sql(model) if r.inline] + + +def create_body(model: Table) -> str: + body: List[str] = [] + body.extend(indent(DefaultSQLRenderer.render(c), " ") for c in model.columns) + body.extend(indent(DefaultSQLRenderer.render(i), " ") for i in model.indexes if i.pk) + body.extend(indent(DefaultSQLRenderer.render(r), " ") for r in get_inline_references_for_sql(model)) + + if model._has_composite_pk(): + body.append( + " PRIMARY KEY (" + + ', '.join(f'"{c.name}"' for c in model.columns if c.pk) + + ')') + + return ',\n'.join(body) + + +def create_components(model: Table) -> str: + components = [comment_to_sql(model.comment)] if model.comment else [] + components.append(f'CREATE TABLE {get_full_name_for_sql(model)} (') + + body = create_body(model) + + components.append(body) + components.append(');') + components.extend('\n' + DefaultSQLRenderer.render(i) for i in model.indexes if not i.pk) + + return '\n'.join(components) + + +def render_column_notes(model: Table) -> str: + result = '' + for col in model.columns: + if col.note: + quoted_note = f"'{prepare_text_for_sql(col.note)}'" + note_sql = f'COMMENT ON COLUMN "{model.name}"."{col.name}" IS {quoted_note};' + result += f'\n\n{note_sql}' + return result + + +@DefaultSQLRenderer.renderer_for(Table) +def render_table(model: Table) -> str: + ''' + Returns full SQL for table definition: + + CREATE TABLE "countries" ( + "code" int PRIMARY KEY, + "name" varchar, + "continent_name" varchar + ); + + Also returns indexes if they were defined: + + CREATE INDEX ON "products" ("id", "name"); + ''' + result = create_components(model) + + if model.note: + result += f'\n\n{model.note.sql}' + + result += render_column_notes(model) + return result diff --git a/pydbml/renderer/sql/default/utils.py b/pydbml/renderer/sql/default/utils.py new file mode 100644 index 0000000..8befac3 --- /dev/null +++ b/pydbml/renderer/sql/default/utils.py @@ -0,0 +1,37 @@ +from typing import List, Dict, Union + +from pydbml.classes import Enum, Reference, Table +from pydbml.constants import MANY_TO_ONE, ONE_TO_MANY +from pydbml.tools import comment + + +def comment_to_sql(val: str) -> str: + return comment(val, '--') + + +def reorder_tables_for_sql(tables: List['Table'], refs: List['Reference']) -> List['Table']: + """ + Attempt to reorder the tables, so that they are defined in SQL before they are referenced by + inline foreign keys. + + Won't aid the rare cases of cross-references and many-to-many relations. + """ + + references: Dict[str, int] = {} + for ref in refs: + if ref.inline: + if ref.type == MANY_TO_ONE and ref.table1 is not None: + table_name = ref.table1.name + elif ref.type == ONE_TO_MANY and ref.table2 is not None: + table_name = ref.table2.name + else: # pragma: no cover + continue + references[table_name] = references.get(table_name, 0) + 1 + return sorted(tables, key=lambda t: references.get(t.name, 0), reverse=True) + + +def get_full_name_for_sql(model: Union[Table, Enum]) -> str: + if model.schema == 'public': + return f'"{model.name}"' + else: + return f'"{model.schema}"."{model.name}"' diff --git a/pydbml/tools.py b/pydbml/tools.py index fe08cf1..4892cbc 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -2,28 +2,13 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: # pragma: no cover - from .classes import Note + pass def comment(val: str, comb: str) -> str: return '\n'.join(f'{comb} {cl}' for cl in val.split('\n')) + '\n' -def comment_to_dbml(val: str) -> str: - return comment(val, '//') - - -def comment_to_sql(val: str) -> str: - return comment(val, '--') - - -def note_option_to_dbml(note: 'Note') -> str: - if '\n' in note.text: - return f"note: '''{note._prepare_text_for_dbml()}'''" - else: - return f"note: '{note._prepare_text_for_dbml()}'" - - def indent(val: str, spaces=4) -> str: if val == '': return val diff --git a/setup.py b/setup.py index 634e0b7..12c1887 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ -from setuptools import setup - +from setuptools import setup, find_packages SHORT_DESCRIPTION = 'Python parser and builder for DBML' @@ -21,14 +20,12 @@ author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', - packages=['pydbml', 'pydbml.classes', 'pydbml.definitions', 'pydbml.parser'], + packages=find_packages(exclude=['test', 'test.*']), license='MIT', platforms='any', - install_requires=[ - 'pyparsing>=3.0.0', - ], + install_requires=['pyparsing>=3.0.0'], classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Environment :: Console", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..d3ba83f --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,137 @@ +from textwrap import dedent + +import pytest + +from pydbml import Database +from pydbml._classes.reference import Reference +from pydbml._classes.sticky_note import StickyNote +from pydbml.classes import Column, Enum, EnumItem, Note, Expression, Table, Index + + +@pytest.fixture +def db(): + return Database() + + +@pytest.fixture +def enum_item1(): + return EnumItem('en-US') + + +@pytest.fixture +def enum1(): + return Enum('product status', ('production', 'development')) + + +@pytest.fixture +def expression1() -> Expression: + return Expression('SUM(amount)') + + +@pytest.fixture +def simple_column() -> Column: + return Column( + name='id', + type='integer' + ) + + +@pytest.fixture +def simple_column_with_table(db: Database, table1: Table, simple_column: Column) -> Column: + table1.add_column(simple_column) + db.add(table1) + return simple_column + + +@pytest.fixture +def complex_column(enum1: Enum) -> Column: + return Column( + name='counter', + type=enum1, + pk=True, + autoinc=True, + unique=True, + not_null=True, + default=0, + comment='This is a counter column', + note=Note('This is a note for the column') + ) + + +@pytest.fixture +def complex_column_with_table(db: Database, table1: Table, complex_column: Column) -> Column: + table1.add_column(complex_column) + db.add(table1) + return complex_column + + +@pytest.fixture +def table1() -> Table: + return Table( + name='products', + columns=[ + Column('id', 'integer'), + Column('name', 'varchar'), + ] + ) + + +@pytest.fixture +def table2() -> Table: + return Table( + name='products', + columns=[ + Column('id', 'integer'), + Column('name', 'varchar'), + ] + ) + + +@pytest.fixture +def table3() -> Table: + return Table( + name='orders', + columns=[ + Column('id', 'integer'), + Column('product_id', 'integer'), + Column('price', 'float'), + ] + ) + +@pytest.fixture +def reference1(table2: Table, table3: Table) -> Reference: + return Reference( + type='>', + col1=[table3.columns[1]], + col2=[table2.columns[0]], + ) + + +@pytest.fixture +def index1(table1: Table) -> Index: + result = Index( + subjects=[table1.columns[1]] + ) + table1.add_index(result) + return result + + +@pytest.fixture +def note1(): + return Note('Simple note') + + +@pytest.fixture +def sticky_note1(): + return StickyNote(name='mynote', text='Simple note') + + +@pytest.fixture +def multiline_note(): + return Note( + dedent( + '''\ + This is a multiline note. + It has multiple lines.''' + ) + ) diff --git a/test/test_blueprints/test_sticky_note.py b/test/test_blueprints/test_sticky_note.py index 5cdcd9c..1f423a5 100644 --- a/test/test_blueprints/test_sticky_note.py +++ b/test/test_blueprints/test_sticky_note.py @@ -1,6 +1,6 @@ from unittest import TestCase -from pydbml.classes.sticky_note import StickyNote +from pydbml._classes.sticky_note import StickyNote from pydbml.parser.blueprints import StickyNoteBlueprint class TestNote(TestCase): diff --git a/test/test_classes/test_base.py b/test/test_classes/test_base.py index 4305665..ebfae0e 100644 --- a/test/test_classes/test_base.py +++ b/test/test_classes/test_base.py @@ -1,6 +1,6 @@ from unittest import TestCase -from pydbml.classes.base import SQLObject +from pydbml._classes.base import SQLObject from pydbml.exceptions import AttributeMissingError diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py index cd1975d..dca0cd8 100644 --- a/test/test_classes/test_column.py +++ b/test/test_classes/test_column.py @@ -1,11 +1,9 @@ from unittest import TestCase from pydbml.classes import Column -from pydbml.classes import Expression from pydbml.classes import Note from pydbml.classes import Reference from pydbml.classes import Table -from pydbml.classes import Enum from pydbml.database import Database from pydbml.exceptions import TableNotFoundError @@ -52,27 +50,6 @@ def test_database_set(self) -> None: database.add(table) self.assertIs(col.database, database) - def test_basic_sql(self) -> None: - r = Column(name='id', - type='integer') - expected = '"id" integer' - self.assertEqual(r.sql, expected) - - def test_sql_enum_type(self) -> None: - et = Enum('product status', ('production', 'development')) - db = Database() - db.add_enum(et) - r = Column(name='id', - type=et, - pk=True, - autoinc=True) - expected = '"id" "product status" PRIMARY KEY AUTOINCREMENT' - self.assertEqual(r.sql, expected) - - et.schema = 'myschema' - expected = '"id" "myschema"."product status" PRIMARY KEY AUTOINCREMENT' - self.assertEqual(r.sql, expected) - def test_pk_autoinc(self) -> None: r = Column(name='id', type='integer', @@ -107,116 +84,6 @@ def test_comment(self) -> None: "id" integer UNIQUE NOT NULL''' self.assertEqual(r.sql, expected) - def test_dbml_simple(self): - c = Column( - name='order', - type='integer' - ) - t = Table(name='Test') - t.add_column(c) - s = Database() - s.add(t) - expected = '"order" integer' - - self.assertEqual(c.dbml, expected) - - def test_dbml_enum_type(self) -> None: - et = Enum('product status', ('production', 'development')) - db = Database() - db.add_enum(et) - r = Column(name='id', - type=et, - pk=True, - autoinc=True) - t = Table('products') - t.add_column(r) - db.add_table(t) - expected = '"id" "product status" [pk, increment]' - self.assertEqual(r.dbml, expected) - - et.schema = 'myschema' - expected = '"id" "myschema"."product status" [pk, increment]' - self.assertEqual(r.dbml, expected) - - def test_dbml_full(self): - c = Column( - name='order', - type='integer', - unique=True, - not_null=True, - pk=True, - autoinc=True, - default='Def_value', - note='Note on the column', - comment='Comment on the column' - ) - t = Table(name='Test') - t.add_column(c) - s = Database() - s.add(t) - expected = \ -'''// Comment on the column -"order" integer [pk, increment, default: 'Def_value', unique, not null, note: 'Note on the column']''' - - self.assertEqual(c.dbml, expected) - - def test_dbml_multiline_note(self): - c = Column( - name='order', - type='integer', - not_null=True, - note='Note on the column\nmultiline', - comment='Comment on the column' - ) - t = Table(name='Test') - t.add_column(c) - s = Database() - s.add(t) - expected = \ -"""// Comment on the column -"order" integer [not null, note: '''Note on the column -multiline''']""" - - self.assertEqual(c.dbml, expected) - - def test_dbml_default(self): - c = Column( - name='order', - type='integer', - default='String value' - ) - t = Table(name='Test') - t.add_column(c) - s = Database() - s.add(t) - - expected = "\"order\" integer [default: 'String value']" - self.assertEqual(c.dbml, expected) - - c.default = 3 - expected = '"order" integer [default: 3]' - self.assertEqual(c.dbml, expected) - - c.default = 3.33 - expected = '"order" integer [default: 3.33]' - self.assertEqual(c.dbml, expected) - - c.default = Expression("now() - interval '5 days'") - expected = "\"order\" integer [default: `now() - interval '5 days'`]" - self.assertEqual(c.dbml, expected) - - c.default = 'NULL' - expected = '"order" integer [default: null]' - self.assertEqual(c.dbml, expected) - - c.default = 'TRue' - expected = '"order" integer [default: true]' - self.assertEqual(c.dbml, expected) - - c.default = 'false' - expected = '"order" integer [default: false]' - self.assertEqual(c.dbml, expected) - def test_database(self): c1 = Column(name='client_id', type='integer') t1 = Table(name='products') @@ -246,52 +113,54 @@ def test_get_refs(self) -> None: self.assertEqual(c1.get_refs(), [ref]) - def test_dbml_with_ref(self) -> None: - c1 = Column(name='client_id', type='integer') - t1 = Table(name='products') - t1.add_column(c1) - c2 = Column(name='id', type='integer', autoinc=True, pk=True) - t2 = Table(name='clients') - t2.add_column(c2) - - ref = Reference(type='>', col1=c1, col2=c2) - s = Database() - s.add(t1) - s.add(t2) - s.add(ref) - - expected = '"client_id" integer' - self.assertEqual(c1.dbml, expected) - ref.inline = True - expected = '"client_id" integer [ref: > "clients"."id"]' - self.assertEqual(c1.dbml, expected) - expected = '"id" integer [pk, increment]' - self.assertEqual(c2.dbml, expected) - - def test_dbml_with_ref_and_properties(self) -> None: - c1 = Column(name='client_id', type='integer') - t1 = Table(name='products') - t1.add_column(c1) - c2 = Column(name='id', type='integer', autoinc=True, pk=True) - t2 = Table(name='clients') - t2.add_column(c2) - - ref = Reference(type='<', col1=c2, col2=c1) - s = Database() - s.add(t1) - s.add(t2) - s.add(ref) - - expected = '"id" integer [pk, increment]' - self.assertEqual(c2.dbml, expected) - ref.inline = True - expected = '"id" integer [ref: < "products"."client_id", pk, increment]' - self.assertEqual(c2.dbml, expected) - expected = '"client_id" integer' - self.assertEqual(c1.dbml, expected) - def test_note_property(self): note1 = Note('column note') c1 = Column(name='client_id', type='integer') c1.note = note1 self.assertIs(c1.note.parent, c1) + + +class TestEqual: + @staticmethod + def test_other_type() -> None: + c1 = Column('name', 'VARCHAR2') + assert c1 != 'name' + + @staticmethod + def test_different_tables() -> None: + t1 = Table('table1', columns=[Column('name', 'VARCHAR2')]) + t2 = Table('table2', columns=[Column('name', 'VARCHAR2')]) + assert t1.columns[0] != t2.columns[0] + + @staticmethod + def test_same_table() -> None: + t1 = Table('table1', columns=[Column('name', 'VARCHAR2')]) + t2 = Table('table1', columns=[Column('name', 'VARCHAR2')]) + assert t1.columns[0] == t2.columns[0] + + @staticmethod + def test_same_column() -> None: + c1 = Column('name', 'VARCHAR2') + assert c1 == c1 + + @staticmethod + def test_table_not_set() -> None: + c1 = Column('name', 'VARCHAR2') + c2 = Column('name', 'VARCHAR2') + assert c1 == c2 + + @staticmethod + def test_ont_table_not_set() -> None: + c1 = Column('name', 'VARCHAR2') + c2 = Column('name', 'VARCHAR2') + t1 = Table('table1') + c1.table = t1 + assert c1 != c2 + + c1.table, c2.table = None, t1 + assert c1 != c2 + + +def test_repr() -> None: + c1 = Column('name', 'VARCHAR2') + assert repr(c1) == "" \ No newline at end of file diff --git a/test/test_classes/test_enum.py b/test/test_classes/test_enum.py index e3c8bb0..fb053fa 100644 --- a/test/test_classes/test_enum.py +++ b/test/test_classes/test_enum.py @@ -5,23 +5,6 @@ class TestEnumItem(TestCase): - def test_dbml_simple(self): - ei = EnumItem('en-US') - expected = '"en-US"' - self.assertEqual(ei.dbml, expected) - - def test_sql(self): - ei = EnumItem('en-US') - expected = "'en-US'," - self.assertEqual(ei.sql, expected) - - def test_dbml_full(self): - ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') - expected = \ -'''// EnumItem comment -"en-US" [note: 'preferred']''' - self.assertEqual(ei.dbml, expected) - def test_note_property(self): note1 = Note('enum item note') ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') @@ -30,100 +13,6 @@ def test_note_property(self): class TestEnum(TestCase): - def test_simple_enum(self) -> None: - items = [ - EnumItem('created'), - EnumItem('running'), - EnumItem('donef'), - EnumItem('failure'), - ] - e = Enum('job_status', items) - expected = \ -'''CREATE TYPE "job_status" AS ENUM ( - 'created', - 'running', - 'donef', - 'failure', -);''' - self.assertEqual(e.sql, expected) - - def test_schema(self) -> None: - items = [ - EnumItem('created'), - EnumItem('running'), - EnumItem('donef'), - EnumItem('failure'), - ] - e = Enum('job_status', items, schema="myschema") - expected = \ -'''CREATE TYPE "myschema"."job_status" AS ENUM ( - 'created', - 'running', - 'donef', - 'failure', -);''' - self.assertEqual(e.sql, expected) - - def test_comments(self) -> None: - items = [ - EnumItem('created', comment='EnumItem comment'), - EnumItem('running'), - EnumItem('donef', comment='EnumItem\nmultiline comment'), - EnumItem('failure'), - ] - e = Enum('job_status', items, comment='Enum comment') - expected = \ -'''-- Enum comment -CREATE TYPE "job_status" AS ENUM ( - -- EnumItem comment - 'created', - 'running', - -- EnumItem - -- multiline comment - 'donef', - 'failure', -);''' - self.assertEqual(e.sql, expected) - - def test_dbml_simple(self): - items = [EnumItem('en-US'), EnumItem('ru-RU'), EnumItem('en-GB')] - e = Enum('lang', items) - expected = \ -'''Enum "lang" { - "en-US" - "ru-RU" - "en-GB" -}''' - self.assertEqual(e.dbml, expected) - - def test_dbml_schema(self): - items = [EnumItem('en-US'), EnumItem('ru-RU'), EnumItem('en-GB')] - e = Enum('lang', items, schema="myschema") - expected = \ -'''Enum "myschema"."lang" { - "en-US" - "ru-RU" - "en-GB" -}''' - self.assertEqual(e.dbml, expected) - - def test_dbml_full(self): - items = [ - EnumItem('en-US', note='preferred'), - EnumItem('ru-RU', comment='Multiline\ncomment'), - EnumItem('en-GB')] - e = Enum('lang', items, comment="Enum comment") - expected = \ -'''// Enum comment -Enum "lang" { - "en-US" [note: 'preferred'] - // Multiline - // comment - "ru-RU" - "en-GB" -}''' - self.assertEqual(e.dbml, expected) - def test_getitem(self) -> None: ei = EnumItem('created') items = [ @@ -154,3 +43,12 @@ def test_iter(self) -> None: for i1, i2 in zip(e, [ei1, ei2, ei3, ei4]): self.assertIs(i1, i2) + + +def test_repr(enum_item1: EnumItem) -> None: + assert repr(enum_item1) == "" + + +def test_str() -> None: + ei = EnumItem('en-US') + assert str(ei) == 'en-US' diff --git a/test/test_classes/test_expression.py b/test/test_classes/test_expression.py index e61fdf3..171ac9e 100644 --- a/test/test_classes/test_expression.py +++ b/test/test_classes/test_expression.py @@ -3,11 +3,9 @@ from pydbml.classes import Expression -class TestNote(TestCase): - def test_sql(self): - e = Expression('SUM(amount)') - self.assertEqual(e.sql, '(SUM(amount))') - - def test_dbml(self): - e = Expression('SUM(amount)') - self.assertEqual(e.dbml, '`SUM(amount)`') +def test_str(expression1: Expression) -> None: + assert str(expression1) == 'SUM(amount)' + + +def test_repr(expression1: Expression) -> None: + assert repr(expression1) == "Expression('SUM(amount)')" diff --git a/test/test_classes/test_index.py b/test/test_classes/test_index.py index 9710d2a..25db20d 100644 --- a/test/test_classes/test_index.py +++ b/test/test_classes/test_index.py @@ -1,139 +1,21 @@ -from unittest import TestCase - from pydbml.classes import Column -from pydbml.classes import Expression from pydbml.classes import Index from pydbml.classes import Note from pydbml.classes import Table -from pydbml.exceptions import ColumnNotFoundError - - -class TestIndex(TestCase): - def test_basic_sql(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subjects=[t.columns[0]]) - t.add_index(r) - self.assertIs(r.table, t) - expected = 'CREATE INDEX ON "products" ("id");' - self.assertEqual(r.sql, expected) - - def test_basic_sql_str(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subjects=['id']) - t.add_index(r) - self.assertIs(r.table, t) - expected = 'CREATE INDEX ON "products" (id);' - self.assertEqual(r.sql, expected) - - def test_column_not_in_table(self) -> None: - t = Table('products') - c = Column('id', 'integer') - i = Index(subjects=[c]) - with self.assertRaises(ColumnNotFoundError): - t.add_index(i) - self.assertIsNone(i.table) - t2 = Table('customers') - t2.add_column(c) - i2 = Index(subjects=[c]) - with self.assertRaises(ColumnNotFoundError): - t.add_index(i2) - self.assertIsNone(i2.table) - - def test_comment(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subjects=[t.columns[0]], - comment='Index comment') - t.add_index(r) - self.assertIs(r.table, t) - expected = \ -'''-- Index comment -CREATE INDEX ON "products" ("id");''' - - self.assertEqual(r.sql, expected) - - def test_unique_type_composite(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - t.add_column(Column('name', 'varchar')) - r = Index( - subjects=[ - t.columns[0], - t.columns[1] - ], - type='hash', - unique=True - ) - t.add_index(r) - expected = 'CREATE UNIQUE INDEX ON "products" USING HASH ("id", "name");' - self.assertEqual(r.sql, expected) - - def test_pk(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - t.add_column(Column('name', 'varchar')) - r = Index( - subjects=[ - t.columns[0], - t.columns[1] - ], - pk=True - ) - t.add_index(r) - expected = 'PRIMARY KEY ("id", "name")' - self.assertEqual(r.sql, expected) - - def test_composite_with_expression(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subjects=[t.columns[0], Expression('id*3')]) - t.add_index(r) - self.assertEqual(r.subjects, [t['id'], Expression('id*3')]) - expected = 'CREATE INDEX ON "products" ("id", (id*3));' - self.assertEqual(r.sql, expected) - def test_dbml_simple(self): - t = Table('products') - t.add_column(Column('id', 'integer')) - i = Index(subjects=[t.columns[0]]) - t.add_index(i) - expected = 'id' - self.assertEqual(i.dbml, expected) +def test_note_property(): + note1 = Note('column note') + t = Table('products') + c = Column('id', 'integer') + i = Index(subjects=[c]) + i.note = note1 + assert i.note.parent is i - def test_dbml_composite(self): - t = Table('products') - t.add_column(Column('id', 'integer')) - i = Index(subjects=[t.columns[0], Expression('id*3')]) - t.add_index(i) - expected = '(id, `id*3`)' - self.assertEqual(i.dbml, expected) +def test_repr(index1: Index) -> None: + assert repr(index1) == "" - def test_dbml_full(self): - t = Table('products') - t.add_column(Column('id', 'integer')) - i = Index( - subjects=[t.columns[0], Expression('getdate()')], - name='Dated id', - unique=True, - type='hash', - pk=True, - note='Note on the column', - comment='Comment on the index' - ) - t.add_index(i) - expected = \ -'''// Comment on the index -(id, `getdate()`) [name: 'Dated id', pk, unique, type: hash, note: 'Note on the column']''' - self.assertEqual(i.dbml, expected) - def test_note_property(self): - note1 = Note('column note') - t = Table('products') - c = Column('id', 'integer') - i = Index(subjects=[c]) - i.note = note1 - self.assertIs(i.note.parent, i) +def test_str(index1: Index) -> None: + assert str(index1) == 'Index(products[name])' diff --git a/test/test_classes/test_note.py b/test/test_classes/test_note.py index 8ff2a9d..00b700a 100644 --- a/test/test_classes/test_note.py +++ b/test/test_classes/test_note.py @@ -1,104 +1,23 @@ from pydbml.classes import Note -from pydbml.classes import Table -from pydbml.classes import Index -from pydbml.classes import Column -from unittest import TestCase -class TestNote(TestCase): - def test_init_types(self): - n1 = Note('My note text') - n2 = Note(3) - n3 = Note([1, 2, 3]) - n4 = Note(None) - n5 = Note(n1) +def test_init_types(): + n1 = Note('My note text') + n2 = Note(3) + n3 = Note([1, 2, 3]) + n4 = Note(None) + n5 = Note(n1) - self.assertEqual(n1.text, 'My note text') - self.assertEqual(n2.text, '3') - self.assertEqual(n3.text, '[1, 2, 3]') - self.assertEqual(n4.text, '') - self.assertEqual(n5.text, 'My note text') + assert n1.text == 'My note text' + assert n2.text == '3' + assert n3.text == '[1, 2, 3]' + assert n4.text == '' + assert n5.text == 'My note text' - def test_oneline(self): - note = Note('One line of note text') - expected = \ -'''Note { - 'One line of note text' -}''' - self.assertEqual(note.dbml, expected) - def test_forced_multiline(self): - note = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.') - expected = \ -"""Note { - ''' - The number of spaces you use to indent a block string - will - be the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output. - ''' -}""" - self.assertEqual(note.dbml, expected) +def test_str(note1: Note) -> None: + assert str(note1) == 'Simple note' - def test_sql_general(self) -> None: - note1 = Note(None) - self.assertEqual(note1.sql, '') - note2 = Note('One line of note text') - self.assertEqual(note2.sql, '-- One line of note text') - note3 = Note('The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.') - expected = \ -"""-- The number of spaces you use to indent a block string --- will --- be the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.""" - self.assertEqual(note3.sql, expected) - def test_sql_table(self) -> None: - table = Table(name="test") - note1 = Note(None) - table.note = note1 - self.assertEqual(note1.sql, '') - note2 = Note('One line of note text') - table.note = note2 - self.assertEqual(note2.sql, 'COMMENT ON TABLE "test" IS \'One line of note text\';') - - def test_sql_column(self) -> None: - column = Column(name="test", type="int") - note1 = Note(None) - column.note = note1 - self.assertEqual(note1.sql, '') - note2 = Note('One line of note text') - column.note = note2 - self.assertEqual(note2.sql, 'COMMENT ON COLUMN "test" IS \'One line of note text\';') - - def test_sql_index(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - index = Index(subjects=[t.columns[0]]) - - note1 = Note(None) - index.note = note1 - self.assertEqual(note1.sql, '') - note2 = Note('One line of note text') - index.note = note2 - self.assertEqual(note2.sql, '-- One line of note text') - - def test_prepare_text_for_sql(self): - line_escape = 'This text \\\nis not split \\\ninto lines' - quotes = "'asd' There's ''' asda '''' asd ''''' asdsa ''" - - note = Note(line_escape) - expected = 'This text is not split into lines' - self.assertEqual(note._prepare_text_for_sql(), expected) - - note = Note(quotes) - expected = '"asd" There"s """ asda """" asd """"" asdsa ""' - self.assertEqual(note._prepare_text_for_sql(), expected) - - def test_prepare_text_for_dbml(self): - quotes = "'asd' There's ''' asda '''' asd ''''' asdsa ''" - expected = "\\'asd\\' There\\'s \\''' asda \\'''\\' asd \\'''\\'\\' asdsa \\'\\'" - note = Note(quotes) - self.assertEqual(note._prepare_text_for_dbml(), expected) - - def test_escaped_newline_sql(self) -> None: - note = Note('One line of note text \\\nstill one line') - self.assertEqual(note.sql, '-- One line of note text still one line') +def test_repr(note1: Note) -> None: + assert repr(note1) == "Note('Simple note')" diff --git a/test/test_classes/test_project.py b/test/test_classes/test_project.py index d77537a..3442b3a 100644 --- a/test/test_classes/test_project.py +++ b/test/test_classes/test_project.py @@ -1,56 +1,15 @@ -from pydbml.classes import Project from pydbml.classes import Note - -from unittest import TestCase +from pydbml.classes import Project -class TestProject(TestCase): - def test_dbml_note(self): - p = Project('myproject', note='Project note') - expected = \ -'''Project "myproject" { - Note { - 'Project note' - } -}''' - self.assertEqual(p.dbml, expected) +def test_note_property(): + note1 = Note('column note') + p = Project('myproject') + p.note = note1 + assert p.note.parent is p - def test_dbml_full(self): - p = Project( - 'myproject', - items={ - 'database_type': 'PostgreSQL', - 'story': "One day I was eating my cantaloupe and\nI thought, why shouldn't I?\nWhy shouldn't I create a database?" - }, - comment='Multiline\nProject comment', - note='Multiline\nProject note') - expected = \ -"""// Multiline -// Project comment -Project "myproject" { - database_type: 'PostgreSQL' - story: '''One day I was eating my cantaloupe and - I thought, why shouldn't I? - Why shouldn't I create a database?''' - Note { - ''' - Multiline - Project note - ''' - } -}""" - self.assertEqual(p.dbml, expected) - def test_dbml_space(self) -> None: - p = Project('My project', {'a': 'b'}) - expected = \ -'''Project "My project" { - a: 'b' -}''' - self.assertEqual(p.dbml, expected) +def test_repr() -> None: + project = Project('myproject') + assert repr(project) == "" - def test_note_property(self): - note1 = Note('column note') - p = Project('myproject') - p.note = note1 - self.assertIs(p.note.parent, p) diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py index d00e698..5e7cc14 100644 --- a/test/test_classes/test_reference.py +++ b/test/test_classes/test_reference.py @@ -5,21 +5,10 @@ from pydbml.classes import Table from pydbml.exceptions import DBMLError from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.sql.default.reference import validate_for_sql class TestReference(TestCase): - def test_sql_single(self): - t = Table('products') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('>', c1, c2) - - expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val");' - self.assertEqual(ref.sql, expected) - def test_table1(self): t = Table('products') c1 = Column('name', 'varchar2') @@ -31,169 +20,6 @@ def test_table1(self): t.add_column(c1) self.assertIs(ref.table1, t) - def test_sql_schema_single(self): - t = Table('products', schema='myschema1') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names', schema='myschema2') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('>', c1, c2) - - expected = 'ALTER TABLE "myschema1"."products" ADD FOREIGN KEY ("name") REFERENCES "myschema2"."names" ("name_val");' - self.assertEqual(ref.sql, expected) - - def test_sql_reverse(self): - t = Table('products') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('<', c1, c2) - - expected = 'ALTER TABLE "names" ADD FOREIGN KEY ("name_val") REFERENCES "products" ("name");' - self.assertEqual(ref.sql, expected) - - def test_sql_multiple(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('>', [c11, c12], (c21, c22)) - - expected = 'ALTER TABLE "products" ADD FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val");' - self.assertEqual(ref.sql, expected) - - def test_sql_schema_multiple(self): - t = Table('products', schema="myschema1") - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names', schema="myschema2") - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('<', [c11, c12], (c21, c22)) - - expected = 'ALTER TABLE "myschema2"."names" ADD FOREIGN KEY ("name_val", "country_val") REFERENCES "myschema1"."products" ("name", "country");' - self.assertEqual(ref.sql, expected) - - def test_sql_full(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference( - '>', - [c11, c12], - (c21, c22), - name="country_name", - comment="Multiline\ncomment for the constraint", - on_update="CASCADE", - on_delete="SET NULL" - ) - - expected = \ -'''-- Multiline --- comment for the constraint -ALTER TABLE "products" ADD CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL;''' - - self.assertEqual(ref.sql, expected) - - def test_many_to_many_sql_simple(self) -> None: - t1 = Table('books') - c11 = Column('id', 'integer', pk=True) - c12 = Column('author', 'varchar') - t1.add_column(c11) - t1.add_column(c12) - t2 = Table('authors') - c21 = Column('id', 'integer', pk=True) - c22 = Column('name', 'varchar') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('<>', c11, c21) - - expected = \ -'''CREATE TABLE "books_authors" ( - "books_id" integer NOT NULL, - "authors_id" integer NOT NULL, - PRIMARY KEY ("books_id", "authors_id") -); - -ALTER TABLE "books_authors" ADD FOREIGN KEY ("books_id") REFERENCES "books" ("id"); - -ALTER TABLE "books_authors" ADD FOREIGN KEY ("authors_id") REFERENCES "authors" ("id");''' - self.assertEqual(expected, ref.sql) - - def test_many_to_many_sql_composite(self) -> None: - t1 = Table('books') - c11 = Column('id', 'integer', pk=True) - c12 = Column('author', 'varchar') - t1.add_column(c11) - t1.add_column(c12) - t2 = Table('authors') - c21 = Column('id', 'integer', pk=True) - c22 = Column('name', 'varchar') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('<>', [c11, c12], [c21, c22]) - - expected = \ -'''CREATE TABLE "books_authors" ( - "books_id" integer NOT NULL, - "books_author" varchar NOT NULL, - "authors_id" integer NOT NULL, - "authors_name" varchar NOT NULL, - PRIMARY KEY ("books_id", "books_author", "authors_id", "authors_name") -); - -ALTER TABLE "books_authors" ADD FOREIGN KEY ("books_id", "books_author") REFERENCES "books" ("id", "author"); - -ALTER TABLE "books_authors" ADD FOREIGN KEY ("authors_id", "authors_name") REFERENCES "authors" ("id", "name");''' - self.assertEqual(expected, ref.sql) - - def test_many_to_many_sql_composite_different_schemas(self) -> None: - t1 = Table('books', schema="schema1") - c11 = Column('id', 'integer', pk=True) - c12 = Column('author', 'varchar') - t1.add_column(c11) - t1.add_column(c12) - t2 = Table('authors', schema="schema2") - c21 = Column('id', 'integer', pk=True) - c22 = Column('name', 'varchar') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('<>', [c11, c12], [c21, c22]) - - expected = \ -'''CREATE TABLE "schema1"."books_authors" ( - "books_id" integer NOT NULL, - "books_author" varchar NOT NULL, - "authors_id" integer NOT NULL, - "authors_name" varchar NOT NULL, - PRIMARY KEY ("books_id", "books_author", "authors_id", "authors_name") -); - -ALTER TABLE "schema1"."books_authors" ADD FOREIGN KEY ("books_id", "books_author") REFERENCES "schema1"."books" ("id", "author"); - -ALTER TABLE "schema1"."books_authors" ADD FOREIGN KEY ("authors_id", "authors_name") REFERENCES "schema2"."authors" ("id", "name");''' - self.assertEqual(expected, ref.sql) - def test_join_table(self) -> None: t1 = Table('books') c11 = Column('id', 'integer', pk=True) @@ -235,248 +61,8 @@ def test_join_table_none(self) -> None: with self.assertRaises(TableNotFoundError): ref.join_table - def test_dbml_simple(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - ref = Reference('>', c2, c21) - - expected = \ -'''Ref { - "products"."name" > "names"."name_val" -}''' - self.assertEqual(ref.dbml, expected) - - def test_dbml_schema(self): - t = Table('products', schema="myschema1") - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names', schema="myschema2") - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - ref = Reference('>', c2, c21) - - expected = \ -'''Ref { - "myschema1"."products"."name" > "myschema2"."names"."name_val" -}''' - self.assertEqual(ref.dbml, expected) - - def test_dbml_full(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - c3 = Column('country', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t.add_column(c3) - t2 = Table('names', schema="myschema") - c21 = Column('name_val', 'varchar2') - c22 = Column('country', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference( - '<', - [c2, c3], - (c21, c22), - name='nameref', - comment='Reference comment\nmultiline', - on_update='CASCADE', - on_delete='SET NULL' - ) - - expected = \ -'''// Reference comment -// multiline -Ref nameref { - "products".("name", "country") < "myschema"."names".("name_val", "country") [update: CASCADE, delete: SET NULL] -}''' - self.assertEqual(ref.dbml, expected) - class TestReferenceInline(TestCase): - def test_sql_single(self): - t = Table('products') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('>', c1, c2, inline=True) - - expected = 'FOREIGN KEY ("name") REFERENCES "names" ("name_val")' - self.assertEqual(ref.sql, expected) - - def test_sql_schema_single(self): - t = Table('products', schema="myschema1") - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names', schema="myschema2") - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('>', c1, c2, inline=True) - - expected = 'FOREIGN KEY ("name") REFERENCES "myschema2"."names" ("name_val")' - self.assertEqual(ref.sql, expected) - - def test_sql_reverse(self): - t = Table('products') - c1 = Column('name', 'varchar2') - t.add_column(c1) - t2 = Table('names') - c2 = Column('name_val', 'varchar2') - t2.add_column(c2) - ref = Reference('<', c1, c2, inline=True) - - expected = 'FOREIGN KEY ("name_val") REFERENCES "products" ("name")' - self.assertEqual(ref.sql, expected) - - def test_sql_multiple(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('>', [c11, c12], (c21, c22), inline=True) - - expected = 'FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val")' - self.assertEqual(ref.sql, expected) - - def test_sql_schema_multiple(self): - t = Table('products', schema="myschema1") - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names', schema="myschema2") - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference('<', [c11, c12], (c21, c22), inline=True) - - expected = 'FOREIGN KEY ("name_val", "country_val") REFERENCES "myschema1"."products" ("name", "country")' - self.assertEqual(ref.sql, expected) - - def test_sql_full(self): - t = Table('products') - c11 = Column('name', 'varchar2') - c12 = Column('country', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference( - '>', - [c11, c12], - (c21, c22), - name="country_name", - comment="Multiline\ncomment for the constraint", - on_update="CASCADE", - on_delete="SET NULL", - inline=True - ) - - expected = \ -'''-- Multiline --- comment for the constraint -CONSTRAINT "country_name" FOREIGN KEY ("name", "country") REFERENCES "names" ("name_val", "country_val") ON UPDATE CASCADE ON DELETE SET NULL''' - - self.assertEqual(ref.sql, expected) - - def test_dbml_simple(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - ref = Reference('>', c2, c21, inline=True) - - expected = 'ref: > "names"."name_val"' - self.assertEqual(ref.dbml, expected) - - def test_dbml_schema(self): - t = Table('products', schema="myschema1") - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names', schema="myschema2") - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - ref = Reference('>', c2, c21, inline=True) - - expected = 'ref: > "myschema2"."names"."name_val"' - self.assertEqual(ref.dbml, expected) - - def test_dbml_settings_ignored(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - ref = Reference( - '<', - c2, - c21, - name='nameref', - comment='Reference comment\nmultiline', - on_update='CASCADE', - on_delete='SET NULL', - inline=True - ) - - expected = 'ref: < "names"."name_val"' - self.assertEqual(ref.dbml, expected) - - def test_dbml_composite_inline_ref_forbidden(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - c3 = Column('country', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t.add_column(c3) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - c22 = Column('country', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - ref = Reference( - '<', - [c2, c3], - (c21, c22), - name='nameref', - comment='Reference comment\nmultiline', - on_update='CASCADE', - on_delete='SET NULL', - inline=True - ) - - with self.assertRaises(DBMLError): - ref.dbml - def test_validate_different_tables(self): t1 = Table('products') c11 = Column('id', 'integer') @@ -523,11 +109,11 @@ def test_validate_no_table(self): c2 ) with self.assertRaises(TableNotFoundError): - ref1._validate_for_sql() + validate_for_sql(ref1) table = Table('name') table.add_column(c1) with self.assertRaises(TableNotFoundError): - ref1._validate_for_sql() + validate_for_sql(ref1) table.delete_column(c1) ref2 = Reference( @@ -536,9 +122,9 @@ def test_validate_no_table(self): [c3, c4] ) with self.assertRaises(TableNotFoundError): - ref2._validate_for_sql() + validate_for_sql(ref2) table = Table('name') table.add_column(c1) table.add_column(c2) with self.assertRaises(TableNotFoundError): - ref2._validate_for_sql() + validate_for_sql(ref2) diff --git a/test/test_classes/test_sticky_note.py b/test/test_classes/test_sticky_note.py index 1727e83..24acb44 100644 --- a/test/test_classes/test_sticky_note.py +++ b/test/test_classes/test_sticky_note.py @@ -1,46 +1,28 @@ -from pydbml.classes import Table -from pydbml.classes import Index -from pydbml.classes import Column -from unittest import TestCase - -from pydbml.classes.sticky_note import StickyNote - - -class TestNote(TestCase): - def test_init_types(self): - n1 = StickyNote('mynote', 'My note text') - n2 = StickyNote('mynote', 3) - n3 = StickyNote('mynote', [1, 2, 3]) - n4 = StickyNote('mynote', None) - - self.assertEqual(n1.text, 'My note text') - self.assertEqual(n2.text, '3') - self.assertEqual(n3.text, '[1, 2, 3]') - self.assertEqual(n4.text, '') - self.assertTrue(n1.name == n2.name == n3.name == n4.name == 'mynote') - - def test_oneline(self): - note = StickyNote('mynote', 'One line of note text') - expected = \ -'''Note mynote { - 'One line of note text' -}''' - self.assertEqual(note.dbml, expected) - - def test_forced_multiline(self): - note = StickyNote('mynote', 'The number of spaces you use to indent a block string\nwill\nbe the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output.') - expected = \ -"""Note mynote { - ''' - The number of spaces you use to indent a block string - will - be the minimum number of leading spaces among all lines. The parser wont automatically remove the number of indentation spaces in the final output. - ''' -}""" - self.assertEqual(note.dbml, expected) - - def test_prepare_text_for_dbml(self): - quotes = "'asd' There's ''' asda '''' asd ''''' asdsa ''" - expected = "\\'asd\\' There\\'s \\''' asda \\'''\\' asd \\'''\\'\\' asdsa \\'\\'" - note = StickyNote('mynote', quotes) - self.assertEqual(note._prepare_text_for_dbml(), expected) +from pydbml._classes.sticky_note import StickyNote + + +def test_init_types(): + n1 = StickyNote('mynote', 'My note text') + n2 = StickyNote('mynote', 3) + n3 = StickyNote('mynote', [1, 2, 3]) + n4 = StickyNote('mynote', None) + + assert n1.text == 'My note text' + assert n2.text == '3' + assert n3.text == '[1, 2, 3]' + assert n4.text == '' + assert n1.name == n2.name == n3.name == n4.name == 'mynote' + + +def test_str(sticky_note1: StickyNote) -> None: + assert str(sticky_note1) == "StickyNote('mynote', 'Simple note')" + + +def test_repr(sticky_note1: StickyNote) -> None: + assert repr(sticky_note1) == "" + + +def test_bool(sticky_note1: StickyNote) -> None: + assert bool(sticky_note1) is True + sticky_note1.text = '' + assert bool(sticky_note1) is False diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 586aa15..345ee99 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -1,5 +1,7 @@ from unittest import TestCase +import pytest + from pydbml.classes import Column from pydbml.classes import Expression from pydbml.classes import Index @@ -19,15 +21,6 @@ def test_schema(self) -> None: t2 = Table('test', 'schema1') self.assertEqual(t2.schema, 'schema1') - def test_one_column(self) -> None: - t = Table('products') - c = Column('id', 'integer') - t.add_column(c) - s = Database() - s.add(t) - expected = 'CREATE TABLE "products" (\n "id" integer\n);' - self.assertEqual(t.sql, expected) - def test_getitem(self) -> None: t = Table('products') c1 = Column('col1', 'integer') @@ -95,174 +88,6 @@ def test_iter(self) -> None: for i1, i2 in zip(t, [c1, c2, c3]): self.assertIs(i1, i2) - def test_ref(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - s = Database() - s.add(t) - s.add(t2) - r = Reference('>', c2, c21) - s.add(r) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2 -);''' - self.assertEqual(t.sql, expected) - r.inline = True - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names" ("name_val") -);''' - self.assertEqual(t.sql, expected) - - def test_notes(self) -> None: - n = Note('Table note') - nc1 = Note('First column note') - nc2 = Note('Another column\nmultiline note') - t = Table('products', note=n) - c1 = Column('id', 'integer', note=nc1) - c2 = Column('name', 'varchar') - c3 = Column('country', 'varchar', note=nc2) - t.add_column(c1) - t.add_column(c2) - t.add_column(c3) - s = Database() - s.add(t) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar, - "country" varchar -); - -COMMENT ON TABLE "products" IS 'Table note'; - -COMMENT ON COLUMN "products"."id" IS 'First column note'; - -COMMENT ON COLUMN "products"."country" IS 'Another column -multiline note';''' - self.assertEqual(t.sql, expected) - - def test_ref_index(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - s = Database() - s.add(t) - - r = Reference('>', c2, c21, inline=True) - s.add(r) - i = Index(subjects=[c1, c2]) - t.add_index(i) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names" ("name_val") -); - -CREATE INDEX ON "products" ("id", "name");''' - self.assertEqual(t.sql, expected) - - def test_index_inline(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - i = Index(subjects=[c1, c2], pk=True) - t.add_index(i) - s = Database() - s.add(t) - - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - PRIMARY KEY ("id", "name") -);''' - self.assertEqual(t.sql, expected) - - def test_schema_sql(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - s = Database() - s.add(t) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2 -);''' - self.assertEqual(t.sql, expected) - t.schema = 'myschema' - expected = \ -'''CREATE TABLE "myschema"."products" ( - "id" integer, - "name" varchar2 -);''' - self.assertEqual(t.sql, expected) - - def test_index_inline_and_comments(self) -> None: - t = Table('products', comment='Multiline\ntable comment') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - i = Index(subjects=[c1, c2], pk=True, comment='Multiline\nindex comment') - t.add_index(i) - s = Database() - s.add(t) - - expected = \ -'''-- Multiline --- table comment -CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - -- Multiline - -- index comment - PRIMARY KEY ("id", "name") -);''' - self.assertEqual(t.sql, expected) - - def test_composite_pk_sql(self): - table = Table( - 'products', - columns=( - Column('id', 'integer', pk=True), - Column('name', 'varchar2', pk=True), - Column('prop', 'object', pk=True), - ) - ) - s = Database() - s.add(table) - - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - "prop" object, - PRIMARY KEY ("id", "name", "prop") -);''' - self.assertEqual(table.sql, expected) - def test_add_column(self) -> None: t = Table('products') c1 = Column('id', 'integer') @@ -325,62 +150,6 @@ def test_delete_index(self) -> None: with self.assertRaises(IndexNotFoundError): t.delete_index(i1) - def test_get_references_for_sql(self): - t = Table('products') - with self.assertRaises(UnknownDatabaseError): - t._get_references_for_sql() - c11 = Column('id', 'integer') - c12 = Column('name', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('id', 'integer') - c22 = Column('name_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - s = Database() - s.add(t) - s.add(t2) - r1 = Reference('>', c12, c22) - r2 = Reference('-', c11, c21) - r3 = Reference('<', c11, c22) - s.add(r1) - s.add(r2) - s.add(r3) - self.assertEqual(t._get_references_for_sql(), []) - self.assertEqual(t2._get_references_for_sql(), []) - r1.inline = r2.inline = r3.inline = True - self.assertEqual(t._get_references_for_sql(), [r1, r2]) - self.assertEqual(t2._get_references_for_sql(), [r3]) - - def test_get_references_for_sql_public(self): - t = Table('products') - with self.assertRaises(UnknownDatabaseError): - t._get_references_for_sql() - c11 = Column('id', 'integer') - c12 = Column('name', 'varchar2') - t.add_column(c11) - t.add_column(c12) - t2 = Table('names') - c21 = Column('id', 'integer') - c22 = Column('name_val', 'varchar2') - t2.add_column(c21) - t2.add_column(c22) - s = Database() - s.add(t) - s.add(t2) - r1 = Reference('>', c12, c22, inline=True) - r2 = Reference('-', c11, c21, inline=True) - r3 = Reference('<', c11, c22, inline=True) - s.add(r1) - s.add(r2) - s.add(r3) - self.assertEqual(t.get_references_for_sql(), [r1, r2]) - self.assertEqual(t2.get_references_for_sql(), [r3]) - r1.inline = r2.inline = r3.inline = False - self.assertEqual(t.get_references_for_sql(), [r1, r2]) - self.assertEqual(t2.get_references_for_sql(), [r3]) - def test_get_refs(self): t = Table('products') with self.assertRaises(UnknownDatabaseError): @@ -406,134 +175,27 @@ def test_get_refs(self): self.assertEqual(t.get_refs(), [r1, r2, r3]) self.assertEqual(t2.get_refs(), []) - def test_dbml_simple(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - s = Database() - s.add(t) - - expected = \ -'''Table "products" { - "id" integer - "name" varchar2 -}''' - self.assertEqual(t.dbml, expected) - - def test_header_color_dbml(self): - t = Table('products') - t.header_color = '#C84432' - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - s = Database() - s.add(t) - - expected = \ -'''Table "products" [headercolor: #C84432] { - "id" integer - "name" varchar2 -}''' - self.assertEqual(t.dbml, expected) - - - def test_schema_dbml(self): - t = Table('products', schema="myschema") - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - s = Database() - s.add(t) - - expected = \ -'''Table "myschema"."products" { - "id" integer - "name" varchar2 -}''' - self.assertEqual(t.dbml, expected) - - def test_dbml_reference(self): - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - s = Database() - s.add(t) - s.add(t2) - r = Reference('>', c2, c21) - s.add(r) - expected = \ -'''Table "products" { - "id" integer - "name" varchar2 -}''' - self.assertEqual(t.dbml, expected) - r.inline = True - expected = \ -'''Table "products" { - "id" integer - "name" varchar2 [ref: > "names"."name_val"] -}''' - self.assertEqual(t.dbml, expected) - expected = \ -'''Table "names" { - "name_val" varchar2 -}''' - self.assertEqual(t2.dbml, expected) - - def test_dbml_full(self): - t = Table( - 'products', - alias='pd', - note='My multiline\nnote', - comment='My multiline\ncomment' - ) - c0 = Column('zero', 'number') - c1 = Column('id', 'integer', unique=True, note='Multiline\ncomment note') - c2 = Column('name', 'varchar2') - t.add_column(c0) - t.add_column(c1) - t.add_column(c2) - i1 = Index(['zero', 'id'], unique=True) - i2 = Index([Expression('capitalize(name)')], comment="index comment") - t.add_index(i1) - t.add_index(i2) - s = Database() - s.add(t) - - expected = \ -"""// My multiline -// comment -Table "products" as "pd" { - "zero" number - "id" integer [unique, note: '''Multiline - comment note'''] - "name" varchar2 - Note { - ''' - My multiline - note - ''' - } - - indexes { - (zero, id) [unique] - // index comment - `capitalize(name)` - } -}""" - self.assertEqual(t.dbml, expected) - def test_note_property(self): note1 = Note('table note') t = Table(name='test') t.note = note1 self.assertIs(t.note.parent, t) + + +class TestAddIndex: + @staticmethod + def test_wrong_type(table1: Table) -> None: + with pytest.raises(TypeError): + table1.add_index('wrong_type') + + + @staticmethod + def test_column_not_in_table(table1: Table, table2: Table) -> None: + with pytest.raises(ColumnNotFoundError): + table1.add_index(Index([table2.columns[0]])) + + @staticmethod + def test_ok(table1: Table) -> None: + i = Index([table1.columns[0]]) + table1.add_index(i) + assert i.table is table1 diff --git a/test/test_classes/test_table_group.py b/test/test_classes/test_table_group.py index d7c8a9e..a718fac 100644 --- a/test/test_classes/test_table_group.py +++ b/test/test_classes/test_table_group.py @@ -5,52 +5,6 @@ class TestTableGroup(TestCase): -# string items no longer supported -# def test_dbml(self): -# tg = TableGroup('mytg', ['merchants', 'countries', 'customers']) -# expected = \ -# '''TableGroup mytg { -# merchants -# countries -# customers -# }''' -# self.assertEqual(tg.dbml, expected) - - def test_dbml_with_comment_and_real_tables(self): - merchants = Table('merchants') - countries = Table('countries') - customers = Table('customers') - tg = TableGroup( - 'mytg', - [merchants, countries, customers], - comment='My table group\nmultiline comment' - ) - expected = \ -'''// My table group -// multiline comment -TableGroup mytg { - "merchants" - "countries" - "customers" -}''' - self.assertEqual(tg.dbml, expected) - - def test_dbml_schema(self): - merchants = Table('merchants', schema="myschema1") - countries = Table('countries', schema="myschema2") - customers = Table('customers', schema="myschema3") - tg = TableGroup( - 'mytg', - [merchants, countries, customers], - ) - expected = \ -'''TableGroup mytg { - "myschema1"."merchants" - "myschema2"."countries" - "myschema3"."customers" -}''' - self.assertEqual(tg.dbml, expected) - def test_getitem(self) -> None: merchants = Table('merchants') countries = Table('countries') diff --git a/test/test_database.py b/test/test_database.py index a33ea82..095e75c 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -14,8 +14,7 @@ from pydbml.database import Database from pydbml.exceptions import DatabaseValidationError from pydbml.constants import ONE_TO_MANY, MANY_TO_ONE, MANY_TO_MANY -from pydbml.database import reorder_tables_for_sql - +from pydbml.renderer.sql.default.utils import reorder_tables_for_sql TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' @@ -371,32 +370,5 @@ class Test: t.database -class TestReorderTablesForSQL(TestCase): - - def test_reorder_tables(self) -> None: - t1 = Mock(name="table1") # 1 ref - t2 = Mock(name="table2") # 2 refs - t3 = Mock(name="table3") - t4 = Mock(name="table4") # 1 ref - t5 = Mock(name="table5") - t6 = Mock(name="table6") # 3 refs - t7 = Mock(name="table7") - t8 = Mock(name="table8") - t9 = Mock(name="table9") - t10 = Mock(name="table10") - - refs = [ - Mock(type=ONE_TO_MANY, table1=t1, table2=t2, inline=True), - Mock(type=MANY_TO_ONE, table1=t4, table2=t3, inline=True), - Mock(type=ONE_TO_MANY, table1=t6, table2=t2, inline=True), - Mock(type=ONE_TO_MANY, table1=t7, table2=t6, inline=True), - Mock(type=MANY_TO_ONE, table1=t6, table2=t8, inline=True), - Mock(type=ONE_TO_MANY, table1=t9, table2=t6, inline=True), - Mock(type=ONE_TO_MANY, table1=t1, table2=t2, inline=False), # ignored not inline - Mock(type=ONE_TO_MANY, table1=t10, table2=t1, inline=True), - Mock(type=MANY_TO_MANY, table1=t1, table2=t2, inline=True), # ignored m2m - ] - original = [t1, t2, t3, t4, t5, t6, t7, t8, t9, t10] - expected = [t6, t2, t1, t4, t3, t5, t7, t8, t9, t10] - result = reorder_tables_for_sql(original, refs) - self.assertEqual(expected, result) +def test_repr() -> None: + assert repr(Database()) == "" diff --git a/test/test_doctest.py b/test/test_doctest.py index 877cca4..9354d1e 100644 --- a/test/test_doctest.py +++ b/test/test_doctest.py @@ -1,15 +1,15 @@ import doctest from pydbml import database -from pydbml.classes import column -from pydbml.classes import enum -from pydbml.classes import expression -from pydbml.classes import index -from pydbml.classes import note -from pydbml.classes import project -from pydbml.classes import reference -from pydbml.classes import table -from pydbml.classes import table_group +from pydbml._classes import column +from pydbml._classes import enum +from pydbml._classes import expression +from pydbml._classes import index +from pydbml._classes import note +from pydbml._classes import project +from pydbml._classes import reference +from pydbml._classes import table +from pydbml._classes import table_group from pydbml.parser import parser diff --git a/test/test_parser.py b/test/test_parser.py index ac26f53..745188e 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -176,3 +176,12 @@ def test(): p_mod = PyDBML(p_mod.dbml) note2 = p_mod.tables[0].note self.assertEqual(source_text, note2.text) + + +def test_repr_pydbml() -> None: + assert repr(PyDBML()) == "" + + +def test_repr_pydbml_parser() -> None: + assert repr(PyDBMLParser('')) == "" + diff --git a/test/test_renderer/__init__.py b/test/test_renderer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_base.py b/test/test_renderer/test_base.py new file mode 100644 index 0000000..9e9f80c --- /dev/null +++ b/test/test_renderer/test_base.py @@ -0,0 +1,40 @@ +from pydbml.renderer.base import BaseRenderer + + +class SampleRenderer(BaseRenderer): + model_renderers = {} + + +def test_renderer_for() -> None: + @SampleRenderer.renderer_for(str) + def render_str(model): + return 'str' + + assert len(SampleRenderer.model_renderers) == 1 + assert str in SampleRenderer.model_renderers + assert SampleRenderer.model_renderers[str] is render_str + + +class TestRender: + @staticmethod + def test_render() -> None: + @SampleRenderer.renderer_for(str) + def render_str(model): + return 'str' + + assert SampleRenderer.render('') == 'str' + + @staticmethod + def test_render_not_supported() -> None: + assert SampleRenderer.render(1) == '' + + @staticmethod + def test_unsupported_renderer_override() -> None: + def unsupported_renderer(model): + return 'unsupported' + + class SampleRenderer2(BaseRenderer): + model_renderers = {} + _unsupported_renderer = unsupported_renderer + + assert SampleRenderer2.render(1) == 'unsupported' diff --git a/test/test_renderer/test_dbml/__init__.py b/test/test_renderer/test_dbml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_dbml/test_column.py b/test/test_renderer/test_dbml/test_column.py new file mode 100644 index 0000000..d785c97 --- /dev/null +++ b/test/test_renderer/test_dbml/test_column.py @@ -0,0 +1,120 @@ +from enum import Enum +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from pydbml.classes import Column, Note +from pydbml.renderer.dbml.default.column import ( + default_to_str, + render_options, + render_column, +) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("test", "'test'"), + (1, "1"), + (1.0, "1.0"), + (True, "True"), + ("False", "false"), + ("null", "null"), + ], +) +def test_default_to_str(input: Any, expected: str) -> None: + assert default_to_str(input) == expected + + +class TestRenderOptions: + @staticmethod + def test_refs(simple_column: Column) -> None: + simple_column.get_refs = Mock( + return_value=[ + Mock(dbml="ref1", inline=True), + Mock(dbml="ref2", inline=False), + Mock(dbml="ref3", inline=True), + ] + ) + assert render_options(simple_column) == " [ref1, ref3]" + + @staticmethod + def test_pk(simple_column_with_table: Column) -> None: + simple_column_with_table.pk = True + assert render_options(simple_column_with_table) == " [pk]" + + @staticmethod + def test_autoinc(simple_column_with_table: Column) -> None: + simple_column_with_table.autoinc = True + assert render_options(simple_column_with_table) == " [increment]" + + @staticmethod + def test_default(simple_column_with_table: Column) -> None: + simple_column_with_table.default = "6" + assert render_options(simple_column_with_table) == " [default: '6']" + + @staticmethod + def test_unique(simple_column_with_table: Column) -> None: + simple_column_with_table.unique = True + assert render_options(simple_column_with_table) == " [unique]" + + @staticmethod + def test_not_null(simple_column_with_table: Column) -> None: + simple_column_with_table.not_null = True + assert render_options(simple_column_with_table) == " [not null]" + + @staticmethod + def test_note(simple_column_with_table: Column) -> None: + simple_column_with_table.note = Note("note") + with patch( + "pydbml.renderer.dbml.default.column.note_option_to_dbml", + Mock(return_value="note"), + ): + assert render_options(simple_column_with_table) == " [note]" + + @staticmethod + def test_no_options(simple_column_with_table: Column) -> None: + assert render_options(simple_column_with_table) == "" + + @staticmethod + def test_all_options(complex_column: Column) -> None: + complex_column.get_refs = Mock( + return_value=[ + Mock(dbml="ref1", inline=True), + Mock(dbml="ref2", inline=False), + Mock(dbml="ref3", inline=True), + ] + ) + complex_column.default = "null" + with patch( + "pydbml.renderer.dbml.default.column.note_option_to_dbml", + Mock(return_value="note"), + ): + assert ( + render_options(complex_column) + == " [ref1, ref3, pk, increment, default: null, unique, not null, note]" + ) + + +class TestRenderColumn: + @staticmethod + def test_comment(simple_column_with_table: Column) -> None: + simple_column_with_table.comment = "Simple comment" + expected = '// Simple comment\n"id" integer' + assert render_column(simple_column_with_table) == expected + + @staticmethod + def test_enum(simple_column_with_table: Column, enum1: Enum) -> None: + simple_column_with_table.type = enum1 + expected = '"id" "product status"' + assert render_column(simple_column_with_table) == expected + + @staticmethod + def test_complex(complex_column_with_table: Column) -> None: + expected = ( + "// This is a counter column\n" + '"counter" "product status" [pk, increment, unique, not null, note: \'This is ' + "a note for the column']" + ) + assert render_column(complex_column_with_table) == expected diff --git a/test/test_renderer/test_dbml/test_enum.py b/test/test_renderer/test_dbml/test_enum.py new file mode 100644 index 0000000..64d2daf --- /dev/null +++ b/test/test_renderer/test_dbml/test_enum.py @@ -0,0 +1,36 @@ +from pydbml.classes import Enum, EnumItem, Note +from pydbml.renderer.dbml.default.enum import ( + render_enum_item, + render_enum, +) + + +class TestRenderEnumItem: + @staticmethod + def test_simple(enum_item1: EnumItem) -> None: + assert render_enum_item(enum_item1) == '"en-US"' + + @staticmethod + def test_comment(enum_item1: EnumItem) -> None: + enum_item1.comment = "comment" + expected = '// comment\n"en-US"' + assert render_enum_item(enum_item1) == expected + + @staticmethod + def test_note(enum_item1: EnumItem) -> None: + enum_item1.note = Note("Enum item note") + expected = "\"en-US\" [note: 'Enum item note']" + assert render_enum_item(enum_item1) == expected + + +class TestEnum: + @staticmethod + def test_simple(enum1: Enum) -> None: + expected = 'Enum "product status" {\n "production"\n "development"\n}' + assert render_enum(enum1) == expected + + @staticmethod + def test_comment(enum1: Enum) -> None: + enum1.comment = "comment" + expected = '// comment\nEnum "product status" {\n "production"\n "development"\n}' + assert render_enum(enum1) == expected diff --git a/test/test_renderer/test_dbml/test_expression.py b/test/test_renderer/test_dbml/test_expression.py new file mode 100644 index 0000000..dc189b0 --- /dev/null +++ b/test/test_renderer/test_dbml/test_expression.py @@ -0,0 +1,6 @@ +from pydbml.classes import Expression +from pydbml.renderer.dbml.default import render_expression + + +def test_render_expression(expression1: Expression) -> None: + assert render_expression(expression1) == "`SUM(amount)`" diff --git a/test/test_renderer/test_dbml/test_index.py b/test/test_renderer/test_dbml/test_index.py new file mode 100644 index 0000000..06ff29b --- /dev/null +++ b/test/test_renderer/test_dbml/test_index.py @@ -0,0 +1,92 @@ +from unittest.mock import patch, Mock + +from pydbml.classes import Index, Expression, Note +from pydbml.renderer.dbml.default.index import render_subjects, render_options, render_index + + +class TestRenderSubjects: + @staticmethod + def test_column(index1: Index) -> None: + assert render_subjects(index1.subjects) == "name" + + @staticmethod + def test_expression(index1: Index) -> None: + index1.subjects = [Expression("SUM(amount)")] + assert render_subjects(index1.subjects) == "`SUM(amount)`" + + @staticmethod + def test_string(index1: Index) -> None: + index1.subjects = ["name"] + assert render_subjects(index1.subjects) == "name" + + @staticmethod + def test_multiple(index1: Index) -> None: + index1.subjects.append(Expression("SUM(amount)")) + index1.subjects.append("name") + assert render_subjects(index1.subjects) == "(name, `SUM(amount)`, name)" + + +class TestRenderOptions: + @staticmethod + def test_name(index1: Index) -> None: + index1.name = "index_name" + assert render_options(index1) == " [name: 'index_name']" + + @staticmethod + def test_pk(index1: Index) -> None: + index1.pk = True + assert render_options(index1) == " [pk]" + + @staticmethod + def test_unique(index1: Index) -> None: + index1.unique = True + assert render_options(index1) == " [unique]" + + @staticmethod + def test_type(index1: Index) -> None: + index1.type = "hash" + assert render_options(index1) == " [type: hash]" + + @staticmethod + def test_note(index1: Index) -> None: + index1.note = Note("note") + with patch( + "pydbml.renderer.dbml.default.index.note_option_to_dbml", + Mock(return_value="note"), + ): + assert render_options(index1) == " [note]" + + @staticmethod + def test_no_options(index1: Index) -> None: + assert render_options(index1) == "" + + @staticmethod + def test_all_options(index1: Index) -> None: + index1.name = "index_name" + index1.pk = True + index1.unique = True + index1.type = "hash" + index1.note = Note("note") + with patch( + "pydbml.renderer.dbml.default.index.note_option_to_dbml", + Mock(return_value="note"), + ): + assert ( + render_options(index1) + == " [name: 'index_name', pk, unique, type: hash, note]" + ) + + +def test_render_index(index1: Index) -> None: + index1.comment = "Index comment" + with patch( + "pydbml.renderer.dbml.default.index.render_subjects", + Mock(return_value="subjects "), + ) as render_subjects_mock: + with patch( + "pydbml.renderer.dbml.default.index.render_options", + Mock(return_value="options"), + ) as render_options_mock: + assert render_index(index1) == '// Index comment\nsubjects options' + assert render_subjects_mock.called + assert render_options_mock.called diff --git a/test/test_renderer/test_dbml/test_note.py b/test/test_renderer/test_dbml/test_note.py new file mode 100644 index 0000000..0030fe7 --- /dev/null +++ b/test/test_renderer/test_dbml/test_note.py @@ -0,0 +1,22 @@ +from pydbml.classes import Note +from pydbml.renderer.dbml.default.note import prepare_text_for_dbml, render_note + + +def test_prepare_text_for_dbml() -> None: + note = Note("""Three quotes: ''', one quote: '.""") + assert prepare_text_for_dbml(note) == "Three quotes: \\''', one quote: \\'." + + +class TestRenderNote: + @staticmethod + def test_oneline() -> None: + note = Note("Note text") + assert render_note(note) == "Note {\n 'Note text'\n}" + + @staticmethod + def test_multiline() -> None: + note = Note("Note text\nwith multiple lines") + assert ( + render_note(note) + == "Note {\n '''\n Note text\n with multiple lines\n '''\n}" + ) diff --git a/test/test_renderer/test_dbml/test_project.py b/test/test_renderer/test_dbml/test_project.py new file mode 100644 index 0000000..c2d2b76 --- /dev/null +++ b/test/test_renderer/test_dbml/test_project.py @@ -0,0 +1,46 @@ +from pydbml.classes import Project, Note +from pydbml.renderer.dbml.default.project import render_items, render_project + + +class TestRenderItems: + @staticmethod + def test_oneline(): + project = Project(name="test", items={"key1": "value1"}) + assert render_items(project.items) == " key1: 'value1'\n" + + @staticmethod + def test_multiline(): + project = Project(name="test", items={"key1": "value1\nvalue2"}) + assert render_items(project.items) == " key1: '''value1\n value2'''\n" + + @staticmethod + def test_multiple(): + project = Project( + name="test", items={"key1": "value1", "key2": "value2\nnewline"} + ) + assert ( + render_items(project.items) + == " key1: 'value1'\n key2: '''value2\n newline'''\n" + ) + + +class TestRenderProject: + @staticmethod + def test_no_note() -> None: + project = Project(name="test", items={"key1": "value1"}) + expected = "Project \"test\" {\n key1: 'value1'\n}" + assert render_project(project) == expected + + @staticmethod + def test_note() -> None: + project = Project(name="test", items={"key1": "value1"}) + project.note = Note("Note text") + expected = ( + 'Project "test" {\n' + " key1: 'value1'\n" + " Note {\n" + " 'Note text'\n" + " }\n" + "}" + ) + assert render_project(project) == expected diff --git a/test/test_renderer/test_dbml/test_reference.py b/test/test_renderer/test_dbml/test_reference.py new file mode 100644 index 0000000..85fdac1 --- /dev/null +++ b/test/test_renderer/test_dbml/test_reference.py @@ -0,0 +1,123 @@ +from unittest.mock import patch + +import pytest + +from pydbml._classes.reference import Reference +from pydbml.exceptions import TableNotFoundError, DBMLError +from pydbml.renderer.dbml.default.reference import ( + validate_for_dbml, + render_inline_reference, + render_col, + render_options, + render_not_inline_reference, + render_reference, +) + + +class TestValidateFroDBML: + @staticmethod + def test_ok(reference1: Reference) -> None: + validate_for_dbml(reference1) + + @staticmethod + def test_no_table(reference1: Reference) -> None: + reference1.col2[0].table = None + with pytest.raises(TableNotFoundError): + validate_for_dbml(reference1) + + +class TestRenderInlineReference: + @staticmethod + def test_ok(reference1: Reference) -> None: + reference1.inline = True + assert render_inline_reference(reference1) == 'ref: > "products"."id"' + + @staticmethod + def test_composite(reference1: Reference) -> None: + reference1.col2.append(reference1.col2[0]) + with pytest.raises(DBMLError): + render_inline_reference(reference1) + + +class TestRendeCol: + @staticmethod + def test_single(reference1: Reference) -> None: + assert render_col(reference1.col2) == '"id"' + + @staticmethod + def test_multiple(reference1: Reference) -> None: + reference1.col2.append(reference1.col2[0]) + assert render_col(reference1.col2) == '("id", "id")' + + +class TestRenderOptions: + @staticmethod + def test_on_update(reference1: Reference) -> None: + reference1.on_update = "cascade" + assert render_options(reference1) == " [update: cascade]" + + @staticmethod + def test_on_delete(reference1: Reference) -> None: + reference1.on_delete = "set null" + assert render_options(reference1) == " [delete: set null]" + + @staticmethod + def test_both(reference1: Reference) -> None: + reference1.on_update = "cascade" + reference1.on_delete = "set null" + assert render_options(reference1) == " [update: cascade, delete: set null]" + + @staticmethod + def test_no_options(reference1: Reference) -> None: + assert render_options(reference1) == "" + + +class TestRenderNotInlineReference: + @staticmethod + def test_ok(reference1: Reference) -> None: + assert render_not_inline_reference(reference1) == ( + 'Ref {\n "orders"."product_id" > "products"."id"\n}' + ) + + @staticmethod + def test_comment(reference1: Reference) -> None: + reference1.comment = "comment" + assert render_not_inline_reference(reference1) == ( + '// comment\nRef {\n "orders"."product_id" > "products"."id"\n}' + ) + + @staticmethod + def test_name(reference1: Reference) -> None: + reference1.name = "ref_name" + assert render_not_inline_reference(reference1) == ( + 'Ref ref_name {\n "orders"."product_id" > "products"."id"\n}' + ) + + +class TestRenderReference: + @staticmethod + def test_inline(reference1: Reference) -> None: + reference1.inline = True + with patch( + "pydbml.renderer.dbml.default.reference.render_inline_reference", + return_value="inline", + ) as mock_render: + with patch( + "pydbml.renderer.dbml.default.reference.validate_for_dbml", + ) as mock_validate: + assert render_reference(reference1) == "inline" + assert mock_render.called + assert mock_validate.called + + @staticmethod + def test_not_inline(reference1: Reference) -> None: + with patch( + "pydbml.renderer.dbml.default.reference.render_not_inline_reference", + return_value="not inline", + ) as mock_render: + with patch( + "pydbml.renderer.dbml.default.reference.validate_for_dbml", + ) as mock_validate: + assert render_reference(reference1) == "not inline" + assert mock_render.called + assert mock_validate.called diff --git a/test/test_renderer/test_dbml/test_renderer.py b/test/test_renderer/test_dbml/test_renderer.py new file mode 100644 index 0000000..2e8acf6 --- /dev/null +++ b/test/test_renderer/test_dbml/test_renderer.py @@ -0,0 +1,20 @@ +from unittest.mock import Mock, patch + +from pydbml.renderer.dbml.default import DefaultDBMLRenderer + + +def test_render_db() -> None: + db = Mock( + project=Mock(), # #1 + refs=(Mock(inline=False), Mock(inline=False), Mock(inline=True)), # #2, #3 + tables=[Mock(), Mock(), Mock()], # #4, #5, #6 + enums=[Mock(), Mock()], # #7, #8 + table_groups=[Mock(), Mock()], # #9, #10 + sticky_notes=[Mock(), Mock()], # #11, #12 + ) + + with patch.object( + DefaultDBMLRenderer, "render", Mock(return_value="") + ) as render_mock: + DefaultDBMLRenderer.render_db(db) + assert render_mock.call_count == 12 diff --git a/test/test_renderer/test_dbml/test_sticky_note.py b/test/test_renderer/test_dbml/test_sticky_note.py new file mode 100644 index 0000000..bc800d0 --- /dev/null +++ b/test/test_renderer/test_dbml/test_sticky_note.py @@ -0,0 +1,17 @@ +from pydbml._classes.sticky_note import StickyNote +from pydbml.renderer.dbml.default import render_sticky_note + + +class TestRenderNote: + @staticmethod + def test_oneline() -> None: + note = StickyNote(name='mynote', text="Note text") + assert render_sticky_note(note) == "Note mynote {\n 'Note text'\n}" + + @staticmethod + def test_multiline() -> None: + note = StickyNote(name='mynote', text="Note text\nwith multiple lines") + assert ( + render_sticky_note(note) + == "Note mynote {\n '''\n Note text\n with multiple lines\n '''\n}" + ) diff --git a/test/test_renderer/test_dbml/test_table.py b/test/test_renderer/test_dbml/test_table.py new file mode 100644 index 0000000..2e584bd --- /dev/null +++ b/test/test_renderer/test_dbml/test_table.py @@ -0,0 +1,81 @@ +from pydbml import Database +from pydbml.classes import Table, Index, Note +from pydbml.renderer.dbml.default.table import ( + get_full_name_for_dbml, + render_header, + render_indexes, + render_table, +) + + +class TestGetFullNameForDBML: + @staticmethod + def test_no_schema(table1: Table) -> None: + table1.schema = "public" + assert get_full_name_for_dbml(table1) == '"products"' + + @staticmethod + def test_with_schema(table1: Table) -> None: + table1.schema = "myschema" + assert get_full_name_for_dbml(table1) == '"myschema"."products"' + + +class TestRenderHeader: + @staticmethod + def test_simple(table1: Table) -> None: + expected = 'Table "products" ' + assert render_header(table1) == expected + + @staticmethod + def test_alias(table1: Table) -> None: + table1.alias = "p" + expected = 'Table "products" as "p" ' + assert render_header(table1) == expected + + @staticmethod + def test_header_color(table1: Table) -> None: + table1.header_color = "red" + expected = 'Table "products" [headercolor: red] ' + assert render_header(table1) == expected + + @staticmethod + def test_all(table1: Table) -> None: + table1.alias = "p" + table1.header_color = "red" + expected = 'Table "products" as "p" [headercolor: red] ' + assert render_header(table1) == expected + + +class TestRenderIndexes: + @staticmethod + def test_no_indexes(table1: Table) -> None: + assert render_indexes(table1) == "" + + @staticmethod + def test_one_index(index1: Index) -> None: + assert render_indexes(index1.table) == "\n indexes {\n name\n }\n" + + +class TestRenderTable: + @staticmethod + def test_simple(db: Database, table1: Table) -> None: + db.add(table1) + expected = 'Table "products" {\n "id" integer\n "name" varchar\n}' + assert render_table(table1) == expected + + @staticmethod + def test_note_and_comment(db: Database, table1: Table) -> None: + table1.comment = "Table comment" + table1.note = Note("Table note") + db.add(table1) + expected = ( + "// Table comment\n" + 'Table "products" {\n' + ' "id" integer\n' + ' "name" varchar\n' + " Note {\n" + " 'Table note'\n" + " }\n" + "}" + ) + assert render_table(table1) == expected diff --git a/test/test_renderer/test_dbml/test_table_group.py b/test/test_renderer/test_dbml/test_table_group.py new file mode 100644 index 0000000..171a270 --- /dev/null +++ b/test/test_renderer/test_dbml/test_table_group.py @@ -0,0 +1,18 @@ +from pydbml._classes.table_group import TableGroup +from pydbml.classes import Table +from pydbml.renderer.dbml.default import render_table_group + + +def test_render_table_group(table1: Table, table2: Table, table3: Table) -> None: + tg = TableGroup( + name="mygroup", items=[table1, table2, table3], comment="My comment" + ) + expected = ( + "// My comment\n" + "TableGroup mygroup {\n" + ' "products"\n' + ' "products"\n' + ' "orders"\n' + "}" + ) + assert render_table_group(tg) == expected diff --git a/test/test_renderer/test_dbml/test_utils.py b/test/test_renderer/test_dbml/test_utils.py new file mode 100644 index 0000000..11a2151 --- /dev/null +++ b/test/test_renderer/test_dbml/test_utils.py @@ -0,0 +1,20 @@ +from pydbml.classes import Note +from pydbml.renderer.dbml.default.utils import note_option_to_dbml, comment_to_dbml + + +class TestNoteOptionsToDBML: + @staticmethod + def test_oneline() -> None: + note = Note("One line note") + expected = "note: 'One line note'" + assert note_option_to_dbml(note) == expected + + @staticmethod + def test_multiline() -> None: + note = Note("Multiline\nnote") + expected = "note: '''Multiline\nnote'''" + assert note_option_to_dbml(note) == expected + + +def test_comment_to_dbml() -> None: + assert comment_to_dbml("Simple comment") == "// Simple comment\n" diff --git a/test/test_renderer/test_sql/__init__.py b/test/test_renderer/test_sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_sql/test_default/__init__.py b/test/test_renderer/test_sql/test_default/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_sql/test_default/test_column.py b/test/test_renderer/test_sql/test_default/test_column.py new file mode 100644 index 0000000..05f9c20 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_column.py @@ -0,0 +1,19 @@ +from pydbml.classes import Column +from pydbml.renderer.sql.default import render_column + + +class TestRenderColumn: + @staticmethod + def test_simple(simple_column: Column) -> None: + expected = '"id" integer' + + assert render_column(simple_column), expected + + @staticmethod + def test_complex(complex_column: Column) -> None: + expected = ( + "-- This is a counter column\n" + '"counter" "product status" PRIMARY KEY AUTOINCREMENT UNIQUE NOT NULL DEFAULT ' + "0" + ) + assert render_column(complex_column) == expected diff --git a/test/test_renderer/test_sql/test_default/test_enum.py b/test/test_renderer/test_sql/test_default/test_enum.py new file mode 100644 index 0000000..9d76866 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_enum.py @@ -0,0 +1,39 @@ +from pydbml.classes import EnumItem, Enum +from pydbml.renderer.sql.default import render_enum, render_enum_item + + +class TestRenderEnumItem: + @staticmethod + def test_simple(enum_item1: EnumItem): + expected = "'en-US'," + assert render_enum_item(enum_item1) == expected + + @staticmethod + def test_comment(enum_item1: EnumItem): + enum_item1.comment = "Test comment" + expected = "-- Test comment\n'en-US'," + assert render_enum_item(enum_item1) == expected + + +class TestRenderEnum: + @staticmethod + def test_simple_enum(enum1: Enum) -> None: + expected = ( + 'CREATE TYPE "product status" AS ENUM (\n' + " 'production',\n" + " 'development',\n" + ");" + ) + assert render_enum(enum1) == expected + + @staticmethod + def test_comments(enum1: Enum) -> None: + enum1.comment = "Enum comment" + expected = ( + "-- Enum comment\n" + 'CREATE TYPE "product status" AS ENUM (\n' + " 'production',\n" + " 'development',\n" + ");" + ) + assert render_enum(enum1) == expected diff --git a/test/test_renderer/test_sql/test_default/test_expression.py b/test/test_renderer/test_sql/test_default/test_expression.py new file mode 100644 index 0000000..2d10c1f --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_expression.py @@ -0,0 +1,6 @@ +from pydbml.classes import Expression +from pydbml.renderer.sql.default import render_expression + + +def test_render_expression(expression1: Expression): + assert render_expression(expression1) == '(SUM(amount))' diff --git a/test/test_renderer/test_sql/test_default/test_index.py b/test/test_renderer/test_sql/test_default/test_index.py new file mode 100644 index 0000000..83859e7 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_index.py @@ -0,0 +1,82 @@ +from pydbml.classes import Column, Expression, Index +from pydbml.renderer.sql.default.index import render_subject, render_index, render_pk + + +class TestRenderSubject: + @staticmethod + def test_column(simple_column: Column) -> None: + expected = '"id"' + assert render_subject(simple_column) == expected + + @staticmethod + def test_expression(expression1: Expression) -> None: + expected = "(SUM(amount))" + assert render_subject(expression1) == expected + + @staticmethod + def test_other() -> None: + expected = "test" + assert render_subject(expected) == expected + + +class TestRenderPK: + @staticmethod + def test_comment(index1: Index) -> None: + index1.comment = "Test comment" + expected = '-- Test comment\nPRIMARY KEY ("name")' + assert render_pk(index1, '"name"') == expected + + @staticmethod + def test_no_comment(index1: Index) -> None: + expected = 'PRIMARY KEY ("name")' + assert render_pk(index1, '"name"') == expected + + +class TestRenderComponents: + @staticmethod + def test_comment(index1: Index) -> None: + index1.comment = "Test comment" + expected = '-- Test comment\nCREATE INDEX ON "products" ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_unique(index1: Index) -> None: + index1.unique = True + expected = 'CREATE UNIQUE INDEX ON "products" ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_name(index1: Index) -> None: + index1.name = "test" + expected = 'CREATE INDEX "test" ON "products" ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_no_table(index1: Index) -> None: + index1.table = None + expected = 'CREATE INDEX ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_type(index1: Index) -> None: + index1.type = "hash" + expected = 'CREATE INDEX ON "products" USING HASH ("name");' + assert render_index(index1) == expected + + +class TestRenderIndex: + @staticmethod + def test_render_index(index1: Index) -> None: + index1.comment = "Test comment" + index1.unique = True + index1.name = "test" + index1.type = "hash" + + expected = '-- Test comment\nCREATE UNIQUE INDEX "test" ON "products" USING HASH ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_render_pk(index1: Index) -> None: + index1.pk = True + expected = 'PRIMARY KEY ("name")' + assert render_index(index1) == expected diff --git a/test/test_renderer/test_sql/test_default/test_note.py b/test/test_renderer/test_sql/test_default/test_note.py new file mode 100644 index 0000000..fc3e11c --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_note.py @@ -0,0 +1,45 @@ +from textwrap import dedent + +from pydbml.classes import Note, Table, Index +from pydbml.renderer.sql.default.note import prepare_text_for_sql, generate_comment_on, render_note + + +def test_prepare_text_for_sql() -> None: + text = dedent( + """\ + First line break is preserved + second line break \\ + is 'ignored' """ + ) + expected = 'First line break is preserved\nsecond line break is "ignored" ' + assert prepare_text_for_sql(Note(text)) == expected + + +def test_generate_comment_on(note1: Note) -> None: + expected = "COMMENT ON TABLE \"table1\" IS 'Simple note';" + + assert generate_comment_on(note1, "Table", "table1") == expected + + +class TestRenderNote: + @staticmethod + def test_table_note_with_text(note1: Note, table1: Table) -> None: + table1.note = note1 + expected = "COMMENT ON TABLE \"products\" IS 'Simple note';" + assert render_note(note1) == expected + + @staticmethod + def test_table_note_without_text(note1: Note, table1: Table) -> None: + table1.note = note1 + note1.text = "" + assert render_note(note1) == "" + + @staticmethod + def test_index_note(index1: Index, multiline_note: Note) -> None: + index1.note = multiline_note + expected = dedent( + """\ + -- This is a multiline note. + -- It has multiple lines.""" + ) + assert render_note(multiline_note) == expected diff --git a/test/test_renderer/test_sql/test_default/test_reference.py b/test/test_renderer/test_sql/test_default/test_reference.py new file mode 100644 index 0000000..14bfced --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_reference.py @@ -0,0 +1,153 @@ +from textwrap import dedent +from unittest.mock import patch + +import pytest + +from pydbml.classes import Table, Reference +from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.sql.default.reference import ( + validate_for_sql, + col_names, + generate_inline_sql, + generate_not_inline_sql, + generate_many_to_many_sql, + render_reference, +) + + +def test_col_names(table1: Table) -> None: + assert col_names(table1.columns) == '"id", "name"' + + +class TestValidateForSQL: + @staticmethod + def test_ok(reference1: Reference) -> None: + validate_for_sql(reference1) + + @staticmethod + def test_faulty(reference1: Reference) -> None: + reference1.col2[0].table = None + with pytest.raises(TableNotFoundError): + validate_for_sql(reference1) + + +class TestGenerateInlineSQL: + @staticmethod + def test_simple(reference1: Reference) -> None: + expected = '{c}FOREIGN KEY ("product_id") REFERENCES "products" ("id")' + assert ( + generate_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + @staticmethod + def test_on_update_on_delete(reference1: Reference) -> None: + reference1.on_update = "cascade" + reference1.on_delete = "set null" + expected = '{c}FOREIGN KEY ("product_id") REFERENCES "products" ("id") ON UPDATE CASCADE ON DELETE SET NULL' + assert ( + generate_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + +class TestGenerateNotInlineSQL: + @staticmethod + def test_simple(reference1: Reference) -> None: + expected = 'ALTER TABLE "orders" ADD {c}FOREIGN KEY ("product_id") REFERENCES "products" ("id");' + assert ( + generate_not_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + @staticmethod + def test_on_update_on_delete(reference1: Reference) -> None: + reference1.on_update = "cascade" + reference1.on_delete = "set null" + expected = 'ALTER TABLE "orders" ADD {c}FOREIGN KEY ("product_id") REFERENCES "products" ("id") ON UPDATE CASCADE ON DELETE SET NULL;' + assert ( + generate_not_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + +def test_generate_many_to_many_sql(reference1: Reference) -> None: + reference1.type = "<>" + expected = dedent( + """\ + CREATE TABLE "orders_products" ( + "orders_product_id" integer NOT NULL, + "products_id" integer NOT NULL, + PRIMARY KEY ("orders_product_id", "products_id") + ); + + ALTER TABLE "orders_products" ADD FOREIGN KEY ("orders_product_id") REFERENCES "orders" ("product_id"); + + ALTER TABLE "orders_products" ADD FOREIGN KEY ("products_id") REFERENCES "products" ("id");""" + ) + assert generate_many_to_many_sql(reference1) == expected + + +class TestRenderReference: + @staticmethod + def test_many_to_many(reference1: Reference) -> None: + reference1.type = "<>" + with patch( + "pydbml.renderer.sql.default.reference.generate_many_to_many_sql" + ) as mock: + render_reference(reference1) + mock.assert_called_once_with(reference1) + + @staticmethod + def test_inline_to_one(reference1: Reference) -> None: + reference1.type = ">" + reference1.inline = True + with patch("pydbml.renderer.sql.default.reference.generate_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + reference1.type = "-" + render_reference(reference1) + assert mock.call_count == 2 + + @staticmethod + def test_inline_to_many(reference1: Reference) -> None: + reference1.type = "<" + reference1.inline = True + with patch("pydbml.renderer.sql.default.reference.generate_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col2, ref_col=reference1.col1 + ) + + @staticmethod + def test_not_inline_to_one(reference1: Reference) -> None: + reference1.type = ">" + reference1.inline = False + with patch("pydbml.renderer.sql.default.reference.generate_not_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + reference1.type = "-" + render_reference(reference1) + assert mock.call_count == 2 + + @staticmethod + def test_not_inline_to_many(reference1: Reference) -> None: + reference1.type = "<" + reference1.inline = False + with patch("pydbml.renderer.sql.default.reference.generate_not_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col2, ref_col=reference1.col1 + ) diff --git a/test/test_renderer/test_sql/test_default/test_renderer.py b/test/test_renderer/test_sql/test_default/test_renderer.py new file mode 100644 index 0000000..7b7606d --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_renderer.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock, patch + +from pydbml.renderer.sql.default import DefaultSQLRenderer + + +def test_render() -> None: + model = Mock() + result = DefaultSQLRenderer.render(model) + assert model.check_attributes_for_sql.called + assert result == "" + + +def test_render_db() -> None: + db = Mock( + refs=(Mock(inline=False), Mock(inline=False), Mock(inline=True)), + tables=[Mock(), Mock(), Mock()], + enums=[Mock(), Mock()], + ) + + with patch( + "pydbml.renderer.sql.default.renderer.reorder_tables_for_sql", + Mock(return_value=db.tables), + ) as reorder_mock: + with patch.object( + DefaultSQLRenderer, "render", Mock(return_value="") + ) as render_mock: + result = DefaultSQLRenderer.render_db(db) + assert reorder_mock.called + assert render_mock.call_count == 7 diff --git a/test/test_renderer/test_sql/test_default/test_table.py b/test/test_renderer/test_sql/test_default/test_table.py new file mode 100644 index 0000000..1504ef0 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_table.py @@ -0,0 +1,215 @@ +from typing import Tuple +from unittest.mock import Mock, patch + +import pytest + +import pydbml.renderer.sql.default.table +from pydbml import Database +from pydbml.classes import Table, Column, Reference, Note +from pydbml.exceptions import UnknownDatabaseError +from pydbml.renderer.sql.default.table import ( + get_references_for_sql, + get_inline_references_for_sql, + create_components, + render_column_notes, + create_body, +) + + +@pytest.fixture +def db1(): + return Database() + + +@pytest.fixture +def table1(db1: Database) -> Table: + t = Table( + name="products", + columns=[ + Column("id", "integer", pk=True), + Column("name", "varchar"), + ], + ) + db1.add(t) + return t + + +@pytest.fixture +def table2(db1: Database) -> Table: + t = Table( + name="names", + columns=[ + Column("id", "integer"), + Column("name_val", "varchar"), + ], + ) + db1.add(t) + return t + + +@pytest.fixture +def not_inline_refs( + db1: Database, table1: Table, table2: Table +) -> Tuple[Reference, Reference, Reference]: + r1 = Reference(">", table1[1], table2[1], inline=False) + r2 = Reference("-", table1[0], table2[0], inline=False) + r3 = Reference("<", table1[0], table2[1], inline=False) + db1.add(r1) + db1.add(r2) + db1.add(r3) + return r1, r2, r3 + + +@pytest.fixture +def inline_refs( + db1: Database, table1: Table, table2: Table +) -> Tuple[Reference, Reference, Reference]: + r1 = Reference(">", table1[1], table2[1], inline=True) + r2 = Reference("-", table1[0], table2[0], inline=True) + r3 = Reference("<", table1[0], table2[1], inline=True) + db1.add(r1) + db1.add(r2) + db1.add(r3) + return r1, r2, r3 + + +class TestGetReferencesForSQL: + @staticmethod + def test_get_references_for_sql_not_inline( + table1: Table, table2: Table, not_inline_refs + ) -> None: + r1, r2, r3 = not_inline_refs + assert get_references_for_sql(table1) == [r1, r2] + assert get_references_for_sql(table2) == [r3] + + @staticmethod + def test_get_references_for_sql_inline( + table1: Table, table2: Table, inline_refs + ) -> None: + r1, r2, r3 = inline_refs + assert get_references_for_sql(table1) == [r1, r2] + assert get_references_for_sql(table2) == [r3] + + @staticmethod + def test_db_not_set(table1: Table) -> None: + table1.database = None + with pytest.raises(UnknownDatabaseError): + get_references_for_sql(table1) + + +class TestGetInlineReferencesForSQL: + @staticmethod + def test_inline(table1: Table, table2: Table, inline_refs) -> None: + r1, r2, r3 = inline_refs + assert get_inline_references_for_sql(table1) == [r1, r2] + assert get_inline_references_for_sql(table2) == [r3] + + @staticmethod + def test_not_inline(table1: Table, table2: Table, not_inline_refs) -> None: + assert get_inline_references_for_sql(table1) == [] + assert get_inline_references_for_sql(table2) == [] + + @staticmethod + def test_abstract(table1: Table, table2: Table, inline_refs) -> None: + table1.abstract = table2.abstract = True + assert get_inline_references_for_sql(table1) == [] + assert get_inline_references_for_sql(table2) == [] + + +class TestCreateBody: + @staticmethod + def test_create_body() -> None: + table = Mock( + columns=[Mock(), Mock()], + indexes=[Mock(pk=True), Mock(pk=False)], + ) + with patch( + "pydbml.renderer.sql.default.table.get_inline_references_for_sql", + Mock(return_value=[Mock()]), + ) as get_inline_mock: + with patch( + "pydbml.renderer.sql.default.renderer.DefaultSQLRenderer.render", + Mock(return_value=""), + ) as render_mock: + create_body(table) + assert get_inline_mock.called + assert render_mock.call_count == 4 + + @staticmethod + def test_composite_pk(table1: Table) -> None: + table1.add_column(Column("id2", "integer", pk=True)) + expected = ( + ' "id" integer,\n' + ' "name" varchar,\n' + ' "id2" integer,\n' + ' PRIMARY KEY ("id", "id2")' + ) + assert create_body(table1) == expected + + +class TestCreateComponents: + @staticmethod + def test_simple(table1: Table) -> None: + with patch( + "pydbml.renderer.sql.default.table.create_body", Mock(return_value="body") + ) as create_body_mock: + expected = 'CREATE TABLE "products" (\nbody\n);' + assert create_components(table1) == expected + + @staticmethod + def test_comment(table1: Table) -> None: + table1.comment = "Simple comment" + with patch( + "pydbml.renderer.sql.default.table.create_body", Mock(return_value="body") + ) as create_body_mock: + expected = '-- Simple comment\n\nCREATE TABLE "products" (\nbody\n);' + assert create_components(table1) == expected + + @staticmethod + def test_indexes(table1: Table) -> None: + table1.indexes = [Mock(pk=False), Mock(pk=True)] + with patch( + "pydbml.renderer.sql.default.table.create_body", Mock(return_value="body") + ) as create_body_mock: + with patch( + "pydbml.renderer.sql.default.renderer.DefaultSQLRenderer.render", + Mock(return_value="index"), + ) as render_mock: + expected = 'CREATE TABLE "products" (\nbody\n);\n\nindex' + assert create_components(table1) == expected + + +class TestRenderColumnNotes: + @staticmethod + def test_notes(table1: Table) -> None: + table1.columns[0].note = Note("First column note") + table1.columns[1].note = Note("Second column note") + expected = ( + "\n" + "\n" + 'COMMENT ON COLUMN "products"."id" IS \'First column note\';\n' + "\n" + 'COMMENT ON COLUMN "products"."name" IS \'Second column note\';' + ) + assert render_column_notes(table1) == expected + + @staticmethod + def test_no_notes(table1: Table) -> None: + assert render_column_notes(table1) == "" + + +def test_render_table(table1: Table) -> None: + table1.note = Mock(sql="-- Simple note") + with patch( + "pydbml.renderer.sql.default.table.create_components", + Mock(return_value="components"), + ) as create_components_mock: + with patch( + "pydbml.renderer.sql.default.table.render_column_notes", + Mock(return_value="\n\ncolumn notes"), + ) as render_column_notes_mock: + assert pydbml.renderer.sql.default.table.render_table(table1) == ( + "components\n\n-- Simple note\n\ncolumn notes" + ) + assert create_components_mock.called + assert render_column_notes_mock.called diff --git a/test/test_renderer/test_sql/test_default/test_utils.py b/test/test_renderer/test_sql/test_default/test_utils.py new file mode 100644 index 0000000..82c3477 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_utils.py @@ -0,0 +1,50 @@ +from unittest.mock import Mock + +from pydbml.classes import Enum +from pydbml.constants import ONE_TO_MANY, MANY_TO_ONE, MANY_TO_MANY +from pydbml.renderer.sql.default.utils import ( + get_full_name_for_sql, + reorder_tables_for_sql, +) + + +class TestGetFullNameForSQL: + @staticmethod + def test_public(enum1: Enum) -> None: + assert get_full_name_for_sql(enum1) == '"product status"' + + @staticmethod + def test_schema(enum1: Enum) -> None: + enum1.schema = "myschema" + assert get_full_name_for_sql(enum1) == '"myschema"."product status"' + + +def test_reorder_tables() -> None: + t1 = Mock(name="table1") # 1 ref + t2 = Mock(name="table2") # 2 refs + t3 = Mock(name="table3") + t4 = Mock(name="table4") # 1 ref + t5 = Mock(name="table5") + t6 = Mock(name="table6") # 3 refs + t7 = Mock(name="table7") + t8 = Mock(name="table8") + t9 = Mock(name="table9") + t10 = Mock(name="table10") + + refs = [ + Mock(type=ONE_TO_MANY, table1=t1, table2=t2, inline=True), + Mock(type=MANY_TO_ONE, table1=t4, table2=t3, inline=True), + Mock(type=ONE_TO_MANY, table1=t6, table2=t2, inline=True), + Mock(type=ONE_TO_MANY, table1=t7, table2=t6, inline=True), + Mock(type=MANY_TO_ONE, table1=t6, table2=t8, inline=True), + Mock(type=ONE_TO_MANY, table1=t9, table2=t6, inline=True), + Mock( + type=ONE_TO_MANY, table1=t1, table2=t2, inline=False + ), # ignored not inline + Mock(type=ONE_TO_MANY, table1=t10, table2=t1, inline=True), + Mock(type=MANY_TO_MANY, table1=t1, table2=t2, inline=True), # ignored m2m + ] + original = [t1, t2, t3, t4, t5, t6, t7, t8, t9, t10] + expected = [t6, t2, t1, t4, t3, t5, t7, t8, t9, t10] + result = reorder_tables_for_sql(original, refs) # type: ignore + assert expected == result diff --git a/test/test_tools.py b/test/test_tools.py index 57943f8..6c1f26c 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -1,10 +1,10 @@ from unittest import TestCase from pydbml.classes import Note -from pydbml.tools import comment_to_dbml, remove_indentation -from pydbml.tools import comment_to_sql +from pydbml.tools import remove_indentation +from pydbml.renderer.sql.default.utils import comment_to_sql from pydbml.tools import indent -from pydbml.tools import note_option_to_dbml +from pydbml.renderer.dbml.default.utils import note_option_to_dbml, comment_to_dbml from pydbml.tools import strip_empty_lines From b5c16620c0736c72520e8a4b40e8af5206957061 Mon Sep 17 00:00:00 2001 From: Daniel Minukhin Date: Thu, 25 Jul 2024 21:17:30 +0200 Subject: [PATCH 102/125] Dynamic Options (#41) --- README.md | 1 + docs/properties.md | 77 ++++++++ pydbml/_classes/column.py | 8 +- pydbml/_classes/project.py | 2 +- pydbml/_classes/table.py | 7 +- pydbml/database.py | 8 +- pydbml/definitions/column.py | 24 ++- pydbml/definitions/table.py | 21 ++- pydbml/parser/blueprints.py | 8 +- pydbml/parser/parser.py | 47 +++-- pydbml/renderer/base.py | 4 +- pydbml/renderer/dbml/default/__init__.py | 3 +- pydbml/renderer/dbml/default/column.py | 6 +- pydbml/renderer/dbml/default/note.py | 18 +- pydbml/renderer/dbml/default/sticky_note.py | 13 +- pydbml/renderer/dbml/default/table.py | 9 +- pydbml/renderer/dbml/default/utils.py | 24 ++- pydbml/renderer/sql/default/column.py | 2 +- test.sh | 6 +- test/conftest.py | 3 +- test/test_definitions/test_column.py | 177 ++++++++++-------- test/test_definitions/test_table.py | 155 ++++++++------- test/test_renderer/test_dbml/test_column.py | 26 ++- test/test_renderer/test_dbml/test_note.py | 7 +- .../test_dbml/test_sticky_note.py | 2 +- test/test_renderer/test_dbml/test_table.py | 28 +++ test_schema.dbml | 1 - 27 files changed, 452 insertions(+), 235 deletions(-) create mode 100644 docs/properties.md diff --git a/README.md b/README.md index 2f1e426..563c589 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. * [Class Reference](docs/classes.md) * [Creating DBML schema](docs/creating_schema.md) * [Upgrading to PyDBML 1.0.0](docs/upgrading.md) +* [Arbitrary Properties](docs/properties.md) > PyDBML requires Python v3.8 or higher diff --git a/docs/properties.md b/docs/properties.md new file mode 100644 index 0000000..f3f878a --- /dev/null +++ b/docs/properties.md @@ -0,0 +1,77 @@ +# Arbitrary Properties + +Since 1.1.0 PyDBML supports arbitrary properties in Table and Column definitions. Arbitrary properties is a dictionary of key-value pairs that can be added to any Table or Column manually, or parsed from a DBML file. This may be useful for extending the standard DBML syntax or keeping additional information in the schema. + +Arbitrary properties are turned off by default. To enable parsing properties in DBML files, set `allow_properties` argument to `True` in the parser call. To enable rendering properties in the output DBML of an existing database, set `allow_properties` database attribute to `True`. + +## Properties in DBML + +In a DBML file arbitrary properties are defined like this: + +```python +>>> dbml_str = ''' +... Table "products" { +... "id" integer +... "name" varchar [col_prop: 'some value'] +... table_prop: 'another value' +... }''' + +``` + +In this example we've added a property `col_prop` to the column `name` and a property `table_prop` to the table `products`. Note that property values must me single-quoted strings. Multiline strings (with `'''`) are supported. + +Now let's parse this DBML string: + +```python +>>> from pydbml import PyDBML +>>> mydb = PyDBML(dbml_str, allow_properties=True) +>>> mydb.tables[0].columns[1].properties +{'col_prop': 'some value'} +>>> mydb.tables[0].properties +{'table_prop': 'another value'} + +``` + +The `allow_properties=True` argument is crucial here. Without it, the parser will raise syntax errors. + +## Rendering Properties + +To render properties in the output DBML, set `allow_properties` attribute of the Database object to `True`. If you parsed the DBML with `allow_properties=True`, the result database will already have this attribute set to `True`. + +We will reuse the `mydb` database from the previous example: + +```python +>>> print(mydb.allow_properties) +True + +``` + +Let's set a new property on the table and render the DBML: + +```python +>>> mydb.tables[0].properties['new_prop'] = 'Multiline\nproperty\nvalue' +>>> print(mydb.dbml) +Table "products" { + "id" integer + "name" varchar [col_prop: 'some value'] + + table_prop: 'another value' + new_prop: ''' + Multiline + property + value''' +} + +``` + +As you see, properties are also rendered in the output DBML correctly. But if `allow_properties` is set to `False`, the properties will be ignored: + +```python +>>> mydb.allow_properties = False +>>> print(mydb.dbml) +Table "products" { + "id" integer + "name" varchar +} + +``` diff --git a/pydbml/_classes/column.py b/pydbml/_classes/column.py index c2ff98b..4868cd1 100644 --- a/pydbml/_classes/column.py +++ b/pydbml/_classes/column.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict from typing import Optional from typing import TYPE_CHECKING from typing import Union @@ -29,8 +29,9 @@ def __init__(self, autoinc: bool = False, default: Optional[Union[str, int, bool, float, Expression]] = None, note: Optional[Union[Note, str]] = None, - # ref_blueprints: Optional[List[ReferenceBlueprint]] = None, - comment: Optional[str] = None): + comment: Optional[str] = None, + properties: Union[Dict[str, str], None] = None + ): self.name = name self.type = type self.unique = unique @@ -39,6 +40,7 @@ def __init__(self, self.autoinc = autoinc self.comment = comment self.note = Note(note) + self.properties = properties if properties else {} self.default = default self.table: Optional['Table'] = None diff --git a/pydbml/_classes/project.py b/pydbml/_classes/project.py index 1efa235..4e303c5 100644 --- a/pydbml/_classes/project.py +++ b/pydbml/_classes/project.py @@ -16,7 +16,7 @@ def __init__(self, comment: Optional[str] = None): self.database = None self.name = name - self.items = items + self.items = items or {} self.note = Note(note) self.comment = comment diff --git a/pydbml/_classes/table.py b/pydbml/_classes/table.py index b5132bf..b8f340d 100644 --- a/pydbml/_classes/table.py +++ b/pydbml/_classes/table.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, Dict from typing import List from typing import Optional from typing import TYPE_CHECKING @@ -32,7 +32,9 @@ def __init__(self, note: Optional[Union[Note, str]] = None, header_color: Optional[str] = None, comment: Optional[str] = None, - abstract: bool = False): + abstract: bool = False, + properties: Union[Dict[str, str], None] = None + ): self.database: Optional[Database] = None self.name = name self.schema = schema @@ -47,6 +49,7 @@ def __init__(self, self.header_color = header_color self.comment = comment self.abstract = abstract + self.properties = properties if properties else {} @property def note(self): diff --git a/pydbml/database.py b/pydbml/database.py index 5585751..41dd589 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -4,25 +4,24 @@ from typing import Optional from typing import Union +from ._classes.sticky_note import StickyNote from .classes import Enum from .classes import Project from .classes import Reference from .classes import Table from .classes import TableGroup -from ._classes.sticky_note import StickyNote from .exceptions import DatabaseValidationError - from .renderer.base import BaseRenderer from .renderer.dbml.default.renderer import DefaultDBMLRenderer from .renderer.sql.default import DefaultSQLRenderer -from .renderer.sql.default.utils import reorder_tables_for_sql class Database: def __init__( self, sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, - dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, + allow_properties: bool = False ) -> None: self.sql_renderer = sql_renderer self.dbml_renderer = dbml_renderer @@ -33,6 +32,7 @@ def __init__( self.table_groups: List['TableGroup'] = [] self.sticky_notes: List['StickyNote'] = [] self.project: Optional['Project'] = None + self.allow_properties = allow_properties def __repr__(self) -> str: return f"" diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index 5c5e4b3..a33ffdd 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -39,8 +39,9 @@ ) ) +prop = name + pp.Suppress(":") + string_literal -column_setting = _ + ( +column_setting = ( pp.CaselessLiteral("not null").set_parse_action( lambda s, loc, tok: True )('notnull') @@ -54,8 +55,13 @@ | note('note') | ref_inline('ref*') | default('default') -) + _ -column_settings = '[' - column_setting + ("," + column_setting)[...] + ']' + c +) + +column_setting_with_property = column_setting | prop.set_results_name('property', list_all_matches=True) + +column_settings = '[' - (_ + column_setting + _) + ("," + column_setting)[...] + ']' + c + +column_settings_with_properties = '[' - (_ + column_setting_with_property + _) + ("," + column_setting_with_property)[...] + ']' + c def parse_column_settings(s, loc, tok): @@ -79,10 +85,13 @@ def parse_column_settings(s, loc, tok): result['ref_blueprints'] = list(tok['ref']) if 'comment' in tok: result['comment'] = tok['comment'][0] + if 'property' in tok: + result['properties'] = {k: v for k, v in tok['property']} return result column_settings.set_parse_action(parse_column_settings) +column_settings_with_properties.set_parse_action(parse_column_settings) constraint = pp.CaselessLiteral("unique") | pp.CaselessLiteral("pk") @@ -95,6 +104,14 @@ def parse_column_settings(s, loc, tok): ) + n +table_column_with_properties = _c + ( + name('name') + + column_type('type') + + constraint[...]('constraints') + c + + column_settings_with_properties('settings')[0, 1] +) + n + + def parse_column(s, loc, tok): ''' address varchar(255) [unique, not null, note: 'to include unit number'] @@ -124,3 +141,4 @@ def parse_column(s, loc, tok): table_column.set_parse_action(parse_column) +table_column_with_properties.set_parse_action(parse_column) diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index 1b4a58e..fdf051d 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -1,12 +1,12 @@ import pyparsing as pp -from .column import table_column +from .column import table_column, table_column_with_properties from .common import _ from .common import _c from .common import end from .common import note from .common import note_object -from .generic import name +from .generic import name, string_literal from .index import indexes from pydbml.parser.blueprints import TableBlueprint @@ -42,9 +42,13 @@ def parse_table_settings(s, loc, tok): note_element = note | note_object +prop = name + pp.Suppress(":") + string_literal + table_element = _ + (note_element('note') | indexes('indexes')) + _ +table_element_with_property = _ + (note_element('note') | indexes('indexes') | prop.set_results_name('property', list_all_matches=True)) + _ table_body = table_column[1, ...]('columns') + _ + table_element[...] +table_body_with_properties = table_column_with_properties[1, ...]('columns') + _ + table_element_with_property[...] table_name = (name('schema') + '.' + name('name')) | (name('name')) @@ -56,6 +60,14 @@ def parse_table_settings(s, loc, tok): + '{' - table_body + _ + '}' ) + end +table_with_properties = _c + ( + pp.CaselessLiteral("table").suppress() + + table_name + + alias('alias')[0, 1] + + table_settings('settings')[0, 1] + _ + + '{' - table_body_with_properties + _ + '}' +) + end + def parse_table(s, loc, tok): ''' @@ -86,12 +98,15 @@ def parse_table(s, loc, tok): init_dict['indexes'] = tok['indexes'] if 'columns' in tok: init_dict['columns'] = tok['columns'] - if'comment_before' in tok: + if 'comment_before' in tok: comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment + if 'property' in tok: + init_dict['properties'] = {k: v for k, v in tok['property']} result = TableBlueprint(**init_dict) return result table.set_parse_action(parse_table) +table_with_properties.set_parse_action(parse_table) diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 1077c29..6632a45 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -133,6 +133,7 @@ class ColumnBlueprint(Blueprint): note: Optional[NoteBlueprint] = None ref_blueprints: Optional[List[ReferenceBlueprint]] = None comment: Optional[str] = None + properties: Optional[Dict[str, str]] = None def build(self) -> 'Column': if isinstance(self.default, ExpressionBlueprint): @@ -155,7 +156,8 @@ def build(self) -> 'Column': autoinc=self.autoinc, default=self.default, note=self.note.build() if self.note else None, - comment=self.comment + comment=self.comment, + properties=self.properties, ) @@ -194,6 +196,7 @@ class TableBlueprint(Blueprint): note: Optional[NoteBlueprint] = None header_color: Optional[str] = None comment: Optional[str] = None + properties: Optional[Dict[str, str]] = None def build(self) -> 'Table': result = Table( @@ -202,7 +205,8 @@ def build(self) -> 'Table': alias=self.alias, note=self.note.build() if self.note else None, header_color=self.header_color, - comment=self.comment + comment=self.comment, + properties=self.properties ) columns = self.columns or [] indexes = self.indexes or [] diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 7b0d854..2954188 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -15,7 +15,7 @@ from pydbml.definitions.project import project from pydbml.definitions.reference import ref from pydbml.definitions.sticky_note import sticky_note -from pydbml.definitions.table import table +from pydbml.definitions.table import table, table_with_properties from pydbml.definitions.table_group import table_group from pydbml.exceptions import TableNotFoundError from pydbml.tools import remove_bom @@ -25,11 +25,11 @@ from .blueprints import TableBlueprint from .blueprints import TableGroupBlueprint -pp.ParserElement.set_default_whitespace_chars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(" \t\r") class PyDBML: - ''' + """ PyDBML parser factory. If properly initiated, returns parsed Database. Usage option 1: @@ -44,22 +44,26 @@ class PyDBML: >>> # or >>> from pathlib import Path >>> p = PyDBML(Path('test_schema.dbml')) - ''' + """ - def __new__(cls, source_: Optional[Union[str, Path, TextIOWrapper]] = None): + def __new__( + cls, + source_: Optional[Union[str, Path, TextIOWrapper]] = None, + allow_properties: bool = False, + ): if source_ is not None: if isinstance(source_, str): source = source_ elif isinstance(source_, Path): - with open(source_, encoding='utf8') as f: + with open(source_, encoding="utf8") as f: source = f.read() elif isinstance(source_, TextIOWrapper): source = source_.read() else: - raise TypeError('Source must be str, path or file stream') + raise TypeError("Source must be str, path or file stream") source = remove_bom(source) - return cls.parse(source) + return cls.parse(source, allow_properties=allow_properties) else: return super().__new__(cls) @@ -67,9 +71,9 @@ def __repr__(self): return "" @staticmethod - def parse(text: str) -> Database: + def parse(text: str, allow_properties: bool = False) -> Database: text = remove_bom(text) - parser = PyDBMLParser(text) + parser = PyDBMLParser(text, allow_properties=allow_properties) return parser.parse() @staticmethod @@ -77,7 +81,7 @@ def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: if isinstance(file, TextIOWrapper): source = file.read() else: - with open(file, encoding='utf8') as f: + with open(file, encoding="utf8") as f: source = f.read() source = remove_bom(source) parser = PyDBMLParser(source) @@ -85,7 +89,7 @@ def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: class PyDBMLParser: - def __init__(self, source: str): + def __init__(self, source: str, allow_properties: bool = False): self.database = None self.ref_blueprints: List[ReferenceBlueprint] = [] @@ -96,6 +100,7 @@ def __init__(self, source: str): self.enums: List[EnumBlueprint] = [] self.project: Optional[ProjectBlueprint] = None self.sticky_notes: List[StickyNoteBlueprint] = [] + self._allow_properties = allow_properties def parse(self): self._set_syntax() @@ -107,7 +112,9 @@ def __repr__(self): return "" def _set_syntax(self): - table_expr = table.copy() + table_expr = ( + table_with_properties.copy() if self._allow_properties else table.copy() + ) ref_expr = ref.copy() enum_expr = enum.copy() table_group_expr = table_group.copy() @@ -129,7 +136,7 @@ def _set_syntax(self): | project_expr | note_expr ) - self._syntax = expr[...] + ('\n' | comment)[...] + pp.StringEnd() + self._syntax = expr[...] + ("\n" | comment)[...] + pp.StringEnd() def parse_blueprint(self, s, loc, tok): blueprint = tok[0] @@ -167,23 +174,23 @@ def parse_blueprint(self, s, loc, tok): elif isinstance(blueprint, StickyNoteBlueprint): self.sticky_notes.append(blueprint) else: - raise RuntimeError(f'type unknown: {blueprint}') + raise RuntimeError(f"type unknown: {blueprint}") blueprint.parser = self - def locate_table(self, schema: str, name: str) -> 'Table': + def locate_table(self, schema: str, name: str) -> "Table": if not self.database: - raise RuntimeError('Database is not ready') + raise RuntimeError("Database is not ready") # first by alias result = self.database.table_dict.get(name) if result is None: - full_name = f'{schema}.{name}' + full_name = f"{schema}.{name}" result = self.database.table_dict.get(full_name) if result is None: - raise TableNotFoundError(f'Table {full_name} not present in the database') + raise TableNotFoundError(f"Table {full_name} not present in the database") return result def build_database(self): - self.database = Database() + self.database = Database(allow_properties=self._allow_properties) for enum_bp in self.enums: self.database.add(enum_bp.build()) for table_bp in self.tables: diff --git a/pydbml/renderer/base.py b/pydbml/renderer/base.py index a8a0c77..4f1dd28 100644 --- a/pydbml/renderer/base.py +++ b/pydbml/renderer/base.py @@ -23,13 +23,13 @@ def render(cls, model) -> str: `self._unsupported_renderer` that by default returns an empty string. """ - return cls.model_renderers.get(type(model), cls._unsupported_renderer)(model) + return cls.model_renderers.get(type(model), cls._unsupported_renderer)(model) # type: ignore @classmethod def renderer_for(cls, model_cls: Type) -> Callable: """A decorator to register a renderer for a model class.""" def decorator(func) -> Callable: - cls.model_renderers[model_cls] = func + cls.model_renderers[model_cls] = func # type: ignore return func return decorator diff --git a/pydbml/renderer/dbml/default/__init__.py b/pydbml/renderer/dbml/default/__init__.py index 82f7386..ca43367 100644 --- a/pydbml/renderer/dbml/default/__init__.py +++ b/pydbml/renderer/dbml/default/__init__.py @@ -1,10 +1,11 @@ -from .renderer import DefaultDBMLRenderer from .column import render_column from .enum import render_enum, render_enum_item from .expression import render_expression from .index import render_index +from .note import render_note from .project import render_project from .reference import render_reference +from .renderer import DefaultDBMLRenderer from .sticky_note import render_sticky_note from .table import render_table from .table_group import render_table_group diff --git a/pydbml/renderer/dbml/default/column.py b/pydbml/renderer/dbml/default/column.py index 6b4dd62..a06a580 100644 --- a/pydbml/renderer/dbml/default/column.py +++ b/pydbml/renderer/dbml/default/column.py @@ -2,7 +2,7 @@ from pydbml.classes import Column, Enum, Expression from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer -from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml, quote_string from pydbml.renderer.sql.default.utils import get_full_name_for_sql @@ -32,6 +32,10 @@ def render_options(model: Column) -> str: options.append('not null') if model.note: options.append(note_option_to_dbml(model.note)) + if model.properties: + if model.table and model.table.database and model.table.database.allow_properties: + for key, value in model.properties.items(): + options.append(f'{key}: {quote_string(value)}') if options: return f' [{", ".join(options)}]' diff --git a/pydbml/renderer/dbml/default/note.py b/pydbml/renderer/dbml/default/note.py index b07e023..0bff77f 100644 --- a/pydbml/renderer/dbml/default/note.py +++ b/pydbml/renderer/dbml/default/note.py @@ -1,24 +1,14 @@ -import re from textwrap import indent from pydbml.classes import Note from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer - - -def prepare_text_for_dbml(model): - '''Escape single quotes''' - pattern = re.compile(r"('''|')") - return pattern.sub(r'\\\1', model.text) +from pydbml.renderer.dbml.default.utils import quote_string @DefaultDBMLRenderer.renderer_for(Note) def render_note(model: Note) -> str: - text = prepare_text_for_dbml(model) - if '\n' in text: - note_text = f"'''\n{text}\n'''" - else: - note_text = f"'{text}'" + text = quote_string(model.text) - note_text = indent(note_text, ' ') - result = f'Note {{\n{note_text}\n}}' + text = indent(text, ' ') + result = f'Note {{\n{text}\n}}' return result diff --git a/pydbml/renderer/dbml/default/sticky_note.py b/pydbml/renderer/dbml/default/sticky_note.py index e2af3c8..36d1122 100644 --- a/pydbml/renderer/dbml/default/sticky_note.py +++ b/pydbml/renderer/dbml/default/sticky_note.py @@ -1,18 +1,15 @@ from textwrap import indent + from pydbml.classes import StickyNote -from pydbml.renderer.dbml.default.note import prepare_text_for_dbml from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import quote_string @DefaultDBMLRenderer.renderer_for(StickyNote) def render_sticky_note(model: StickyNote) -> str: - text = prepare_text_for_dbml(model) - if '\n' in text: - note_text = f"'''\n{text}\n'''" - else: - note_text = f"'{text}'" + text = quote_string(model.text) - note_text = indent(note_text, ' ') - result = f'Note {model.name} {{\n{note_text}\n}}' + text = indent(text, ' ') + result = f'Note {model.name} {{\n{text}\n}}' return result diff --git a/pydbml/renderer/dbml/default/table.py b/pydbml/renderer/dbml/default/table.py index bfd02ed..c896066 100644 --- a/pydbml/renderer/dbml/default/table.py +++ b/pydbml/renderer/dbml/default/table.py @@ -3,7 +3,7 @@ from pydbml.classes import Table from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer -from pydbml.renderer.dbml.default.utils import comment_to_dbml +from pydbml.renderer.dbml.default.utils import comment_to_dbml, quote_string def get_full_name_for_dbml(model) -> str: @@ -42,6 +42,13 @@ def render_table(model: Table) -> str: result += '{\n' columns_str = '\n'.join(DefaultDBMLRenderer.render(c) for c in model.columns) result += indent(columns_str, ' ') + '\n' + + if model.properties: + if model.database and model.database.allow_properties: + properties_str = '\n' + '\n'.join(f'{key}: {quote_string(value)}' for key, value in model.properties.items()) + '\n' + properties_str = indent(properties_str, ' ') + result += properties_str + if model.note: result += indent(model.note.dbml, ' ') + '\n' diff --git a/pydbml/renderer/dbml/default/utils.py b/pydbml/renderer/dbml/default/utils.py index 60f8544..c46e893 100644 --- a/pydbml/renderer/dbml/default/utils.py +++ b/pydbml/renderer/dbml/default/utils.py @@ -1,12 +1,30 @@ -from pydbml.renderer.dbml.default.note import prepare_text_for_dbml +import re +from typing import TYPE_CHECKING + from pydbml.tools import comment +if TYPE_CHECKING: + from pydbml.classes import Note + + +def prepare_text_for_dbml(text: str) -> str: + '''Escape single quotes''' + pattern = re.compile(r"('''|')") + return pattern.sub(r'\\\1', text) + + +def quote_string(text: str) -> str: + if '\n' in text: + return f"'''\n{prepare_text_for_dbml(text)}'''" + else: + return f"'{prepare_text_for_dbml(text)}'" + def note_option_to_dbml(note: 'Note') -> str: if '\n' in note.text: - return f"note: '''{prepare_text_for_dbml(note)}'''" + return f"note: '''{prepare_text_for_dbml(note.text)}'''" else: - return f"note: '{prepare_text_for_dbml(note)}'" + return f"note: '{prepare_text_for_dbml(note.text)}'" def comment_to_dbml(val: str) -> str: diff --git a/pydbml/renderer/sql/default/column.py b/pydbml/renderer/sql/default/column.py index d890ad3..d04b061 100644 --- a/pydbml/renderer/sql/default/column.py +++ b/pydbml/renderer/sql/default/column.py @@ -31,7 +31,7 @@ def render_column(model: Column) -> str: if isinstance(model.default, Expression): default = DefaultSQLRenderer.render(model.default) else: - default = model.default + default = model.default # type: ignore components.append(f'DEFAULT {default}') result = comment_to_sql(model.comment) if model.comment else '' diff --git a/test.sh b/test.sh index 54ae6b8..8e64632 100755 --- a/test.sh +++ b/test.sh @@ -1,6 +1,2 @@ -python3 -m doctest README.md &&\ - python3 -m doctest docs/classes.md &&\ - python3 -m doctest docs/upgrading.md &&\ - python3 -m doctest docs/creating_schema.md &&\ - python3 -m unittest discover &&\ +pytest --doctest-glob="*.md" &&\ mypy pydbml --ignore-missing-imports diff --git a/test/conftest.py b/test/conftest.py index d3ba83f..8ef74aa 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -54,7 +54,8 @@ def complex_column(enum1: Enum) -> Column: not_null=True, default=0, comment='This is a counter column', - note=Note('This is a note for the column') + note=Note('This is a note for the column'), + properties={'foo': 'bar', 'baz': "qux\nqux"} ) diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index b8b6ea1..90767e2 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -4,7 +4,7 @@ from pyparsing import ParseSyntaxException from pyparsing import ParserElement -from pydbml.definitions.column import column_setting +from pydbml.definitions.column import column_setting, table_column_with_properties from pydbml.definitions.column import column_settings from pydbml.definitions.column import column_type from pydbml.definitions.column import constraint @@ -13,37 +13,37 @@ from pydbml.parser.blueprints import ExpressionBlueprint -ParserElement.set_default_whitespace_chars(' \t\r') +ParserElement.set_default_whitespace_chars(" \t\r") class TestColumnType(TestCase): def test_simple(self) -> None: - val = 'int' + val = "int" res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_quoted(self) -> None: val = '"mytype"' res = column_type.parse_string(val, parseAll=True) - self.assertEqual(res[0], 'mytype') + self.assertEqual(res[0], "mytype") def test_with_schema(self) -> None: - val = 'myschema.mytype' + val = "myschema.mytype" res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_expression(self) -> None: - val = 'varchar(255)' + val = "varchar(255)" res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_array(self) -> None: - val = 'int[]' + val = "int[]" res = column_type.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_symbols(self) -> None: - val = '(*#^)' + val = "(*#^)" with self.assertRaises(ParseException): column_type.parse_string(val, parseAll=True) @@ -57,15 +57,15 @@ class TestDefault(TestCase): def test_string(self) -> None: val = "default: 'string'" val2 = "default: \n\n'string'" - expected = 'string' + expected = "string" res = default.parse_string(val, parseAll=True) self.assertEqual(res[0], expected) res = default.parse_string(val2, parseAll=True) self.assertEqual(res[0], expected) def test_expression(self) -> None: - expr1 = 'datetime.now()' - expr2 = 'datetime\nnow()' + expr1 = "datetime.now()" + expr2 = "datetime\nnow()" val = f"default: `{expr1}`" val2 = f"default: `{expr2}`" val3 = f"default: ``" @@ -77,20 +77,20 @@ def test_expression(self) -> None: self.assertEqual(res[0].text, expr2) res = default.parse_string(val3, parseAll=True) self.assertIsInstance(res[0], ExpressionBlueprint) - self.assertEqual(res[0].text, '') + self.assertEqual(res[0].text, "") def test_bool(self) -> None: - vals = ['true', 'false', 'null'] - exps = [True, False, 'NULL'] + vals = ["true", "false", "null"] + exps = [True, False, "NULL"] while len(vals) > 0: - res = default.parse_string(f'default: {vals.pop()}', parseAll=True) + res = default.parse_string(f"default: {vals.pop()}", parseAll=True) self.assertEqual(exps.pop(), res[0]) def test_numbers(self) -> None: vals = [0, 17, 13.3, 2.0] while len(vals) > 0: cur = vals.pop() - res = default.parse_string(f'default: {cur}', parseAll=True) + res = default.parse_string(f"default: {cur}", parseAll=True) self.assertEqual((cur), res[0]) def test_wrong(self) -> None: @@ -101,20 +101,20 @@ def test_wrong(self) -> None: class TestColumnSetting(TestCase): def test_pass(self) -> None: - vals = ['not null', - 'null', - 'primary key', - 'pk', - 'unique', - 'default: 123', - 'ref: > table.column'] + vals = [ + "not null", + "null", + "primary key", + "pk", + "unique", + "default: 123", + "ref: > table.column", + ] for val in vals: column_setting.parse_string(val, parseAll=True) def test_fail(self) -> None: - vals = ['wrong', - '`null`', - '"pk"'] + vals = ["wrong", "`null`", '"pk"'] for val in vals: with self.assertRaises(ParseException): column_setting.parse_string(val, parseAll=True) @@ -122,38 +122,42 @@ def test_fail(self) -> None: class TestColumnSettings(TestCase): def test_nulls(self) -> None: - res = column_settings.parse_string('[NULL]', parseAll=True) - self.assertNotIn('not_null', res[0]) - res = column_settings.parse_string('[NOT NULL]', parseAll=True) - self.assertTrue(res[0]['not_null']) - res = column_settings.parse_string('[NULL, NOT NULL]', parseAll=True) - self.assertTrue(res[0]['not_null']) - res = column_settings.parse_string('[NOT NULL, NULL]', parseAll=True) - self.assertNotIn('not_null', res[0]) + res = column_settings.parse_string("[NULL]", parseAll=True) + self.assertNotIn("not_null", res[0]) + res = column_settings.parse_string("[NOT NULL]", parseAll=True) + self.assertTrue(res[0]["not_null"]) + res = column_settings.parse_string("[NULL, NOT NULL]", parseAll=True) + self.assertTrue(res[0]["not_null"]) + res = column_settings.parse_string("[NOT NULL, NULL]", parseAll=True) + self.assertNotIn("not_null", res[0]) def test_pk(self) -> None: - res = column_settings.parse_string('[pk]', parseAll=True) - self.assertTrue(res[0]['pk']) - res = column_settings.parse_string('[primary key]', parseAll=True) - self.assertTrue(res[0]['pk']) - res = column_settings.parse_string('[primary key, pk]', parseAll=True) - self.assertTrue(res[0]['pk']) + res = column_settings.parse_string("[pk]", parseAll=True) + self.assertTrue(res[0]["pk"]) + res = column_settings.parse_string("[primary key]", parseAll=True) + self.assertTrue(res[0]["pk"]) + res = column_settings.parse_string("[primary key, pk]", parseAll=True) + self.assertTrue(res[0]["pk"]) def test_unique_increment(self) -> None: - res = column_settings.parse_string('[unique, increment]', parseAll=True) - self.assertTrue(res[0]['unique']) - self.assertTrue(res[0]['autoinc']) + res = column_settings.parse_string("[unique, increment]", parseAll=True) + self.assertTrue(res[0]["unique"]) + self.assertTrue(res[0]["autoinc"]) def test_refs(self) -> None: - res = column_settings.parse_string('[ref: > table.column]', parseAll=True) - self.assertEqual(len(res[0]['ref_blueprints']), 1) - res = column_settings.parse_string('[ref: - table.column, ref: < table2.column2]', parseAll=True) - self.assertEqual(len(res[0]['ref_blueprints']), 2) + res = column_settings.parse_string("[ref: > table.column]", parseAll=True) + self.assertEqual(len(res[0]["ref_blueprints"]), 1) + res = column_settings.parse_string( + "[ref: - table.column, ref: < table2.column2]", parseAll=True + ) + self.assertEqual(len(res[0]["ref_blueprints"]), 2) def test_note_default(self) -> None: - res = column_settings.parse_string('[default: 123, note: "mynote"]', parseAll=True) - self.assertIn('note', res[0]) - self.assertEqual(res[0]['default'], 123) + res = column_settings.parse_string( + '[default: 123, note: "mynote"]', parseAll=True + ) + self.assertIn("note", res[0]) + self.assertEqual(res[0]["default"], 123) def test_wrong(self) -> None: val = "[wrong]" @@ -163,39 +167,39 @@ def test_wrong(self) -> None: class TestConstraint(TestCase): def test_should_parse(self) -> None: - constraint.parse_string('unique', parseAll=True) - constraint.parse_string('pk', parseAll=True) + constraint.parse_string("unique", parseAll=True) + constraint.parse_string("pk", parseAll=True) def test_should_fail(self) -> None: with self.assertRaises(ParseException): - constraint.parse_string('wrong', parseAll=True) + constraint.parse_string("wrong", parseAll=True) class TestColumn(TestCase): def test_no_settings(self) -> None: - val = 'address varchar(255)\n' + val = "address varchar(255)\n" res = table_column.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'address') - self.assertEqual(res[0].type, 'varchar(255)') + self.assertEqual(res[0].name, "address") + self.assertEqual(res[0].type, "varchar(255)") def test_with_constraint(self) -> None: - val = 'user_id integer unique\n' + val = "user_id integer unique\n" res = table_column.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'user_id') - self.assertEqual(res[0].type, 'integer') + self.assertEqual(res[0].name, "user_id") + self.assertEqual(res[0].type, "integer") self.assertTrue(res[0].unique) - val2 = 'user_id integer pk unique\n' + val2 = "user_id integer pk unique\n" res2 = table_column.parse_string(val2, parseAll=True) - self.assertEqual(res2[0].name, 'user_id') - self.assertEqual(res2[0].type, 'integer') + self.assertEqual(res2[0].name, "user_id") + self.assertEqual(res2[0].type, "integer") self.assertTrue(res2[0].unique) self.assertTrue(res2[0].pk) def test_with_settings(self) -> None: val = "_test_ \"mytype\" [unique, not null, note: 'to include unit number']\n" res = table_column.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, '_test_') - self.assertEqual(res[0].type, 'mytype') + self.assertEqual(res[0].name, "_test_") + self.assertEqual(res[0].type, "mytype") self.assertTrue(res[0].unique) self.assertTrue(res[0].not_null) self.assertTrue(res[0].note is not None) @@ -206,39 +210,46 @@ def test_enum_type_bad(self) -> None: table_column.parse_string(val, parseAll=True) def test_settings_and_constraints(self) -> None: - val = "_test_ \"mytype\" unique pk [not null]\n" + val = '_test_ "mytype" unique pk [not null]\n' res = table_column.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, '_test_') - self.assertEqual(res[0].type, 'mytype') + self.assertEqual(res[0].name, "_test_") + self.assertEqual(res[0].type, "mytype") self.assertTrue(res[0].unique) self.assertTrue(res[0].not_null) self.assertTrue(res[0].pk) def test_comment_above(self) -> None: - val = '//comment above\naddress varchar\n' + val = "//comment above\naddress varchar\n" res = table_column.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'address') - self.assertEqual(res[0].type, 'varchar') - self.assertEqual(res[0].comment, 'comment above') + self.assertEqual(res[0].name, "address") + self.assertEqual(res[0].type, "varchar") + self.assertEqual(res[0].comment, "comment above") def test_comment_after(self) -> None: - val = 'address varchar //comment after\n' + val = "address varchar //comment after\n" res = table_column.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'address') - self.assertEqual(res[0].type, 'varchar') - self.assertEqual(res[0].comment, 'comment after') - val2 = 'user_id integer pk unique //comment after\n' + self.assertEqual(res[0].name, "address") + self.assertEqual(res[0].type, "varchar") + self.assertEqual(res[0].comment, "comment after") + val2 = "user_id integer pk unique //comment after\n" res2 = table_column.parse_string(val2, parseAll=True) - self.assertEqual(res2[0].name, 'user_id') - self.assertEqual(res2[0].type, 'integer') + self.assertEqual(res2[0].name, "user_id") + self.assertEqual(res2[0].type, "integer") self.assertTrue(res2[0].unique) self.assertTrue(res2[0].pk) - self.assertEqual(res2[0].comment, 'comment after') - val3 = "_test_ \"mytype\" unique pk [not null] //comment after\n" + self.assertEqual(res2[0].comment, "comment after") + val3 = '_test_ "mytype" unique pk [not null] //comment after\n' res3 = table_column.parse_string(val3, parseAll=True) - self.assertEqual(res3[0].name, '_test_') - self.assertEqual(res3[0].type, 'mytype') + self.assertEqual(res3[0].name, "_test_") + self.assertEqual(res3[0].type, "mytype") self.assertTrue(res3[0].unique) self.assertTrue(res3[0].not_null) self.assertTrue(res3[0].pk) - self.assertEqual(res3[0].comment, 'comment after') + self.assertEqual(res3[0].comment, "comment after") + + +def test_properties() -> None: + val = "address varchar(255) [unique, foo: 'bar', baz: '''123''']" + res = table_column_with_properties.parse_string(val, parseAll=True) + assert res[0].properties == {"foo": "bar", "baz": '123'} + diff --git a/test/test_definitions/test_table.py b/test/test_definitions/test_table.py index 561f430..f002fdd 100644 --- a/test/test_definitions/test_table.py +++ b/test/test_definitions/test_table.py @@ -4,89 +4,89 @@ from pyparsing import ParseSyntaxException from pyparsing import ParserElement -from pydbml.definitions.table import alias +from pydbml.definitions.table import alias, table_with_properties from pydbml.definitions.table import header_color from pydbml.definitions.table import table from pydbml.definitions.table import table_body from pydbml.definitions.table import table_settings -ParserElement.set_default_whitespace_chars(' \t\r') +ParserElement.set_default_whitespace_chars(" \t\r") class TestAlias(TestCase): def test_ok(self) -> None: - val = 'as Alias' + val = "as Alias" alias.parse_string(val, parseAll=True) def test_nok(self) -> None: - val = 'asalias' + val = "asalias" with self.assertRaises(ParseSyntaxException): alias.parse_string(val, parseAll=True) class TestHeaderColor(TestCase): def test_oneline(self) -> None: - val = 'headercolor: #CCCCCC' + val = "headercolor: #CCCCCC" res = header_color.parse_string(val, parseAll=True) - self.assertEqual(res['header_color'], '#CCCCCC') + self.assertEqual(res["header_color"], "#CCCCCC") def test_multiline(self) -> None: - val = 'headercolor:\n\n#E02' + val = "headercolor:\n\n#E02" res = header_color.parse_string(val, parseAll=True) - self.assertEqual(res['header_color'], '#E02') + self.assertEqual(res["header_color"], "#E02") class TestTableSettings(TestCase): def test_one(self) -> None: - val = '[headercolor: #E024DF]' + val = "[headercolor: #E024DF]" res = table_settings.parse_string(val, parseAll=True) - self.assertEqual(res[0]['header_color'], '#E024DF') + self.assertEqual(res[0]["header_color"], "#E024DF") def test_both(self) -> None: val = '[note: "note content", headercolor: #E024DF]' res = table_settings.parse_string(val, parseAll=True) - self.assertEqual(res[0]['header_color'], '#E024DF') - self.assertIn('note', res[0]) + self.assertEqual(res[0]["header_color"], "#E024DF") + self.assertIn("note", res[0]) class TestTableBody(TestCase): def test_one_column(self) -> None: - val = 'id integer [pk, increment]\n' + val = "id integer [pk, increment]\n" res = table_body.parse_string(val, parseAll=True) - self.assertEqual(len(res['columns']), 1) + self.assertEqual(len(res["columns"]), 1) def test_two_columns(self) -> None: - val = 'id integer [pk, increment]\nname string\n' + val = "id integer [pk, increment]\nname string\n" res = table_body.parse_string(val, parseAll=True) - self.assertEqual(len(res['columns']), 2) + self.assertEqual(len(res["columns"]), 2) def test_columns_indexes(self) -> None: - val = ''' + val = """ id integer country varchar [NOT NULL, ref: > countries.country_name] booking_date date unique pk indexes { (id, country) [pk] // composite primary key -}''' +}""" res = table_body.parse_string(val, parseAll=True) - self.assertEqual(len(res['columns']), 3) - self.assertEqual(len(res['indexes']), 1) + self.assertEqual(len(res["columns"]), 3) + self.assertEqual(len(res["indexes"]), 1) def test_columns_indexes_note(self) -> None: - val = ''' + val = """ id integer country varchar [NOT NULL, ref: > countries.country_name] booking_date date unique pk note: 'mynote' indexes { (id, country) [pk] // composite primary key -}''' +}""" res = table_body.parse_string(val, parseAll=True) - self.assertEqual(len(res['columns']), 3) - self.assertEqual(len(res['indexes']), 1) - self.assertIsNotNone(res['note']) - val2 = ''' + self.assertEqual(len(res["columns"]), 3) + self.assertEqual(len(res["indexes"]), 1) + self.assertIsNotNone(res["note"]) + val2 = """ id integer country varchar [NOT NULL, ref: > countries.country_name] booking_date date unique pk @@ -95,85 +95,85 @@ def test_columns_indexes_note(self) -> None: } indexes { (id, country) [pk] // composite primary key -}''' +}""" res2 = table_body.parse_string(val2, parseAll=True) - self.assertEqual(len(res2['columns']), 3) - self.assertEqual(len(res2['indexes']), 1) - self.assertIsNotNone(res2['note']) + self.assertEqual(len(res2["columns"]), 3) + self.assertEqual(len(res2["indexes"]), 1) + self.assertIsNotNone(res2["note"]) def test_no_columns(self) -> None: - val = ''' + val = """ note: 'mynote' indexes { (id, country) [pk] // composite primary key -}''' +}""" with self.assertRaises(ParseException): table_body.parse_string(val, parseAll=True) def test_columns_after_indexes(self) -> None: - val = ''' + val = """ note: 'mynote' indexes { (id, country) [pk] // composite primary key } -id integer''' +id integer""" with self.assertRaises(ParseException): table_body.parse_string(val, parseAll=True) class TestTable(TestCase): def test_simple(self) -> None: - val = 'table ids {\nid integer\n}' + val = "table ids {\nid integer\n}" res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') + self.assertEqual(res[0].name, "ids") self.assertEqual(len(res[0].columns), 1) def test_with_alias(self) -> None: - val = 'table ids as ii {\nid integer\n}' + val = "table ids as ii {\nid integer\n}" res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") self.assertEqual(len(res[0].columns), 1) def test_schema(self) -> None: - val = 'table ids as ii {\nid integer\n}' + val = "table ids as ii {\nid integer\n}" res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].schema, 'public') # default + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].schema, "public") # default self.assertEqual(len(res[0].columns), 1) - val = 'table myschema.ids as ii {\nid integer\n}' + val = "table myschema.ids as ii {\nid integer\n}" res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].schema, 'myschema') + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].schema, "myschema") def test_with_settings(self) -> None: val = 'table ids as ii [headercolor: #ccc, note: "headernote"] {\nid integer\n}' res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(res[0].header_color, '#ccc') - self.assertEqual(res[0].note.text, 'headernote') + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "headernote") self.assertEqual(len(res[0].columns), 1) def test_with_body_note(self) -> None: - val = ''' + val = """ table ids as ii [ headercolor: #ccc, note: "headernote"] { id integer note: "bodynote" -}''' +}""" res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(res[0].header_color, '#ccc') - self.assertEqual(res[0].note.text, 'bodynote') + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "bodynote") self.assertEqual(len(res[0].columns), 1) def test_comment_after(self) -> None: - val = ''' + val = """ // some comment before table table ids as ii [ headercolor: #ccc, @@ -181,17 +181,17 @@ def test_comment_after(self) -> None: { id integer note: "bodynote" -} // some somment after table''' +} // some somment after table""" res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].comment, 'some comment before table') - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(res[0].header_color, '#ccc') - self.assertEqual(res[0].note.text, 'bodynote') + self.assertEqual(res[0].comment, "some comment before table") + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "bodynote") self.assertEqual(len(res[0].columns), 1) def test_with_indexes(self) -> None: - val = ''' + val = """ table ids as ii [ headercolor: #ccc, note: "headernote"] @@ -202,11 +202,30 @@ def test_with_indexes(self) -> None: indexes { (id, country) [pk] // composite primary key } -}''' +}""" res = table.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(res[0].header_color, '#ccc') - self.assertEqual(res[0].note.text, 'bodynote') + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "bodynote") self.assertEqual(len(res[0].columns), 2) self.assertEqual(len(res[0].indexes), 1) + + +def test_properties() -> None: + val = """ +table ids as ii [ + headercolor: #ccc, + note: "headernote"] +{ + id integer + country varchar + note: "bodynote" + foo: 'bar' + baz: '''123''' + indexes { + (id, country) [pk] // composite primary key + } +}""" + res = table_with_properties.parse_string(val, parseAll=True) + assert res[0].properties == {"foo": "bar", "baz": "123"} diff --git a/test/test_renderer/test_dbml/test_column.py b/test/test_renderer/test_dbml/test_column.py index d785c97..1b52807 100644 --- a/test/test_renderer/test_dbml/test_column.py +++ b/test/test_renderer/test_dbml/test_column.py @@ -77,8 +77,21 @@ def test_note(simple_column_with_table: Column) -> None: def test_no_options(simple_column_with_table: Column) -> None: assert render_options(simple_column_with_table) == "" + @staticmethod + def test_properties(simple_column_with_table: Column) -> None: + simple_column_with_table.properties = {"key": "value"} + simple_column_with_table.table.database.allow_properties = True + assert render_options(simple_column_with_table) == " [key: 'value']" + + @staticmethod + def test_properties_not_allowed(simple_column_with_table: Column) -> None: + simple_column_with_table.properties = {"key": "value"} + simple_column_with_table.table.database.allow_properties = False + assert render_options(simple_column_with_table) == "" + @staticmethod def test_all_options(complex_column: Column) -> None: + complex_column.table = Mock(database=Mock(allow_properties=True)) complex_column.get_refs = Mock( return_value=[ Mock(dbml="ref1", inline=True), @@ -87,14 +100,19 @@ def test_all_options(complex_column: Column) -> None: ] ) complex_column.default = "null" + + expected = ( + " [ref1, ref3, pk, increment, default: null, unique, not null, note, foo: " + "'bar', baz: '''\n" + "qux\n" + "qux''']" + ) + with patch( "pydbml.renderer.dbml.default.column.note_option_to_dbml", Mock(return_value="note"), ): - assert ( - render_options(complex_column) - == " [ref1, ref3, pk, increment, default: null, unique, not null, note]" - ) + assert render_options(complex_column) == expected class TestRenderColumn: diff --git a/test/test_renderer/test_dbml/test_note.py b/test/test_renderer/test_dbml/test_note.py index 0030fe7..20df62d 100644 --- a/test/test_renderer/test_dbml/test_note.py +++ b/test/test_renderer/test_dbml/test_note.py @@ -1,10 +1,11 @@ from pydbml.classes import Note -from pydbml.renderer.dbml.default.note import prepare_text_for_dbml, render_note +from pydbml.renderer.dbml.default.note import render_note +from pydbml.renderer.dbml.default.utils import prepare_text_for_dbml def test_prepare_text_for_dbml() -> None: note = Note("""Three quotes: ''', one quote: '.""") - assert prepare_text_for_dbml(note) == "Three quotes: \\''', one quote: \\'." + assert prepare_text_for_dbml(note.text) == "Three quotes: \\''', one quote: \\'." class TestRenderNote: @@ -18,5 +19,5 @@ def test_multiline() -> None: note = Note("Note text\nwith multiple lines") assert ( render_note(note) - == "Note {\n '''\n Note text\n with multiple lines\n '''\n}" + == "Note {\n '''\n Note text\n with multiple lines'''\n}" ) diff --git a/test/test_renderer/test_dbml/test_sticky_note.py b/test/test_renderer/test_dbml/test_sticky_note.py index bc800d0..1f5bc2a 100644 --- a/test/test_renderer/test_dbml/test_sticky_note.py +++ b/test/test_renderer/test_dbml/test_sticky_note.py @@ -13,5 +13,5 @@ def test_multiline() -> None: note = StickyNote(name='mynote', text="Note text\nwith multiple lines") assert ( render_sticky_note(note) - == "Note mynote {\n '''\n Note text\n with multiple lines\n '''\n}" + == "Note mynote {\n '''\n Note text\n with multiple lines'''\n}" ) diff --git a/test/test_renderer/test_dbml/test_table.py b/test/test_renderer/test_dbml/test_table.py index 2e584bd..7843a5d 100644 --- a/test/test_renderer/test_dbml/test_table.py +++ b/test/test_renderer/test_dbml/test_table.py @@ -79,3 +79,31 @@ def test_note_and_comment(db: Database, table1: Table) -> None: "}" ) assert render_table(table1) == expected + + @staticmethod + def test_properties(db: Database, table1: Table) -> None: + table1.properties = {"key": "value"} + db.add(table1) + db.allow_properties = True + expected = ( + 'Table "products" {\n' + ' "id" integer\n' + ' "name" varchar\n' + "\n" + " key: 'value'\n" + "}" + ) + assert render_table(table1) == expected + + @staticmethod + def test_properties_not_allowed(db: Database, table1: Table) -> None: + table1.properties = {"key": "value"} + db.add(table1) + db.allow_properties = False + expected = ( + 'Table "products" {\n' + ' "id" integer\n' + ' "name" varchar\n' + "}" + ) + assert render_table(table1) == expected diff --git a/test_schema.dbml b/test_schema.dbml index 432c1ca..0d68b2c 100644 --- a/test_schema.dbml +++ b/test_schema.dbml @@ -36,7 +36,6 @@ Table "products" { "status" "product status" "created_at" datetime [default: `now()`] - Indexes { (merchant_id, status) [name: "product_status"] id [type: hash, unique] From 4fbe049f509295b9065eb3bc6580f993b3a3ab52 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Thu, 25 Jul 2024 21:42:43 +0200 Subject: [PATCH 103/125] feat: unicode identifiers --- CHANGELOG.md | 5 +++++ pydbml/definitions/generic.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61c8825..df77d48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.1.0 + +- New: allow unicode characters in identifiers (DBML v3.3.0) +- + # 1.0.11 - Fix: allow pk in named indexes (thanks @pierresouchay for the contribution) diff --git a/pydbml/definitions/generic.py b/pydbml/definitions/generic.py index 69c1270..180cca4 100644 --- a/pydbml/definitions/generic.py +++ b/pydbml/definitions/generic.py @@ -4,7 +4,7 @@ pp.ParserElement.set_default_whitespace_chars(' \t\r') -name = pp.Word(pp.alphanums + '_') | pp.QuotedString('"') +name = pp.Word(pp.unicode.alphanums + '_') | pp.QuotedString('"') # Literals From f43deae193ce7bc9af16fc51cacd65863444509b Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Thu, 25 Jul 2024 22:09:06 +0200 Subject: [PATCH 104/125] update changelog, readme, bump version --- CHANGELOG.md | 3 ++- README.md | 2 +- setup.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df77d48..d406ca3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,8 @@ # 1.1.0 +- New: SQL and DBML rendering rewritten tow support external renderers - New: allow unicode characters in identifiers (DBML v3.3.0) -- +- New: support for arbitrary table and column properties (#37) # 1.0.11 - Fix: allow pk in named indexes (thanks @pierresouchay for the contribution) diff --git a/README.md b/README.md index 2f1e426..fd16e2e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # DBML parser for Python -*Compliant with DBML **v3.2.0** syntax* +*Compliant with DBML **v3.6.1** syntax* PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. diff --git a/setup.py b/setup.py index 12c1887..9e470be 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.0.11', + version='1.1.0', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From c9ef6eda3faedd8043d309584cff42da0a762739 Mon Sep 17 00:00:00 2001 From: Colin <50527015+big-c-note@users.noreply.github.com> Date: Tue, 6 Aug 2024 09:13:41 -0400 Subject: [PATCH 105/125] adds sql_renderer arg to parser --- pydbml/parser/parser.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index 2954188..beae3eb 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import List from typing import Optional +from typing import Type from typing import Union import pyparsing as pp @@ -18,6 +19,8 @@ from pydbml.definitions.table import table, table_with_properties from pydbml.definitions.table_group import table_group from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.base import BaseRenderer +from pydbml.renderer.sql.default import DefaultSQLRenderer from pydbml.tools import remove_bom from .blueprints import EnumBlueprint, StickyNoteBlueprint from .blueprints import ProjectBlueprint @@ -50,6 +53,7 @@ def __new__( cls, source_: Optional[Union[str, Path, TextIOWrapper]] = None, allow_properties: bool = False, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, ): if source_ is not None: if isinstance(source_, str): @@ -63,7 +67,9 @@ def __new__( raise TypeError("Source must be str, path or file stream") source = remove_bom(source) - return cls.parse(source, allow_properties=allow_properties) + return cls.parse( + source, allow_properties=allow_properties, sql_renderer=sql_renderer + ) else: return super().__new__(cls) @@ -71,9 +77,15 @@ def __repr__(self): return "" @staticmethod - def parse(text: str, allow_properties: bool = False) -> Database: + def parse( + text: str, + allow_properties: bool = False, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + ) -> Database: text = remove_bom(text) - parser = PyDBMLParser(text, allow_properties=allow_properties) + parser = PyDBMLParser( + text, allow_properties=allow_properties, sql_renderer=sql_renderer + ) return parser.parse() @staticmethod @@ -89,7 +101,12 @@ def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: class PyDBMLParser: - def __init__(self, source: str, allow_properties: bool = False): + def __init__( + self, + source: str, + allow_properties: bool = False, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + ): self.database = None self.ref_blueprints: List[ReferenceBlueprint] = [] @@ -101,6 +118,7 @@ def __init__(self, source: str, allow_properties: bool = False): self.project: Optional[ProjectBlueprint] = None self.sticky_notes: List[StickyNoteBlueprint] = [] self._allow_properties = allow_properties + self._sql_renderer = sql_renderer def parse(self): self._set_syntax() @@ -190,7 +208,9 @@ def locate_table(self, schema: str, name: str) -> "Table": return result def build_database(self): - self.database = Database(allow_properties=self._allow_properties) + self.database = Database( + allow_properties=self._allow_properties, sql_renderer=self._sql_renderer + ) for enum_bp in self.enums: self.database.add(enum_bp.build()) for table_bp in self.tables: From 056ee4ef8ddd06d496ec436e0c7a9810ed4dd0e6 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Aug 2024 09:19:13 +0200 Subject: [PATCH 106/125] add dbml renderer to parser --- pydbml/parser/parser.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py index beae3eb..53a5e92 100644 --- a/pydbml/parser/parser.py +++ b/pydbml/parser/parser.py @@ -20,6 +20,7 @@ from pydbml.definitions.table_group import table_group from pydbml.exceptions import TableNotFoundError from pydbml.renderer.base import BaseRenderer +from pydbml.renderer.dbml.default import DefaultDBMLRenderer from pydbml.renderer.sql.default import DefaultSQLRenderer from pydbml.tools import remove_bom from .blueprints import EnumBlueprint, StickyNoteBlueprint @@ -54,6 +55,7 @@ def __new__( source_: Optional[Union[str, Path, TextIOWrapper]] = None, allow_properties: bool = False, sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, ): if source_ is not None: if isinstance(source_, str): @@ -68,7 +70,10 @@ def __new__( source = remove_bom(source) return cls.parse( - source, allow_properties=allow_properties, sql_renderer=sql_renderer + source, + allow_properties=allow_properties, + sql_renderer=sql_renderer, + dbml_renderer=dbml_renderer, ) else: return super().__new__(cls) @@ -81,10 +86,14 @@ def parse( text: str, allow_properties: bool = False, sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, ) -> Database: text = remove_bom(text) parser = PyDBMLParser( - text, allow_properties=allow_properties, sql_renderer=sql_renderer + text, + allow_properties=allow_properties, + sql_renderer=sql_renderer, + dbml_renderer=dbml_renderer, ) return parser.parse() @@ -106,6 +115,7 @@ def __init__( source: str, allow_properties: bool = False, sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, ): self.database = None @@ -119,6 +129,7 @@ def __init__( self.sticky_notes: List[StickyNoteBlueprint] = [] self._allow_properties = allow_properties self._sql_renderer = sql_renderer + self._dbml_renderer = dbml_renderer def parse(self): self._set_syntax() @@ -209,7 +220,9 @@ def locate_table(self, schema: str, name: str) -> "Table": def build_database(self): self.database = Database( - allow_properties=self._allow_properties, sql_renderer=self._sql_renderer + allow_properties=self._allow_properties, + sql_renderer=self._sql_renderer, + dbml_renderer=self._dbml_renderer, ) for enum_bp in self.enums: self.database.add(enum_bp.build()) From 2606d18d30c04a2a28c0573e3f33e492ab84aba7 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 11 Aug 2024 09:20:06 +0200 Subject: [PATCH 107/125] update changelog, bump version --- CHANGELOG.md | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d406ca3..8156a22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.1.1 +- New: SQL and DBML renderers can now be supplied to parser + # 1.1.0 - New: SQL and DBML rendering rewritten tow support external renderers diff --git a/setup.py b/setup.py index 9e470be..37766f4 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.1.0', + version='1.1.1', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 7ccd5c6e94e9c68159f46ceb2bb484daa6725a3a Mon Sep 17 00:00:00 2001 From: Sangwoon Park Date: Tue, 26 Nov 2024 17:16:01 +0900 Subject: [PATCH 108/125] Fix bug in DBML column default rendering with single quotes - Integrated `prepare_text_for_dbml` into `default_to_str` for consistent handling of special characters in DBML column strings. - Updated the test suite to include a new test case for handling binary string input (`b'0'`). --- pydbml/renderer/dbml/default/column.py | 4 ++-- test/test_renderer/test_dbml/test_column.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pydbml/renderer/dbml/default/column.py b/pydbml/renderer/dbml/default/column.py index a06a580..3ec037a 100644 --- a/pydbml/renderer/dbml/default/column.py +++ b/pydbml/renderer/dbml/default/column.py @@ -2,7 +2,7 @@ from pydbml.classes import Column, Enum, Expression from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer -from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml, quote_string +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml, quote_string, prepare_text_for_dbml from pydbml.renderer.sql.default.utils import get_full_name_for_sql @@ -11,7 +11,7 @@ def default_to_str(val: Union[Expression, str, int, float]) -> str: if val.lower() in ('null', 'true', 'false'): return val.lower() else: - return f"'{val}'" + return f"'{prepare_text_for_dbml(val)}'" elif isinstance(val, Expression): return val.dbml else: # int or float or bool diff --git a/test/test_renderer/test_dbml/test_column.py b/test/test_renderer/test_dbml/test_column.py index 1b52807..b2cf3b8 100644 --- a/test/test_renderer/test_dbml/test_column.py +++ b/test/test_renderer/test_dbml/test_column.py @@ -21,6 +21,7 @@ (True, "True"), ("False", "false"), ("null", "null"), + ("b'0'", "'b\\'0\\''"), ], ) def test_default_to_str(input: Any, expected: str) -> None: From 85a7d07bf9c343f81830c8216140bd661b9739c4 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 6 Jan 2025 15:49:57 +0100 Subject: [PATCH 109/125] fix: line breaks in column and index options are allowed (#48) --- pydbml/definitions/column.py | 6 +++--- pydbml/definitions/index.py | 6 +++--- pydbml/renderer/dbml/default/utils.py | 2 +- test/test_definitions/test_column.py | 15 +++++++++++++++ test_schema.dbml | 5 ++++- 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index a33ffdd..1cd37d0 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -41,7 +41,7 @@ prop = name + pp.Suppress(":") + string_literal -column_setting = ( +column_setting = _ + ( pp.CaselessLiteral("not null").set_parse_action( lambda s, loc, tok: True )('notnull') @@ -55,11 +55,11 @@ | note('note') | ref_inline('ref*') | default('default') -) +) + _ column_setting_with_property = column_setting | prop.set_results_name('property', list_all_matches=True) -column_settings = '[' - (_ + column_setting + _) + ("," + column_setting)[...] + ']' + c +column_settings = '[' - column_setting + ("," + column_setting)[...] + ']' + c column_settings_with_properties = '[' - (_ + column_setting_with_property + _) + ("," + column_setting_with_property)[...] + ']' + c diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index 5c7e096..9364901 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -17,15 +17,15 @@ index_type = pp.CaselessLiteral("type:").suppress() + _ - ( pp.CaselessLiteral("btree")('type') | pp.CaselessLiteral("hash")('type') ) -index_setting = ( +index_setting = _ + ( unique('unique') | index_type | pp.CaselessLiteral("name:") + _ - string_literal('name') | note('note') | pk('pk') -) +) + _ index_settings = ( - '[' + _ + index_setting + (_ + ',' + _ - index_setting)[...] + _ - ']' + c + '[' + index_setting + (',' - index_setting)[...] - ']' + c ) diff --git a/pydbml/renderer/dbml/default/utils.py b/pydbml/renderer/dbml/default/utils.py index c46e893..3e71c47 100644 --- a/pydbml/renderer/dbml/default/utils.py +++ b/pydbml/renderer/dbml/default/utils.py @@ -3,7 +3,7 @@ from pydbml.tools import comment -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from pydbml.classes import Note diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py index 90767e2..335949a 100644 --- a/test/test_definitions/test_column.py +++ b/test/test_definitions/test_column.py @@ -1,3 +1,4 @@ +from textwrap import dedent from unittest import TestCase from pyparsing import ParseException @@ -204,6 +205,20 @@ def test_with_settings(self) -> None: self.assertTrue(res[0].not_null) self.assertTrue(res[0].note is not None) + def test_multiline_settings(self) -> None: + val = dedent("""_test_ \"mytype\" [ + unique, + not null, + note: 'to include unit number' + ] + """) + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "_test_") + self.assertEqual(res[0].type, "mytype") + self.assertTrue(res[0].unique) + self.assertTrue(res[0].not_null) + self.assertTrue(res[0].note is not None) + def test_enum_type_bad(self) -> None: val = "_test_ myschema.mytype(12) [unique]\n" with self.assertRaises(ParseException): diff --git a/test_schema.dbml b/test_schema.dbml index 0d68b2c..48f340d 100644 --- a/test_schema.dbml +++ b/test_schema.dbml @@ -17,7 +17,10 @@ Enum "product status" { Table "orders" [headercolor: #fff] { "id" int [pk, increment] - "user_id" int [unique, not null] + "user_id" int [ + unique, + not null + ] "status" orders_status "created_at" varchar } From 6df3c2f585f4ad3d3a68c6f92c9445c2fa069f07 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Mon, 6 Jan 2025 17:47:54 +0100 Subject: [PATCH 110/125] fix: table elements may go in any order (#49) --- pydbml/definitions/table.py | 25 +++++++++++++++++++------ test/test_definitions/test_table.py | 21 +++++++++------------ 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index fdf051d..b5ae190 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -44,11 +44,20 @@ def parse_table_settings(s, loc, tok): prop = name + pp.Suppress(":") + string_literal -table_element = _ + (note_element('note') | indexes('indexes')) + _ -table_element_with_property = _ + (note_element('note') | indexes('indexes') | prop.set_results_name('property', list_all_matches=True)) + _ - -table_body = table_column[1, ...]('columns') + _ + table_element[...] -table_body_with_properties = table_column_with_properties[1, ...]('columns') + _ + table_element_with_property[...] +table_element = _ + ( + table_column.set_results_name('columns', list_all_matches=True) | + note_element('note') | + indexes.set_results_name('indexes', list_all_matches=True) +) + _ +table_element_with_property = _ + ( + table_column_with_properties.set_results_name('columns', list_all_matches=True) | + note_element('note') | + indexes.set_results_name('indexes', list_all_matches=True) | + prop.set_results_name('property', list_all_matches=True) +) + _ + +table_body = table_element[...] +table_body_with_properties = table_element_with_property[...] table_name = (name('schema') + '.' + name('name')) | (name('name')) @@ -95,7 +104,7 @@ def parse_table(s, loc, tok): # will override one from settings init_dict['note'] = tok['note'][0] if 'indexes' in tok: - init_dict['indexes'] = tok['indexes'] + init_dict['indexes'] = tok['indexes'][0] if 'columns' in tok: init_dict['columns'] = tok['columns'] if 'comment_before' in tok: @@ -103,6 +112,10 @@ def parse_table(s, loc, tok): init_dict['comment'] = comment if 'property' in tok: init_dict['properties'] = {k: v for k, v in tok['property']} + + if not init_dict.get('columns'): + raise SyntaxError(f'Table {init_dict["name"]} at position {loc} has no columns!') + result = TableBlueprint(**init_dict) return result diff --git a/test/test_definitions/test_table.py b/test/test_definitions/test_table.py index f002fdd..cb65ef6 100644 --- a/test/test_definitions/test_table.py +++ b/test/test_definitions/test_table.py @@ -101,24 +101,16 @@ def test_columns_indexes_note(self) -> None: self.assertEqual(len(res2["indexes"]), 1) self.assertIsNotNone(res2["note"]) - def test_no_columns(self) -> None: - val = """ -note: 'mynote' -indexes { - (id, country) [pk] // composite primary key -}""" - with self.assertRaises(ParseException): - table_body.parse_string(val, parseAll=True) - - def test_columns_after_indexes(self) -> None: + def test_columns_after_indexes_are_allowed(self) -> None: val = """ note: 'mynote' indexes { (id, country) [pk] // composite primary key } id integer""" - with self.assertRaises(ParseException): - table_body.parse_string(val, parseAll=True) + res = table_body.parse_string(val, parseAll=True) + self.assertEqual(len(res["columns"]), 1) + self.assertEqual(len(res["indexes"]), 1) class TestTable(TestCase): @@ -128,6 +120,11 @@ def test_simple(self) -> None: self.assertEqual(res[0].name, "ids") self.assertEqual(len(res[0].columns), 1) + def test_no_columns(self) -> None: + val = "table ids {\nNote: 'No columns!'\n}" + with self.assertRaises(SyntaxError): + res = table.parse_string(val, parseAll=True) + def test_with_alias(self) -> None: val = "table ids as ii {\nid integer\n}" res = table.parse_string(val, parseAll=True) From 598dd2f796f571619ce37d9caa53e2610a12038d Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 7 Jan 2025 14:57:47 +0100 Subject: [PATCH 111/125] feat: TableGroup can have notes (DBML 3.7.2) --- pydbml/_classes/table_group.py | 5 +- pydbml/definitions/table_group.py | 19 ++++++-- pydbml/parser/blueprints.py | 4 +- pydbml/renderer/dbml/default/table_group.py | 4 ++ test/test_definitions/test_table_group.py | 47 +++++++++++++++---- .../test_dbml/test_table_group.py | 17 +++++-- test_schema.dbml | 3 +- 7 files changed, 81 insertions(+), 18 deletions(-) diff --git a/pydbml/_classes/table_group.py b/pydbml/_classes/table_group.py index 0b6d4dd..35b8c30 100644 --- a/pydbml/_classes/table_group.py +++ b/pydbml/_classes/table_group.py @@ -2,6 +2,7 @@ from typing import Optional from pydbml._classes.base import DBMLObject +from pydbml._classes.note import Note from pydbml._classes.table import Table @@ -16,11 +17,13 @@ class TableGroup(DBMLObject): def __init__(self, name: str, items: List[Table], - comment: Optional[str] = None): + comment: Optional[str] = None, + note: Optional[Note] = None): self.database = None self.name = name self.items = items self.comment = comment + self.note = note def __repr__(self): """ diff --git a/pydbml/definitions/table_group.py b/pydbml/definitions/table_group.py index cf56938..812522a 100644 --- a/pydbml/definitions/table_group.py +++ b/pydbml/definitions/table_group.py @@ -1,20 +1,30 @@ import pyparsing as pp -from .common import _ +from .common import _, note, note_object from .common import _c from .common import end from .generic import name -from pydbml.parser.blueprints import TableGroupBlueprint +from pydbml.parser.blueprints import TableGroupBlueprint, NoteBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') table_name = pp.Combine(name + '.' + name) | name +note_element = note | note_object + +tg_element = _ + (note_element('note') | table_name.set_results_name('items', list_all_matches=True)) + _ + +tg_body = tg_element[...] + +tg_settings = '[' + _ + note('note') + _ + ']' + + table_group = _c + ( pp.CaselessLiteral('TableGroup') - name('name') + _ + + tg_settings[0, 1] + _ - '{' + _ - - (table_name + _)[...]('items') + _ + - tg_body + _ - '}' ) + end @@ -34,6 +44,9 @@ def parse_table_group(s, loc, tok): if 'comment_before' in tok: comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment + if 'note' in tok: + note = tok['note'] + init_dict['note'] = note if isinstance(note, NoteBlueprint) else note[0] return TableGroupBlueprint(**init_dict) diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 6632a45..b27e1fe 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -295,6 +295,7 @@ class TableGroupBlueprint(Blueprint): name: str items: List[str] comment: Optional[str] = None + note: Optional[NoteBlueprint] = None def build(self) -> 'TableGroup': if not self.parser: @@ -310,5 +311,6 @@ def build(self) -> 'TableGroup': return TableGroup( name=self.name, items=items, - comment=self.comment + comment=self.comment, + note=self.note.build() if self.note else None, ) diff --git a/pydbml/renderer/dbml/default/table_group.py b/pydbml/renderer/dbml/default/table_group.py index 29d4e85..a5c708f 100644 --- a/pydbml/renderer/dbml/default/table_group.py +++ b/pydbml/renderer/dbml/default/table_group.py @@ -1,3 +1,5 @@ +from textwrap import indent + from pydbml.classes import TableGroup from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer from pydbml.renderer.dbml.default.table import get_full_name_for_dbml @@ -10,5 +12,7 @@ def render_table_group(model: TableGroup) -> str: result += f'TableGroup {model.name} {{\n' for i in model.items: result += f' {get_full_name_for_dbml(i)}\n' + if model.note: + result += indent(model.note.dbml, ' ') + '\n' result += '}' return result diff --git a/test/test_definitions/test_table_group.py b/test/test_definitions/test_table_group.py index d8e2a7a..1c38560 100644 --- a/test/test_definitions/test_table_group.py +++ b/test/test_definitions/test_table_group.py @@ -1,3 +1,4 @@ +from textwrap import dedent from unittest import TestCase from pyparsing import ParserElement @@ -5,24 +6,54 @@ from pydbml.definitions.table_group import table_group -ParserElement.set_default_whitespace_chars(' \t\r') +ParserElement.set_default_whitespace_chars(" \t\r") class TestProject(TestCase): def test_empty(self) -> None: - val = 'TableGroup name {}' + val = "TableGroup name {}" res = table_group.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'name') + self.assertEqual(res[0].name, "name") def test_fields(self) -> None: val = "TableGroup name {table1 table2}" res = table_group.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'name') - self.assertEqual(res[0].items, ['table1', 'table2']) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) def test_comment(self) -> None: val = "//comment before\nTableGroup name\n{\ntable1\ntable2\n}" res = table_group.parse_string(val, parseAll=True) - self.assertEqual(res[0].name, 'name') - self.assertEqual(res[0].items, ['table1', 'table2']) - self.assertEqual(res[0].comment, 'comment before') + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].comment, "comment before") + + def test_note_settings(self) -> None: + val = "TableGroup name [note: 'My note'] \n{\ntable1\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].note.text, "My note") + + def test_note_body(self) -> None: + val = dedent("""\ + TableGroup name { + table1 + Note: ''' + Note line1 + Note line2 + ''' + table2 + } + """) + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertIn("Note line1\n", res[0].note.text,) + + def test_note_settings_overriden_by_note_body(self) -> None: + val = "TableGroup name [note: 'Settings note'] \n{\ntable1\nnote: 'Body note'\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].note.text, "Body note") diff --git a/test/test_renderer/test_dbml/test_table_group.py b/test/test_renderer/test_dbml/test_table_group.py index 171a270..28def64 100644 --- a/test/test_renderer/test_dbml/test_table_group.py +++ b/test/test_renderer/test_dbml/test_table_group.py @@ -1,3 +1,4 @@ +from pydbml._classes.note import Note from pydbml._classes.table_group import TableGroup from pydbml.classes import Table from pydbml.renderer.dbml.default import render_table_group @@ -5,14 +6,22 @@ def test_render_table_group(table1: Table, table2: Table, table3: Table) -> None: tg = TableGroup( - name="mygroup", items=[table1, table2, table3], comment="My comment" + name="mygroup", + items=[table1, table2, table3], + comment="My comment", + note=Note('Note line1\nNote line2') ) expected = ( - "// My comment\n" - "TableGroup mygroup {\n" + '// My comment\n' + 'TableGroup mygroup {\n' ' "products"\n' ' "products"\n' ' "orders"\n' - "}" + ' Note {\n' + " '''\n" + ' Note line1\n' + " Note line2'''\n" + ' }\n' + '}' ) assert render_table_group(tg) == expected diff --git a/test_schema.dbml b/test_schema.dbml index 48f340d..2e2eee8 100644 --- a/test_schema.dbml +++ b/test_schema.dbml @@ -57,9 +57,10 @@ Table "users" { Ref:"orders"."id" < "order_items"."order_id" -TableGroup g1 { +TableGroup g1 [note: 'test note'] { users merchants + note: 'test note 2' } TableGroup g2 { From 4075e62743bccd952df3d79eeb36fed4a9753624 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 7 Jan 2025 15:01:50 +0100 Subject: [PATCH 112/125] update docs --- docs/classes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/classes.md b/docs/classes.md index 276c834..6d817fd 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -324,4 +324,5 @@ Sticky notes are similar to regular notes, except that they are defined at the r * **name** (str) — table group name, * **items** (str) — dictionary with tables in the group, * **comment** (str) — comment, if was added before table group definition. +* **note** (Note) — table group's note if was defined. * **dbml** (str) — DBML definition for this table group. From 64ecc294050497634ea9b5ba995f60f93c87f6f8 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 7 Jan 2025 17:24:47 +0100 Subject: [PATCH 113/125] feat: TableGroup can have color (DBML 3.7.4) --- docs/classes.md | 1 + pydbml/_classes/table_group.py | 4 +- pydbml/definitions/common.py | 3 + pydbml/definitions/table.py | 6 +- pydbml/definitions/table_group.py | 13 +++- pydbml/parser/blueprints.py | 2 + pydbml/renderer/dbml/default/table_group.py | 5 +- test/test_definitions/test_table_group.py | 17 +++++- .../test_dbml/test_table_group.py | 60 ++++++++++++------- test_schema.dbml | 2 +- 10 files changed, 81 insertions(+), 32 deletions(-) diff --git a/docs/classes.md b/docs/classes.md index 6d817fd..f4e0fd5 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -325,4 +325,5 @@ Sticky notes are similar to regular notes, except that they are defined at the r * **items** (str) — dictionary with tables in the group, * **comment** (str) — comment, if was added before table group definition. * **note** (Note) — table group's note if was defined. +* **color** (str) — the color param, if defined. * **dbml** (str) — DBML definition for this table group. diff --git a/pydbml/_classes/table_group.py b/pydbml/_classes/table_group.py index 35b8c30..cd9abc6 100644 --- a/pydbml/_classes/table_group.py +++ b/pydbml/_classes/table_group.py @@ -18,12 +18,14 @@ def __init__(self, name: str, items: List[Table], comment: Optional[str] = None, - note: Optional[Note] = None): + note: Optional[Note] = None, + color: Optional[str] = None): self.database = None self.name = name self.items = items self.comment = comment self.note = note + self.color = color def __repr__(self): """ diff --git a/pydbml/definitions/common.py b/pydbml/definitions/common.py index fc5ed7a..8288de3 100644 --- a/pydbml/definitions/common.py +++ b/pydbml/definitions/common.py @@ -34,3 +34,6 @@ pk = pp.CaselessLiteral("pk") unique = pp.CaselessLiteral("unique") + +hex_char = pp.Word(pp.srange('[0-9a-fA-F]'), exact=1) +hex_color = ("#" - (hex_char * 3 ^ hex_char * 6)).leaveWhitespace() diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index b5ae190..aa0dfc9 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -1,22 +1,20 @@ import pyparsing as pp +from pydbml.parser.blueprints import TableBlueprint from .column import table_column, table_column_with_properties -from .common import _ +from .common import _, hex_color from .common import _c from .common import end from .common import note from .common import note_object from .generic import name, string_literal from .index import indexes -from pydbml.parser.blueprints import TableBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') alias = pp.WordStart() + pp.Literal('as').suppress() - pp.WordEnd() - name -hex_char = pp.Word(pp.srange('[0-9a-fA-F]'), exact=1) -hex_color = ("#" - (hex_char * 3 ^ hex_char * 6)).leaveWhitespace() header_color = ( pp.CaselessLiteral('headercolor:').suppress() + _ - pp.Combine(hex_color)('header_color') diff --git a/pydbml/definitions/table_group.py b/pydbml/definitions/table_group.py index 812522a..b8e5b16 100644 --- a/pydbml/definitions/table_group.py +++ b/pydbml/definitions/table_group.py @@ -1,10 +1,10 @@ import pyparsing as pp -from .common import _, note, note_object +from pydbml.parser.blueprints import TableGroupBlueprint, NoteBlueprint +from .common import _, note, note_object, hex_color from .common import _c from .common import end from .generic import name -from pydbml.parser.blueprints import TableGroupBlueprint, NoteBlueprint pp.ParserElement.set_default_whitespace_chars(' \t\r') @@ -15,9 +15,14 @@ tg_body = tg_element[...] -tg_settings = '[' + _ + note('note') + _ + ']' +tg_color = ( + pp.CaselessLiteral('color:').suppress() + _ + - pp.Combine(hex_color)('color') +) +tg_setting = _ + (note('note') | tg_color) + _ +tg_settings = '[' + tg_setting + (',' + tg_setting)[...] + ']' table_group = _c + ( pp.CaselessLiteral('TableGroup') @@ -47,6 +52,8 @@ def parse_table_group(s, loc, tok): if 'note' in tok: note = tok['note'] init_dict['note'] = note if isinstance(note, NoteBlueprint) else note[0] + if 'color' in tok: + init_dict['color'] = tok['color'] return TableGroupBlueprint(**init_dict) diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index b27e1fe..5c7b3a2 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -296,6 +296,7 @@ class TableGroupBlueprint(Blueprint): items: List[str] comment: Optional[str] = None note: Optional[NoteBlueprint] = None + color: Optional[str] = None def build(self) -> 'TableGroup': if not self.parser: @@ -313,4 +314,5 @@ def build(self) -> 'TableGroup': items=items, comment=self.comment, note=self.note.build() if self.note else None, + color=self.color ) diff --git a/pydbml/renderer/dbml/default/table_group.py b/pydbml/renderer/dbml/default/table_group.py index a5c708f..311d461 100644 --- a/pydbml/renderer/dbml/default/table_group.py +++ b/pydbml/renderer/dbml/default/table_group.py @@ -9,7 +9,10 @@ @DefaultDBMLRenderer.renderer_for(TableGroup) def render_table_group(model: TableGroup) -> str: result = comment_to_dbml(model.comment) if model.comment else '' - result += f'TableGroup {model.name} {{\n' + result += f'TableGroup {model.name}' + if model.color: + result += f' [color: {model.color}]' + result += ' {\n' for i in model.items: result += f' {get_full_name_for_dbml(i)}\n' if model.note: diff --git a/test/test_definitions/test_table_group.py b/test/test_definitions/test_table_group.py index 1c38560..b119d1d 100644 --- a/test/test_definitions/test_table_group.py +++ b/test/test_definitions/test_table_group.py @@ -9,7 +9,7 @@ ParserElement.set_default_whitespace_chars(" \t\r") -class TestProject(TestCase): +class TestTableGroup(TestCase): def test_empty(self) -> None: val = "TableGroup name {}" res = table_group.parse_string(val, parseAll=True) @@ -35,6 +35,21 @@ def test_note_settings(self) -> None: self.assertEqual(res[0].items, ["table1", "table2"]) self.assertEqual(res[0].note.text, "My note") + def test_color(self) -> None: + val = "TableGroup name [color: #FFF] \n{\ntable1\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].color, "#FFF") + + def test_all_settings(self) -> None: + val = "TableGroup name [color: #FFF, note: 'My note'] \n{\ntable1\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].color, "#FFF") + self.assertEqual(res[0].note.text, "My note") + def test_note_body(self) -> None: val = dedent("""\ TableGroup name { diff --git a/test/test_renderer/test_dbml/test_table_group.py b/test/test_renderer/test_dbml/test_table_group.py index 28def64..adfa1a3 100644 --- a/test/test_renderer/test_dbml/test_table_group.py +++ b/test/test_renderer/test_dbml/test_table_group.py @@ -4,24 +4,42 @@ from pydbml.renderer.dbml.default import render_table_group -def test_render_table_group(table1: Table, table2: Table, table3: Table) -> None: - tg = TableGroup( - name="mygroup", - items=[table1, table2, table3], - comment="My comment", - note=Note('Note line1\nNote line2') - ) - expected = ( - '// My comment\n' - 'TableGroup mygroup {\n' - ' "products"\n' - ' "products"\n' - ' "orders"\n' - ' Note {\n' - " '''\n" - ' Note line1\n' - " Note line2'''\n" - ' }\n' - '}' - ) - assert render_table_group(tg) == expected +class TestTableGroup: + @staticmethod + def test_simple(table1: Table, table2: Table, table3: Table) -> None: + tg = TableGroup( + name="mygroup", + items=[table1, table2, table3], + ) + expected = ( + 'TableGroup mygroup {\n' + ' "products"\n' + ' "products"\n' + ' "orders"\n' + '}' + ) + assert render_table_group(tg) == expected + + @staticmethod + def test_full(table1: Table, table2: Table, table3: Table) -> None: + tg = TableGroup( + name="mygroup", + items=[table1, table2, table3], + comment="My comment", + note=Note('Note line1\nNote line2'), + color='#FFF' + ) + expected = ( + '// My comment\n' + 'TableGroup mygroup [color: #FFF] {\n' + ' "products"\n' + ' "products"\n' + ' "orders"\n' + ' Note {\n' + " '''\n" + ' Note line1\n' + " Note line2'''\n" + ' }\n' + '}' + ) + assert render_table_group(tg) == expected diff --git a/test_schema.dbml b/test_schema.dbml index 2e2eee8..3d0b481 100644 --- a/test_schema.dbml +++ b/test_schema.dbml @@ -57,7 +57,7 @@ Table "users" { Ref:"orders"."id" < "order_items"."order_id" -TableGroup g1 [note: 'test note'] { +TableGroup g1 [note: 'test note', color: #FFF] { users merchants note: 'test note 2' From eb0c5101e834b264f96d214f194ec7124887544f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 7 Jan 2025 17:30:57 +0100 Subject: [PATCH 114/125] resolve conflicts --- pydbml/renderer/dbml/default/project.py | 4 +- pydbml/renderer/dbml/default/table_group.py | 4 +- pydbml/tools.py | 8 ++ test/test_data/integration1.dbml | 2 +- .../test_dbml/test_table_group.py | 4 +- test/test_tools.py | 89 +++++++++++-------- 6 files changed, 71 insertions(+), 40 deletions(-) diff --git a/pydbml/renderer/dbml/default/project.py b/pydbml/renderer/dbml/default/project.py index eed74a8..f885b10 100644 --- a/pydbml/renderer/dbml/default/project.py +++ b/pydbml/renderer/dbml/default/project.py @@ -4,6 +4,7 @@ from pydbml.classes import Project from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer from pydbml.renderer.dbml.default.utils import comment_to_dbml +from pydbml.tools import doublequote_string def render_items(items: Dict[str, str]) -> str: @@ -19,7 +20,8 @@ def render_items(items: Dict[str, str]) -> str: @DefaultDBMLRenderer.renderer_for(Project) def render_project(model: Project) -> str: result = comment_to_dbml(model.comment) if model.comment else '' - result += f'Project "{model.name}" {{\n' + quoted_name = doublequote_string(model.name) + result += f'Project {quoted_name} {{\n' result += render_items(model.items) if model.note: result += indent(DefaultDBMLRenderer.render(model.note), ' ') + '\n' diff --git a/pydbml/renderer/dbml/default/table_group.py b/pydbml/renderer/dbml/default/table_group.py index 311d461..ca7cb44 100644 --- a/pydbml/renderer/dbml/default/table_group.py +++ b/pydbml/renderer/dbml/default/table_group.py @@ -4,12 +4,14 @@ from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer from pydbml.renderer.dbml.default.table import get_full_name_for_dbml from pydbml.renderer.dbml.default.utils import comment_to_dbml +from pydbml.tools import doublequote_string @DefaultDBMLRenderer.renderer_for(TableGroup) def render_table_group(model: TableGroup) -> str: result = comment_to_dbml(model.comment) if model.comment else '' - result += f'TableGroup {model.name}' + quoted_name = doublequote_string(model.name) + result += f'TableGroup {quoted_name}' if model.color: result += f' [color: {model.color}]' result += ' {\n' diff --git a/pydbml/tools.py b/pydbml/tools.py index 4892cbc..9ebd8fb 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -27,6 +27,14 @@ def strip_empty_lines(source: str) -> str: return pattern.sub('\g', source) +def doublequote_string(source: str) -> str: + """Safely wrap a single-line string in double quotes""" + if '\n' in source: + raise ValueError(f'Multiline strings are not allowed: {source!r}') + result = source.strip('"').replace('"', '\\"') + return f'"{result}"' + + def remove_indentation(source: str) -> str: if not source: return source diff --git a/test/test_data/integration1.dbml b/test/test_data/integration1.dbml index 9c59723..b24c9d7 100644 --- a/test/test_data/integration1.dbml +++ b/test/test_data/integration1.dbml @@ -38,7 +38,7 @@ Ref { "Employees"."favorite_book_id" > "books"."id" } -TableGroup Unanimate { +TableGroup "Unanimate" { "books" "countries" } \ No newline at end of file diff --git a/test/test_renderer/test_dbml/test_table_group.py b/test/test_renderer/test_dbml/test_table_group.py index adfa1a3..5342b39 100644 --- a/test/test_renderer/test_dbml/test_table_group.py +++ b/test/test_renderer/test_dbml/test_table_group.py @@ -12,7 +12,7 @@ def test_simple(table1: Table, table2: Table, table3: Table) -> None: items=[table1, table2, table3], ) expected = ( - 'TableGroup mygroup {\n' + 'TableGroup "mygroup" {\n' ' "products"\n' ' "products"\n' ' "orders"\n' @@ -31,7 +31,7 @@ def test_full(table1: Table, table2: Table, table3: Table) -> None: ) expected = ( '// My comment\n' - 'TableGroup mygroup [color: #FFF] {\n' + 'TableGroup "mygroup" [color: #FFF] {\n' ' "products"\n' ' "products"\n' ' "orders"\n' diff --git a/test/test_tools.py b/test/test_tools.py index 6c1f26c..a452fb2 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -1,7 +1,9 @@ from unittest import TestCase +import pytest + from pydbml.classes import Note -from pydbml.tools import remove_indentation +from pydbml.tools import remove_indentation, doublequote_string from pydbml.renderer.sql.default.utils import comment_to_sql from pydbml.tools import indent from pydbml.renderer.dbml.default.utils import note_option_to_dbml, comment_to_dbml @@ -10,103 +12,120 @@ class TestCommentToDBML(TestCase): def test_comment(self) -> None: - oneline = 'comment' - self.assertEqual(f'// {oneline}\n', comment_to_dbml(oneline)) + oneline = "comment" + self.assertEqual(f"// {oneline}\n", comment_to_dbml(oneline)) - expected = \ -'''// + expected = """// // line1 // line2 // line3 // -''' - source = '\nline1\nline2\nline3\n' +""" + source = "\nline1\nline2\nline3\n" self.assertEqual(comment_to_dbml(source), expected) class TestCommentToSQL(TestCase): def test_comment(self) -> None: - oneline = 'comment' - self.assertEqual(f'-- {oneline}\n', comment_to_sql(oneline)) + oneline = "comment" + self.assertEqual(f"-- {oneline}\n", comment_to_sql(oneline)) - expected = \ -'''-- + expected = """-- -- line1 -- line2 -- line3 -- -''' - source = '\nline1\nline2\nline3\n' +""" + source = "\nline1\nline2\nline3\n" self.assertEqual(comment_to_sql(source), expected) class TestNoteOptionToDBML(TestCase): def test_oneline(self) -> None: - note = Note('one line note') + note = Note("one line note") self.assertEqual(f"note: 'one line note'", note_option_to_dbml(note)) def test_oneline_with_quote(self) -> None: - note = Note('one line\'d note') + note = Note("one line'd note") self.assertEqual(f"note: 'one line\\'d note'", note_option_to_dbml(note)) def test_multiline(self) -> None: - note = Note('line1\nline2\nline3') + note = Note("line1\nline2\nline3") expected = "note: '''line1\nline2\nline3'''" self.assertEqual(expected, note_option_to_dbml(note)) def test_multiline_with_quotes(self) -> None: - note = Note('line1\n\'\'\'line2\nline3') + note = Note("line1\n'''line2\nline3") expected = "note: '''line1\n\\'''line2\nline3'''" self.assertEqual(expected, note_option_to_dbml(note)) class TestIndent(TestCase): def test_empty(self) -> None: - self.assertEqual(indent(''), '') + self.assertEqual(indent(""), "") def test_nonempty(self) -> None: - oneline = 'one line text' - self.assertEqual(indent(oneline), f' {oneline}') - source = 'line1\nline2\nline3' - expected = ' line1\n line2\n line3' + oneline = "one line text" + self.assertEqual(indent(oneline), f" {oneline}") + source = "line1\nline2\nline3" + expected = " line1\n line2\n line3" self.assertEqual(indent(source), expected) - expected2 = ' line1\n line2\n line3' + expected2 = " line1\n line2\n line3" self.assertEqual(indent(source, 2), expected2) class TestStripEmptyLines(TestCase): def test_empty(self) -> None: - source = '' + source = "" self.assertEqual(strip_empty_lines(source), source) def test_no_empty_lines(self) -> None: - source = 'line1\n\n\nline2' + source = "line1\n\n\nline2" self.assertEqual(strip_empty_lines(source), source) def test_empty_lines(self) -> None: - stripped = ' line1\n\n line2' - source = f'\n \n \n\t \t \n \n{stripped}\n\n\n \n \t \n\t \n \n' + stripped = " line1\n\n line2" + source = f"\n \n \n\t \t \n \n{stripped}\n\n\n \n \t \n\t \n \n" self.assertEqual(strip_empty_lines(source), stripped) def test_one_empty_line(self) -> None: - stripped = ' line1\n\n line2' - source = f'\n{stripped}' + stripped = " line1\n\n line2" + source = f"\n{stripped}" self.assertEqual(strip_empty_lines(source), stripped) - source = f'{stripped}\n' + source = f"{stripped}\n" self.assertEqual(strip_empty_lines(source), stripped) def test_end(self) -> None: - stripped = ' line1\n\n line2' - source = f'\n{stripped}\n ' + stripped = " line1\n\n line2" + source = f"\n{stripped}\n " self.assertEqual(strip_empty_lines(source), stripped) class TestRemoveIndentation(TestCase): def test_empty(self) -> None: - source = '' + source = "" self.assertEqual(remove_indentation(source), source) def test_not_empty(self) -> None: - source = ' line1\n line2' - expected = 'line1\n line2' + source = " line1\n line2" + expected = "line1\n line2" self.assertEqual(remove_indentation(source), expected) + + +class TestDoublequoteString: + @staticmethod + @pytest.mark.parametrize( + "source,expected", + [ + ("Test string", '"Test string"'), + ('String with "quotes"!', '"String with \\"quotes\\"!"'), + ('"Quoted string"', '"Quoted string"'), + ], + ) + def test_oneline(source: str, expected: str) -> None: + assert doublequote_string(source) == expected + + @staticmethod + def test_multiline() -> None: + with pytest.raises(ValueError): + doublequote_string('line1\nline2') From a3c5875c6737e972ba982aceb415584ae093d912 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 7 Jan 2025 17:39:18 +0100 Subject: [PATCH 115/125] update changelog and bump version --- CHANGELOG.md | 8 ++++++++ README.md | 2 +- setup.py | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8156a22..390b291 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# 1.1.2 +- Fix: escaping single quotes in a column's default value. Thanks @ryanproback for the contribution +- Fix: TableGroup and Project name are now safely quoted on render. Thanks @ryanproback for reporting +- Fix: line breaks in column and index options are now allowed. Thanks @aardjon for reporting +- Fix: table elements order is now not enforced by the parser. Thanks @aardjon for reporting +- New: TableGroup now can have notes (DBML v.3.7.2) +- New: TableGroup now can have color (DBML v.3.7.4) + # 1.1.1 - New: SQL and DBML renderers can now be supplied to parser diff --git a/README.md b/README.md index f4005cb..d30ecb6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # DBML parser for Python -*Compliant with DBML **v3.6.1** syntax* +*Compliant with DBML **v3.9.5** syntax* PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. diff --git a/setup.py b/setup.py index 37766f4..a3bb235 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.1.1', + version='1.1.2', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 3b9f8df338f4b99d0af032766a84a23ef9e6c847 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Tue, 7 Jan 2025 17:41:11 +0100 Subject: [PATCH 116/125] fix escape sequence --- pydbml/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydbml/tools.py b/pydbml/tools.py index 9ebd8fb..353136a 100644 --- a/pydbml/tools.py +++ b/pydbml/tools.py @@ -24,7 +24,7 @@ def remove_bom(source: str) -> str: def strip_empty_lines(source: str) -> str: """Remove empty lines or lines with just spaces from beginning and end.""" pattern = re.compile(r'^([ \t]*\n)*(?P[\s\S]+?)(\n[ \t]*)*$') - return pattern.sub('\g', source) + return pattern.sub(r'\g', source) def doublequote_string(source: str) -> str: From ea21dc1c3edb32fe3f4489ad59c1eb51f85a2a14 Mon Sep 17 00:00:00 2001 From: Pierre Souchay Date: Wed, 15 Jan 2025 08:13:25 +0100 Subject: [PATCH 117/125] feat(indexes): support all Postgresqsl index types Ref: https://github.com/Vanderhoof/PyDBML/issues/56 --- pydbml/definitions/index.py | 7 ++++++- pydbml/parser/blueprints.py | 12 +++++++++++- test/test_definitions/test_index.py | 16 ++++++++++------ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index 9364901..ed39fda 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -15,7 +15,12 @@ pp.ParserElement.set_default_whitespace_chars(' \t\r') index_type = pp.CaselessLiteral("type:").suppress() + _ - ( - pp.CaselessLiteral("btree")('type') | pp.CaselessLiteral("hash")('type') + pp.CaselessLiteral("brin")('type') | + pp.CaselessLiteral("btree")('type') | + pp.CaselessLiteral("gin")('type') | + pp.CaselessLiteral("gist")('type') | + pp.CaselessLiteral("hash")('type') | + pp.CaselessLiteral("spgist")('type') ) index_setting = _ + ( unique('unique') diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py index 5c7b3a2..614a251 100644 --- a/pydbml/parser/blueprints.py +++ b/pydbml/parser/blueprints.py @@ -166,7 +166,17 @@ class IndexBlueprint(Blueprint): subject_names: List[Union[str, ExpressionBlueprint]] name: Optional[str] = None unique: bool = False - type: Optional[Literal['hash', 'btree']] = None + type: Optional[ + Literal[ + # https://www.postgresql.org/docs/current/indexes-types.html + "brin", + "btree", + "gin", + "gist", + "hash", + "spgist", + ] + ] = None pk: bool = False note: Optional[NoteBlueprint] = None comment: Optional[str] = None diff --git a/test/test_definitions/test_index.py b/test/test_definitions/test_index.py index 4396998..b3239f4 100644 --- a/test/test_definitions/test_index.py +++ b/test/test_definitions/test_index.py @@ -20,12 +20,16 @@ class TestIndexType(TestCase): def test_correct(self) -> None: - val = 'Type: BTREE' - res = index_type.parse_string(val, parseAll=True) - self.assertEqual(res['type'], 'btree') - val2 = 'type:\nhash' - res2 = index_type.parse_string(val2, parseAll=True) - self.assertEqual(res2['type'], 'hash') + for val, expected in [ + ("Type: BTREE", "btree"), + ("type: hash", "hash"), + ("type: gist", "gist"), + ("TYPE:SPGiST", "spgist"), + ("type: GIN", "gin"), + ("Type:\tbRiN", "brin"), + ]: + res = index_type.parse_string(val, parseAll=True) + self.assertEqual(res["type"], expected) def test_incorrect(self) -> None: val = 'type: wrong' From 502fb092abcf0dd9470c03bdd97698b0027325aa Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Wed, 15 Jan 2025 10:05:06 +0100 Subject: [PATCH 118/125] chore: update index class type hint and docs --- docs/classes.md | 2 +- pydbml/_classes/index.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/classes.md b/docs/classes.md index f4e0fd5..d4605bd 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -158,7 +158,7 @@ Indexes are stored in the `indexes` attribute of a `Table` object. * **table** (`Table`) — link to table, for which this index is defined. * **name** (str) — index name, if defined. * **unique** (bool) — indicates whether the index is unique. -* **type** (str) — index type, if defined. Can be either `hash` or `btree`. +* **type** (str) — index type, if defined. Accepted values: `brin`, `btree`, `gin`, `gist`, `hash`, `spgist`. * **pk** (bool) — indicates whether this a primary key index. * **note** (note) — index note, if defined. * **comment** (str) — comment, if it was added just before index definition. diff --git a/pydbml/_classes/index.py b/pydbml/_classes/index.py index 35581bd..0d6ee24 100644 --- a/pydbml/_classes/index.py +++ b/pydbml/_classes/index.py @@ -22,7 +22,17 @@ def __init__(self, subjects: List[Union[str, Column, Expression]], name: Optional[str] = None, unique: bool = False, - type: Optional[Literal['hash', 'btree']] = None, + type: Optional[ + Literal[ + # https://www.postgresql.org/docs/current/indexes-types.html + "brin", + "btree", + "gin", + "gist", + "hash", + "spgist", + ] + ] = None, pk: bool = False, note: Optional[Union[Note, str]] = None, comment: Optional[str] = None): From 47de70d1a130fc35cffd8cf97da286884881b462 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Wed, 15 Jan 2025 10:13:42 +0100 Subject: [PATCH 119/125] update changelog and bump version --- CHANGELOG.md | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 390b291..81db133 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.1.3 +- New: support more index types. Thanks @pierresouchay for the contribution. + # 1.1.2 - Fix: escaping single quotes in a column's default value. Thanks @ryanproback for the contribution - Fix: TableGroup and Project name are now safely quoted on render. Thanks @ryanproback for reporting diff --git a/setup.py b/setup.py index a3bb235..f0ed360 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.1.2', + version='1.1.3', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From e8eab9cbe3b8ee78a43a0b66930e38ade9d78c26 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Fri, 14 Feb 2025 13:33:57 +0100 Subject: [PATCH 120/125] feat(#58): remove trailing comma in Enum SQL --- pydbml/renderer/sql/default/enum.py | 3 ++- test/test_data/integration1.sql | 2 +- test/test_renderer/test_sql/test_default/test_enum.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pydbml/renderer/sql/default/enum.py b/pydbml/renderer/sql/default/enum.py index 6bc8626..5d1b3ba 100644 --- a/pydbml/renderer/sql/default/enum.py +++ b/pydbml/renderer/sql/default/enum.py @@ -21,7 +21,8 @@ def render_enum(model: Enum) -> str: result = comment_to_sql(model.comment) if model.comment else '' result += f'CREATE TYPE {get_full_name_for_sql(model)} AS ENUM (\n' - result += '\n'.join(f'{indent(DefaultSQLRenderer.render(i), " ")}' for i in model.items) + enum_body = '\n'.join(f'{indent(DefaultSQLRenderer.render(i), " ")}' for i in model.items) + result += enum_body.rstrip(',') result += '\n);' return result diff --git a/test/test_data/integration1.sql b/test/test_data/integration1.sql index 7330553..f648fb1 100644 --- a/test/test_data/integration1.sql +++ b/test/test_data/integration1.sql @@ -1,7 +1,7 @@ CREATE TYPE "level" AS ENUM ( 'junior', 'middle', - 'senior', + 'senior' ); CREATE TABLE "books" ( diff --git a/test/test_renderer/test_sql/test_default/test_enum.py b/test/test_renderer/test_sql/test_default/test_enum.py index 9d76866..9a5d4ee 100644 --- a/test/test_renderer/test_sql/test_default/test_enum.py +++ b/test/test_renderer/test_sql/test_default/test_enum.py @@ -21,7 +21,7 @@ def test_simple_enum(enum1: Enum) -> None: expected = ( 'CREATE TYPE "product status" AS ENUM (\n' " 'production',\n" - " 'development',\n" + " 'development'\n" ");" ) assert render_enum(enum1) == expected @@ -33,7 +33,7 @@ def test_comments(enum1: Enum) -> None: "-- Enum comment\n" 'CREATE TYPE "product status" AS ENUM (\n' " 'production',\n" - " 'development',\n" + " 'development'\n" ");" ) assert render_enum(enum1) == expected From 2bf015cf0d4a242339dad45095afec2c0efa197f Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Fri, 14 Feb 2025 13:35:11 +0100 Subject: [PATCH 121/125] update changelog and bump version --- CHANGELOG.md | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81db133..0ea6361 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.1.4 +- Fix: Remove trailing comma in Enum SQL (#58 thanks @ralfschulze for reporting) + # 1.1.3 - New: support more index types. Thanks @pierresouchay for the contribution. diff --git a/setup.py b/setup.py index f0ed360..36d152e 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.1.3', + version='1.1.4', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 4af1b0b9c4bfe28b7f90cc29c672254b503a2db3 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Wed, 12 Mar 2025 12:50:39 +0100 Subject: [PATCH 122/125] Disable unicode characters support in identifiers (#59) --- pydbml/definitions/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydbml/definitions/generic.py b/pydbml/definitions/generic.py index 180cca4..69c1270 100644 --- a/pydbml/definitions/generic.py +++ b/pydbml/definitions/generic.py @@ -4,7 +4,7 @@ pp.ParserElement.set_default_whitespace_chars(' \t\r') -name = pp.Word(pp.unicode.alphanums + '_') | pp.QuotedString('"') +name = pp.Word(pp.alphanums + '_') | pp.QuotedString('"') # Literals From b4c915f2f6f7d89fe0992a7a7b55bb1dc48887e5 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Wed, 12 Mar 2025 12:51:44 +0100 Subject: [PATCH 123/125] update changelog and bump version --- CHANGELOG.md | 3 +++ prof.py | 7 +++++++ setup.py | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 prof.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ea6361..3f820d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.2.0 +- Fix: Temporarily disable unicode characters support in identifiers for performance (#59) + # 1.1.4 - Fix: Remove trailing comma in Enum SQL (#58 thanks @ralfschulze for reporting) diff --git a/prof.py b/prof.py new file mode 100644 index 0000000..f4f15ec --- /dev/null +++ b/prof.py @@ -0,0 +1,7 @@ +from pydbml import PyDBML +from pathlib import Path + + +if __name__ == '__main__': + d = PyDBML(Path('test_schema.dbml')) + print(d) diff --git a/setup.py b/setup.py index 36d152e..fc5cf27 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='1.1.4', + version='1.2.0', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', From 5c6a9173580d7273a6799cdd725eabd3a2e76911 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Wed, 12 Mar 2025 12:52:28 +0100 Subject: [PATCH 124/125] remove temp file --- prof.py | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 prof.py diff --git a/prof.py b/prof.py deleted file mode 100644 index f4f15ec..0000000 --- a/prof.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydbml import PyDBML -from pathlib import Path - - -if __name__ == '__main__': - d = PyDBML(Path('test_schema.dbml')) - print(d) From 52d59ad2044f963aa34ba6ceca00180595705fa0 Mon Sep 17 00:00:00 2001 From: samhaese <63563702+samhaese@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:46:30 -0700 Subject: [PATCH 125/125] Add load(s) and dump(s) methods --- pydbml/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pydbml/__init__.py b/pydbml/__init__.py index a585105..868de24 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -1,3 +1,15 @@ +import os + from . import _classes from .parser import PyDBML from .database import Database + +load = PyDBML.parse_file +loads = PyDBML.parse + +def dump(db: Database, fp: str | os.PathLike): + with open(fp, 'w') as f: + f.write(db.dbml) + +def dumps(db: Database) -> str: + return db.dbml \ No newline at end of file