Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy import select
from sqlalchemy.exc import ProgrammingError
from starlette.concurrency import run_in_threadpool
from starlette.status import HTTP_201_CREATED, HTTP_409_CONFLICT, HTTP_204_NO_CONTENT

from api.pagination import CustomPage
Expand Down Expand Up @@ -84,7 +85,8 @@ async def upload_asset(
) -> dict:
from services.gcs_helper import gcs_upload

uri, blob_name = gcs_upload(file, bucket)
# GCS client calls are synchronous and can block for large uploads.
uri, blob_name = await run_in_threadpool(gcs_upload, file, bucket)
return {
"uri": uri,
"storage_path": blob_name,
Expand Down
47 changes: 34 additions & 13 deletions services/gcs_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import base64
import datetime
import json
import logging
import os
from functools import lru_cache
from hashlib import md5

from fastapi import UploadFile
Expand All @@ -27,8 +29,12 @@

GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME")
GCS_BUCKET_BASE_URL = f"https://storage.cloud.google.com/{GCS_BUCKET_NAME}/uploads"
GCS_LOOKUP_TIMEOUT_SECS = float(os.environ.get("GCS_LOOKUP_TIMEOUT_SECS", "15"))
GCS_UPLOAD_TIMEOUT_SECS = float(os.environ.get("GCS_UPLOAD_TIMEOUT_SECS", "120"))
logger = logging.getLogger(__name__)


@lru_cache(maxsize=1)
def get_storage_client():
from google.cloud import storage
from google.oauth2 import service_account
Expand All @@ -51,14 +57,16 @@ def get_storage_client():
return client


def get_storage_bucket(client=None, bucket: str = None):
if client is None:
client = get_storage_client()
@lru_cache(maxsize=8)
def _get_cached_bucket(bucket_name: str):
return get_storage_client().bucket(bucket_name)

if bucket is None:
bucket = GCS_BUCKET_NAME

return client.bucket(bucket)
def get_storage_bucket(client=None, bucket: str = None):
bucket_name = bucket or GCS_BUCKET_NAME
if client is not None:
return client.bucket(bucket_name)
return _get_cached_bucket(bucket_name)


def make_blob_name_and_uri(file):
Expand All @@ -78,12 +86,16 @@ def gcs_upload(file: UploadFile, bucket=None):
file.file.seek(0)

blob_name, uri = make_blob_name_and_uri(file)
eblob = bucket.get_blob(blob_name)
eblob = bucket.get_blob(blob_name, timeout=GCS_LOOKUP_TIMEOUT_SECS)

if not eblob:
blob = bucket.blob(blob_name)
file.file.seek(0)
blob.upload_from_file(file.file, content_type=file.content_type)
blob.upload_from_file(
file.file,
content_type=file.content_type,
timeout=GCS_UPLOAD_TIMEOUT_SECS,
)
return uri, blob_name


Expand All @@ -93,11 +105,20 @@ def gcs_remove(uri: str, bucket):


def add_signed_url(asset: Asset, bucket):
asset.signed_url = bucket.blob(asset.storage_path).generate_signed_url(
version="v4",
expiration=datetime.timedelta(minutes=15),
method="GET",
)
try:
asset.signed_url = bucket.blob(asset.storage_path).generate_signed_url(
version="v4",
expiration=datetime.timedelta(minutes=15),
method="GET",
)
except Exception:
logger.warning(
"Failed to generate signed URL for asset_id=%s storage_path=%s",
getattr(asset, "id", None),
getattr(asset, "storage_path", None),
exc_info=True,
)
asset.signed_url = None
return asset


Expand Down
Loading