diff --git a/CHANGELOG.md b/CHANGELOG.md index 4168c37..8871a29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,19 @@ +## Unreleased + +### Features + +- Make LLM support optional and installable via `litecli[ai]`. + +### Bug Fixes + +- Avoid completion refresh crashes when no database is connected. + +### Internal + +- Clean up ty type-checking for optional sqlean/llm imports. +- Add an llm module alias for test patching. +- Avoid ty conflicts for optional sqlite/llm imports. + ## 1.18.0 ### Internal diff --git a/litecli/__init__.py b/litecli/__init__.py index fcca276..3c05333 100644 --- a/litecli/__init__.py +++ b/litecli/__init__.py @@ -1,4 +1,3 @@ -# type: ignore from __future__ import annotations import importlib.metadata diff --git a/litecli/main.py b/litecli/main.py index 2f0e594..8151ffa 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -11,11 +11,6 @@ from collections import namedtuple from datetime import datetime from io import open - -try: - from sqlean import OperationalError, sqlite_version -except ImportError: - from sqlite3 import OperationalError, sqlite_version from time import time from typing import Any, Generator, Iterable, cast @@ -51,6 +46,21 @@ from .sqlcompleter import SQLCompleter from .sqlexecute import SQLExecute + +def _load_sqlite3() -> Any: + try: + import sqlean + except ImportError: + import sqlite3 + + return sqlite3 + return sqlean + + +_sqlite3 = _load_sqlite3() +OperationalError = _sqlite3.OperationalError +sqlite_version = _sqlite3.sqlite_version + # Query tuples are used for maintaining history Query = namedtuple("Query", ["query", "successful", "mutating"]) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index e2d7efa..11e71d5 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import importlib import io import logging import os @@ -13,20 +14,53 @@ from typing import Any import click -import llm -from llm.cli import cli from . import export from .main import Verbosity, parse_special_command from .types import DBCursor + +def _load_llm_module() -> Any | None: + try: + return importlib.import_module("llm") + except ImportError: + return None + + +def _load_llm_cli_module() -> Any | None: + try: + return importlib.import_module("llm.cli") + except ImportError: + return None + + +llm_module = _load_llm_module() +llm_cli_module = _load_llm_cli_module() + +# Alias for tests and patching. +llm = llm_module + +LLM_IMPORTED = llm_module is not None + +cli: click.Command | None +if llm_cli_module is not None: + llm_cli = getattr(llm_cli_module, "cli", None) + cli = llm_cli if isinstance(llm_cli, click.Command) else None +else: + cli = None + +LLM_CLI_IMPORTED = cli is not None + log = logging.getLogger(__name__) LLM_TEMPLATE_NAME = "litecli-llm-template" -LLM_CLI_COMMANDS: list[str] = list(cli.commands.keys()) +LLM_CLI_COMMANDS: list[str] = list(cli.commands.keys()) if isinstance(cli, click.Group) else [] # Mapping of model_id to None used for completion tree leaves. -# the file name is llm.py and module name is llm, hence ty is complaining that get_models is missing. -MODELS: dict[str, None] = {x.model_id: None for x in llm.get_models()} # type: ignore[attr-defined] +if llm_module is not None: + get_models = getattr(llm_module, "get_models", None) + MODELS: dict[str, None] = {x.model_id: None for x in get_models()} if callable(get_models) else {} +else: + MODELS = {} def run_external_cmd( @@ -110,7 +144,7 @@ def build_command_tree(cmd: click.Command) -> dict[str, Any] | None: # Generate the tree -COMMAND_TREE: dict[str, Any] | None = build_command_tree(cli) +COMMAND_TREE: dict[str, Any] | None = build_command_tree(cli) if cli is not None else {} def get_completions(tokens: list[str], tree: dict[str, Any] | None = COMMAND_TREE) -> list[str]: @@ -123,6 +157,8 @@ def get_completions(tokens: list[str], tree: dict[str, Any] | None = COMMAND_TRE Returns: list[str]: List of possible completions. """ + if not LLM_CLI_IMPORTED: + return [] for token in tokens: if token.startswith("-"): # Skip options (flags) @@ -171,6 +207,18 @@ def __init__(self, results: Any | None = None) -> None: # https://llm.datasette.io/en/stable/plugins/directory.html """ +NEED_DEPENDENCIES = """ +To enable LLM features you need to install litecli with AI support: + + pip install 'litecli[ai]' + +or install LLM libraries separately + + pip install llm + +This is required to use the \\llm command. +""" + _SQL_CODE_FENCE = r"```sql\n(.*?)\n```" PROMPT = """ You are a helpful assistant who is a SQLite expert. You are embedded in a SQLite @@ -230,6 +278,10 @@ def handle_llm(text: str, cur: DBCursor) -> tuple[str, str | None, float]: is_verbose = mode is Verbosity.VERBOSE is_succinct = mode is Verbosity.SUCCINCT + if not LLM_IMPORTED: + output = [(None, None, None, NEED_DEPENDENCIES)] + raise FinishIteration(output) + if not arg.strip(): # No question provided. Print usage and bail. output = [(None, None, None, USAGE)] raise FinishIteration(output) diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py index c345c76..08f7bd5 100644 --- a/litecli/packages/special/main.py +++ b/litecli/packages/special/main.py @@ -8,6 +8,13 @@ log = logging.getLogger(__name__) +try: + import llm # noqa: F401 + + LLM_IMPORTED = True +except ImportError: + LLM_IMPORTED = False + NO_QUERY = 0 PARSED_QUERY = 1 RAW_QUERY = 2 @@ -176,13 +183,19 @@ def quit(*_args: Any) -> None: arg_type=NO_QUERY, case_sensitive=True, ) -@special_command( - "\\llm", - "\\ai", - "Use LLM to construct a SQL query.", - arg_type=NO_QUERY, - case_sensitive=False, - aliases=(".ai", ".llm"), -) def stub() -> None: raise NotImplementedError + + +if LLM_IMPORTED: + + @special_command( + "\\llm", + "\\ai", + "Use LLM to construct a SQL query.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=(".ai", ".llm"), + ) + def llm_stub() -> None: + raise NotImplementedError diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py index c81123a..5e52b73 100644 --- a/litecli/sqlexecute.py +++ b/litecli/sqlexecute.py @@ -1,25 +1,26 @@ from __future__ import annotations import logging +import os.path from contextlib import closing -from typing import Any, Generator, Iterable +from typing import Any, Generator, Iterable, cast +from urllib.parse import urlparse + +import sqlparse try: - import sqlean as sqlite3 - from sqlean import OperationalError + import sqlean as _sqlite3 - sqlite3.extensions.enable_all() + _sqlite3.extensions.enable_all() except ImportError: - import sqlite3 - from sqlite3 import OperationalError -import os.path -from urllib.parse import urlparse - -import sqlparse + import sqlite3 as _sqlite3 from litecli.packages import special from litecli.packages.special.utils import check_if_sqlitedotcommand +sqlite3 = cast(Any, _sqlite3) +OperationalError = sqlite3.OperationalError + _logger = logging.getLogger(__name__) # FIELD_TYPES = decoders.copy() @@ -179,7 +180,8 @@ def get_result(self, cursor: Any) -> tuple[str | None, list | None, list | None, def tables(self) -> Generator[tuple[str], None, None]: """Yields table names""" - assert self.conn is not None + if not self.conn: + return with closing(self.conn.cursor()) as cur: _logger.debug("Tables Query. sql: %r", self.tables_query) cur.execute(self.tables_query) @@ -188,7 +190,8 @@ def tables(self) -> Generator[tuple[str], None, None]: def table_columns(self) -> Generator[tuple[str, str], None, None]: """Yields column names""" - assert self.conn is not None + if not self.conn: + return with closing(self.conn.cursor()) as cur: _logger.debug("Columns Query. sql: %r", self.table_columns_query) cur.execute(self.table_columns_query) @@ -206,7 +209,8 @@ def databases(self) -> Generator[str, None, None]: def functions(self) -> Iterable[tuple]: """Yields tuples of (schema_name, function_name)""" - assert self.conn is not None + if not self.conn: + return with closing(self.conn.cursor()) as cur: _logger.debug("Functions Query. sql: %r", self.functions_query) cur.execute(self.functions_query % self.dbname) diff --git a/pyproject.toml b/pyproject.toml index 49650c9..939b997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,7 @@ dependencies = [ "configobj>=5.0.5", "prompt-toolkit>=3.0.3,<4.0.0", "pygments>=1.6", - "sqlparse>=0.4.4", - "setuptools", # Required by llm commands to install models - "pip", - "llm>=0.25.0" + "sqlparse>=0.4.4" ] [build-system] @@ -33,7 +30,11 @@ build-backend = "setuptools.build_meta" litecli = "litecli.main:cli" [project.optional-dependencies] -ai = ["llm"] +ai = [ + "llm>=0.25.0", + "setuptools", # Required by llm commands to install models + "pip", +] sqlean = ["sqlean-py>=3.47.0", "sqlean-stubs>=0.0.3"] @@ -45,7 +46,9 @@ dev = [ "pytest-cov>=4.1.0", "tox>=4.8.0", "pdbpp>=0.10.3", - "llm>=0.19.0", + "llm>=0.25.0", + "setuptools", + "pip", "ty>=0.0.4" ] diff --git a/tests/test_llm_special.py b/tests/test_llm_special.py index d7de461..b62a8b9 100644 --- a/tests/test_llm_special.py +++ b/tests/test_llm_special.py @@ -2,9 +2,16 @@ import pytest +import litecli.packages.special.llm as llm_module from litecli.packages.special.llm import USAGE, FinishIteration, handle_llm +@pytest.fixture(autouse=True) +def enable_llm(monkeypatch): + monkeypatch.setattr(llm_module, "LLM_IMPORTED", True) + monkeypatch.setattr(llm_module, "LLM_CLI_COMMANDS", ["models"]) + + @patch("litecli.packages.special.llm.llm") def test_llm_command_without_args(mock_llm, executor): r""" diff --git a/tests/test_main.py b/tests/test_main.py index 4c880e7..0a47c9b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -243,7 +243,7 @@ def stub_terminal_size(): shutil.get_terminal_size = stub_terminal_size # type: ignore[assignment] lc = LiteCli() assert isinstance(lc.get_reserved_space(), int) - shutil.get_terminal_size = old_func # type: ignore[assignment] + shutil.get_terminal_size = old_func @dbtest diff --git a/tests/test_sqlexecute.py b/tests/test_sqlexecute.py index 5d20851..3fdaef1 100644 --- a/tests/test_sqlexecute.py +++ b/tests/test_sqlexecute.py @@ -1,15 +1,26 @@ # coding=UTF-8 import os +from typing import Any import pytest from .utils import assert_result_equal, dbtest, is_expanded_output, run, set_expanded_output -try: - from sqlean import OperationalError, ProgrammingError -except ImportError: - from sqlite3 import OperationalError, ProgrammingError + +def _load_sqlite3() -> Any: + try: + import sqlean + except ImportError: + import sqlite3 + + return sqlite3 + return sqlean + + +_sqlite3 = _load_sqlite3() +OperationalError = _sqlite3.OperationalError +ProgrammingError = _sqlite3.ProgrammingError @dbtest