33# SPDX-License-Identifier: MIT
44
55import os
6- from typing import Optional
6+ from pathlib import Path
7+ from typing import Optional , Union
78from urllib .parse import urlparse
89
910from context_logger import get_logger
1617
1718class IFileDownloader (object ):
1819
19- def download (
20- self ,
21- file_url : str ,
22- file_name : Optional [str ] = None ,
23- sub_dir : Optional [str ] = None ,
24- headers : Optional [dict [str , str ]] = None ,
25- skip_if_exists : bool = True ,
26- chunk_size : int = 1000 * 1000 ,
27- ) -> str :
20+ def download (self , file_url : str , file_name : Optional [str ] = None , sub_dir : Optional [Union [str , Path ]] = None ,
21+ headers : Optional [dict [str , str ]] = None , skip_if_exists : bool = True , chunk_size : int = 1000 * 1000
22+ ) -> Path :
2823 raise NotImplementedError ()
2924
30- def download_and_copy (
31- self ,
32- file_url : str ,
33- sub_dirs : list [str ],
34- file_name : Optional [str ] = None ,
35- headers : Optional [dict [str , str ]] = None ,
36- skip_if_exists : bool = True ,
37- chunk_size : int = 1000 * 1000 ,
38- ) -> list [str ]:
25+ def download_and_copy (self , file_url : str , sub_dirs : list [Union [str , Path ]], file_name : Optional [str ] = None ,
26+ headers : Optional [dict [str , str ]] = None , skip_if_exists : bool = True ,
27+ chunk_size : int = 1000 * 1000 ) -> list [Path ]:
3928 raise NotImplementedError ()
4029
4130
4231class FileDownloader (IFileDownloader ):
4332
44- def __init__ (self , session_provider : ISessionProvider , download_location : str ) -> None :
33+ def __init__ (self , session_provider : ISessionProvider , download_location : Path ) -> None :
4534 self ._session_provider = session_provider
4635 self ._download_location = download_location
4736
48- def download (
49- self ,
50- file_url : str ,
51- file_name : Optional [str ] = None ,
52- sub_dir : Optional [str ] = None ,
53- headers : Optional [dict [str , str ]] = None ,
54- skip_if_exists : bool = True ,
55- chunk_size : int = 1000 * 1000 ,
56- ) -> str :
37+ def download (self , file_url : str , file_name : Optional [str ] = None , sub_dir : Optional [Union [str , Path ]] = None ,
38+ headers : Optional [dict [str , str ]] = None , skip_if_exists : bool = True , chunk_size : int = 1000 * 1000
39+ ) -> Path :
5740 if not urlparse (file_url ).scheme :
5841 return self ._check_local_file (file_url )
5942
@@ -75,15 +58,9 @@ def download(
7558
7659 return file_path
7760
78- def download_and_copy (
79- self ,
80- file_url : str ,
81- sub_dirs : list [str ],
82- file_name : Optional [str ] = None ,
83- headers : Optional [dict [str , str ]] = None ,
84- skip_if_exists : bool = True ,
85- chunk_size : int = 1000 * 1000 ,
86- ) -> list [str ]:
61+ def download_and_copy (self , file_url : str , sub_dirs : list [Union [str , Path ]], file_name : Optional [str ] = None ,
62+ headers : Optional [dict [str , str ]] = None , skip_if_exists : bool = True ,
63+ chunk_size : int = 1000 * 1000 ) -> list [Path ]:
8764 if not sub_dirs :
8865 raise ValueError ('At least one sub directory must be provided' )
8966
@@ -102,11 +79,11 @@ def download_and_copy(
10279
10380 return downloaded_files
10481
105- def _check_local_file (self , file_url : str ) -> str :
82+ def _check_local_file (self , file_url : str ) -> Path :
10683 file_path = os .path .abspath (file_url )
10784 if os .path .isfile (file_path ):
10885 log .info ('Local file path provided, skipping download' , file = file_path )
109- return file_path
86+ return Path ( file_path )
11087 else :
11188 log .error ('Local file does not exist' , file = file_path )
11289 raise ValueError ('Local file does not exist' )
@@ -121,21 +98,21 @@ def _send_request(self, file_url: str, headers: dict[str, str]) -> Response:
12198
12299 return response
123100
124- def _get_target_path (self , file_url : str , file_name : Optional [str ], sub_dir : Optional [str ] = None ) -> str :
125- if not file_name and '/' in file_url :
101+ def _get_target_path (self , file_url : str , file_name : Optional [str ],
102+ sub_dir : Optional [Union [str , Path ]] = None ) -> Path :
103+ if not file_name :
126104 file_name = file_url .split ('/' )[- 1 ]
127105
128106 download_dir = self ._download_location
129107
130108 if sub_dir :
131- download_dir += f'/ { sub_dir } '
109+ download_dir = download_dir / sub_dir
132110 create_directory (download_dir )
133111
134- return f' { download_dir } / { file_name } '
112+ return download_dir / file_name
135113
136- def _download_file (self , response : Response , file_path : str , chunk_size : int ) -> None :
137- if not os .path .exists (self ._download_location ):
138- os .makedirs (self ._download_location )
114+ def _download_file (self , response : Response , file_path : Path , chunk_size : int ) -> None :
115+ create_directory (self ._download_location )
139116
140117 with open (file_path , 'wb' ) as asset_file :
141118 for chunk in response .iter_content (chunk_size ):
0 commit comments