Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ env.sh
.mypy_cache
notebooks/output
notebooks/repos
.venv/
.vscode/
72 changes: 52 additions & 20 deletions notebooks/codesearchnet-opennmt.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
"""
CLI tool for converting CodeSearchNet dataset to OpenNMT format for
function name suggestion task.

Usage example:
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
unzip java.zip
python notebooks/codesearchnet-opennmt.py \
--data_dir='java/final/jsonl/valid' \
--newline='\\n'
"""
from argparse import ArgumentParser, Namespace
import logging
from pathlib import Path
from time import time
from typing import List, Tuple

import pandas as pd
from torch.utils.data import Dataset


logging.basicConfig(level=logging.INFO)


class CodeSearchNetRAM(Dataset):
"""Stores one split of CodeSearchNet data in memory

Usage example:
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
unzip java.zip
python notebooks/codesearchnet-opennmt.py \
--data_dir='java/final/jsonl/valid' \
--newline='\\n'
"""
class CodeSearchNetRAM(object):
"""Stores one split of CodeSearchNet data in memory"""

def __init__(self, split_path: Path, newline_repl: str):
super().__init__()
self.pd = pd
self.newline_repl = newline_repl

files = sorted(split_path.glob("**/*.gz"))
logging.info(f"Total number of files: {len(files):,}")
assert files, "could not find files under %s" % split_path

columns_list = ["code", "func_name"]
columns_list = ["code", "func_name", "code_tokens"]

start = time()
self.pd = self._jsonl_list_to_dataframe(files, columns_list)
Expand Down Expand Up @@ -61,10 +64,21 @@ def __getitem__(self, idx: int) -> Tuple[str, str]:

# drop fn signature
code = row["code"]
fn_body = code[code.find("{") + 1 : code.rfind("}")].lstrip().rstrip()
fn_body = fn_body.replace("\n", "\\n")
fn_body = (
code[
code.find("{", code.find(fn_name) + len(fn_name)) + 1 : code.rfind("}")
]
.lstrip()
.rstrip()
)
fn_body = fn_body.replace("\n", self.newline_repl)
# fn_body_enc = self.enc.encode(fn_body)
return (fn_name, fn_body)

tokens = row["code_tokens"]
body_tokens = tokens[tokens.index(fn_name) + 2 :]
fn_body_tokens = body_tokens[body_tokens.index("{") + 1 : len(body_tokens) - 1]

return (fn_name, fn_body, fn_body_tokens)

def __len__(self) -> int:
return len(self.pd)
Expand All @@ -76,11 +90,16 @@ def main(args: Namespace) -> None:
with open(args.src_file % split_name, mode="w", encoding="utf8") as s, open(
args.tgt_file % split_name, mode="w", encoding="utf8"
) as t:
for fn_name, fn_body in dataset:
for fn_name, fn_body, fn_body_tokens in dataset:
if not fn_name or not fn_body:
continue
print(fn_body, file=s)
print(fn_name if args.word_level_targets else " ".join(fn_name), file=t)
src = " ".join(fn_body_tokens) if args.token_level_sources else fn_body
tgt = fn_name if args.word_level_targets else " ".join(fn_name)
if args.print:
print(f"'{tgt[:40]:40}' - '{src[:40]:40}'")
else:
print(src, file=s)
print(tgt, file=t)


if __name__ == "__main__":
Expand All @@ -96,18 +115,31 @@ def main(args: Namespace) -> None:
"--newline", type=str, default="\\n", help="Replace newline with this"
)

parser.add_argument(
"--token-level-sources",
action="store_true",
help="Use language-specific token sources instead of word level ones",
)

parser.add_argument(
"--word-level-targets",
action="store_true",
help="Use word level targets instead of char level ones",
)

parser.add_argument(
"--src_file", type=str, default="src-%s.txt", help="File with function bodies",
"--src_file",
type=str,
default="src-%s.token",
help="File with function bodies",
)

parser.add_argument(
"--tgt_file", type=str, default="tgt-%s.token", help="File with function texts"
)

parser.add_argument(
"--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts"
"--print", action="store_true", help="Print data preview to the STDOUT"
)

args = parser.parse_args()
Expand Down