Skip to content

Commit df53b0e

Browse files
committed
Handle native Postgres numeric values
1 parent b44f901 commit df53b0e

2 files changed

Lines changed: 118 additions & 6 deletions

File tree

src/webserver/database/csv_import.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,18 @@ pub(super) async fn run_csv_import(
163163
)
164164
})?;
165165
let buffered = tokio::io::BufReader::new(file);
166-
run_csv_import_insert(db, csv_import, buffered).await
167-
.with_context(|| {
166+
run_csv_import_insert(db, csv_import, buffered).await.with_context(|| {
168167
let table_name = &csv_import.table_name;
169-
format!(
168+
let import_error = format!(
170169
"{} was uploaded correctly, but its records could not be imported into the table {}",
171170
file_path.display(),
172171
table_name
173-
)
172+
);
173+
if db.kind() == DbKind::Postgres {
174+
format!("The postgres COPY FROM STDIN command failed. {import_error}")
175+
} else {
176+
import_error
177+
}
174178
})
175179
}
176180

src/webserver/database/driver.rs

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ impl From<tokio_rusqlite::Error> for DbError {
119119

120120
impl From<tokio_postgres::Error> for DbError {
121121
fn from(error: tokio_postgres::Error) -> Self {
122+
if let Some(db_error) = error.as_db_error() {
123+
return DbError::Database {
124+
message: db_error.message().to_string(),
125+
offset: db_error.position().and_then(postgres_error_position),
126+
};
127+
}
122128
db_error(error)
123129
}
124130
}
@@ -142,6 +148,13 @@ fn db_error(error: impl std::error::Error) -> DbError {
142148
}
143149
}
144150

151+
fn postgres_error_position(position: &tokio_postgres::error::ErrorPosition) -> Option<usize> {
152+
match position {
153+
tokio_postgres::error::ErrorPosition::Original(position) => usize::try_from(*position).ok(),
154+
tokio_postgres::error::ErrorPosition::Internal { .. } => None,
155+
}
156+
}
157+
145158
#[derive(Clone)]
146159
pub struct DbPool {
147160
inner: Arc<DbPoolInner>,
@@ -887,8 +900,8 @@ fn postgres_value(
887900
.try_get::<_, f64>(idx)
888901
.map_or(DbValue::Null, DbValue::Real),
889902
Type::NUMERIC => row
890-
.try_get::<_, String>(idx)
891-
.map_or(DbValue::Null, |value| numeric_text_value(&value)),
903+
.try_get::<_, PgNumericValue>(idx)
904+
.map_or(DbValue::Null, |value| value.0),
892905
Type::BYTEA => row
893906
.try_get::<_, Vec<u8>>(idx)
894907
.map_or(DbValue::Null, DbValue::Bytes),
@@ -913,6 +926,101 @@ fn postgres_value(
913926
}
914927
}
915928

929+
struct PgNumericValue(DbValue);
930+
931+
impl<'a> tokio_postgres::types::FromSql<'a> for PgNumericValue {
932+
fn from_sql(
933+
ty: &tokio_postgres::types::Type,
934+
raw: &'a [u8],
935+
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
936+
if *ty != tokio_postgres::types::Type::NUMERIC {
937+
return Err("PgNumericValue only supports NUMERIC".into());
938+
}
939+
Ok(Self(numeric_text_value(&postgres_numeric_to_string(raw)?)))
940+
}
941+
942+
fn accepts(ty: &tokio_postgres::types::Type) -> bool {
943+
*ty == tokio_postgres::types::Type::NUMERIC
944+
}
945+
}
946+
947+
fn postgres_numeric_to_string(
948+
raw: &[u8],
949+
) -> Result<String, Box<dyn std::error::Error + Sync + Send>> {
950+
use std::fmt::Write as _;
951+
952+
const SIGN_POSITIVE: u16 = 0x0000;
953+
const SIGN_NEGATIVE: u16 = 0x4000;
954+
const SIGN_NAN: u16 = 0xC000;
955+
const SIGN_POSITIVE_INFINITY: u16 = 0xD000;
956+
const SIGN_NEGATIVE_INFINITY: u16 = 0xF000;
957+
958+
if raw.len() < 8 {
959+
return Err("invalid PostgreSQL NUMERIC payload".into());
960+
}
961+
let digit_count = usize::from(read_u16_be(raw, 0));
962+
let weight = read_i16_be(raw, 2);
963+
let sign = read_u16_be(raw, 4);
964+
let decimal_scale = usize::from(read_u16_be(raw, 6));
965+
if raw.len() < 8 + digit_count * 2 {
966+
return Err("truncated PostgreSQL NUMERIC payload".into());
967+
}
968+
match sign {
969+
SIGN_NAN => return Ok("NaN".to_string()),
970+
SIGN_POSITIVE_INFINITY => return Ok("Infinity".to_string()),
971+
SIGN_NEGATIVE_INFINITY => return Ok("-Infinity".to_string()),
972+
SIGN_POSITIVE | SIGN_NEGATIVE => {}
973+
_ => return Err("invalid PostgreSQL NUMERIC sign".into()),
974+
}
975+
976+
let digits = (0..digit_count)
977+
.map(|idx| read_u16_be(raw, 8 + idx * 2))
978+
.collect::<Vec<_>>();
979+
let integer_group_count = i32::from(weight) + 1;
980+
981+
let mut integer_part = String::new();
982+
if integer_group_count <= 0 {
983+
integer_part.push('0');
984+
} else {
985+
for group_idx in 0..usize::try_from(integer_group_count)? {
986+
let digit = digits.get(group_idx).copied().unwrap_or_default();
987+
if group_idx == 0 {
988+
integer_part.push_str(&digit.to_string());
989+
} else {
990+
write!(integer_part, "{digit:04}")?;
991+
}
992+
}
993+
}
994+
995+
let mut fraction_part = String::new();
996+
for _ in 0..(-integer_group_count).max(0) {
997+
fraction_part.push_str("0000");
998+
}
999+
let first_fraction_group = usize::try_from(integer_group_count.max(0))?;
1000+
for digit in digits.iter().skip(first_fraction_group) {
1001+
write!(fraction_part, "{digit:04}")?;
1002+
}
1003+
fraction_part.truncate(decimal_scale);
1004+
while fraction_part.len() < decimal_scale {
1005+
fraction_part.push('0');
1006+
}
1007+
1008+
let prefix = if sign == SIGN_NEGATIVE { "-" } else { "" };
1009+
if decimal_scale == 0 {
1010+
Ok(format!("{prefix}{integer_part}"))
1011+
} else {
1012+
Ok(format!("{prefix}{integer_part}.{fraction_part}"))
1013+
}
1014+
}
1015+
1016+
fn read_i16_be(bytes: &[u8], offset: usize) -> i16 {
1017+
i16::from_be_bytes([bytes[offset], bytes[offset + 1]])
1018+
}
1019+
1020+
fn read_u16_be(bytes: &[u8], offset: usize) -> u16 {
1021+
u16::from_be_bytes([bytes[offset], bytes[offset + 1]])
1022+
}
1023+
9161024
async fn execute_mysql(
9171025
conn: &mut mysql_async::Conn,
9181026
sql: &str,

0 commit comments

Comments
 (0)