opendal/services/huggingface/
config.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19
20use serde::Deserialize;
21use serde::Serialize;
22
23use super::HUGGINGFACE_SCHEME;
24use super::backend::HuggingfaceBuilder;
25
26/// Configuration for Huggingface service support.
27#[derive(Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
28#[serde(default)]
29#[non_exhaustive]
30pub struct HuggingfaceConfig {
31    /// Repo type of this backend. Default is model.
32    ///
33    /// Available values:
34    /// - model
35    /// - dataset
36    /// - datasets (alias for dataset)
37    pub repo_type: Option<String>,
38    /// Repo id of this backend.
39    ///
40    /// This is required.
41    pub repo_id: Option<String>,
42    /// Revision of this backend.
43    ///
44    /// Default is main.
45    pub revision: Option<String>,
46    /// Root of this backend. Can be "/path/to/dir".
47    ///
48    /// Default is "/".
49    pub root: Option<String>,
50    /// Token of this backend.
51    ///
52    /// This is optional.
53    pub token: Option<String>,
54    /// Endpoint of the Huggingface Hub.
55    ///
56    /// Default is "https://huggingface.co".
57    pub endpoint: Option<String>,
58}
59
60impl Debug for HuggingfaceConfig {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("HuggingfaceConfig")
63            .field("repo_type", &self.repo_type)
64            .field("repo_id", &self.repo_id)
65            .field("revision", &self.revision)
66            .field("root", &self.root)
67            .finish_non_exhaustive()
68    }
69}
70
71impl crate::Configurator for HuggingfaceConfig {
72    type Builder = HuggingfaceBuilder;
73
74    fn from_uri(uri: &crate::types::OperatorUri) -> crate::Result<Self> {
75        let mut map = uri.options().clone();
76
77        if let Some(repo_type) = uri.name() {
78            if !repo_type.is_empty() {
79                map.insert("repo_type".to_string(), repo_type.to_string());
80            }
81        }
82
83        let raw_path = uri.root().ok_or_else(|| {
84            crate::Error::new(
85                crate::ErrorKind::ConfigInvalid,
86                "uri path must include owner and repo",
87            )
88            .with_context("service", HUGGINGFACE_SCHEME)
89        })?;
90
91        let mut segments = raw_path.splitn(4, '/');
92        let owner = segments.next().filter(|s| !s.is_empty()).ok_or_else(|| {
93            crate::Error::new(
94                crate::ErrorKind::ConfigInvalid,
95                "repository owner is required in uri path",
96            )
97            .with_context("service", HUGGINGFACE_SCHEME)
98        })?;
99        let repo = segments.next().filter(|s| !s.is_empty()).ok_or_else(|| {
100            crate::Error::new(
101                crate::ErrorKind::ConfigInvalid,
102                "repository name is required in uri path",
103            )
104            .with_context("service", HUGGINGFACE_SCHEME)
105        })?;
106
107        map.insert("repo_id".to_string(), format!("{owner}/{repo}"));
108
109        if let Some(segment) = segments.next() {
110            if map.contains_key("revision") {
111                let mut root_value = segment.to_string();
112                if let Some(rest) = segments.next() {
113                    if !rest.is_empty() {
114                        if !root_value.is_empty() {
115                            root_value.push('/');
116                            root_value.push_str(rest);
117                        } else {
118                            root_value = rest.to_string();
119                        }
120                    }
121                }
122                if !root_value.is_empty() {
123                    map.insert("root".to_string(), root_value);
124                }
125            } else {
126                if !segment.is_empty() {
127                    map.insert("revision".to_string(), segment.to_string());
128                }
129                if let Some(rest) = segments.next() {
130                    if !rest.is_empty() {
131                        map.insert("root".to_string(), rest.to_string());
132                    }
133                }
134            }
135        }
136
137        Self::from_iter(map)
138    }
139
140    fn into_builder(self) -> Self::Builder {
141        HuggingfaceBuilder { config: self }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::Configurator;
149    use crate::types::OperatorUri;
150
151    #[test]
152    fn from_uri_sets_repo_type_id_and_revision() {
153        let uri = OperatorUri::new(
154            "huggingface://model/opendal/sample/main/dataset",
155            Vec::<(String, String)>::new(),
156        )
157        .unwrap();
158
159        let cfg = HuggingfaceConfig::from_uri(&uri).unwrap();
160        assert_eq!(cfg.repo_type.as_deref(), Some("model"));
161        assert_eq!(cfg.repo_id.as_deref(), Some("opendal/sample"));
162        assert_eq!(cfg.revision.as_deref(), Some("main"));
163        assert_eq!(cfg.root.as_deref(), Some("dataset"));
164    }
165
166    #[test]
167    fn from_uri_uses_existing_revision_and_sets_root() {
168        let uri = OperatorUri::new(
169            "huggingface://dataset/opendal/sample/data/train",
170            vec![("revision".to_string(), "dev".to_string())],
171        )
172        .unwrap();
173
174        let cfg = HuggingfaceConfig::from_uri(&uri).unwrap();
175        assert_eq!(cfg.repo_type.as_deref(), Some("dataset"));
176        assert_eq!(cfg.repo_id.as_deref(), Some("opendal/sample"));
177        assert_eq!(cfg.revision.as_deref(), Some("dev"));
178        assert_eq!(cfg.root.as_deref(), Some("data/train"));
179    }
180
181    #[test]
182    fn from_uri_requires_owner_and_repo() {
183        let uri = OperatorUri::new(
184            "huggingface://model/opendal",
185            Vec::<(String, String)>::new(),
186        )
187        .unwrap();
188
189        assert!(HuggingfaceConfig::from_uri(&uri).is_err());
190    }
191}