Skip to content
75 changes: 55 additions & 20 deletions docarray/utils/create_dynamic_doc_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo

from docarray.base_doc.doc import BaseDocWithoutId
from docarray import BaseDoc, DocList
from docarray.typing import AnyTensor
from docarray.utils._internal._typing import safe_issubclass
Expand Down Expand Up @@ -50,16 +51,19 @@ class MyDoc(BaseDoc):
:param model: The input model
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
"""
if is_pydantic_v2:
raise NotImplementedError(
'This method is not supported in Pydantic 2.0. Please use Pydantic 1.8.2 or lower.'
)

fields: Dict[str, Any] = {}
for field_name, field in model.__annotations__.items():
if field_name not in model.__fields__:
import copy

fields_copy = copy.deepcopy(model.__fields__)
annotations_copy = copy.deepcopy(model.__annotations__)
for field_name, field in annotations_copy.items():
if field_name not in fields_copy:
continue
field_info = model.__fields__[field_name].field_info

if is_pydantic_v2:
field_info = fields_copy[field_name]
else:
field_info = fields_copy[field_name].field_info
try:
if safe_issubclass(field, DocList):
t: Any = field.doc_type
Expand All @@ -68,9 +72,8 @@ class MyDoc(BaseDoc):
fields[field_name] = (field, field_info)
except TypeError:
fields[field_name] = (field, field_info)
return create_model(
model.__name__, __base__=model, __validators__=model.__validators__, **fields
)

return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)


def _get_field_annotation_from_schema(
Expand Down Expand Up @@ -201,6 +204,8 @@ def _get_field_annotation_from_schema(
num_recursions=num_recursions + 1,
definitions=definitions,
)
elif field_type == 'null':
ret = None
else:
if num_recursions > 0:
raise ValueError(
Expand Down Expand Up @@ -255,14 +260,18 @@ class MyDoc(BaseDoc):
:return: A BaseDoc class dynamically created following the `schema`.
"""
if not definitions:
definitions = schema.get('definitions', {})
definitions = (
schema.get('definitions', {}) if not is_pydantic_v2 else schema.get('$defs')
)

cached_models = cached_models if cached_models is not None else {}
fields: Dict[str, Any] = {}
if base_doc_name in cached_models:
return cached_models[base_doc_name]
has_id = False
for field_name, field_schema in schema.get('properties', {}).items():

if field_name == 'id':
has_id = True
field_type = _get_field_annotation_from_schema(
field_schema=field_schema,
field_name=field_name,
Expand All @@ -272,17 +281,43 @@ class MyDoc(BaseDoc):
num_recursions=0,
definitions=definitions,
)
fields[field_name] = (
field_type,
FieldInfo(default=field_schema.pop('default', None), **field_schema),
)
if not is_pydantic_v2:
field_schema['default'] = field_schema.get('default', None)
fields[field_name] = (
field_type,
FieldInfo(**field_schema),
)
else:
field_kwargs = {}
field_json_schema_extra = {}
for k, v in field_schema.items():
if k in FieldInfo.__slots__:
field_kwargs[k] = v
else:
field_json_schema_extra[k] = v
fields[field_name] = (
field_type,
FieldInfo(
json_schema_extra=field_json_schema_extra,
**field_kwargs,
),
)

model = create_model(base_doc_name, __base__=BaseDoc, **fields)
model.__config__.title = schema.get('title', model.__config__.title)
base_model = BaseDoc if has_id else BaseDocWithoutId
model = create_model(base_doc_name, __base__=base_model, **fields)
if not is_pydantic_v2:
model.__config__.title = schema.get('title', model.__config__.title)
else:
set_title = schema.get('title', model.model_config.get('title', None))
if set_title:
model.model_config['title'] = set_title

for k in RESERVED_KEYS:
if k in schema:
schema.pop(k)
model.__config__.schema_extra = schema
if not is_pydantic_v2:
model.__config__.schema_extra = schema
else:
model.model_config['json_schema_extra'] = schema
cached_models[base_doc_name] = model
return model
17 changes: 10 additions & 7 deletions tests/units/util/test_create_dynamic_code_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from docarray import BaseDoc, DocList
from docarray.documents import TextDoc
from docarray.typing import AnyTensor, ImageUrl
from docarray.utils._internal.pydantic import is_pydantic_v2
from docarray.utils.create_dynamic_doc_class import (
create_base_doc_from_schema,
create_pure_python_type_model,
)
from docarray.utils._internal.pydantic import is_pydantic_v2


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
@pytest.mark.parametrize('transformation', ['proto', 'json'])
def test_create_pydantic_model_from_schema(transformation):
class Nested2Doc(BaseDoc):
Expand All @@ -26,7 +25,7 @@ class Nested1Doc(BaseDoc):
classvar: ClassVar[str] = 'classvar1'

class CustomDoc(BaseDoc):
tensor: Optional[AnyTensor]
tensor: Optional[AnyTensor] = None
url: ImageUrl
lll: List[List[List[int]]] = [[[5]]]
fff: List[List[List[float]]] = [[[5.2]]]
Expand Down Expand Up @@ -80,7 +79,10 @@ class CustomDoc(BaseDoc):
assert len(custom_partial_da) == 1
assert custom_partial_da[0].url == 'photo.jpg'
assert custom_partial_da[0].lll == [[[40]]]
assert custom_partial_da[0].lu == ['3', '4'] # Union validates back to string
if is_pydantic_v2:
assert custom_partial_da[0].lu == [3, 4]
else:
assert custom_partial_da[0].lu == ['3', '4'] # Union validates back to string
assert custom_partial_da[0].fff == [[[40.2]]]
assert custom_partial_da[0].di == {'a': 2}
assert custom_partial_da[0].d == {'b': 'a'}
Expand All @@ -99,7 +101,10 @@ class CustomDoc(BaseDoc):
assert len(original_back) == 1
assert original_back[0].url == 'photo.jpg'
assert original_back[0].lll == [[[40]]]
assert original_back[0].lu == ['3', '4'] # Union validates back to string
if is_pydantic_v2:
assert original_back[0].lu == [3, 4] # Union validates back to string
else:
assert original_back[0].lu == ['3', '4'] # Union validates back to string
assert original_back[0].fff == [[[40.2]]]
assert original_back[0].di == {'a': 2}
assert original_back[0].d == {'b': 'a'}
Expand Down Expand Up @@ -174,7 +179,6 @@ class ResultTestDoc(BaseDoc):
assert doc.ia == f'ID {i}'


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
@pytest.mark.parametrize('transformation', ['proto', 'json'])
def test_create_empty_doc_list_from_schema(transformation):
class CustomDoc(BaseDoc):
Expand Down Expand Up @@ -260,7 +264,6 @@ class ResultTestDoc(BaseDoc):
assert len(custom_da) == 0


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_create_with_field_info():
class CustomDoc(BaseDoc):
"""Here I have the description of the class"""
Expand Down