Skip to main content

opendal_core/raw/oio/read/
position_read.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use futures::Future;
21
22use crate::raw::*;
23use crate::*;
24
25const DEFAULT_POSITION_READ_MAX_BUF_SIZE: usize = 2 * 1024 * 1024;
26
27/// PositionRead is used to implement [`oio::Read`] based on positioned reads.
28///
29/// Services that implement [`PositionRead`] must support position-independent
30/// reads. `size` is the maximum number of bytes to read, and implementations may
31/// return fewer bytes. Returning an empty buffer means EOF.
32pub trait PositionRead: Send + Sync + Unpin + 'static {
33    /// Read up to `size` bytes from `offset`.
34    fn read_at(&self, offset: u64, size: usize)
35    -> impl Future<Output = Result<Buffer>> + MaybeSend;
36}
37
38/// PositionReader implements [`oio::Read`] based on [`PositionRead`].
39pub struct PositionReader<R: PositionRead> {
40    inner: Arc<R>,
41    max_buf_size: usize,
42}
43
44impl<R: PositionRead> PositionReader<R> {
45    /// Create a new [`PositionReader`].
46    pub fn new(inner: R) -> Self {
47        Self {
48            inner: Arc::new(inner),
49            max_buf_size: DEFAULT_POSITION_READ_MAX_BUF_SIZE,
50        }
51    }
52
53    /// Set the maximum buffer size used by [`PositionReader`].
54    pub fn with_max_buf_size(mut self, buf_size: usize) -> Self {
55        assert!(
56            buf_size > 0,
57            "position read max buffer size must not be zero"
58        );
59
60        self.max_buf_size = buf_size;
61        self
62    }
63
64    /// Consume the reader and return the inner [`PositionRead`].
65    ///
66    /// # Panics
67    ///
68    /// Panics if there are active streams that still share the inner reader.
69    pub fn into_inner(self) -> R {
70        Arc::into_inner(self.inner).expect("position reader must not be shared")
71    }
72}
73
74impl<R: PositionRead> oio::Read for PositionReader<R> {
75    async fn open(&self, range: BytesRange) -> Result<(RpRead, Box<dyn oio::ReadStreamDyn>)> {
76        let stream = PositionReadStream::new(self.inner.clone(), range, self.max_buf_size);
77        Ok((
78            RpRead::default(),
79            Box::new(stream) as Box<dyn oio::ReadStreamDyn>,
80        ))
81    }
82
83    async fn read(&self, range: BytesRange) -> Result<(RpRead, Buffer)> {
84        let size = range
85            .size()
86            .ok_or_else(|| Error::new(ErrorKind::Unsupported, "read requires a bounded range"))?;
87
88        let mut offset = range.offset();
89        let mut remaining = size;
90        let mut bufs = Vec::new();
91
92        while remaining > 0 {
93            let read_size = remaining.min(self.max_buf_size as u64) as usize;
94            let buf = self.inner.read_at(offset, read_size).await?;
95            check_position_read_size(read_size, buf.len())?;
96            if buf.is_empty() {
97                return Err(Error::new(
98                    ErrorKind::RangeNotSatisfied,
99                    "range exceeds content length",
100                )
101                .with_context("offset", offset)
102                .with_context("remaining", remaining));
103            }
104
105            let n = buf.len() as u64;
106            offset += n;
107            remaining -= n;
108            bufs.push(buf);
109        }
110
111        Ok((RpRead::default(), bufs.into_iter().flatten().collect()))
112    }
113}
114
115struct PositionReadStream<R: PositionRead> {
116    inner: Arc<R>,
117    offset: u64,
118    remaining: Option<u64>,
119    max_buf_size: usize,
120    done: bool,
121}
122
123impl<R: PositionRead> PositionReadStream<R> {
124    fn new(inner: Arc<R>, range: BytesRange, max_buf_size: usize) -> Self {
125        Self {
126            inner,
127            offset: range.offset(),
128            remaining: range.size(),
129            max_buf_size,
130            done: false,
131        }
132    }
133}
134
135impl<R: PositionRead> oio::ReadStream for PositionReadStream<R> {
136    async fn read(&mut self) -> Result<Buffer> {
137        if self.done || self.remaining == Some(0) {
138            return Ok(Buffer::new());
139        }
140
141        let read_size = self
142            .remaining
143            .map(|remaining| remaining.min(self.max_buf_size as u64) as usize)
144            .unwrap_or(self.max_buf_size);
145
146        let buf = self.inner.read_at(self.offset, read_size).await?;
147        check_position_read_size(read_size, buf.len())?;
148        if buf.is_empty() {
149            self.done = true;
150            if let Some(remaining) = self.remaining {
151                return Err(Error::new(
152                    ErrorKind::RangeNotSatisfied,
153                    "range exceeds content length",
154                )
155                .with_context("offset", self.offset)
156                .with_context("remaining", remaining));
157            }
158            return Ok(Buffer::new());
159        }
160
161        let n = buf.len() as u64;
162        self.offset += n;
163        if let Some(remaining) = &mut self.remaining {
164            *remaining -= n;
165        }
166
167        Ok(buf)
168    }
169}
170
171fn check_position_read_size(expected: usize, actual: usize) -> Result<()> {
172    if actual > expected {
173        return Err(
174            Error::new(ErrorKind::Unexpected, "reader got unexpected data size")
175                .with_context("expect", expected)
176                .with_context("actual", actual),
177        );
178    }
179
180    Ok(())
181}
182
183#[cfg(test)]
184mod tests {
185    use std::sync::Arc;
186    use std::sync::Mutex;
187
188    use bytes::Bytes;
189
190    use super::*;
191    use crate::raw::oio::Read;
192    use crate::raw::oio::ReadStream;
193
194    struct TestPositionRead {
195        content: Bytes,
196        max_read: usize,
197        calls: Arc<Mutex<Vec<(u64, usize)>>>,
198    }
199
200    impl TestPositionRead {
201        fn new(content: &'static [u8], max_read: usize) -> Self {
202            Self {
203                content: Bytes::from_static(content),
204                max_read,
205                calls: Arc::default(),
206            }
207        }
208    }
209
210    impl PositionRead for TestPositionRead {
211        async fn read_at(&self, offset: u64, size: usize) -> Result<Buffer> {
212            self.calls.lock().unwrap().push((offset, size));
213
214            let offset = offset as usize;
215            if offset >= self.content.len() {
216                return Ok(Buffer::new());
217            }
218
219            let end = offset + size.min(self.max_read).min(self.content.len() - offset);
220            Ok(Buffer::from(self.content.slice(offset..end)))
221        }
222    }
223
224    #[tokio::test]
225    async fn test_position_reader_read_handles_partial_reads() -> Result<()> {
226        let inner = TestPositionRead::new(b"0123456789", 2);
227        let calls = inner.calls.clone();
228        let reader = PositionReader::new(inner).with_max_buf_size(4);
229
230        let (_, buf) = reader.read(BytesRange::from(2..8)).await?;
231
232        assert_eq!(buf.to_vec(), b"234567");
233        assert_eq!(calls.lock().unwrap().as_slice(), &[(2, 4), (4, 4), (6, 2)]);
234
235        Ok(())
236    }
237
238    #[tokio::test]
239    async fn test_position_reader_read_reports_early_eof() -> Result<()> {
240        let reader =
241            PositionReader::new(TestPositionRead::new(b"0123456789", 4)).with_max_buf_size(4);
242
243        let err = reader.read(BytesRange::from(8..12)).await.unwrap_err();
244
245        assert_eq!(err.kind(), ErrorKind::RangeNotSatisfied);
246        Ok(())
247    }
248
249    #[tokio::test]
250    async fn test_position_reader_open_stops_at_eof() -> Result<()> {
251        let reader =
252            PositionReader::new(TestPositionRead::new(b"0123456789", 2)).with_max_buf_size(4);
253        let (_, mut stream) = reader.open(BytesRange::from(8..)).await?;
254
255        let buf = stream.read_all().await?;
256
257        assert_eq!(buf.to_vec(), b"89");
258        Ok(())
259    }
260}