在计算机科学中,浮点数的精度限制常常成为数值计算的瓶颈,特别是在金融、科学计算等需要高精度的领域。IEEE 754 标准的浮点数无法精确表示所有十进制小数,例如 0.1 + 0.2 并不等于 0.3。在 Exercism 的 “decimal” 练习中,我们将实现一个支持任意精度的十进制数类型,这不仅能帮助我们深入理解数值表示和计算原理,还能学习 Rust 中的运算符重载、错误处理和高精度计算技术。

什么是高精度十进制数?

高精度十进制数(Arbitrary-Precision Decimal)是一种可以表示任意位数的十进制数的数据类型,它克服了标准浮点数精度限制的问题。这种数据类型通常用于:

  1. 金融计算:需要精确到分的货币计算
  2. 科学计算:需要高精度的数学运算
  3. 数据库系统:DECIMAL/NUMERIC 类型的实现
  4. 工程计算:需要避免浮点误差的场合

让我们先看看练习提供的结构和函数签名:

/// Type implementing arbitrary-precision decimal arithmetic
pub struct Decimal {
    // implement your type here
}

impl Decimal {
    pub fn try_from(input: &str) -> Option<Decimal> {
        unimplemented!("Create a new decimal with a value of {}", input)
    }
}

我们需要实现这个十进制数结构体,它应该支持:

  1. 从字符串创建十进制数
  2. 基本算术运算(加、减、乘)
  3. 比较运算
  4. 高精度计算

算法设计

1. 数据结构设计

use std::cmp::Ordering;

#[derive(Debug, Clone)]
pub struct Decimal {
    // 符号位:true 表示正数,false 表示负数
    positive: bool,
    // 整数部分的数字(从低位到高位存储)
    integer: Vec<u8>,
    // 小数部分的数字(从高位到低位存储)
    fractional: Vec<u8>,
}

impl Decimal {
    pub fn try_from(input: &str) -> Option<Decimal> {
        if input.is_empty() {
            return None;
        }
        
        let mut chars = input.chars().peekable();
        
        // 处理符号
        let positive = match chars.peek() {
            Some(&'+') => {
                chars.next(); // 跳过 '+' 符号
                true
            }
            Some(&'-') => {
                chars.next(); // 跳过 '-' 符号
                false
            }
            _ => true,
        };
        
        // 检查是否还有字符
        if chars.peek().is_none() {
            return None;
        }
        
        // 查找小数点位置
        let mut integer_part = String::new();
        let mut fractional_part = String::new();
        let mut found_dot = false;
        
        for ch in chars {
            if ch == '.' {
                if found_dot {
                    return None; // 多个小数点
                }
                found_dot = true;
            } else if ch.is_ascii_digit() {
                if found_dot {
                    fractional_part.push(ch);
                } else {
                    integer_part.push(ch);
                }
            } else {
                return None; // 非法字符
            }
        }
        
        // 处理空的整数部分
        if integer_part.is_empty() {
            integer_part.push('0');
        }
        
        // 转换整数部分(从低位到高位存储)
        let mut integer = Vec::new();
        for ch in integer_part.chars().rev() {
            integer.push(ch.to_digit(10).unwrap() as u8);
        }
        
        // 转换小数部分(从高位到低位存储)
        let mut fractional = Vec::new();
        for ch in fractional_part.chars() {
            fractional.push(ch.to_digit(10).unwrap() as u8);
        }
        
        // 规范化(移除前导和后缀零)
        let decimal = Decimal {
            positive,
            integer,
            fractional,
        };
        
        Some(decimal.normalize())
    }
    
    // 规范化数字(移除前导和后缀零)
    fn normalize(mut self) -> Self {
        // 移除整数部分的前导零
        while self.integer.len() > 1 && self.integer.last() == Some(&0) {
            self.integer.pop();
        }
        
        // 移除小数部分的后缀零
        while !self.fractional.is_empty() && self.fractional.last() == Some(&0) {
            self.fractional.pop();
        }
        
        // 处理 -0 的情况
        if !self.positive && self.integer.len() == 1 && self.integer[0] == 0 && self.fractional.is_empty() {
            self.positive = true;
        }
        
        self
    }
}

2. 比较运算实现

use std::cmp::Ordering;

impl PartialEq for Decimal {
    fn eq(&self, other: &Self) -> bool {
        // 如果都是零,不管符号如何都相等
        if self.is_zero() && other.is_zero() {
            return true;
        }
        
        self.positive == other.positive 
            && self.integer == other.integer 
            && self.fractional == other.fractional
    }
}

impl Eq for Decimal {}

impl PartialOrd for Decimal {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for Decimal {
    fn cmp(&self, other: &Self) -> Ordering {
        // 正数大于负数
        match (self.positive, other.positive) {
            (true, false) => {
                if self.is_zero() && other.is_zero() {
                    Ordering::Equal
                } else {
                    Ordering::Greater
                }
            }
            (false, true) => {
                if self.is_zero() && other.is_zero() {
                    Ordering::Equal
                } else {
                    Ordering::Less
                }
            }
            (true, true) => self.cmp_magnitude(other),
            (false, false) => other.cmp_magnitude(self),
        }
    }
}

impl Decimal {
    // 比较绝对值大小
    fn cmp_magnitude(&self, other: &Self) -> Ordering {
        // 先比较整数部分长度
        match self.integer.len().cmp(&other.integer.len()) {
            Ordering::Equal => {
                // 整数部分长度相同,逐位比较
                for i in (0..self.integer.len()).rev() {
                    match self.integer[i].cmp(&other.integer[i]) {
                        Ordering::Equal => continue,
                        other => return other,
                    }
                }
                
                // 整数部分相同,比较小数部分
                let max_fractional_len = self.fractional.len().max(other.fractional.len());
                for i in 0..max_fractional_len {
                    let self_digit = self.fractional.get(i).copied().unwrap_or(0);
                    let other_digit = other.fractional.get(i).copied().unwrap_or(0);
                    match self_digit.cmp(&other_digit) {
                        Ordering::Equal => continue,
                        other => return other,
                    }
                }
                
                Ordering::Equal
            }
            other => other,
        }
    }
    
    fn is_zero(&self) -> bool {
        self.integer.len() == 1 && self.integer[0] == 0 && self.fractional.is_empty()
    }
}

3. 算术运算实现

use std::ops::{Add, Sub, Mul};

impl Add for Decimal {
    type Output = Decimal;
    
    fn add(self, other: Decimal) -> Decimal {
        match (self.positive, other.positive) {
            (true, true) => add_positive(self, other),
            (false, false) => {
                let mut result = add_positive(self.negate(), other.negate());
                result.positive = false;
                result
            }
            (true, false) => sub_positive(self, other.negate()),
            (false, true) => sub_positive(other, self.negate()),
        }.normalize()
    }
}

impl Sub for Decimal {
    type Output = Decimal;
    
    fn sub(self, other: Decimal) -> Decimal {
        self.add(other.negate())
    }
}

impl Mul for Decimal {
    type Output = Decimal;
    
    fn mul(self, other: Decimal) -> Decimal {
        let positive = self.positive == other.positive;
        
        let result = mul_positive(self, other);
        if !positive && !result.is_zero() {
            Decimal { positive: false, ..result }
        } else {
            Decimal { positive: true, ..result }
        }.normalize()
    }
}

// 加法辅助函数(两个正数相加)
fn add_positive(a: Decimal, b: Decimal) -> Decimal {
    // 对齐小数点
    let (a_int, a_frac, b_int, b_frac) = align_decimals(&a, &b);
    
    // 计算小数部分之和
    let (mut fractional, carry) = add_digit_vectors(&a_frac, &b_frac, false);
    
    // 计算整数部分之和
    let (integer, _) = add_digit_vectors(&a_int, &b_int, carry);
    
    Decimal {
        positive: true,
        integer,
        fractional,
    }
}

// 减法辅助函数(两个正数相减,假设 a >= b)
fn sub_positive(a: Decimal, b: Decimal) -> Decimal {
    // 对齐小数点
    let (a_int, a_frac, b_int, b_frac) = align_decimals(&a, &b);
    
    // 计算小数部分之差
    let (mut fractional, borrow) = sub_digit_vectors(&a_frac, &b_frac, false);
    
    // 计算整数部分之差
    let (integer, _) = sub_digit_vectors(&a_int, &b_int, borrow);
    
    Decimal {
        positive: true,
        integer,
        fractional,
    }
}

// 乘法辅助函数
fn mul_positive(a: Decimal, b: Decimal) -> Decimal {
    // 将两个数都转换为整数形式进行计算
    let a_digits: Vec<u8> = a.integer.iter().chain(a.fractional.iter()).copied().rev().collect();
    let b_digits: Vec<u8> = b.integer.iter().chain(b.fractional.iter()).copied().rev().collect();
    
    // 计算结果的位数
    let result_len = a_digits.len() + b_digits.len();
    let mut result = vec![0u16; result_len];
    
    // 执行乘法
    for (i, &a_digit) in a_digits.iter().enumerate() {
        for (j, &b_digit) in b_digits.iter().enumerate() {
            let prod = a_digit as u16 * b_digit as u16;
            let pos = i + j;
            result[pos] += prod;
        }
    }
    
    // 处理进位
    for i in 0..result_len - 1 {
        let carry = result[i] / 10;
        result[i] %= 10;
        result[i + 1] += carry;
    }
    
    // 计算小数点位置
    let fractional_digits = a.fractional.len() + b.fractional.len();
    
    // 分离整数和小数部分
    let mut integer = Vec::new();
    let mut fractional = Vec::new();
    
    for (i, &digit) in result.iter().enumerate() {
        let digit = digit as u8;
        if i < fractional_digits {
            fractional.push(digit);
        } else {
            integer.push(digit);
        }
    }
    
    // 反转顺序
    integer.reverse();
    fractional.reverse();
    
    // 移除前导零
    while integer.len() > 1 && integer.last() == Some(&0) {
        integer.pop();
    }
    
    // 移除后缀零
    while !fractional.is_empty() && fractional.last() == Some(&0) {
        fractional.pop();
    }
    
    Decimal {
        positive: true,
        integer,
        fractional,
    }
}

// 对齐两个数的小数部分
fn align_decimals(a: &Decimal, b: &Decimal) -> (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) {
    let max_fractional_len = a.fractional.len().max(b.fractional.len());
    
    let mut a_frac = a.fractional.clone();
    a_frac.resize(max_fractional_len, 0);
    
    let mut b_frac = b.fractional.clone();
    b_frac.resize(max_fractional_len, 0);
    
    (a.integer.clone(), a_frac, b.integer.clone(), b_frac)
}

// 两个数字向量相加(低位在前)
fn add_digit_vectors(a: &[u8], b: &[u8], mut carry: bool) -> (Vec<u8>, bool) {
    let max_len = a.len().max(b.len());
    let mut result = Vec::new();
    let mut new_carry = false;
    
    for i in 0..max_len {
        let a_digit = a.get(i).copied().unwrap_or(0);
        let b_digit = b.get(i).copied().unwrap_or(0);
        
        let mut sum = a_digit + b_digit;
        if carry {
            sum += 1;
        }
        
        if sum >= 10 {
            result.push(sum - 10);
            carry = true;
        } else {
            result.push(sum);
            carry = false;
        }
    }
    
    if carry {
        result.push(1);
        new_carry = true;
    }
    
    (result, new_carry)
}

// 两个数字向量相减(低位在前,假设 a >= b)
fn sub_digit_vectors(a: &[u8], b: &[u8], mut borrow: bool) -> (Vec<u8>, bool) {
    let mut result = Vec::new();
    let mut new_borrow = false;
    
    for i in 0..a.len() {
        let a_digit = a[i];
        let b_digit = b.get(i).copied().unwrap_or(0);
        
        let mut diff = if borrow {
            10 + a_digit - b_digit - 1
        } else {
            if a_digit >= b_digit {
                a_digit - b_digit
            } else {
                10 + a_digit - b_digit
            }
        };
        
        result.push(diff);
        borrow = if borrow {
            a_digit <= b_digit
        } else {
            a_digit < b_digit
        };
    }
    
    // 移除前导零
    while result.len() > 1 && result.last() == Some(&0) {
        result.pop();
    }
    
    (result, new_borrow)
}

impl Decimal {
    fn negate(mut self) -> Self {
        if !self.is_zero() {
            self.positive = !self.positive;
        }
        self
    }
}

4. 显示格式化实现

use std::fmt;

impl fmt::Display for Decimal {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        if !self.positive {
            write!(f, "-")?;
        }
        
        // 显示整数部分
        for digit in self.integer.iter().rev() {
            write!(f, "{}", digit)?;
        }
        
        // 显示小数部分
        if !self.fractional.is_empty() {
            write!(f, ".")?;
            for digit in &self.fractional {
                write!(f, "{}", digit)?;
            }
        }
        
        Ok(())
    }
}

测试用例分析

通过查看测试用例,我们可以更好地理解需求:

#[test]
fn test_eq() {
    assert!(decimal("0.0") == decimal("0.0"));
    assert!(decimal("1.0") == decimal("1.0"));
    for big in BIGS.iter() {
        assert!(decimal(big) == decimal(big));
    }
}

相等性比较应该正确处理各种情况,包括大数。

#[test]
fn test_add() {
    assert_eq!(decimal("0.1") + decimal("0.2"), decimal("0.3"));
    assert_eq!(decimal(BIGS[0]) + decimal(BIGS[1]), decimal(BIGS[2]));
    assert_eq!(decimal(BIGS[1]) + decimal(BIGS[0]), decimal(BIGS[2]));
}

加法应该精确计算,包括大数和进位情况。

#[test]
fn test_gt_varying_positive_precisions() {
    assert!(decimal("1.1") > decimal("1.01"));
    assert!(decimal("1.01") > decimal("1.0"));
    assert!(decimal("1.0") > decimal("0.1"));
    assert!(decimal("0.1") > decimal("0.01"));
}

比较运算应该正确处理不同精度的数字。

#[test]
fn test_negatives() {
    assert!(Decimal::try_from("-1").is_some());
    assert_eq!(decimal("0") - decimal("1"), decimal("-1"));
    assert_eq!(decimal("5.5") + decimal("-6.5"), decimal("-1"));
}

负数应该被正确处理。

完整实现

考虑所有边界情况的完整实现:

use std::cmp::Ordering;
use std::ops::{Add, Sub, Mul};
use std::fmt;

#[derive(Debug, Clone)]
pub struct Decimal {
    positive: bool,
    integer: Vec<u8>,
    fractional: Vec<u8>,
}

impl Decimal {
    pub fn try_from(input: &str) -> Option<Decimal> {
        if input.is_empty() {
            return None;
        }
        
        let mut chars = input.chars().peekable();
        
        // 处理符号
        let positive = match chars.peek() {
            Some(&'+') => {
                chars.next();
                true
            }
            Some(&'-') => {
                chars.next();
                false
            }
            _ => true,
        };
        
        if chars.peek().is_none() {
            return None;
        }
        
        // 解析数字
        let mut integer_part = String::new();
        let mut fractional_part = String::new();
        let mut found_dot = false;
        
        for ch in chars {
            if ch == '.' {
                if found_dot {
                    return None;
                }
                found_dot = true;
            } else if ch.is_ascii_digit() {
                if found_dot {
                    fractional_part.push(ch);
                } else {
                    integer_part.push(ch);
                }
            } else {
                return None;
            }
        }
        
        if integer_part.is_empty() {
            integer_part.push('0');
        }
        
        let mut integer = Vec::new();
        for ch in integer_part.chars().rev() {
            integer.push(ch.to_digit(10).unwrap() as u8);
        }
        
        let mut fractional = Vec::new();
        for ch in fractional_part.chars() {
            fractional.push(ch.to_digit(10).unwrap() as u8);
        }
        
        let decimal = Decimal {
            positive,
            integer,
            fractional,
        };
        
        Some(decimal.normalize())
    }
    
    fn normalize(mut self) -> Self {
        // 移除整数部分的前导零
        while self.integer.len() > 1 && self.integer.last() == Some(&0) {
            self.integer.pop();
        }
        
        // 移除小数部分的后缀零
        while !self.fractional.is_empty() && self.fractional.last() == Some(&0) {
            self.fractional.pop();
        }
        
        // 处理 -0 的情况
        if !self.positive && self.integer.len() == 1 && self.integer[0] == 0 && self.fractional.is_empty() {
            self.positive = true;
        }
        
        self
    }
    
    fn is_zero(&self) -> bool {
        self.integer.len() == 1 && self.integer[0] == 0 && self.fractional.is_empty()
    }
    
    fn negate(mut self) -> Self {
        if !self.is_zero() {
            self.positive = !self.positive;
        }
        self
    }
}

impl PartialEq for Decimal {
    fn eq(&self, other: &Self) -> bool {
        if self.is_zero() && other.is_zero() {
            return true;
        }
        
        self.positive == other.positive 
            && self.integer == other.integer 
            && self.fractional == other.fractional
    }
}

impl Eq for Decimal {}

impl PartialOrd for Decimal {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for Decimal {
    fn cmp(&self, other: &Self) -> Ordering {
        match (self.positive, other.positive) {
            (true, false) => {
                if self.is_zero() && other.is_zero() {
                    Ordering::Equal
                } else {
                    Ordering::Greater
                }
            }
            (false, true) => {
                if self.is_zero() && other.is_zero() {
                    Ordering::Equal
                } else {
                    Ordering::Less
                }
            }
            (true, true) => self.cmp_magnitude(other),
            (false, false) => other.cmp_magnitude(self),
        }
    }
}

impl Decimal {
    fn cmp_magnitude(&self, other: &Self) -> Ordering {
        match self.integer.len().cmp(&other.integer.len()) {
            Ordering::Equal => {
                for i in (0..self.integer.len()).rev() {
                    match self.integer[i].cmp(&other.integer[i]) {
                        Ordering::Equal => continue,
                        other => return other,
                    }
                }
                
                let max_fractional_len = self.fractional.len().max(other.fractional.len());
                for i in 0..max_fractional_len {
                    let self_digit = self.fractional.get(i).copied().unwrap_or(0);
                    let other_digit = other.fractional.get(i).copied().unwrap_or(0);
                    match self_digit.cmp(&other_digit) {
                        Ordering::Equal => continue,
                        other => return other,
                    }
                }
                
                Ordering::Equal
            }
            other => other,
        }
    }
}

impl Add for Decimal {
    type Output = Decimal;
    
    fn add(self, other: Decimal) -> Decimal {
        match (self.positive, other.positive) {
            (true, true) => add_positive(self, other),
            (false, false) => {
                let mut result = add_positive(self.negate(), other.negate());
                result.positive = false;
                result
            }
            (true, false) => sub_positive(self, other.negate()),
            (false, true) => sub_positive(other, self.negate()),
        }.normalize()
    }
}

impl Sub for Decimal {
    type Output = Decimal;
    
    fn sub(self, other: Decimal) -> Decimal {
        self.add(other.negate())
    }
}

impl Mul for Decimal {
    type Output = Decimal;
    
    fn mul(self, other: Decimal) -> Decimal {
        let positive = self.positive == other.positive;
        
        let result = mul_positive(self, other);
        if !positive && !result.is_zero() {
            Decimal { positive: false, ..result }
        } else {
            Decimal { positive: true, ..result }
        }.normalize()
    }
}

fn add_positive(a: Decimal, b: Decimal) -> Decimal {
    let (a_int, a_frac, b_int, b_frac) = align_decimals(&a, &b);
    let (mut fractional, carry) = add_digit_vectors(&a_frac, &b_frac, false);
    let (integer, _) = add_digit_vectors(&a_int, &b_int, carry);
    
    Decimal {
        positive: true,
        integer,
        fractional,
    }
}

fn sub_positive(a: Decimal, b: Decimal) -> Decimal {
    let (a_int, a_frac, b_int, b_frac) = align_decimals(&a, &b);
    let (mut fractional, borrow) = sub_digit_vectors(&a_frac, &b_frac, false);
    let (integer, _) = sub_digit_vectors(&a_int, &b_int, borrow);
    
    Decimal {
        positive: true,
        integer,
        fractional,
    }
}

fn mul_positive(a: Decimal, b: Decimal) -> Decimal {
    let a_digits: Vec<u8> = a.integer.iter().chain(a.fractional.iter()).copied().rev().collect();
    let b_digits: Vec<u8> = b.integer.iter().chain(b.fractional.iter()).copied().rev().collect();
    
    let result_len = a_digits.len() + b_digits.len();
    let mut result = vec![0u16; result_len];
    
    for (i, &a_digit) in a_digits.iter().enumerate() {
        for (j, &b_digit) in b_digits.iter().enumerate() {
            let prod = a_digit as u16 * b_digit as u16;
            let pos = i + j;
            result[pos] += prod;
        }
    }
    
    for i in 0..result_len - 1 {
        let carry = result[i] / 10;
        result[i] %= 10;
        result[i + 1] += carry;
    }
    
    let fractional_digits = a.fractional.len() + b.fractional.len();
    
    let mut integer = Vec::new();
    let mut fractional = Vec::new();
    
    for (i, &digit) in result.iter().enumerate() {
        let digit = digit as u8;
        if i < fractional_digits {
            fractional.push(digit);
        } else {
            integer.push(digit);
        }
    }
    
    integer.reverse();
    fractional.reverse();
    
    while integer.len() > 1 && integer.last() == Some(&0) {
        integer.pop();
    }
    
    while !fractional.is_empty() && fractional.last() == Some(&0) {
        fractional.pop();
    }
    
    Decimal {
        positive: true,
        integer,
        fractional,
    }
}

fn align_decimals(a: &Decimal, b: &Decimal) -> (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) {
    let max_fractional_len = a.fractional.len().max(b.fractional.len());
    
    let mut a_frac = a.fractional.clone();
    a_frac.resize(max_fractional_len, 0);
    
    let mut b_frac = b.fractional.clone();
    b_frac.resize(max_fractional_len, 0);
    
    (a.integer.clone(), a_frac, b.integer.clone(), b_frac)
}

fn add_digit_vectors(a: &[u8], b: &[u8], mut carry: bool) -> (Vec<u8>, bool) {
    let max_len = a.len().max(b.len());
    let mut result = Vec::new();
    let mut new_carry = false;
    
    for i in 0..max_len {
        let a_digit = a.get(i).copied().unwrap_or(0);
        let b_digit = b.get(i).copied().unwrap_or(0);
        
        let mut sum = a_digit + b_digit;
        if carry {
            sum += 1;
        }
        
        if sum >= 10 {
            result.push(sum - 10);
            carry = true;
        } else {
            result.push(sum);
            carry = false;
        }
    }
    
    if carry {
        result.push(1);
        new_carry = true;
    }
    
    (result, new_carry)
}

fn sub_digit_vectors(a: &[u8], b: &[u8], mut borrow: bool) -> (Vec<u8>, bool) {
    let mut result = Vec::new();
    let mut new_borrow = false;
    
    for i in 0..a.len() {
        let a_digit = a[i];
        let b_digit = b.get(i).copied().unwrap_or(0);
        
        let mut diff = if borrow {
            10 + a_digit - b_digit - 1
        } else {
            if a_digit >= b_digit {
                a_digit - b_digit
            } else {
                10 + a_digit - b_digit
            }
        };
        
        result.push(diff);
        borrow = if borrow {
            a_digit <= b_digit
        } else {
            a_digit < b_digit
        };
    }
    
    while result.len() > 1 && result.last() == Some(&0) {
        result.pop();
    }
    
    (result, new_borrow)
}

impl fmt::Display for Decimal {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        if !self.positive {
            write!(f, "-")?;
        }
        
        for digit in self.integer.iter().rev() {
            write!(f, "{}", digit)?;
        }
        
        if !self.fractional.is_empty() {
            write!(f, ".")?;
            for digit in &self.fractional {
                write!(f, "{}", digit)?;
            }
        }
        
        Ok(())
    }
}

性能优化版本

考虑性能的优化实现:

// 使用第三方库实现(如 rust_decimal)
use rust_decimal::Decimal as ExternalDecimal;
use std::str::FromStr;

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Decimal {
    inner: ExternalDecimal,
}

impl Decimal {
    pub fn try_from(input: &str) -> Option<Decimal> {
        ExternalDecimal::from_str(input)
            .ok()
            .map(|d| Decimal { inner: d })
    }
}

use std::ops::{Add, Sub, Mul};

impl Add for Decimal {
    type Output = Decimal;
    
    fn add(self, other: Decimal) -> Decimal {
        Decimal {
            inner: self.inner + other.inner,
        }
    }
}

impl Sub for Decimal {
    type Output = Decimal;
    
    fn sub(self, other: Decimal) -> Decimal {
        Decimal {
            inner: self.inner - other.inner,
        }
    }
}

impl Mul for Decimal {
    type Output = Decimal;
    
    fn mul(self, other: Decimal) -> Decimal {
        Decimal {
            inner: self.inner * other.inner,
        }
    }
}

impl std::fmt::Display for Decimal {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(f, "{}", self.inner)
    }
}

错误处理和边界情况

考虑更多边界情况的实现:

use std::cmp::Ordering;
use std::ops::{Add, Sub, Mul};
use std::fmt;

#[derive(Debug, Clone)]
pub struct Decimal {
    positive: bool,
    integer: Vec<u8>,
    fractional: Vec<u8>,
}

#[derive(Debug, PartialEq)]
pub enum ParseDecimalError {
    EmptyString,
    InvalidCharacter,
    MultipleDecimalPoints,
}

impl Decimal {
    pub fn try_from(input: &str) -> Result<Decimal, ParseDecimalError> {
        if input.is_empty() {
            return Err(ParseDecimalError::EmptyString);
        }
        
        let mut chars = input.chars().peekable();
        
        let positive = match chars.peek() {
            Some(&'+') => {
                chars.next();
                true
            }
            Some(&'-') => {
                chars.next();
                false
            }
            _ => true,
        };
        
        if chars.peek().is_none() {
            return Err(ParseDecimalError::EmptyString);
        }
        
        let mut integer_part = String::new();
        let mut fractional_part = String::new();
        let mut found_dot = false;
        
        for ch in chars {
            if ch == '.' {
                if found_dot {
                    return Err(ParseDecimalError::MultipleDecimalPoints);
                }
                found_dot = true;
            } else if ch.is_ascii_digit() {
                if found_dot {
                    fractional_part.push(ch);
                } else {
                    integer_part.push(ch);
                }
            } else {
                return Err(ParseDecimalError::InvalidCharacter);
            }
        }
        
        if integer_part.is_empty() {
            integer_part.push('0');
        }
        
        let mut integer = Vec::new();
        for ch in integer_part.chars().rev() {
            integer.push(ch.to_digit(10).unwrap() as u8);
        }
        
        let mut fractional = Vec::new();
        for ch in fractional_part.chars() {
            fractional.push(ch.to_digit(10).unwrap() as u8);
        }
        
        let decimal = Decimal {
            positive,
            integer,
            fractional,
        };
        
        Ok(decimal.normalize())
    }
    
    // 其他方法保持不变...
    fn normalize(mut self) -> Self {
        while self.integer.len() > 1 && self.integer.last() == Some(&0) {
            self.integer.pop();
        }
        
        while !self.fractional.is_empty() && self.fractional.last() == Some(&0) {
            self.fractional.pop();
        }
        
        if !self.positive && self.integer.len() == 1 && self.integer[0] == 0 && self.fractional.is_empty() {
            self.positive = true;
        }
        
        self
    }
    
    fn is_zero(&self) -> bool {
        self.integer.len() == 1 && self.integer[0] == 0 && self.fractional.is_empty()
    }
    
    fn negate(mut self) -> Self {
        if !self.is_zero() {
            self.positive = !self.positive;
        }
        self
    }
}

// 实现 FromStr trait
impl std::str::FromStr for Decimal {
    type Err = ParseDecimalError;
    
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Decimal::try_from(s)
    }
}

impl fmt::Display for Decimal {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        if !self.positive {
            write!(f, "-")?;
        }
        
        for digit in self.integer.iter().rev() {
            write!(f, "{}", digit)?;
        }
        
        if !self.fractional.is_empty() {
            write!(f, ".")?;
            for digit in &self.fractional {
                write!(f, "{}", digit)?;
            }
        }
        
        Ok(())
    }
}

// 其他实现保持不变...

实际应用场景

高精度十进制数在实际开发中有以下应用:

  1. 金融系统:货币计算需要精确到分
  2. 科学计算:需要高精度的数学运算
  3. 数据库系统:DECIMAL/NUMERIC 类型的实现
  4. 电子商务:价格计算和订单处理
  5. 税务系统:税额计算需要精确处理
  6. 工程计算:避免浮点误差的场合

算法复杂度分析

  1. 时间复杂度

    • 加法:O(n + m),其中 n 和 m 是两个数的位数
    • 减法:O(n + m)
    • 乘法:O(n × m)
    • 比较:O(n + m)
  2. 空间复杂度

    • 存储:O(n + m)
    • 运算过程:取决于具体运算

与其他实现方式的比较

// 使用 num-bigint 库的实现
use num_bigint::BigInt;
use num_traits::{Zero, One};

#[derive(Debug, Clone)]
pub struct Decimal {
    numerator: BigInt,
    denominator: BigInt,
    scale: usize, // 小数点后的位数
}

impl Decimal {
    pub fn try_from(input: &str) -> Option<Decimal> {
        // 实现解析逻辑
        unimplemented!()
    }
}

// 使用 BigDecimal 库的实现
use bigdecimal::BigDecimal;

pub struct Decimal {
    inner: BigDecimal,
}

impl Decimal {
    pub fn try_from(input: &str) -> Option<Decimal> {
        input.parse::<BigDecimal>()
            .ok()
            .map(|bd| Decimal { inner: bd })
    }
}

总结

通过 decimal 练习,我们学到了:

  1. 数值表示:掌握了高精度十进制数的内部表示方法
  2. 算法实现:学会了实现基本算术运算的算法
  3. 运算符重载:理解了 Rust 中运算符重载的机制
  4. 错误处理:熟练使用 Result 类型处理解析错误
  5. 性能优化:了解了不同实现方式的性能特点
  6. 边界处理:学会了处理各种边界情况

这些技能在实际开发中非常有用,特别是在实现金融系统、科学计算库和需要高精度数值处理的应用时。高精度十进制数虽然实现复杂,但它涉及到了数值计算、算法设计和错误处理等许多核心概念,是学习 Rust 数值处理的良好起点。

通过这个练习,我们也看到了 Rust 在数值处理和运算符重载方面的强大能力,以及如何用安全且高效的方式实现复杂的数学算法。这种结合了安全性和性能的语言特性正是 Rust 的魅力所在。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐